In [1]:
import torch
import torch.nn as nn
import nussl
from nussl.datasets import transforms as nussl_tfm
from nussl.ml.networks.modules import BatchNorm, RecurrentStack, Embedding, STFT, LearnedFilterBank, AmplitudeToDB
from models.MaskInference import MaskInference
from models.UNet import UNetSpect
from models.Filterbank import Filterbank
from utils import utils, data
from pathlib import Path
import yaml, argparse
import numpy as np
import matplotlib.pyplot as plt

In [2]:
#Load yaml configs into configs dictionary
with open('config/test_auto.yml','r') as f:
    configs = yaml.safe_load(f)
    f.close()

In [3]:
utils.logger()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

stft_params = nussl.STFTParams(**configs['stft_params'])

#############CHANGE FOR WAVEFORM##################
tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
    nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_magnitudes', 1),
    nussl_tfm.ToSeparationModel(),
])

if configs['model_type'] == 'Filterbank':
    stft_params=None

#############CHANGE FOR WAVEFORM##################
train_data = data.on_the_fly(stft_params, transform=tfm, fg_path=configs['test_folder'], **configs['train_generator_params'])
train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=1, batch_size=configs['batch_size'])

val_data = data.on_the_fly(stft_params, transform=tfm, fg_path=configs['valid_folder'], **configs['valid_generator_params'])
val_dataloader = torch.utils.data.DataLoader(val_data, num_workers=1, batch_size=configs['batch_size'])

In [4]:
loss_type = configs['loss_type']
loss_dict = {'L1': nussl.ml.train.loss.L1Loss,
             'L2': nussl.ml.train.loss.MSELoss,
             'MSE': nussl.ml.train.loss.MSELoss,}
assert loss_type in loss_dict.keys(), f'Loss type must be one of {loss_dict.keys()}'
loss_fn = loss_dict[loss_type]()

In [5]:
def train_step(engine, batch):
    optimizer.zero_grad()
    
    #Forward pass
    output = model(batch)
    loss = loss_fn(output['estimates'],batch['source_magnitudes'])
    
    #Backward pass
    loss.backward()
    optimizer.step()
    
    loss_vals = {'loss':loss.item()}
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch)
    loss = loss_fn(output['estimates'],batch['source_magnitudes'])  
    loss_vals = {'loss':loss.item()}
    return loss_vals

In [6]:
model_type = configs['model_type']
model_dict = {'Mask': MaskInference,
              'UNet': UNetSpect,
              'Filterbank':Filterbank
             }
assert model_type in model_dict.keys(), f'Model type must be one of {model_dict.keys()}'
#Set up the model and optimizer
if model_type=='Mask':
    model = MaskInference.build(stft_params.window_length//2+1, **configs['model_params']).to(device)
elif model_type=='UNet':
    model = UNetSpect.build(**configs['model_params']).to(device)
elif model_type=='Filterbank':
    model = Filterbank.build(128, **configs['model_params']).to(device)
    
optimizer = torch.optim.Adam(model.parameters(), **configs['optimizer_params'])

In [7]:
# Create nussl ML engine
trainer, validator = nussl.ml.train.create_train_and_validation_engines(train_step, val_step, device=device)

# Save model outputs
checkpoint_folder = Path('models/'+configs['save_name']).absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(checkpoint_folder, model, optimizer, train_data, trainer, val_dataloader, validator)
nussl.ml.train.add_progress_bar_handler(trainer, validator)

  from tqdm.autonotebook import tqdm


In [11]:
# trainer.run(train_dataloader, **configs['train_params'])

In [12]:
for batch in train_dataloader:
    batch=batch
    break
    
for key in batch:
    if torch.is_tensor(batch[key]):
        batch[key] = batch[key].float().to(device)   

In [15]:
batch['source_magnitudes'].shape

torch.Size([10, 1724, 257, 1, 1])