In [3]:
import nussl
import torch
from nussl.datasets import transforms as nussl_tfm
from models.Waveform import Waveform
#from models.MaskInference import MaskInference
from utils import utils, data
from pathlib import Path

In [4]:
data.prepare_musdbhq(folder='data/musdb18hq/',musdb_root='/SFS/user/ry/stonekev/.nussl/',download=True)

100%|██████████████████████████████████████████████████████████████████████████████████| 21.1G/21.1G [4:07:55<00:00, 1.52MB/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 86/86 [08:21<00:00,  5.83s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [01:35<00:00,  6.85s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [05:33<00:00,  6.67s/it]


In [None]:
utils.logger()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MAX_MIXTURES = int(1e8) # We'll set this to some impossibly high number for on the fly mixing.

stft_params = nussl.STFTParams(window_length=512, hop_length=128)

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
    nussl_tfm.GetAudio(),
    #nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_audio', 1),
    nussl_tfm.ToSeparationModel(),
])

train_folder = "~/audio_isolation/data/tutorial/train"
val_folder = "~/audio_isolation/data/tutorial/test"

train_data = data.on_the_fly(stft_params, n_channels=2, transform=tfm, 
    fg_path=train_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)
train_dataloader = torch.utils.data.DataLoader(
    train_data, num_workers=1, batch_size=10)

val_data = data.on_the_fly(stft_params, n_channels=2, transform=tfm, 
    fg_path=val_folder, num_mixtures=10, coherent_prob=1.0)
val_dataloader = torch.utils.data.DataLoader(
    val_data, num_workers=1, batch_size=10)

In [None]:
import torchaudio
mix = '/Users/dev/audio_isolation/data/musdb18hq/train/bass/A Classic Education - NightOwl.wav'
waveform, sample_rate = torchaudio.load(mix)

In [None]:
from IPython.display import Audio, Video
Audio(mix)

In [None]:
waveform.shape

In [None]:
train_data[0]['mix_audio'].shape

In [None]:
train_data[0]['mix_audio'].shape

In [None]:
model = Waveform.build(1025, 2, 50, 2, True, 0.3, 1, 2048, 512, activation='sigmoid')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nussl.ml.train.loss.L1Loss()

In [None]:
model.config

In [None]:
# nf = stft_params.window_length // 2 + 1
# model = Waveform.build(nf, 1, 50, 1, True, 0.0, 1, 'sigmoid')
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# loss_fn = nussl.ml.train.loss.L1Loss()

def train_step(engine, batch):
    optimizer.zero_grad()
    output = model(batch) # forward pass
    loss = loss_fn(
        output['estimates'],
        batch['source_audio']
    )
    
    loss.backward() # backwards + gradient step
    optimizer.step()
    
    loss_vals = {
        'L1Loss': loss.item(),
        'loss': loss.item()
    }
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch) # forward pass
    loss = loss_fn(
        output['estimates'],
        batch['source_audio']
    )    
    loss_vals = {
        'L1Loss': loss.item(), 
        'loss': loss.item()
    }
    return loss_vals

# Create the engines
trainer, validator = nussl.ml.train.create_train_and_validation_engines(
    train_step, val_step, device=DEVICE
)

# We'll save the output relative to this notebook.
output_folder = Path('.').absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(output_folder, model, 
    optimizer, train_data, trainer, val_dataloader, validator)

trainer.run(
    train_dataloader, 
    epoch_length=10, 
    max_epochs=25
)

In [None]:
separator = nussl.separation.deep.DeepMaskEstimation(
    nussl.AudioSignal(), model_path='checkpoints/best.model.pth',
    device=DEVICE,
)

In [None]:
separator = nussl.separation.deep.DeepAudioEstimation(
    nussl.AudioSignal(), model_path='checkpoints/best.model.pth',
    device=DEVICE,
)

In [None]:
import json

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
])
test_dataset = nussl.datasets.MUSDB18(subsets=['test'], transform=tfm)

# Just do 5 items for speed. Change to 50 for actual experiment.
for i in range(5):
    item = test_dataset[i]
    separator.audio_signal = item['mix']
    estimates = separator()

    source_keys = list(item['sources'].keys())
    estimates = {
        'vocals': estimates[0],
        'bass+drums+other': item['mix'] - estimates[0]
    }

    sources = [item['sources'][k] for k in source_keys]
    estimates = [estimates[k] for k in source_keys]

    evaluator = nussl.evaluation.BSSEvalScale(
        sources, estimates, source_labels=source_keys
    )
    scores = evaluator.evaluate()
    output_folder = Path(output_folder).absolute()
    output_folder.mkdir(exist_ok=True)
    output_file = output_folder / sources[0].file_name.replace('wav', 'json')
    with open(output_file, 'w') as f:
        json.dump(scores, f, indent=4)

In [None]:
import glob
import numpy as np

json_files = glob.glob(f"*.json")
df = nussl.evaluation.aggregate_score_files(
    json_files, aggregator=np.nanmedian)
nussl.evaluation.associate_metrics(separator.model, df, test_dataset)
report_card = nussl.evaluation.report_card(
    df, report_each_source=True)
print(report_card)

In [None]:
separator.model.save('checkpoints/best.model.pth')

In [None]:
model_checkpoint = torch.load('checkpoints/best.model.pth')

In [None]:
model_checkpoint['metadata'].keys()

In [None]:
model_checkpoint['metadata']['evaluation']