In [1]:
import nussl
import torch
import yaml
from nussl.datasets import transforms as nussl_tfm
from models.MaskInference import MaskInference
from utils import data, viz, utils
from pathlib import Path

In [2]:
#Load yaml configs into configs dictionary
with open('config/stft_mask.yml','r') as f:
    configs = yaml.safe_load(f)
    f.close()

In [3]:
utils.logger()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

stft_params = nussl.STFTParams(**configs['stft_params'])

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
    nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_magnitudes', 1),
    nussl_tfm.ToSeparationModel(),
])

train_data = data.on_the_fly(stft_params, transform=tfm, fg_path=configs['test_folder'], **configs['train_generator_params'])
train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=1, batch_size=configs['batch_size'])

val_data = data.on_the_fly(stft_params, transform=tfm, fg_path=configs['valid_folder'], **configs['valid_generator_params'])
val_dataloader = torch.utils.data.DataLoader(val_data, num_workers=1, batch_size=configs['batch_size'])

In [5]:
loss_fn = nussl.ml.train.loss.L1Loss()

def train_step(engine, batch):
    optimizer.zero_grad()
    
    #Forward pass
    output = model(batch)
    loss = loss_fn(output['estimates'],batch['source_magnitudes'])
    
    #Backward pass
    loss.backward()
    optimizer.step()
    
    loss_vals = {'loss_L1':loss.item(), 'loss':loss.item()}
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch)
    loss = loss_fn(output['estimates'],batch['source_magnitudes'])  
    loss_vals = {'loss_L1': loss.item(), 'loss':loss.item()}
    return loss_vals

In [6]:
#Set up the model and optimizer
model = MaskInference.build(stft_params.window_length//2+1, **configs['model_params'])
optimizer = torch.optim.Adam(model.parameters(), **configs['optimizer_params'])

# Create nussl ML engine
trainer, validator = nussl.ml.train.create_train_and_validation_engines(train_step, val_step, device=device)

# Save model outputs
checkpoint_folder = Path('models').absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(checkpoint_folder, model, optimizer, train_data, trainer, val_dataloader, validator)

In [6]:
trainer.run(train_dataloader, **configs['train_params'])

04/21/2023 11:48:08 AM | engine.py:874 Engine run starting with max_epochs=10.
04/21/2023 11:48:18 AM | engine.py:874 Engine run starting with max_epochs=1.
04/21/2023 11:48:24 AM | engine.py:972 Epoch[1] Complete. Time taken: 00:00:06.579
04/21/2023 11:48:24 AM | engine.py:988 Engine run complete. Time taken: 00:00:06.621
04/21/2023 11:48:24 AM | trainer.py:311 

EPOCH SUMMARY 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 
- Epoch number: 0001 / 0010 
- Training loss:   0.000932 
- Validation loss: 0.000755 
- Epoch took: 0:00:16.165835 
- Time since start: 0:00:16.165910 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 
Saving to /SFS/user/ry/stonekev/audio/audio_isolation/models/checkpoints/checkpoints/best.model.pth. 
Output @ /SFS/user/ry/stonekev/audio/audio_isolation/models/checkpoints 

04/21/2023 11:48:24 AM | engine.py:972 Epoch[1] Complete. Time taken: 00:00:16.139
04/21/2023 11:48:27 AM | engine.py:874 Engine run starting with max_epochs=1.
04/21/2023 11:48:34 AM | engine.py:972 Epoch[1] Complete. Time t