Here we train our first version of the GAN.



## Initialize Wave-U-Net

We start by loading the necessary packages

Wave-U-Net is named ``generator``

In [2]:
# Import same packages as the train script in Wave-U-Net-Pytorch
import argparse
import os
import time
from functools import partial
from datetime import datetime



import torch
import pickle
import numpy as np

import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.nn import L1Loss
from tqdm import tqdm
from torchsummary import summary
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
# install torchaudio if not already installed
# ! pip install torchaudio
import torchaudio

import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Optional

# !pip install sktime
from sktime.transformations.panel.rocket import MiniRocketMultivariate
# 
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/My Drive/git_projects/spring_2025_dl_audio_project


# add a path to Wave-U-Net
import sys
sys.path.append('Wave-U-Net-Pytorch')
sys.path.append("workspace/hdd_project_data")

import model.utils as model_utils
import utils
from model.waveunet import Waveunet

# Check to see if we have a GPU available
print("GPU:", torch.cuda.is_available())

GPU: True


In [3]:
# I run these commands in the terminal that you get when you pay for Colab.

# %pip install musdb  # has some helpful data structures, also installs ffmpeg and stempeg
# %pip uninstall stempeg    # musdb installs the wrong version of stempeg'

We define the parameters of the model.

In [4]:
model_config = {
    "num_inputs": 256,               # 128 mel bins per spectrogram, but we have to spectrograms
    "num_outputs": 128,              # Output also has 128 mel bins
    "num_channels": [512*2, 512*4, 512*8],    # Example channel progression
    "instruments": ["vocal"],        # Only output vocal, so no music branch
    "kernel_size": 3,                # Must be odd
    "target_output_size": 256,       # Desired output time frames (post-processing may crop)
    "conv_type": "normal",           # Set to "normal" to meet assertion requirements
    "res": "fixed",                  # Use fixed resampling
    "separate": False,                # Separate branch for vocal
    "depth": 1,                      # Number of conv layers per block
    "strides": 2                   # Down/up-sampling stride
}

Load the model, check how much GPU memory it will use during training, and print a summary of the model.

In [5]:
# Ensure that you have a CUDA-enabled device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate and move the model to GPU
generator = Waveunet(**model_config).to(device)

# # Set up a dummy optimizer and loss function
# optimizer = Adam(generator.parameters(), lr=1e-3)
# loss_fn = L1Loss()

# # Define a dummy batch size
# batch_size = 256

# # Create a dummy input tensor with the required shape
# # model.num_inputs corresponds to the number of channels (256 in your config)
# # model.input_size is the computed length (353, for instance)
# dummy_input = torch.randn(batch_size, generator.num_inputs, generator.input_size, device=device)

# # Create a dummy target tensor with the shape that your model outputs.
# # For a single output branch (vocal), the output shape should be:
# # (batch_size, num_outputs, model.output_size)
# # model.num_outputs is 128 and model.output_size is computed (257 in your case)
# dummy_target = torch.randn(batch_size, generator.num_outputs, generator.output_size, device=device)

# # Reset GPU peak memory stats
# torch.cuda.reset_peak_memory_stats(device)

# # Run a single forward and backward pass
# optimizer.zero_grad()
# # If separate is False, the model returns a dictionary; pass the correct key.
# output = generator(dummy_input)["vocal"]
# loss = loss_fn(output, dummy_target)
# loss.backward()
# optimizer.step()

# # Retrieve GPU memory stats
# peak_memory = torch.cuda.max_memory_allocated(device)
# current_memory = torch.cuda.memory_allocated(device)
# print("Peak GPU memory allocated (bytes):", peak_memory)
# print("Current GPU memory allocated (bytes):", current_memory)

# # Optionally, print a detailed memory summary
# print(torch.cuda.memory_summary(device=device))


# summary(generator, input_size=(generator.num_inputs,  generator.input_size))


Using valid convolutions with 289 inputs and 257 outputs


Optionally compile the model to potentially decrease training time.

If we compile the model, to save it after training, we have to uncompile it using the following code:

```python
orig_generator = generator._orig_mod
path = ""
torch.save(orig_generator.state_dict(), path + "generator_state_dict.pt")


## Initialize miniRocket
We start by loading the necessary packages

### CPU Core Allocation for MiniRocketMultivariate

- The implementation of `MiniRocketMultivariate` runs on the **CPU**.
- We need to decide how many cores to allocate for it.
- Some cores will be used by MiniRocket itself, while others are needed for data preparation (e.g., generating spectrograms).
- This allocation likely needs to be **tuned for optimal performance**.
- As a starting point, we detect the number of available cores and split them evenly.
- Note: We avoid using *all* available cores to leave some resources for the operating system and other background processes.


In [6]:

# dataloader_n_jobs = num_cores - minirocket_n_jobs - 1

Create the MiniRocket model

In [7]:
import pandas as pd
from time import time

# MiniRocket Discriminator using tsai library
class TsaiMiniRocketDiscriminator(nn.Module):
    def __init__(
        self,
        freq_bins=256,
        time_frames=256,
        num_kernels=10000,  # number of convolutional kernels
        hidden_dim=1024,    # Increased to handle larger feature dimension
        output_dim=1
    ):
        super(TsaiMiniRocketDiscriminator, self).__init__()

        # This is the mini rocket transformer which extracts features
        self.rocket = MiniRocketMultivariate(num_kernels=num_kernels, n_jobs=minirocket_n_jobs)
        # tsai's miniRocketClassifier is implemented with MiniRocketMultivariate as well
        self.fitted = False   # fit before training
        self.freq_bins = freq_bins
        self.time_frames = time_frames
        self.num_kernels = num_kernels

        # For 2D data handling - process each sample with proper dimensions
        self.example_input = np.zeros((1, freq_bins, time_frames))

        self.feature_dim = num_kernels  # For vocals + accompaniment

        # Example feature reducing layers
        self.classifier = nn.Sequential(
            # First reduce the massive dimension to something manageable
            nn.LazyLinear( hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Second hidden layer
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Final classification layer
            nn.Linear(hidden_dim // 2, output_dim),
            nn.Sigmoid()
        )

    def fit_rocket(self, spectrograms):
        """
            Fit MiniRocket with just one piece of vocal training data (not the entire training dataset)
        """
        if not self.fitted:
            try:
                # Reshape for MiniRocket - it expects (n_instances, n_dimensions, series_length)
                # flatten the freq_bins dimension to create a multivariate time series
                batch_size = spectrograms.shape[0]

                # Convert first to numpy for sktime processing
                sample_data = spectrograms.detach().cpu().numpy()
                # print(sample_data.shape)
                # Reshape to sktime's expected format - reduce to single sample for fitting
                # sample_data = sample_data[:, 0]  # Take one sample, remove channel dim

                # Fit on this sample
                self.rocket.fit(sample_data)
                self.fitted = True

                # Test transform to get feature dimension
                test_transform = self.rocket.transform(sample_data)
                self.feature_dim = test_transform.shape[1]

                print(f"MiniRocket fitted. Feature dimension: {self.feature_dim}")

            except Exception as e:
                print(f"Error fitting MiniRocket: {e}")
                # Use a fallback if fitting fails
                self.fitted = True  # Mark as fitted to avoid repeated attempts

    def extract_features(self, spectrogram):
        """Extract MiniRocket features from a spectrogram"""
        try:
            # Ensure rocket is fitted
            if not self.fitted:
                self.fit_rocket(spectrogram)

            # Convert to numpy for sktime
            spec_np = spectrogram.detach().cpu().numpy()

            # Remove channel dimension expected by sktime
            # print(spec_np.shape)
            # spec_np = spec_np[:, 0]  # [batch_size, freq_bins, time_frames]
            # print(spec_np.shape)

            # This step extracts features using the convolutional kernels, numbers specified by num_kernels
            # print("1")
            features = self.rocket.transform(spec_np)
            # print("2")
            # Convert back to torch tensor
            # print("features:", features.shape)
            # print(features.head())
            features_tensor = torch.tensor(features.values).to(spectrogram.device)
            # print("features:", features.shape)
            # print("3")
            return features_tensor

        except Exception as e:
            print(f"Error in feature extraction: {e}")
            # Return zeros as fallback
            return torch.zeros((spectrogram.shape[0], self.num_kernels),
                              device=spectrogram.device)

    def forward(self, vocals, accompaniment):
        """
        Forward pass of the discriminator

        Args:
            vocals: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
            accompaniment: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
        """
        # Extract features from both spectrograms
        # start_time = time()
        vocal_features = self.extract_features(vocals)
        accomp_features = self.extract_features(accompaniment)
        # print("extract:", time()-start_time)
        # Concatenate features (conditional GAN)
        combined_features = torch.cat([vocal_features, accomp_features], dim=1)
        # print(combined_features.size())

        # Classify as real/fake
        validity = self.classifier(combined_features)

        return validity



In [8]:
# del discriminator
# discriminator = TsaiMiniRocketDiscriminator()
# We probably do not need to compile the model

# Import Data into Session

First, we run the code that defines the custom Dataset objects. The Datasets were compiled previously and saved in .pt files. In the next cell, we load those Dataset objects.

In [9]:
# %cd /content/drive/My Drive/git_projects/
# sys.path.append('/workspace/hdd_project_data/stempeg')
# import stempeg
import musdb
import torch
import librosa
import numpy as np
from torch.utils.data import Dataset

class MusdbDataset(Dataset):

  def __init__(self, musDB, window_size = 256, step_size = 128):
    self.mel_specs = torch.zeros(1, 2, 128, window_size)
    self.sample_rates = torch.tensor([0])

    num_songs = 0

    for track in musDB:
      stems, rate = track.stems, track.rate

      num_songs += 1

      # separate the vocal from other instruments and conver to mono signal
      audio_novocal = librosa.to_mono(np.transpose(stems[1] + stems[2] + stems[3]))
      audio_vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      logmelspec_novocal = torch.from_numpy(self._mel_spectrogram(audio_novocal, rate))
      logmelspec_vocal = torch.from_numpy(self._mel_spectrogram(audio_vocal, rate))

      start_ndx = 0

      for step in range(window_size // step_size):
        cropped_logmelspec_novocal = logmelspec_novocal[:, start_ndx:]
        cropped_logmelspec_vocal = logmelspec_vocal[:, start_ndx:]
        num_slices = cropped_logmelspec_novocal.shape[1] // window_size

        # chop off the last bit so that number of stft steps is a multiple of window_size
        cropped_logmelspec_novocal = cropped_logmelspec_novocal[: , 0:num_slices*window_size]
        cropped_logmelspec_vocal = cropped_logmelspec_vocal[:, 0:num_slices*window_size]

        # reshape tensors into chunks of size 128x(window_size)
        # first dimension is number of chunks
        cropped_logmelspec_novocal = torch.transpose(torch.reshape(cropped_logmelspec_novocal, (128, num_slices, window_size)), 0, 1)
        cropped_logmelspec_vocal = torch.transpose(torch.reshape(cropped_logmelspec_vocal, (128, num_slices, window_size)), 0, 1)

        # unsqueeze and concatenate these tensors. Then concatenate to the big tensor
        logmels = torch.cat((cropped_logmelspec_novocal.unsqueeze(1), cropped_logmelspec_vocal.unsqueeze(1)), 1)
        self.mel_specs = torch.cat((self.mel_specs, logmels), 0)
        self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)

        if num_songs % 5 == 0:
          print(str(num_songs) + " songs processed; produced " + str(self.mel_specs.shape[0]) + " spectrograms")

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx, 0], self.mel_specs[ndx, 1], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

  def cat(self, other_ds):
    self.mel_specs = torch.cat((self.mel_specs, other_ds.mel_specs), 0)
    self.sample_rates = torch.cat((self.sample_rates, other_ds.sample_rates), 0)

import torch
import librosa
import numpy as np
from torch.utils.data import Dataset

class SingingDataset(Dataset):

  def __init__(self, musDB, window_size = 256, step_size = 128):
    self.mel_specs = torch.zeros(1, 128, window_size)
    self.sample_rates = torch.tensor([0])

    num_songs = 0

    for track in musDB:
      stems, rate = track.stems, track.rate

      num_songs += 1

      # load the vocal
      vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      mel_spec = torch.from_numpy(self._mel_spectrogram(vocal, rate))

      start_ndx = 0
      for step in range(window_size // step_size):
        cropped_mel_spec = mel_spec[:, start_ndx:]
        num_slices = cropped_mel_spec.shape[1] // window_size

        # chop off the last bit so that number of stft steps is a multiple of window_size
        cropped_mel_spec = cropped_mel_spec[:, 0:num_slices*window_size]

        # reshape tensors into chunks of size 128x(window_size)
        # first dimension is number of chunks
        cropped_mel_spec = torch.transpose(torch.reshape(cropped_mel_spec, (128, num_slices, window_size)), 0, 1)

        # concatenate to the big tensor
        self.mel_specs = torch.cat((self.mel_specs, cropped_mel_spec), 0)
        self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)


    if num_songs % 5 == 0:
        print(str(num_songs) + " songs processed; produced " + str(self.mel_specs.shape[0]) + " spectrograms")

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

import torch
import librosa
import numpy as np
import os
from torch.utils.data import Dataset

class LibriSpeechDataset(Dataset):

    def __init__(self, path, window_size = 256, step_size = 128, num_specs = 7647*2):
        self.mel_specs = self.mel_specs = torch.zeros(1, 128, window_size)
        self.sample_rates = torch.tensor([0])

        num_files_opened = 0

        for speaker_dir in os.listdir(path):
            speaker_path = path + "/" + speaker_dir
            for chapter_dir in os.listdir(speaker_path):
                chapter_path = speaker_path + "/" + chapter_dir
                for file in os.listdir(chapter_path):
                    # checks file extension and stops when we hit desired number of spectrograms (num_specs)
                    if file.endswith('.flac') and self.mel_specs.shape[0] - 1 < num_specs:
                        # get audio file and convert to log mel spectrogram
                        speech, rate = librosa.load(chapter_path + "/" + file, sr = 44100)
                        mel_spec = torch.from_numpy(self._mel_spectrogram(speech, rate))
                        start_ndx = 0

                        num_files_opened += 1

                        for step in range(window_size // step_size):
                            cropped_mel_spec = mel_spec[:, start_ndx:]

                            # Saves the total number of 128 x (window_size) spectrograms
                            num_slices = cropped_mel_spec.shape[1] // window_size

                            # chop off the last bit so that number of stft steps is a multiple of window_size
                            cropped_mel_spec = cropped_mel_spec[ : , 0 : num_slices*window_size]

                            # reshape the tensor to have many spectrograms of size 128 x (steps)
                            cropped_mel_spec = torch.transpose(torch.reshape(cropped_mel_spec, (128, num_slices, window_size)), 0, 1)

                            # concatenate tensor to the full tensor in the Dataset object
                            self.mel_specs = torch.cat((self.mel_specs, cropped_mel_spec), 0)
                            self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)

                            # increment start_ndx
                            start_ndx += step_size


                        if num_files_opened % 50 == 0:
                            print("opened " + str(num_files_opened) + " files and produced " + str(self.mel_specs.shape[0]) + " spectrograms")


        # chop off the zero layer we initialized with
        self.mel_specs = self.mel_specs[1:]
        self.sample_rates = self.sample_rates[1:]

    def __len__(self):
        return self.mel_specs.shape[0]

    def __getitem__(self, ndx):
        return self.mel_specs[ndx], self.sample_rates[ndx]

    def _mel_spectrogram(self, audio, rate):
        # compute the log-mel-spectrogram of the audio at the given sample rate
        return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

In [10]:
path = "/workspace/hdd_project_data/"


# The string below is the path to the saved LibriSpeechDataset in your Drive
librispeechDataset_path = path + "LibriSpeechDataset_withOverlap.pt"


librispeech_dataset = torch.load(librispeechDataset_path, weights_only=False)


# The string below is the path to the saved MusdbDataset in your Drive
musdbDataset_path = path + "musdb_noOverlap_test.pt"
musdb_dataset = torch.load(musdbDataset_path, weights_only=False)

In [11]:
# This fixes the problem with the sample rates
musdb_dataset.sample_rates = torch.full((len(musdb_dataset),), 44100)
librispeech_dataset.sample_rates = torch.full((len(musdb_dataset),), 44100)

# Because of the way the librispeech dataset was constructed, it is slightly longer
# than the musbd dataset. Crop the librispeech dataset with these lines
librispeech_dataset.mel_specs = librispeech_dataset.mel_specs[0:len(musdb_dataset)]
librispeech_dataset.sample_rates = librispeech_dataset.sample_rates[0:len(musdb_dataset)]

### Explore these datasets

In [12]:
# --- Explore the Datasets ---
print("=== MusDB Dataset Exploration ===")
print("Length:", len(musdb_dataset))
print("mel_specs shape:", musdb_dataset.mel_specs.shape)
print("sample_rates shape:", musdb_dataset.sample_rates.shape)
print()
accompaniment, vocal, sample_rate = musdb_dataset[0]
print("Sample 0 - Accompaniment shape:", accompaniment.size())
print("Sample 0 - Vocal shape:", vocal.size())
print("Sample 0 - Sample rate:", sample_rate)
print()

print("=== LibriSpeech Dataset Exploration ===")
print("Length:", len(librispeech_dataset))
print("mel_specs shape:", librispeech_dataset.mel_specs.shape)
print("sample_rates shape:", librispeech_dataset.sample_rates.shape)
print()
speech, sample_rate = librispeech_dataset[0]
print("Sample 0 - Speech shape:", speech.size())
print("Sample 0 - Sample rate:", sample_rate)

=== MusDB Dataset Exploration ===
Length: 4167
mel_specs shape: torch.Size([4167, 2, 128, 256])
sample_rates shape: torch.Size([4167])

Sample 0 - Accompaniment shape: torch.Size([128, 256])
Sample 0 - Vocal shape: torch.Size([128, 256])
Sample 0 - Sample rate: tensor(44100)

=== LibriSpeech Dataset Exploration ===
Length: 4167
mel_specs shape: torch.Size([4167, 128, 256])
sample_rates shape: torch.Size([4167])

Sample 0 - Speech shape: torch.Size([128, 256])
Sample 0 - Sample rate: tensor(44100)


## Dataset Helpers Explanation
Why New Dataset Helpers?

We have created new dataset helper classes (i.e., AccompanimentData, VocalData, and SpeechData) so that we can control how the data is padded and later shuffled.

- **Separation of Data:**
We separated the vocal and accompaniment data from the MusDB dataset. In our experiments, we might want to shuffle the speech data independently of the combined music data.

- **Shuffling Considerations:**
For the vocal and accompaniment data, we want to maintain their pairing so that they are shuffled in the same order. In contrast, we want the speech data to be shuffled independently.

- **Future Extensions:**
In the future, we may add another helper class that combines the vocal and accompaniment data to ensure synchronized shuffling in our data loaders.

This modular approach gives us flexibility in handling and preprocessing the data for our GAN training.

In [13]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

class AccompanimentData(Dataset):
    def __init__(self, musdb_dataset, output_length=289):
        self.musdb_dataset = musdb_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.musdb_dataset)

    def __getitem__(self, index):
        accompaniment, _, _ = self.musdb_dataset[index]  # shape: [128, 256]
        current_len = accompaniment.size(-1)             # 256
        delta = self.output_length - current_len         # 289 - 256 = 33

        # If delta is positive, pad. Otherwise, you might want to crop or handle differently.
        if delta > 0:
            # Half the remainder goes to the front
            left_pad_len = (delta // 2) + (delta % 2)  # 17
            right_pad_len = delta // 2                # 16
            accompaniment_pad = F.pad(accompaniment,
                                  (left_pad_len, right_pad_len),
                                  "constant", 0)
        return {"no_pad" : accompaniment, "pad" : accompaniment_pad}


class VocalData(Dataset):
    def __init__(self, musdb_dataset, output_length=289):
        self.musdb_dataset = musdb_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.musdb_dataset)

    def __getitem__(self, index):
        _, vocal, _ = self.musdb_dataset[index]  # shape: [128, 256]
        current_len = vocal.size(-1)
        delta = self.output_length - current_len

        if delta > 0:
            left_pad_len = (delta // 2) + (delta % 2)
            right_pad_len = delta // 2
            vocal_pad = F.pad(vocal, (left_pad_len, right_pad_len), "constant", 0)
        return {"no_pad" : vocal, "pad" : vocal_pad}



class SpeechData(Dataset):
    def __init__(self, librispeech_dataset, output_length=289):
        self.librispeech_dataset = librispeech_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.librispeech_dataset)

    def __getitem__(self, index):
        speech, _ = self.librispeech_dataset[index]
        # If speech has multiple slices, pick the first slice
        if speech.dim() == 3:
            speech = speech[0]  # shape: [128, 256]
        current_len = speech.size(-1)
        delta = self.output_length - current_len

        if delta > 0:
            left_pad_len = (delta // 2) + (delta % 2)
            right_pad_len = delta // 2
            speech_pad = F.pad(speech, (left_pad_len, right_pad_len), "constant", 0)
        return {"no_pad" : speech, "pad" : speech_pad}


In [14]:
# print(AccompanimentData(musdb_dataset)[0])
# print(VocalData(musdb_dataset)[0])
# print(SpeechData(librispeech_dataset)[0])

### DataLoader Explanation
What is a DataLoader and Why Do We Need It?

A DataLoader in PyTorch is a utility that wraps a dataset and provides:

- **Batching:** It divides your dataset into batches so that you can train your models with mini-batch gradient descent.

- **Shuffling:** It shuffles the data at every epoch (if specified) to help reduce overfitting and ensure the model sees a diverse set of examples.

- **Parallel Data Loading:** It can load data in parallel using multiple worker processes, speeding up training.

In our case, we create separate DataLoaders for:

- The accompaniment data (paired with vocals) from the MusDB dataset.

- The vocal data (paired with accompaniment) from the MusDB dataset.

- The speech data from the LibriSpeech dataset.

This lets us shuffle the speech data independently, while keeping the vocal/accompaniment pairs synchronized during training.

In [15]:
# Define batch size
batch_size = 32  # Change as needed

# Create data loaders
accompaniment_loader = DataLoader(
    AccompanimentData(musdb_dataset),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)
vocal_loader = DataLoader(
    VocalData(musdb_dataset),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)
speech_loader = DataLoader(
    SpeechData(librispeech_dataset),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

# # Print how many batches each DataLoader contains
# print("Accompaniment loader length:", len(accompaniment_loader))
# print("Vocal loader length:", len(vocal_loader))
# print("Speech loader length:", len(speech_loader))

# # Optionally, fetch and print the shape of the first batch
# accompaniment_batch = next(iter(accompaniment_loader))
# vocal_batch = next(iter(vocal_loader))
# speech_batch = next(iter(speech_loader))
# print(accompaniment_batch["pad"])

# print("Accompaniment first batch shape:", accompaniment_batch.shape)
# print("Vocal first batch shape:", vocal_batch.shape)
# print("Speech first batch shape:", speech_batch.shape)


## Second Generator Model
Here we initialize the second generator model whose purpose is to convert the generated vocals back to normal speach for the cycle GAN. We again use Wave-U-Net, but with a different configuration. The main difference is that we will not input the music along with the vocal track.

## Transform Input to generator_2

The output of the generator model is a (batch_size, 128, 257) tensor. The model expects a tensor of size (batch_size, 128, 289). We need to pad the last dimension with 16 zeros on each size.

In [16]:
def transform_for_gen_2(batch, output_length = 289):
  current_len = batch.size(-1)
  delta = output_length - current_len

  if delta > 0:
      left_pad_len = (delta // 2) + (delta % 2)
      right_pad_len = delta // 2
      batch = F.pad(batch, (left_pad_len, right_pad_len), "constant", 0)
  return batch


## Train the Cycle GAN
The models are ``generator`` and ``discriminator`` and ``generator_2``.


In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

# ----- Single Epoch Training Function -----
def train_epoch(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    lambda_l1,
    lambda_cycle,
    adversarial_loss,
    device,
    virtual_batch_size=1  # Set > 1 for gradient accumulation across mini-batches.
    
):
    total_loss_D = 0.0
    total_loss_G = 0.0
    total_loss_G_adv = 0.0
    total_loss_G_L1 = 0.0
    total_loss_cycle = 0.0
    num_batches = 0

    # Optionally record gradient norms per batch for diagnosing vanishing gradients.
    grad_norms_D = []
    grad_norms_G = []

    # Wrap the loader in tqdm for progress reporting.
    train_loader = enumerate(zip(zip(accompaniment_loader, vocal_loader), speech_loader))
    
    for batch_number, ((accompaniment, vocal), speech) in tqdm(train_loader, desc="Training Batches"):
        # Move data to device and cast to float32.
        accompaniment_pad = accompaniment["pad"].type(torch.float32).to(device)  # shape: [B, 128, 289]
        speech_pad = speech["pad"].type(torch.float32).to(device)                # shape: [B, 128, 289]

        # Prepare generator input by concatenating speech and accompaniment along the channel dim.
        generator_input = torch.cat([speech_pad, accompaniment_pad], dim=1)  # shape: [B, 256, 289]

        # ---------------------
        # Train the Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        B = accompaniment_pad.size(0)
        real_labels = torch.full((B, 1), 1.0, device=device)
        fake_labels = torch.zeros((B, 1), device=device)

        # Get non-padded data.
        accompaniment_no_pad = accompaniment["no_pad"].type(torch.float32).to(device)
        vocals_no_pad = vocal["no_pad"].type(torch.float32).to(device)

        # Discriminator on real samples.
        pred_real = discriminator(vocals_no_pad, accompaniment_no_pad)
        loss_D_real = adversarial_loss(pred_real, real_labels)

        # Generate fake singing.
        raw_fake_singing = generator(generator_input)["vocal"]
        # Clone the output to ensure safety with subsequent operations.
        fake_singing = raw_fake_singing.clone()
        fake_singing_crop = fake_singing[:, :, :256].clone()

        # Discriminator on fake samples.
        pred_fake = discriminator(fake_singing_crop, accompaniment_no_pad)
        loss_D_fake = adversarial_loss(pred_fake, fake_labels)

        loss_D = 0.5 * (loss_D_real + loss_D_fake)
        loss_D.backward()

        # Record discriminator gradient norms.
        grad_norm = 0.0
        count = 0
        for p in discriminator.parameters():
            if p.grad is not None:
                grad_norm += p.grad.norm().item()
                count += 1
        if count > 0:
            grad_norms_D.append(grad_norm / count)

        optimizer_D.step()

        # ---------------------
        # Train the Generators
        # ---------------------
        # Reset generator gradients if not accumulating.
        if virtual_batch_size == 1:
            optimizer_G.zero_grad()
            optimizer_G2.zero_grad()

        # Generator adversarial loss.
        pred_fake_for_G = discriminator(fake_singing, accompaniment_no_pad)
        
        # Prepare input for second generator (assume transform_for_gen_2 is defined).
        fake_singing_padded = transform_for_gen_2(fake_singing)
        raw_reconstructed_speech = generator_2(fake_singing_padded)["speech"]
        reconstructed_speech = raw_reconstructed_speech[:, :, :256].clone()
        speech_no_pad = speech["no_pad"].type(torch.float32).to(device)

        loss_G_adv = adversarial_loss(pred_fake_for_G, real_labels)
        loss_cycle = l1_loss(reconstructed_speech, speech_no_pad)
        loss_G_L1 = l1_loss(fake_singing_crop, vocals_no_pad)

        loss_G = loss_G_adv + lambda_l1 * loss_G_L1 + lambda_cycle * loss_cycle

        loss_G.backward()

        # Record generator gradient norms.
        grad_norm = 0.0
        count = 0
        for p in generator.parameters():
            if p.grad is not None:
                grad_norm += p.grad.norm().item()
                count += 1
        if count > 0:
            grad_norms_G.append(grad_norm / count)

        # If using virtual batch accumulation.
        if (batch_number + 1) % virtual_batch_size == 0:
            optimizer_G.step()
            optimizer_G2.step()
            optimizer_G.zero_grad()
            optimizer_G2.zero_grad()

        total_loss_D += loss_D.item()
        total_loss_G += loss_G.item()
        total_loss_G_adv += loss_G_adv.item()
        total_loss_G_L1 += loss_G_L1.item()
        total_loss_cycle += loss_cycle.item()
        num_batches += 1

    epoch_metrics = {
        "loss_D": total_loss_D / num_batches,
        "loss_G_total": total_loss_G / num_batches,
        "loss_G_adv": total_loss_G_adv / num_batches,
        "loss_G_L1": total_loss_G_L1 / num_batches,
        "loss_cycle": total_loss_cycle / num_batches,
        "avg_grad_norm_D": sum(grad_norms_D) / len(grad_norms_D) if grad_norms_D else 0.0,
        "avg_grad_norm_G": sum(grad_norms_G) / len(grad_norms_G) if grad_norms_G else 0.0,
    }
    return epoch_metrics

# ----- Multi-Epoch Training Function -----
def train(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    lambda_l1,
    lambda_cycle,
    adversarial_loss,
    device,
    num_epochs,
    virtual_batch_size,
    log_dir
):
    writer = SummaryWriter(log_dir=log_dir)
    global_step = 0

    for epoch in range(num_epochs):
        print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
        epoch_metrics = train_epoch(
            generator,
            generator_2,
            discriminator,
            optimizer_D,
            optimizer_G,
            optimizer_G2,
            accompaniment_loader,
            vocal_loader,
            speech_loader,
            l1_loss,
            lambda_l1,
            lambda_cycle,
            adversarial_loss,
            device,
            virtual_batch_size
        )
        print(f"Epoch {epoch+1} Metrics:")
        print(f"  Loss_D:         {epoch_metrics['loss_D']:.4f}")
        print(f"  Loss_G_total:   {epoch_metrics['loss_G_total']:.4f}")
        print(f"  Loss_G_adv:     {epoch_metrics['loss_G_adv']:.4f}")
        print(f"  Loss_G_L1:      {epoch_metrics['loss_G_L1']:.4f}")
        print(f"  Loss_Cycle:     {epoch_metrics['loss_cycle']:.4f}")
        print(f"  Grad Norm D:    {epoch_metrics['avg_grad_norm_D']:.4f}")
        print(f"  Grad Norm G:    {epoch_metrics['avg_grad_norm_G']:.4f}")

        # Log metrics to TensorBoard.
        writer.add_scalar("Loss/Discriminator", epoch_metrics["loss_D"], epoch)
        writer.add_scalar("Loss/Generator_total", epoch_metrics["loss_G_total"], epoch)
        writer.add_scalar("Loss/Generator_adversarial", epoch_metrics["loss_G_adv"], epoch)
        writer.add_scalar("Loss/Generator_L1", epoch_metrics["loss_G_L1"], epoch)
        writer.add_scalar("Loss/Cycle", epoch_metrics["loss_cycle"], epoch)
        writer.add_scalar("Gradients/Discriminator", epoch_metrics["avg_grad_norm_D"], epoch)
        writer.add_scalar("Gradients/Generator", epoch_metrics["avg_grad_norm_G"], epoch)

        global_step += 1

    writer.close()




In [18]:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "runs/" + "cycleGAN_experiment_" + now

# ---------------- hyper‑parameters in ONE place ----------------
train_parameters = {
    # optimisation
    "lr_G":          1e-4,
    "lr_G2":         1e-4,
    "lr_D":          1e-8,
    "betas":         (0.5, 0.999),

    # loss weights
    "lambda_l1":     5,
    "lambda_cycle":  5,

    # schedule
    "num_epochs":    20,
    "virtual_batch_size": 8,

    # bookkeeping
    "log_dir":       f"runs/cycleGAN_experiment_{now}",
    "model_dir":     "models",
}


In [19]:
# ----- Example Setup and Training Invocation -----

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Model configurations for generator and generator_2.
model_config_gen = {
    "num_inputs": 256,  # Two spectrograms concatenated (2 * 128 mel bins)
    "num_outputs": 128,
    "num_channels": [512*2, 512*4, 512*8],
    "instruments": ["vocal"],
    "kernel_size": 3,
    "target_output_size": 256,
    "conv_type": "normal",
    "res": "fixed",
    "separate": False,
    "depth": 1,
    "strides": 2
}
generator = Waveunet(**model_config_gen).to(device)

model_config_gen2 = {
    "num_inputs": 128,  # One spectrogram input
    "num_outputs": 128,
    "num_channels": [256*2, 256*4, 256*8],
    "instruments": ["speech"],
    "kernel_size": 3,
    "target_output_size": 256,
    "conv_type": "normal",
    "res": "fixed",
    "separate": False,
    "depth": 1,
    "strides": 2
}
generator_2 = Waveunet(**model_config_gen2).to(device)
minirocket_n_jobs = 22 # Instantiate the discriminator.
discriminator = TsaiMiniRocketDiscriminator().to(device)

# Assume dataloaders `accompaniment_loader`, `vocal_loader`, and `speech_loader` are defined.

# Optionally, prepare the discriminator (e.g., pre-fitting on some speech data).
accompaniment_batch = next(iter(accompaniment_loader))
vocal_batch = next(iter(vocal_loader))
speech_batch = next(iter(speech_loader))["no_pad"]
print("Fitting discriminator...")
discriminator.fit_rocket(speech_batch)

# Loss functions.
adversarial_loss = nn.BCELoss().to(device)
l1_loss = nn.L1Loss().to(device)

optimizer_G  = optim.Adam(generator.parameters(),  lr=train_parameters["lr_G"],  betas=train_parameters["betas"])
optimizer_G2 = optim.Adam(generator_2.parameters(), lr=train_parameters["lr_G2"], betas=train_parameters["betas"])
optimizer_D  = optim.Adam(discriminator.parameters(), lr=train_parameters["lr_D"], betas=train_parameters["betas"])



Using device: cuda
Using valid convolutions with 289 inputs and 257 outputs
Using valid convolutions with 289 inputs and 257 outputs
Fitting discriminator...
MiniRocket fitted. Feature dimension: 9996


In [20]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Start training.
train(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    train_parameters["lambda_l1"],
    train_parameters["lambda_cycle"],
    adversarial_loss,
    device,
    num_epochs          = train_parameters["num_epochs"],
    virtual_batch_size  = train_parameters["virtual_batch_size"],
    log_dir             = train_parameters["log_dir"],
)



=== Epoch 1/20 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 1 Metrics:
  Loss_D:         0.6951
  Loss_G_total:   243.4683
  Loss_G_adv:     0.7334
  Loss_G_L1:      23.6128
  Loss_Cycle:     24.9342
  Grad Norm D:    0.2922
  Grad Norm G:    249.7935

=== Epoch 2/20 ===


Training Batches: 130it [00:36,  3.54it/s]


Epoch 2 Metrics:
  Loss_D:         0.6955
  Loss_G_total:   148.0159
  Loss_G_adv:     0.7296
  Loss_G_L1:      15.8821
  Loss_Cycle:     13.5751
  Grad Norm D:    0.2750
  Grad Norm G:    183.3576

=== Epoch 3/20 ===


Training Batches: 130it [00:36,  3.55it/s]


Epoch 3 Metrics:
  Loss_D:         0.6948
  Loss_G_total:   124.2751
  Loss_G_adv:     0.7266
  Loss_G_L1:      15.1369
  Loss_Cycle:     9.5728
  Grad Norm D:    0.2731
  Grad Norm G:    89.1967

=== Epoch 4/20 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 4 Metrics:
  Loss_D:         0.6944
  Loss_G_total:   124.7123
  Loss_G_adv:     0.7245
  Loss_G_L1:      15.0036
  Loss_Cycle:     9.7940
  Grad Norm D:    0.2735
  Grad Norm G:    104.9015

=== Epoch 5/20 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 5 Metrics:
  Loss_D:         0.6937
  Loss_G_total:   122.6212
  Loss_G_adv:     0.7230
  Loss_G_L1:      14.9567
  Loss_Cycle:     9.4229
  Grad Norm D:    0.2711
  Grad Norm G:    101.0466

=== Epoch 6/20 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 6 Metrics:
  Loss_D:         0.6921
  Loss_G_total:   122.4530
  Loss_G_adv:     0.7221
  Loss_G_L1:      15.0646
  Loss_Cycle:     9.2816
  Grad Norm D:    0.2725
  Grad Norm G:    127.8588

=== Epoch 7/20 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 7 Metrics:
  Loss_D:         0.6917
  Loss_G_total:   114.3243
  Loss_G_adv:     0.7211
  Loss_G_L1:      15.0874
  Loss_Cycle:     7.6332
  Grad Norm D:    0.2716
  Grad Norm G:    92.4090

=== Epoch 8/20 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 8 Metrics:
  Loss_D:         0.6909
  Loss_G_total:   115.1173
  Loss_G_adv:     0.7202
  Loss_G_L1:      15.1406
  Loss_Cycle:     7.7388
  Grad Norm D:    0.2723
  Grad Norm G:    133.7462

=== Epoch 9/20 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 9 Metrics:
  Loss_D:         0.6900
  Loss_G_total:   115.1285
  Loss_G_adv:     0.7190
  Loss_G_L1:      15.0184
  Loss_Cycle:     7.8635
  Grad Norm D:    0.2740
  Grad Norm G:    154.5852

=== Epoch 10/20 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 10 Metrics:
  Loss_D:         0.6894
  Loss_G_total:   112.3885
  Loss_G_adv:     0.7190
  Loss_G_L1:      14.9336
  Loss_Cycle:     7.4003
  Grad Norm D:    0.2728
  Grad Norm G:    132.3393

=== Epoch 11/20 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 11 Metrics:
  Loss_D:         0.6882
  Loss_G_total:   109.2319
  Loss_G_adv:     0.7182
  Loss_G_L1:      14.9160
  Loss_Cycle:     6.7867
  Grad Norm D:    0.2729
  Grad Norm G:    112.1448

=== Epoch 12/20 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 12 Metrics:
  Loss_D:         0.6877
  Loss_G_total:   111.4107
  Loss_G_adv:     0.7180
  Loss_G_L1:      14.9067
  Loss_Cycle:     7.2318
  Grad Norm D:    0.2750
  Grad Norm G:    143.5424

=== Epoch 13/20 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 13 Metrics:
  Loss_D:         0.6878
  Loss_G_total:   108.7171
  Loss_G_adv:     0.7183
  Loss_G_L1:      14.8102
  Loss_Cycle:     6.7896
  Grad Norm D:    0.2762
  Grad Norm G:    117.7255

=== Epoch 14/20 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 14 Metrics:
  Loss_D:         0.6865
  Loss_G_total:   107.1933
  Loss_G_adv:     0.7183
  Loss_G_L1:      14.7806
  Loss_Cycle:     6.5144
  Grad Norm D:    0.2743
  Grad Norm G:    107.0674

=== Epoch 15/20 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 15 Metrics:
  Loss_D:         0.6864
  Loss_G_total:   107.5444
  Loss_G_adv:     0.7182
  Loss_G_L1:      14.8236
  Loss_Cycle:     6.5417
  Grad Norm D:    0.2754
  Grad Norm G:    111.4470

=== Epoch 16/20 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 16 Metrics:
  Loss_D:         0.6847
  Loss_G_total:   111.0083
  Loss_G_adv:     0.7186
  Loss_G_L1:      14.8393
  Loss_Cycle:     7.2186
  Grad Norm D:    0.2750
  Grad Norm G:    154.9176

=== Epoch 17/20 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 17 Metrics:
  Loss_D:         0.6844
  Loss_G_total:   108.3326
  Loss_G_adv:     0.7186
  Loss_G_L1:      14.8154
  Loss_Cycle:     6.7074
  Grad Norm D:    0.2735
  Grad Norm G:    126.0532

=== Epoch 18/20 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 18 Metrics:
  Loss_D:         0.6839
  Loss_G_total:   109.9369
  Loss_G_adv:     0.7188
  Loss_G_L1:      14.7647
  Loss_Cycle:     7.0789
  Grad Norm D:    0.2747
  Grad Norm G:    151.6303

=== Epoch 19/20 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 19 Metrics:
  Loss_D:         0.6832
  Loss_G_total:   108.0518
  Loss_G_adv:     0.7188
  Loss_G_L1:      14.7560
  Loss_Cycle:     6.7106
  Grad Norm D:    0.2723
  Grad Norm G:    123.9540

=== Epoch 20/20 ===


Training Batches: 130it [00:36,  3.52it/s]

Epoch 20 Metrics:
  Loss_D:         0.6830
  Loss_G_total:   107.1174
  Loss_G_adv:     0.7199
  Loss_G_L1:      14.7880
  Loss_Cycle:     6.4915
  Grad Norm D:    0.2736
  Grad Norm G:    115.7214





In [23]:
gc.collect()
torch.cuda.empty_cache()

# Start training.
train(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    train_parameters["lambda_l1"],
    train_parameters["lambda_cycle"],
    adversarial_loss,
    device,
    num_epochs          = 50, #train_parameters["num_epochs"],
    virtual_batch_size  = train_parameters["virtual_batch_size"],
    log_dir             = train_parameters["log_dir"],
)



=== Epoch 1/50 ===


Training Batches: 0it [00:00, ?it/s]

Training Batches: 130it [00:36,  3.52it/s]


Epoch 1 Metrics:
  Loss_D:         0.6821
  Loss_G_total:   108.2040
  Loss_G_adv:     0.7204
  Loss_G_L1:      14.7449
  Loss_Cycle:     6.7519
  Grad Norm D:    0.2734
  Grad Norm G:    143.7835

=== Epoch 2/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 2 Metrics:
  Loss_D:         0.6818
  Loss_G_total:   106.9059
  Loss_G_adv:     0.7201
  Loss_G_L1:      14.7240
  Loss_Cycle:     6.5132
  Grad Norm D:    0.2719
  Grad Norm G:    127.6750

=== Epoch 3/50 ===


Training Batches: 130it [00:36,  3.54it/s]


Epoch 3 Metrics:
  Loss_D:         0.6813
  Loss_G_total:   107.5395
  Loss_G_adv:     0.7207
  Loss_G_L1:      14.7562
  Loss_Cycle:     6.6076
  Grad Norm D:    0.2716
  Grad Norm G:    138.9725

=== Epoch 4/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 4 Metrics:
  Loss_D:         0.6813
  Loss_G_total:   107.1295
  Loss_G_adv:     0.7199
  Loss_G_L1:      14.7537
  Loss_Cycle:     6.5282
  Grad Norm D:    0.2717
  Grad Norm G:    142.0719

=== Epoch 5/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 5 Metrics:
  Loss_D:         0.6807
  Loss_G_total:   106.4080
  Loss_G_adv:     0.7202
  Loss_G_L1:      14.6990
  Loss_Cycle:     6.4386
  Grad Norm D:    0.2716
  Grad Norm G:    142.7587

=== Epoch 6/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 6 Metrics:
  Loss_D:         0.6802
  Loss_G_total:   105.4599
  Loss_G_adv:     0.7216
  Loss_G_L1:      14.7674
  Loss_Cycle:     6.1803
  Grad Norm D:    0.2712
  Grad Norm G:    130.4185

=== Epoch 7/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 7 Metrics:
  Loss_D:         0.6794
  Loss_G_total:   105.5919
  Loss_G_adv:     0.7210
  Loss_G_L1:      14.7146
  Loss_Cycle:     6.2596
  Grad Norm D:    0.2716
  Grad Norm G:    141.1670

=== Epoch 8/50 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 8 Metrics:
  Loss_D:         0.6793
  Loss_G_total:   104.1546
  Loss_G_adv:     0.7212
  Loss_G_L1:      14.7654
  Loss_Cycle:     5.9213
  Grad Norm D:    0.2718
  Grad Norm G:    125.7804

=== Epoch 9/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 9 Metrics:
  Loss_D:         0.6787
  Loss_G_total:   106.4841
  Loss_G_adv:     0.7222
  Loss_G_L1:      14.7878
  Loss_Cycle:     6.3646
  Grad Norm D:    0.2747
  Grad Norm G:    163.2046

=== Epoch 10/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 10 Metrics:
  Loss_D:         0.6787
  Loss_G_total:   104.9057
  Loss_G_adv:     0.7215
  Loss_G_L1:      14.7025
  Loss_Cycle:     6.1344
  Grad Norm D:    0.2731
  Grad Norm G:    151.0257

=== Epoch 11/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 11 Metrics:
  Loss_D:         0.6780
  Loss_G_total:   103.8345
  Loss_G_adv:     0.7231
  Loss_G_L1:      14.7461
  Loss_Cycle:     5.8762
  Grad Norm D:    0.2733
  Grad Norm G:    133.3871

=== Epoch 12/50 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 12 Metrics:
  Loss_D:         0.6766
  Loss_G_total:   102.6991
  Loss_G_adv:     0.7226
  Loss_G_L1:      14.7059
  Loss_Cycle:     5.6894
  Grad Norm D:    0.2743
  Grad Norm G:    124.8220

=== Epoch 13/50 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 13 Metrics:
  Loss_D:         0.6763
  Loss_G_total:   107.2014
  Loss_G_adv:     0.7250
  Loss_G_L1:      14.7326
  Loss_Cycle:     6.5627
  Grad Norm D:    0.2768
  Grad Norm G:    181.9534

=== Epoch 14/50 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 14 Metrics:
  Loss_D:         0.6756
  Loss_G_total:   102.8215
  Loss_G_adv:     0.7238
  Loss_G_L1:      14.7393
  Loss_Cycle:     5.6803
  Grad Norm D:    0.2769
  Grad Norm G:    131.8778

=== Epoch 15/50 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 15 Metrics:
  Loss_D:         0.6750
  Loss_G_total:   102.8494
  Loss_G_adv:     0.7256
  Loss_G_L1:      14.6616
  Loss_Cycle:     5.7632
  Grad Norm D:    0.2773
  Grad Norm G:    131.0210

=== Epoch 16/50 ===


Training Batches: 130it [00:36,  3.54it/s]


Epoch 16 Metrics:
  Loss_D:         0.6745
  Loss_G_total:   103.8186
  Loss_G_adv:     0.7248
  Loss_G_L1:      14.7330
  Loss_Cycle:     5.8858
  Grad Norm D:    0.2776
  Grad Norm G:    150.1938

=== Epoch 17/50 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 17 Metrics:
  Loss_D:         0.6742
  Loss_G_total:   104.5843
  Loss_G_adv:     0.7261
  Loss_G_L1:      14.6689
  Loss_Cycle:     6.1027
  Grad Norm D:    0.2769
  Grad Norm G:    163.2186

=== Epoch 18/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 18 Metrics:
  Loss_D:         0.6745
  Loss_G_total:   102.0902
  Loss_G_adv:     0.7257
  Loss_G_L1:      14.6769
  Loss_Cycle:     5.5960
  Grad Norm D:    0.2770
  Grad Norm G:    128.0683

=== Epoch 19/50 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 19 Metrics:
  Loss_D:         0.6729
  Loss_G_total:   102.6454
  Loss_G_adv:     0.7280
  Loss_G_L1:      14.6292
  Loss_Cycle:     5.7543
  Grad Norm D:    0.2779
  Grad Norm G:    138.5407

=== Epoch 20/50 ===


Training Batches: 130it [00:36,  3.51it/s]


Epoch 20 Metrics:
  Loss_D:         0.6729
  Loss_G_total:   102.3846
  Loss_G_adv:     0.7266
  Loss_G_L1:      14.6598
  Loss_Cycle:     5.6718
  Grad Norm D:    0.2774
  Grad Norm G:    140.5238

=== Epoch 21/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 21 Metrics:
  Loss_D:         0.6725
  Loss_G_total:   101.9019
  Loss_G_adv:     0.7279
  Loss_G_L1:      14.6028
  Loss_Cycle:     5.6320
  Grad Norm D:    0.2783
  Grad Norm G:    141.6563

=== Epoch 22/50 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 22 Metrics:
  Loss_D:         0.6716
  Loss_G_total:   103.2474
  Loss_G_adv:     0.7277
  Loss_G_L1:      14.6876
  Loss_Cycle:     5.8163
  Grad Norm D:    0.2793
  Grad Norm G:    147.4989

=== Epoch 23/50 ===


Training Batches: 130it [00:37,  3.49it/s]


Epoch 23 Metrics:
  Loss_D:         0.6711
  Loss_G_total:   102.9706
  Loss_G_adv:     0.7282
  Loss_G_L1:      14.5812
  Loss_Cycle:     5.8673
  Grad Norm D:    0.2790
  Grad Norm G:    154.3014

=== Epoch 24/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 24 Metrics:
  Loss_D:         0.6709
  Loss_G_total:   101.1025
  Loss_G_adv:     0.7292
  Loss_G_L1:      14.5789
  Loss_Cycle:     5.4957
  Grad Norm D:    0.2803
  Grad Norm G:    124.4891

=== Epoch 25/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 25 Metrics:
  Loss_D:         0.6707
  Loss_G_total:   104.4628
  Loss_G_adv:     0.7290
  Loss_G_L1:      14.6416
  Loss_Cycle:     6.1051
  Grad Norm D:    0.2804
  Grad Norm G:    180.3266

=== Epoch 26/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 26 Metrics:
  Loss_D:         0.6694
  Loss_G_total:   101.5440
  Loss_G_adv:     0.7293
  Loss_G_L1:      14.6602
  Loss_Cycle:     5.5027
  Grad Norm D:    0.2816
  Grad Norm G:    136.5232

=== Epoch 27/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 27 Metrics:
  Loss_D:         0.6686
  Loss_G_total:   101.2138
  Loss_G_adv:     0.7307
  Loss_G_L1:      14.5815
  Loss_Cycle:     5.5151
  Grad Norm D:    0.2818
  Grad Norm G:    141.4417

=== Epoch 28/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 28 Metrics:
  Loss_D:         0.6693
  Loss_G_total:   100.6406
  Loss_G_adv:     0.7300
  Loss_G_L1:      14.5577
  Loss_Cycle:     5.4244
  Grad Norm D:    0.2833
  Grad Norm G:    112.4078

=== Epoch 29/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 29 Metrics:
  Loss_D:         0.6685
  Loss_G_total:   103.1823
  Loss_G_adv:     0.7304
  Loss_G_L1:      14.6182
  Loss_Cycle:     5.8721
  Grad Norm D:    0.2835
  Grad Norm G:    164.5597

=== Epoch 30/50 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 30 Metrics:
  Loss_D:         0.6677
  Loss_G_total:   101.8889
  Loss_G_adv:     0.7329
  Loss_G_L1:      14.5670
  Loss_Cycle:     5.6642
  Grad Norm D:    0.2846
  Grad Norm G:    148.0628

=== Epoch 31/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 31 Metrics:
  Loss_D:         0.6676
  Loss_G_total:   101.0748
  Loss_G_adv:     0.7329
  Loss_G_L1:      14.5631
  Loss_Cycle:     5.5053
  Grad Norm D:    0.2850
  Grad Norm G:    133.3183

=== Epoch 32/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 32 Metrics:
  Loss_D:         0.6666
  Loss_G_total:   102.3803
  Loss_G_adv:     0.7330
  Loss_G_L1:      14.6010
  Loss_Cycle:     5.7285
  Grad Norm D:    0.2852
  Grad Norm G:    159.9388

=== Epoch 33/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 33 Metrics:
  Loss_D:         0.6668
  Loss_G_total:   103.1469
  Loss_G_adv:     0.7327
  Loss_G_L1:      14.5590
  Loss_Cycle:     5.9239
  Grad Norm D:    0.2852
  Grad Norm G:    172.2963

=== Epoch 34/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 34 Metrics:
  Loss_D:         0.6662
  Loss_G_total:   101.1477
  Loss_G_adv:     0.7340
  Loss_G_L1:      14.6093
  Loss_Cycle:     5.4735
  Grad Norm D:    0.2848
  Grad Norm G:    137.0367

=== Epoch 35/50 ===


Training Batches: 130it [00:36,  3.51it/s]


Epoch 35 Metrics:
  Loss_D:         0.6658
  Loss_G_total:   101.2842
  Loss_G_adv:     0.7335
  Loss_G_L1:      14.5799
  Loss_Cycle:     5.5302
  Grad Norm D:    0.2860
  Grad Norm G:    147.6029

=== Epoch 36/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 36 Metrics:
  Loss_D:         0.6649
  Loss_G_total:   101.8851
  Loss_G_adv:     0.7344
  Loss_G_L1:      14.5820
  Loss_Cycle:     5.6482
  Grad Norm D:    0.2875
  Grad Norm G:    150.8804

=== Epoch 37/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 37 Metrics:
  Loss_D:         0.6644
  Loss_G_total:   102.2075
  Loss_G_adv:     0.7346
  Loss_G_L1:      14.5928
  Loss_Cycle:     5.7018
  Grad Norm D:    0.2880
  Grad Norm G:    161.3325

=== Epoch 38/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 38 Metrics:
  Loss_D:         0.6644
  Loss_G_total:   100.1289
  Loss_G_adv:     0.7364
  Loss_G_L1:      14.5414
  Loss_Cycle:     5.3371
  Grad Norm D:    0.2885
  Grad Norm G:    139.1900

=== Epoch 39/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 39 Metrics:
  Loss_D:         0.6634
  Loss_G_total:   100.3359
  Loss_G_adv:     0.7357
  Loss_G_L1:      14.5074
  Loss_Cycle:     5.4126
  Grad Norm D:    0.2891
  Grad Norm G:    131.5290

=== Epoch 40/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 40 Metrics:
  Loss_D:         0.6627
  Loss_G_total:   100.4140
  Loss_G_adv:     0.7367
  Loss_G_L1:      14.5624
  Loss_Cycle:     5.3730
  Grad Norm D:    0.2902
  Grad Norm G:    126.7075

=== Epoch 41/50 ===


Training Batches: 130it [00:37,  3.50it/s]


Epoch 41 Metrics:
  Loss_D:         0.6624
  Loss_G_total:   101.7910
  Loss_G_adv:     0.7370
  Loss_G_L1:      14.5302
  Loss_Cycle:     5.6806
  Grad Norm D:    0.2905
  Grad Norm G:    168.7773

=== Epoch 42/50 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 42 Metrics:
  Loss_D:         0.6625
  Loss_G_total:   100.0049
  Loss_G_adv:     0.7379
  Loss_G_L1:      14.4773
  Loss_Cycle:     5.3761
  Grad Norm D:    0.2909
  Grad Norm G:    141.1132

=== Epoch 43/50 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 43 Metrics:
  Loss_D:         0.6617
  Loss_G_total:   101.8431
  Loss_G_adv:     0.7379
  Loss_G_L1:      14.5095
  Loss_Cycle:     5.7115
  Grad Norm D:    0.2925
  Grad Norm G:    166.5422

=== Epoch 44/50 ===


Training Batches: 130it [00:35,  3.67it/s]


Epoch 44 Metrics:
  Loss_D:         0.6611
  Loss_G_total:   99.6885
  Loss_G_adv:     0.7381
  Loss_G_L1:      14.5127
  Loss_Cycle:     5.2774
  Grad Norm D:    0.2920
  Grad Norm G:    124.1801

=== Epoch 45/50 ===


Training Batches: 130it [00:35,  3.68it/s]


Epoch 45 Metrics:
  Loss_D:         0.6608
  Loss_G_total:   101.4193
  Loss_G_adv:     0.7392
  Loss_G_L1:      14.4631
  Loss_Cycle:     5.6729
  Grad Norm D:    0.2927
  Grad Norm G:    161.6244

=== Epoch 46/50 ===


Training Batches: 130it [00:35,  3.62it/s]


Epoch 46 Metrics:
  Loss_D:         0.6605
  Loss_G_total:   99.7569
  Loss_G_adv:     0.7410
  Loss_G_L1:      14.5312
  Loss_Cycle:     5.2720
  Grad Norm D:    0.2924
  Grad Norm G:    139.5405

=== Epoch 47/50 ===


Training Batches: 130it [00:36,  3.53it/s]


Epoch 47 Metrics:
  Loss_D:         0.6597
  Loss_G_total:   99.9714
  Loss_G_adv:     0.7403
  Loss_G_L1:      14.5144
  Loss_Cycle:     5.3318
  Grad Norm D:    0.2924
  Grad Norm G:    150.7815

=== Epoch 48/50 ===


Training Batches: 130it [00:36,  3.52it/s]


Epoch 48 Metrics:
  Loss_D:         0.6594
  Loss_G_total:   99.5843
  Loss_G_adv:     0.7401
  Loss_G_L1:      14.4895
  Loss_Cycle:     5.2794
  Grad Norm D:    0.2954
  Grad Norm G:    138.6163

=== Epoch 49/50 ===


Training Batches: 130it [00:37,  3.51it/s]


Epoch 49 Metrics:
  Loss_D:         0.6590
  Loss_G_total:   100.6547
  Loss_G_adv:     0.7402
  Loss_G_L1:      14.4550
  Loss_Cycle:     5.5279
  Grad Norm D:    0.2944
  Grad Norm G:    153.5300

=== Epoch 50/50 ===


Training Batches: 130it [00:37,  3.51it/s]

Epoch 50 Metrics:
  Loss_D:         0.6586
  Loss_G_total:   99.7650
  Loss_G_adv:     0.7421
  Loss_G_L1:      14.4586
  Loss_Cycle:     5.3460
  Grad Norm D:    0.2949
  Grad Norm G:    146.5494





## Save the models

In [24]:
# assert False
path = "models/" 
torch.save(generator.state_dict(), path + "generator_state_dict_" + now + ".pt")
torch.save(generator_2.state_dict(), path + "generator_2_state_dict_" + now + ".pt")
torch.save(discriminator.state_dict(), path + "discriminator_state_dict_" + now + ".pt")

# ------------- package everything to save -------------
export_dict = {
    "train_parameters": train_parameters,
    "model_config_gen": model_config_gen,      # Wave‑U‑Net (speech+accomp → vocal)
    "model_config_gen2": model_config_gen2,    # Wave‑U‑Net (vocal → speech)
}
import json

# (optional) ensure JSON‑serialisable: convert tuples → lists
def _convert(obj):
    if isinstance(obj, tuple):
        return list(obj)
    if isinstance(obj, dict):
        return {k: _convert(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [_convert(x) for x in obj]
    return obj

export_dict = _convert(export_dict)

with open(f"{path}/training_record_{now}.json", "w") as fp:
    json.dump(export_dict, fp, indent=2)

