In [1]:
import sys
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from IPython.display import Audio


ROOT_PATH=Path("..").resolve().absolute()
DEVICE=torch.device("cpu")#
sys.path.append(str(ROOT_PATH))

%load_ext autoreload
%autoreload 2
from main.evaluate_separation import evaluate_data


In [3]:
for path in Path("../grid_search_steps/").glob("*"):
    print("\n---------------------------------")
    smin, smax, rho, num_steps = path.name.split("-")
    print(f"sigma-min={smin} sigma-max={smax}, rho={rho}, num-steps={num_steps}\n")
    evaluate_data(path)


---------------------------------
sigma-min=0.0001 sigma-max=20.0, rho=7.0, num-steps=75

SI-SNR 1:  4.020451545715332
SI-SNR 2:  1.7643216848373413
SI-SNRi 1:  1.2339680194854736
SI-SNRi 2:  4.076409935951233
SDR mean:  3.110156536102295

---------------------------------
sigma-min=0.0001 sigma-max=20.0, rho=7.0, num-steps=100

SI-SNR 1:  4.122089385986328
SI-SNR 2:  2.251411199569702
SI-SNRi 1:  1.3356058597564697
SI-SNRi 2:  4.563499450683594
SDR mean:  3.4070937633514404

---------------------------------
sigma-min=0.0001 sigma-max=20.0, rho=7.0, num-steps=500

SI-SNR 1:  4.944915294647217
SI-SNR 2:  1.9749281406402588
SI-SNRi 1:  2.1584317684173584
SI-SNRi 2:  4.28701639175415
SDR mean:  3.6377131938934326

---------------------------------
sigma-min=0.0001 sigma-max=20.0, rho=7.0, num-steps=250

SI-SNR 1:  4.277606010437012
SI-SNR 2:  2.3484227657318115
SI-SNRi 1:  1.4911224842071533
SI-SNRi 2:  4.660511016845703
SDR mean:  3.4997260570526123

---------------------------------
s

In [6]:
for path in Path("grid-search-schedule/").glob("*"):
    print("\n---------------------------------")
    smin, smax, rho = path.name.split("-")
    print(f"sigma-min={smin} sigma-max={smax}, rho={rho}\n")
    results = evaluate_data(path)
    print(results["sdr"])
    


---------------------------------
sigma-min=0.0001 sigma-max=20.0, rho=7.0

3.6207435131073

---------------------------------
sigma-min=0.1 sigma-max=1.0, rho=5.0

0.7884336113929749

---------------------------------
sigma-min=0.0001 sigma-max=1.0, rho=7.0

2.2821261882781982

---------------------------------
sigma-min=0.1 sigma-max=20.0, rho=5.0

1.1403182744979858

---------------------------------
sigma-min=0.05 sigma-max=1.0, rho=5.0

1.603183388710022

---------------------------------
sigma-min=0.1 sigma-max=5.0, rho=5.0

1.4022505283355713

---------------------------------
sigma-min=0.05 sigma-max=1.0, rho=9.0

1.6622364521026611

---------------------------------
sigma-min=0.0001 sigma-max=5.0, rho=5.0

3.4981727600097656

---------------------------------
sigma-min=0.1 sigma-max=20.0, rho=9.0

1.1558301448822021

---------------------------------
sigma-min=0.0001 sigma-max=5.0, rho=9.0

3.479304075241089

---------------------------------
sigma-min=0.1 sigma-max=5.0, rho=

In [25]:
#!tar czf "sep.tar.xz"  "separations-musdb-irene-2/"

from pathlib import Path
import torchaudio
import tqdm
from main.separation import enforce_mixture_consistency

def load_separation_data(separation_path):
    separation_folder = Path(separation_path)
    seps1, seps2, oris1, oris2, ms = [], [], [], [], []
    chunks = []
    
    for chunk_folder in tqdm.tqdm(list(separation_folder.glob("*"))):
        original_tracks = [torchaudio.load(ori) for ori in sorted(list(chunk_folder.glob("ori*.wav")))]
        separated_tracks = [torchaudio.load(sep) for sep in sorted(list(chunk_folder.glob("sep*.wav")))]

        original_tracks, sample_rates_ori = zip(*original_tracks)
        separated_tracks, sample_rates_sep = zip(*separated_tracks)

        assert len({*sample_rates_ori, *sample_rates_sep}) == 1
        sample_rate = sample_rates_ori[0]

        assert len(original_tracks) == len(separated_tracks)
        assert len(original_tracks) == 2
        ori1, ori2 = original_tracks
        sep1, sep2 = separated_tracks

        m = ori1 + ori2

        #if torch.amax(torch.abs(ori1)) < 1e-3 or torch.amax(torch.abs(ori2)) < 1e-3:
        #    continue
        
        chunks.append(int(chunk_folder.name))
        seps1.append(sep1)
        seps2.append(sep2)
        oris1.append(ori1)
        oris2.append(ori2)
        ms.append(m)

    seps1 = torch.stack(seps1, dim=0)
    seps2 = torch.stack(seps2, dim=0)
    seps = torch.stack([seps1, seps2], dim=1)

    oris1 = torch.stack(oris1, dim=0)
    oris2 = torch.stack(oris2, dim=0)
    oris = torch.stack([oris1, oris2], dim=1)

    ms = torch.stack(ms, dim=0)
    return [seps1,seps2], [oris1,oris2], ms, chunks

