In [1]:
import nussl
import torch
from nussl.datasets import transforms as nussl_tfm
from models import MaskInference
from utils import utils, data
from pathlib import Path

In [2]:
utils.logger()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MAX_MIXTURES = int(1e8) # We'll set this to some impossibly high number for on the fly mixing.

stft_params = nussl.STFTParams(window_length=512, hop_length=128)

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
    nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_magnitudes', 1),
    nussl_tfm.ToSeparationModel(),
])

train_folder = "~/audio_isolation/data/train"
val_folder = "~/audio_isolation/data/valid"

train_data = data.on_the_fly(stft_params, transform=tfm, 
    fg_path=train_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)
train_dataloader = torch.utils.data.DataLoader(
    train_data, num_workers=1, batch_size=10)

val_data = data.on_the_fly(stft_params, transform=tfm, 
    fg_path=val_folder, num_mixtures=10, coherent_prob=1.0)
val_dataloader = torch.utils.data.DataLoader(
    val_data, num_workers=1, batch_size=10)

nf = stft_params.window_length // 2 + 1
model = MaskInference.build(nf, 1, 50, 1, True, 0.0, 1, 'sigmoid')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nussl.ml.train.loss.L1Loss()

def train_step(engine, batch):
    optimizer.zero_grad()
    output = model(batch) # forward pass
    loss = loss_fn(
        output['estimates'],
        batch['source_magnitudes']
    )
    
    loss.backward() # backwards + gradient step
    optimizer.step()
    
    loss_vals = {
        'L1Loss': loss.item(),
        'loss': loss.item()
    }
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch) # forward pass
    loss = loss_fn(
        output['estimates'],
        batch['source_magnitudes']
    )    
    loss_vals = {
        'L1Loss': loss.item(), 
        'loss': loss.item()
    }
    return loss_vals

# Create the engines
trainer, validator = nussl.ml.train.create_train_and_validation_engines(
    train_step, val_step, device=DEVICE
)

# We'll save the output relative to this notebook.
output_folder = Path('.').absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(output_folder, model, 
    optimizer, train_data, trainer, val_dataloader, validator)

trainer.run(
    train_dataloader, 
    epoch_length=10, 
    max_epochs=25
)

ScaperError: Label value must match one of the available labels: ['A Classic Education - NightOwl', 'ANiMAL - Clinic A', 'ANiMAL - Easy Tiger', 'ANiMAL - Rockshow', "Actions - Devil's Words", 'Actions - One Minute Smile', 'Actions - South Of The Water', 'Aimee Norwich - Child', 'Alexander Ross - Goodbye Bolero', 'Alexander Ross - Velvet Curtain', 'Angela Thomas Wade - Milk Cow Blues', 'Atlantis Bound - It Was My Fault For Waiting', 'Auctioneer - Our Future Faces', 'AvaLuna - Waterduct', 'BigTroubles - Phantom', 'Bill Chudziak - Children Of No-one', 'Black Bloc - If You Want Success', 'Celestial Shore - Die For Us', 'Chris Durban - Celebrate', 'Clara Berry And Wooldog - Air Traffic', 'Clara Berry And Wooldog - Stella', 'Clara Berry And Wooldog - Waltz For My Victims', 'Cnoc An Tursa - Bannockburn', 'Creepoid - OldTree', 'Dark Ride - Burning Bridges', 'Dreamers Of The Ghetto - Heavy Love', 'Drumtracks - Ghost Bitch', 'Faces On Film - Waiting For Ga', 'Fergessen - Back From The Start', 'Fergessen - Nos Palpitants', 'Fergessen - The Wind', 'Flags - 54', 'Giselle - Moss', 'Grants - PunchDrunk', 'Helado Negro - Mitad Del Mundo', 'Hezekiah Jones - Borrowed Heart', 'Hollow Ground - Left Blind', 'Hop Along - Sister Cities', 'Invisible Familiars - Disturbing Wildlife', 'James May - All Souls Moon', 'James May - Dont Let Go', 'James May - If You Say', 'James May - On The Line', 'Jay Menon - Through My Eyes', 'Johnny Lokke - Promises & Lies', 'Johnny Lokke - Whisper To A Scream', 'Jokers, Jacks & Kings - Sea Of Leaves', 'Leaf - Come Around', 'Leaf - Summerghost', 'Leaf - Wicked', 'Lushlife - Toynbee Suite', 'Matthew Entwistle - Dont You Ever', 'Meaxic - Take A Step', 'Meaxic - You Listen', 'Music Delta - 80s Rock', 'Music Delta - Beatles', 'Music Delta - Britpop', 'Music Delta - Country1', 'Music Delta - Country2', 'Music Delta - Disco', 'Music Delta - Gospel', 'Music Delta - Grunge', 'Music Delta - Hendrix', 'Music Delta - Punk', 'Music Delta - Reggae', 'Music Delta - Rock', 'Music Delta - Rockabilly', 'Night Panther - Fire', 'North To Alaska - All The Same', 'Patrick Talbot - A Reason To Leave', 'Patrick Talbot - Set Me Free', "Phre The Eon - Everybody's Falling Apart", 'Port St Willow - Stay Even', 'Remember December - C U Next Time', 'Secret Mountains - High Horse', 'Skelpolu - Human Mistakes', 'Skelpolu - Together Alone', 'Snowmine - Curfews', "Spike Mullings - Mike's Sulking", 'St Vitus - Word Gets Around', 'Steven Clark - Bounty', 'Strand Of Oaks - Spacestation', 'Sweet Lights - You Let Me Down', 'Swinging Steaks - Lost My Way', 'The Districts - Vermont', 'The Long Wait - Back Home To Blue', 'The Scarlet Brand - Les Fleurs Du Mal', 'The So So Glos - Emergency', "The Wrong'Uns - Rothko", 'Tim Taler - Stalker', 'Titanium - Haunted Age', 'Traffic Experiment - Once More (With Feeling)', 'Traffic Experiment - Sirens', 'Triviul - Angelsaint', 'Triviul - Dorothy', 'Voelund - Comfort Lives In Belief', 'Wall Of Death - Femme', 'Young Griffo - Blood To Bone', 'Young Griffo - Facade', 'Young Griffo - Pennies']

