In [1]:
import numpy as np
from scipy.optimize import minimize

In [267]:
def objective(params, X, m, gamma1, gamma2):
    n, l = X.shape
    A = params[:n**2].reshape((n, n))
    S = params[n**2:].reshape((n, l))
    t = np.sum(S, axis=0)
    return np.linalg.norm(X - np.dot(A, S))**2 + np.linalg.norm(m - t)**2

def update_A(S, X, gamma2):
    n, l = X.shape
    E = np.eye(n) - 1/n
    A = np.dot(X, S.T) + gamma2 * E
    SS = np.dot(S, S.T) + gamma2 * np.eye(l)
    A = np.dot(A, np.linalg.inv(SS))
    np.fill_diagonal(A, 1)
    return A.ravel()

def update_S(A, X):
    return np.linalg.lstsq(A, X, rcond=None)[0].ravel()

def minimize_function(X, m, gamma1, gamma2):

    # initialize A as identity matrix and S as X
    n, l = X.shape
    A_init = np.eye(n)
    S_init = X.copy()

    # set bounds for A and S separately
    A_bounds = [(1, 1) if i == j else (gamma1, gamma2) for i in range(n) for j in range(n)]
    S_bounds = [(-np.inf, np.inf)] * (n * l)
    bounds = A_bounds + S_bounds

    # alternating optimization
    params = np.concatenate([A_init.ravel(), S_init.ravel()])
    res = minimize(objective, params, args=(X, m, gamma1, gamma2), method='Powell', bounds=bounds)

    # extract the optimized A and S
    A_opt = res.x[:n**2].reshape((n, n))
    S_opt = res.x[n**2:].reshape((n, l))
    
    return A_opt, S_opt


In [144]:
A = np.array([[1, 0.02, 0.22],
             [0.1, 1, 0.02],
             [0.1, 0.1, 1]])

In [150]:
S = np.random.random((3, 10))
X = A @ S

In [151]:
m = S[0]+S[1]+S[2]

In [157]:
gamma1 = 0
gamma2 = 1
A_opt, S_opt = minimize_function(X, m, gamma1, gamma2)
l1 = np.linalg.norm(X - np.dot(A_opt, S_opt))
l2 = np.linalg.norm(m - (S_opt[0]+S_opt[1]+S_opt[2]))
A_opt,l1, l2, l1+l2

(array([[1.        , 0.06843177, 0.12250687],
        [0.09254738, 1.        , 0.10953054],
        [0.10842322, 0.04973531, 1.        ]]),
 2.082960350869624e-06,
 2.736596681014895e-06,
 4.8195570318845185e-06)

In [149]:
np.linalg.norm(A-A_opt), np.linalg.norm(S-S_opt)

(0.1420167871496621, 0.35506382445659007)

In [40]:
import librosa as lb
import soundfile as sf
from matplotlib import pyplot as plt
import os

In [168]:
bleed_path = '/home/rajesh/Desktop/Datasets/musdb18hq_bleeded/train/Music Delta - Hendrix/'
bvocals, fs = lb.load(bleed_path+'vocals.wav')
bbass, fs = lb.load(bleed_path+'bass.wav')
bdrums, fs = lb.load(bleed_path+'drums.wav')
bother, fs = lb.load(bleed_path+'other.wav')

clean_path = '/home/rajesh/Desktop/Datasets/musdb18hq/train/Music Delta - Hendrix/'
vocals, fs = lb.load(clean_path+'vocals.wav')
bass, fs = lb.load(clean_path+'bass.wav')
drums, fs = lb.load(clean_path+'drums.wav')
other, fs = lb.load(clean_path+'other.wav')

mixture, fs = lb.load(clean_path+'mixture.wav')

In [292]:
if len(bbass) > len(bass):
    n_ = len(bass)
else:
    n_ = len(bbass)
    
if n_ > len(mixture):
    n_ = len(mixture)
    
vocals = vocals[:n_]
bass = bass[:n_]
drums = drums[:n_]
other = other[:n_]

bvocals = bvocals[:n_]
bbass = bbass[:n_]
bdrums = bdrums[:n_]
bother = bother[:n_]

m = mixture[:n_]

In [293]:
1 *fs

22050

In [294]:
X = np.array([bvocals, bbass, bdrums, bother])
S = np.array([vocals, bass, drums, other])

X.shape, S.shape, m.shape

((4, 437559), (4, 437559), (437559,))

In [295]:
X = X[:, 1200:1225]
S = S[:, 1200:1225]
m = m[1200:1225]

In [296]:
gamma1 = 0
gamma2 = 0.4
A_opt, S_opt = minimize_function(X, m, gamma1, gamma2)
l1 = np.linalg.norm(X - np.dot(A_opt, S_opt))
l2 = np.linalg.norm(m - (S_opt[0]+S_opt[1]+S_opt[2]+S_opt[3]))
A_opt, l1, l2, l1+l2

(array([[1.        , 0.28874028, 0.11956632, 0.10669804],
        [0.39097117, 1.        , 0.16794349, 0.02533739],
        [0.39989507, 0.12364107, 1.        , 0.174168  ],
        [0.35299177, 0.06508365, 0.03121272, 1.        ]]),
 5.1904028291821654e-05,
 3.612653996495037e-05,
 8.803056825677202e-05)

In [286]:
np.linalg.norm(m - (S[0]+S[1]+S[2]+S[3]))

0.0002051982

In [291]:
24*50

1200

In [25]:
out = '/home/rajesh/Desktop/'
sf.write(out+'pred_vocal.wav', S_opt[0], fs)
sf.write(out+'pred_bass.wav', S_opt[1], fs)
sf.write(out+'pred_drums.wav', S_opt[2], fs)
sf.write(out+'pred_other.wav', S_opt[3], fs)

In [None]:
def get_metrics(y):
    avg_y = []
    for i in range(len(y)):
        x = y[~np.isnan(y)]
        avg = sum(x)/len(x)
        avg_y.append(avg)
    return avg_y

In [None]:
def compute_sdr(true, reconstructed, fs):
    t = np.array([true])
    r = np.array([reconstructed])

    sdr, isr, sir, sar = museval.evaluate(t, r, win=fs, hop=fs)
        
    avg_sdr = get_metrics(sdr)
    avg_isr = get_metrics(isr) #Source to Spatial Distortion Image
    avg_sir = get_metrics(sir)
    avg_sar = get_metrics(sar)

    return sum(avg_sdr)/len(avg_sdr)

In [None]:
v_sdr = compute_sdr(vocals[:n], bvocals[:n], fs)
b_sdr = compute_sdr(bass[:n], bbass[:n], fs)
d_sdr = compute_sdr(drums[:n], bdrums[:n], fs)
o_sdr = compute_sdr(other[:n], bother[:n], fs)

sdr = (v_sdr + b_sdr + d_sdr + o_sdr)/4
sdr, v_sdr, b_sdr, d_sdr, o_sdr

In [None]:
v_sdr = compute_sdr(vocals[:n], S[0], fs)
b_sdr = compute_sdr(bass[:n], S[1], fs)
d_sdr = compute_sdr(drums[:n], S[2], fs)
o_sdr = compute_sdr(other[:n], S[3], fs)

sdr = (v_sdr + b_sdr + d_sdr + o_sdr)/4
sdr, v_sdr, b_sdr, d_sdr, o_sdr