Here we train our first version of the GAN.



## Initialize Wave-U-Net

We start by loading the necessary packages

In [None]:
# Import same packages as the train script in Wave-U-Net-Pytorch
import argparse
import os
import time
from functools import partial

import torch 
import pickle
import numpy as np

import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from tqdm import tqdm

# add a path to Wave-U-Net
import sys
sys.path.append('Wave-U-Net-Pytorch')

import model.utils as model_utils
import utils
from model.waveunet import Waveunet


We define the parameters of the model.

In [2]:
model_config = {
    "num_inputs": 256,               # 128 mel bins per spectrogram, but we have to spectrograms
    "num_outputs": 128,              # Output also has 128 mel bins
    "num_channels": [256*2, 256*4, 256*8],    # Example channel progression
    "instruments": ["vocal"],        # Only output vocal, so no music branch
    "kernel_size": 3,                # Must be odd
    "target_output_size": 256,       # Desired output time frames (post-processing may crop)
    "conv_type": "normal",           # Set to "normal" to meet assertion requirements
    "res": "fixed",                  # Use fixed resampling
    "separate": False,                # Separate branch for vocal
    "depth": 1,                      # Number of conv layers per block
    "strides": 2                   # Down/up-sampling stride
}

Using valid convolutions with 289 inputs and 257 outputs
input_size (length of input): 289
num_inputs (number of channels in the input): 256


Load the model, check how much GPU memory it will use during training, and print a summary of the model.

In [4]:
# summary(generator, input_size=(generator.num_inputs,  generator.input_size))import torch
from torch.optim import Adam
from torch.nn import L1Loss

from torchsummary import summary

# Ensure that you have a CUDA-enabled device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate and move the model to GPU
generator = Waveunet(**model_config).to(device)

# Set up a dummy optimizer and loss function
optimizer = Adam(generator.parameters(), lr=1e-3)
loss_fn = L1Loss()

# Define a dummy batch size
batch_size = 256

# Create a dummy input tensor with the required shape
# model.num_inputs corresponds to the number of channels (256 in your config)
# model.input_size is the computed length (353, for instance)
dummy_input = torch.randn(batch_size, generator.num_inputs, generator.input_size, device=device)

# Create a dummy target tensor with the shape that your model outputs.
# For a single output branch (vocal), the output shape should be:
# (batch_size, num_outputs, model.output_size)
# model.num_outputs is 128 and model.output_size is computed (257 in your case)
dummy_target = torch.randn(batch_size, generator.num_outputs, generator.output_size, device=device)

# Reset GPU peak memory stats
torch.cuda.reset_peak_memory_stats(device)

# Run a single forward and backward pass
optimizer.zero_grad()
# If separate is False, the model returns a dictionary; pass the correct key.
output = generator(dummy_input)["vocal"]
loss = loss_fn(output, dummy_target)
loss.backward()
optimizer.step()

# Retrieve GPU memory stats
peak_memory = torch.cuda.max_memory_allocated(device)
current_memory = torch.cuda.memory_allocated(device)
print("Peak GPU memory allocated (bytes):", peak_memory)
print("Current GPU memory allocated (bytes):", current_memory)

# Optionally, print a detailed memory summary
print(torch.cuda.memory_summary(device=device))


summary(generator, input_size=(generator.num_inputs,  generator.input_size))


Using valid convolutions with 289 inputs and 257 outputs
Peak GPU memory allocated (bytes): 6275877376
Current GPU memory allocated (bytes): 783944192
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 765570 KiB |   5985 MiB |  24314 MiB |  23567 MiB |
|       from large pool | 764032 KiB |   5984 MiB |  24312 MiB |  23566 MiB |
|       from small pool |   1538 KiB |      1 MiB |      2 MiB |      0 MiB |
|---------------------------------------------------------------------------|
| Active memory         | 765570 KiB |   5985 MiB |  24314 MiB |  23567 MiB |
|       from large pool | 764032 KiB |   5984 MiB |  24312 MiB |  235

Optionally compile the model to potentially decrease training time.

In [6]:
generator = torch.compile(generator, mode='max-autotune')

If we compile the model, to save it after training, we have to uncompile it using the following code:

```python
orig_generator = generator._orig_mod
path = ""
torch.save(orig_generator.state_dict(), path + "generator_state_dict.pt")


## Initialize miniRocket
We start by loading the necessary packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchaudio
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Optional
from sktime.transformations.panel.rocket import MiniRocketMultivariate

In [None]:
# MiniRocket Discriminator using tsai library
class TsaiMiniRocketDiscriminator(nn.Module):
    def __init__(
        self,
        freq_bins=256,
        time_frames=256,
        num_kernels=10000,  # number of convolutional kernels
        hidden_dim=1024,    # Increased to handle larger feature dimension
        output_dim=1
    ):
        super(TsaiMiniRocketDiscriminator, self).__init__()
        
        # This is the mini rocket transformer which extracts features
        self.rocket = MiniRocketMultivariate(num_kernels=num_kernels)  
        # tsai's miniRocketClassifier is implemented with MiniRocketMultivariate as well
        self.fitted = False   # fit before training
        self.freq_bins = freq_bins
        self.time_frames = time_frames
        self.num_kernels = num_kernels
        
        # For 2D data handling - process each sample with proper dimensions
        self.example_input = np.zeros((1, freq_bins, time_frames))
        
        feature_dim = num_kernels * 2  # For vocals + accompaniment
        
        # Example feature reducing layers
        self.classifier = nn.Sequential(
            # First reduce the massive dimension to something manageable
            nn.Linear(feature_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            # Second hidden layer
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            # Final classification layer
            nn.Linear(hidden_dim // 2, output_dim),
            nn.Sigmoid()
        )
    
    def fit_rocket(self, spectrograms):
        """
            Fit MiniRocket with just one piece of vocal training data (not the entire training dataset)
        """
        if not self.fitted:
            try:
                # Reshape for MiniRocket - it expects (n_instances, n_dimensions, series_length)
                # flatten the freq_bins dimension to create a multivariate time series
                batch_size = spectrograms.shape[0]
                
                # Convert first to numpy for sktime processing
                sample_data = spectrograms.cpu().numpy()
                
                # Reshape to sktime's expected format - reduce to single sample for fitting
                sample_data = sample_data[:1, 0]  # Take one sample, remove channel dim
                
                # Fit on this sample 
                self.rocket.fit(sample_data)
                self.fitted = True
                
                # Test transform to get feature dimension
                test_transform = self.rocket.transform(sample_data)
                self.feature_dim = test_transform.shape[1]
                
                print(f"MiniRocket fitted. Feature dimension: {self.feature_dim}")
                
            except Exception as e:
                print(f"Error fitting MiniRocket: {e}")
                # Use a fallback if fitting fails
                self.fitted = True  # Mark as fitted to avoid repeated attempts
    
    def extract_features(self, spectrogram):
        """Extract MiniRocket features from a spectrogram"""
        try:
            # Ensure rocket is fitted
            if not self.fitted:
                self.fit_rocket(spectrogram)
            
            # Convert to numpy for sktime
            spec_np = spectrogram.cpu().numpy()
            
            # Remove channel dimension expected by sktime
            spec_np = spec_np[:, 0]  # [batch_size, freq_bins, time_frames]
            
            # This step extracts features using the convolutional kernels, numbers specified by num_kernels
            features = self.rocket.transform(spec_np)
            
            # Convert back to torch tensor
            features = torch.tensor(features, dtype=torch.float32).to(spectrogram.device)
            
            return features
            
        except Exception as e:
            print(f"Error in feature extraction: {e}")
            # Return zeros as fallback
            return torch.zeros((spectrogram.shape[0], self.num_kernels), 
                              device=spectrogram.device)
    
    def forward(self, vocals, accompaniment):
        """
        Forward pass of the discriminator
        
        Args:
            vocals: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
            accompaniment: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
        """
        # Extract features from both spectrograms
        vocal_features = self.extract_features(vocals)
        accomp_features = self.extract_features(accompaniment)
        
        # Concatenate features (conditional GAN)
        combined_features = torch.cat([vocal_features, accomp_features], dim=1)
        
        # Classify as real/fake
        validity = self.classifier(combined_features)
        
        return validity



## Load data
Q: Does this only load the musdb18 dataset?


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# when you install musdb, pip automatically installs a version of stempeg that
# contains a small bug. To work around this, download the stempeg folder from
# the github to your drive.

%pip install musdb  # has some helpful data structures, also installs ffmpeg and stempeg
%pip uninstall stempeg    # musdb installs the wrong version of stempeg'

# The path below should be changed to the location of the stempeg package in
# your Drive
%cd '/content/drive/MyDrive/DeepLearningBootcamp'

import stempeg
import musdb

import librosa
from torch.utils.data import Dataset

In [None]:
class MusdbDataset(Dataset):

  def __init__(self, musDB, steps = 256):
    self.mel_specs = torch.zeros(1, 2, 128, steps)
    self.sample_rates = torch.tensor([0])

    for track in musDB:
      stems, rate = track.stems, track.rate

      # separate the vocal from other instruments and conver to mono signal
      audio_novocal = librosa.to_mono(np.transpose(stems[1] + stems[2] + stems[3]))
      audio_vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      logmelspec_novocal = torch.from_numpy(self._mel_spectrogram(audio_novocal, rate))
      logmelspec_vocal = torch.from_numpy(self._mel_spectrogram(audio_vocal, rate))

      num_slices = logmelspec_novocal.shape[1] // steps

      # chop off the last bit so that number of stft steps is a multiple of step size
      logmelspec_novocal = logmelspec_novocal[0:128 , 0:num_slices*steps]
      logmelspec_vocal = logmelspec_vocal[0:128, 0:num_slices*steps]

      logmelspec_novocal = torch.reshape(logmelspec_novocal, (num_slices, 128, steps))
      logmelspec_vocal = torch.reshape(logmelspec_vocal, (num_slices, 128, steps))

      # unsqueeze and concatenate these tensors. Then concatenate to the big tensor
      logmels = torch.cat((logmelspec_novocal.unsqueeze(1), logmelspec_vocal.unsqueeze(1)), 1)
      self.mel_specs = torch.cat((self.mel_specs, logmels), 0)
      self.sample_rates = torch.cat((self.sample_rates, torch.Tensor([rate])), 0)

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx, 0], self.mel_specs[ndx, 1], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

In [None]:
# get the full data set into the workspace
music = musdb.DB("/content/drive/MyDrive/DeepLearningBootcamp/musdb18_data", subsets="train")

In [None]:
# create a dataset out of the first 10 tracks, see how many slices of audio we have
data = MusdbDataset(music[0:10])
print(len(data))


### Load LibriSpeech
To be completed...

In [None]:
euih;ta