<a href="https://colab.research.google.com/github/jjaw89/spring_2025_dl_audio_project/blob/main/cycle_GAN_train_shuffled_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Here we train our first version of the GAN.



## Initialize Wave-U-Net

We start by loading the necessary packages

Wave-U-Net is named ``generator``

In [None]:
# === Check to see if we are in colab ===
import sys

if 'google.colab' in sys.modules:
    colab = True
else:
    colab = False
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
#    !{sys.executable} -m pip install musdb
    %cd /content/drive/My\ Drive/git_projects/spring_2025_dl_audio_project

# === Built-in modules ===
import os
import time
import argparse
import pickle
from functools import partial
from datetime import datetime
from typing import Tuple, List, Dict, Optional

# === Third-party modules ===
import numpy as np
import pandas as pd
import librosa
import torchaudio
import matplotlib.pyplot as plt
from tqdm import tqdm

# === PyTorch ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import L1Loss
from torch.nn.utils import spectral_norm
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary



# === Colab-specific installation (only run if in Colab) ===
if colab:
    !pip install sktime
    !{sys.executable} -m pip install musdb
    !{sys.executable} -m pip uninstall -y stempeg  # musdb installs the wrong version of stempeg

# === Audio/MIR ===
if colab:
  %cd /content/drive/My Drive/git_projects/
import stempeg
import musdb

# === Time series transforms ===
from sktime.transformations.panel.rocket import MiniRocketMultivariate

if colab:
  %cd /content/drive/My\ Drive/git_projects/spring_2025_dl_audio_project/Wave-U-Net-Pytorch

# === Local module imports ===
if not colab:
  sys.path.append('Wave-U-Net-Pytorch')
  sys.path.append("workspace/hdd_project_data")


import model.utils as model_utils
import utils
from model.waveunet import Waveunet

# === Device setup ===
cuda = torch.cuda.is_available()
if cuda:
    device = torch.device('cuda')
    print("GPU:", torch.cuda.get_device_name(0))
