In [1]:
import nussl
import torch
from nussl.datasets import transforms as nussl_tfm
from models.MaskInference import MaskInference
from utils import data, viz, utils
from pathlib import Path

In [150]:
# # Prepare MUSDB
# data.prepare_musdb('~/.nussl/tutorial/')
# data.prepare_musdbhq(folder='data/musdb18hq/',musdb_root='/SFS/user/ry/stonekev/.nussl/',download=True)

In [12]:
model_path = nussl.efz_utils.download_trained_model(
    'mask-inference-wsj2mix-model-v1.pth')

MetadataError: No matching metadata for file mask-inference-wsj2mix-model-v1.pth at url http://nussl.ci.northwestern.edu/audio-metadata.json!

In [2]:
utils.logger()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MAX_MIXTURES = int(1e8) # We'll set this to some impossibly high number for on the fly mixing.

stft_params = nussl.STFTParams(window_length=512, hop_length=128)

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
    nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_magnitudes', 1),
    nussl_tfm.ToSeparationModel(),
])

train_folder = "~/.nussl/tutorial/train"
val_folder = "~/.nussl/tutorial/valid"

train_data = data.on_the_fly(stft_params, transform=tfm, 
    fg_path=train_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)
train_dataloader = torch.utils.data.DataLoader(
    train_data, num_workers=1, batch_size=10)

val_data = data.on_the_fly(stft_params, transform=tfm, 
    fg_path=val_folder, num_mixtures=10, coherent_prob=1.0)
val_dataloader = torch.utils.data.DataLoader(
    val_data, num_workers=1, batch_size=10)

In [10]:
train_data[0]['source_magnitudes'].shape

dict_keys(['index', 'mix_magnitude', 'ideal_binary_mask', 'source_magnitudes'])

In [8]:
train_data[0]['source_magnitudes'].shape

torch.Size([1724, 257, 1, 1])

In [3]:
nf = stft_params.window_length // 2 + 1
model = MaskInference.build(nf, 1, 50, 1, True, 0.0, 1, 'sigmoid')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nussl.ml.train.loss.L1Loss()

In [4]:
model

SeparationModel(
  (layers): ModuleDict(
    (model): MaskInference(
      (amplitude_to_db): AmplitudeToDB()
      (input_normalization): BatchNorm(
        (batch_norm): BatchNorm1d(257, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (recurrent_stack): RecurrentStack(
        (rnn): LSTM(257, 50, batch_first=True, bidirectional=True)
      )
      (embedding): Embedding(
        (linear): Linear(in_features=100, out_features=257, bias=True)
      )
    )
    (mask): Alias()
    (estimates): Alias()
  )
)
Number of parameters: 150071

In [152]:
def train_step(engine, batch):
    optimizer.zero_grad()
    output = model(batch) # forward pass
    loss = loss_fn(
        output['estimates'],
        batch['source_magnitudes']
    )
    
    loss.backward() # backwards + gradient step
    optimizer.step()
    
    loss_vals = {
        'L1Loss': loss.item(),
        'loss': loss.item()
    }
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch) # forward pass
    loss = loss_fn(
        output['estimates'],
        batch['source_magnitudes']
    )    
    loss_vals = {
        'L1Loss': loss.item(), 
        'loss': loss.item()
    }
    return loss_vals

In [163]:
# Create the engines
trainer, validator = nussl.ml.train.create_train_and_validation_engines(
    train_step, val_step, device=DEVICE
)

# We'll save the output relative to this notebook.
output_folder = Path('outputs').absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(output_folder, model, 
    optimizer, train_data, trainer, val_dataloader, validator)

In [164]:
# trainer.run(
#     train_dataloader, 
#     epoch_length=10, 
#     max_epochs=1
# )

In [165]:
separator = nussl.separation.deep.DeepMaskEstimation(
    nussl.AudioSignal(), model_path='checkpoints/best.model.pth',
    device=DEVICE,
)



In [167]:
test_folder = "~/.nussl/tutorial/test/"
test_data = data.mixer(stft_params, transform=None, 
    fg_path=test_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)
item = test_data[0]

separator.audio_signal = item['mix']
estimates = separator()
# Since our model only returns one source, let's tack on the
# residual (which should be accompaniment)
estimates.append(item['mix'] - estimates[0])

#viz.show_sources(estimates)

In [168]:
import json

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
])
test_dataset = nussl.datasets.MUSDB18(subsets=['test'], transform=tfm)

