In [None]:
!python3 -m pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs

In [15]:
import torch
import numpy as np
import torchaudio
import tempfile
from cdpam import CDPAM
from demucs.separate import Separator
from IPython.display import Audio
from datasets import load_dataset, Dataset
from walloc import walloc
from spauq.core.metrics import spauq_eval
from fastprogress.fastprogress import progress_bar
class Config: pass

In [2]:
device = "cuda"
separator = Separator()
cdpam_loss = CDPAM()
MUSDB = load_dataset("danjacobellis/musdb18HQ", split='validation')

  state = torch.load(modfolder,map_location="cpu")['state']


Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

In [3]:
def pad(audio, p=2**16):
    B,C,L = audio.shape
    padding_size = (p - (L % p)) % p
    if padding_size > 0:
        audio = torch.nn.functional.pad(audio, (0, padding_size), mode='constant', value=0)
    return audio

In [4]:
max_duration=44100*180
SSDR = []
SRDR = []
PSNR = []
CDPAM_dB = []
res_red = torchaudio.transforms.Resample(44100, 44100*3//32).to(device)
res_inc = torchaudio.transforms.Resample(44100*3//32, 44100).to(device)
with torch.no_grad():
    for i_sample in progress_bar(range(0, len(MUSDB), 5)):
        SSDR.append([])
        SRDR.append([])
        PSNR.append([])
        CDPAM_dB.append([])
        y = []
        for i_instr in range(5):
            sample = MUSDB[i_sample+i_instr]
            instr = sample['instrument']
            x, fs = torchaudio.load(sample['audio']['bytes'])
            x = x[:,:max_duration]
            L = x.shape[-1]
            x_padded = pad(x.unsqueeze(0), 2**16).to(device)
            if i_instr == 0:
                mix = res_inc(res_red(x_padded[0]))
                pred = separator.separate_tensor(mix)
            else:
                ℓ = min(x_padded.shape[-1], pred[1][instr].shape[-1])
                SDR = spauq_eval(
                    reference=x_padded[0,:,:ℓ].to("cpu"),
                    estimate=pred[1][instr][:,:ℓ].to("cpu"),
                    fs = fs
                )
                psnr = 20*np.log10(2) - 10*np.log10(torch.nn.functional.mse_loss(
                    x_padded[0,:,:ℓ],
                    pred[1][instr][:,:ℓ]
                ).item())
                cdpam = cdpam_loss.forward(
                    wav_in=x_padded[0,:,:ℓ],
                    wav_out=pred[1][instr][:,:ℓ]
                )
                SSDR[-1].append(SDR['SSR'])
                SRDR[-1].append(SDR['SRR'])
                PSNR[-1].append(psnr)
                CDPAM_dB[-1].append(-np.log10(cdpam.mean().item()))
SSDR = torch.tensor(SSDR)
SRDR = torch.tensor(SRDR)
PSNR = torch.tensor(PSNR)
CDPAM_dB = torch.tensor(CDPAM_dB)



In [16]:
metrics = {
    'SSDR': SSDR,
    'SRDR': SRDR,
    'PSNR': PSNR,
    'CDPAM_dB': CDPAM_dB,
}

instruments = ['other', 'drums', 'bass', 'vocals']

data = {}
for metric_name, metric_tensor in metrics.items():
    for i, instr in enumerate(instruments):
        column_name = f"{metric_name}_{instr}"
        data[column_name] = metric_tensor[:, i].tolist()
dataset = Dataset.from_dict(data)

In [28]:
dataset.to_pandas().mean(axis=0)

SSDR_other         10.704823
SSDR_drums         16.199578
SSDR_bass          14.704946
SSDR_vocals         6.966281
SRDR_other          2.198909
SRDR_drums          5.626160
SRDR_bass          -1.134002
SRDR_vocals        -6.824239
PSNR_other         35.039341
PSNR_drums         37.322374
PSNR_bass          40.381729
PSNR_vocals        37.350482
CDPAM_dB_other      4.355985
CDPAM_dB_drums      4.126468
CDPAM_dB_bass       4.931243
CDPAM_dB_vocals     4.275728
dtype: float64

In [29]:
dataset.to_pandas()

Unnamed: 0,SSDR_other,SSDR_drums,SSDR_bass,SSDR_vocals,SRDR_other,SRDR_drums,SRDR_bass,SRDR_vocals,PSNR_other,PSNR_drums,PSNR_bass,PSNR_vocals,CDPAM_dB_other,CDPAM_dB_drums,CDPAM_dB_bass,CDPAM_dB_vocals
0,7.125585,26.222903,21.796037,4.307801,0.188382,11.379605,9.957638,-1.223913,35.672066,39.531105,40.318915,38.146466,3.097354,3.409934,4.35777,3.885645
1,9.097131,13.015608,21.178335,15.093844,2.851003,4.18781,10.648108,6.624113,33.39687,36.739286,43.064438,37.809442,4.047241,3.62845,5.386743,4.322017
2,11.998509,9.891652,10.299504,2.222695,4.502746,2.159502,4.131414,-6.806898,32.232472,38.612324,37.387804,34.487823,4.334126,4.152385,4.570494,4.375951
3,5.169636,18.361984,10.315825,0.00751,-2.014159,8.036357,4.394193,-28.936085,32.064585,36.302704,34.574281,38.475836,3.835742,4.245982,4.408003,4.49542
4,14.43887,0.003954,12.718195,7.588516,4.944657,-29.042175,6.100142,0.69641,41.33152,41.230927,40.650744,46.630686,4.611304,4.751829,5.221698,4.931275
5,11.627238,23.552683,27.723527,0.0,3.447574,11.522928,15.527971,-80.0,38.692351,40.53961,40.718782,51.640048,4.19585,4.371332,5.099035,5.444118
6,12.63362,22.977482,6.122392,6.987067,3.555213,9.336339,3.092606,1.764556,33.176431,40.459815,42.460533,35.266687,4.736176,4.395675,5.055411,4.182013
7,16.583538,14.069318,22.219807,1.401862,6.410601,6.580141,10.789367,-15.513867,33.884896,43.915548,43.967196,35.471537,4.572604,4.490681,4.21281,4.355609
8,8.223966,8.35228,19.526706,7.496367,0.062056,1.861201,10.22156,2.02172,34.382981,36.776081,39.814354,33.10497,4.247733,3.96893,5.05987,4.01067
9,12.362068,19.337205,6.358176,5.257035,4.51122,8.22451,1.570637,1.119307,33.318047,38.778647,38.164281,35.219505,4.322004,4.068309,4.602257,4.263597
