In [1]:
# 2023 (c) LINE Corporation
# Authors: Robin Scheibler
# MIT License
import argparse
import json
import math
import os
import time
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
import yaml
from omegaconf import OmegaConf
from pesq import pesq
from pystoi import stoi

# from sdes.sdes import MixSDE
from datasets import NoisyDataset, WSJ0_mix, musdb_mix
from pl_model import DiffSepModel
import musdb
import museval
import soundfile as sf
import IPython.display as ipd

  from .autonotebook import tqdm as notebook_tqdm


In [27]:
batch_size = 16
max_len_s = 3
sr = 16000
max_data = max_len_s*sr
device = 'cuda'

model = DiffSepModel.load_from_checkpoint('/home/qor6271/Desktop/diff-music-sep/exp/musdb18_16000/mixsde_score_fn/checkpoints/epoch-5799_si_sdr-0.000.ckpt')
# model = DiffSepModel.load_from_checkpoint('/home/qor6271/Desktop/diff-music-sep/exp/musdb18_16000/mixsde_score_fn/checkpoints/epoch-5799_si_sdr-0.000.ckpt')
# transfer to GPU
model = model.to(device)
model.eval()


DiffSepModel(
  (score_model): ScoreModelNCSNpp(
    (backbone): NCSNpp(
      (act): SiLU()
      (output_layer): Conv2d(10, 8, kernel_size=(1, 1), stride=(1, 1))
      (pyramid_upsample): Upsample()
      (pyramid_downsample): Downsample()
      (all_modules): ModuleList(
        (0): GaussianFourierProjection()
        (1): Linear(in_features=128, out_features=256, bias=True)
        (2): Linear(in_features=256, out_features=256, bias=True)
        (3): Conv2d(10, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): ResnetBlockBigGANpp(
          (GroupNorm_0): GroupNorm(16, 64, eps=1e-06, affine=True)
          (Conv_0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (Dense_0): Linear(in_features=256, out_features=64, bias=True)
          (GroupNorm_1): GroupNorm(16, 64, eps=1e-06, affine=True)
          (Dropout_0): Dropout(p=0.0, inplace=False)
          (Conv_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       

# musdb

In [6]:
musdb_list = musdb.DB(root='data/musdb18', subsets="test")

In [66]:
for idx in range(len(musdb_list)):
    data = musdb_list[idx].stems
    
    data = list(map(lambda x: torch.from_numpy(x).float().transpose(0,1).to(device)[[0]], data))
    
    mix_full = data[0].unsqueeze(0)
    tgt_full = torch.cat(data[1:], dim=0).unsqueeze(0)
    
    for t in range(mix_full.shape[-1]//(max_data*batch_size)+1):
        mix = mix_full[:,:,t*max_data*batch_size+44100*30:(t+1)*max_data*batch_size+44100*30]
        mix = list(mix.split(max_data, dim=2))
        
        tgt = tgt_full[:,:,t*max_data*batch_size+44100*30:(t+1)*max_data*batch_size+44100*30]
        tgt = list(tgt.split(max_data, dim=2))
        if mix[-1].shape[-1] != max_data:
            mix[-1] = torch.nn.functional.pad(mix[-1], (0, max_data-mix[-1].shape[-1]))
            tgt[-1] = torch.nn.functional.pad(tgt[-1], (0, max_data-tgt[-1].shape[-1]))

        mix = torch.cat(mix, dim=0)
        tgt = torch.cat(tgt, dim=0)
        batch, *stats = model.normalize_batch((mix, tgt))

        mix, target = batch

        est, *_ = model.separate(mix)

        est = model.denormalize_batch(est, *stats)
        est = torch.cat(est.split(1, dim=0), dim=2).squeeze()
        if t == 0:
            est_full = est
            break
        else:
            est_full = torch.cat([est_full, est], dim=1)
        
    break

# torchaudio

In [44]:
def separate(filename, device='cuda'):
    def separate_one_channel(mix):
        mix_full = mix.unsqueeze(0)

        for t in range(mix_full.shape[-1]//(max_data*batch_size)+1):
            mix = mix_full[:,:,t*max_data*batch_size:(t+1)*max_data*batch_size]
            mix = list(mix.split(max_data, dim=2))
            
            if mix[-1].shape[-1] != max_data:
                mix[-1] = torch.nn.functional.pad(mix[-1], (0, max_data-mix[-1].shape[-1]))
            
            mix = torch.cat(mix, dim=0).to(device)
            est = model.separate(mix)
            est = torch.cat(est.split(1, dim=0), dim=2).squeeze().detach().cpu()
            if t == 0:
                est_full = est
            else:
                est_full = torch.cat([est_full, est], dim=1)
        est_full = est_full[:,:mix_full.shape[-1]]
        est_full = est_full
        return est_full
    
    
    data = torch.cat(
            [torchaudio.load(f'{filename}/{inst}')[0] for inst in ['mixture.wav', 'drums.wav', 'bass.wav', 'other.wav', 'vocals.wav']], dim=0
        )
    data0 = data[::2, 16000*30:16000*40]
    data1 = data[1::2, 16000*30:16000*40]

    mix = [data0[[0]], data1[[0]]]
    est = [separate_one_channel(mix[0]), separate_one_channel(mix[1])]
    tgt = [data0[1:], data1[1:]]
    
    return est, tgt, mix

In [38]:
class Audio():
    def __init__(self, audio):
        self.audio = audio

class Track():
    def __init__(self, name, mixture, drums, bass, other, vocals, subset='test'):
        self.targets = {"mixture": Audio(mixture), "drums": Audio(drums), "bass": Audio(bass), "other": Audio(other), "vocals": Audio(vocals)}
        self.rate = 16000
        self.name = name
        self.subset = subset
      
def calculate_sdr(name, est, tgt, mix):
    inst = ["drums","bass","other","vocals"]
    
    track_list = []
    track_list.append(Audio(torch.cat([mix[0], mix[1]], dim=0).transpose(0,1).numpy()))
    pred_dict = {}
    for i in range(len(inst)):
        tgt_ = torch.cat([tgt[0][[i]], tgt[1][[i]]], dim=0).transpose(0,1).numpy()
        est_ = torch.cat([est[0][[i]], est[1][[i]]], dim=0).transpose(0,1).numpy()
        est_ = est_ / np.std(est_) * np.std(tgt_)
        pred_dict[inst[i]] = est_
        track_list.append(tgt_)
    
    
    track = Track(name, *track_list)
  
    score = museval.eval_mus_track(
        track, pred_dict, output_dir="eval_sdr"
    )
    return score

In [32]:
root = 'data/musdb18_16000/test/'
# for song in os.listdir(root):
song = np.random.choice(os.listdir(root))
filename = f'{root}/{song}'
# filename = 'data/musdb18_16000/test/Carlos Gonzalez - A Place For Us'

    
        

In [39]:
if os.path.isdir(filename):
    est, tgt, mix = separate(filename)
    scores = calculate_sdr('song', est, tgt, mix)
    
    print(scores)

drums           ==> SDR:   2.561  SIR:   2.686  ISR:   9.202  SAR:   3.471  
bass            ==> SDR:  -1.557  SIR:  -8.028  ISR:   2.677  SAR:   2.917  
other           ==> SDR:  -0.102  SIR:  -4.170  ISR:   4.806  SAR:   4.609  
vocals          ==> SDR:  -3.337  SIR:  -3.685  ISR:   3.933  SAR:   0.722  



In [45]:
file_names = ['sigma_003_03','sigma_003_05','sigma_005_03','sigma_005_05']
for st in file_names:
    model = DiffSepModel.load_from_checkpoint(f'/home/qor6271/Desktop/diff-music-sep/exp/musdb18_16000/{st}/checkpoints/epoch-999_si_sdr-0.000.ckpt')
    model = model.to(device)
    model.eval()

    if os.path.isdir(filename):
        est, tgt, mix = separate(filename)
        scores = calculate_sdr('song', est, tgt, mix)
        
        print(scores)

drums           ==> SDR:   1.442  SIR:   2.030  ISR:   7.555  SAR:   2.491  
bass            ==> SDR:  -1.831  SIR:  -8.849  ISR:   2.176  SAR:   2.123  
other           ==> SDR:   0.012  SIR:  -3.738  ISR:   4.953  SAR:   4.887  
vocals          ==> SDR:  -3.886  SIR:  -9.733  ISR:   3.079  SAR:   2.376  

drums           ==> SDR:   1.421  SIR:   3.342  ISR:   6.943  SAR:   0.649  
bass            ==> SDR:  -1.722  SIR:  -7.618  ISR:   2.251  SAR:   0.434  
other           ==> SDR:   0.144  SIR:  -3.517  ISR:   5.061  SAR:   4.817  
vocals          ==> SDR:  -4.277  SIR: -10.147  ISR:   2.231  SAR:   0.526  

drums           ==> SDR:   1.321  SIR:   1.561  ISR:   7.522  SAR:   2.342  
bass            ==> SDR:  -1.666  SIR:  -8.455  ISR:   2.474  SAR:   2.235  
other           ==> SDR:  -0.012  SIR:  -3.685  ISR:   5.045  SAR:   5.410  
vocals          ==> SDR:  -4.122  SIR: -10.416  ISR:   2.539  SAR:   2.528  

drums           ==> SDR:   0.963  SIR:   0.944  ISR:   6.865  SAR:   0.59

In [None]:
f, ax = plt.subplots(4,1)
for i in range(4):
    ax[i].plot(est[0][i])
    display(ipd.Audio(est[0][i], rate=16000))

In [None]:
f, ax = plt.subplots(4,1)
for i in range(4):
    ax[i].plot(est[0][i])
    display(ipd.Audio(est[0][i], rate=16000))

In [None]:
f, ax = plt.subplots(4,1)
for i in range(4):
    ax[i].plot(est[0][i])
    display(ipd.Audio(est[0][i], rate=16000))

In [None]:
f, ax = plt.subplots(4,1)
for i in range(4):
    ax[i].plot(tgt[0][i])
    display(ipd.Audio(tgt[0][i], rate=16000))

In [None]:
root = 'data/musdb18_16000/test/'
song = np.random.choice(os.listdir(root))
filename = f'{root}/{song}'
# filename = 'data/musdb18_16000/test/We Fell From The Sky - Not You'

f, ax = plt.subplots(4,1)
data = [torchaudio.load(f'{filename}/{inst}')[0][0] for inst in ['mixture.wav', 'drums.wav', 'bass.wav', 'other.wav', 'vocals.wav']]

for i in range(4):
    ax[i].plot(data[i+1].numpy())
    display(ipd.Audio(data[i+1], rate=16000))

print(song)


In [29]:
sigma_min = 0.05
sigma_max = 0.5
rho = sigma_max/sigma_min
logrho = np.log(rho)

t= np.linspace(0,1,10)

sigma_min*rho**t*np.sqrt(2*logrho)

array([0.1072983 , 0.13858109, 0.17898435, 0.23116718, 0.2985639 ,
       0.3856101 , 0.4980346 , 0.64323642, 0.83077178, 1.07298301])

In [34]:
sigma_min**2*(rho**(2*t)-np.e**(-2*2*t))*logrho/(2+logrho)

array([0.        , 0.00137392, 0.00317278, 0.00585735, 0.01013281,
       0.01713476, 0.02873139, 0.04802231, 0.08016726, 0.13376629])

In [53]:
sigma_min = 0.03
sigma_max = 0.3
rho = sigma_max/sigma_min
logrho = np.log(rho)

In [54]:
sigma_min*rho**t*np.sqrt(2*(logrho+2))

array([0.08800371, 0.11366116, 0.14679904, 0.18959825, 0.24487555,
       0.31626894, 0.40847704, 0.52756839, 0.68138077, 0.88003711])

In [55]:
(sigma_min*rho**t)**2

array([0.0009    , 0.00150129, 0.0025043 , 0.00417743, 0.00696837,
       0.01162395, 0.01938991, 0.03234432, 0.05395358, 0.09      ])

In [2]:

root = 'data/musdb18_16000/test/'
# for song in os.listdir(root):
song = np.random.choice(os.listdir(root))
filename = f'{root}/{song}'

mix = torchaudio.load(f'{filename}/mixture.wav')[0]

print(mix.std())
for i, inst in enumerate(['drums.wav', 'bass.wav', 'other.wav', 'vocals.wav']):
    print(torchaudio.load(f'{filename}/{inst}')[0].std())

tensor(0.1026)
tensor(0.0391)
tensor(0.0570)
tensor(0.0356)
tensor(0.0667)
