In [None]:
!python3 -m pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs

In [1]:
import torch
import numpy as np
import torchaudio
import tempfile
from cdpam import CDPAM
from demucs.separate import Separator
from IPython.display import Audio
from datasets import load_dataset, Dataset
from walloc import walloc
from spauq.core.metrics import spauq_eval
from fastprogress.fastprogress import progress_bar
class Config: pass

In [2]:
device = "cuda"
separator = Separator()
cdpam_loss = CDPAM()
MUSDB = load_dataset("danjacobellis/musdb18HQ", split='validation')

  state = torch.load(modfolder,map_location="cpu")['state']


Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

In [3]:
def pad(audio, p=2**16):
    B,C,L = audio.shape
    padding_size = (p - (L % p)) % p
    if padding_size > 0:
        audio = torch.nn.functional.pad(audio, (0, padding_size), mode='constant', value=0)
    return audio

In [4]:
max_duration=44100*180
SSDR = []
SRDR = []
PSNR = []
CDPAM_dB = []
res_red = torchaudio.transforms.Resample(44100, 16536).to(device)
res_inc = torchaudio.transforms.Resample(16536, 44100).to(device)
with torch.no_grad():
    for i_sample in progress_bar(range(0, len(MUSDB), 5)):
        SSDR.append([])
        SRDR.append([])
        PSNR.append([])
        CDPAM_dB.append([])
        y = []
        for i_instr in range(5):
            sample = MUSDB[i_sample+i_instr]
            instr = sample['instrument']
            x, fs = torchaudio.load(sample['audio']['bytes'])
            x = x[:,:max_duration]
            L = x.shape[-1]
            x_padded = pad(x.unsqueeze(0), 2**16).to(device)
            if i_instr == 0:
                mix = res_inc(res_red(x_padded[0]))
                pred = separator.separate_tensor(mix)
            else:
                ℓ = min(x_padded.shape[-1], pred[1][instr].shape[-1])
                SDR = spauq_eval(
                    reference=x_padded[0,:,:ℓ].to("cpu"),
                    estimate=pred[1][instr][:,:ℓ].to("cpu"),
                    fs = fs
                )
                psnr = 20*np.log10(2) - 10*np.log10(torch.nn.functional.mse_loss(
                    x_padded[0,:,:ℓ],
                    pred[1][instr][:,:ℓ]
                ).item())
                cdpam = cdpam_loss.forward(
                    wav_in=x_padded[0,:,:ℓ],
                    wav_out=pred[1][instr][:,:ℓ]
                )
                SSDR[-1].append(SDR['SSR'])
                SRDR[-1].append(SDR['SRR'])
                PSNR[-1].append(psnr)
                CDPAM_dB[-1].append(-np.log10(cdpam.mean().item()))
SSDR = torch.tensor(SSDR)
SRDR = torch.tensor(SRDR)
PSNR = torch.tensor(PSNR)
CDPAM_dB = torch.tensor(CDPAM_dB)



In [5]:
metrics = {
    'SSDR': SSDR,
    'SRDR': SRDR,
    'PSNR': PSNR,
    'CDPAM_dB': CDPAM_dB,
}

instruments = ['other', 'drums', 'bass', 'vocals']

data = {}
for metric_name, metric_tensor in metrics.items():
    for i, instr in enumerate(instruments):
        column_name = f"{metric_name}_{instr}"
        data[column_name] = metric_tensor[:, i].tolist()
dataset = Dataset.from_dict(data)

In [8]:
dataset.push_to_hub("danjacobellis/LSDIR_demucs_2xRR",split='validation')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/LSDIR_demucs_2xRR/commit/1c7afb2b6422a7ee6a2d9e8ade801ee38fe44f52', commit_message='Upload dataset', commit_description='', oid='1c7afb2b6422a7ee6a2d9e8ade801ee38fe44f52', pr_url=None, pr_revision=None, pr_num=None)

In [6]:
dataset.to_pandas().mean(axis=0)

SSDR_other         13.883515
SSDR_drums         19.460045
SSDR_bass          16.062864
SSDR_vocals        12.708179
SRDR_other          4.994606
SRDR_drums          8.426462
SRDR_bass           0.655329
SRDR_vocals        -1.130510
PSNR_other         37.345101
PSNR_drums         39.134829
PSNR_bass          41.124241
PSNR_vocals        40.735187
CDPAM_dB_other      4.693773
CDPAM_dB_drums      4.460550
CDPAM_dB_bass       5.002654
CDPAM_dB_vocals     4.772007
dtype: float64

In [7]:
dataset.to_pandas()

Unnamed: 0,SSDR_other,SSDR_drums,SSDR_bass,SSDR_vocals,SRDR_other,SRDR_drums,SRDR_bass,SRDR_vocals,PSNR_other,PSNR_drums,PSNR_bass,PSNR_vocals,CDPAM_dB_other,CDPAM_dB_drums,CDPAM_dB_bass,CDPAM_dB_vocals
0,12.179016,29.05174,22.537349,8.820852,5.11629,12.705761,10.571674,3.382954,38.349524,40.726144,41.361966,40.802299,3.121012,4.292292,4.349785,4.305566
1,15.7656,15.524216,21.67002,18.33288,7.126939,5.981851,10.498622,8.600775,36.932263,38.111488,43.228338,39.429012,4.888429,4.012938,5.408183,4.596148
2,18.795009,12.882469,11.079025,7.363884,7.699246,4.870228,5.297716,0.193858,35.165179,40.497086,38.38121,38.361912,4.761102,4.475166,4.587887,4.92555
3,12.736026,20.373618,12.704955,0.028049,5.618602,8.992358,5.17739,-24.050797,35.900309,37.084709,35.128201,40.611965,4.433118,4.435271,4.843884,4.939514
4,17.252186,6.456883,13.602947,8.007977,6.76237,1.040346,6.811704,4.398951,42.925821,47.066172,43.121652,49.394346,5.58977,5.361248,5.243661,5.372389
5,13.608643,27.648416,27.884736,0.0,5.293326,13.962954,16.422733,-80.0,40.30318,42.814516,41.70897,56.489552,4.458266,4.76727,5.083093,5.764425
6,17.797342,25.588447,9.612875,12.449044,6.57212,10.800437,4.082094,5.209088,35.811663,41.71601,43.283742,37.020934,5.152307,4.542123,5.198038,4.475703
7,26.805629,13.280812,22.895593,9.005526,11.694765,6.623464,10.920478,-0.653791,37.870708,44.593784,44.217242,38.677367,5.497861,4.582979,4.184532,4.815382
8,8.740922,12.094929,22.299714,16.560555,2.282385,4.462507,11.807274,7.374029,35.742848,38.981929,41.654942,36.004199,4.525406,4.344902,5.248436,4.49189
9,15.845414,20.982093,7.130796,9.821832,6.80872,9.562315,2.103353,3.652058,35.6356,39.941055,38.888905,37.293479,4.725924,4.223544,4.675255,4.697709
