In [None]:
import nussl
import scaper
from IPython.display import Audio, display
from pathlib import Path
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import tqdm

# These two libraries are from https://github.com/source-separation/tutorial 
from common import data
from common import viz
from nussl.datasets import transforms as nussl_tfm

dataset_path = "~/userdata/datasets/musdb"


## 1. Load dataset

In [None]:
dataset_path = "~/userdata/datasets/musdb"
musdb_train = nussl.datasets.MUSDB18(dataset_path, subsets='train', split="train")
musdb_valid = nussl.datasets.MUSDB18(dataset_path, subsets='train', split="valid")
musdb_test = nussl.datasets.MUSDB18(dataset_path, subsets='test')
len(musdb_train), len(musdb_valid), len(musdb_test)

In [None]:
train_item = musdb_train[0]

In [None]:
train_item['mix'].audio_data

In [None]:
train_item['sources']

In [None]:
# Dataset is usually subscriptable,
display(Audio(musdb_train[0]['mix'].audio_data, rate=44100))

In [None]:
display(Audio(train_item['sources']['vocals'].audio_data, rate=44100))

In [None]:
# This function is a modified version of tutorial.common.data.prepare_data
# It will split the STEM audio file into each source and split
for i, musdb in enumerate([musdb_train, musdb_valid, musdb_test]):
  _folder = Path(dataset_path) / ["train", "valid", "test"][i]
  _folder = _folder.expanduser()
  _folder.mkdir(exist_ok=True)
  for item in tqdm.tqdm(musdb):
    song_name = item['mix'].file_name
    for key, val in item['sources'].items():
      src_path = _folder / key 
      src_path.mkdir(exist_ok=True)
      src_path = str(src_path / song_name) + '.wav'
      val.write_audio_to_file(src_path)


In [None]:
# data.on_the_fly makes a new mixture from source

stft_params = nussl.STFTParams(window_length=512, hop_length=128, window_type='sqrt_hann')
duration = 10

trainset = data.on_the_fly(stft_params, transform=None, fg_path=dataset_path+"/train", num_mixtures=500, duration=duration)
item = trainset[0]
viz.show_sources(item['sources'])

In [None]:
print(item['metadata']['jam'])

In [None]:
print(item)
print(item.keys())

### Transform Data
- We have to transform nussl.core.AudioSignal into desired format
    1. We want to make One Vs All separation system. Therefore, we have to combine the sources except target
        - If you want to make vocal separator, you can mix drum, bass, and other as a single source
        - If you want to make drum separator, you can mix vocal, bass, and other as a single source
    2. We want to use spectrogram instead of waveform audio samples 


In [None]:
from nussl.datasets import transforms as nussl_tfm

item = trainset[0]
sum_sources = nussl_tfm.SumSources([['vocals', 'drums', 'other']])
transformed_item = sum_sources(item)
print(transformed_item['sources'])
viz.show_sources(transformed_item['sources'])

In [None]:
# Make Magnitude Spectrogram np.abs(AudioSignal.stft())
msa = nussl_tfm.MagnitudeSpectrumApproximation()

item = trainset[0]

transformed_item = msa(item)
print(transformed_item.keys())
print(transformed_item['source_magnitudes'].shape)

plt.figure(figsize=(10,20))
plt.subplot(4,1,1)
plt.imshow(np.log10(transformed_item['source_magnitudes'][...,0]), origin='lower', aspect='auto')
plt.subplot(4,1,2)
plt.imshow(np.log10(transformed_item['source_magnitudes'][...,1]), origin='lower', aspect='auto')
plt.subplot(4,1,3)
plt.imshow(np.log10(transformed_item['source_magnitudes'][...,2]), origin='lower', aspect='auto')
plt.subplot(4,1,4)
plt.imshow(np.log10(transformed_item['source_magnitudes'][...,3]), origin='lower', aspect='auto')

In [None]:
index_sources = nussl_tfm.IndexSources('source_magnitudes', 0)
transformed_item = index_sources(msa(item))
print(transformed_item['source_magnitudes'].shape)
plt.imshow(np.log10(transformed_item['source_magnitudes'].squeeze()), origin='lower', aspect='auto')


In [None]:
type(transformed_item['source_magnitudes'])

