In [1]:
# 2023 (c) LINE Corporation
# Authors: Robin Scheibler
# MIT License
import argparse
import json
import math
import os
import time
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
import yaml
from omegaconf import OmegaConf
from pesq import pesq
from pystoi import stoi

# from sdes.sdes import MixSDE
from datasets import NoisyDataset, WSJ0_mix, musdb_mix
from pl_model import DiffSepModel
import musdb
import museval
import soundfile as sf
import IPython.display as ipd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 16
max_len_s = 3
sr = 16000
max_data = max_len_s*sr
device = 'cuda'

# model = DiffSepModel.load_from_checkpoint('/home/qor6271/Desktop/diff-music-sep/exp/musdb18_16000/2023-12-11_22-15-21_experiment-music-separation-16000-ouvekh_model.sde.sigma_max-0.5_model.sde.sigma_min-0.01_model.sde.theta_int_max-3.0_model.sde.theta_min-0.1_model.sde.theta_rho-4/checkpoints/epoch-4999_si_sdr-0.000.ckpt')
model = DiffSepModel.load_from_checkpoint('/home/qor6271/Desktop/diff-music-sep/exp/musdb18_16000/2023-12-13_22-48-23_experiment-music-separation-16000-ouvekh_model.sde.sigma_max-0.5_model.sde.sigma_min-0.01_model.sde.theta_int_max-3.0_model.sde.theta_min-0.1_model.sde.theta_rho-4/checkpoints/epoch-5999_si_sdr-0.000.ckpt')
# transfer to GPU
model = model.to(device)
model.eval()


