<a href="https://colab.research.google.com/github/jjaw89/spring_2025_dl_audio_project/blob/main/cycle_gan_train_shuffled_data_local_machine.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Here we train our first version of the GAN.



## Initialize Wave-U-Net

We start by loading the necessary packages

Wave-U-Net is named ``generator``

In [1]:
# === Check to see if we are in colab ===
import sys

if 'google.colab' in sys.modules:
    colab = True
else:
    colab = False
if colab:
    from google.colab import drive
    drive.mount('/content/drive')
#    !{sys.executable} -m pip install musdb
    %cd /content/drive/My\ Drive/git_projects/spring_2025_dl_audio_project

# === Built-in modules ===
import os
import time
import argparse
import pickle
from functools import partial
from datetime import datetime
from typing import Tuple, List, Dict, Optional

# === Third-party modules ===
import numpy as np
import pandas as pd
import librosa
import torchaudio
import matplotlib.pyplot as plt
from tqdm import tqdm

# === PyTorch ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import L1Loss
from torch.nn.utils import spectral_norm
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary



# === Colab-specific installation (only run if in Colab) ===
if colab:
    !pip install sktime
    !{sys.executable} -m pip install musdb
    !{sys.executable} -m pip uninstall -y stempeg  # musdb installs the wrong version of stempeg

# === Audio/MIR ===
if colab:
  %cd /content/drive/My Drive/git_projects/
import stempeg
import musdb

# === Time series transforms ===
from sktime.transformations.panel.rocket import MiniRocketMultivariate

if colab:
  %cd /content/drive/My\ Drive/git_projects/spring_2025_dl_audio_project/Wave-U-Net-Pytorch

# === Local module imports ===
if not colab:
  sys.path.append('Wave-U-Net-Pytorch')
  sys.path.append("workspace/hdd_project_data")


import model.utils as model_utils
import utils
from model.waveunet import Waveunet

# === Device setup ===
cuda = torch.cuda.is_available()
if cuda:
    device = torch.device('cuda')
    print("GPU:", torch.cuda.get_device_name(0))
else:
    raise Exception("GPU not available. Please check your setup.")




GPU: NVIDIA GeForce RTX 4090


## Initialize miniRocket
We start by loading the necessary packages

### CPU Core Allocation for MiniRocketMultivariate

- The implementation of `MiniRocketMultivariate` runs on the **CPU**.
- We need to decide how many cores to allocate for it.
- Some cores will be used by MiniRocket itself, while others are needed for data preparation (e.g., generating spectrograms).
- This allocation likely needs to be **tuned for optimal performance**.
- As a starting point, we detect the number of available cores and split them evenly.
- Note: We avoid using *all* available cores to leave some resources for the operating system and other background processes.


Create the MiniRocket model

In [2]:


# MiniRocket Discriminator using tsai library
class TsaiMiniRocketDiscriminator(nn.Module):
    def __init__(
        self,
        freq_bins=256,
        time_frames=256,
        num_kernels=10000,  # number of convolutional kernels
        hidden_dim=1024,    # Increased to handle larger feature dimension
        output_dim=1,
        accompaniment = True   # whether or not we feed accompaniment
    ):
        super(TsaiMiniRocketDiscriminator, self).__init__()

        # This is the mini rocket transformer which extracts features
        self.rocket = MiniRocketMultivariate(num_kernels=num_kernels, n_jobs=minirocket_n_jobs)
        # tsai's miniRocketClassifier is implemented with MiniRocketMultivariate as well
        self.fitted = False   # fit before training
        self.freq_bins = freq_bins
        self.time_frames = time_frames
        self.num_kernels = num_kernels
        self.accompaniment = accompaniment

        # For 2D data handling - process each sample with proper dimensions
        self.example_input = np.zeros((1, freq_bins, time_frames))

        self.feature_dim = num_kernels  # For vocals + accompaniment

        if self.accompaniment:
          classifier_input_dim = 19992
        else:
          classifier_input_dim = 9996

        # Example feature reducing layers
        self.classifier = nn.Sequential(
            # First reduce the massive dimension to something manageable
            nn.Dropout(0.3),
            spectral_norm(nn.Linear(classifier_input_dim, hidden_dim)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Second hidden layer
            nn.Dropout(0.3),
            spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Final classification layer
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, output_dim),
            nn.Sigmoid()
            # nn.Tanh()
        )
        #         # Example feature reducing layers
        # self.classifier = nn.Sequential(
        #     # First reduce the massive dimension to something manageable
        #     nn.Linear(19992, hidden_dim),
        #     nn.LeakyReLU(0.2),
        #     nn.Dropout(0.3),

        #     # Second hidden layer
        #     nn.Linear(hidden_dim, hidden_dim // 2),
        #     nn.LeakyReLU(0.2),
        #     nn.Dropout(0.3),

        #     # Final classification layer
        #     nn.Linear(hidden_dim // 2, output_dim),
        #     nn.Sigmoid()
        # )

    def fit_rocket(self, spectrograms):
        """
            Fit MiniRocket with just one piece of vocal training data (not the entire training dataset)
        """
        if not self.fitted:
            try:
                # Reshape for MiniRocket - it expects (n_instances, n_dimensions, series_length)
                # flatten the freq_bins dimension to create a multivariate time series
                batch_size = spectrograms.shape[0]

                # Convert first to numpy for sktime processing
                sample_data = spectrograms.detach().cpu().numpy()
                # print(sample_data.shape)
                # Reshape to sktime's expected format - reduce to single sample for fitting
                # sample_data = sample_data[:, 0]  # Take one sample, remove channel dim

                # Fit on this sample
                self.rocket.fit(sample_data)
                self.fitted = True

                # Test transform to get feature dimension
                test_transform = self.rocket.transform(sample_data)
                self.feature_dim = test_transform.shape[1]

                print(f"MiniRocket fitted. Feature dimension: {self.feature_dim}")

            except Exception as e:
                print(f"Error fitting MiniRocket: {e}")
                # Use a fallback if fitting fails
                self.fitted = True  # Mark as fitted to avoid repeated attempts

    def extract_features(self, spectrogram):
        """Extract MiniRocket features from a spectrogram"""
        try:
            # Ensure rocket is fitted
            if not self.fitted:
                self.fit_rocket(spectrogram)

            # Convert to numpy for sktime
            spec_np = spectrogram.detach().cpu().numpy()

            # Remove channel dimension expected by sktime
            # print(spec_np.shape)
            # spec_np = spec_np[:, 0]  # [batch_size, freq_bins, time_frames]
            # print(spec_np.shape)

            # This step extracts features using the convolutional kernels, numbers specified by num_kernels
            # print("1")
            features = self.rocket.transform(spec_np)
            # print("2")
            # Convert back to torch tensor
            # print("features:", features.shape)
            # print(features.head())
            features_tensor = torch.tensor(features.values).to(spectrogram.device)
            # print("features:", features.shape)
            # print("3")
            return features_tensor

        except Exception as e:
            print(f"Error in feature extraction: {e}")
            # Return zeros as fallback
            return torch.zeros((spectrogram.shape[0], self.num_kernels),
                              device=spectrogram.device)

    def forward(self, vocals, accompaniment = None):
        """
        Forward pass of the discriminator

        Args:
            vocals: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
            accompaniment: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
        """
        # Extract features from both spectrograms
        # start_time = time()
        vocal_features = self.extract_features(vocals)

        if self.accompaniment:
          accomp_features = self.extract_features(accompaniment)
          # print("extract:", time()-start_time)
          # Concatenate features (conditional GAN)
          output_features = torch.cat([vocal_features, accomp_features], dim=1)
          # print(combined_features.size())
        else:
          output_features = vocal_features

        # Classify as real/fake
        # print(combined_features.size())
        validity = self.classifier(output_features)

        return validity



# Import Data into Session

First, we run the code that defines the custom Dataset objects. The Datasets were compiled previously and saved in .pt files. In the next cell, we load those Dataset objects.

In [3]:
if colab:
  %cd /content/drive/My Drive/git_projects/spring_2025_dl_audio_projects
else:
  sys.path.append('/workspace/hdd_project_data/stempeg')

import stempeg


class MusdbDataset(Dataset):

  def __init__(self, musDB, window_size = 256, step_size = 128):
    self.mel_specs = torch.zeros(1, 2, 128, window_size)
    self.sample_rates = torch.tensor([0])

    num_songs = 0

    for track in musDB:
      stems, rate = track.stems, track.rate

      num_songs += 1

      # separate the vocal from other instruments and conver to mono signal
      audio_novocal = librosa.to_mono(np.transpose(stems[1] + stems[2] + stems[3]))
      audio_vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      logmelspec_novocal = torch.from_numpy(self._mel_spectrogram(audio_novocal, rate))
      logmelspec_vocal = torch.from_numpy(self._mel_spectrogram(audio_vocal, rate))

      start_ndx = 0

      for step in range(window_size // step_size):
        cropped_logmelspec_novocal = logmelspec_novocal[:, start_ndx:]
        cropped_logmelspec_vocal = logmelspec_vocal[:, start_ndx:]
        num_slices = cropped_logmelspec_novocal.shape[1] // window_size

        # chop off the last bit so that number of stft steps is a multiple of window_size
        cropped_logmelspec_novocal = cropped_logmelspec_novocal[: , 0:num_slices*window_size]
        cropped_logmelspec_vocal = cropped_logmelspec_vocal[:, 0:num_slices*window_size]

        # reshape tensors into chunks of size 128x(window_size)
        # first dimension is number of chunks
        cropped_logmelspec_novocal = torch.transpose(torch.reshape(cropped_logmelspec_novocal, (128, num_slices, window_size)), 0, 1)
        cropped_logmelspec_vocal = torch.transpose(torch.reshape(cropped_logmelspec_vocal, (128, num_slices, window_size)), 0, 1)

        # unsqueeze and concatenate these tensors. Then concatenate to the big tensor
        logmels = torch.cat((cropped_logmelspec_novocal.unsqueeze(1), cropped_logmelspec_vocal.unsqueeze(1)), 1)
        self.mel_specs = torch.cat((self.mel_specs, logmels), 0)
        self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)

        if num_songs % 5 == 0:
          print(str(num_songs) + " songs processed; produced " + str(self.mel_specs.shape[0]) + " spectrograms")

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx, 0], self.mel_specs[ndx, 1], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

  def cat(self, other_ds):
    self.mel_specs = torch.cat((self.mel_specs, other_ds.mel_specs), 0)
    self.sample_rates = torch.cat((self.sample_rates, other_ds.sample_rates), 0)



class SingingDataset(Dataset):

  def __init__(self, musDB, window_size = 256, step_size = 128):
    self.mel_specs = torch.zeros(1, 128, window_size)
    self.sample_rates = torch.tensor([0])

    num_songs = 0

    for track in musDB:
      stems, rate = track.stems, track.rate

      num_songs += 1

      # load the vocal
      vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      mel_spec = torch.from_numpy(self._mel_spectrogram(vocal, rate))

      start_ndx = 0
      for step in range(window_size // step_size):
        cropped_mel_spec = mel_spec[:, start_ndx:]
        num_slices = cropped_mel_spec.shape[1] // window_size

        # chop off the last bit so that number of stft steps is a multiple of window_size
        cropped_mel_spec = cropped_mel_spec[:, 0:num_slices*window_size]

        # reshape tensors into chunks of size 128x(window_size)
        # first dimension is number of chunks
        cropped_mel_spec = torch.transpose(torch.reshape(cropped_mel_spec, (128, num_slices, window_size)), 0, 1)

        # concatenate to the big tensor
        self.mel_specs = torch.cat((self.mel_specs, cropped_mel_spec), 0)
        self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)


    if num_songs % 5 == 0:
        print(str(num_songs) + " songs processed; produced " + str(self.mel_specs.shape[0]) + " spectrograms")

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))



