In [None]:
import sys
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from IPython.display import Audio


ROOT_PATH=Path("..").resolve().absolute()
DEVICE=torch.device("cpu")#
sys.path.append(str(ROOT_PATH))

%load_ext autoreload
%autoreload 2

## Compute separation metrics

In [None]:
import torchmetrics.functional.audio as tma
from pathlib import Path
import torchaudio
import torch
import tqdm


def si_snr(preds, target):
    return tma.scale_invariant_signal_noise_ratio(preds=preds.cpu(), target=target.cpu()).mean().item()

def si_sdr(preds, target):
    return tma.scale_invariant_signal_distortion_ratio(preds=preds.cpu(), target=target.cpu()).mean().item()

def sdr(preds, target):
    return tma.signal_distortion_ratio(preds=preds.cpu(), target=target.cpu()).mean().item()

In [None]:
separation_folder = Path("separations")

seps1, seps2, oris1, oris2, ms = [], [], [], [], []
for file in tqdm.autonotebook.tqdm(list(separation_folder.glob("*"))):
    ori1, _ = torchaudio.load(file / "ori1.wav")
    ori2, _ = torchaudio.load(file / "ori2.wav")
    sep1, _ = torchaudio.load(file / "sep1.wav")
    sep2, _ = torchaudio.load(file / "sep2.wav")
    m = ori1 + ori2
    
    if torch.amax(torch.abs(ori1 - m)) < 1e-5 or torch.amax(torch.abs(ori2 - m)) < 1e-5:
        continue
    
    seps1.append(sep1)
    seps2.append(sep2)
    oris1.append(ori1) 
    oris2.append(ori2)
    ms.append(m)

seps1 = torch.cat(seps1, dim=0)
seps2 = torch.cat(seps2, dim=0)
oris1 = torch.cat(oris1, dim=0)
oris2 = torch.cat(oris2, dim=0)
ms = torch.cat(ms, dim=0)

print("SI-SDR piano: ",si_sdr(seps1, oris1))
print("SI-SDR basso: ",si_sdr(seps2, oris2))
print("SI-SDRi piano: ", si_sdr(seps1, oris1) - si_sdr(ms, oris1))
print("SI-SDRi basso: ", si_sdr(seps2, oris2) - si_sdr(ms, oris2))
print()
print("SI-SNR piano: ", si_snr(seps1, oris1))
print("SI-SNR basso: ", si_snr(seps2, oris2))
print("SI-SNRi piano: ", si_snr(seps1, oris1) - si_snr(ms, oris1))
print("SI-SNRi basso: ", si_snr(seps2, oris2) - si_snr(ms, oris2))
print()
print("SDR piano: ",sdr(seps1, oris1))
print("SDR basso: ",sdr(seps2, oris2))
print("SDRi piano: ",sdr(seps1, oris1) - sdr(ms, oris1) )
print("SDRi basso: ",sdr(seps2, oris2) - sdr(ms, oris2) )

#print(f"SI-SNRi (1): {si_snri_1/count}")
#print(f"SI-SNRi (2): {si_snri_2/count}")
#    print(f"SI-SNR (mix): {si_snr(sep1+sep2, m)}\n")
    
    
    


# Compute data statistics

In [None]:
from audio_data_pytorch import WAVDataset
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm 
from audio_data_pytorch.transforms import Loudness

dataset = WAVDataset("/data/MusDB/data/drums/train")

fn = lambda x:x #Loudness(44100, -20)
wavs_norms = pd.Series([np.linalg.norm(fn(wav)) for wav in tqdm(dataset)])

In [None]:
wav_norm_normalized = wavs_norms / wavs_norms.mean()
print(wav_norm_normalized.describe())
wav_norm_normalized.hist(bins=30)

In [None]:
wavs_stds = pd.Series([wav.std().item() for wav in tqdm(dataset)])
wavs_stds.describe()