def sdr(preds, targets):
    preds = preds.cpu()
    targets = targets.cpu()

    signal = torch.norm(targets) ** 2
    e_error = torch.norm(preds - targets) ** 2
    return 10 * torch.log10( signal/e_error )


(seps1,seps2),(oris1, oris2), ms, chunks = load_separation_data("separations-musdb-dpm2")

tmp = list(zip(chunks, seps1,seps2,oris1,oris2, ms))
tmp = sorted(tmp, key=lambda x:x[0])

bass_sdrs, drums_sdrs = [], []


def is_non_zero(source_waveforms):
  """Return shape (source,) weights for signals that are nonzero."""
  source_norms = torch.sqrt(torch.mean(source_waveforms**2, axis=-1))
  return source_norms > 1e-8

for (chunk,s1,s2,o1,o2,m) in tmp:
    #if torch.amax(torch.abs(o1)) < 1e-3 or torch.amax(torch.abs(o2)) < 1e-3:
    #    continue
    
    if not is_non_zero(o1) or not is_non_zero(o2):
        continue
        
    s1,s2 = enforce_mixture_consistency(m, torch.stack([s1.view(1,1,-1), s2.view(1,1,-1)], dim=1))
    print(f"[{chunk:02d}]: \t({torch.amax(torch.abs(o1)):.4f}) \t({torch.amax(torch.abs(o2)):.4f})\t bass={sdr(s1,o1):.3f}\t drums={sdr(s2,o2):.4f}")
    bass_sdrs.append(sdr(s1,o1))
    drums_sdrs.append(sdr(s2,o2))

sdr1 = sum(bass_sdrs)/len(bass_sdrs)
sdr2 = sum(drums_sdrs)/len(drums_sdrs)
print(sdr1, sdr2)
print((sdr1+sdr2)*0.5)


100%|██████████| 48/48 [00:00<00:00, 147.30it/s]


[00]: 	(0.7027) 	(1.0000)	 bass=-1.833	 drums=6.1318
[01]: 	(0.6571) 	(1.0000)	 bass=4.425	 drums=2.1312
[02]: 	(0.6682) 	(1.0000)	 bass=1.939	 drums=3.7263
[03]: 	(0.5316) 	(1.0000)	 bass=7.625	 drums=2.7491
[04]: 	(0.4069) 	(1.0000)	 bass=3.983	 drums=4.8827
[05]: 	(0.1375) 	(0.4310)	 bass=0.931	 drums=6.1348
[06]: 	(0.4774) 	(1.0000)	 bass=0.663	 drums=6.8117
[07]: 	(0.5299) 	(1.0000)	 bass=3.077	 drums=6.5311
[08]: 	(0.3761) 	(0.8795)	 bass=1.650	 drums=5.5166
[09]: 	(0.3382) 	(1.0000)	 bass=-1.004	 drums=7.4564
[10]: 	(0.2447) 	(1.0000)	 bass=-2.237	 drums=5.7632
[11]: 	(0.7815) 	(1.0000)	 bass=-1.355	 drums=5.9645
[12]: 	(0.3378) 	(0.9728)	 bass=-4.939	 drums=6.4330
[14]: 	(0.4099) 	(1.0000)	 bass=0.986	 drums=4.3455
[15]: 	(0.1440) 	(0.6007)	 bass=-20.446	 drums=7.9618
[16]: 	(0.4570) 	(0.9467)	 bass=8.674	 drums=7.6329
[17]: 	(0.4325) 	(0.9502)	 bass=-0.954	 drums=5.7250
[18]: 	(0.6845) 	(1.0000)	 bass=0.459	 drums=6.0074
[19]: 	(0.4694) 	(0.8108)	 bass=7.217	 drums=7.2800
[20]

