In [1]:
import nussl
import torch
from nussl.datasets import transforms as nussl_tfm
from models.Waveform import Waveform
from utils import utils, data
from pathlib import Path

In [2]:
#data.prepare_musdbhq(folder='data/musdb18hq/',musdb_root='/SFS/user/ry/stonekev/.nussl/',download=True)

In [3]:
utils.logger()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MAX_MIXTURES = int(1e8) # We'll set this to some impossibly high number for on the fly mixing.

stft_params = nussl.STFTParams(window_length=512, hop_length=128)

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
    nussl_tfm.GetAudio(),
    #nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_audio', 1),
    nussl_tfm.ToSeparationModel(),
])

train_folder = "~/audio_isolation/data/musdb18hq/train"
val_folder = "~/audio_isolation/data/musdb18hq/test"

train_data = data.on_the_fly(stft_params, transform=tfm, 
    fg_path=train_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)
train_dataloader = torch.utils.data.DataLoader(
    train_data, num_workers=1, batch_size=10)

val_data = data.on_the_fly(stft_params, transform=tfm, 
    fg_path=val_folder, num_mixtures=10, coherent_prob=1.0)
val_dataloader = torch.utils.data.DataLoader(
    val_data, num_workers=1, batch_size=10)

In [4]:
sample = train_data[0]['mix_audio']
sample.shape

torch.Size([1, 220500])

In [5]:
sample

tensor([[ 0.0000,  0.0009,  0.0016,  ..., -0.0013, -0.0005,  0.0000]],
       dtype=torch.float64)

In [6]:
sample = train_data[0]['mix_audio'].unsqueeze(0)
sample_model = Waveform(1025, 1, 50, 2, True, 0.3, 1, 2048, 512).double()
sample_out = sample_model(sample)
print(sample_out)

{'estimates': tensor([[[[-8.4832e-09],
          [ 9.1094e-04],
          [ 1.6035e-03],
          ...,
          [-1.2705e-03],
          [-5.4304e-04],
          [ 1.2505e-07]]]], dtype=torch.float64, grad_fn=<PermuteBackward0>)}


In [7]:
sample_out['estimates'].size()

torch.Size([1, 1, 220500, 1])

In [8]:
sample_model