else:
    raise Exception("GPU not available. Please check your setup.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/.shortcut-targets-by-id/14hS-tzi4BinLtBzv8urx6eA2XOhsMNJI/git_projects/spring_2025_dl_audio_project
Collecting stempeg>=0.2.3 (from musdb)
  Using cached stempeg-0.2.3-py3-none-any.whl.metadata (9.0 kB)
Using cached stempeg-0.2.3-py3-none-any.whl (963 kB)
Installing collected packages: stempeg
Successfully installed stempeg-0.2.3
Found existing installation: stempeg 0.2.3
Uninstalling stempeg-0.2.3:
  Successfully uninstalled stempeg-0.2.3
/content/drive/.shortcut-targets-by-id/14hS-tzi4BinLtBzv8urx6eA2XOhsMNJI/git_projects
/content/drive/.shortcut-targets-by-id/14hS-tzi4BinLtBzv8urx6eA2XOhsMNJI/git_projects/spring_2025_dl_audio_project/Wave-U-Net-Pytorch
GPU: NVIDIA A100-SXM4-40GB


## Initialize miniRocket
We start by loading the necessary packages

### CPU Core Allocation for MiniRocketMultivariate

- The implementation of `MiniRocketMultivariate` runs on the **CPU**.
- We need to decide how many cores to allocate for it.
- Some cores will be used by MiniRocket itself, while others are needed for data preparation (e.g., generating spectrograms).
- This allocation likely needs to be **tuned for optimal performance**.
- As a starting point, we detect the number of available cores and split them evenly.
- Note: We avoid using *all* available cores to leave some resources for the operating system and other background processes.


Create the MiniRocket model

In [None]:
# MiniRocket Discriminator using tsai library
class TsaiMiniRocketDiscriminator(nn.Module):
    def __init__(
        self,
        freq_bins=256,
        time_frames=256,
        num_kernels=10000,  # number of convolutional kernels
        hidden_dim=1024,    # Increased to handle larger feature dimension
        output_dim=1,
        accompaniment = True   # whether or not we feed accompaniment
    ):
        super(TsaiMiniRocketDiscriminator, self).__init__()

        # This is the mini rocket transformer which extracts features
        self.rocket = MiniRocketMultivariate(num_kernels=num_kernels, n_jobs=minirocket_n_jobs)
        # tsai's miniRocketClassifier is implemented with MiniRocketMultivariate as well
        self.fitted = False   # fit before training
        self.freq_bins = freq_bins
        self.time_frames = time_frames
        self.num_kernels = num_kernels
        self.accompaniment = accompaniment

        # For 2D data handling - process each sample with proper dimensions
        self.example_input = np.zeros((1, freq_bins, time_frames))

        self.feature_dim = num_kernels  # For vocals + accompaniment

        if self.accompaniment:
          classifier_input_dim = 19992
        else:
          classifier_input_dim = 9996

        # Example feature reducing layers
        self.classifier = nn.Sequential(
            # First reduce the massive dimension to something manageable
            nn.Dropout(0.3),
            spectral_norm(nn.Linear(classifier_input_dim, hidden_dim)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Second hidden layer
            nn.Dropout(0.3),
            spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Final classification layer
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, output_dim),
            nn.Sigmoid()
            # nn.Tanh()
        )
        #         # Example feature reducing layers
        # self.classifier = nn.Sequential(
        #     # First reduce the massive dimension to something manageable
        #     nn.Linear(19992, hidden_dim),
        #     nn.LeakyReLU(0.2),
        #     nn.Dropout(0.3),

        #     # Second hidden layer
        #     nn.Linear(hidden_dim, hidden_dim // 2),
        #     nn.LeakyReLU(0.2),
        #     nn.Dropout(0.3),

        #     # Final classification layer
        #     nn.Linear(hidden_dim // 2, output_dim),
        #     nn.Sigmoid()
        # )

    def fit_rocket(self, spectrograms):
        """
            Fit MiniRocket with just one piece of vocal training data (not the entire training dataset)
        """
        if not self.fitted:
            try:
                # Reshape for MiniRocket - it expects (n_instances, n_dimensions, series_length)
                # flatten the freq_bins dimension to create a multivariate time series
                batch_size = spectrograms.shape[0]

                # Convert first to numpy for sktime processing
                sample_data = spectrograms.detach().cpu().numpy()
                # print(sample_data.shape)
                # Reshape to sktime's expected format - reduce to single sample for fitting
                # sample_data = sample_data[:, 0]  # Take one sample, remove channel dim

                # Fit on this sample
                self.rocket.fit(sample_data)
                self.fitted = True

                # Test transform to get feature dimension
                test_transform = self.rocket.transform(sample_data)
                self.feature_dim = test_transform.shape[1]

                print(f"MiniRocket fitted. Feature dimension: {self.feature_dim}")

            except Exception as e:
                print(f"Error fitting MiniRocket: {e}")
                # Use a fallback if fitting fails
                self.fitted = True  # Mark as fitted to avoid repeated attempts

    def extract_features(self, spectrogram):
        """Extract MiniRocket features from a spectrogram"""
        try:
            # Ensure rocket is fitted
            if not self.fitted:
                self.fit_rocket(spectrogram)

            # Convert to numpy for sktime
            spec_np = spectrogram.detach().cpu().numpy()

            # Remove channel dimension expected by sktime
            # print(spec_np.shape)
            # spec_np = spec_np[:, 0]  # [batch_size, freq_bins, time_frames]
            # print(spec_np.shape)

            # This step extracts features using the convolutional kernels, numbers specified by num_kernels
            # print("1")
            features = self.rocket.transform(spec_np)
            # print("2")
            # Convert back to torch tensor
            # print("features:", features.shape)
            # print(features.head())
            features_tensor = torch.tensor(features.values).to(spectrogram.device)
            # print("features:", features.shape)
            # print("3")
            return features_tensor

        except Exception as e:
            print(f"Error in feature extraction: {e}")
            # Return zeros as fallback
            return torch.zeros((spectrogram.shape[0], self.num_kernels),
                              device=spectrogram.device)

    def forward(self, vocals, accompaniment = None):
        """
        Forward pass of the discriminator

        Args:
            vocals: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
            accompaniment: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
        """
        # Extract features from both spectrograms
        # start_time = time()
        vocal_features = self.extract_features(vocals)

        if self.accompaniment:
          accomp_features = self.extract_features(accompaniment)
          # print("extract:", time()-start_time)
          # Concatenate features (conditional GAN)
          output_features = torch.cat([vocal_features, accomp_features], dim=1)
          # print(combined_features.size())
        else:
          output_features = vocal_features

        # Classify as real/fake
        # print(combined_features.size())
        validity = self.classifier(output_features)

        return validity



# Import Data into Session

First, we run the code that defines the custom Dataset objects. The Datasets were compiled previously and saved in .pt files. In the next cell, we load those Dataset objects.

In [None]:
if colab:
  %cd /content/drive/My Drive/git_projects/spring_2025_dl_audio_projects
else:
  sys.path.append('/workspace/hdd_project_data/stempeg')

import stempeg


class MusdbDataset(Dataset):

  def __init__(self, musDB, window_size = 256, step_size = 128):
    self.mel_specs = torch.zeros(1, 2, 128, window_size)
    self.sample_rates = torch.tensor([0])

    num_songs = 0

    for track in musDB:
      stems, rate = track.stems, track.rate

      num_songs += 1

      # separate the vocal from other instruments and conver to mono signal
      audio_novocal = librosa.to_mono(np.transpose(stems[1] + stems[2] + stems[3]))
      audio_vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      logmelspec_novocal = torch.from_numpy(self._mel_spectrogram(audio_novocal, rate))
      logmelspec_vocal = torch.from_numpy(self._mel_spectrogram(audio_vocal, rate))

      start_ndx = 0

      for step in range(window_size // step_size):
        cropped_logmelspec_novocal = logmelspec_novocal[:, start_ndx:]
        cropped_logmelspec_vocal = logmelspec_vocal[:, start_ndx:]
        num_slices = cropped_logmelspec_novocal.shape[1] // window_size

        # chop off the last bit so that number of stft steps is a multiple of window_size
        cropped_logmelspec_novocal = cropped_logmelspec_novocal[: , 0:num_slices*window_size]
        cropped_logmelspec_vocal = cropped_logmelspec_vocal[:, 0:num_slices*window_size]

        # reshape tensors into chunks of size 128x(window_size)
        # first dimension is number of chunks
        cropped_logmelspec_novocal = torch.transpose(torch.reshape(cropped_logmelspec_novocal, (128, num_slices, window_size)), 0, 1)
        cropped_logmelspec_vocal = torch.transpose(torch.reshape(cropped_logmelspec_vocal, (128, num_slices, window_size)), 0, 1)

        # unsqueeze and concatenate these tensors. Then concatenate to the big tensor
        logmels = torch.cat((cropped_logmelspec_novocal.unsqueeze(1), cropped_logmelspec_vocal.unsqueeze(1)), 1)
        self.mel_specs = torch.cat((self.mel_specs, logmels), 0)
        self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)

        if num_songs % 5 == 0:
          print(str(num_songs) + " songs processed; produced " + str(self.mel_specs.shape[0]) + " spectrograms")

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx, 0], self.mel_specs[ndx, 1], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

  def cat(self, other_ds):
    self.mel_specs = torch.cat((self.mel_specs, other_ds.mel_specs), 0)
    self.sample_rates = torch.cat((self.sample_rates, other_ds.sample_rates), 0)



class SingingDataset(Dataset):

  def __init__(self, musDB, window_size = 256, step_size = 128):
    self.mel_specs = torch.zeros(1, 128, window_size)
    self.sample_rates = torch.tensor([0])

    num_songs = 0

    for track in musDB:
      stems, rate = track.stems, track.rate

      num_songs += 1

      # load the vocal
      vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      mel_spec = torch.from_numpy(self._mel_spectrogram(vocal, rate))

      start_ndx = 0
      for step in range(window_size // step_size):
        cropped_mel_spec = mel_spec[:, start_ndx:]
        num_slices = cropped_mel_spec.shape[1] // window_size

        # chop off the last bit so that number of stft steps is a multiple of window_size
        cropped_mel_spec = cropped_mel_spec[:, 0:num_slices*window_size]

        # reshape tensors into chunks of size 128x(window_size)
        # first dimension is number of chunks
        cropped_mel_spec = torch.transpose(torch.reshape(cropped_mel_spec, (128, num_slices, window_size)), 0, 1)

        # concatenate to the big tensor
        self.mel_specs = torch.cat((self.mel_specs, cropped_mel_spec), 0)
        self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)


    if num_songs % 5 == 0:
        print(str(num_songs) + " songs processed; produced " + str(self.mel_specs.shape[0]) + " spectrograms")

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))