# Just do 5 items for speed. Change to 50 for actual experiment.
for i in range(5):
    item = test_dataset[i]
    separator.audio_signal = item['mix']
    estimates = separator()

    source_keys = list(item['sources'].keys())
    estimates = {
        'vocals': estimates[0],
        'bass+drums+other': item['mix'] - estimates[0]
    }

    sources = [item['sources'][k] for k in source_keys]
    estimates = [estimates[k] for k in source_keys]

    evaluator = nussl.evaluation.BSSEvalScale(
        sources, estimates, source_labels=source_keys
    )
    scores = evaluator.evaluate()
    output_folder = Path(output_folder).absolute()
    output_folder.mkdir(exist_ok=True)
    output_file = output_folder / sources[0].file_name.replace('wav', 'json')
    with open(output_file, 'w') as f:
        json.dump(scores, f, indent=4)

In [192]:
import glob
import numpy as np

json_files = glob.glob(f'outputs/*.json')
df = nussl.evaluation.aggregate_score_files(
    json_files, aggregator=np.nanmedian)
#nussl.evaluation.associate_metrics(separator.model, df, test_dataset)

In [193]:
df

Unnamed: 0,source,file,SI-SDR,SI-SIR,SI-SAR,SD-SDR,SNR,SRR,SI-SDRi,SD-SDRi,SNRi,MIX-SI-SDR,MIX-SD-SDR,MIX-SNR
0,vocals,Al James - Schoolboy Facination_vocals.json,-0.784535,-0.598966,13.001692,-3.031669,2.547792,0.905951,0.30355,-1.943431,3.687373,-1.088084,-1.088238,-1.139581
1,vocals,Angels In Amplifiers - I'm Alright_vocals.json,-1.787292,-1.642771,13.076481,-3.987777,2.021374,0.022919,0.411055,-1.789207,4.153718,-2.198346,-2.19857,-2.132344
2,vocals,BKS - Bulldozer_vocals.json,-5.427542,-5.334472,11.316941,-6.423873,-0.629328,0.459113,0.568335,-0.42778,5.434137,-5.995878,-5.996093,-6.063465
3,vocals,Arise - Run Run Run_vocals.json,-8.229635,-8.162679,9.938586,-9.085124,-2.230849,-1.594217,0.094274,-0.761121,6.175299,-8.32391,-8.324004,-8.406149
4,vocals,AM Contra - Heart Peripheral_vocals.json,-3.116621,-3.043049,14.635663,-4.313564,0.819025,1.880799,0.685322,-0.511602,4.596596,-3.801943,-3.801961,-3.77757
5,bass+drums+other,Al James - Schoolboy Facination_vocals.json,1.404598,1.66205,13.808934,-2.30701,3.663484,0.100976,0.22448,-3.486975,2.523903,1.180118,1.179965,1.139581
6,bass+drums+other,Angels In Amplifiers - I'm Alright_vocals.json,2.391235,2.609235,15.493062,-1.530708,4.132972,0.73177,0.298231,-3.623489,2.000628,2.093005,2.092781,2.132344
7,bass+drums+other,BKS - Bulldozer_vocals.json,6.372875,6.723571,17.476411,-0.169282,5.403282,0.92049,0.291079,-6.250863,-0.660183,6.081796,6.08158,6.063465
8,bass+drums+other,Arise - Run Run Run_vocals.json,8.279771,8.582419,19.999837,1.098618,6.166107,2.025606,-0.139036,-7.320096,-2.240042,8.418808,8.418714,8.406149
9,bass+drums+other,AM Contra - Heart Peripheral_vocals.json,4.402479,4.625504,17.41295,-1.710477,4.557014,-0.486884,0.634297,-5.478641,0.779443,3.768182,3.768164,3.77757