In [None]:
to_tensor = nussl_tfm.ToSeparationModel()
item = trainset[0]
transformed_item = to_tensor(index_sources(msa(item)))
print(transformed_item.keys())

In [None]:
transformed_item['source_magnitudes'].shape

In [None]:
tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
    nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_magnitudes', 1),
    nussl_tfm.ToSeparationModel(),
])
item = trainset[0]
print("Before transforms")
for key in item:
    print(key, type(item[key]))
print("\nAfter transforms")
item = tfm(item)
for key in item:
    print(key, type(item[key]))

In [None]:
stft_params = nussl.STFTParams(window_length=512, hop_length=128, window_type='sqrt_hann')
tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),
    nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_magnitudes', 1),
    nussl_tfm.ToSeparationModel(),
])
duration = 5
trainset = trainset = data.on_the_fly(stft_params, 
                                      transform=tfm, 
                                      fg_path=dataset_path+"/train", 
                                      num_mixtures=10000000,
                                      time_stretch=None,
                                      duration=duration)
item = trainset[0]
print(item.keys())

In [None]:
len(trainset)

In [None]:
validset = data.on_the_fly(stft_params, transform=tfm, fg_path=dataset_path+"/valid", num_mixtures=64,time_stretch=None, duration=duration)
testset = data.on_the_fly(stft_params, transform=tfm, fg_path=dataset_path+"/test", num_mixtures=32,time_stretch=None, duration=duration)

In [None]:
stft_params.window_length

## Design Model

In [None]:
import torch
import torch.nn as nn
import torchaudio

In [None]:
class Separator(nn.Module):
  def __init__(self, num_freq,hidden_size):
    super().__init__()
    self.amp_to_db = torchaudio.transforms.AmplitudeToDB(stype='magnitude')
    self.batch_norm = nn.BatchNorm2d(num_freq)
    self.rnn = nn.LSTM(input_size=num_freq, hidden_size=hidden_size, num_layers=3, bidirectional=True, batch_first=True, dropout=0.3)
    self.linear = nn.Linear(hidden_size*2, num_freq)
    
  def forward(self, x):
    db_spec = self.amp_to_db(x.float())
    db_spec = db_spec.permute(0,2,1,3)
    norm_spec = self.batch_norm(db_spec)
    norm_spec = norm_spec.permute(0,2,1,3)[..., 0]
    
    hidden, _ = self.rnn(norm_spec)
    mask = self.linear(hidden).sigmoid().unsqueeze(-1)
    masked_output = x.float() * mask
    
    return {'mask': mask, 'estimation': masked_output}

