In [1]:
import numpy as np
import numpy.linalg as linalg
import os
import librosa as lb
import soundfile as sf
import museval
from matplotlib import pyplot as plt
import itertools
import random
from tqdm import tqdm

In [2]:
def leastnormsoln(r, s):
    b = np.array([r-s]).T
    A = np.array([[s[1], s[2], s[3], 0, 0 ,0, 0, 0, 0, 0, 0, 0], 
                  [0, 0, 0, s[0], s[2], s[3], 0, 0, 0, 0, 0, 0], 
                  [0, 0, 0, 0, 0, 0, s[0], s[1], s[3], 0, 0, 0],
                  [0, 0, 0, 0, 0, 0, 0, 0, 0, s[0], s[1], s[2]]])
    soln, residual, rank, singular = linalg.lstsq(A, b, rcond=-1)
    return np.array([[1, soln[0][0], soln[1][0], soln[2][0]],
                     [soln[3][0], 1, soln[4][0], soln[5][0]],
                     [soln[6][0], soln[7][0], 1, soln[8][0]],
                     [soln[9][0], soln[10][0], soln[11][0], 1]])

In [3]:
def norm(A):
    forb = linalg.norm(A, ord='fro')
    l1 = linalg.norm(A, ord=1)
    return l1, forb

In [5]:
def framewise(R, S):
    lmda = []
    for i in range(R.shape[1]):
        r = R[:, i]
        s = S[:, i]
        lmda.append(leastnormsoln(r, s))
        
    return lmda

In [32]:
bleed_path = '/home/rajesh/Desktop/Datasets/musdb18hq_bleeded/test/Sambasevam Shanmugam - Kaathaadi/'
clean_path = '/home/rajesh/Desktop/Datasets/musdb18hq/test/Sambasevam Shanmugam - Kaathaadi/'
dest_path = '/home/rajesh/Desktop/'


bvocals, fs = lb.load(bleed_path+'vocals.wav')
bbass, fs = lb.load(bleed_path+'bass.wav')
bdrums, fs = lb.load(bleed_path+'drums.wav')
bother, fs = lb.load(bleed_path+'other.wav')

vocals, fs = lb.load(clean_path+'vocals.wav')
bass, fs = lb.load(clean_path+'bass.wav')
drums, fs = lb.load(clean_path+'drums.wav')
other, fs = lb.load(clean_path+'other.wav')

if len(bbass) > len(bass):
    n = len(bass)
else:
    n = len(bbass)

R = np.array([bvocals[:n], bbass[:n], bdrums[:n], bother[:n]])
S = np.array([vocals[:n], bass[:n], drums[:n], other[:n]])

In [33]:
lmbda = framewise(R, S)

In [34]:
s_aprx = []
for i in range(R.shape[1]):
    A = lmbda[i]
    A_inv = np.linalg.inv(A)
    s_aprx.append(np.dot(A_inv, R[:, i]))

In [35]:
s_aprx = np.array(s_aprx).T

In [36]:
out = '/home/rajesh/Desktop/'
sf.write(out+'pred_vocal.wav', s_aprx[0], fs)
sf.write(out+'pred_bass.wav', s_aprx[1], fs)
sf.write(out+'pred_drums.wav', s_aprx[2], fs)
sf.write(out+'pred_other.wav', s_aprx[3], fs)

### BLOCKWISE

In [37]:
def blockleastnormsoln(r, s):
    r = r.flatten()
    s = s.flatten()
    b = r-s
    
    val = int(r.shape[0]/4)
    s0 = s[0:val]
    s1 = s[val:2*val]
    s2 = s[2*val:3*val]
    s3 = s[3*val:4*val]
    
    z = np.zeros(val)

    c1 = np.hstack((s1, z, z, z))
    c2 = np.hstack((s2, z, z, z))
    c3 = np.hstack((s3, z, z, z))

    c4 = np.hstack((z, s0, z, z))
    c5 = np.hstack((z, s2, z, z))
    c6 = np.hstack((z, s3, z, z))

    c7 = np.hstack((z, z, s0, z))
    c8 = np.hstack((z, z, s1, z))
    c9 = np.hstack((z, z, s3, z))

    c10 = np.hstack((z, z, z, s0))
    c11 = np.hstack((z, z, z, s1))
    c12 = np.hstack((z, z, z, s2))
    
    A = np.stack((c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12), axis=1)
    
    soln, residual, rank, singular = linalg.lstsq(A, b, rcond=-1)
    
    return np.array([[1, soln[0], soln[1], soln[2]],
                     [soln[3], 1, soln[4], soln[5]],
                     [soln[6], soln[7], 1, soln[8]],
                     [soln[9], soln[10], soln[11], 1]])

In [47]:
time = 20
block = int(time * fs)

samples = 7009232 #use seconds or samples
block = samples

block_lamda = []
for i in range(0, R.shape[1], block):
    r = R[: ,i:i+block]
    s = S[: ,i:i+block]
    block_lamda.append(blockleastnormsoln(r, s))

In [48]:
s_aprx = []
for i in range(0,len(block_lamda), block):
    r = R[: ,i:i+block]
    A = block_lamda[i]
    A_inv = np.linalg.inv(A)
    s_aprx.append(np.dot(A_inv, r))

In [49]:
s_aprx = np.array(s_aprx[0])

In [50]:
out = '/home/rajesh/Desktop/'
sf.write(out+'pred_vocal.wav', s_aprx[0], fs)
sf.write(out+'pred_bass.wav', s_aprx[1], fs)
sf.write(out+'pred_drums.wav', s_aprx[2], fs)
sf.write(out+'pred_other.wav', s_aprx[3], fs)

In [52]:
block_lamda

[array([[1.        , 0.10000024, 0.09999914, 0.10000384],
        [0.10000004, 1.        , 0.09999926, 0.1000045 ],
        [0.1       , 0.10000022, 1.        , 0.10000418],
        [0.09999999, 0.10000023, 0.09999929, 1.        ]])]