In [None]:
separator = nussl.separation.deep.DeepMaskEstimation(
    nussl.AudioSignal(), model_path='checkpoints/best.model.pth',
    device=DEVICE,
)



In [None]:
import json

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
])
test_dataset = nussl.datasets.MUSDB18(subsets=['test'], transform=tfm)

# Just do 5 items for speed. Change to 50 for actual experiment.
for i in range(5):
    item = test_dataset[i]
    separator.audio_signal = item['mix']
    estimates = separator()

    source_keys = list(item['sources'].keys())
    estimates = {
        'vocals': estimates[0],
        'bass+drums+other': item['mix'] - estimates[0]
    }

    sources = [item['sources'][k] for k in source_keys]
    estimates = [estimates[k] for k in source_keys]

    evaluator = nussl.evaluation.BSSEvalScale(
        sources, estimates, source_labels=source_keys
    )
    scores = evaluator.evaluate()
    output_folder = Path(output_folder).absolute()
    output_folder.mkdir(exist_ok=True)
    output_file = output_folder / sources[0].file_name.replace('wav', 'json')
    with open(output_file, 'w') as f:
        json.dump(scores, f, indent=4)

In [None]:
import glob
import numpy as np

json_files = glob.glob(f"*.json")
df = nussl.evaluation.aggregate_score_files(
    json_files, aggregator=np.nanmedian)
nussl.evaluation.associate_metrics(separator.model, df, test_dataset)
report_card = nussl.evaluation.report_card(
    df, report_each_source=True)
print(report_card)

TypeError: Could not convert ["BKS - Bulldozer_vocals.jsonArise - Run Run Run_vocals.jsonAl James - Schoolboy Facination_vocals.jsonAM Contra - Heart Peripheral_vocals.jsonAngels In Amplifiers - I'm Alright_vocals.jsonBKS - Bulldozer_vocals.jsonArise - Run Run Run_vocals.jsonAl James - Schoolboy Facination_vocals.jsonAM Contra - Heart Peripheral_vocals.jsonAngels In Amplifiers - I'm Alright_vocals.json"] to numeric

In [None]:
separator.model.save('checkpoints/best.model.pth')

'checkpoints/best.model.pth'

In [None]:
model_checkpoint = torch.load('checkpoints/best.model.pth')

In [None]:
model_checkpoint['metadata'].keys()

dict_keys(['config', 'nussl_version', 'evaluation', 'test_dataset'])

In [None]:
model_checkpoint['metadata']['evaluation']

{'SI-SDR': {'mean': 1.3912683270056592,
  'median': 1.0962115067756693,
  'std': 4.804988720552033},
 'SI-SIR': {'mean': 1.6625557918403284,
  'median': 1.3613294923719872,
  'std': 4.892045534713316},
 'SI-SAR': {'mean': 13.869686333945719,
  'median': 13.402982023422254,
  'std': 3.3645639731478334},
 'SD-SDR': {'mean': -1.7625529689954553,
  'median': -1.8067065858755935,
  'std': 3.4507046858153183},
 'SNR': {'mean': 3.551795662930993,
  'median': 3.6060124104499063,
  'std': 2.5600226865690074},
 'SRR': {'mean': 1.610932020926893,
  'median': 1.4595263614094136,
  'std': 2.073824737064923},
 'SI-SDRi': {'mean': 1.3778954420719303,
  'median': 1.2909034958714753,
  'std': 0.5523867936330886},
 'SD-SDRi': {'mean': -1.775784962047302,
  'median': -1.6789136928611672,
  'std': 1.8198928173376951},
 'SNRi': {'mean': 3.551795662930992,
  'median': 3.6060124104499063,
  'std': 2.5809225218167002},
 'MIX-SI-SDR': {'mean': 0.013372884933728813,
  'median': 0.04601839003699992,
  'std': 5.0