class LibriSpeechDataset(Dataset):

    def __init__(self, path, window_size = 256, step_size = 128, num_specs = 7647*2):
        self.mel_specs = self.mel_specs = torch.zeros(1, 128, window_size)
        self.sample_rates = torch.tensor([0])

        num_files_opened = 0

        for speaker_dir in os.listdir(path):
            speaker_path = path + "/" + speaker_dir
            for chapter_dir in os.listdir(speaker_path):
                chapter_path = speaker_path + "/" + chapter_dir
                for file in os.listdir(chapter_path):
                    # checks file extension and stops when we hit desired number of spectrograms (num_specs)
                    if file.endswith('.flac') and self.mel_specs.shape[0] - 1 < num_specs:
                        # get audio file and convert to log mel spectrogram
                        speech, rate = librosa.load(chapter_path + "/" + file, sr = 44100)
                        mel_spec = torch.from_numpy(self._mel_spectrogram(speech, rate))
                        start_ndx = 0

                        num_files_opened += 1

                        for step in range(window_size // step_size):
                            cropped_mel_spec = mel_spec[:, start_ndx:]

                            # Saves the total number of 128 x (window_size) spectrograms
                            num_slices = cropped_mel_spec.shape[1] // window_size

                            # chop off the last bit so that number of stft steps is a multiple of window_size
                            cropped_mel_spec = cropped_mel_spec[ : , 0 : num_slices*window_size]

                            # reshape the tensor to have many spectrograms of size 128 x (steps)
                            cropped_mel_spec = torch.transpose(torch.reshape(cropped_mel_spec, (128, num_slices, window_size)), 0, 1)

                            # concatenate tensor to the full tensor in the Dataset object
                            self.mel_specs = torch.cat((self.mel_specs, cropped_mel_spec), 0)
                            self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)

                            # increment start_ndx
                            start_ndx += step_size


                        if num_files_opened % 50 == 0:
                            print("opened " + str(num_files_opened) + " files and produced " + str(self.mel_specs.shape[0]) + " spectrograms")


        # chop off the zero layer we initialized with
        self.mel_specs = self.mel_specs[1:]
        self.sample_rates = self.sample_rates[1:]

    def __len__(self):
        return self.mel_specs.shape[0]

    def __getitem__(self, ndx):
        return self.mel_specs[ndx], self.sample_rates[ndx]

    def _mel_spectrogram(self, audio, rate):
        # compute the log-mel-spectrogram of the audio at the given sample rate
        return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

### Explore these datasets

## Dataset Helpers Explanation
Why New Dataset Helpers?

We have created new dataset helper classes (i.e., AccompanimentData, VocalData, and SpeechData) so that we can control how the data is padded and later shuffled.

- **Separation of Data:**
We separated the vocal and accompaniment data from the MusDB dataset. In our experiments, we might want to shuffle the speech data independently of the combined music data.

- **Shuffling Considerations:**
For the vocal and accompaniment data, we want to maintain their pairing so that they are shuffled in the same order. In contrast, we want the speech data to be shuffled independently.

- **Future Extensions:**
In the future, we may add another helper class that combines the vocal and accompaniment data to ensure synchronized shuffling in our data loaders.

This modular approach gives us flexibility in handling and preprocessing the data for our GAN training.