class LibriSpeechDataset(Dataset):

    def __init__(self, path, window_size = 256, step_size = 128, num_specs = 7647*2):
        self.mel_specs = self.mel_specs = torch.zeros(1, 128, window_size)
        self.sample_rates = torch.tensor([0])

        num_files_opened = 0

        for speaker_dir in os.listdir(path):
            speaker_path = path + "/" + speaker_dir
            for chapter_dir in os.listdir(speaker_path):
                chapter_path = speaker_path + "/" + chapter_dir
                for file in os.listdir(chapter_path):
                    # checks file extension and stops when we hit desired number of spectrograms (num_specs)
                    if file.endswith('.flac') and self.mel_specs.shape[0] - 1 < num_specs:
                        # get audio file and convert to log mel spectrogram
                        speech, rate = librosa.load(chapter_path + "/" + file, sr = 44100)
                        mel_spec = torch.from_numpy(self._mel_spectrogram(speech, rate))
                        start_ndx = 0

                        num_files_opened += 1

                        for step in range(window_size // step_size):
                            cropped_mel_spec = mel_spec[:, start_ndx:]

                            # Saves the total number of 128 x (window_size) spectrograms
                            num_slices = cropped_mel_spec.shape[1] // window_size

                            # chop off the last bit so that number of stft steps is a multiple of window_size
                            cropped_mel_spec = cropped_mel_spec[ : , 0 : num_slices*window_size]

                            # reshape the tensor to have many spectrograms of size 128 x (steps)
                            cropped_mel_spec = torch.transpose(torch.reshape(cropped_mel_spec, (128, num_slices, window_size)), 0, 1)

                            # concatenate tensor to the full tensor in the Dataset object
                            self.mel_specs = torch.cat((self.mel_specs, cropped_mel_spec), 0)
                            self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)

                            # increment start_ndx
                            start_ndx += step_size


                        if num_files_opened % 50 == 0:
                            print("opened " + str(num_files_opened) + " files and produced " + str(self.mel_specs.shape[0]) + " spectrograms")


        # chop off the zero layer we initialized with
        self.mel_specs = self.mel_specs[1:]
        self.sample_rates = self.sample_rates[1:]

    def __len__(self):
        return self.mel_specs.shape[0]

    def __getitem__(self, ndx):
        return self.mel_specs[ndx], self.sample_rates[ndx]

    def _mel_spectrogram(self, audio, rate):
        # compute the log-mel-spectrogram of the audio at the given sample rate
        return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

[Errno 2] No such file or directory: '/content/drive/My Drive/git_projects/spring_2025_dl_audio_projects'
/content/drive/.shortcut-targets-by-id/14hS-tzi4BinLtBzv8urx6eA2XOhsMNJI/git_projects/spring_2025_dl_audio_project/Wave-U-Net-Pytorch


### Explore these datasets

## Dataset Helpers Explanation
Why New Dataset Helpers?

We have created new dataset helper classes (i.e., AccompanimentData, VocalData, and SpeechData) so that we can control how the data is padded and later shuffled.

- **Separation of Data:**
We separated the vocal and accompaniment data from the MusDB dataset. In our experiments, we might want to shuffle the speech data independently of the combined music data.

- **Shuffling Considerations:**
For the vocal and accompaniment data, we want to maintain their pairing so that they are shuffled in the same order. In contrast, we want the speech data to be shuffled independently.

- **Future Extensions:**
In the future, we may add another helper class that combines the vocal and accompaniment data to ensure synchronized shuffling in our data loaders.

This modular approach gives us flexibility in handling and preprocessing the data for our GAN training.

In [None]:


class AccompanimentVocalData(Dataset):
  def __init__(self, musdb_dataset, output_length = 289):
    self.musdb = musdb_dataset
    self.out_len = output_length

  def __len__(self):
    return len(self.musdb)

  def __getitem__(self, ndx):
    acc, voc, _ = self.musdb[ndx]
    delta = self.out_len - acc.size(-1)

    if delta > 0:
      # Half the remainder goes to the front
      left_pad_len = (delta // 2) + (delta % 2)  # 17
      right_pad_len = delta // 2                # 16
      acc_pad = F.pad(acc, (left_pad_len, right_pad_len), "constant", 0)
      voc_pad = F.pad(voc, (left_pad_len, right_pad_len), "constant", 0)

    return {"acc_no_pad" : acc,
            "voc_no_pad" : voc,
            "acc_pad": acc_pad,
            "voc_pad" : voc_pad
            }


class SpeechData(Dataset):
    def __init__(self, librispeech_dataset, output_length=289):
        self.librispeech_dataset = librispeech_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.librispeech_dataset)

    def __getitem__(self, index):
        speech, _ = self.librispeech_dataset[index]
        # If speech has multiple slices, pick the first slice
        if speech.dim() == 3:
            speech = speech[0]  # shape: [128, 256]
        current_len = speech.size(-1)
        delta = self.output_length - current_len

        if delta > 0:
            left_pad_len = (delta // 2) + (delta % 2)
            right_pad_len = delta // 2
            speech_pad = F.pad(speech, (left_pad_len, right_pad_len), "constant", 0)
        return {"no_pad" : speech, "pad" : speech_pad}

def transform_for_gen_2(batch, output_length = 289):
  current_len = batch.size(-1)
  delta = output_length - current_len

  if delta > 0:
      left_pad_len = (delta // 2) + (delta % 2)
      right_pad_len = delta // 2
      batch = F.pad(batch, (left_pad_len, right_pad_len), "constant", 0)
  return batch



### DataLoader Explanation
What is a DataLoader and Why Do We Need It?

A DataLoader in PyTorch is a utility that wraps a dataset and provides:

- **Batching:** It divides your dataset into batches so that you can train your models with mini-batch gradient descent.

- **Shuffling:** It shuffles the data at every epoch (if specified) to help reduce overfitting and ensure the model sees a diverse set of examples.

- **Parallel Data Loading:** It can load data in parallel using multiple worker processes, speeding up training.

In our case, we create separate DataLoaders for:

- The accompaniment data (paired with vocals) from the MusDB dataset.

- The vocal data (paired with accompaniment) from the MusDB dataset.

- The speech data from the LibriSpeech dataset.

This lets us shuffle the speech data independently, while keeping the vocal/accompaniment pairs synchronized during training.

In [None]:


# # Print how many batches each DataLoader contains
# print("Accompaniment loader length:", len(accompaniment_loader))
# print("Vocal loader length:", len(vocal_loader))
# print("Speech loader length:", len(speech_loader))

# # Optionally, fetch and print the shape of the first batch
# accompaniment_batch = next(iter(accompaniment_loader))
# vocal_batch = next(iter(vocal_loader))
# speech_batch = next(iter(speech_loader))
# print(accompaniment_batch["pad"])

# print("Accompaniment first batch shape:", accompaniment_batch.shape)
# print("Vocal first batch shape:", vocal_batch.shape)
# print("Speech first batch shape:", speech_batch.shape)


## Second Generator Model
Here we initialize the second generator model whose purpose is to convert the generated vocals back to normal speach for the cycle GAN. We again use Wave-U-Net, but with a different configuration. The main difference is that we will not input the music along with the vocal track.

## Transform Input to generator_2

The output of the generator model is a (batch_size, 128, 257) tensor. The model expects a tensor of size (batch_size, 128, 289). We need to pad the last dimension with 16 zeros on each size.

## Train the Cycle GAN
The models are ``generator`` and ``discriminator`` and ``generator_2``.


In [None]:
##### Code to convert spectrograms to audio and play audio ######

import IPython.display as ipd
def convert_to_audio(mel_spectrograms, n_fft=2048, hop_length=512, power=2.0, n_iter=32):
    audio_files = []
    sr = 44100

    for mel_spec in mel_spectrograms:
        # Convert Mel spectrogram back to linear spectrogram
        # mel_s
        mel_spec = mel_spec.detach().cpu().numpy() # Remove batch dimension
        print(mel_spec)
        # print(mel_spec.shape)
        linear_spec = librosa.feature.inverse.mel_to_stft(
            mel_spec, sr=sr, n_fft=n_fft, power=power
        )

        # Apply Griffin-Lim algorithm for phase reconstruction
        audio = librosa.griffinlim(
            linear_spec, hop_length=hop_length, n_iter=n_iter
        )
        print(audio)
        audio_files.append(audio)
        # break

    return audio_files

def display_audio(audio_file):
    # y, sr = librosa.load(audio_file, sr=44100)
    ipd.display(ipd.Audio(audio_file, rate=44100))

# A Different Architecture

The code below defines a training loop with a slightly different cycle GAN architecture. In this model, we have
- generator_1 - a Wave-U-Net model trained to accept speech and an accompaniment to produce a fake vocal performance
- generator_2 - a Wave-U-Net model trained to acced singing and produce fake speech
- discriminator_1 - a MiniRocket classifier that determines real vs fake vocal performances
- discriminator_2 - a MiniRocket classifier that determines real vs fake speech.

The training loop iterates the following process on a triple of (vocal, accompaniment, and speech).


A. Speech + Accompaniment --> Vocal --> Reconstructed Speech
1.   Feed the speech and accompaniment into generator_1 to produce a fake vocal
2.   discriminator_1 distinguishes between real vocals and the output of generator_1
3.   generator_2 takes the fake singing output and generates reconstructed speech
4.   Compute L_1 loss between the input speech and the reconstructed speech

B. Vocal --> Fake Speech + Real Accompaniment --> Reconstructed Vocal


1. A real vocal is fed into generator_2 which produces fake speech (the corresponding accompaniment is used later)
2. discriminator_2 distinguishes between real speech and the fake speech produced by generator_2
3. generator_1 takes the fake speech and the real accompaniment and produces a reconstructed vocal
4. Compute L_1 loss between input vocal and the reconstructed vocal



In [None]:
# ----- Single Epoch Training Function -----
def cycle_train_epoch(
    generator_vocal,
    generator_speech,
    discriminator_vocal,
    discriminator_speech,
    optimizer_DV,
    optimizer_DS,
    optimizer_GV,
    optimizer_GS,
    acc_voc_loader,
    speech_loader,
    l1_loss,
    mse_loss,
    lambda_l1,
    lambda_cycle,
    bce_loss,
    device,
    virtual_batch_size,
    clip_length,
    input_size_generators,
    save_output=False,
    smart_discriminator=False,
    batch_size = 32
):

    total_loss_DV = total_loss_DS = total_loss_GV = total_loss_GS = 0
    total_loss_adv_vocal = total_loss_adv_speech = total_loss_cycle_vocal = total_loss_cycle_speech = 0.0
    # Optionally record gradient norms per batch for diagnosing vanishing gradients.
    grad_norms_DV = []
    grad_norms_DS = []
    grad_norms_GV = []
    grad_norms_GS = []
    num_batches = 0


    alpha = 0.9      # Tuneable constant to gate the discriminator training
    running_loss_DV = 0.0
    running_loss_DS = 0.0
    dv_threshold = 0.6
    ds_threshold = 0.6

    optimizer_DV.zero_grad()
    optimizer_DS.zero_grad()
    optimizer_GV.zero_grad()
    optimizer_GS.zero_grad()

    # real_labels = torch.ones(batch_size, 1, device=device, requires_grad = False)
    # fake_labels = torch.zeros(batch_size, 1, device=device, requires_grad = False)

    num_DV_backwards = num_DS_backwards = 0

    # ---- batch loop ----
    for (acc_voc, speech) in tqdm(
        zip(acc_voc_loader, speech_loader),
        desc="Training Batches"
    ):
        # Read in data
        x_acc = acc_voc["acc_pad"].float().to(device)       # [B,128,289]
        x_voc = acc_voc["voc_pad"].float().to(device)
        x_speech = speech["pad"].float().to(device)    # [B,128,289]
        x_in = torch.cat([x_speech, x_acc], dim=1)     # [B,256,289]

        real_labels = torch.ones(batch_size, 1, device=device, requires_grad = False)
        fake_labels = torch.zeros(batch_size, 1, device=device, requires_grad = False)


        acc_np    = acc_voc["acc_no_pad"]
        voc_np    = acc_voc["voc_no_pad"]
        speech_np = speech["no_pad"]

        # Compute transformations with generators
        raw_fake_vocal = generator_vocal(x_in)["vocal"]
        fake_vocal = raw_fake_vocal.clone()
        fake_vocal_crop = torch.narrow(fake_vocal, 2, 0, clip_length).clone()

        raw_fake_speech = generator_speech(x_voc)["speech"]
        fake_speech = raw_fake_speech.clone()
        fake_speech_crop = torch.narrow(fake_speech, 2, 0, clip_length).clone()

        # Generate reconstructed speech/vocal
        fake_vocal_pad = transform_for_gen_2(fake_vocal, input_size_generators)  # you must define this
        raw_rec_speech = generator_speech(fake_vocal_pad)["speech"]
        rec_speech = raw_rec_speech.clone()
        rec_speech_crop = torch.narrow(rec_speech, 2, 0, clip_length).clone()

        fake_speech_pad = transform_for_gen_2(fake_speech, input_size_generators)  # you must define this
        fake_speech_with_acc = torch.cat([fake_speech_pad, x_acc], dim=1)
        raw_rec_vocal = generator_vocal(fake_speech_with_acc)["vocal"]
        rec_vocal = raw_rec_vocal.clone()
        rec_vocal_crop = torch.narrow(rec_vocal, 2, 0, clip_length).clone()

        # Compute losses

        pred_real_vocal = discriminator_vocal(voc_np, acc_np)
        pred_fake_vocal_D = discriminator_vocal(fake_vocal_crop.detach(), acc_np)
        pred_real_speech = discriminator_speech(speech_np)
        pred_fake_speech_D = discriminator_speech(fake_speech_crop.detach())

        loss_DV_fake = bce_loss(pred_fake_vocal_D, fake_labels)
        loss_DV_real = bce_loss(pred_real_vocal, real_labels)
        loss_DS_fake = bce_loss(pred_fake_speech_D, fake_labels)
        loss_DS_real = bce_loss(pred_real_speech, real_labels)

        # Minimizing adv losses is teaching the gens to trick the discs (labels are swapped)
        pred_fake_vocal = discriminator_vocal(fake_vocal_crop, acc_np)
        pred_fake_speech = discriminator_speech(fake_speech_crop)
        loss_adv_vocal = mse_loss(pred_fake_vocal, real_labels)
        loss_adv_speech = mse_loss(pred_fake_speech, real_labels)

        loss_cycle_vocal = l1_loss(rec_vocal_crop, voc_np)
        loss_cycle_speech = l1_loss(rec_speech_crop, speech_np)

        # NOTE: COULD INCLUDE IDENTITY LOSS

        loss_DV = 0.5 * (loss_DV_real + loss_DV_fake)
        loss_DS = 0.5 * (loss_DS_real + loss_DS_fake)
        loss_GV = (loss_adv_vocal + lambda_cycle * loss_cycle_speech) / (1 + lambda_cycle)
        loss_GS = (loss_adv_speech + lambda_cycle * loss_cycle_vocal) / (1 + lambda_cycle)

        running_loss_DV = alpha * running_loss_DV + (1 - alpha) * loss_DV.item()
        running_loss_DS = alpha * running_loss_DS + (1 - alpha) * loss_DS.item()


        # Update generators
        ((loss_GV + loss_GS) / virtual_batch_size).backward()


        # Record gradients & take steps
        if (num_batches + 1) % virtual_batch_size == 0:
          # Record gradients
          grad_norm = 0.0
          count = 0
          for p in generator_vocal.parameters():
              if p.grad is not None:
                  grad_norm += p.grad.norm().item()
                  count += 1
          if count > 0:
              grad_norms_GV.append(grad_norm / count)

          grad_norm = 0.0
          count = 0
          for p in generator_speech.parameters():
              if p.grad is not None:
                  grad_norm += p.grad.norm().item()
                  count += 1
          if count > 0:
              grad_norms_GS.append(grad_norm / count)

          # Take steps
          optimizer_GV.step()
          optimizer_GS.step()
          optimizer_GV.zero_grad()
          optimizer_GS.zero_grad()


        # Update discriminators
        if running_loss_DV > dv_threshold or smart_discriminator:
          (loss_DV / virtual_batch_size).backward()
          num_DV_backwards += 1
          if (num_DV_backwards+1) % virtual_batch_size == 0:
            grad_norm = 0.0
            count = 0
            for p in discriminator_vocal.parameters():
                if p.grad is not None:
                    grad_norm += p.grad.norm().item()
                    count += 1
            if count > 0:
                grad_norms_DV.append(grad_norm / count)

            optimizer_DV.step()
            optimizer_DV.zero_grad()
            num_DV_backwards = 0

        if running_loss_DS > ds_threshold or smart_discriminator:
          (loss_DS / virtual_batch_size).backward()
          num_DS_backwards += 1
          if (num_DS_backwards+1) % virtual_batch_size == 0:
            # record gradients
            grad_norm = 0.0
            count = 0
            for p in discriminator_speech.parameters():
                if p.grad is not None:
                    grad_norm += p.grad.norm().item()
                    count += 1
            if count > 0:
                grad_norms_DS.append(grad_norm / count)

            optimizer_DS.step()
            optimizer_DS.zero_grad()
            num_DS_backwards = 0

        # Accumulate metrics
        total_loss_DV     += loss_DV.item()
        total_loss_DS     += loss_DS.item()
        total_loss_adv_vocal  += loss_adv_vocal.item()
        total_loss_adv_speech += loss_adv_speech.item()
        total_loss_cycle_vocal += loss_cycle_vocal.item()
        total_loss_cycle_speech += loss_cycle_speech.item()
        total_loss_GV      += loss_GV.item()
        total_loss_GS   += loss_GS.item()
        num_batches       += 1
        # audio_files = []
        # if save_output:
        #     fake_singing_list = [fake_crop[i] for i in range(fake_crop.shape[0])]
        #     audio_files = convert_to_audio(fake_singing_list)
    return {
        "loss_DV":      total_loss_DV / num_batches,
        "loss_DS":      total_loss_DS / num_batches,
        "loss_GV":      total_loss_GV / num_batches,
        "loss_GS":      total_loss_GS / num_batches,
        "loss_adv_vocal":  total_loss_adv_vocal / num_batches,
        "loss_adv_speech":  total_loss_adv_speech / num_batches,
        "loss_cycle_vocal":  total_loss_cycle_vocal / num_batches,
        "loss_cycle_speech":  total_loss_cycle_speech / num_batches,
        "avg_grad_norm_DV": sum(grad_norms_DV) / len(grad_norms_DV) if grad_norms_DV else 0.0,
        "avg_grad_norm_DS": sum(grad_norms_DS) / len(grad_norms_DS) if grad_norms_DS else 0.0,
        "avg_grad_norm_GV": sum(grad_norms_GV) / len(grad_norms_GV) if grad_norms_GV else 0.0,
        "avg_grad_norm_GS": sum(grad_norms_GS) / len(grad_norms_GS) if grad_norms_GS else 0.0,
        "num_DV_updates" : len(grad_norms_DV),
        "num_DS_updates" : len(grad_norms_DS)
    }#, audio_files

# ----- Multi-Epoch Training Function -----
def cycle_train(
    generator_vocal,
    generator_speech,
    discriminator_vocal,
    discriminator_speech,
    optimizer_DV,
    optimizer_DS,
    optimizer_GV,
    optimizer_GS,
    acc_voc_loader,
    speech_loader,
    l1_loss,
    mse_loss,
    lambda_l1,
    lambda_cycle,
    bce_loss,
    device,
    num_epochs,
    virtual_batch_size,
    clip_length,
    input_size_generators,
    log_dir,
    save_audio = True,
    smart_discriminator = False,
    batch_size = 32
):
    writer = SummaryWriter(log_dir=log_dir)
    global_step = 0

    for epoch in range(num_epochs):
        save_audio = True if epoch == num_epochs-1 else False
        print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
        epoch_metrics = cycle_train_epoch(
            generator_vocal,
            generator_speech,
            discriminator_vocal,
            discriminator_speech,
            optimizer_DV,
            optimizer_DS,
            optimizer_GV,
            optimizer_GS,
            acc_voc_loader,
            speech_loader,
            l1_loss,
            mse_loss,
            lambda_l1,
            lambda_cycle,
            bce_loss,
            device,
            virtual_batch_size,
            clip_length,
            input_size_generators,
            save_output = save_audio,
            smart_discriminator = smart_discriminator,
            batch_size = batch_size
        )
        print(f"Epoch {epoch+1} Metrics:")
        print(f"  Loss_DV:         {epoch_metrics['loss_DV']:.4f}")
        print(f"  Loss_DS:         {epoch_metrics['loss_DS']:.4f}")
        # print(f"  Loss_GV_total:   {epoch_metrics['loss_GV']:.4f}")
        # print(f"  Loss_GS_total:   {epoch_metrics['loss_GS']:.4f}")
        print(f"  Loss_adv_vocal:     {epoch_metrics['loss_adv_vocal']:.4f}")
        print(f"  Loss_adv_speech:     {epoch_metrics['loss_adv_speech']:.4f}")
        print(f"  Loss_Cycle Vocal:     {epoch_metrics['loss_cycle_vocal']:.4f}")
        print(f"  Loss_Cycle Speech:     {epoch_metrics['loss_cycle_speech']:.4f}")
        print(f"  Grad Norm DV:    {epoch_metrics['avg_grad_norm_DV']:.4f}")
        print(f"  Grad Norm DS:    {epoch_metrics['avg_grad_norm_DS']:.4f}")
        print(f"  Grad Norm GV:    {epoch_metrics['avg_grad_norm_GV']:.4f}")
        print(f"  Grad Norm GS:    {epoch_metrics['avg_grad_norm_GS']:.4f}")
        print(f"  num_DV_updates:    {epoch_metrics['num_DV_updates']:.4f}")
        print(f"  num_DS_updates:    {epoch_metrics['num_DS_updates']:.4f}")

        # Log metrics to TensorBoard.
        writer.add_scalar("Loss/Discriminator", epoch_metrics["loss_DV"], epoch)
        # writer.add_scalar("Loss/Generator_total", epoch_metrics["loss_G_total"], epoch)
        writer.add_scalar("Loss/Generator_adversarial", epoch_metrics["loss_adv_vocal"], epoch)
        # writer.add_scalar("Loss/Generator_L1", epoch_metrics["loss_G_L1"], epoch)
        writer.add_scalar("Loss/Cycle", epoch_metrics["loss_cycle_vocal"], epoch)
        writer.add_scalar("Gradients/Discriminator", epoch_metrics["avg_grad_norm_DV"], epoch)
        writer.add_scalar("Gradients/Generator", epoch_metrics["avg_grad_norm_GV"], epoch)

        global_step += 1

    writer.close()
    return #audio_files

## Variable dataset
By default we train shorter models on colab and longer models in a docker container on a local machine.

In [None]:
if colab:
    path = "/content/drive/My Drive/git_projects/spring_2025_dl_audio_project_data/"
else:
    path = "/workspace/hdd_project_data/"

# Dataset paths for short clips
if colab:
    librispeech_data = "LibriSpeechDataset_withOverlap.pt"
    musdb_data = "musdb_noOverlap_test.pt"
else:
    librispeech_data = "librispeech_longClip_train_small.pt"
    musdb_data = "musdb_longClip_train.pt"
    # librispeech_data = "LibriSpeechDataset_withOverlap.pt"
    # musdb_data = "musdb_noOverlap_test.pt"

librispeechDataset_path = path + librispeech_data
musdbDataset_path = path + musdb_data

librispeech_dataset = torch.load(librispeechDataset_path, weights_only=False)
musdb_dataset = torch.load(musdbDataset_path, weights_only=False)


In [None]:
# Because of the way the librispeech dataset was constructed, it is slightly longer
# than the musbd dataset. Crop the librispeech dataset with these lines
librispeech_dataset.mel_specs = librispeech_dataset.mel_specs[0:len(musdb_dataset)]
librispeech_dataset.sample_rates = librispeech_dataset.sample_rates[0:len(musdb_dataset)]

### Record length of clips

In [None]:
# --- Explore the Datasets ---
print("=== MusDB Dataset Exploration ===")
print("Length:", len(musdb_dataset))
print("mel_specs shape:", musdb_dataset.mel_specs.shape)
print("sample_rates shape:", musdb_dataset.sample_rates.shape)
print()
accompaniment, vocal, sample_rate = musdb_dataset[0]
print("Sample 0 - Accompaniment shape:", accompaniment.size())
print("Sample 0 - Vocal shape:", vocal.size())
print("Sample 0 - Sample rate:", sample_rate)
print()

print("=== LibriSpeech Dataset Exploration ===")
print("Length:", len(librispeech_dataset))
print("mel_specs shape:", librispeech_dataset.mel_specs.shape)
print("sample_rates shape:", librispeech_dataset.sample_rates.shape)
print()
speech, sample_rate = librispeech_dataset[0]
print("Sample 0 - Speech shape:", speech.size())
print("Sample 0 - Sample rate:", sample_rate)

musdb_length = musdb_dataset.mel_specs.shape[-1]
librispeech_length = librispeech_dataset.mel_specs.shape[-1]

if musdb_length == librispeech_length:
    clip_length = musdb_length
else:
    raise ValueError("The lengths of the datasets do not match. Please check the dataset lengths.")
print()
print("=========================")
print("Training clip length:", clip_length)

=== MusDB Dataset Exploration ===
Length: 4167
mel_specs shape: torch.Size([4167, 2, 128, 256])
sample_rates shape: torch.Size([4167])

Sample 0 - Accompaniment shape: torch.Size([128, 256])
Sample 0 - Vocal shape: torch.Size([128, 256])
Sample 0 - Sample rate: tensor(44100)

=== LibriSpeech Dataset Exploration ===
Length: 4167
mel_specs shape: torch.Size([4167, 128, 256])
sample_rates shape: torch.Size([4167])

Sample 0 - Speech shape: torch.Size([128, 256])
Sample 0 - Sample rate: tensor(44100)

Training clip length: 256


## Initialize Generator Models
We initilaze generator models so we can find the necessary padding for the loader.

In [None]:
# Model configurations for generator and generator_2.
model_config_gen_vocal= {
    "num_inputs": 256,  # Two spectrograms concatenated (2 * 128 mel bins)
    "num_outputs": 128,
    "num_channels": [512*2, 512*4, 512*8],
    "instruments": ["vocal"],
    "kernel_size": 3,
    "target_output_size": clip_length,
    "conv_type": "normal",
    "res": "fixed",
    "separate": False,
    "depth": 1,
    "strides": 2
}

model_config_gen_speech = {
    "num_inputs": 128,  # One spectrogram input
    "num_outputs": 128,
    "num_channels": [256*2, 256*4, 256*8],
    "instruments": ["speech"],
    "kernel_size": 3,
    "target_output_size": clip_length,
    "conv_type": "normal",
    "res": "fixed",
    "separate": False,
    "depth": 1,
    "strides": 2
}
generator_vocal = Waveunet(**model_config_gen_vocal).to(device)
generator_speech = Waveunet(**model_config_gen_speech).to(device)

input_size_generators = generator_vocal.input_size
output_size_generators = generator_vocal.output_size
print("Input size for generators:", input_size_generators)
print("Output size for generators:", output_size_generators)

Using valid convolutions with 289 inputs and 257 outputs
Using valid convolutions with 289 inputs and 257 outputs
Input size for generators: 289
Output size for generators: 257


### Create dataloader
We create the dataloaders for the training loop using the correct input_size for the models.

In [None]:
# Define batch size, virtual batch size, and the number of workers for DataLoaders
batch_size = 32  # Change as needed
target_virtual_batch_size = 256
num_workers = 1

now = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "runs/" + "cycleGAN_experiment_" + now

if target_virtual_batch_size % batch_size != 0:
    raise ValueError("virtual_batch_size must be a multiple of batch_size.")
else:
    virtual_batch_size = target_virtual_batch_size // batch_size

# ---------------- hyper‑parameters in ONE place ----------------
train_parameters = {
    # data loaders
    "batch_size":     batch_size,
    "virtual_batch_size": virtual_batch_size,
    "num_workers":    num_workers,
    "clip_length": clip_length,
    "input_size_generators":    input_size_generators,

    # optimisation
    "lr_G":          1e-4,
    "lr_G2":         1e-4,
    "lr_D":          2e-4,
    "betas":         (0.9, 0.999),

    # loss weights
    "lambda_l1":     1,
    "lambda_cycle": 0.001,

    # schedule
    "num_epochs":    5,

    # bookkeeping
    "log_dir":       f"runs/cycleGAN_experiment_{now}",
    "model_dir":     "models",
}


In [None]:
################ DO WE WANT TO SHUFFLE THE DATASETS? ################

# Create data loaders
acc_voc_loader = DataLoader(
    AccompanimentVocalData(musdb_dataset, output_length=input_size_generators),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)

speech_loader = DataLoader(
    SpeechData(librispeech_dataset, output_length=input_size_generators),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)


### Initialize Discriminator Models
We initialize the discriminator models and call ``fit_rocket``.

In [None]:
# check the number of cores
import multiprocessing
num_cores = multiprocessing.cpu_count()
print("Number of CPU cores:", num_cores)

################ DO WE WANT BOTH DISCRIMINATORS TO USE ALL THE CORES ###############
# The discriminators will not be running at the same time,
# so it seems safe to give them both more than half the cores.
minirocket_n_jobs = num_cores-2

# Instantiate the discriminators.
discriminator_vocal = TsaiMiniRocketDiscriminator().to(device)
discriminator_speech = TsaiMiniRocketDiscriminator(freq_bins = 128,
                                                   hidden_dim = 512,
                                                   accompaniment = False).to(device)

# Optionally, prepare the discriminator (e.g., pre-fitting on some speech data).
acc_voc_batch = next(iter(acc_voc_loader))
speech_batch = next(iter(speech_loader))["no_pad"]

print()
print("Fitting discriminator_vocal...")
discriminator_vocal.fit_rocket(speech_batch)
print()
print("Fitting discriminator_speech...")
discriminator_speech.fit_rocket(speech_batch)

Number of CPU cores: 12

Fitting discriminator_vocal...
MiniRocket fitted. Feature dimension: 9996

Fitting discriminator_speech...
MiniRocket fitted. Feature dimension: 9996


In [None]:
# Loss functions.
bce_loss = nn.BCELoss().to(device)
l1_loss = nn.L1Loss().to(device)
mse_loss = nn.MSELoss().to(device)

# Optimizers.
optimizer_GV  = optim.Adam(generator_vocal.parameters(),  lr=train_parameters["lr_G"],  betas=train_parameters["betas"])
optimizer_GS = optim.Adam(generator_speech.parameters(), lr=train_parameters["lr_G2"], betas=train_parameters["betas"])
optimizer_DV  = optim.Adam(discriminator_vocal.parameters(), lr=train_parameters["lr_D"], betas=train_parameters["betas"])
optimizer_DS  = optim.Adam(discriminator_speech.parameters(), lr=train_parameters["lr_D"], betas=train_parameters["betas"])


In [None]:
# Clear the CUDA cache to free up memory. Sometimes PyTorch doesn't release memory immediately.
# This can help prevent out-of-memory errors.

import gc
gc.collect()
torch.cuda.empty_cache()


In [None]:
##### Optionally load a model checkpoint #####
# NEED TO ADD SECOND DISCRIMINATOR IF WE END UP LOADING A MODEL

# state_time = "20250416-143134"
# model_dir  = "models"

# # ----  load the raw weight dictionaries ----
# disc_state = torch.load(f"{model_dir}/discriminator_state_dict_{state_time}.pt", map_location=device)
# gen_state  = torch.load(f"{model_dir}/generator_state_dict_{state_time}.pt",         map_location=device)
# gen2_state = torch.load(f"{model_dir}/generator_2_state_dict_{state_time}.pt",       map_location=device)

# # ----   rebuild the model objects ----
# discriminator = TsaiMiniRocketDiscriminator().to(device)
# generator     = Waveunet(**model_config_gen).to(device)
# generator_2   = Waveunet(**model_config_gen2).to(device)

# # ---- load the weights into those models ----
# discriminator.load_state_dict(disc_state)
# generator.load_state_dict(gen_state)
# generator_2.load_state_dict(gen2_state)



## Train the model


In [None]:
cycle_train_epoch = torch.compile(cycle_train_epoch, mode="default")

# # Start training.
# audio_files = cycle_train(
#     generator_vocal,
#     generator_speech,
#     discriminator_vocal,
#     discriminator_speech,
#     optimizer_DV,
#     optimizer_DS,
#     optimizer_GV,
#     optimizer_GS,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     train_parameters["lambda_l1"],
#     train_parameters["lambda_cycle"],
#     adversarial_loss,
#     device,
#     num_epochs          = train_parameters["num_epochs"],
#     virtual_batch_size  = train_parameters["virtual_batch_size"],
#     log_dir             = train_parameters["log_dir"],
#     clip_length = train_parameters["clip_length"],
#     input_size_generators = train_parameters["input_size_generators"],
# )

audio_files = cycle_train(
    generator_vocal,
    generator_speech,
    discriminator_vocal,
    discriminator_speech,
    optimizer_DV,
    optimizer_DS,
    optimizer_GV,
    optimizer_GS,
    acc_voc_loader,
    speech_loader,
    l1_loss,
    mse_loss,
    train_parameters["lambda_l1"],
    train_parameters["lambda_cycle"],
    bce_loss,
    device,
    num_epochs          = train_parameters["num_epochs"],
    virtual_batch_size  = train_parameters["virtual_batch_size"],
    log_dir             = train_parameters["log_dir"],
    clip_length = train_parameters["clip_length"],
    input_size_generators = train_parameters["input_size_generators"],
    batch_size = train_parameters["batch_size"]
)



=== Epoch 1/5 ===


Training Batches: 7it [01:12,  3.51s/it]W0423 22:33:34.238000 670 torch/_logging/_internal.py:1089] [64/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['param

KeyboardInterrupt: 

In [None]:
# state_time = "20250421-033033"
# model_dir  = "models"

# # ----  load the raw weight dictionaries ----
# disc_state = torch.load(f"{model_dir}/discriminator_state_dict_{state_time}.pt", map_location=device)
# gen_state  = torch.load(f"{model_dir}/generator_state_dict_{state_time}.pt",         map_location=device)
# gen2_state = torch.load(f"{model_dir}/generator_2_state_dict_{state_time}.pt",       map_location=device)

# # ----   rebuild the model objects ----
# discriminator = TsaiMiniRocketDiscriminator().to(device)
# generator     = Waveunet(**model_config_gen).to(device)
# generator_2   = Waveunet(**model_config_gen2).to(device)

# # ---- load the weights into those models ----
# discriminator.load_state_dict(disc_state)
# generator.load_state_dict(gen_state)
# generator_2.load_state_dict(gen2_state)


In [None]:
# # display the converted audios
# print(audio_files)
# for audio in audio_files:
#     print(audio)
#     display_audio(audio)

In [None]:
# gc.collect()
# torch.cuda.empty_cache()

# # Start training.
# train(
#     generator,
#     generator_2,
#     discriminator,
#     optimizer_D,
#     optimizer_G,
#     optimizer_G2,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     train_parameters["lambda_l1"],
#     train_parameters["lambda_cycle"],
#     adversarial_loss,
#     device,
#     num_epochs          = 50, #train_parameters["num_epochs"],
#     virtual_batch_size  = train_parameters["virtual_batch_size"],
#     log_dir             = train_parameters["log_dir"],
# )


## Save the models

In [None]:
# # assert False
# path = "models/"
# torch.save(generator.state_dict(), path + "generator_state_dict_" + now + ".pt")
# torch.save(generator_2.state_dict(), path + "generator_2_state_dict_" + now + ".pt")
# torch.save(discriminator.state_dict(), path + "discriminator_state_dict_" + now + ".pt")

# # ------------- package everything to save -------------
# export_dict = {
#     "train_parameters": train_parameters,
#     "model_config_gen": model_config_gen,      # Wave‑U‑Net (speech+accomp → vocal)
#     "model_config_gen2": model_config_gen2,    # Wave‑U‑Net (vocal → speech)
# }
# import json

# # (optional) ensure JSON‑serialisable: convert tuples → lists
# def _convert(obj):
#     if isinstance(obj, tuple):
#         return list(obj)
#     if isinstance(obj, dict):
#         return {k: _convert(v) for k, v in obj.items()}
#     if isinstance(obj, list):
#         return [_convert(x) for x in obj]
#     return obj

# export_dict = _convert(export_dict)

# with open(f"{path}/training_record_{now}.json", "w") as fp:
#     json.dump(export_dict, fp, indent=2)



In [None]:
# now = datetime.now().strftime("%Y%m%d-%H%M%S")
# log_dir = "runs/" + "cycleGAN_experiment_" + now

# gc.collect()
# torch.cuda.empty_cache()

# # Start training.
# train(
#     generator_vocal,
#     generator_speech,
#     discriminator,
#     optimizer_D,
#     optimizer_G,
#     optimizer_G2,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     train_parameters["lambda_l1"],
#     train_parameters["lambda_cycle"],
#     adversarial_loss,
#     device,
#     num_epochs          = 50, #train_parameters["num_epochs"],
#     virtual_batch_size  = train_parameters["virtual_batch_size"],
#     log_dir             = train_parameters["log_dir"],
# )