In [2]:
evaluate_data("separations-musdb-dpm2")

100%|██████████| 48/48 [00:00<00:00, 124.57it/s]


SI-SDR piano:  0.47045114636421204
SI-SDR basso:  3.94543194770813
SI-SDRi piano:  3.2123315632343292
SI-SDRi basso:  1.348285436630249

SI-SNR piano:  0.471749871969223
SI-SNR basso:  3.9458775520324707
SI-SNRi piano:  3.21295365691185
SI-SNRi basso:  1.3492193222045898

SDR mean:  3.3860299489753944
3.6767312666719043e-10
SDRi mean:  3.3860299486077214


## Compute separation metrics

In [None]:
import torchmetrics.functional.audio as tma
from pathlib import Path
import torchaudio
import torch
import tqdm


def si_snr(preds, target):
    return tma.scale_invariant_signal_noise_ratio(preds=preds.cpu(), target=target.cpu()).mean().item()

def si_sdr(preds, target):
    return tma.scale_invariant_signal_distortion_ratio(preds=preds.cpu(), target=target.cpu()).mean().item()

def sdr(preds, target):
    return tma.signal_distortion_ratio(preds=preds.cpu(), target=target.cpu()).mean().item()

In [None]:
separation_folder = Path("separations")

seps1, seps2, oris1, oris2, ms = [], [], [], [], []
for file in tqdm.autonotebook.tqdm(list(separation_folder.glob("*"))):
    ori1, _ = torchaudio.load(file / "ori1.wav")
    ori2, _ = torchaudio.load(file / "ori2.wav")
    sep1, _ = torchaudio.load(file / "sep1.wav")
    sep2, _ = torchaudio.load(file / "sep2.wav")
    m = ori1 + ori2
    
    if torch.amax(torch.abs(ori1 - m)) < 1e-5 or torch.amax(torch.abs(ori2 - m)) < 1e-5:
        continue
    
    seps1.append(sep1)
    seps2.append(sep2)
    oris1.append(ori1) 
    oris2.append(ori2)
    ms.append(m)

seps1 = torch.cat(seps1, dim=0)
seps2 = torch.cat(seps2, dim=0)
oris1 = torch.cat(oris1, dim=0)
oris2 = torch.cat(oris2, dim=0)
ms = torch.cat(ms, dim=0)

print("SI-SDR piano: ",si_sdr(seps1, oris1))
print("SI-SDR basso: ",si_sdr(seps2, oris2))
print("SI-SDRi piano: ", si_sdr(seps1, oris1) - si_sdr(ms, oris1))
print("SI-SDRi basso: ", si_sdr(seps2, oris2) - si_sdr(ms, oris2))
print()
print("SI-SNR piano: ", si_snr(seps1, oris1))
print("SI-SNR basso: ", si_snr(seps2, oris2))
print("SI-SNRi piano: ", si_snr(seps1, oris1) - si_snr(ms, oris1))
print("SI-SNRi basso: ", si_snr(seps2, oris2) - si_snr(ms, oris2))
print()
print("SDR piano: ",sdr(seps1, oris1))
print("SDR basso: ",sdr(seps2, oris2))
print("SDRi piano: ",sdr(seps1, oris1) - sdr(ms, oris1) )
print("SDRi basso: ",sdr(seps2, oris2) - sdr(ms, oris2) )

#print(f"SI-SNRi (1): {si_snri_1/count}")
#print(f"SI-SNRi (2): {si_snri_2/count}")
#    print(f"SI-SNR (mix): {si_snr(sep1+sep2, m)}\n")
    
    
    


  0%|          | 0/35 [00:00<?, ?it/s]

SI-SDR piano:  -1.4496734142303467
SI-SDR basso:  2.3900609016418457
SI-SDRi piano:  0.8651337623596191
SI-SDRi basso:  -3.06253719329834

SI-SNR piano:  -1.6719918251037598
SI-SNR basso:  1.7373285293579102
SI-SNRi piano:  0.477435827255249
SI-SNRi basso:  -3.596161365509033

SDR piano:  0.2750040888786316
SDR basso:  2.827162981033325
SDRi piano:  1.6498166918754578
SDRi basso:  -3.1223719120025635


# Compute data statistics

In [None]:
from audio_data_pytorch import WAVDataset
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm 
from audio_data_pytorch.transforms import Loudness

dataset = WAVDataset("/data/MusDB/data/drums/train")

