In [1]:
import torch
import torch.nn as nn
import nussl
from nussl.datasets import transforms as nussl_tfm
from nussl.ml.networks.modules import BatchNorm, RecurrentStack, Embedding, STFT, LearnedFilterBank, AmplitudeToDB
from models.MaskInference import MaskInference
from models.UNet import UNetSpect
from models.Filterbank import Filterbank
from utils import utils, data, viz
from pathlib import Path
import yaml, argparse
import numpy as np
import matplotlib.pyplot as plt
import tqdm

In [2]:
#Load yaml configs into configs dictionary
with open('config/test_filterbank.yml','r') as f:
    configs = yaml.safe_load(f)
    f.close()

In [3]:
utils.logger()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_type = configs['model_type']
model_dict = {'Mask': MaskInference,
              'UNet': UNetSpect,
              'Filterbank':Filterbank
             }
waveform_models = ['Filterbank']
assert model_type in model_dict.keys(), f'Model type must be one of {model_dict.keys()}'

if model_type in waveform_models:
    stft_params = None
    
    tfm = nussl_tfm.Compose([
        nussl_tfm.SumSources([['bass', 'drums', 'other']]),
        nussl_tfm.GetAudio(),
        nussl_tfm.IndexSources('source_audio', 1),
        nussl_tfm.ToSeparationModel(),
    ])
    
    target_key = 'source_audio'
    output_key = 'audio'
    
else:
    stft_params = nussl.STFTParams(**configs['stft_params'])
    
    tfm = nussl_tfm.Compose([
        nussl_tfm.SumSources([['bass', 'drums', 'other']]),
        nussl_tfm.MagnitudeSpectrumApproximation(),
        nussl_tfm.IndexSources('source_magnitudes', 1),
        nussl_tfm.ToSeparationModel(),
    ])
    
    target_key = 'source_magnitudes'
    output_key = 'estimates'


configs['batch_size'] = 1
configs['train_generator_params']['num_mixtures']=10
configs['valid_generator_params']['num_mixtures']=1

duration=5

train_data = data.on_the_fly(stft_params, transform=tfm, fg_path=configs['test_folder'], **configs['train_generator_params'], duration=duration)
train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=1, batch_size=configs['batch_size'])

val_data = data.on_the_fly(stft_params, transform=tfm, fg_path=configs['test_folder'], **configs['valid_generator_params'], duration=duration)
val_dataloader = torch.utils.data.DataLoader(val_data, num_workers=1, batch_size=configs['batch_size'])

In [4]:
overfit_selection=1

In [5]:
loss_type = configs['loss_type']
loss_dict = {'L1': nussl.ml.train.loss.L1Loss,
             'L2': nussl.ml.train.loss.MSELoss,
             'MSE': nussl.ml.train.loss.MSELoss,}
assert loss_type in loss_dict.keys(), f'Loss type must be one of {loss_dict.keys()}'
loss_fn = loss_dict[loss_type]()

In [6]:
def train_step(engine, batch):
    optimizer.zero_grad()
    
    #Forward pass
    output = model(batch)
    loss = loss_fn(output[output_key],batch[target_key])
    
    #Backward pass
    loss.backward()
    optimizer.step()
    
    loss_vals = {'loss':loss.item()}
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch)
    loss = loss_fn(output[output_key],batch[target_key])  
    loss_vals = {'loss':loss.item()}
    return loss_vals

In [7]:
#Set up the model and optimizer
if model_type=='Mask':
    model = MaskInference.build(stft_params.window_length//2+1, **configs['model_params']).to(device)
elif model_type=='UNet':
    model = UNetSpect.build(**configs['model_params']).to(device)
elif model_type=='Filterbank':
    model = Filterbank.build(**configs['model_params']).to(device)

optimizer = torch.optim.Adam(model.parameters(), **configs['optimizer_params'])

In [8]:
for i,batch in enumerate(train_dataloader):
    if i==overfit_selection:
        batch=batch
        break
    
for key in batch:
    if torch.is_tensor(batch[key]):
        batch[key] = batch[key].float().to(device)   

In [9]:
configs['optimizer_params']['lr'] = 1e-1
optimizer = torch.optim.Adam(model.parameters(), **configs['optimizer_params'])

In [10]:
# Create nussl ML engine
trainer, validator = nussl.ml.train.create_train_and_validation_engines(train_step, val_step, device=device)

In [11]:
N_ITERATIONS = 200
loss_history = [] # For bookkeeping

for i in range(N_ITERATIONS):
    loss_val = train_step(trainer,batch)
    loss_history.append(loss_val['loss'])
    if i%20==0:
        print(f'Loss: {loss_val["loss"]:.6f} at iteration {i}')

Loss: 11.452027 at iteration 0
Loss: 0.185101 at iteration 20
Loss: 0.021633 at iteration 40
Loss: 0.014579 at iteration 60
Loss: 0.012950 at iteration 80
Loss: 0.012403 at iteration 100
Loss: 0.011581 at iteration 120



KeyboardInterrupt



In [None]:
configs['train_params']['epoch_length']=2
configs['optimizer_params']['lr'] = 1e-10
optimizer = torch.optim.Adam(model.parameters(), **configs['optimizer_params'])

# Save model outputs
checkpoint_folder = Path('overfit').absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(checkpoint_folder, model, optimizer, train_data, trainer, val_dataloader, validator)
nussl.ml.train.add_progress_bar_handler(trainer, validator)

In [None]:
trainer.run(train_dataloader, **configs['train_params'])

In [None]:
target = batch[target_key].detach().numpy()
output = model(batch)
estimates = output[output_key].detach().numpy()

In [None]:
print(np.max(target))
print(np.max(estimates))
print(np.min(target))
print(np.min(estimates))
print(np.std(target))
print(np.std(estimates))

In [None]:
#Load in the model
if model_type in waveform_models:
    separator = nussl.separation.deep.DeepAudioEstimation(
        nussl.AudioSignal(), model_path='overfit/checkpoints/latest.model.pth',
        device='cpu',
    )
else:
    separator = nussl.separation.deep.DeepMaskEstimation(
        nussl.AudioSignal(), model_path='overfit/checkpoints/latest.model.pth',
        device='cpu',
    )

In [None]:
#Test on the data
test_folder = configs['test_folder']
tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
])
test_data = data.mixer(stft_params, transform=tfm, fg_path=configs['train_folder'], num_mixtures=1, coherent_prob=1.0, duration=5)

test_data = data.on_the_fly(stft_params, transform=None, fg_path=configs['test_folder'], **configs['train_generator_params'], duration=duration)

signal = test_data[overfit_selection]['mix']

In [None]:
separator.audio_signal = signal
estimates = separator()
estimates.append(signal - estimates[0])
viz.show_sources(estimates)