DiffSepModel(
  (score_model): ScoreModelNCSNpp(
    (backbone): NCSNpp(
      (act): SiLU()
      (output_layer): Conv2d(10, 8, kernel_size=(1, 1), stride=(1, 1))
      (pyramid_upsample): Upsample()
      (pyramid_downsample): Downsample()
      (all_modules): ModuleList(
        (0): GaussianFourierProjection()
        (1): Linear(in_features=128, out_features=256, bias=True)
        (2): Linear(in_features=256, out_features=256, bias=True)
        (3): Conv2d(10, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): ResnetBlockBigGANpp(
          (GroupNorm_0): GroupNorm(16, 64, eps=1e-06, affine=True)
          (Conv_0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (Dense_0): Linear(in_features=256, out_features=64, bias=True)
          (GroupNorm_1): GroupNorm(16, 64, eps=1e-06, affine=True)
          (Dropout_0): Dropout(p=0.0, inplace=False)
          (Conv_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       

# musdb

In [6]:
musdb_list = musdb.DB(root='data/musdb18', subsets="test")

In [66]:
for idx in range(len(musdb_list)):
    data = musdb_list[idx].stems
    
    data = list(map(lambda x: torch.from_numpy(x).float().transpose(0,1).to(device)[[0]], data))
    
    mix_full = data[0].unsqueeze(0)
    tgt_full = torch.cat(data[1:], dim=0).unsqueeze(0)
    
    for t in range(mix_full.shape[-1]//(max_data*batch_size)+1):
        mix = mix_full[:,:,t*max_data*batch_size+44100*30:(t+1)*max_data*batch_size+44100*30]
        mix = list(mix.split(max_data, dim=2))
        
        tgt = tgt_full[:,:,t*max_data*batch_size+44100*30:(t+1)*max_data*batch_size+44100*30]
        tgt = list(tgt.split(max_data, dim=2))
        if mix[-1].shape[-1] != max_data:
            mix[-1] = torch.nn.functional.pad(mix[-1], (0, max_data-mix[-1].shape[-1]))
            tgt[-1] = torch.nn.functional.pad(tgt[-1], (0, max_data-tgt[-1].shape[-1]))

        mix = torch.cat(mix, dim=0)
        tgt = torch.cat(tgt, dim=0)
        batch, *stats = model.normalize_batch((mix, tgt))

        mix, target = batch

        est, *_ = model.separate(mix)

        est = model.denormalize_batch(est, *stats)
        est = torch.cat(est.split(1, dim=0), dim=2).squeeze()
        if t == 0:
            est_full = est
            break
        else:
            est_full = torch.cat([est_full, est], dim=1)
        
    break

# torchaudio

In [10]:
def separate(filename, device='cuda'):
    def separate_one_channel(mix):
        mix_full = mix.unsqueeze(0)

        for t in range(mix_full.shape[-1]//(max_data*batch_size)+1):
            mix = mix_full[:,:,t*max_data*batch_size:(t+1)*max_data*batch_size]
            mix = list(mix.split(max_data, dim=2))
            
            if mix[-1].shape[-1] != max_data:
                mix[-1] = torch.nn.functional.pad(mix[-1], (0, max_data-mix[-1].shape[-1]))
            
            mix = torch.cat(mix, dim=0).to(device)
            (mix, _), *stats = model.normalize_batch((mix, None))
            est = model.separate(mix, sampler='ddim')
            est = model.denormalize_batch(est, *stats)
            est = torch.cat(est.split(1, dim=0), dim=2).squeeze().detach().cpu()
            if t == 0:
                est_full = est
            else:
                est_full = torch.cat([est_full, est], dim=1)
        est_full = est_full[:,:mix_full.shape[-1]]
        est_full = est_full
        return est_full
    
    
    data = torch.cat(
            [torchaudio.load(f'{filename}/{inst}')[0] for inst in ['mixture.wav', 'drums.wav', 'bass.wav', 'other.wav', 'vocals.wav']], dim=0
        )
    data0 = data[::2, 20*16000:60*16000]
    data1 = data[1::2, 20*16000:60*16000]

    mix = [data0[[0]], data1[[0]]]
    est = [separate_one_channel(mix[0]), separate_one_channel(mix[1])]
    tgt = [data0[1:], data1[1:]]
    
    return est, tgt, mix

In [6]:
class Audio():
    def __init__(self, audio):
        self.audio = audio

class Track():
    def __init__(self, name, mixture, drums, bass, other, vocals, subset='test'):
        self.targets = {"mixture": Audio(mixture), "drums": Audio(drums), "bass": Audio(bass), "other": Audio(other), "vocals": Audio(vocals)}
        self.rate = 16000
        self.name = name
        self.subset = subset
      
def calculate_sdr(name, est, tgt, mix):
    inst = ["drums","bass","other","vocals"]
    
    track_list = []
    track_list.append(Audio(torch.cat([mix[0], mix[1]], dim=0).transpose(0,1).numpy()))
    pred_dict = {}
    for i in range(len(inst)):
        tgt_ = torch.cat([tgt[0][[i]], tgt[1][[i]]], dim=0).transpose(0,1).numpy()
        est_ = torch.cat([est[0][[i]], est[1][[i]]], dim=0).transpose(0,1).numpy()
        est_ = est_ / np.std(est_) * np.std(tgt_)
        pred_dict[inst[i]] = est_
        track_list.append(tgt_)
    
    
    track = Track(name, *track_list)
  
    score = museval.eval_mus_track(
        track, pred_dict, output_dir="eval_sdr"
    )
    return score

In [11]:
root = 'data/musdb18_16000/test/'
# for song in os.listdir(root):
song = np.random.choice(os.listdir(root))
filename = f'{root}/{song}'
filename = 'data/musdb18_16000/test/Carlos Gonzalez - A Place For Us'


        

In [12]:
if os.path.isdir(filename):
    est, tgt, mix = separate(filename)
    scores = calculate_sdr('song', est, tgt, mix)
    
    print(scores)

drums           ==> SDR: -14.935  SIR: -16.569  ISR:   7.413  SAR:  -0.483  
bass            ==> SDR:  -2.192  SIR:  -3.338  ISR:   4.650  SAR:  -0.106  
other           ==> SDR:   1.694  SIR:   2.978  ISR:   8.379  SAR:   2.893  
vocals          ==> SDR:   1.781  SIR:   3.758  ISR:   8.321  SAR:   1.752  



In [28]:
models_2 = ['mixsde_gamma_2','ouvekh_rho_0_gamma_2','ouvekh_rho_4_gamma_2']
models_3 = ['mixsde_gamma_3','ouvekh_rho_0_gamma_3','ouvekh_rho_4_gamma_3']
inst = ["drums","bass","other","vocals"]

for model_ in models_2:
    model = DiffSepModel.load_from_checkpoint(f'exp/musdb18_16000/{model_}/checkpoints/epoch-1999_si_sdr-0.000.ckpt')
    model = model.to(device)
    model.eval()
    filename = 'data/musdb18_16000/test/Carlos Gonzalez - A Place For Us'
    est, tgt, mix = separate(filename)
    scores = calculate_sdr('song', est, tgt, mix)
            
    print(scores)
    for i in range(4):
        sf.write(f'sample/{model_}_{inst[i]}_ddim.wav', torch.cat([est[0][[i]], est[1][[i]]], dim=0).transpose(0,1).numpy(), sr)

drums           ==> SDR:  -0.720  SIR:  -1.707  ISR:   5.702  SAR:   0.629  
bass            ==> SDR:  -3.647  SIR: -11.063  ISR:   2.525  SAR:   0.747  
other           ==> SDR:   0.818  SIR:   0.560  ISR:   7.027  SAR:   2.441  
vocals          ==> SDR:  -0.118  SIR:   1.356  ISR:   6.281  SAR:   0.996  

drums           ==> SDR:  -0.673  SIR:  -1.086  ISR:   5.944  SAR:   0.087  
bass            ==> SDR:  -3.851  SIR: -10.544  ISR:   2.644  SAR:   0.344  
other           ==> SDR:   1.185  SIR:   1.010  ISR:   7.213  SAR:   2.149  
vocals          ==> SDR:   0.311  SIR:   1.973  ISR:   6.489  SAR:   0.727  

drums           ==> SDR:  -0.335  SIR:   0.255  ISR:   6.706  SAR:   0.112  
bass            ==> SDR:  -3.493  SIR:  -9.467  ISR:   2.866  SAR:   0.013  
other           ==> SDR:   1.506  SIR:   1.142  ISR:   7.645  SAR:   2.817  
vocals          ==> SDR:   0.668  SIR:   2.203  ISR:   6.969  SAR:   1.107  



In [29]:
inst = ["drums","bass","other","vocals"]

for model_ in models_3:
    model = DiffSepModel.load_from_checkpoint(f'exp/musdb18_16000/{model_}/checkpoints/epoch-1999_si_sdr-0.000.ckpt')
    model = model.to(device)
    model.eval()
    filename = 'data/musdb18_16000/test/Carlos Gonzalez - A Place For Us'
    if os.path.isdir(filename):
        est, tgt, mix = separate(filename)
        scores = calculate_sdr('song', est, tgt, mix)
        
        print(scores)
        for i in range(4):
            sf.write(f'sample/{model_}_{inst[i]}_ddim.wav', torch.cat([est[0][[i]], est[1][[i]]], dim=0).transpose(0,1).numpy(), sr)

drums           ==> SDR:  -0.550  SIR:   0.353  ISR:   6.390  SAR:  -0.495  
bass            ==> SDR:  -3.505  SIR:  -9.671  ISR:   2.802  SAR:   0.129  
other           ==> SDR:   0.743  SIR:   0.852  ISR:   6.916  SAR:   1.685  
vocals          ==> SDR:   0.018  SIR:   1.743  ISR:   6.815  SAR:   0.818  

drums           ==> SDR:  -0.169  SIR:   1.083  ISR:   6.615  SAR:  -0.417  
bass            ==> SDR:  -3.395  SIR:  -9.129  ISR:   2.881  SAR:   0.006  
other           ==> SDR:   1.077  SIR:   1.248  ISR:   7.067  SAR:   1.564  
vocals          ==> SDR:   0.466  SIR:   2.344  ISR:   6.896  SAR:   0.522  

drums           ==> SDR:  -0.039  SIR:   1.280  ISR:   6.846  SAR:  -0.616  
bass            ==> SDR:  -3.627  SIR:  -8.725  ISR:   2.801  SAR:  -0.886  
other           ==> SDR:   1.164  SIR:   1.700  ISR:   7.262  SAR:   1.543  
vocals          ==> SDR:   0.620  SIR:   2.401  ISR:   6.933  SAR:   0.331  



In [None]:
drums           ==> SDR:  -1.051  SIR:  -2.930  ISR:   4.316  SAR:  -0.486  
bass            ==> SDR:   0.174  SIR:   0.005  ISR:   6.112  SAR:   3.556  
other           ==> SDR:   2.551  SIR:   3.346  ISR:   9.955  SAR:   6.315  
vocals          ==> SDR:  -0.594  SIR:   1.210  ISR:   2.119  SAR:  -0.338  

In [None]:
f, ax = plt.subplots(4,1)
for i in range(4):
    ax[i].plot(est[0][i])
    display(ipd.Audio(est[0][i], rate=16000))

In [None]:
f, ax = plt.subplots(4,1)
for i in range(4):
    ax[i].plot(tgt[0][i])
    display(ipd.Audio(tgt[0][i], rate=16000))

In [9]:
import sdes

In [10]:
def separate(filename, device='cuda'):
    def separate_one_channel(mix, tgt):
        mix_full = mix.unsqueeze(0)
        tgt_full = tgt.unsqueeze(0)

        for t in range(mix_full.shape[-1]//(max_data*batch_size)+1):
            mix = mix_full[:,:,t*max_data*batch_size:(t+1)*max_data*batch_size]
            mix = list(mix.split(max_data, dim=2))
            tgt = tgt_full[:,:,t*max_data*batch_size:(t+1)*max_data*batch_size]
            tgt = list(tgt.split(max_data, dim=2))
            
            
            if mix[-1].shape[-1] != max_data:
                mix[-1] = torch.nn.functional.pad(mix[-1], (0, max_data-mix[-1].shape[-1]))
                tgt[-1] = torch.nn.functional.pad(tgt[-1], (0, max_data-tgt[-1].shape[-1]))
                
            
            mix = torch.cat(mix, dim=0).to(device)
            tgt = torch.cat(tgt, dim=0).to(device)
            (mix, tgt), *stats = model.normalize_batch((mix, tgt))

            # sampler = sdes.get_ddim_sampler(model.sde, model, mix)
            sampler = sdes.get_ddim_sampler(model.sde, model, mix, noise_true=torch.randn_like(tgt), tgt=tgt)
            
            est, *others = sampler()
            est = model.denormalize_batch(est, *stats)
            est = torch.cat(est.split(1, dim=0), dim=2).squeeze().detach().cpu()
            if t == 0:
                est_full = est
            else:
                est_full = torch.cat([est_full, est], dim=1)
        est_full = est_full[:,:mix_full.shape[-1]]
        est_full = est_full
        return est_full
    
    
    data = torch.cat(
            [torchaudio.load(f'{filename}/{inst}')[0] for inst in ['mixture.wav', 'drums.wav', 'bass.wav', 'other.wav', 'vocals.wav']], dim=0
        )
    data0 = data[::2, 16000*20:16000*40]
    data1 = data[1::2, 16000*20:16000*40]

    mix = [data0[[0]], data1[[0]]]
    tgt = [data0[1:], data1[1:]]
    est = [separate_one_channel(mix[0], tgt[0]), separate_one_channel(mix[1], tgt[1])]
    
    return est, tgt, mix

In [11]:
# root = 'data/musdb18_16000/test/'
# # for song in os.listdir(root):
# song = np.random.choice(os.listdir(root))
# filename = f'{root}/{song}'
# filename = 'data/musdb18_16000/test/Carlos Gonzalez - A Place For Us'

if os.path.isdir(filename):
    est, tgt, mix = separate(filename)
    scores = calculate_sdr('song', est, tgt, mix)
    
    print(scores)

drums           ==> SDR:   1.608  SIR:   6.953  ISR:   7.706  SAR:   0.115  
bass            ==> SDR:   2.919  SIR:   6.918  ISR:   8.531  SAR:   3.186  
other           ==> SDR:   5.816  SIR:   8.821  ISR:  14.315  SAR:   7.101  
vocals          ==> SDR:   1.611  SIR:   7.441  ISR:   7.241  SAR:   1.317  



In [None]:
display(ipd.Audio(est[0][0], rate=16000))
display(ipd.Audio(est[0][1], rate=16000))
display(ipd.Audio(est[0][2], rate=16000))
display(ipd.Audio(est[0][3], rate=16000))

In [None]:
sdr_scores = [0,0,0,0]
for i, t in enumerate(scores.scores["targets"]):
    
    metric = 'SDR'
    
    sdr_scores[i] += np.nanmedian([float(f["metrics"][metric]) for f in t["frames"]])