In [None]:
wav = dataset[0]#,int(1e6):int(1e6)+ 50].numpy()
wav_diff = np.diff(np.diff(wav))
plt.plot(wav)
plt.plot(wav_diff/np.linalg.norm(wav_diff)*np.linalg.norm(wav))
si_snr(wav[:-2], torch.from_numpy(wav_diff))

In [None]:
from main.dataset import ChunkedWAVDataset 
from audio_data_pytorch import AllTransform
dataset_val = ChunkedWAVDataset(
    path="/data/MusDB/data/drums/test", 
    max_chunk_size=22050,
    recursive=True,
    transforms=AllTransform(source_rate=44100, target_rate=44100, mono=True),
)

In [None]:
import scipy
conv1d = scipy.signal.fftconvolve
norm = np.linalg.norm


def compute_cossim(wav:np.ndarray, chunk:np.ndarray):
    chunk_size = chunk.shape[-1]
    chunk_norm = np.linalg.norm(chunk)
    dot_wav = conv1d(wav, np.flip(chunk), mode="same")
    wav_norm = np.sqrt(conv1d(wav**2, np.ones_like(chunk), mode="same"))
    wav_norm[np.isnan(wav_norm)] = 0.0
    
    cossim = np.zeros_like(dot_wav)
    mask = wav_norm != 0
    cossim[mask] = dot_wav[mask] / (wav_norm[mask] * chunk_norm)
    return cossim

chunk = dataset_val[0][0].numpy()
plt.plot(chunk)
plt.show()

cossims = []
for wav in tqdm(dataset):
    wav = wav[0].numpy()
    cossim = compute_cossim(wav, chunk)
    i = cossim.argmax()
    cossims.append( (i, cossim[i]) )
    

#wav = dataset[1][0].numpy()
#chunk = dataset_val[100][0].numpy()

def compute_cossim(wav:np.ndarray, chunk:np.ndarray):
    chunk_size = chunk.shape[-1]
    chunk_norm = np.linalg.norm(chunk)
    dot_wav = conv1d(wav, np.flip(chunk), mode="same")
    wav_norm = np.sqrt(conv1d(wav**2, np.ones_like(chunk), mode="same"))
    wav_norm[np.isnan(wav_norm)] = 0.0
    
    cossim = np.zeros_like(dot_wav)
    mask = wav_norm != 0
    cossim[mask] = dot_wav[mask] / (wav_norm[mask] * chunk_norm)
    return cossim

In [None]:
a,b = zip(*cossims)
chunk_size = chunk.shape[-1]
wav_idx, chunk_idx, val = max(zip(range(len(a)),a,b),key=lambda x:x[-1])

wav_chunk = dataset[wav_idx][0,chunk_idx-chunk_size//2:chunk_idx+chunk_size//2]

print(wav_idx, chunk_idx, val)

plt.plot(chunk/norm(chunk))
plt.plot(wav_chunk/norm(wav_chunk))
display(Audio(wav_chunk, rate=44100))

np.dot(wav_chunk, chunk)/(norm(wav_chunk)*norm(chunk))

In [None]:
i= 1000001
chunk_slice = slice(i-chunk_size//2,i+chunk_size//2)
print(wav_norm[i], np.linalg.norm(wav[chunk_slice]))
print(np.dot(wav[chunk_slice], chunk), dot_wav[i])

In [None]:
display(Audio(chunk, rate=44100))
plt.plot(cossim)

In [None]:
i = (cossim>0.1).nonzero()[0][1000]#cossim.argmax()
print(i,":",cossim[i])

plt.plot(chunk/chunk_norm)
wav_chunk = wav[i-chunk_size//2:i+chunk_size//2]
plt.plot(wav_chunk/wav_norm[i])
print(wav.shape)
display(Audio(wav_chunk, rate=44100))

print(dot_wav[i])
np.dot(wav_chunk, chunk)/(norm(wav_chunk)*norm(chunk))

In [None]:
interpol = np.arange(chunk_size)/chunk_size
display(Audio(wav_chunk*interpol + (1-interpol)*chunk, rate=44100))
print(interpol)