fn = lambda x:x #Loudness(44100, -20)
wavs_norms = pd.Series([np.linalg.norm(fn(wav)) for wav in tqdm(dataset)])

In [None]:
wav_norm_normalized = wavs_norms / wavs_norms.mean()
print(wav_norm_normalized.describe())
wav_norm_normalized.hist(bins=30)

In [None]:
wavs_stds = pd.Series([wav.std().item() for wav in tqdm(dataset)])
wavs_stds.describe()

In [None]:
wav = dataset[0]#,int(1e6):int(1e6)+ 50].numpy()
wav_diff = np.diff(np.diff(wav))
plt.plot(wav)
plt.plot(wav_diff/np.linalg.norm(wav_diff)*np.linalg.norm(wav))
si_snr(wav[:-2], torch.from_numpy(wav_diff))

In [None]:
from main.dataset import ChunkedWAVDataset 
from audio_data_pytorch import AllTransform
dataset_val = ChunkedWAVDataset(
    path="/data/MusDB/data/drums/test", 
    max_chunk_size=22050,
    recursive=True,
    transforms=AllTransform(source_rate=44100, target_rate=44100, mono=True),
)

In [None]:
import scipy
conv1d = scipy.signal.fftconvolve
norm = np.linalg.norm


def compute_cossim(wav:np.ndarray, chunk:np.ndarray):
    chunk_size = chunk.shape[-1]
    chunk_norm = np.linalg.norm(chunk)
    dot_wav = conv1d(wav, np.flip(chunk), mode="same")
    wav_norm = np.sqrt(conv1d(wav**2, np.ones_like(chunk), mode="same"))
    wav_norm[np.isnan(wav_norm)] = 0.0
    
    cossim = np.zeros_like(dot_wav)
    mask = wav_norm != 0
    cossim[mask] = dot_wav[mask] / (wav_norm[mask] * chunk_norm)
    return cossim

chunk = dataset_val[0][0].numpy()
plt.plot(chunk)
plt.show()

cossims = []
for wav in tqdm(dataset):
    wav = wav[0].numpy()
    cossim = compute_cossim(wav, chunk)
    i = cossim.argmax()
    cossims.append( (i, cossim[i]) )
    

#wav = dataset[1][0].numpy()
#chunk = dataset_val[100][0].numpy()

def compute_cossim(wav:np.ndarray, chunk:np.ndarray):
    chunk_size = chunk.shape[-1]
    chunk_norm = np.linalg.norm(chunk)
    dot_wav = conv1d(wav, np.flip(chunk), mode="same")
    wav_norm = np.sqrt(conv1d(wav**2, np.ones_like(chunk), mode="same"))
    wav_norm[np.isnan(wav_norm)] = 0.0
    
    cossim = np.zeros_like(dot_wav)
    mask = wav_norm != 0
    cossim[mask] = dot_wav[mask] / (wav_norm[mask] * chunk_norm)
    return cossim

In [None]:
a,b = zip(*cossims)
chunk_size = chunk.shape[-1]
wav_idx, chunk_idx, val = max(zip(range(len(a)),a,b),key=lambda x:x[-1])

wav_chunk = dataset[wav_idx][0,chunk_idx-chunk_size//2:chunk_idx+chunk_size//2]

print(wav_idx, chunk_idx, val)

plt.plot(chunk/norm(chunk))
plt.plot(wav_chunk/norm(wav_chunk))
display(Audio(wav_chunk, rate=44100))

np.dot(wav_chunk, chunk)/(norm(wav_chunk)*norm(chunk))

In [None]:
i= 1000001
chunk_slice = slice(i-chunk_size//2,i+chunk_size//2)
print(wav_norm[i], np.linalg.norm(wav[chunk_slice]))
print(np.dot(wav[chunk_slice], chunk), dot_wav[i])

In [None]:
display(Audio(chunk, rate=44100))
plt.plot(cossim)

In [None]:
i = (cossim>0.1).nonzero()[0][1000]#cossim.argmax()
print(i,":",cossim[i])

plt.plot(chunk/chunk_norm)
wav_chunk = wav[i-chunk_size//2:i+chunk_size//2]
plt.plot(wav_chunk/wav_norm[i])
print(wav.shape)
display(Audio(wav_chunk, rate=44100))

print(dot_wav[i])
np.dot(wav_chunk, chunk)/(norm(wav_chunk)*norm(chunk))

In [None]:
interpol = np.arange(chunk_size)/chunk_size
display(Audio(wav_chunk*interpol + (1-interpol)*chunk, rate=44100))
print(interpol)