In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import os
from IPython.display import Audio, display

import stempeg
import numpy as np

import warnings
warnings.filterwarnings('ignore')

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
device

'cuda'

In [4]:
class WaveNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super(WaveNetBlock, self).__init__()
        self.dilation = dilation
        self.conv_filter = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding='same')
        self.conv_gate = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding='same')
        self.conv_res = nn.Conv1d(out_channels, in_channels, 1)  # Residual connection
        self.conv_skip = nn.Conv1d(out_channels, out_channels, 1)  # Skip connection

    def forward(self, x):
        # Apply dilated convolutions
        filter_output = torch.tanh(self.conv_filter(x))
        gate_output = torch.sigmoid(self.conv_gate(x))
        gated_output = filter_output * gate_output
        
        # Residual and skip connections
        residual = self.conv_res(gated_output)
        skip_connection = self.conv_skip(gated_output)
        output = x + residual
        return output, skip_connection

class WaveNet(nn.Module):
    def __init__(self, num_blocks, num_layers, in_channels, out_channels, residual_channels, skip_channels, kernel_size):
        super(WaveNet, self).__init__()
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.kernel_size = kernel_size
        
        self.start_conv = nn.Conv1d(in_channels, residual_channels, kernel_size=1)
        
        self.blocks = nn.ModuleList([
            WaveNetBlock(residual_channels, residual_channels, kernel_size, 2 ** i)
            for i in range(num_layers)
        ])
        
        self.end_conv1 = nn.Conv1d(residual_channels, skip_channels, kernel_size=3, padding=1)
        self.end_conv2 = nn.Conv1d(skip_channels, out_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = self.start_conv(x)
        skip_connections = []

        for _ in range(self.num_blocks):
            for layer in self.blocks:
                x, skip = layer(x)
                skip_connections.append(skip)

        output = torch.relu(sum(skip_connections))
        output = self.end_conv1(output)
        output = torch.relu(output)
        output = self.end_conv2(output)
        
        return output

In [5]:
def EnergyConservingLoss(input_mix, input_voice, input_noise, generated_voice):
    
    voice_L = nn.L1Loss()
    noise_L = nn.L1Loss()
    
    voice_diff = 2 * voice_L(generated_voice, input_voice)
    noise_diff = noise_L((input_mix - generated_voice), input_noise)
    
    loss = voice_diff + noise_diff
    
    return loss

In [6]:
def process_musdb(subset):
    
    assert subset in ['train', 'test']
    
    mix = []
    noise = []
    vocals = []

    for filename in os.listdir('musdb18/{}'.format(subset)):
        
        if filename == '.ipynb_checkpoints':
            continue
            
        # Pull training sample from sparser/quieter region
        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename, 
                                                stem_id=[0, 3, 4],
                                                out_type=np.float32,
                                                start=35,
                                                duration=5)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T, sample_rate))
        vocals.append((audio[2].T, sample_rate))
        
        # Pull training sample from more populated/louder region
        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename, 
                                                stem_id=[0, 3, 4],
                                                out_type=np.float32,
                                                start=60,
                                                duration=5)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T, sample_rate))
        vocals.append((audio[2].T, sample_rate))
                
    return mix, noise, vocals

In [None]:
model = WaveNet(num_blocks=3, num_layers=10, in_channels=2, out_channels=2,
                residual_channels=64, skip_channels=256, kernel_size=3)
model.to(device) 

train_mix, train_noise, train_vocals = process_musdb('train')

criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.01)

num_epochs = 10

for epoch in range(num_epochs):
    total_loss = 0.0
    
    for i in range(len(train_mix)):
        
        optimizer.zero_grad()

        input_mix = torch.tensor(train_mix[i][0]).unsqueeze(0).to(device=device)
        input_voice = torch.tensor(train_vocals[i][0]).unsqueeze(0).to(device=device)
        input_noise = torch.tensor(train_noise[i][0]).unsqueeze(0).to(device=device)
        
            
        output = model(input_mix)

        loss = criterion(output, input_voice) #EnergyConservingLoss(input_mix, input_voice, input_noise, output)
        total_loss += loss.item()

        loss.backward()
            
        optimizer.step()
            

    avg_loss = total_loss / len(train_mix)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}")

print("Training finished!")

In [None]:
Audio(train_mix[0][0], rate=train_mix[0][1])

In [None]:
isolated = model(torch.tensor(train_mix[0][0]).unsqueeze(0).to(device=device))
Audio(isolated.squeeze(0).cpu().detach().numpy(), rate=train_mix[1][1])