Waveform(
  (stft): STFT()
  (amplitude_to_db): AmplitudeToDB()
  (input_normalization): BatchNorm(
    (batch_norm): BatchNorm1d(1025, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (recurrent_stack): RecurrentStack(
    (rnn): LSTM(1025, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  )
  (embedding): Embedding(
    (linear): Linear(in_features=100, out_features=1025, bias=True)
  )
)

In [9]:
model = Waveform.build(1025, 1, 50, 2, True, 0.3, 1, 2048, 512)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nussl.ml.train.loss.L1Loss()

In [11]:
model

SeparationModel(
  (layers): ModuleDict(
    (model): Waveform(
      (stft): STFT()
      (amplitude_to_db): AmplitudeToDB()
      (input_normalization): BatchNorm(
        (batch_norm): BatchNorm1d(1025, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (recurrent_stack): RecurrentStack(
        (rnn): LSTM(1025, 50, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
      )
      (embedding): Embedding(
        (linear): Linear(in_features=100, out_features=1025, bias=True)
      )
    )
    (estimates): Alias()
  )
)
Number of parameters: 597175

In [12]:
# nf = stft_params.window_length // 2 + 1
# model = Waveform.build(nf, 1, 50, 1, True, 0.0, 1, 'sigmoid')
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# loss_fn = nussl.ml.train.loss.L1Loss()

def train_step(engine, batch):
    optimizer.zero_grad()
    output = model(batch) # forward pass
    loss = loss_fn(
        output['estimates'],
        batch['source_audio'],
    )
    
    loss.backward() # backwards + gradient step
    optimizer.step()
    
    loss_vals = {
        'L1Loss': loss.item(),
        'loss': loss.item()
    }
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch) # forward pass
    loss = loss_fn(
        output['estimates'],
        batch['source_audio']
    )    
    loss_vals = {
        'L1Loss': loss.item(), 
        'loss': loss.item()
    }
    return loss_vals

# Create the engines
trainer, validator = nussl.ml.train.create_train_and_validation_engines(
    train_step, val_step, device=DEVICE
)

# We'll save the output relative to this notebook.
output_folder = Path('.').absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(output_folder, model, 
    optimizer, train_data, trainer, val_dataloader, validator)

trainer.run(
    train_dataloader, 
    epoch_length=10, 
    max_epochs=2
)

04/27/2023 07:52:32 PM | engine.py:874 Engine run starting with max_epochs=2.
04/27/2023 07:52:52 PM | engine.py:874 Engine run starting with max_epochs=1.
04/27/2023 07:52:58 PM | engine.py:972 Epoch[1] Complete. Time taken: 00:00:04.572
04/27/2023 07:52:58 PM | engine.py:988 Engine run complete. Time taken: 00:00:06.036
04/27/2023 07:53:03 PM | trainer.py:311 

EPOCH SUMMARY 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 
- Epoch number: 0001 / 0002 
- Training loss:   0.093904 
- Validation loss: 0.115035 
- Epoch took: 0:00:31.065096 
- Time since start: 0:00:31.065127 
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 
Saving to /Users/dev/audio_isolation/checkpoints/best.model.pth. 
Output @ /Users/dev/audio_isolation 

04/27/2023 07:53:03 PM | engine.py:972 Epoch[1] Complete. Time taken: 00:00:26.213
04/27/2023 07:53:16 PM | engine.py:874 Engine run starting with max_epochs=1.
04/27/2023 07:53:21 PM | engine.py:972 Epoch[1] Complete. Time taken: 00:00:04.596
04/27/2023 07:53:21 PM | engine.py:988 Engine run comp

State:
	iteration: 20
	epoch: 2
	epoch_length: 10
	max_epochs: 2
	output: <class 'dict'>
	batch: <class 'dict'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
	epoch_history: <class 'dict'>
	iter_history: <class 'dict'>
	past_iter_history: <class 'dict'>
	saved_model_path: /Users/dev/audio_isolation/checkpoints/best.model.pth
	output_folder: <class 'pathlib.PosixPath'>

In [None]:
separator = nussl.separation.deep.DeepMaskEstimation(
    nussl.AudioSignal(), model_path='checkpoints/best.model.pth',
    device=DEVICE,
)



In [None]:
from utils import viz

test_folder = "~/audio_isolation/data/tutorial/test/"
test_data = data.mixer(stft_params, transform=None, 
    fg_path=test_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)
item = test_data[0]

separator.audio_signal = item['mix']
estimates = separator()
# Since our model only returns one source, let's tack on the
# residual (which should be accompaniment)
estimates.append(item['mix'] - estimates[0])

viz.show_sources(estimates)

RuntimeError: Calculated padded input size per channel: (2). Kernel size: (3). Kernel size can't be greater than actual input size

In [None]:
import json

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
])
#test_dataset = nussl.datasets.MUSDB18(subsets=['test'], transform=tfm)
test_dataset = data.mixer(stft_params, transform=tfm, 
    fg_path=test_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)

# Just do 5 items for speed. Change to 50 for actual experiment.
for i in range(5):
    item = test_dataset[i]
    separator.audio_signal = item['mix']
    estimates = separator()

    source_keys = list(item['sources'].keys())
    estimates = {
        'vocals': estimates[0],
        'bass+drums+other': item['mix'] - estimates[0]
    }

    sources = [item['sources'][k] for k in source_keys]
    estimates = [estimates[k] for k in source_keys]

    evaluator = nussl.evaluation.BSSEvalScale(
        sources, estimates, source_labels=source_keys
    )
    scores = evaluator.evaluate()
    output_folder = Path(output_folder).absolute()
    output_folder.mkdir(exist_ok=True)
    output_file = output_folder / sources[0].file_name.replace('wav', 'json') # Path(str(separator.audio_signal.file_name) + '.json')
    with open(output_file, 'w') as f:
        json.dump(scores, f, indent=4)

In [None]:
import glob
import numpy as np

json_files = glob.glob(f"*.json")
df = nussl.evaluation.aggregate_score_files(
    json_files, aggregator=np.nanmedian)
nussl.evaluation.associate_metrics(separator.model, df, test_dataset)
report_card = nussl.evaluation.report_card(
    df, report_each_source=True)
print(report_card)

In [None]:
separator.model.save('checkpoints/best.model.pth')

In [None]:
model_checkpoint = torch.load('checkpoints/best.model.pth')

In [None]:
model_checkpoint['metadata'].keys()

In [None]:
model_checkpoint['metadata']['evaluation']