In [75]:
import numpy as np
import librosa as lb
import soundfile as sf
from matplotlib import pyplot as plt
from tqdm import tqdm
import museval

In [47]:
x = np.array([[1, 2, 3, 4],
             [4, 5, 6, 4],
             [1, 7, 3, 4],
             [2, 2, 3, 4],
             [1, 2, 0, 4]]).T

In [66]:
A_act = np.array([[1, 0.2, 0.1, 0.05],
             [0.2, 1, 0.1, 0.05],
             [0.1, 0.2, 1, 0.15],
             [0.4, 0.2, 0.1, 1]])

In [67]:
s = np.dot(A_act, x)

In [68]:
s

array([[1.9, 5.8, 2.9, 2.9, 1.6],
       [2.7, 6.6, 7.7, 2.9, 2.4],
       [4.1, 8. , 5.1, 4.2, 1.1],
       [5.1, 7.2, 6.1, 5.5, 4.8]])

In [31]:
import numpy as np

def vector_factorisation(x, gamma1, gamma2, max_iter=1000, tol=1e-50):
    n, l = x.shape
    s = np.copy(x)
    A = np.eye(n)
    epsilon = 1e-10

    for i in range(max_iter):
        # update A
        for j in range(n):
            A[j, j] = 1  # diagonal elements
            for k in range(n):
                if k != j:
                    num = x[j] @ s[k] - A[k] @ s + A[j, k] * s[j] * s[k]
                    A[j, k] = max(gamma1, min(gamma2, np.squeeze(num / (np.dot(s[j], s[j]) + epsilon))[0]))



        # update s
        s = np.linalg.lstsq(A, x, rcond=None)[0]

        # calculate error
        error = np.linalg.norm(x - A @ s) ** 2

        if i > 0 and abs(error - prev_error) < tol:
            break

        prev_error = error

    return A, s


In [53]:
x = x

gamma1 = 0.05
gamma2 = 0.4
A, s = vector_factorisation(x, gamma1, gamma2)
error = np.linalg.norm(x - np.dot(A, s)) ** 2
x

array([[1, 4, 1, 2, 1],
       [2, 5, 7, 2, 2],
       [3, 6, 3, 3, 0],
       [4, 4, 4, 4, 4]])

In [59]:
np.round(s)

array([[-1.,  1., -2., -0., -0.],
       [ 0.,  3.,  7.,  0.,  1.],
       [ 2.,  4., -0.,  2., -2.],
       [ 3.,  1.,  2.,  3.,  4.]])

In [70]:
A

array([[1.        , 0.4       , 0.4       , 0.4       ],
       [0.05      , 1.        , 0.39734594, 0.4       ],
       [0.05      , 0.4       , 1.        , 0.4       ],
       [0.05      , 0.4       , 0.4       , 1.        ]])

In [61]:
error

3.0847884543329836e-29

In [65]:
np.round(np.dot(A, s))

array([[ 1.,  4.,  1.,  2.,  1.],
       [ 2.,  5.,  7.,  2.,  2.],
       [ 3.,  6.,  3.,  3., -0.],
       [ 4.,  4.,  4.,  4.,  4.]])

In [72]:
np.linalg.norm(A_act - A) ** 2

0.843414610983212

In [76]:
bleed_path = '/home/rajesh/Desktop/Datasets/musdb18hq_bleeded/train/Music Delta - Hendrix/'
bvocals, fs = lb.load(bleed_path+'vocals.wav')
bbass, fs = lb.load(bleed_path+'bass.wav')
bdrums, fs = lb.load(bleed_path+'drums.wav')
bother, fs = lb.load(bleed_path+'other.wav')

In [77]:
clean_path = '/home/rajesh/Desktop/Datasets/musdb18hq/train/Music Delta - Hendrix/'
vocals, fs = lb.load(clean_path+'vocals.wav')
bass, fs = lb.load(clean_path+'bass.wav')
drums, fs = lb.load(clean_path+'drums.wav')
other, fs = lb.load(clean_path+'other.wav')

In [78]:
if len(bbass) > len(bass):
    n = len(bass)
else:
    n = len(bbass)

In [79]:
R = np.array([bvocals[:n], bbass[:n], bdrums[:n], bother[:n]]) #two min
R.shape

(4, 437559)

In [80]:
S = np.array([vocals[:n], bass[:n], drums[:n], other[:n]])
S.shape

(4, 437559)

In [102]:
gamma1 = 0.0005
gamma2 = 0.08
A, s = vector_factorisation(R, gamma1, gamma2)

In [103]:
A