model = Separator(num_freq=stft_params.window_length//2+1, hidden_size=32)


In [None]:
rnn_example = nn.LSTM(input_size=1, hidden_size=2, num_layers=1)

In [None]:
input_dummy = torch.arange(12).view(1,-1,1).float()
input_dummy

In [None]:
output, hidden_states = rnn_example(input_dummy)
output

In [None]:
h_state, c_state = hidden_states
h_state, c_state

In [None]:
train_loader = DataLoader(trainset, shuffle=True, batch_size=16, num_workers=4)
valid_loader = DataLoader(validset, batch_size=32)
# batch = next(iter(train_loader))

In [None]:
batch.keys()

In [None]:
mix_spec = batch['mix_magnitude']
model(mix_spec)

In [None]:
batch = next(iter(train_loader))

In [None]:
batch['mix_magnitude'].shape

In [None]:
self = model
x = batch['mix_magnitude']

db_spec = self.amp_to_db(x.float())
db_spec = db_spec.permute(0,2,1,3)
norm_spec = self.batch_norm(db_spec)
norm_spec = norm_spec.permute(0,2,1,3)[..., 0]

hidden, _ = self.rnn(norm_spec)
mask = self.linear(hidden).sigmoid().unsqueeze(-1)
masked_output = x.float() * mask

In [None]:
plt.imshow(x[0].transpose(0,1), aspect='auto', origin='lower')

In [None]:
plt.imshow(db_spec[0], aspect='auto', origin='lower')

In [None]:
plt.imshow(norm_spec[0].detach().permute(1,0), aspect='auto', origin='lower')

In [None]:
plt.imshow(hidden[0].detach().permute(1,0), aspect='auto', origin='lower')

In [None]:
plt.imshow(mask[0].detach().permute(1,0,2), aspect='auto', origin='lower')

In [None]:
plt.imshow(masked_output[0].detach().permute(1,0,2), aspect='auto', origin='lower')

In [None]:
plt.imshow(batch['source_magnitudes'][0][...,0].transpose(0,1), aspect='auto', origin='lower')

In [None]:
diff= batch['source_magnitudes'][0][...,0].transpose(0,1) - masked_output[0].detach().permute(1,0,2)
plt.imshow(diff, aspect='auto', origin='lower')

In [None]:
def spec_l1_loss(pred, target):
  return torch.mean(torch.abs(pred-target))

In [None]:
def train_loop(model, optimizer, train_loader, valid_loader, loss_func, num_iter, valid_iter, device):
  model = model.to(device)
  itr = 0
  train_loss_record = []
  valid_loss_record = []
  model.train()
  iter_train_loader = iter(train_loader)
  for itr in tqdm.tqdm(range(num_iter)):
    batch = next(iter_train_loader)
    optimizer.zero_grad()
    pred = model(batch['mix_magnitude'].to(device))
    loss = loss_func(pred['estimation'], batch['source_magnitudes'][..., 0].to(device))
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    train_loss_record.append(loss.item())
    if itr % valid_iter == 0:
      model.eval()
      valid_loss = 0
      with torch.no_grad():
        for batch in valid_loader:
          pred = model(batch['mix_magnitude'].to(device))
          loss = loss_func(pred['estimation'], batch['source_magnitudes'][..., 0].to(device))
          valid_loss = loss.item() * len(batch['mix_magnitude'])
      valid_loss_record.append(valid_loss/len(valid_loader.dataset))
      model.train()
  return {'train': train_loss_record, 'valid':valid_loss_record}

    
train_loader = DataLoader(trainset, shuffle=True, batch_size=32, num_workers=4)
valid_loader = DataLoader(validset, batch_size=32, num_workers=0)

model = Separator(num_freq=stft_params.window_length//2+1, hidden_size=256)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
result = train_loop(model, optimizer, train_loader, valid_loader, spec_l1_loss, 100, 50, device='cuda')
model = model.to('cpu')
torch.save(model.state_dict(), 'vocal_separator_lstm.pt')

In [None]:
torch.save(model.state_dict(), 'vocal_separator_lstm.pt')

In [None]:
ckpt = torch.load('vocal_separator_lstm_large.pt')

In [None]:
model.load_state_dict(ckpt)

In [None]:
plt.plot(result['train'])
# plt.plot(list(range(0, 10000, 200)), result['valid'])

In [None]:
len(result['train'])

# 3. Test on custom audio

In [None]:
# audio_path = "01 범 내려온다_Tiger is Coming.wav"
audio_path = "/home/teo/userdata/datasets/musdb/test/Zeno - Signs.stem.mp4"
audio_signal = nussl.AudioSignal(audio_path)

audio_signal.stft_params = stft_params

In [None]:
spec = audio_signal.to_mono().stft()
magnitude_spec = np.abs(spec)
input_tensor = torch.Tensor(magnitude_spec).float()
input_tensor = torch.stack([input_tensor, input_tensor], dim=0).permute(0,2,1,3)
print(input_tensor.shape)
model.eval()
model.to('cuda')
with torch.no_grad():
  result = model(input_tensor.to('cuda'))

In [None]:
plt.imshow(result['mask'][0].cpu().permute(1,0,2), aspect='auto', origin='lower')

In [None]:
masked_spec = result['estimation'][0].cpu().numpy().transpose(1,0,2)
masked_spec = masked_spec* np.exp(1j * np.angle(spec))

print(masked_spec.shape)
recon_signal = nussl.AudioSignal(stft=masked_spec, sample_rate=audio_signal.sample_rate, stft_params=stft_params)
recon_audio = recon_signal.istft()

In [None]:
display(Audio(recon_audio, rate=recon_signal.sample_rate))

In [None]:
mix_spec = batch['mix_magnitude']
mix_spec = torchaudio.transforms.AmplitudeToDB()(mix_spec)
print(mix_spec.shape) # N T F C
mix_spec = mix_spec.permute(0,2,1,3) # N F T C
nn.BatchNorm2d(257)(mix_spec)

In [None]:
batch['source_magnitudes'].shape