In [4]:
class AccompanimentVocalData(Dataset):
  def __init__(self, musdb_dataset, output_length = 289):
    self.musdb = musdb_dataset
    self.out_len = output_length

  def __len__(self):
    return len(self.musdb)

  def __getitem__(self, ndx):
    acc, voc, _ = self.musdb[ndx]
    delta = self.out_len - acc.size(-1)

    if delta > 0:
      # Half the remainder goes to the front
      left_pad_len = (delta // 2) + (delta % 2)  # 17
      right_pad_len = delta // 2                # 16
      acc_pad = F.pad(acc, (left_pad_len, right_pad_len), "constant", 0)
      voc_pad = F.pad(voc, (left_pad_len, right_pad_len), "constant", 0)

    return {"acc_no_pad" : acc,
            "voc_no_pad" : voc,
            "acc_pad": acc_pad,
            "voc_pad" : voc_pad}

class SpeechData(Dataset):
    def __init__(self, librispeech_dataset, output_length=289):
        self.librispeech_dataset = librispeech_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.librispeech_dataset)

    def __getitem__(self, index):
        speech, _ = self.librispeech_dataset[index]
        # If speech has multiple slices, pick the first slice
        if speech.dim() == 3:
            speech = speech[0]  # shape: [128, 256]
        current_len = speech.size(-1)
        delta = self.output_length - current_len

        if delta > 0:
            left_pad_len = (delta // 2) + (delta % 2)
            right_pad_len = delta // 2
            speech_pad = F.pad(speech, (left_pad_len, right_pad_len), "constant", 0)
        return {"no_pad" : speech, "pad" : speech_pad}

def transform_for_gen_2(batch, output_length = 289):
  current_len = batch.size(-1)
  delta = output_length - current_len

  if delta > 0:
      left_pad_len = (delta // 2) + (delta % 2)
      right_pad_len = delta // 2
      batch = F.pad(batch, (left_pad_len, right_pad_len), "constant", 0)
  return batch



### DataLoader Explanation
What is a DataLoader and Why Do We Need It?

A DataLoader in PyTorch is a utility that wraps a dataset and provides:

- **Batching:** It divides your dataset into batches so that you can train your models with mini-batch gradient descent.

- **Shuffling:** It shuffles the data at every epoch (if specified) to help reduce overfitting and ensure the model sees a diverse set of examples.

- **Parallel Data Loading:** It can load data in parallel using multiple worker processes, speeding up training.

In our case, we create separate DataLoaders for:

- The accompaniment data (paired with vocals) from the MusDB dataset.

- The vocal data (paired with accompaniment) from the MusDB dataset.

- The speech data from the LibriSpeech dataset.

This lets us shuffle the speech data independently, while keeping the vocal/accompaniment pairs synchronized during training.

In [5]:


# # Print how many batches each DataLoader contains
# print("Accompaniment loader length:", len(accompaniment_loader))
# print("Vocal loader length:", len(vocal_loader))
# print("Speech loader length:", len(speech_loader))

# # Optionally, fetch and print the shape of the first batch
# accompaniment_batch = next(iter(accompaniment_loader))
# vocal_batch = next(iter(vocal_loader))
# speech_batch = next(iter(speech_loader))
# print(accompaniment_batch["pad"])

# print("Accompaniment first batch shape:", accompaniment_batch.shape)
# print("Vocal first batch shape:", vocal_batch.shape)
# print("Speech first batch shape:", speech_batch.shape)


## Second Generator Model
Here we initialize the second generator model whose purpose is to convert the generated vocals back to normal speach for the cycle GAN. We again use Wave-U-Net, but with a different configuration. The main difference is that we will not input the music along with the vocal track.

## Transform Input to generator_2

The output of the generator model is a (batch_size, 128, 257) tensor. The model expects a tensor of size (batch_size, 128, 289). We need to pad the last dimension with 16 zeros on each size.

## Train the Cycle GAN
The models are ``generator`` and ``discriminator`` and ``generator_2``.


In [6]:
##### Code to convert spectrograms to audio and play audio ######

import IPython.display as ipd
def convert_to_audio(mel_spectrograms, n_fft=2048, hop_length=512, power=2.0, n_iter=32):
    audio_files = []
    sr = 44100

    for mel_spec in mel_spectrograms:
        # Convert Mel spectrogram back to linear spectrogram
        # mel_s
        mel_spec = mel_spec.detach().cpu().numpy() # Remove batch dimension
        print(mel_spec)
        # print(mel_spec.shape)
        linear_spec = librosa.feature.inverse.mel_to_stft(
            mel_spec, sr=sr, n_fft=n_fft, power=power
        )

        # Apply Griffin-Lim algorithm for phase reconstruction
        audio = librosa.griffinlim(
            linear_spec, hop_length=hop_length, n_iter=n_iter
        )
        print(audio)
        audio_files.append(audio)
        # break

    return audio_files

def display_audio(audio_file):
    # y, sr = librosa.load(audio_file, sr=44100)
    ipd.display(ipd.Audio(audio_file, rate=44100))

# A Different Architecture

The code below defines a training loop with a slightly different cycle GAN architecture. In this model, we have
- generator_1 - a Wave-U-Net model trained to accept speech and an accompaniment to produce a fake vocal performance
- generator_2 - a Wave-U-Net model trained to acced singing and produce fake speech
- discriminator_1 - a MiniRocket classifier that determines real vs fake vocal performances
- discriminator_2 - a MiniRocket classifier that determines real vs fake speech.

The training loop iterates the following process on a triple of (vocal, accompaniment, and speech).


A. Speech + Accompaniment --> Vocal --> Reconstructed Speech
1.   Feed the speech and accompaniment into generator_1 to produce a fake vocal
2.   discriminator_1 distinguishes between real vocals and the output of generator_1
3.   generator_2 takes the fake singing output and generates reconstructed speech
4.   Compute L_1 loss between the input speech and the reconstructed speech

B. Vocal --> Fake Speech + Real Accompaniment --> Reconstructed Vocal


1. A real vocal is fed into generator_2 which produces fake speech (the corresponding accompaniment is used later)
2. discriminator_2 distinguishes between real speech and the fake speech produced by generator_2
3. generator_1 takes the fake speech and the real accompaniment and produces a reconstructed vocal
4. Compute L_1 loss between input vocal and the reconstructed vocal



In [7]:
# # ----- Single Epoch Training Function -----
# def cycle_train_epoch(
#     generator_vocal,
#     generator_speech,
#     discriminator_vocal,
#     discriminator_speech,
#     optimizer_GV,
#     optimizer_GS,
#     optimizer_DV,
#     optimizer_DS,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     lambda_l1,
#     lambda_cycle,
#     adversarial_loss,
#     device,
#     virtual_batch_size,
#     clip_length,
#     input_size_generators,
#     save_output=False,
#     smart_discriminator=False
# ):
#     total_loss_DV = total_loss_DS = total_loss_GV = total_loss_GS = 0
#     total_loss_GV_adv = total_loss_GS_adv = total_loss_cycle_vocal = total_loss_cycle_speech = 0.0
#     # Optionally record gradient norms per batch for diagnosing vanishing gradients.
#     grad_norms_DV = []
#     grad_norms_DS = []
#     grad_norms_GV = []
#     grad_norms_GS = []
#     num_batches = 0

#     # ---- batch loop ----
#     for ((accomp, voc), speech) in tqdm(
#         zip(zip(accompaniment_loader, vocal_loader), speech_loader),
#         desc="Training Batches"
#     ):
#         # Move data to device
#         x_acc = accomp["pad"].float().to(device)       # [B,128,289]
#         x_speech = speech["pad"].float().to(device)    # [B,128,289]
#         x_voc = voc["pad"].float().to(device)
#         x_in = torch.cat([x_speech, x_acc], dim=1)     # [B,256,289]

#         # Discriminator real/fake labels
#         B = x_acc.size(0)
#         # real_labels = torch.ones(B, 1, device=device)
#         # fake_labels = torch.zeros(B, 1, device=device)

#         real_labels = torch.rand(B,1, device=device) * 0.2 + 0.8  # [0.8,1.0]
#         fake_labels = torch.rand(B,1, device=device) * 0.2        # [0.0,0.2]

#         # real_labels = torch.ones(B, 1, device=device)
#         # fake_labels = -torch.ones(B, 1, device=device)

#         # 1) Train disciminator_vocal
#         optimizer_DV.zero_grad()
#         acc_np = accomp["no_pad"].float().to(device)
#         voc_np = voc["no_pad"].float().to(device)
#         speech_np = speech["no_pad"].float().to(device)

#         pred_real = discriminator_vocal(voc_np, acc_np)
#         loss_DV_real = adversarial_loss(pred_real, real_labels)

#         raw_fake_vocal = generator_vocal(x_in)["vocal"]
#         fake_vocal = raw_fake_vocal.clone()
#         fake_vocal_crop = torch.narrow(fake_vocal, 2, 0, clip_length).clone()

#         pred_fake_vocal = discriminator_vocal(fake_vocal_crop, acc_np)
#         loss_DV_fake = adversarial_loss(pred_fake_vocal, fake_labels)

#         ################ I THINK WE ARE USING THE WRONG LOSS FUNCTION FOR ADVERSARIAL LOSS #########
#         loss_DV = 0.5 * (loss_DV_real + loss_DV_fake)
#         if loss_DV.item() > 0.5  or smart_discriminator:
#             loss_DV.backward()
#             optimizer_DV.step()

#         # 2) Train GV
#         if virtual_batch_size == 1:
#             optimizer_GV.zero_grad()
#             optimizer_GS.zero_grad()


#         pred_for_GV = discriminator_vocal(fake_vocal, acc_np)
#         loss_GV_adv = adversarial_loss(pred_for_GV, real_labels)


#         # cycle‑consistency
#         fake_vocal_pad = transform_for_gen_2(fake_vocal, input_size_generators)  # you must define this
#         raw_rec_speech = generator_speech(fake_vocal_pad)["speech"]
#         rec_speech = raw_rec_speech.clone()
#         rec_speech_crop = torch.narrow(rec_speech, 2, 0, clip_length).clone()


#         loss_cycle_speech = l1_loss(rec_speech_crop, speech_np)

#         # convex combination
#         loss_GV = (loss_GV_adv + lambda_cycle * loss_cycle_speech) / (1 + lambda_cycle)
#         loss_GV.backward()


#         # 3) Train discriminator_speech
#         optimizer_DS.zero_grad()
#         # pred_real = discriminator_speech(speech_np, acc_np)
#         pred_real_speech = discriminator_speech(speech_np)
#         loss_DS_real = adversarial_loss(pred_real_speech, real_labels)

#         raw_fake_speech = generator_speech(x_voc)["speech"]
#         fake_speech = raw_fake_speech.clone()
#         fake_speech_crop = torch.narrow(fake_speech, 2, 0, clip_length).clone()

#         # pred_fake_speech = discriminator_speech(fake_speech_crop, acc_np)
#         pred_fake_speech = discriminator_speech(fake_speech_crop)
#         loss_DS_fake = adversarial_loss(pred_fake_speech, fake_labels)

#         loss_DS = 0.5 * (loss_DS_real + loss_DS_fake)
#         if loss_DS.item() > 0.5 or smart_discriminator:
#             loss_DS.backward()
#             optimizer_DS.step()


#         # 4) Train GS

#         # pred_for_GS = discriminator_speech(fake_speech, acc_np)
#         pred_for_GS = discriminator_speech(fake_speech)
#         loss_GS_adv = adversarial_loss(pred_for_GS, real_labels)

#         # cycle‑consistency
#         fake_speech_pad = transform_for_gen_2(fake_speech, output_length=input_size_generators)  # you must define this
#         fake_speech_with_acc = torch.cat([fake_speech_pad, x_acc], dim=1)
#         raw_rec_vocal = generator_vocal(fake_speech_with_acc)["vocal"]
#         rec_vocal = raw_rec_vocal.clone()
#         rec_vocal_crop = torch.narrow(rec_vocal, 2, 0, clip_length).clone()

#         loss_cycle_vocal = l1_loss(rec_vocal_crop, voc_np)

#         # convex combination
#         loss_GS = (loss_GS_adv + lambda_cycle * loss_cycle_vocal) / (1 + lambda_cycle)
#         loss_GS.backward()



#         # Record discriminator gradient norms.
#         grad_norm = 0.0
#         count = 0
#         for p in discriminator_vocal.parameters():
#             if p.grad is not None:
#                 grad_norm += p.grad.norm().item()
#                 count += 1
#         if count > 0:
#             grad_norms_DV.append(grad_norm / count)

#         grad_norm = 0.0
#         count = 0
#         for p in discriminator_speech.parameters():
#             if p.grad is not None:
#                 grad_norm += p.grad.norm().item()
#                 count += 1
#         if count > 0:
#             grad_norms_DS.append(grad_norm / count)


#         grad_norm = 0.0
#         count = 0

#         for p in generator_vocal.parameters():
#             if p.grad is not None:
#                 grad_norm += p.grad.norm().item()
#                 count += 1
#         if count > 0:
#             grad_norms_GV.append(grad_norm / count)

#         grad_norm = 0.0
#         count = 0
#         for p in generator_speech.parameters():
#             if p.grad is not None:
#                 grad_norm += p.grad.norm().item()
#                 count += 1
#         if count > 0:
#             grad_norms_GS.append(grad_norm / count)




#         if (num_batches + 1) % virtual_batch_size == 0:
#             optimizer_GV.step()
#             optimizer_GS.step()






#         # Accumulate metrics
#         total_loss_DV     += loss_DV.item()
#         total_loss_DS     += loss_DS.item()
#         total_loss_GV_adv  += loss_GV_adv.item()
#         total_loss_GS_adv += loss_GS_adv.item()
#         total_loss_cycle_vocal += loss_cycle_vocal.item()
#         total_loss_cycle_speech += loss_cycle_speech.item()
#         total_loss_GV      += loss_GV.item()
#         total_loss_GS   += loss_GS.item()
#         num_batches       += 1
#         # audio_files = []
#         # if save_output:
#         #     fake_singing_list = [fake_crop[i] for i in range(fake_crop.shape[0])]
#         #     audio_files = convert_to_audio(fake_singing_list)
#     return {
#         "loss_DV":      total_loss_DV / num_batches,
#         "loss_DS":      total_loss_DS / num_batches,
#         "loss_GV":      total_loss_GV / num_batches,
#         "loss_GS":      total_loss_GS / num_batches,
#         "loss_GV_adv":  total_loss_GV_adv / num_batches,
#         "loss_GS_adv":  total_loss_GS_adv / num_batches,
#         "loss_cycle_vocal":  total_loss_cycle_vocal / num_batches,
#         "loss_cycle_speech":  total_loss_cycle_speech / num_batches,
#         "avg_grad_norm_DV": sum(grad_norms_DV) / len(grad_norms_DV) if grad_norms_DV else 0.0,
#         "avg_grad_norm_DS": sum(grad_norms_DS) / len(grad_norms_DS) if grad_norms_DS else 0.0,
#         "avg_grad_norm_GV": sum(grad_norms_GV) / len(grad_norms_GV) if grad_norms_GV else 0.0,
#         "avg_grad_norm_GS": sum(grad_norms_GS) / len(grad_norms_GS) if grad_norms_GS else 0.0
#     }#, audio_files

# # ----- Multi-Epoch Training Function -----
# def cycle_train(
#     generator_vocal,
#     generator_speech,
#     discriminator_vocal,
#     discriminator_speech,
#     optimizer_DV,
#     optimizer_DS,
#     optimizer_GV,
#     optimizer_GS,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     lambda_l1,
#     lambda_cycle,
#     adversarial_loss,
#     device,
#     num_epochs,
#     virtual_batch_size,
#     log_dir,
#     clip_length,
#     input_size_generators,
#     save_audio = True,
#     smart_discriminator = False
# ):
#     writer = SummaryWriter(log_dir=log_dir)
#     global_step = 0

#     for epoch in range(num_epochs):
#         save_audio = True if epoch == num_epochs-1 else False
#         print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
#         epoch_metrics = cycle_train_epoch(
#             generator_vocal,
#             generator_speech,
#             discriminator_vocal,
#             discriminator_speech,
#             optimizer_DV,
#             optimizer_DS,
#             optimizer_GV,
#             optimizer_GS,
#             accompaniment_loader,
#             vocal_loader,
#             speech_loader,
#             l1_loss,
#             lambda_l1,
#             lambda_cycle,
#             adversarial_loss,
#             device,
#             virtual_batch_size,
#             clip_length,
#             input_size_generators,
#             save_output = save_audio,
#             smart_discriminator = smart_discriminator
#         )
#         print(f"Epoch {epoch+1} Metrics:")
#         print(f"  Loss_DV:         {epoch_metrics['loss_DV']:.4f}")
#         print(f"  Loss_DS:         {epoch_metrics['loss_DS']:.4f}")
#         # print(f"  Loss_GV_total:   {epoch_metrics['loss_GV']:.4f}")
#         # print(f"  Loss_GS_total:   {epoch_metrics['loss_GS']:.4f}")
#         print(f"  Loss_GV_adv:     {epoch_metrics['loss_GV_adv']:.4f}")
#         print(f"  Loss_GS_adv:     {epoch_metrics['loss_GS_adv']:.4f}")
#         print(f"  Loss_Cycle Vocal:     {epoch_metrics['loss_cycle_vocal']:.4f}")
#         print(f"  Loss_Cycle Speech:     {epoch_metrics['loss_cycle_speech']:.4f}")
#         print(f"  Grad Norm DV:    {epoch_metrics['avg_grad_norm_DV']:.4f}")
#         print(f"  Grad Norm DS:    {epoch_metrics['avg_grad_norm_DS']:.4f}")
#         print(f"  Grad Norm GV:    {epoch_metrics['avg_grad_norm_GV']:.4f}")
#         print(f"  Grad Norm GS:    {epoch_metrics['avg_grad_norm_GS']:.4f}")

#         # Log metrics to TensorBoard.
#         writer.add_scalar("Loss/Discriminator", epoch_metrics["loss_DV"], epoch)
#         # writer.add_scalar("Loss/Generator_total", epoch_metrics["loss_G_total"], epoch)
#         writer.add_scalar("Loss/Generator_adversarial", epoch_metrics["loss_GV_adv"], epoch)
#         # writer.add_scalar("Loss/Generator_L1", epoch_metrics["loss_G_L1"], epoch)
#         writer.add_scalar("Loss/Cycle", epoch_metrics["loss_cycle_vocal"], epoch)
#         writer.add_scalar("Gradients/Discriminator", epoch_metrics["avg_grad_norm_DV"], epoch)
#         writer.add_scalar("Gradients/Generator", epoch_metrics["avg_grad_norm_GV"], epoch)

#         global_step += 1

#     writer.close()
#     return #audio_files

In [8]:
# ----- Single Epoch Training Function -----
def cycle_train_epoch(
    generator_vocal,
    generator_speech,
    discriminator_vocal,
    discriminator_speech,
    optimizer_DV,
    optimizer_DS,
    optimizer_GV,
    optimizer_GS,
    acc_voc_loader,
    speech_loader,
    l1_loss,
    mse_loss,
    lambda_l1,
    lambda_cycle,
    lambda_identity,
    bce_loss,
    device,
    virtual_batch_size,
    clip_length,
    input_size_generators,
    save_output=False,
    smart_discriminator=False,
    batch_size = 32

):

    total_loss_DV = total_loss_DS = total_loss_GV = total_loss_GS  = 0
    total_loss_GV_identity = total_loss_GS_identity = 0
    total_loss_adv_vocal = total_loss_adv_speech = total_loss_cycle_vocal = total_loss_cycle_speech = total_loss_GV_identity = total_loss_GS_identity =  0.0
    # Optionally record gradient norms per batch for diagnosing vanishing gradients.
    grad_norms_DV = []
    grad_norms_DS = []
    grad_norms_GV = []
    grad_norms_GS = []
    num_batches = 0


    alpha = 0.9      # Tuneable constant to gate the discriminator training
    running_loss_DV = 0.0
    running_loss_DS = 0.0
    dv_threshold = 0.6
    ds_threshold = 0.6

    optimizer_DV.zero_grad()
    optimizer_DS.zero_grad()
    optimizer_GV.zero_grad()
    optimizer_GS.zero_grad()

    # real_labels = torch.ones(batch_size, 1, device=device, requires_grad = False)
    # fake_labels = torch.zeros(batch_size, 1, device=device, requires_grad = False)

    num_DV_backwards = num_DS_backwards = 0

    # ---- batch loop ----
    for (acc_voc, speech) in tqdm(
        zip(acc_voc_loader, speech_loader),
        desc="Training Batches"
    ):
        # Read in data
        x_acc = acc_voc["acc_pad"].float().to(device)       # [B,128,289]
        x_voc = acc_voc["voc_pad"].float().to(device)
        x_speech = speech["pad"].float().to(device)    # [B,128,289]
        x_in = torch.cat([x_speech, x_acc], dim=1)     # [B,256,289]

        real_labels = torch.ones(batch_size, 1, device=device, requires_grad = False)
        fake_labels = torch.zeros(batch_size, 1, device=device, requires_grad = False)

        ############ START COPIED CODE ######################
        acc_np = acc_voc["acc_no_pad"].float().to(device)
        voc_np = acc_voc["voc_no_pad"].float().to(device)
        speech_np = speech["no_pad"].float().to(device)

        # Compute transformations with generators
        raw_fake_vocal = generator_vocal(x_in)["vocal"]
        fake_vocal = raw_fake_vocal.clone()
        fake_vocal_crop = torch.narrow(fake_vocal, 2, 0, clip_length).clone()

        raw_fake_speech = generator_speech(x_voc)["speech"]
        fake_speech = raw_fake_speech.clone()
        fake_speech_crop = torch.narrow(fake_speech, 2, 0, clip_length).clone()

        # Generate reconstructed speech/vocal
        fake_vocal_pad = transform_for_gen_2(fake_vocal, input_size_generators)  # you must define this
        raw_rec_speech = generator_speech(fake_vocal_pad)["speech"]
        rec_speech = raw_rec_speech.clone()
        rec_speech_crop = torch.narrow(rec_speech, 2, 0, clip_length).clone()

        fake_speech_pad = transform_for_gen_2(fake_speech, input_size_generators)  # you must define this
        fake_speech_with_acc = torch.cat([fake_speech_pad, x_acc], dim=1)
        raw_rec_vocal = generator_vocal(fake_speech_with_acc)["vocal"]
        rec_vocal = raw_rec_vocal.clone()
        rec_vocal_crop = torch.narrow(rec_vocal, 2, 0, clip_length).clone()

        # Identity generation
        identity_vocal = generator_vocal(torch.cat([x_voc, x_acc], dim=1))["vocal"]
        identity_vocal_crop = torch.narrow(identity_vocal, 2, 0, clip_length).clone()

        identity_speech = generator_speech(x_speech)["speech"]
        identity_speech_crop = torch.narrow(identity_speech, 2, 0, clip_length).clone()

        # Compute losses
        pred_real_vocal = discriminator_vocal(voc_np, acc_np)
        pred_fake_vocal_D = discriminator_vocal(fake_vocal_crop.detach(), acc_np)
        pred_real_speech = discriminator_speech(speech_np)
        pred_fake_speech_D = discriminator_speech(fake_speech_crop.detach())

        loss_DV_fake = bce_loss(pred_fake_vocal_D, fake_labels)
        loss_DV_real = bce_loss(pred_real_vocal, real_labels)
        loss_DS_fake = bce_loss(pred_fake_speech_D, fake_labels)
        loss_DS_real = bce_loss(pred_real_speech, real_labels)

        # Minimizing adv losses is teaching the gens to trick the discs (labels are swapped)
        pred_fake_vocal = discriminator_vocal(fake_vocal_crop, acc_np)
        pred_fake_speech = discriminator_speech(fake_speech_crop)
        loss_adv_vocal = mse_loss(pred_fake_vocal, real_labels)
        loss_adv_speech = mse_loss(pred_fake_speech, real_labels)

        loss_cycle_vocal = l1_loss(rec_vocal_crop, voc_np)
        loss_cycle_speech = l1_loss(rec_speech_crop, speech_np)

        loss_identity_vocal = l1_loss(identity_vocal_crop, voc_np)
        loss_identity_speech = l1_loss(identity_speech_crop, speech_np)

        # NOTE: COULD INCLUDE IDENTITY LOSS

        loss_DV = 0.5 * (loss_DV_real + loss_DV_fake)
        loss_DS = 0.5 * (loss_DS_real + loss_DS_fake)
        loss_GV = (loss_adv_vocal + lambda_cycle * loss_cycle_vocal + lambda_identity * loss_identity_vocal) / (1 + lambda_cycle + lambda_identity)
        loss_GS = (loss_adv_speech + lambda_cycle * loss_cycle_speech + lambda_identity * loss_identity_speech) / (1 + lambda_cycle + lambda_identity)

#
        running_loss_DV = alpha * running_loss_DV + (1 - alpha) * loss_DV.item()
        running_loss_DS = alpha * running_loss_DS + (1 - alpha) * loss_DS.item()


        # Update generators
        ((loss_GV + loss_GS) / virtual_batch_size).backward()


        # Record gradients & take steps
        if (num_batches + 1) % virtual_batch_size == 0:
          # Record gradients
          grad_norm = 0.0
          count = 0
          for p in generator_vocal.parameters():
              if p.grad is not None:
                  grad_norm += p.grad.norm().item()
                  count += 1
          if count > 0:
              grad_norms_GV.append(grad_norm / count)

          grad_norm = 0.0
          count = 0
          for p in generator_speech.parameters():
              if p.grad is not None:
                  grad_norm += p.grad.norm().item()
                  count += 1
          if count > 0:
              grad_norms_GS.append(grad_norm / count)

          # Take steps
          optimizer_GV.step()
          optimizer_GS.step()
          optimizer_GV.zero_grad()
          optimizer_GS.zero_grad()


        # Update discriminators
        if running_loss_DV > dv_threshold or smart_discriminator:
          (loss_DV / virtual_batch_size).backward()
          num_DV_backwards += 1
          if (num_DV_backwards+1) % virtual_batch_size == 0:
            grad_norm = 0.0
            count = 0
            for p in discriminator_vocal.parameters():
                if p.grad is not None:
                    grad_norm += p.grad.norm().item()
                    count += 1
            if count > 0:
                grad_norms_DV.append(grad_norm / count)

            optimizer_DV.step()
            optimizer_DV.zero_grad()
            num_DV_backwards = 0

        if running_loss_DS > ds_threshold or smart_discriminator:
          (loss_DS / virtual_batch_size).backward()
          num_DS_backwards += 1
          if (num_DS_backwards+1) % virtual_batch_size == 0:
            # record gradients
            grad_norm = 0.0
            count = 0
            for p in discriminator_speech.parameters():
                if p.grad is not None:
                    grad_norm += p.grad.norm().item()
                    count += 1
            if count > 0:
                grad_norms_DS.append(grad_norm / count)

            optimizer_DS.step()
            optimizer_DS.zero_grad()
            num_DS_backwards = 0

        # Accumulate metrics
        total_loss_DV     += loss_DV.item()
        total_loss_DS     += loss_DS.item()
        total_loss_adv_vocal  += loss_adv_vocal.item()
        total_loss_adv_speech += loss_adv_speech.item()
        total_loss_cycle_vocal += loss_cycle_vocal.item()
        total_loss_cycle_speech += loss_cycle_speech.item()
        total_loss_GV_identity += loss_identity_vocal.item()
        total_loss_GS_identity += loss_identity_speech.item()
        total_loss_GV      += loss_GV.item()
        total_loss_GS   += loss_GS.item()
        num_batches       += 1
        # audio_files = []
        # if save_output:
        #     fake_singing_list = [fake_crop[i] for i in range(fake_crop.shape[0])]
        #     audio_files = convert_to_audio(fake_singing_list)
    return {
        "loss_DV":      total_loss_DV / num_batches,
        "loss_DS":      total_loss_DS / num_batches,
        "loss_GV":      total_loss_GV / num_batches,
        "loss_GS":      total_loss_GS / num_batches,
        "loss_adv_vocal":  total_loss_adv_vocal / num_batches,
        "loss_adv_speech":  total_loss_adv_speech / num_batches,
        "loss_cycle_vocal":  total_loss_cycle_vocal / num_batches,
        "loss_cycle_speech":  total_loss_cycle_speech / num_batches,
        "loss_identity_vocal":  total_loss_GV_identity / num_batches,
        "loss_identity_speech":  total_loss_GS_identity / num_batches,
        "avg_grad_norm_DV": sum(grad_norms_DV) / len(grad_norms_DV) if grad_norms_DV else 0.0,
        "avg_grad_norm_DS": sum(grad_norms_DS) / len(grad_norms_DS) if grad_norms_DS else 0.0,
        "avg_grad_norm_GV": sum(grad_norms_GV) / len(grad_norms_GV) if grad_norms_GV else 0.0,
        "avg_grad_norm_GS": sum(grad_norms_GS) / len(grad_norms_GS) if grad_norms_GS else 0.0,
        "num_DV_updates" : len(grad_norms_DV),
        "num_DS_updates" : len(grad_norms_DS)
    }#, audio_files

# ----- Multi-Epoch Training Function -----
def cycle_train(
    generator_vocal,
    generator_speech,
    discriminator_vocal,
    discriminator_speech,
    optimizer_DV,
    optimizer_DS,
    optimizer_GV,
    optimizer_GS,
    acc_voc_loader,
    speech_loader,
    l1_loss,
    mse_loss,
    lambda_l1,
    lambda_cycle,
    lambda_identity,
    bce_loss,
    device,
    num_epochs,
    virtual_batch_size,
    clip_length,
    input_size_generators,
    log_dir,
    save_audio = True,
    smart_discriminator = False,
    batch_size = 32
):
    writer = SummaryWriter(log_dir=log_dir)
    global_step = 0

    for epoch in range(num_epochs):
        save_audio = True if epoch == num_epochs-1 else False
        print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
        epoch_metrics = cycle_train_epoch(
            generator_vocal,
            generator_speech,
            discriminator_vocal,
            discriminator_speech,
            optimizer_DV,
            optimizer_DS,
            optimizer_GV,
            optimizer_GS,
            acc_voc_loader,
            speech_loader,
            l1_loss,
            mse_loss,
            lambda_l1,
            lambda_cycle,
            lambda_identity,
            bce_loss,
            device,
            virtual_batch_size,
            clip_length,
            input_size_generators,
            save_output = save_audio,
            smart_discriminator = smart_discriminator,
            batch_size = batch_size
        )
        print(f"Epoch {epoch+1} Metrics:")
        print(f"  Loss_DV:         {epoch_metrics['loss_DV']:.4f}")
        print(f"  Loss_DS:         {epoch_metrics['loss_DS']:.4f}")
        # print(f"  Loss_GV_total:   {epoch_metrics['loss_GV']:.4f}")
        # print(f"  Loss_GS_total:   {epoch_metrics['loss_GS']:.4f}")
        print(f"  Loss_adv_vocal:     {epoch_metrics['loss_adv_vocal']:.4f}")
        print(f"  Loss_adv_speech:     {epoch_metrics['loss_adv_speech']:.4f}")
        print(f"  Loss_Cycle Vocal:     {epoch_metrics['loss_cycle_vocal']:.4f}")
        print(f"  Loss_Cycle Speech:     {epoch_metrics['loss_cycle_speech']:.4f}")
        print(f"  Loss_Identity Vocal:     {epoch_metrics['loss_identity_vocal']:.4f}")
        print(f"  Loss_Identity Speech:     {epoch_metrics['loss_identity_speech']:.4f}")
        print(f"  Grad Norm DV:    {epoch_metrics['avg_grad_norm_DV']:.4f}")
        print(f"  Grad Norm DS:    {epoch_metrics['avg_grad_norm_DS']:.4f}")
        print(f"  Grad Norm GV:    {epoch_metrics['avg_grad_norm_GV']:.4f}")
        print(f"  Grad Norm GS:    {epoch_metrics['avg_grad_norm_GS']:.4f}")
        print(f"  num_DV_updates:    {epoch_metrics['num_DV_updates']:.4f}")
        print(f"  num_DS_updates:    {epoch_metrics['num_DS_updates']:.4f}")

        # Log metrics to TensorBoard.
        writer.add_scalar("Loss/Discriminator", epoch_metrics["loss_DV"], epoch)
        # writer.add_scalar("Loss/Generator_total", epoch_metrics["loss_G_total"], epoch)
        writer.add_scalar("Loss/Generator_adversarial", epoch_metrics["loss_adv_vocal"], epoch)
        # writer.add_scalar("Loss/Generator_L1", epoch_metrics["loss_G_L1"], epoch)
        writer.add_scalar("Loss/Cycle", epoch_metrics["loss_cycle_vocal"], epoch)
        writer.add_scalar("Gradients/Discriminator", epoch_metrics["avg_grad_norm_DV"], epoch)
        writer.add_scalar("Gradients/Generator", epoch_metrics["avg_grad_norm_GV"], epoch)

        global_step += 1

    writer.close()
    return #audio_files

## Variable dataset
By default we train shorter models on colab and longer models in a docker container on a local machine.

In [9]:
if colab:
    path = "/content/drive/My Drive/git_projects/spring_2025_dl_audio_project_data/"
else:
    path = "/workspace/hdd_project_data/"

# Dataset paths for short clips
if colab:
    librispeech_data = "LibriSpeechDataset_withOverlap.pt"
    musdb_data = "musdb_noOverlap_test.pt"
else:
    # librispeech_data = "librispeech_longClip_train_small.pt"
    # musdb_data = "musdb_longClip_train.pt"
    librispeech_data = "LibriSpeechDataset_withOverlap.pt"
    musdb_data = "musdb_noOverlap_test.pt"

librispeechDataset_path = path + librispeech_data
musdbDataset_path = path + musdb_data

librispeech_dataset = torch.load(librispeechDataset_path, weights_only=False)
musdb_dataset = torch.load(musdbDataset_path, weights_only=False)


In [10]:
# Because of the way the librispeech dataset was constructed, it is slightly longer
# than the musbd dataset. Crop the librispeech dataset with these lines
librispeech_dataset.mel_specs = librispeech_dataset.mel_specs[0:len(musdb_dataset)]
librispeech_dataset.sample_rates = librispeech_dataset.sample_rates[0:len(musdb_dataset)]

### Record length of clips

In [11]:
# --- Explore the Datasets ---
print("=== MusDB Dataset Exploration ===")
print("Length:", len(musdb_dataset))
print("mel_specs shape:", musdb_dataset.mel_specs.shape)
print("sample_rates shape:", musdb_dataset.sample_rates.shape)
print()
accompaniment, vocal, sample_rate = musdb_dataset[0]
print("Sample 0 - Accompaniment shape:", accompaniment.size())
print("Sample 0 - Vocal shape:", vocal.size())
print("Sample 0 - Sample rate:", sample_rate)
print()

print("=== LibriSpeech Dataset Exploration ===")
print("Length:", len(librispeech_dataset))
print("mel_specs shape:", librispeech_dataset.mel_specs.shape)
print("sample_rates shape:", librispeech_dataset.sample_rates.shape)
print()
speech, sample_rate = librispeech_dataset[0]
print("Sample 0 - Speech shape:", speech.size())
print("Sample 0 - Sample rate:", sample_rate)

musdb_length = musdb_dataset.mel_specs.shape[-1]
librispeech_length = librispeech_dataset.mel_specs.shape[-1]

if musdb_length == librispeech_length:
    clip_length = musdb_length
else:
    raise ValueError("The lengths of the datasets do not match. Please check the dataset lengths.")
print()
print("=========================")
print("Training clip length:", clip_length)

=== MusDB Dataset Exploration ===
Length: 4167
mel_specs shape: torch.Size([4167, 2, 128, 256])
sample_rates shape: torch.Size([4167])

Sample 0 - Accompaniment shape: torch.Size([128, 256])
Sample 0 - Vocal shape: torch.Size([128, 256])
Sample 0 - Sample rate: tensor(44100)

=== LibriSpeech Dataset Exploration ===
Length: 4167
mel_specs shape: torch.Size([4167, 128, 256])
sample_rates shape: torch.Size([4167])

Sample 0 - Speech shape: torch.Size([128, 256])
Sample 0 - Sample rate: tensor(44100)

Training clip length: 256


## Initialize Generator Models
We initilaze generator models so we can find the necessary padding for the loader.

In [12]:
# Model configurations for generator and generator_2.
model_config_gen_vocal= {
    "num_inputs": 256,  # Two spectrograms concatenated (2 * 128 mel bins)
    "num_outputs": 128,
    "num_channels": [512*2, 512*4, 512*8],
    "instruments": ["vocal"],
    "kernel_size": 3,
    "target_output_size": clip_length,
    "conv_type": "normal",
    "res": "fixed",
    "separate": False,
    "depth": 1,
    "strides": 2
}

model_config_gen_speech = {
    "num_inputs": 128,  # One spectrogram input
    "num_outputs": 128,
    "num_channels": [256*2, 256*4, 256*8],
    "instruments": ["speech"],
    "kernel_size": 3,
    "target_output_size": clip_length,
    "conv_type": "normal",
    "res": "fixed",
    "separate": False,
    "depth": 1,
    "strides": 2
}
generator_vocal = Waveunet(**model_config_gen_vocal).to(device)
generator_speech = Waveunet(**model_config_gen_speech).to(device)

input_size_generators = generator_vocal.input_size
output_size_generators = generator_vocal.output_size
print("Input size for generators:", input_size_generators)
print("Output size for generators:", output_size_generators)

Using valid convolutions with 289 inputs and 257 outputs
Using valid convolutions with 289 inputs and 257 outputs
Input size for generators: 289
Output size for generators: 257


### Create dataloader
We create the dataloaders for the training loop using the correct input_size for the models.

In [13]:
# Define batch size, virtual batch size, and the number of workers for DataLoaders
batch_size = 64  # Change as needed
target_virtual_batch_size = 256
num_workers = 1

now = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "runs/" + "cycleGAN_experiment_" + now

if target_virtual_batch_size % batch_size != 0:
    raise ValueError("virtual_batch_size must be a multiple of batch_size.")
else:
    virtual_batch_size = target_virtual_batch_size // batch_size

# ---------------- hyper‑parameters in ONE place ----------------
train_parameters = {
    # data loaders
    "batch_size":     batch_size,
    "virtual_batch_size": virtual_batch_size,
    "num_workers":    num_workers,
    "clip_length": clip_length,
    "input_size_generators":    input_size_generators,

    # optimisation
    "lr_G":          1e-4,
    "lr_G2":         1e-4,
    "lr_D":          1e-4,
    "betas":         (0.9, 0.999),

    # loss weights
    "lambda_l1":     1,
    "lambda_cycle": .01,
    "lambda_identity": 0.0,

    # schedule
    "num_epochs":    50,

    # bookkeeping
    "log_dir":       f"runs/cycleGAN_experiment_{now}",
    "model_dir":     "models",
}


In [14]:
################ DO WE WANT TO SHUFFLE THE DATASETS? ################

# Create data loaders
acc_voc_loader = DataLoader(
    AccompanimentVocalData(musdb_dataset, output_length=input_size_generators),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)


speech_loader = DataLoader(
    SpeechData(librispeech_dataset, output_length=input_size_generators),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)


### Initialize Discriminator Models
We initialize the discriminator models and call ``fit_rocket``.

In [15]:
# check the number of cores
import multiprocessing
num_cores = multiprocessing.cpu_count()
print("Number of CPU cores:", num_cores)

################ DO WE WANT BOTH DISCRIMINATORS TO USE ALL THE CORES ###############
# The discriminators will not be running at the same time,
# so it seems safe to give them both more than half the cores.
minirocket_n_jobs = num_cores-2

# Instantiate the discriminators.
discriminator_vocal = TsaiMiniRocketDiscriminator().to(device)
discriminator_speech = TsaiMiniRocketDiscriminator(freq_bins = 128,
                                                   hidden_dim = 512,
                                                   accompaniment = False).to(device)

# Optionally, prepare the discriminator (e.g., pre-fitting on some speech data).
acc_voc_batch = next(iter(acc_voc_loader))
speech_batch = next(iter(speech_loader))["no_pad"]

print()
print("Fitting discriminator_vocal...")
discriminator_vocal.fit_rocket(speech_batch)
print()
print("Fitting discriminator_speech...")
discriminator_speech.fit_rocket(speech_batch)

Number of CPU cores: 24

Fitting discriminator_vocal...
MiniRocket fitted. Feature dimension: 9996

Fitting discriminator_speech...
MiniRocket fitted. Feature dimension: 9996


In [16]:
# Loss functions.
bce_loss = nn.BCELoss().to(device)
l1_loss = nn.L1Loss().to(device)
mse_loss = nn.MSELoss().to(device)

# Optimizers.
optimizer_GV  = optim.Adam(generator_vocal.parameters(),  lr=train_parameters["lr_G"],  betas=train_parameters["betas"])
optimizer_GS = optim.Adam(generator_speech.parameters(), lr=train_parameters["lr_G2"], betas=train_parameters["betas"])
optimizer_DV  = optim.Adam(discriminator_vocal.parameters(), lr=train_parameters["lr_D"], betas=train_parameters["betas"])
optimizer_DS  = optim.Adam(discriminator_speech.parameters(), lr=train_parameters["lr_D"], betas=train_parameters["betas"])


In [17]:
# Clear the CUDA cache to free up memory. Sometimes PyTorch doesn't release memory immediately.
# This can help prevent out-of-memory errors.

import gc
gc.collect()
torch.cuda.empty_cache()


In [18]:
##### Optionally load a model checkpoint #####
# NEED TO ADD SECOND DISCRIMINATOR IF WE END UP LOADING A MODEL

# state_time = "20250416-143134"
# model_dir  = "models"

# # ----  load the raw weight dictionaries ----
# disc_state = torch.load(f"{model_dir}/discriminator_state_dict_{state_time}.pt", map_location=device)
# gen_state  = torch.load(f"{model_dir}/generator_state_dict_{state_time}.pt",         map_location=device)
# gen2_state = torch.load(f"{model_dir}/generator_2_state_dict_{state_time}.pt",       map_location=device)

# # ----   rebuild the model objects ----
# discriminator = TsaiMiniRocketDiscriminator().to(device)
# generator     = Waveunet(**model_config_gen).to(device)
# generator_2   = Waveunet(**model_config_gen2).to(device)

# # ---- load the weights into those models ----
# discriminator.load_state_dict(disc_state)
# generator.load_state_dict(gen_state)
# generator_2.load_state_dict(gen2_state)



## Train the model


In [19]:
cycle_train_epoch = torch.compile(cycle_train_epoch, mode="default")

# # Start training.
# audio_files = cycle_train(
#     generator_vocal,
#     generator_speech,
#     discriminator_vocal,
#     discriminator_speech,
#     optimizer_DV,
#     optimizer_DS,
#     optimizer_GV,
#     optimizer_GS,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     train_parameters["lambda_l1"],
#     train_parameters["lambda_cycle"],
#     adversarial_loss,
#     device,
#     num_epochs          = train_parameters["num_epochs"],
#     virtual_batch_size  = train_parameters["virtual_batch_size"],
#     log_dir             = train_parameters["log_dir"],
#     clip_length = train_parameters["clip_length"],
#     input_size_generators = train_parameters["input_size_generators"],
# )

audio_files = cycle_train(
    generator_vocal,
    generator_speech,
    discriminator_vocal,
    discriminator_speech,
    optimizer_DV,
    optimizer_DS,
    optimizer_GV,
    optimizer_GS,
    acc_voc_loader,
    speech_loader,
    l1_loss,
    mse_loss,
    train_parameters["lambda_l1"],
    train_parameters["lambda_cycle"],
    train_parameters["lambda_identity"],
    bce_loss,
    device,
    num_epochs          = train_parameters["num_epochs"],
    virtual_batch_size  = train_parameters["virtual_batch_size"],
    log_dir             = train_parameters["log_dir"],
    clip_length = train_parameters["clip_length"],
    input_size_generators = train_parameters["input_size_generators"],
    batch_size = train_parameters["batch_size"]
)





=== Epoch 1/50 ===


Training Batches: 3it [00:08,  2.31s/it]W0424 02:22:32.157000 1138029 torch/_logging/_internal.py:1089] [64/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self'].param_groups[0]['params'][8].grad", "L['self'].param_groups[0]['params'][11].grad", "L['self'].param_groups[0]['params'][12].grad", "L['self'].param_groups[0]['params'][13].grad", "L['self'].param_groups[0]['params'][14].grad", "L['self'].param_groups[0]['params'][16].grad", "L['self'].param_groups[0]['params'][17].grad", "L['self'].param_groups[0]['params'][18].grad", "L['self'].param_groups[0]['params'][19].grad", "L['self'].param_groups[0]['p

Epoch 1 Metrics:
  Loss_DV:         0.9376
  Loss_DS:         0.7542
  Loss_adv_vocal:     0.1835
  Loss_adv_speech:     0.1702
  Loss_Cycle Vocal:     22.2182
  Loss_Cycle Speech:     27.2024
  Loss_Identity Vocal:     26.9812
  Loss_Identity Speech:     29.3930
  Grad Norm DV:    1.4584
  Grad Norm DS:    0.8551
  Grad Norm GV:    0.0635
  Grad Norm GS:    0.0476
  num_DV_updates:    15.0000
  num_DS_updates:    15.0000

=== Epoch 2/50 ===


Training Batches: 65it [01:02,  1.05it/s]


Epoch 2 Metrics:
  Loss_DV:         0.7612
  Loss_DS:         0.6453
  Loss_adv_vocal:     0.1250
  Loss_adv_speech:     0.1626
  Loss_Cycle Vocal:     8.8480
  Loss_Cycle Speech:     10.7420
  Loss_Identity Vocal:     11.6693
  Loss_Identity Speech:     13.4380
  Grad Norm DV:    0.3839
  Grad Norm DS:    0.7347
  Grad Norm GV:    0.0661
  Grad Norm GS:    0.0851
  num_DV_updates:    15.0000
  num_DS_updates:    9.0000

=== Epoch 3/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 3 Metrics:
  Loss_DV:         0.7699
  Loss_DS:         0.7043
  Loss_adv_vocal:     0.1001
  Loss_adv_speech:     0.0980
  Loss_Cycle Vocal:     7.3752
  Loss_Cycle Speech:     8.9802
  Loss_Identity Vocal:     9.9666
  Loss_Identity Speech:     11.3633
  Grad Norm DV:    0.1395
  Grad Norm DS:    0.2608
  Grad Norm GV:    0.0404
  Grad Norm GS:    0.0489
  num_DV_updates:    17.0000
  num_DS_updates:    16.0000

=== Epoch 4/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 4 Metrics:
  Loss_DV:         0.7447
  Loss_DS:         0.6076
  Loss_adv_vocal:     0.1171
  Loss_adv_speech:     0.1484
  Loss_Cycle Vocal:     6.5652
  Loss_Cycle Speech:     8.7287
  Loss_Identity Vocal:     9.1298
  Loss_Identity Speech:     11.5009
  Grad Norm DV:    0.1405
  Grad Norm DS:    0.7926
  Grad Norm GV:    0.0335
  Grad Norm GS:    0.0378
  num_DV_updates:    16.0000
  num_DS_updates:    5.0000

=== Epoch 5/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 5 Metrics:
  Loss_DV:         0.7596
  Loss_DS:         0.6865
  Loss_adv_vocal:     0.1022
  Loss_adv_speech:     0.0977
  Loss_Cycle Vocal:     6.1793
  Loss_Cycle Speech:     8.5918
  Loss_Identity Vocal:     8.6609
  Loss_Identity Speech:     11.7040
  Grad Norm DV:    0.1470
  Grad Norm DS:    0.3359
  Grad Norm GV:    0.0517
  Grad Norm GS:    0.0635
  num_DV_updates:    16.0000
  num_DS_updates:    14.0000

=== Epoch 6/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 6 Metrics:
  Loss_DV:         0.7515
  Loss_DS:         0.6264
  Loss_adv_vocal:     0.1068
  Loss_adv_speech:     0.1224
  Loss_Cycle Vocal:     5.8547
  Loss_Cycle Speech:     8.1877
  Loss_Identity Vocal:     8.6195
  Loss_Identity Speech:     11.7772
  Grad Norm DV:    0.1304
  Grad Norm DS:    0.2587
  Grad Norm GV:    0.0553
  Grad Norm GS:    0.0659
  num_DV_updates:    16.0000
  num_DS_updates:    12.0000

=== Epoch 7/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 7 Metrics:
  Loss_DV:         0.7519
  Loss_DS:         0.6170
  Loss_adv_vocal:     0.1067
  Loss_adv_speech:     0.1282
  Loss_Cycle Vocal:     5.7083
  Loss_Cycle Speech:     7.4126
  Loss_Identity Vocal:     8.9702
  Loss_Identity Speech:     11.7695
  Grad Norm DV:    0.1333
  Grad Norm DS:    0.4631
  Grad Norm GV:    0.0645
  Grad Norm GS:    0.0809
  num_DV_updates:    16.0000
  num_DS_updates:    8.0000

=== Epoch 8/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 8 Metrics:
  Loss_DV:         0.7483
  Loss_DS:         0.6551
  Loss_adv_vocal:     0.1072
  Loss_adv_speech:     0.1059
  Loss_Cycle Vocal:     5.6196
  Loss_Cycle Speech:     6.7103
  Loss_Identity Vocal:     9.7548
  Loss_Identity Speech:     12.0990
  Grad Norm DV:    0.1302
  Grad Norm DS:    0.2654
  Grad Norm GV:    0.0714
  Grad Norm GS:    0.0874
  num_DV_updates:    16.0000
  num_DS_updates:    14.0000

=== Epoch 9/50 ===


Training Batches: 65it [01:01,  1.05it/s]


Epoch 9 Metrics:
  Loss_DV:         0.7445
  Loss_DS:         0.5905
  Loss_adv_vocal:     0.1073
  Loss_adv_speech:     0.1427
  Loss_Cycle Vocal:     5.4045
  Loss_Cycle Speech:     6.2074
  Loss_Identity Vocal:     10.3633
  Loss_Identity Speech:     12.5486
  Grad Norm DV:    0.1377
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0662
  Grad Norm GS:    0.0745
  num_DV_updates:    16.0000
  num_DS_updates:    0.0000

=== Epoch 10/50 ===


Training Batches: 65it [01:00,  1.07it/s]


Epoch 10 Metrics:
  Loss_DV:         0.7388
  Loss_DS:         0.5949
  Loss_adv_vocal:     0.1079
  Loss_adv_speech:     0.1407
  Loss_Cycle Vocal:     5.3081
  Loss_Cycle Speech:     6.0121
  Loss_Identity Vocal:     11.1510
  Loss_Identity Speech:     13.1262
  Grad Norm DV:    0.1401
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0764
  Grad Norm GS:    0.0852
  num_DV_updates:    16.0000
  num_DS_updates:    0.0000

=== Epoch 11/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 11 Metrics:
  Loss_DV:         0.7297
  Loss_DS:         0.5966
  Loss_adv_vocal:     0.1105
  Loss_adv_speech:     0.1403
  Loss_Cycle Vocal:     5.0644
  Loss_Cycle Speech:     5.8144
  Loss_Identity Vocal:     11.7128
  Loss_Identity Speech:     13.4122
  Grad Norm DV:    0.1492
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0676
  Grad Norm GS:    0.0759
  num_DV_updates:    16.0000
  num_DS_updates:    0.0000

=== Epoch 12/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 12 Metrics:
  Loss_DV:         0.7196
  Loss_DS:         0.5992
  Loss_adv_vocal:     0.1131
  Loss_adv_speech:     0.1399
  Loss_Cycle Vocal:     4.9233
  Loss_Cycle Speech:     5.6913
  Loss_Identity Vocal:     12.0971
  Loss_Identity Speech:     13.7862
  Grad Norm DV:    0.1717
  Grad Norm DS:    3.5323
  Grad Norm GV:    0.0648
  Grad Norm GS:    0.0777
  num_DV_updates:    15.0000
  num_DS_updates:    1.0000

=== Epoch 13/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 13 Metrics:
  Loss_DV:         0.7076
  Loss_DS:         0.6783
  Loss_adv_vocal:     0.1151
  Loss_adv_speech:     0.1032
  Loss_Cycle Vocal:     4.8858
  Loss_Cycle Speech:     5.6705
  Loss_Identity Vocal:     12.4270
  Loss_Identity Speech:     14.1455
  Grad Norm DV:    0.1894
  Grad Norm DS:    0.5048
  Grad Norm GV:    0.0678
  Grad Norm GS:    0.0807
  num_DV_updates:    15.0000
  num_DS_updates:    10.0000

=== Epoch 14/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 14 Metrics:
  Loss_DV:         0.6958
  Loss_DS:         0.6868
  Loss_adv_vocal:     0.1177
  Loss_adv_speech:     0.1006
  Loss_Cycle Vocal:     4.8185
  Loss_Cycle Speech:     5.5698
  Loss_Identity Vocal:     12.8244
  Loss_Identity Speech:     14.4943
  Grad Norm DV:    0.2335
  Grad Norm DS:    0.1832
  Grad Norm GV:    0.0695
  Grad Norm GS:    0.0855
  num_DV_updates:    14.0000
  num_DS_updates:    17.0000

=== Epoch 15/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 15 Metrics:
  Loss_DV:         0.6837
  Loss_DS:         0.5845
  Loss_adv_vocal:     0.1209
  Loss_adv_speech:     0.1552
  Loss_Cycle Vocal:     4.7537
  Loss_Cycle Speech:     5.4078
  Loss_Identity Vocal:     13.0199
  Loss_Identity Speech:     14.8254
  Grad Norm DV:    0.2643
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0668
  Grad Norm GS:    0.0802
  num_DV_updates:    13.0000
  num_DS_updates:    0.0000

=== Epoch 16/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 16 Metrics:
  Loss_DV:         0.6765
  Loss_DS:         0.5847
  Loss_adv_vocal:     0.1187
  Loss_adv_speech:     0.1547
  Loss_Cycle Vocal:     4.7333
  Loss_Cycle Speech:     5.2163
  Loss_Identity Vocal:     13.2911
  Loss_Identity Speech:     15.1695
  Grad Norm DV:    0.2457
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0712
  Grad Norm GS:    0.0809
  num_DV_updates:    13.0000
  num_DS_updates:    0.0000

=== Epoch 17/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 17 Metrics:
  Loss_DV:         0.6725
  Loss_DS:         0.5868
  Loss_adv_vocal:     0.1192
  Loss_adv_speech:     0.1541
  Loss_Cycle Vocal:     4.6115
  Loss_Cycle Speech:     4.9358
  Loss_Identity Vocal:     13.4486
  Loss_Identity Speech:     15.2745
  Grad Norm DV:    0.2758
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0670
  Grad Norm GS:    0.0783
  num_DV_updates:    13.0000
  num_DS_updates:    0.0000

=== Epoch 18/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 18 Metrics:
  Loss_DV:         0.6643
  Loss_DS:         0.5891
  Loss_adv_vocal:     0.1254
  Loss_adv_speech:     0.1535
  Loss_Cycle Vocal:     4.6615
  Loss_Cycle Speech:     4.9173
  Loss_Identity Vocal:     13.4406
  Loss_Identity Speech:     15.5071
  Grad Norm DV:    0.2974
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0797
  Grad Norm GS:    0.0942
  num_DV_updates:    12.0000
  num_DS_updates:    0.0000

=== Epoch 19/50 ===


Training Batches: 65it [01:00,  1.08it/s]


Epoch 19 Metrics:
  Loss_DV:         0.6574
  Loss_DS:         0.5880
  Loss_adv_vocal:     0.1161
  Loss_adv_speech:     0.1531
  Loss_Cycle Vocal:     4.7692
  Loss_Cycle Speech:     5.0284
  Loss_Identity Vocal:     13.6385
  Loss_Identity Speech:     15.5429
  Grad Norm DV:    0.2191
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0853
  Grad Norm GS:    0.0961
  num_DV_updates:    13.0000
  num_DS_updates:    0.0000

=== Epoch 20/50 ===


Training Batches: 65it [01:00,  1.07it/s]


Epoch 20 Metrics:
  Loss_DV:         0.6652
  Loss_DS:         0.5864
  Loss_adv_vocal:     0.1168
  Loss_adv_speech:     0.1543
  Loss_Cycle Vocal:     4.6666
  Loss_Cycle Speech:     4.8712
  Loss_Identity Vocal:     13.8038
  Loss_Identity Speech:     15.4432
  Grad Norm DV:    0.5098
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0768
  Grad Norm GS:    0.0875
  num_DV_updates:    9.0000
  num_DS_updates:    0.0000

=== Epoch 21/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 21 Metrics:
  Loss_DV:         0.6512
  Loss_DS:         0.5859
  Loss_adv_vocal:     0.1151
  Loss_adv_speech:     0.1546
  Loss_Cycle Vocal:     4.6111
  Loss_Cycle Speech:     4.8728
  Loss_Identity Vocal:     13.8935
  Loss_Identity Speech:     15.5365
  Grad Norm DV:    0.1760
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0815
  Grad Norm GS:    0.0951
  num_DV_updates:    14.0000
  num_DS_updates:    0.0000

=== Epoch 22/50 ===


Training Batches: 65it [01:04,  1.00it/s]


Epoch 22 Metrics:
  Loss_DV:         0.6682
  Loss_DS:         0.5874
  Loss_adv_vocal:     0.1195
  Loss_adv_speech:     0.1546
  Loss_Cycle Vocal:     4.4655
  Loss_Cycle Speech:     4.7049
  Loss_Identity Vocal:     14.2675
  Loss_Identity Speech:     15.6490
  Grad Norm DV:    0.8710
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0728
  Grad Norm GS:    0.0855
  num_DV_updates:    6.0000
  num_DS_updates:    0.0000

=== Epoch 23/50 ===


Training Batches: 65it [01:05,  1.01s/it]


Epoch 23 Metrics:
  Loss_DV:         0.7053
  Loss_DS:         0.5876
  Loss_adv_vocal:     0.1231
  Loss_adv_speech:     0.1548
  Loss_Cycle Vocal:     4.2435
  Loss_Cycle Speech:     4.5239
  Loss_Identity Vocal:     14.3928
  Loss_Identity Speech:     15.6804
  Grad Norm DV:    0.1640
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0574
  Grad Norm GS:    0.0646
  num_DV_updates:    18.0000
  num_DS_updates:    0.0000

=== Epoch 24/50 ===


Training Batches: 65it [01:02,  1.03it/s]


Epoch 24 Metrics:
  Loss_DV:         0.6602
  Loss_DS:         0.5880
  Loss_adv_vocal:     0.1376
  Loss_adv_speech:     0.1542
  Loss_Cycle Vocal:     4.1605
  Loss_Cycle Speech:     4.4450
  Loss_Identity Vocal:     14.5471
  Loss_Identity Speech:     15.8410
  Grad Norm DV:    0.5372
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0572
  Grad Norm GS:    0.0672
  num_DV_updates:    8.0000
  num_DS_updates:    0.0000

=== Epoch 25/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 25 Metrics:
  Loss_DV:         0.7013
  Loss_DS:         0.5903
  Loss_adv_vocal:     0.1058
  Loss_adv_speech:     0.1540
  Loss_Cycle Vocal:     4.1312
  Loss_Cycle Speech:     4.3841
  Loss_Identity Vocal:     14.6932
  Loss_Identity Speech:     15.9611
  Grad Norm DV:    0.1423
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0561
  Grad Norm GS:    0.0646
  num_DV_updates:    17.0000
  num_DS_updates:    0.0000

=== Epoch 26/50 ===


Training Batches: 65it [01:02,  1.03it/s]


Epoch 26 Metrics:
  Loss_DV:         0.5991
  Loss_DS:         0.5908
  Loss_adv_vocal:     0.1612
  Loss_adv_speech:     0.1529
  Loss_Cycle Vocal:     4.2206
  Loss_Cycle Speech:     4.4338
  Loss_Identity Vocal:     14.8403
  Loss_Identity Speech:     16.0961
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0706
  Grad Norm GS:    0.0803
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 27/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 27 Metrics:
  Loss_DV:         0.5986
  Loss_DS:         0.5887
  Loss_adv_vocal:     0.1606
  Loss_adv_speech:     0.1532
  Loss_Cycle Vocal:     4.4594
  Loss_Cycle Speech:     4.6309
  Loss_Identity Vocal:     14.8279
  Loss_Identity Speech:     16.1248
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0886
  Grad Norm GS:    0.0994
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 28/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 28 Metrics:
  Loss_DV:         0.6050
  Loss_DS:         0.5898
  Loss_adv_vocal:     0.1560
  Loss_adv_speech:     0.1536
  Loss_Cycle Vocal:     4.1776
  Loss_Cycle Speech:     4.3523
  Loss_Identity Vocal:     15.0676
  Loss_Identity Speech:     16.1501
  Grad Norm DV:    1.4555
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0657
  Grad Norm GS:    0.0744
  num_DV_updates:    2.0000
  num_DS_updates:    0.0000

=== Epoch 29/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 29 Metrics:
  Loss_DV:         0.7953
  Loss_DS:         0.5922
  Loss_adv_vocal:     0.0695
  Loss_adv_speech:     0.1530
  Loss_Cycle Vocal:     3.9934
  Loss_Cycle Speech:     4.1876
  Loss_Identity Vocal:     15.0940
  Loss_Identity Speech:     16.2680
  Grad Norm DV:    0.3756
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0577
  Grad Norm GS:    0.0654
  num_DV_updates:    16.0000
  num_DS_updates:    0.0000

=== Epoch 30/50 ===


Training Batches: 65it [01:03,  1.02it/s]


Epoch 30 Metrics:
  Loss_DV:         0.6454
  Loss_DS:         0.5914
  Loss_adv_vocal:     0.1458
  Loss_adv_speech:     0.1528
  Loss_Cycle Vocal:     3.9604
  Loss_Cycle Speech:     4.1514
  Loss_Identity Vocal:     15.2302
  Loss_Identity Speech:     16.4775
  Grad Norm DV:    0.1641
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0705
  Grad Norm GS:    0.0856
  num_DV_updates:    13.0000
  num_DS_updates:    0.0000

=== Epoch 31/50 ===


Training Batches: 65it [01:04,  1.00it/s]


Epoch 31 Metrics:
  Loss_DV:         0.6830
  Loss_DS:         0.5924
  Loss_adv_vocal:     0.1044
  Loss_adv_speech:     0.1522
  Loss_Cycle Vocal:     3.9661
  Loss_Cycle Speech:     4.1412
  Loss_Identity Vocal:     15.3255
  Loss_Identity Speech:     16.5839
  Grad Norm DV:    0.2015
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0661
  Grad Norm GS:    0.0790
  num_DV_updates:    14.0000
  num_DS_updates:    0.0000

=== Epoch 32/50 ===


Training Batches: 65it [01:05,  1.01s/it]


Epoch 32 Metrics:
  Loss_DV:         0.6527
  Loss_DS:         0.5921
  Loss_adv_vocal:     0.1130
  Loss_adv_speech:     0.1522
  Loss_Cycle Vocal:     3.9507
  Loss_Cycle Speech:     4.1363
  Loss_Identity Vocal:     15.2732
  Loss_Identity Speech:     16.4658
  Grad Norm DV:    0.1571
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0671
  Grad Norm GS:    0.0753
  num_DV_updates:    14.0000
  num_DS_updates:    0.0000

=== Epoch 33/50 ===


Training Batches: 65it [01:05,  1.00s/it]


Epoch 33 Metrics:
  Loss_DV:         0.6412
  Loss_DS:         0.5931
  Loss_adv_vocal:     0.1201
  Loss_adv_speech:     0.1519
  Loss_Cycle Vocal:     3.9302
  Loss_Cycle Speech:     4.0832
  Loss_Identity Vocal:     15.2730
  Loss_Identity Speech:     16.4065
  Grad Norm DV:    0.4005
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0663
  Grad Norm GS:    0.0765
  num_DV_updates:    8.0000
  num_DS_updates:    0.0000

=== Epoch 34/50 ===


Training Batches: 65it [01:04,  1.01it/s]


Epoch 34 Metrics:
  Loss_DV:         0.6631
  Loss_DS:         0.5919
  Loss_adv_vocal:     0.1058
  Loss_adv_speech:     0.1519
  Loss_Cycle Vocal:     3.9058
  Loss_Cycle Speech:     4.0346
  Loss_Identity Vocal:     15.3053
  Loss_Identity Speech:     16.4001
  Grad Norm DV:    0.1557
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0646
  Grad Norm GS:    0.0753
  num_DV_updates:    14.0000
  num_DS_updates:    0.0000

=== Epoch 35/50 ===


Training Batches: 65it [01:02,  1.03it/s]


Epoch 35 Metrics:
  Loss_DV:         0.5845
  Loss_DS:         0.5927
  Loss_adv_vocal:     0.1610
  Loss_adv_speech:     0.1521
  Loss_Cycle Vocal:     3.8472
  Loss_Cycle Speech:     3.9885
  Loss_Identity Vocal:     15.3635
  Loss_Identity Speech:     16.4611
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0632
  Grad Norm GS:    0.0734
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 36/50 ===


Training Batches: 65it [01:03,  1.03it/s]


Epoch 36 Metrics:
  Loss_DV:         0.5839
  Loss_DS:         0.5932
  Loss_adv_vocal:     0.1605
  Loss_adv_speech:     0.1519
  Loss_Cycle Vocal:     3.7870
  Loss_Cycle Speech:     3.9021
  Loss_Identity Vocal:     15.4856
  Loss_Identity Speech:     16.5651
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0615
  Grad Norm GS:    0.0687
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 37/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 37 Metrics:
  Loss_DV:         0.5850
  Loss_DS:         0.5934
  Loss_adv_vocal:     0.1606
  Loss_adv_speech:     0.1521
  Loss_Cycle Vocal:     3.7468
  Loss_Cycle Speech:     3.8545
  Loss_Identity Vocal:     15.4837
  Loss_Identity Speech:     16.5825
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0678
  Grad Norm GS:    0.0781
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 38/50 ===


Training Batches: 65it [01:00,  1.07it/s]


Epoch 38 Metrics:
  Loss_DV:         0.5858
  Loss_DS:         0.5941
  Loss_adv_vocal:     0.1603
  Loss_adv_speech:     0.1517
  Loss_Cycle Vocal:     3.9553
  Loss_Cycle Speech:     4.1058
  Loss_Identity Vocal:     15.5530
  Loss_Identity Speech:     16.5472
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0821
  Grad Norm GS:    0.0870
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 39/50 ===


Training Batches: 65it [01:00,  1.07it/s]


Epoch 39 Metrics:
  Loss_DV:         0.5856
  Loss_DS:         0.5921
  Loss_adv_vocal:     0.1605
  Loss_adv_speech:     0.1522
  Loss_Cycle Vocal:     3.8160
  Loss_Cycle Speech:     3.9516
  Loss_Identity Vocal:     15.3300
  Loss_Identity Speech:     16.1929
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0669
  Grad Norm GS:    0.0723
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 40/50 ===


Training Batches: 65it [01:02,  1.05it/s]


Epoch 40 Metrics:
  Loss_DV:         0.5850
  Loss_DS:         0.5936
  Loss_adv_vocal:     0.1603
  Loss_adv_speech:     0.1515
  Loss_Cycle Vocal:     3.7362
  Loss_Cycle Speech:     3.8449
  Loss_Identity Vocal:     15.2846
  Loss_Identity Speech:     16.2293
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0656
  Grad Norm GS:    0.0778
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 41/50 ===


Training Batches: 65it [01:03,  1.03it/s]


Epoch 41 Metrics:
  Loss_DV:         0.5858
  Loss_DS:         0.5933
  Loss_adv_vocal:     0.1601
  Loss_adv_speech:     0.1513
  Loss_Cycle Vocal:     3.6524
  Loss_Cycle Speech:     3.7824
  Loss_Identity Vocal:     15.3166
  Loss_Identity Speech:     16.3056
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0612
  Grad Norm GS:    0.0721
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 42/50 ===


Training Batches: 65it [01:03,  1.03it/s]


Epoch 42 Metrics:
  Loss_DV:         0.5857
  Loss_DS:         0.5947
  Loss_adv_vocal:     0.1600
  Loss_adv_speech:     0.1515
  Loss_Cycle Vocal:     3.6192
  Loss_Cycle Speech:     3.7523
  Loss_Identity Vocal:     15.3814
  Loss_Identity Speech:     16.4027
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0567
  Grad Norm GS:    0.0671
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 43/50 ===


Training Batches: 65it [01:02,  1.03it/s]


Epoch 43 Metrics:
  Loss_DV:         0.5869
  Loss_DS:         0.5962
  Loss_adv_vocal:     0.1596
  Loss_adv_speech:     0.1511
  Loss_Cycle Vocal:     3.5411
  Loss_Cycle Speech:     3.6782
  Loss_Identity Vocal:     15.4344
  Loss_Identity Speech:     16.4820
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0527
  Grad Norm GS:    0.0610
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 44/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 44 Metrics:
  Loss_DV:         0.5868
  Loss_DS:         0.5968
  Loss_adv_vocal:     0.1594
  Loss_adv_speech:     0.1497
  Loss_Cycle Vocal:     3.5193
  Loss_Cycle Speech:     3.6635
  Loss_Identity Vocal:     15.5225
  Loss_Identity Speech:     16.5929
  Grad Norm DV:    0.0000
  Grad Norm DS:    3.2303
  Grad Norm GV:    0.0629
  Grad Norm GS:    0.0713
  num_DV_updates:    0.0000
  num_DS_updates:    1.0000

=== Epoch 45/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 45 Metrics:
  Loss_DV:         0.5869
  Loss_DS:         0.6359
  Loss_adv_vocal:     0.1595
  Loss_adv_speech:     0.1241
  Loss_Cycle Vocal:     3.4375
  Loss_Cycle Speech:     3.5955
  Loss_Identity Vocal:     15.5891
  Loss_Identity Speech:     16.6670
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.5599
  Grad Norm GV:    0.0623
  Grad Norm GS:    0.0697
  num_DV_updates:    0.0000
  num_DS_updates:    7.0000

=== Epoch 46/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 46 Metrics:
  Loss_DV:         0.5877
  Loss_DS:         0.7362
  Loss_adv_vocal:     0.1591
  Loss_adv_speech:     0.0820
  Loss_Cycle Vocal:     3.4071
  Loss_Cycle Speech:     3.5757
  Loss_Identity Vocal:     15.6448
  Loss_Identity Speech:     16.6896
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.2503
  Grad Norm GV:    0.0596
  Grad Norm GS:    0.0658
  num_DV_updates:    0.0000
  num_DS_updates:    17.0000

=== Epoch 47/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 47 Metrics:
  Loss_DV:         0.5878
  Loss_DS:         0.6237
  Loss_adv_vocal:     0.1586
  Loss_adv_speech:     0.1330
  Loss_Cycle Vocal:     3.4112
  Loss_Cycle Speech:     3.5839
  Loss_Identity Vocal:     15.7098
  Loss_Identity Speech:     16.7080
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.2128
  Grad Norm GV:    0.0651
  Grad Norm GS:    0.0744
  num_DV_updates:    0.0000
  num_DS_updates:    12.0000

=== Epoch 48/50 ===


Training Batches: 65it [01:03,  1.03it/s]


Epoch 48 Metrics:
  Loss_DV:         0.5878
  Loss_DS:         0.6248
  Loss_adv_vocal:     0.1586
  Loss_adv_speech:     0.1228
  Loss_Cycle Vocal:     3.3621
  Loss_Cycle Speech:     3.5565
  Loss_Identity Vocal:     15.7191
  Loss_Identity Speech:     16.7114
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.2977
  Grad Norm GV:    0.0634
  Grad Norm GS:    0.0707
  num_DV_updates:    0.0000
  num_DS_updates:    10.0000

=== Epoch 49/50 ===


Training Batches: 65it [01:02,  1.04it/s]


Epoch 49 Metrics:
  Loss_DV:         0.5879
  Loss_DS:         0.6375
  Loss_adv_vocal:     0.1589
  Loss_adv_speech:     0.1125
  Loss_Cycle Vocal:     3.3509
  Loss_Cycle Speech:     3.5582
  Loss_Identity Vocal:     15.7626
  Loss_Identity Speech:     16.7232
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.2291
  Grad Norm GV:    0.0630
  Grad Norm GS:    0.0748
  num_DV_updates:    0.0000
  num_DS_updates:    13.0000

=== Epoch 50/50 ===


Training Batches: 65it [01:03,  1.03it/s]

Epoch 50 Metrics:
  Loss_DV:         0.5885
  Loss_DS:         0.5934
  Loss_adv_vocal:     0.1581
  Loss_adv_speech:     0.1359
  Loss_Cycle Vocal:     3.3135
  Loss_Cycle Speech:     3.5201
  Loss_Identity Vocal:     15.7926
  Loss_Identity Speech:     16.7192
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0584
  Grad Norm GS:    0.0659
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000





In [20]:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "runs/" + "cycleGAN_experiment_" + now
audio_files = cycle_train(
    generator_vocal,
    generator_speech,
    discriminator_vocal,
    discriminator_speech,
    optimizer_DV,
    optimizer_DS,
    optimizer_GV,
    optimizer_GS,
    acc_voc_loader,
    speech_loader,
    l1_loss,
    mse_loss,
    train_parameters["lambda_l1"],
    train_parameters["lambda_cycle"],
    train_parameters["lambda_identity"],
    bce_loss,
    device,
    num_epochs          = train_parameters["num_epochs"],
    virtual_batch_size  = train_parameters["virtual_batch_size"],
    log_dir             = train_parameters["log_dir"],
    clip_length = train_parameters["clip_length"],
    input_size_generators = train_parameters["input_size_generators"],
    batch_size = train_parameters["batch_size"]
)



=== Epoch 1/50 ===


Training Batches: 65it [01:03,  1.02it/s]


Epoch 1 Metrics:
  Loss_DV:         0.5887
  Loss_DS:         0.5927
  Loss_adv_vocal:     0.1584
  Loss_adv_speech:     0.1357
  Loss_Cycle Vocal:     3.2867
  Loss_Cycle Speech:     3.5220
  Loss_Identity Vocal:     15.8233
  Loss_Identity Speech:     16.7654
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0600
  Grad Norm GS:    0.0673
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 2/50 ===


Training Batches: 65it [01:03,  1.03it/s]


Epoch 2 Metrics:
  Loss_DV:         0.5887
  Loss_DS:         0.5928
  Loss_adv_vocal:     0.1579
  Loss_adv_speech:     0.1355
  Loss_Cycle Vocal:     3.2379
  Loss_Cycle Speech:     3.4977
  Loss_Identity Vocal:     15.8232
  Loss_Identity Speech:     16.7900
  Grad Norm DV:    0.0000
  Grad Norm DS:    0.0000
  Grad Norm GV:    0.0604
  Grad Norm GS:    0.0716
  num_DV_updates:    0.0000
  num_DS_updates:    0.0000

=== Epoch 3/50 ===


Training Batches: 48it [00:46,  1.03it/s]


KeyboardInterrupt: 

In [None]:
# now = datetime.now().strftime("%Y%m%d-%H%M%S")
# log_dir = "runs/" + "cycleGAN_experiment_" + now
# audio_files = cycle_train(
#     generator_vocal,
#     generator_speech,
#     discriminator_vocal,
#     discriminator_speech,
#     optimizer_DV,
#     optimizer_DS,
#     optimizer_GV,
#     optimizer_GS,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     mse_loss,
#     train_parameters["lambda_l1"],
#     train_parameters["lambda_cycle"],
#     bce_loss,
#     device,
#     num_epochs          = train_parameters["num_epochs"],
#     virtual_batch_size  = train_parameters["virtual_batch_size"],
#     log_dir             = log_dir,
#     clip_length = train_parameters["clip_length"],
#     input_size_generators = train_parameters["input_size_generators"],
#     batch_size = train_parameters["batch_size"]
# )


In [None]:
# state_time = "20250421-033033"
# model_dir  = "models"

# # ----  load the raw weight dictionaries ----
# disc_state = torch.load(f"{model_dir}/discriminator_state_dict_{state_time}.pt", map_location=device)
# gen_state  = torch.load(f"{model_dir}/generator_state_dict_{state_time}.pt",         map_location=device)
# gen2_state = torch.load(f"{model_dir}/generator_2_state_dict_{state_time}.pt",       map_location=device)

# # ----   rebuild the model objects ----
# discriminator = TsaiMiniRocketDiscriminator().to(device)
# generator     = Waveunet(**model_config_gen_vocal).to(device)
# generator_2   = Waveunet(**model_config_gen_).to(device)

# # ---- load the weights into those models ----
# discriminator.load_state_dict(disc_state)
# generator.load_state_dict(gen_state)
# generator_2.load_state_dict(gen2_state)


In [None]:
# # display the converted audios
# print(audio_files)
# for audio in audio_files:
#     print(audio)
#     display_audio(audio)

In [23]:
# gc.collect()
# torch.cuda.empty_cache()

# # Start training.
# train(
#     generator,
#     generator_2,
#     discriminator,
#     optimizer_D,
#     optimizer_G,
#     optimizer_G2,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     train_parameters["lambda_l1"],
#     train_parameters["lambda_cycle"],
#     adversarial_loss,
#     device,
#     num_epochs          = 50, #train_parameters["num_epochs"],
#     virtual_batch_size  = train_parameters["virtual_batch_size"],
#     log_dir             = train_parameters["log_dir"],
# )
print(now)

20250424-031428


## Save the models

In [21]:
# assert False
path = "models/"
torch.save(generator_vocal.state_dict(), path + "generator_state_vocal_dict_" + now + ".pt")
torch.save(generator_speech.state_dict(), path + "generator_state_speech_dict_" + now + ".pt")
torch.save(discriminator_speech.state_dict(), path + "discriminator_speech_state_dict_" + now + ".pt")
torch.save(discriminator_vocal.state_dict(), path + "discriminator_vocal_state_dict_" + now + ".pt")

# ------------- package everything to save -------------
export_dict = {
    "train_parameters": train_parameters,
    "model_config_gen_speech": model_config_gen_speech,      # Wave‑U‑Net (speech+accomp → vocal)
    "model_config_gen_vocal": model_config_gen_vocal,    # Wave‑U‑Net (vocal → speech)
}
import json

# (optional) ensure JSON‑serialisable: convert tuples → lists
def _convert(obj):
    if isinstance(obj, tuple):
        return list(obj)
    if isinstance(obj, dict):
        return {k: _convert(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [_convert(x) for x in obj]
    return obj

export_dict = _convert(export_dict)

with open(f"{path}/training_record_{now}.json", "w") as fp:
    json.dump(export_dict, fp, indent=2)



In [None]:
# now = datetime.now().strftime("%Y%m%d-%H%M%S")
# log_dir = "runs/" + "cycleGAN_experiment_" + now

# gc.collect()
# torch.cuda.empty_cache()

# # Start training.
# train(
#     generator_vocal,
#     generator_speech,
#     discriminator,
#     optimizer_D,
#     optimizer_G,
#     optimizer_G2,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     train_parameters["lambda_l1"],
#     train_parameters["lambda_cycle"],
#     adversarial_loss,
#     device,
#     num_epochs          = 50, #train_parameters["num_epochs"],
#     virtual_batch_size  = train_parameters["virtual_batch_size"],
#     log_dir             = train_parameters["log_dir"],
# )