In [1]:
from wavenet_model import WaveNetModel

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import torchaudio

import os
from IPython.display import Audio, display

import stempeg
import numpy as np

import librosa
from auraloss.time import SNRLoss, SISDRLoss, SDSDRLoss, ESRLoss
from auraloss.freq import STFTLoss, MelSTFTLoss, STFTMagnitudeLoss

import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(2, 64, kernel_size=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=2, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=1)
        )

        # Middle
        self.middle = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=2, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=1)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=2, padding=1),
            nn.ReLU(inplace=True),
        )

        # Output Layer
        self.output_layer = nn.Conv2d(64, 2, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.encoder(x)
        # Middle
        x2 = self.middle(x1)
        # Decoder
        x3 = self.decoder(x2)
        # Output Layer
        x4 = self.output_layer(x3)
        return x4

In [4]:
def process_musdb(subset):
    
    assert subset in ['train', 'test']
    
    mix = []
    noise = []
    vocals = []

    for filename in os.listdir('musdb18/{}'.format(subset))[:1]:
        
        if filename == '.ipynb_checkpoints':
            continue
            
        # Pull training sample from sparser/quieter region
        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename,
                                                out_type=np.float32,
                                                start=30,
                                                duration=10)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T+audio[2].T+audio[3].T, sample_rate))
        vocals.append((audio[4].T, sample_rate))
        
        # Pull training sample from more populated/louder region
        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename, 
                                                out_type=np.float32,
                                                start=60,
                                                duration=10)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T+audio[2].T+audio[3].T, sample_rate))
        vocals.append((audio[4].T, sample_rate))
        

        audio, sample_rate = stempeg.read_stems('musdb18/{}/'.format(subset) + filename, 
                                                out_type=np.float32,
                                                start=45,
                                                duration=10)

        mix.append((audio[0].T, sample_rate))
        noise.append((audio[1].T+audio[2].T+audio[3].T, sample_rate))
        vocals.append((audio[4].T, sample_rate))
        
    mix_out = []
    noise_out = []
    vocals_out = []

    for i in range(len(vocals)):
        if np.mean(abs(vocals[i][0][0]) + abs(vocals[i][0][1])) >= 0.05:
            mix_out.append(mix[i])
            noise_out.append(noise[i])
            vocals_out.append(vocals[i])

    return mix_out, noise_out, vocals_out

In [5]:
# Load your dataset (assuming you have a function process_musdb that loads your data)
train_mix, train_noise, train_vocals = process_musdb('train')

In [6]:
model = WaveNetModel(layers=7,
                     blocks=2,
                     dilation_channels=64,
                     residual_channels=64,
                     skip_channels=128,
                     end_channels=128,
                     classes=2,
                     output_length=0,
                     kernel_size=2,
                     dtype=torch.FloatTensor,
                     bias=False).to(device)

In [None]:
# Define your loss function
criterion = STFTLoss(fft_size=4096)

# Define your optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Number of training epochs
num_epochs = 1000

for epoch in range(num_epochs):
    total_loss = 0.0
    
    for i in range(len(train_mix)):
        
        optimizer.zero_grad()

        input_mix = torch.tensor(train_mix[i][0]).unsqueeze(0).to(device=device)
        input_voice = torch.tensor(train_vocals[i][0]).unsqueeze(0).to(device=device)
        input_noise = torch.tensor(train_noise[i][0]).unsqueeze(0).to(device=device)
        
        # Forward pass
        output = model(input_mix)

        # Calculate the loss
        loss = criterion(output, input_voice[:, :, :440928])
        
        total_loss += loss.item()

        # Backpropagation
        loss.backward()
            
        # Update weights
        optimizer.step()
            
    avg_loss = total_loss / len(train_mix)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}")

print("Training finished!")

Epoch [1/1000], Avg Loss: 4.9690
Epoch [2/1000], Avg Loss: 2.3918
Epoch [3/1000], Avg Loss: 2.0793
Epoch [4/1000], Avg Loss: 1.9779
Epoch [5/1000], Avg Loss: 2.0345
Epoch [6/1000], Avg Loss: 1.9322
Epoch [7/1000], Avg Loss: 1.8723
Epoch [8/1000], Avg Loss: 1.8984
Epoch [9/1000], Avg Loss: 1.8180
Epoch [10/1000], Avg Loss: 1.7657
Epoch [11/1000], Avg Loss: 1.7355
Epoch [12/1000], Avg Loss: 1.6941
Epoch [13/1000], Avg Loss: 1.6485
Epoch [14/1000], Avg Loss: 1.6425
Epoch [15/1000], Avg Loss: 1.6886
Epoch [16/1000], Avg Loss: 1.6534
Epoch [17/1000], Avg Loss: 1.6211
Epoch [18/1000], Avg Loss: 1.6470
Epoch [19/1000], Avg Loss: 1.6910
Epoch [20/1000], Avg Loss: 1.6114
Epoch [21/1000], Avg Loss: 1.5770
Epoch [22/1000], Avg Loss: 1.7741
Epoch [23/1000], Avg Loss: 1.6268
Epoch [24/1000], Avg Loss: 1.7201
Epoch [25/1000], Avg Loss: 1.7571
Epoch [26/1000], Avg Loss: 1.7459
Epoch [27/1000], Avg Loss: 1.9820
Epoch [28/1000], Avg Loss: 1.8285
Epoch [29/1000], Avg Loss: 1.7571
Epoch [30/1000], Avg Lo

In [None]:
Audio(train_mix[0][0], rate=train_mix[0][1])

In [None]:
isolated = model(torch.tensor(train_mix[0][0]).unsqueeze(0).to(device=device))
Audio(isolated.squeeze(0).cpu().detach(), rate=train_mix[0][1])

In [None]:
train_mix[0][0]

In [None]:
isolated * 0.24

In [None]:
Audio(librosa.istft(librosa.stft(train_mix[0][0])), rate=train_mix[0][1])

In [None]:
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f'Channel {c+1}')
        if xlim:
            axes[c].set_xlim(xlim)
        if ylim:
            axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)

In [None]:
plot_waveform(torch.tensor(train_mix[0][0]), sample_rate=train_mix[0][1])

In [None]:
plot_waveform(isolated[0].squeeze(0).cpu().detach(), sample_rate=train_mix[0][1])

In [None]:
output.shape