Here we train our first version of the GAN.



## Initialize Wave-U-Net

We start by loading the necessary packages

Wave-U-Net is named ``generator``

In [1]:
# Import same packages as the train script in Wave-U-Net-Pytorch
import argparse
import os
import time
from functools import partial

import torch
import pickle
import numpy as np

import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.nn import L1Loss
from tqdm import tqdm
from torchsummary import summary
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
# install torchaudio if not already installed
# ! pip install torchaudio
import torchaudio

import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Optional

!pip install sktime
from sktime.transformations.panel.rocket import MiniRocketMultivariate

from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive/git_projects/spring_2025_dl_audio_project


# add a path to Wave-U-Net
import sys
sys.path.append('Wave-U-Net-Pytorch')

import model.utils as model_utils
import utils
from model.waveunet import Waveunet

# Check to see if we have a GPU available
print("GPU:", torch.cuda.is_available())

Collecting sktime
  Downloading sktime-0.36.1-py3-none-any.whl.metadata (34 kB)
Collecting scikit-base<0.13.0,>=0.6.1 (from sktime)
  Downloading scikit_base-0.12.2-py3-none-any.whl.metadata (8.8 kB)
Downloading sktime-0.36.1-py3-none-any.whl (37.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.0/37.0 MB[0m [31m50.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading scikit_base-0.12.2-py3-none-any.whl (142 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.7/142.7 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: scikit-base, sktime
Successfully installed scikit-base-0.12.2 sktime-0.36.1
Mounted at /content/drive
/content/drive/My Drive/git_projects/spring_2025_dl_audio_project
GPU: True


In [2]:
# I run these commands in the terminal that you get when you pay for Colab.

# %pip install musdb  # has some helpful data structures, also installs ffmpeg and stempeg
# %pip uninstall stempeg    # musdb installs the wrong version of stempeg'

We define the parameters of the model.

In [3]:
model_config = {
    "num_inputs": 256,               # 128 mel bins per spectrogram, but we have to spectrograms
    "num_outputs": 128,              # Output also has 128 mel bins
    "num_channels": [512*2, 512*4, 512*8],    # Example channel progression
    "instruments": ["vocal"],        # Only output vocal, so no music branch
    "kernel_size": 3,                # Must be odd
    "target_output_size": 256,       # Desired output time frames (post-processing may crop)
    "conv_type": "normal",           # Set to "normal" to meet assertion requirements
    "res": "fixed",                  # Use fixed resampling
    "separate": False,                # Separate branch for vocal
    "depth": 1,                      # Number of conv layers per block
    "strides": 2                   # Down/up-sampling stride
}

Load the model, check how much GPU memory it will use during training, and print a summary of the model.

In [4]:
# Ensure that you have a CUDA-enabled device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate and move the model to GPU
generator = Waveunet(**model_config).to(device)

# # Set up a dummy optimizer and loss function
# optimizer = Adam(generator.parameters(), lr=1e-3)
# loss_fn = L1Loss()

# # Define a dummy batch size
# batch_size = 256

# # Create a dummy input tensor with the required shape
# # model.num_inputs corresponds to the number of channels (256 in your config)
# # model.input_size is the computed length (353, for instance)
# dummy_input = torch.randn(batch_size, generator.num_inputs, generator.input_size, device=device)

# # Create a dummy target tensor with the shape that your model outputs.
# # For a single output branch (vocal), the output shape should be:
# # (batch_size, num_outputs, model.output_size)
# # model.num_outputs is 128 and model.output_size is computed (257 in your case)
# dummy_target = torch.randn(batch_size, generator.num_outputs, generator.output_size, device=device)

# # Reset GPU peak memory stats
# torch.cuda.reset_peak_memory_stats(device)

# # Run a single forward and backward pass
# optimizer.zero_grad()
# # If separate is False, the model returns a dictionary; pass the correct key.
# output = generator(dummy_input)["vocal"]
# loss = loss_fn(output, dummy_target)
# loss.backward()
# optimizer.step()

# # Retrieve GPU memory stats
# peak_memory = torch.cuda.max_memory_allocated(device)
# current_memory = torch.cuda.memory_allocated(device)
# print("Peak GPU memory allocated (bytes):", peak_memory)
# print("Current GPU memory allocated (bytes):", current_memory)

# # Optionally, print a detailed memory summary
# print(torch.cuda.memory_summary(device=device))


# summary(generator, input_size=(generator.num_inputs,  generator.input_size))


Using valid convolutions with 289 inputs and 257 outputs


Optionally compile the model to potentially decrease training time.

If we compile the model, to save it after training, we have to uncompile it using the following code:

```python
orig_generator = generator._orig_mod
path = ""
torch.save(orig_generator.state_dict(), path + "generator_state_dict.pt")


## Initialize miniRocket
We start by loading the necessary packages

### CPU Core Allocation for MiniRocketMultivariate

- The implementation of `MiniRocketMultivariate` runs on the **CPU**.
- We need to decide how many cores to allocate for it.
- Some cores will be used by MiniRocket itself, while others are needed for data preparation (e.g., generating spectrograms).
- This allocation likely needs to be **tuned for optimal performance**.
- As a starting point, we detect the number of available cores and split them evenly.
- Note: We avoid using *all* available cores to leave some resources for the operating system and other background processes.


In [5]:
import multiprocessing
num_cores = multiprocessing.cpu_count()
print(num_cores)
minirocket_n_jobs = num_cores - 1
# dataloader_n_jobs = num_cores - minirocket_n_jobs - 1

12


Create the MiniRocket model

In [6]:
import pandas as pd
from time import time

# MiniRocket Discriminator using tsai library
class TsaiMiniRocketDiscriminator(nn.Module):
    def __init__(
        self,
        freq_bins=256,
        time_frames=256,
        num_kernels=10000,  # number of convolutional kernels
        hidden_dim=1024,    # Increased to handle larger feature dimension
        output_dim=1
    ):
        super(TsaiMiniRocketDiscriminator, self).__init__()

        # This is the mini rocket transformer which extracts features
        self.rocket = MiniRocketMultivariate(num_kernels=num_kernels, n_jobs=minirocket_n_jobs)
        # tsai's miniRocketClassifier is implemented with MiniRocketMultivariate as well
        self.fitted = False   # fit before training
        self.freq_bins = freq_bins
        self.time_frames = time_frames
        self.num_kernels = num_kernels

        # For 2D data handling - process each sample with proper dimensions
        self.example_input = np.zeros((1, freq_bins, time_frames))

        self.feature_dim = num_kernels  # For vocals + accompaniment

        # Example feature reducing layers
        self.classifier = nn.Sequential(
            # First reduce the massive dimension to something manageable
            nn.LazyLinear( hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Second hidden layer
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Final classification layer
            nn.Linear(hidden_dim // 2, output_dim),
            nn.Sigmoid()
        )

    def fit_rocket(self, spectrograms):
        """
            Fit MiniRocket with just one piece of vocal training data (not the entire training dataset)
        """
        if not self.fitted:
            try:
                # Reshape for MiniRocket - it expects (n_instances, n_dimensions, series_length)
                # flatten the freq_bins dimension to create a multivariate time series
                batch_size = spectrograms.shape[0]

                # Convert first to numpy for sktime processing
                sample_data = spectrograms.detach().cpu().numpy()
                # print(sample_data.shape)
                # Reshape to sktime's expected format - reduce to single sample for fitting
                # sample_data = sample_data[:, 0]  # Take one sample, remove channel dim

                # Fit on this sample
                self.rocket.fit(sample_data)
                self.fitted = True

                # Test transform to get feature dimension
                test_transform = self.rocket.transform(sample_data)
                self.feature_dim = test_transform.shape[1]

                print(f"MiniRocket fitted. Feature dimension: {self.feature_dim}")

            except Exception as e:
                print(f"Error fitting MiniRocket: {e}")
                # Use a fallback if fitting fails
                self.fitted = True  # Mark as fitted to avoid repeated attempts

    def extract_features(self, spectrogram):
        """Extract MiniRocket features from a spectrogram"""
        try:
            # Ensure rocket is fitted
            if not self.fitted:
                self.fit_rocket(spectrogram)

            # Convert to numpy for sktime
            spec_np = spectrogram.detach().cpu().numpy()

            # Remove channel dimension expected by sktime
            # print(spec_np.shape)
            # spec_np = spec_np[:, 0]  # [batch_size, freq_bins, time_frames]
            # print(spec_np.shape)

            # This step extracts features using the convolutional kernels, numbers specified by num_kernels
            # print("1")
            features = self.rocket.transform(spec_np)
            # print("2")
            # Convert back to torch tensor
            # print("features:", features.shape)
            # print(features.head())
            features_tensor = torch.tensor(features.values).to(spectrogram.device)
            # print("features:", features.shape)
            # print("3")
            return features_tensor

        except Exception as e:
            print(f"Error in feature extraction: {e}")
            # Return zeros as fallback
            return torch.zeros((spectrogram.shape[0], self.num_kernels),
                              device=spectrogram.device)

    def forward(self, vocals, accompaniment):
        """
        Forward pass of the discriminator

        Args:
            vocals: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
            accompaniment: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
        """
        # Extract features from both spectrograms
        # start_time = time()
        vocal_features = self.extract_features(vocals)
        accomp_features = self.extract_features(accompaniment)
        # print("extract:", time()-start_time)
        # Concatenate features (conditional GAN)
        combined_features = torch.cat([vocal_features, accomp_features], dim=1)
        # print(combined_features.size())

        # Classify as real/fake
        validity = self.classifier(combined_features)

        return validity



In [7]:
# del discriminator
discriminator = TsaiMiniRocketDiscriminator()
# We probably do not need to compile the model

# Import Data into Session

First, we run the code that defines the custom Dataset objects. The Datasets were compiled previously and saved in .pt files. In the next cell, we load those Dataset objects.

In [8]:
class MusdbDataset(Dataset):

  def __init__(self, musDB, steps = 256):
    self.mel_specs = torch.zeros(1, 2, 128, steps)
    self.sample_rates = torch.tensor([0])

    print("Tracks in MusDB:", len(musDB))

    for track in musDB:
      stems, rate = track.stems, track.rate

      # separate the vocal from other instruments and conver to mono signal
      audio_novocal = librosa.to_mono(np.transpose(stems[1] + stems[2] + stems[3]))
      audio_vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      logmelspec_novocal = torch.from_numpy(self._mel_spectrogram(audio_novocal, rate))
      logmelspec_vocal = torch.from_numpy(self._mel_spectrogram(audio_vocal, rate))

      num_slices = logmelspec_novocal.shape[1] // steps

      # chop off the last bit so that number of stft steps is a multiple of step size
      logmelspec_novocal = logmelspec_novocal[0:128 , 0:num_slices*steps]
      logmelspec_vocal = logmelspec_vocal[0:128, 0:num_slices*steps]

      logmelspec_novocal = torch.reshape(logmelspec_novocal, (num_slices, 128, steps))
      logmelspec_vocal = torch.reshape(logmelspec_vocal, (num_slices, 128, steps))

      # unsqueeze and concatenate these tensors. Then concatenate to the big tensor
      logmels = torch.cat((logmelspec_novocal.unsqueeze(1), logmelspec_vocal.unsqueeze(1)), 1)
      self.mel_specs = torch.cat((self.mel_specs, logmels), 0)
      self.sample_rates = torch.cat((self.sample_rates, torch.Tensor([rate])), 0)

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx, 0], self.mel_specs[ndx, 1], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))


class LibriSpeechDataset(Dataset):

  def __init__(self, path, steps = 256, num_specs = 7647):
    self.mel_specs = self.mel_specs = torch.zeros(1, 128, steps)
    self.sample_rates = torch.tensor([0])

    num_files_opened = 0

    for speaker_dir in os.listdir(path):
      speaker_path = path + "/" + speaker_dir
      for chapter_dir in os.listdir(speaker_path):
        chapter_path = speaker_path + "/" + chapter_dir
        for file in os.listdir(chapter_path):
          # checks file extension and stops when we hit desired number of spectrograms (num_specs)
          if file.endswith('.flac') and self.mel_specs.shape[0] - 1 < num_specs:

            try:
              # get audio file and convert to log mel spectrogram
              speech, rate = librosa.load(chapter_path + "/" + file, sr = 44100)
              mel_spec = torch.from_numpy(self._mel_spectrogram(speech, rate))

              # Saves the total number of 128 x (steps) spectrograms
              num_slices = mel_spec.shape[1] // steps

              # chop off the last bit so that number of stft steps is a multiple of step size
              mel_spec = mel_spec[ : , 0 : num_slices*steps]

              # reshape the tensor to have many spectrograms of size 128 x (steps)
              mel_spec = torch.transpose(torch.reshape(mel_spec, (128, num_slices, steps)), 0, 1)

              # concatenate tensor to the full tensor in the Dataset object
              self.mel_specs = torch.cat((self.mel_specs, mel_spec), 0)
              self.sample_rates = torch.cat((self.sample_rates, torch.Tensor([rate])), 0)
              num_files_opened += 1

            except:
              print("failed to open " + file)


    # chop off the zero layer we initialized with
    self.mel_specs = self.mel_specs[1:]
    self.sample_rates = self.sample_rates[1:]
    print("opened " + str(num_files_opened) + " files")
    print("collected " + str(self.mel_specs.shape[0]) + " chunks")

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    return self.mel_specs[ndx], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

In [10]:
path = "/content/drive/MyDrive/git_projects/spring_2025_dl_audio_project_data/"

# The string below is the path to the saved MusdbDataset in your Drive
musdbDataset_path = path + "Copy of musdb18_DatasetObject.pt"

# The string below is the path to the saved LibriSpeechDataset in your Drive
librispeechDataset_path = path + "Copy of LibriSpeechDatasetObject.pt"

musdb_dataset = torch.load(musdbDataset_path, weights_only=False)
librispeech_dataset = torch.load(librispeechDataset_path, weights_only=False)




FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/git_projects/spring_2025_dl_audio_project_data/Copy of musdb18_DatasetObject.pt'

In [None]:
# This fixes the problem with the sample rates
musdb_dataset.sample_rates = torch.full((len(musdb_dataset),), 44100)
librispeech_dataset.sample_rates = torch.full((len(musdb_dataset),), 44100)

# Because of the way the librispeech dataset was constructed, it is slightly longer
# than the musbd dataset. Crop the librispeech dataset with these lines
librispeech_dataset.mel_specs = librispeech_dataset.mel_specs[0:len(musdb_dataset)]
librispeech_dataset.sample_rates = librispeech_dataset.sample_rates[0:len(musdb_dataset)]

### Explore these datasets

In [None]:
# --- Explore the Datasets ---
print("=== MusDB Dataset Exploration ===")
print("Length:", len(musdb_dataset))
print("mel_specs shape:", musdb_dataset.mel_specs.shape)
print("sample_rates shape:", musdb_dataset.sample_rates.shape)
print()
accompaniment, vocal, sample_rate = musdb_dataset[0]
print("Sample 0 - Accompaniment shape:", accompaniment.size())
print("Sample 0 - Vocal shape:", vocal.size())
print("Sample 0 - Sample rate:", sample_rate)
print()

print("=== LibriSpeech Dataset Exploration ===")
print("Length:", len(librispeech_dataset))
print("mel_specs shape:", librispeech_dataset.mel_specs.shape)
print("sample_rates shape:", librispeech_dataset.sample_rates.shape)
print()
speech, sample_rate = librispeech_dataset[0]
print("Sample 0 - Speech shape:", speech.size())
print("Sample 0 - Sample rate:", sample_rate)

## Dataset Helpers Explanation
Why New Dataset Helpers?

We have created new dataset helper classes (i.e., AccompanimentData, VocalData, and SpeechData) so that we can control how the data is padded and later shuffled.

- **Separation of Data:**
We separated the vocal and accompaniment data from the MusDB dataset. In our experiments, we might want to shuffle the speech data independently of the combined music data.

- **Shuffling Considerations:**
For the vocal and accompaniment data, we want to maintain their pairing so that they are shuffled in the same order. In contrast, we want the speech data to be shuffled independently.

- **Future Extensions:**
In the future, we may add another helper class that combines the vocal and accompaniment data to ensure synchronized shuffling in our data loaders.

This modular approach gives us flexibility in handling and preprocessing the data for our GAN training.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

class AccompanimentData(Dataset):
    def __init__(self, musdb_dataset, output_length=289):
        self.musdb_dataset = musdb_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.musdb_dataset)

    def __getitem__(self, index):
        accompaniment, _, _ = self.musdb_dataset[index]  # shape: [128, 256]
        current_len = accompaniment.size(-1)             # 256
        delta = self.output_length - current_len         # 289 - 256 = 33

        # If delta is positive, pad. Otherwise, you might want to crop or handle differently.
        if delta > 0:
            # Half the remainder goes to the front
            left_pad_len = (delta // 2) + (delta % 2)  # 17
            right_pad_len = delta // 2                # 16
            accompaniment_pad = F.pad(accompaniment,
                                  (left_pad_len, right_pad_len),
                                  "constant", 0)
        return {"no_pad" : accompaniment, "pad" : accompaniment_pad}


class VocalData(Dataset):
    def __init__(self, musdb_dataset, output_length=289):
        self.musdb_dataset = musdb_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.musdb_dataset)

    def __getitem__(self, index):
        _, vocal, _ = self.musdb_dataset[index]  # shape: [128, 256]
        current_len = vocal.size(-1)
        delta = self.output_length - current_len

        if delta > 0:
            left_pad_len = (delta // 2) + (delta % 2)
            right_pad_len = delta // 2
            vocal_pad = F.pad(vocal, (left_pad_len, right_pad_len), "constant", 0)
        return {"no_pad" : vocal, "pad" : vocal_pad}



class SpeechData(Dataset):
    def __init__(self, librispeech_dataset, output_length=289):
        self.librispeech_dataset = librispeech_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.librispeech_dataset)

    def __getitem__(self, index):
        speech, _ = self.librispeech_dataset[index]
        # If speech has multiple slices, pick the first slice
        if speech.dim() == 3:
            speech = speech[0]  # shape: [128, 256]
        current_len = speech.size(-1)
        delta = self.output_length - current_len

        if delta > 0:
            left_pad_len = (delta // 2) + (delta % 2)
            right_pad_len = delta // 2
            speech_pad = F.pad(speech, (left_pad_len, right_pad_len), "constant", 0)
        return {"no_pad" : speech, "pad" : speech_pad}


In [None]:
# print(AccompanimentData(musdb_dataset)[0])
# print(VocalData(musdb_dataset)[0])
# print(SpeechData(librispeech_dataset)[0])

### DataLoader Explanation
What is a DataLoader and Why Do We Need It?

A DataLoader in PyTorch is a utility that wraps a dataset and provides:

- **Batching:** It divides your dataset into batches so that you can train your models with mini-batch gradient descent.

- **Shuffling:** It shuffles the data at every epoch (if specified) to help reduce overfitting and ensure the model sees a diverse set of examples.

- **Parallel Data Loading:** It can load data in parallel using multiple worker processes, speeding up training.

In our case, we create separate DataLoaders for:

- The accompaniment data (paired with vocals) from the MusDB dataset.

- The vocal data (paired with accompaniment) from the MusDB dataset.

- The speech data from the LibriSpeech dataset.

This lets us shuffle the speech data independently, while keeping the vocal/accompaniment pairs synchronized during training.

In [None]:
# Define batch size
batch_size = 32  # Change as needed

# Create data loaders
accompaniment_loader = DataLoader(
    AccompanimentData(musdb_dataset),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)
vocal_loader = DataLoader(
    VocalData(musdb_dataset),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)
speech_loader = DataLoader(
    SpeechData(librispeech_dataset),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

# # Print how many batches each DataLoader contains
# print("Accompaniment loader length:", len(accompaniment_loader))
# print("Vocal loader length:", len(vocal_loader))
# print("Speech loader length:", len(speech_loader))

# # Optionally, fetch and print the shape of the first batch
# accompaniment_batch = next(iter(accompaniment_loader))
# vocal_batch = next(iter(vocal_loader))
# speech_batch = next(iter(speech_loader))
# print(accompaniment_batch["pad"])

# print("Accompaniment first batch shape:", accompaniment_batch.shape)
# print("Vocal first batch shape:", vocal_batch.shape)
# print("Speech first batch shape:", speech_batch.shape)


## Second Generator Model
Here we initialize the second generator model whose purpose is to convert the generated vocals back to normal speach for the cycle GAN. We again use Wave-U-Net, but with a different configuration. The main difference is that we will not input the music along with the vocal track.

In [None]:
model_config = {
    "num_inputs": 128,               # 128 mel bins per spectrogram, this time we only have one spectrogram
    "num_outputs": 128,              # Output also has 128 mel bins
    "num_channels": [256*2, 256*4, 256*8],    # The input has half the channels, so we might be able get away with half the hidden channels
    "instruments": ["speech"],        # Only output vocal, so no music branch
    "kernel_size": 3,                # Must be odd
    "target_output_size": 256,       # Desired output time frames (post-processing may crop)
    "conv_type": "normal",           # Set to "normal" to meet assertion requirements
    "res": "fixed",                  # Use fixed resampling
    "separate": False,                # Separate branch for vocal
    "depth": 1,                      # Number of conv layers per block
    "strides": 2                   # Down/up-sampling stride
}

# Instantiate and move the model to GPU
generator_2 = Waveunet(**model_config).to(device)

## Transform Input to generator_2

The output of the generator model is a (batch_size, 128, 257) tensor. The model expects a tensor of size (batch_size, 128, 289). We need to pad the last dimension with 16 zeros on each size.

In [None]:
def transform_for_gen_2(batch, output_length = 289):
  current_len = batch.size(-1)
  delta = output_length - current_len

  if delta > 0:
      left_pad_len = (delta // 2) + (delta % 2)
      right_pad_len = delta // 2
      batch = F.pad(batch, (left_pad_len, right_pad_len), "constant", 0)
  return batch


## Train the Cycle GAN
The models are ``generator`` and ``discriminator`` and ``generator_2``.


In [None]:
from tqdm import tqdm
# Add a virtual batch?
def train_epoch(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    adversarial_loss,
    device):


    total_loss_D = 0.0
    total_loss_G = 0.0
    total_loss_G_adv = 0.0
    total_loss_G_L1 = 0.0
    total_loss_cycle = 0.0
    num_batches = 0

    train_loader = tqdm(enumerate(zip(zip(accompaniment_loader, vocal_loader), speech_loader)))
    print(device)
    for batch_number, ((accompaniment, vocal), speech) in train_loader:
        # Move data to device and CAST TO FLOAT32
        accompaniment_pad_device = accompaniment["pad"].type(torch.float32).to(device)  # [B, 128, 289]
        speech_pad_device = speech["pad"].type(torch.float32).to(device)                # [B, 128, 289]

        # Prepare generator input by concatenating speech and accompaniment along channel dimension.
        # Resulting shape: [B, 256, 289]
        generator_input = torch.cat([speech_pad_device, accompaniment_pad_device], dim=1)

        # ---------------------
        # Train the Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Real labels (optionally with label smoothing)
        B = accompaniment_pad_device.size(0)
        real_labels = torch.full((B, 1), 1.0, device=device)
        fake_labels = torch.zeros((B, 1), device=device)


        # Discriminator output for real pairs: (vocal, accompaniment)
        accompaniment_no_pad = accompaniment["no_pad"].type(torch.float32).to(device)
        vocals_no_pad = vocal["no_pad"].type(torch.float32).to(device)
        # print("real to discriminator")
        # print(vocals_no_pad.size())
        # print(accompaniment_no_pad.size())
        pred_real = discriminator(vocals_no_pad, accompaniment_no_pad)
        loss_D_real = adversarial_loss(pred_real, real_labels)

        # Generate fake singing using the generator
        # print("fake to discriminator")
        fake_singing = generator(generator_input)["vocal"]
        fake_singing_crop = fake_singing[:, :, :256]
        # print("vocal size:", vocal.size())
        # print("fake_singing size:", fake_singing.size())

        # Discriminator output for fake pairs: (generated singing)

        pred_fake = discriminator(fake_singing_crop, accompaniment_no_pad)
        loss_D_fake = adversarial_loss(pred_fake, fake_labels)

        # Total discriminator loss and update
        loss_D = 0.5 * (loss_D_real + loss_D_fake)
        loss_D.backward()
        optimizer_D.step()

        # ---------------------
        # Train the Generators
        # ---------------------
        optimizer_G.zero_grad()

        # Generator wants the discriminator to output "real" for its fake singing.
        pred_fake_for_G = discriminator(fake_singing, accompaniment_no_pad)

        # Reconstruct the speech
        fake_singing_padded = transform_for_gen_2(fake_singing)
        reconstructed_speech = generator_2(fake_singing_padded)["speech"][:, :, :256]
        speech_no_pad_device = speech["no_pad"].type(torch.float32).to(device)

        # Compute the losses for the generators
        loss_G_adv = adversarial_loss(pred_fake_for_G, real_labels)
        loss_cycle = l1_loss(reconstructed_speech, speech_no_pad_device)
        loss_G_L1 = l1_loss(fake_singing_crop, vocals_no_pad)

        # Total generator loss
        lambda_l1 = 10  # Weight for L1 loss
        lambda_cycle = 10  # weight for cycle loss
        loss_G = loss_G_adv + lambda_l1 * loss_G_L1 + lambda_cycle * loss_cycle
        loss_G.backward()

        # Optimizer steps
        optimizer_G.step()
        optimizer_G2.step()

        # Accumulate losses
        total_loss_D += loss_D.item()
        total_loss_G += loss_G.item()
        total_loss_G_adv += loss_G_adv.item()
        total_loss_G_L1 += loss_G_L1.item()
        total_loss_cycle += loss_cycle.item()
        num_batches += 1

    print(
        f"Epoch Averages:\n"
        f"Loss_D: {total_loss_D / num_batches:.4f}  "
        f"Loss_G_total: {total_loss_G / num_batches:.4f}  "
        f"Loss_G_adv: {total_loss_G_adv / num_batches:.4f}  "
        f"Loss_G_L1 (vocal): {total_loss_G_L1 / num_batches:.4f}  "
        f"Loss_Cycle (speech): {total_loss_cycle / num_batches:.4f}"
    )



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)

# Optional
generator = torch.compile(generator, mode='max-autotune')
generator_2 = torch.compile(generator_2, mode='max-autotune')

discriminator = discriminator.to(device)

accompaniment_batch = next(iter(accompaniment_loader))
vocal_batch = next(iter(vocal_loader))
speech_batch = next(iter(speech_loader))["no_pad"]
print("fitting discriminator")
discriminator.fit_rocket(speech_batch)

adversarial_loss = nn.BCELoss().to(device)
l1_loss = nn.L1Loss().to(device)

# Optimizers for generator and discriminator
optimizer_G = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_G2 = optim.Adam(generator_2.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=5e-5, betas=(0.5, 0.999))


In [None]:
# Number of training epochs
num_epochs = 1
# This seems to sometimes help
import gc
gc.collect()
torch.cuda.empty_cache()

for i in range(num_epochs):
    print(f"start epoch {i}")
    train_epoch(
        generator,
        generator_2,
        discriminator,
        optimizer_D,
        optimizer_G,
        optimizer_G2,
        accompaniment_loader,
        vocal_loader,
        speech_loader,
        l1_loss,
        adversarial_loss,
        device
    )

## Save the models

In [None]:
# Assuming we have compiled the generator
# orig_generator = generator._orig_mod
# path = ""
# torch.save(orig_generator.state_dict(), path + "generator_state_dict.pt")
# # Save the discriminator state dict
# torch.save(discriminator.state_dict(), path + "discriminator_state_dict.pt")