In [1]:
# 2023 (c) LINE Corporation
# Authors: Robin Scheibler
# MIT License
import argparse
import json
import math
import os
import time
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
import yaml
from omegaconf import OmegaConf
from pesq import pesq
from pystoi import stoi

# from sdes.sdes import MixSDE
from datasets import NoisyDataset, WSJ0_mix, musdb_mix
from pl_model import DiffSepModel
import musdb

import IPython.display as ipd

  from .autonotebook import tqdm as notebook_tqdm


In [64]:
output_dir_base = Path("results")
batch_size = 16
max_len_s = 1
sr = 44100
max_data = max_len_s*sr
device = 'cuda:0'

model = DiffSepModel.load_from_checkpoint('exp/musdb/2023-11-22_18-04-49_/checkpoints/epoch-099_mse-0.000.ckpt')
# transfer to GPU
model = model.to(device)
model.eval()


DiffSepModel(
  (score_model): ScoreModelNCSNpp(
    (backbone): NCSNpp(
      (act): SiLU()
      (output_layer): Conv2d(10, 8, kernel_size=(1, 1), stride=(1, 1))
      (pyramid_upsample): Upsample()
      (pyramid_downsample): Downsample()
      (all_modules): ModuleList(
        (0): GaussianFourierProjection()
        (1): Linear(in_features=128, out_features=256, bias=True)
        (2): Linear(in_features=256, out_features=256, bias=True)
        (3): Conv2d(10, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): ResnetBlockBigGANpp(
          (GroupNorm_0): GroupNorm(16, 64, eps=1e-06, affine=True)
          (Conv_0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (Dense_0): Linear(in_features=256, out_features=64, bias=True)
          (GroupNorm_1): GroupNorm(16, 64, eps=1e-06, affine=True)
          (Dropout_0): Dropout(p=0.0, inplace=False)
          (Conv_1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       

In [65]:
musdb_list = musdb.DB(root='data/musdb18', subsets="test")

In [66]:
for idx in range(len(musdb_list)):
    data = musdb_list[idx].stems
    
    data = list(map(lambda x: torch.from_numpy(x).float().transpose(0,1).to(device)[[0]], data))
    
    mix_full = data[0].unsqueeze(0)
    tgt_full = torch.cat(data[1:], dim=0).unsqueeze(0)
    
    for t in range(mix_full.shape[-1]//(max_data*batch_size)+1):
        mix = mix_full[:,:,t*max_data*batch_size+44100*30:(t+1)*max_data*batch_size+44100*30]
        mix = list(mix.split(max_data, dim=2))
        
        tgt = tgt_full[:,:,t*max_data*batch_size+44100*30:(t+1)*max_data*batch_size+44100*30]
        tgt = list(tgt.split(max_data, dim=2))
        if mix[-1].shape[-1] != max_data:
            mix[-1] = torch.nn.functional.pad(mix[-1], (0, max_data-mix[-1].shape[-1]))
            tgt[-1] = torch.nn.functional.pad(tgt[-1], (0, max_data-tgt[-1].shape[-1]))

        mix = torch.cat(mix, dim=0)
        tgt = torch.cat(tgt, dim=0)
        batch, *stats = model.normalize_batch((mix, tgt))

        mix, target = batch

        est, *_ = model.separate(mix)

        est = model.denormalize_batch(est, *stats)
        est = torch.cat(est.split(1, dim=0), dim=2).squeeze()
        if t == 0:
            est_full = est
            break
        else:
            est_full = torch.cat([est_full, est], dim=1)
        
    break