array([[1.        , 0.08      , 0.08      , 0.04947618],
       [0.08      , 1.        , 0.08      , 0.04987831],
       [0.08      , 0.08      , 1.        , 0.04009501],
       [0.08      , 0.08      , 0.08      , 1.        ]])

In [104]:
s

array([[-1.24165251e-06, -2.00534027e-05, -8.51605912e-07, ...,
        -4.79479387e-05, -1.39942263e-05,  9.53296534e-06],
       [ 3.45760699e-05,  1.28867105e-05,  2.30465073e-05, ...,
        -4.81764492e-05, -3.29970375e-05,  5.60306879e-06],
       [-2.64251944e-05,  5.33046264e-05,  2.67467438e-04, ...,
        -2.95898636e-05,  6.28946254e-06, -3.27101676e-06],
       [ 3.88650171e-06, -2.39571105e-05, -2.30537250e-05, ...,
        -4.61341820e-05, -1.79287287e-05, -4.21447131e-06]])

In [105]:
R

array([[-3.9729321e-07, -1.5943402e-05,  2.1248899e-05, ...,
        -5.6451787e-05, -1.7017877e-05,  9.5110136e-06],
       [ 3.2556574e-05,  1.4351868e-05,  4.3225893e-05, ...,
        -5.6680568e-05, -3.4507673e-05,  5.8938140e-06],
       [-2.3602612e-05,  5.1770730e-05,  2.6831869e-04, ...,
        -3.9129565e-05,  1.8113088e-06, -2.2291133e-06],
       [ 4.4392395e-06, -2.0266076e-05,  1.1926210e-07, ...,
        -5.6191322e-05, -2.1184873e-05, -3.2652699e-06]], dtype=float32)

In [106]:
out = '/home/rajesh/Desktop/'
sf.write(out+'pred_vocal.wav', s[0], fs)
sf.write(out+'pred_bass.wav', s[1], fs)
sf.write(out+'pred_drums.wav', s[2], fs)
sf.write(out+'pred_other.wav', s[3], fs)

In [107]:
def get_metrics(y):
    avg_y = []
    for i in range(len(y)):
        x = y[~np.isnan(y)]
        avg = sum(x)/len(x)
        avg_y.append(avg)
    return avg_y

In [108]:
def compute_sdr(true, reconstructed, fs):
    t = np.array([true])
    r = np.array([reconstructed])

    sdr, isr, sir, sar = museval.evaluate(t, r, win=fs, hop=fs)
        
    avg_sdr = get_metrics(sdr)
    avg_isr = get_metrics(isr) #Source to Spatial Distortion Image
    avg_sir = get_metrics(sir)
    avg_sar = get_metrics(sar)

    return sum(avg_sdr)/len(avg_sdr)

In [109]:
v_sdr = compute_sdr(vocals[:n], s[0], fs)
b_sdr = compute_sdr(bass[:n], s[1], fs)
d_sdr = compute_sdr(drums[:n], s[2], fs)
o_sdr = compute_sdr(other[:n], s[3], fs)

sdr = (v_sdr + b_sdr + d_sdr + o_sdr)/4
sdr, v_sdr, b_sdr, d_sdr, o_sdr

(26.16621468139518,
 24.76758269105062,
 26.917228591883482,
 28.178508484063016,
 24.801538958583592)

In [110]:
v_sdr = compute_sdr(vocals[:n], bvocals[:n], fs)
b_sdr = compute_sdr(bass[:n], bbass[:n], fs)
d_sdr = compute_sdr(drums[:n], bdrums[:n], fs)
o_sdr = compute_sdr(other[:n], bother[:n], fs)

sdr = (v_sdr + b_sdr + d_sdr + o_sdr)/4
sdr, v_sdr, b_sdr, d_sdr, o_sdr

(13.643590131533998,
 12.391711369834548,
 14.655533676232949,
 17.10209808779055,
 10.425017392277942)

# SARAGA?

In [112]:
bleed_path = '/home/rajesh/Desktop/Datasets/Saraga3stem/train/Aaniraimekkani/'
vocal, fs = lb.load(bleed_path+'vocals.wav')
mridangam, fs = lb.load(bleed_path+'mridangam.wav')
violin, fs = lb.load(bleed_path+'violin.wav')

In [114]:
R = np.array([vocal, mridangam, violin])

In [None]:
gamma1 = 0.0005
gamma2 = 0.1
A, s = vector_factorisation(R, gamma1, gamma2)

In [None]:
A

In [None]:
s

In [None]:
R

In [None]:
out = '/home/rajesh/Desktop/'
sf.write(out+'vocal_pred.wav', s[0], fs)
sf.write(out+'mridangam_pred.wav', s[1], fs)
sf.write(out+'violin_pred.wav', s[2], fs)
