In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import os
from IPython.display import Audio, display

import stempeg
import numpy as np

import warnings
warnings.filterwarnings('ignore')

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
device

'cuda'

In [4]:
class WaveNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super(WaveNetBlock, self).__init__()
        self.dilation = dilation
        self.conv_filter = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding='same')
        self.conv_gate = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding='same')
        self.conv_res = nn.Conv1d(out_channels, in_channels, 1)  # Residual connection
        self.conv_skip = nn.Conv1d(out_channels, out_channels, 1)  # Skip connection

    def forward(self, x):
        # Apply dilated convolutions
        filter_output = torch.tanh(self.conv_filter(x))
        gate_output = torch.sigmoid(self.conv_gate(x))
        gated_output = filter_output * gate_output
        
        # Residual and skip connections
        residual = self.conv_res(gated_output)
        skip_connection = self.conv_skip(gated_output)
        output = x + residual
        return output, skip_connection

class WaveNet(nn.Module):
    def __init__(self, num_blocks, num_layers, in_channels, out_channels, residual_channels, skip_channels, kernel_size):
        super(WaveNet, self).__init__()
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.kernel_size = kernel_size
        
        self.start_conv = nn.Conv1d(in_channels, residual_channels, kernel_size=1)
        
        self.blocks = nn.ModuleList([
            WaveNetBlock(residual_channels, residual_channels, kernel_size, 2 ** i)
            for i in range(num_layers)
        ])
        
        self.end_conv1 = nn.Conv1d(residual_channels, skip_channels, kernel_size=3, padding=1)
        self.end_conv2 = nn.Conv1d(skip_channels, out_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = self.start_conv(x)
        skip_connections = []

        for _ in range(self.num_blocks):
            for layer in self.blocks:
                x, skip = layer(x)
                skip_connections.append(skip)

        output = torch.relu(sum(skip_connections))
        output = self.end_conv1(output)
        output = torch.relu(output)
        output = self.end_conv2(output)
        
        return output


# Create an instance of the WaveNet model
model = WaveNet(num_blocks=3, num_layers=10, in_channels=2, out_channels=2,
                residual_channels=16, skip_channels=64, kernel_size=3)
model.to(device) 

WaveNet(
  (start_conv): Conv1d(2, 16, kernel_size=(1,), stride=(1,))
  (blocks): ModuleList(
    (0): WaveNetBlock(
      (conv_filter): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=same)
      (conv_gate): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=same)
      (conv_res): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
      (conv_skip): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
    )
    (1): WaveNetBlock(
      (conv_filter): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (conv_gate): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (conv_res): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
      (conv_skip): Conv1d(16, 16, kernel_size=(1,), stride=(1,))
    )
    (2): WaveNetBlock(
      (conv_filter): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
      (conv_gate): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
      (conv_res): Conv1d(16

In [5]:
def EnergyConservingLoss(input_mix, input_voice, input_noise, generated_voice):
    
    L = nn.L1Loss()
    
    voice_diff = L(input_voice, generated_voice)
    noise_diff = L(input_noise, (input_mix - generated_voice))
    
    loss = voice_diff + noise_diff
    
    return loss

In [6]:
def process_musdb(subset):
    
    assert subset in ['train', 'test']
    
    mix = []
    noise = []
    vocals = []

    for filename in os.listdir('musdb18/{}'.format(subset))[:40]:
        
        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename, 
                                                stem_id=[0, 3, 4],
                                                out_type=np.float32,
                                                start=10,
                                                duration=10)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T, sample_rate))
        vocals.append((audio[2].T, sample_rate))
                
    return mix, noise, vocals

In [7]:
model = WaveNet(num_blocks=3, num_layers=10, in_channels=2, out_channels=2,
                residual_channels=16, skip_channels=64, kernel_size=3)
model.to(device) 

train_mix, train_noise, train_vocals = process_musdb('train')

optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10

for epoch in range(num_epochs):
    total_loss = 0.0

    for i in range(len(train_mix)):
        optimizer.zero_grad()
        
        input_mix = torch.tensor(train_mix[i][0]).unsqueeze(0).to(device=device)
        input_voice = torch.tensor(train_vocals[i][0]).unsqueeze(0).to(device=device)
        input_noise = torch.tensor(train_noise[i][0]).unsqueeze(0).to(device=device)
        
        # Forward pass
        output = model(input_mix)

        # Calculate loss
        loss = EnergyConservingLoss(input_mix, input_voice, input_noise, output)
        total_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_mix)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}")

print("Training finished!")

Epoch [1/10], Avg Loss: 0.1090
Epoch [2/10], Avg Loss: 0.0860
Epoch [3/10], Avg Loss: 0.0837
Epoch [4/10], Avg Loss: 0.0821
Epoch [5/10], Avg Loss: 0.0823
Epoch [6/10], Avg Loss: 0.0810
Epoch [7/10], Avg Loss: 0.0809
Epoch [8/10], Avg Loss: 0.0801
Epoch [9/10], Avg Loss: 0.0804
Epoch [10/10], Avg Loss: 0.0816
Training finished!


In [8]:
Audio(train_mix[3][0], rate=train_mix[3][1])

In [9]:
isolated = model(torch.tensor(train_mix[3][0]).unsqueeze(0).to(device=device))
Audio(isolated.squeeze(0).cpu().detach().numpy(), rate=train_mix[3][1])