In [1]:
from wavenet_model import WaveNetModel

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

import os
from IPython.display import Audio, display

import stempeg
import numpy as np

from auraloss.time import SNRLoss, SISDRLoss, SDSDRLoss

import warnings
warnings.filterwarnings('ignore')

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def process_musdb(subset):
    
    assert subset in ['train', 'test']
    
    mix = []
    noise = []
    vocals = []

    for filename in os.listdir('musdb18/{}'.format(subset))[:1]:
        
        if filename == '.ipynb_checkpoints':
            continue
            
        # Pull training sample from sparser/quieter region
        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename,
                                                out_type=np.float32,
                                                start=30,
                                                duration=10)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T+audio[2].T+audio[3].T, sample_rate))
        vocals.append((audio[4].T, sample_rate))
        
        # Pull training sample from more populated/louder region
        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename, 
                                                out_type=np.float32,
                                                start=60,
                                                duration=10)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T+audio[2].T+audio[3].T, sample_rate))
        vocals.append((audio[4].T, sample_rate))
        

        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename, 
                                                out_type=np.float32,
                                                start=45,
                                                duration=10)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T+audio[2].T+audio[3].T, sample_rate))
        vocals.append((audio[4].T, sample_rate))
        
    mix_out = []
    noise_out = []
    vocals_out = []

    for i in range(len(vocals)):
        if np.mean(abs(vocals[i][0][0]) + abs(vocals[i][0][1])) >= 0.05:
            mix_out.append(mix[i])
            noise_out.append(noise[i])
            vocals_out.append(vocals[i])

    return mix_out, noise_out, vocals_out

In [4]:
# Load your dataset (assuming you have a function process_musdb that loads your data)
train_mix, train_noise, train_vocals = process_musdb('train')

In [5]:
model = WaveNetModel(layers=2,
                     blocks=2,
                     dilation_channels=128,
                     residual_channels=128,
                     skip_channels=512,
                     end_channels=512,
                     classes=2,
                     output_length=0,
                     kernel_size=2,
                     dtype=torch.FloatTensor,
                     bias=False).to(device)

In [None]:
# Define your loss function
criterion = nn.L1Loss()

# Define your optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Number of training epochs
num_epochs = 50

for epoch in range(num_epochs):
    total_loss = 0.0
    
    for i in range(len(train_mix)):
        
        optimizer.zero_grad()

        input_mix = torch.tensor(train_mix[i][0]).unsqueeze(0).to(device=device)
        input_voice = torch.tensor(train_vocals[i][0]).unsqueeze(0).to(device=device)
        input_noise = torch.tensor(train_noise[i][0]).unsqueeze(0).to(device=device)
        
        # Forward pass
        output = model(input_mix)

        # Calculate the loss
        loss = criterion((input_mix[:, :, :440996] - output), input_voice[:, :, :440996])
        
        total_loss += loss.item()

        # Backpropagation
        loss.backward()
            
        # Update weights
        optimizer.step()
            
    avg_loss = total_loss / len(train_mix)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}")

print("Training finished!")

Epoch [1/50], Avg Loss: 0.2426
Epoch [2/50], Avg Loss: 1.1063


In [None]:
Audio(train_mix[0][0], rate=train_mix[0][1])

In [None]:
isolated = model(torch.tensor(train_mix[0][0]).unsqueeze(0).to(device=device))
Audio(torch.tensor(train_mix[0][0])[:, :440996].cpu().detach().numpy() - isolated[0].squeeze(0).cpu().detach().numpy(), rate=train_mix[0][1])