<a href="https://colab.research.google.com/github/jjaw89/spring_2025_dl_audio_project/blob/main/Greg_cycle_GAN_train_2_with_listening.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Here we train our first version of the GAN.



## Initialize Wave-U-Net

We start by loading the necessary packages

Wave-U-Net is named ``generator``

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My Drive/git_projects/spring_2025_dl_audio_project

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/14hS-tzi4BinLtBzv8urx6eA2XOhsMNJI/git_projects/spring_2025_dl_audio_project


In [None]:
# Import same packages as the train script in Wave-U-Net-Pytorch
import argparse
import os
import time
from functools import partial
from datetime import datetime



import torch
import pickle
import numpy as np

import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils import spectral_norm
from torch.optim import Adam
from torch.nn import L1Loss
from tqdm import tqdm
from torchsummary import summary
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
# install torchaudio if not already installed
# ! pip install torchaudio
import torchaudio

import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Optional

!pip install sktime
from sktime.transformations.panel.rocket import MiniRocketMultivariate


# add a path to Wave-U-Net
import sys
sys.path.append('Wave-U-Net-Pytorch')
sys.path.append("workspace/hdd_project_data")

import model.utils as model_utils
import utils
from model.waveunet import Waveunet

# Check to see if we have a GPU available
print("GPU:", torch.cuda.is_available())

Collecting sktime
  Downloading sktime-0.37.0-py3-none-any.whl.metadata (34 kB)
Collecting scikit-base<0.13.0,>=0.6.1 (from sktime)
  Downloading scikit_base-0.12.2-py3-none-any.whl.metadata (8.8 kB)
Downloading sktime-0.37.0-py3-none-any.whl (37.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.0/37.0 MB[0m [31m67.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading scikit_base-0.12.2-py3-none-any.whl (142 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.7/142.7 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: scikit-base, sktime
Successfully installed scikit-base-0.12.2 sktime-0.37.0
GPU: True


In [None]:
# I run these commands in the terminal that you get when you pay for Colab.

!{sys.executable} -m pip install musdb  # has some helpful data structures, also installs ffmpeg and stempeg
!{sys.executable} -m pip uninstall -y stempeg    # musdb installs the wrong version of stempeg'

Collecting musdb
  Downloading musdb-0.4.2-py2.py3-none-any.whl.metadata (10 kB)
Collecting stempeg>=0.2.3 (from musdb)
  Downloading stempeg-0.2.3-py3-none-any.whl.metadata (9.0 kB)
Collecting pyaml (from musdb)
  Downloading pyaml-25.1.0-py3-none-any.whl.metadata (12 kB)
Collecting ffmpeg-python>=0.2.0 (from stempeg>=0.2.3->musdb)
  Downloading ffmpeg_python-0.2.0-py3-none-any.whl.metadata (1.7 kB)
Downloading musdb-0.4.2-py2.py3-none-any.whl (13 kB)
Downloading stempeg-0.2.3-py3-none-any.whl (963 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m963.5/963.5 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyaml-25.1.0-py3-none-any.whl (26 kB)
Downloading ffmpeg_python-0.2.0-py3-none-any.whl (25 kB)
Installing collected packages: pyaml, ffmpeg-python, stempeg, musdb
Successfully installed ffmpeg-python-0.2.0 musdb-0.4.2 pyaml-25.1.0 stempeg-0.2.3
Found existing installation: stempeg 0.2.3
Uninstalling stempeg-0.2.3:
  Successfully uninstalled stempeg-0

## Initialize miniRocket
We start by loading the necessary packages

### CPU Core Allocation for MiniRocketMultivariate

- The implementation of `MiniRocketMultivariate` runs on the **CPU**.
- We need to decide how many cores to allocate for it.
- Some cores will be used by MiniRocket itself, while others are needed for data preparation (e.g., generating spectrograms).
- This allocation likely needs to be **tuned for optimal performance**.
- As a starting point, we detect the number of available cores and split them evenly.
- Note: We avoid using *all* available cores to leave some resources for the operating system and other background processes.


Create the MiniRocket model

In [None]:
import pandas as pd
from time import time

# MiniRocket Discriminator using tsai library
class TsaiMiniRocketDiscriminator(nn.Module):
    def __init__(
        self,
        freq_bins=256,
        time_frames=256,
        num_kernels=10000,  # number of convolutional kernels
        hidden_dim=1024,    # Increased to handle larger feature dimension
        output_dim=1
    ):
        super(TsaiMiniRocketDiscriminator, self).__init__()

        # This is the mini rocket transformer which extracts features
        self.rocket = MiniRocketMultivariate(num_kernels=num_kernels, n_jobs=minirocket_n_jobs)
        # tsai's miniRocketClassifier is implemented with MiniRocketMultivariate as well
        self.fitted = False   # fit before training
        self.freq_bins = freq_bins
        self.time_frames = time_frames
        self.num_kernels = num_kernels

        # For 2D data handling - process each sample with proper dimensions
        self.example_input = np.zeros((1, freq_bins, time_frames))

        self.feature_dim = num_kernels  # For vocals + accompaniment

        # Example feature reducing layers
        self.classifier = nn.Sequential(
            # First reduce the massive dimension to something manageable
            nn.Dropout(0.3),
            spectral_norm(nn.Linear(19992, hidden_dim)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Second hidden layer
            nn.Dropout(0.3),
            spectral_norm(nn.Linear(hidden_dim, hidden_dim // 2)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            # Final classification layer
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, output_dim),
            nn.Sigmoid()
            # nn.Tanh()
        )
        #         # Example feature reducing layers
        # self.classifier = nn.Sequential(
        #     # First reduce the massive dimension to something manageable
        #     nn.Linear(19992, hidden_dim),
        #     nn.LeakyReLU(0.2),
        #     nn.Dropout(0.3),

        #     # Second hidden layer
        #     nn.Linear(hidden_dim, hidden_dim // 2),
        #     nn.LeakyReLU(0.2),
        #     nn.Dropout(0.3),

        #     # Final classification layer
        #     nn.Linear(hidden_dim // 2, output_dim),
        #     nn.Sigmoid()
        # )

    def fit_rocket(self, spectrograms):
        """
            Fit MiniRocket with just one piece of vocal training data (not the entire training dataset)
        """
        if not self.fitted:
            try:
                # Reshape for MiniRocket - it expects (n_instances, n_dimensions, series_length)
                # flatten the freq_bins dimension to create a multivariate time series
                batch_size = spectrograms.shape[0]

                # Convert first to numpy for sktime processing
                sample_data = spectrograms.detach().cpu().numpy()
                # print(sample_data.shape)
                # Reshape to sktime's expected format - reduce to single sample for fitting
                # sample_data = sample_data[:, 0]  # Take one sample, remove channel dim

                # Fit on this sample
                self.rocket.fit(sample_data)
                self.fitted = True

                # Test transform to get feature dimension
                test_transform = self.rocket.transform(sample_data)
                self.feature_dim = test_transform.shape[1]

                print(f"MiniRocket fitted. Feature dimension: {self.feature_dim}")

            except Exception as e:
                print(f"Error fitting MiniRocket: {e}")
                # Use a fallback if fitting fails
                self.fitted = True  # Mark as fitted to avoid repeated attempts

    def extract_features(self, spectrogram):
        """Extract MiniRocket features from a spectrogram"""
        try:
            # Ensure rocket is fitted
            if not self.fitted:
                self.fit_rocket(spectrogram)

            # Convert to numpy for sktime
            spec_np = spectrogram.detach().cpu().numpy()

            # Remove channel dimension expected by sktime
            # print(spec_np.shape)
            # spec_np = spec_np[:, 0]  # [batch_size, freq_bins, time_frames]
            # print(spec_np.shape)

            # This step extracts features using the convolutional kernels, numbers specified by num_kernels
            # print("1")
            features = self.rocket.transform(spec_np)
            # print("2")
            # Convert back to torch tensor
            # print("features:", features.shape)
            # print(features.head())
            features_tensor = torch.tensor(features.values).to(spectrogram.device)
            # print("features:", features.shape)
            # print("3")
            return features_tensor

        except Exception as e:
            print(f"Error in feature extraction: {e}")
            # Return zeros as fallback
            return torch.zeros((spectrogram.shape[0], self.num_kernels),
                              device=spectrogram.device)

    def forward(self, vocals, accompaniment):
        """
        Forward pass of the discriminator

        Args:
            vocals: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
            accompaniment: Spectrograms of shape [batch_size, channels, freq_bins, time_frames]
        """
        # Extract features from both spectrograms
        # start_time = time()
        vocal_features = self.extract_features(vocals)
        accomp_features = self.extract_features(accompaniment)
        # print("extract:", time()-start_time)
        # Concatenate features (conditional GAN)
        combined_features = torch.cat([vocal_features, accomp_features], dim=1)
        # print(combined_features.size())

        # Classify as real/fake
        # print(combined_features.size())
        validity = self.classifier(combined_features)

        return validity



In [None]:
# del discriminator
# discriminator = TsaiMiniRocketDiscriminator()
# We probably do not need to compile the model

# Import Data into Session

First, we run the code that defines the custom Dataset objects. The Datasets were compiled previously and saved in .pt files. In the next cell, we load those Dataset objects.

In [None]:
%cd /content/drive/My Drive/git_projects/
# sys.path.append('/workspace/hdd_project_data/stempeg')
import stempeg
import musdb
import torch
import librosa
import numpy as np
from torch.utils.data import Dataset

class MusdbDataset(Dataset):

  def __init__(self, musDB, window_size = 256, step_size = 128):
    self.mel_specs = torch.zeros(1, 2, 128, window_size)
    self.sample_rates = torch.tensor([0])

    num_songs = 0

    for track in musDB:
      stems, rate = track.stems, track.rate

      num_songs += 1

      # separate the vocal from other instruments and conver to mono signal
      audio_novocal = librosa.to_mono(np.transpose(stems[1] + stems[2] + stems[3]))
      audio_vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      logmelspec_novocal = torch.from_numpy(self._mel_spectrogram(audio_novocal, rate))
      logmelspec_vocal = torch.from_numpy(self._mel_spectrogram(audio_vocal, rate))

      start_ndx = 0

      for step in range(window_size // step_size):
        cropped_logmelspec_novocal = logmelspec_novocal[:, start_ndx:]
        cropped_logmelspec_vocal = logmelspec_vocal[:, start_ndx:]
        num_slices = cropped_logmelspec_novocal.shape[1] // window_size

        # chop off the last bit so that number of stft steps is a multiple of window_size
        cropped_logmelspec_novocal = cropped_logmelspec_novocal[: , 0:num_slices*window_size]
        cropped_logmelspec_vocal = cropped_logmelspec_vocal[:, 0:num_slices*window_size]

        # reshape tensors into chunks of size 128x(window_size)
        # first dimension is number of chunks
        cropped_logmelspec_novocal = torch.transpose(torch.reshape(cropped_logmelspec_novocal, (128, num_slices, window_size)), 0, 1)
        cropped_logmelspec_vocal = torch.transpose(torch.reshape(cropped_logmelspec_vocal, (128, num_slices, window_size)), 0, 1)

        # unsqueeze and concatenate these tensors. Then concatenate to the big tensor
        logmels = torch.cat((cropped_logmelspec_novocal.unsqueeze(1), cropped_logmelspec_vocal.unsqueeze(1)), 1)
        self.mel_specs = torch.cat((self.mel_specs, logmels), 0)
        self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)

        if num_songs % 5 == 0:
          print(str(num_songs) + " songs processed; produced " + str(self.mel_specs.shape[0]) + " spectrograms")

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx, 0], self.mel_specs[ndx, 1], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

  def cat(self, other_ds):
    self.mel_specs = torch.cat((self.mel_specs, other_ds.mel_specs), 0)
    self.sample_rates = torch.cat((self.sample_rates, other_ds.sample_rates), 0)

import torch
import librosa
import numpy as np
from torch.utils.data import Dataset

class SingingDataset(Dataset):

  def __init__(self, musDB, window_size = 256, step_size = 128):
    self.mel_specs = torch.zeros(1, 128, window_size)
    self.sample_rates = torch.tensor([0])

    num_songs = 0

    for track in musDB:
      stems, rate = track.stems, track.rate

      num_songs += 1

      # load the vocal
      vocal = librosa.to_mono(np.transpose(stems[4]))

      # compute log mel spectrogram and convert to pytorch tensor
      mel_spec = torch.from_numpy(self._mel_spectrogram(vocal, rate))

      start_ndx = 0
      for step in range(window_size // step_size):
        cropped_mel_spec = mel_spec[:, start_ndx:]
        num_slices = cropped_mel_spec.shape[1] // window_size

        # chop off the last bit so that number of stft steps is a multiple of window_size
        cropped_mel_spec = cropped_mel_spec[:, 0:num_slices*window_size]

        # reshape tensors into chunks of size 128x(window_size)
        # first dimension is number of chunks
        cropped_mel_spec = torch.transpose(torch.reshape(cropped_mel_spec, (128, num_slices, window_size)), 0, 1)

        # concatenate to the big tensor
        self.mel_specs = torch.cat((self.mel_specs, cropped_mel_spec), 0)
        self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)


    if num_songs % 5 == 0:
        print(str(num_songs) + " songs processed; produced " + str(self.mel_specs.shape[0]) + " spectrograms")

    # remove the all zeros slice that we initialized with
    self.mel_specs = self.mel_specs[1: , : , :]
    self.sample_rates = self.sample_rates[1:]

  def __len__(self):
    return self.mel_specs.shape[0]

  def __getitem__(self, ndx):
    # returns tuple (mel spectrogram of accompaniment, mel spectrogram of vocal, rate)
    return self.mel_specs[ndx], self.sample_rates[ndx]

  def _mel_spectrogram(self, audio, rate):
    # compute the log-mel-spectrogram of the audio at the given sample rate
    return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

import torch
import librosa
import numpy as np
import os
from torch.utils.data import Dataset

class LibriSpeechDataset(Dataset):

    def __init__(self, path, window_size = 256, step_size = 128, num_specs = 7647*2):
        self.mel_specs = self.mel_specs = torch.zeros(1, 128, window_size)
        self.sample_rates = torch.tensor([0])

        num_files_opened = 0

        for speaker_dir in os.listdir(path):
            speaker_path = path + "/" + speaker_dir
            for chapter_dir in os.listdir(speaker_path):
                chapter_path = speaker_path + "/" + chapter_dir
                for file in os.listdir(chapter_path):
                    # checks file extension and stops when we hit desired number of spectrograms (num_specs)
                    if file.endswith('.flac') and self.mel_specs.shape[0] - 1 < num_specs:
                        # get audio file and convert to log mel spectrogram
                        speech, rate = librosa.load(chapter_path + "/" + file, sr = 44100)
                        mel_spec = torch.from_numpy(self._mel_spectrogram(speech, rate))
                        start_ndx = 0

                        num_files_opened += 1

                        for step in range(window_size // step_size):
                            cropped_mel_spec = mel_spec[:, start_ndx:]

                            # Saves the total number of 128 x (window_size) spectrograms
                            num_slices = cropped_mel_spec.shape[1] // window_size

                            # chop off the last bit so that number of stft steps is a multiple of window_size
                            cropped_mel_spec = cropped_mel_spec[ : , 0 : num_slices*window_size]

                            # reshape the tensor to have many spectrograms of size 128 x (steps)
                            cropped_mel_spec = torch.transpose(torch.reshape(cropped_mel_spec, (128, num_slices, window_size)), 0, 1)

                            # concatenate tensor to the full tensor in the Dataset object
                            self.mel_specs = torch.cat((self.mel_specs, cropped_mel_spec), 0)
                            self.sample_rates = torch.cat((self.sample_rates, torch.full((num_slices,), rate)), 0)

                            # increment start_ndx
                            start_ndx += step_size


                        if num_files_opened % 50 == 0:
                            print("opened " + str(num_files_opened) + " files and produced " + str(self.mel_specs.shape[0]) + " spectrograms")


        # chop off the zero layer we initialized with
        self.mel_specs = self.mel_specs[1:]
        self.sample_rates = self.sample_rates[1:]

    def __len__(self):
        return self.mel_specs.shape[0]

    def __getitem__(self, ndx):
        return self.mel_specs[ndx], self.sample_rates[ndx]

    def _mel_spectrogram(self, audio, rate):
        # compute the log-mel-spectrogram of the audio at the given sample rate
        return librosa.power_to_db(librosa.feature.melspectrogram(y = audio, sr = rate))

/content/drive/.shortcut-targets-by-id/14hS-tzi4BinLtBzv8urx6eA2XOhsMNJI/git_projects


In [None]:
path = "/content/drive/My Drive/git_projects/spring_2025_dl_audio_project_data/"


# The string below is the path to the saved LibriSpeechDataset in your Drive
librispeechDataset_path = path + "LibriSpeechDataset_withOverlap.pt"


librispeech_dataset = torch.load(librispeechDataset_path, weights_only=False)


# The string below is the path to the saved MusdbDataset in your Drive
musdbDataset_path = path + "musdb_withOverlap_train.pt"
musdb_dataset = torch.load(musdbDataset_path, weights_only=False)

In [None]:
# This fixes the problem with the sample rates
#musdb_dataset.sample_rates = torch.full((len(musdb_dataset),), 44100)
#librispeech_dataset.sample_rates = torch.full((len(musdb_dataset),), 44100)

# Because of the way the librispeech dataset was constructed, it is slightly longer
# than the musbd dataset. Crop the librispeech dataset with these lines
librispeech_dataset.mel_specs = librispeech_dataset.mel_specs[0:len(musdb_dataset)]
librispeech_dataset.sample_rates = librispeech_dataset.sample_rates[0:len(musdb_dataset)]

### Explore these datasets

In [None]:
# --- Explore the Datasets ---
print("=== MusDB Dataset Exploration ===")
print("Length:", len(musdb_dataset))
print("mel_specs shape:", musdb_dataset.mel_specs.shape)
print("sample_rates shape:", musdb_dataset.sample_rates.shape)
print()
accompaniment, vocal, sample_rate = musdb_dataset[0]
print("Sample 0 - Accompaniment shape:", accompaniment.size())
print("Sample 0 - Vocal shape:", vocal.size())
print("Sample 0 - Sample rate:", sample_rate)
print()

print("=== LibriSpeech Dataset Exploration ===")
print("Length:", len(librispeech_dataset))
print("mel_specs shape:", librispeech_dataset.mel_specs.shape)
print("sample_rates shape:", librispeech_dataset.sample_rates.shape)
print()
speech, sample_rate = librispeech_dataset[0]
print("Sample 0 - Speech shape:", speech.size())
print("Sample 0 - Sample rate:", sample_rate)

## Dataset Helpers Explanation
Why New Dataset Helpers?

We have created new dataset helper classes (i.e., AccompanimentData, VocalData, and SpeechData) so that we can control how the data is padded and later shuffled.

- **Separation of Data:**
We separated the vocal and accompaniment data from the MusDB dataset. In our experiments, we might want to shuffle the speech data independently of the combined music data.

- **Shuffling Considerations:**
For the vocal and accompaniment data, we want to maintain their pairing so that they are shuffled in the same order. In contrast, we want the speech data to be shuffled independently.

- **Future Extensions:**
In the future, we may add another helper class that combines the vocal and accompaniment data to ensure synchronized shuffling in our data loaders.

This modular approach gives us flexibility in handling and preprocessing the data for our GAN training.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

class AccompanimentData(Dataset):
    def __init__(self, musdb_dataset, output_length=289):
        self.musdb_dataset = musdb_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.musdb_dataset)

    def __getitem__(self, index):
        accompaniment, _, _ = self.musdb_dataset[index]  # shape: [128, 256]
        current_len = accompaniment.size(-1)             # 256
        delta = self.output_length - current_len         # 289 - 256 = 33

        # If delta is positive, pad. Otherwise, you might want to crop or handle differently.
        if delta > 0:
            # Half the remainder goes to the front
            left_pad_len = (delta // 2) + (delta % 2)  # 17
            right_pad_len = delta // 2                # 16
            accompaniment_pad = F.pad(accompaniment,
                                  (left_pad_len, right_pad_len),
                                  "constant", 0)
        return {"no_pad" : accompaniment, "pad" : accompaniment_pad}


class VocalData(Dataset):
    def __init__(self, musdb_dataset, output_length=289):
        self.musdb_dataset = musdb_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.musdb_dataset)

    def __getitem__(self, index):
        _, vocal, _ = self.musdb_dataset[index]  # shape: [128, 256]
        current_len = vocal.size(-1)
        delta = self.output_length - current_len

        if delta > 0:
            left_pad_len = (delta // 2) + (delta % 2)
            right_pad_len = delta // 2
            vocal_pad = F.pad(vocal, (left_pad_len, right_pad_len), "constant", 0)
        return {"no_pad" : vocal, "pad" : vocal_pad}



class SpeechData(Dataset):
    def __init__(self, librispeech_dataset, output_length=289):
        self.librispeech_dataset = librispeech_dataset
        self.output_length = output_length

    def __len__(self):
        return len(self.librispeech_dataset)

    def __getitem__(self, index):
        speech, _ = self.librispeech_dataset[index]
        # If speech has multiple slices, pick the first slice
        if speech.dim() == 3:
            speech = speech[0]  # shape: [128, 256]
        current_len = speech.size(-1)
        delta = self.output_length - current_len

        if delta > 0:
            left_pad_len = (delta // 2) + (delta % 2)
            right_pad_len = delta // 2
            speech_pad = F.pad(speech, (left_pad_len, right_pad_len), "constant", 0)
        return {"no_pad" : speech, "pad" : speech_pad}


In [None]:
# print(AccompanimentData(musdb_dataset)[0])
# print(VocalData(musdb_dataset)[0])
# print(SpeechData(librispeech_dataset)[0])

### DataLoader Explanation
What is a DataLoader and Why Do We Need It?

A DataLoader in PyTorch is a utility that wraps a dataset and provides:

- **Batching:** It divides your dataset into batches so that you can train your models with mini-batch gradient descent.

- **Shuffling:** It shuffles the data at every epoch (if specified) to help reduce overfitting and ensure the model sees a diverse set of examples.

- **Parallel Data Loading:** It can load data in parallel using multiple worker processes, speeding up training.

In our case, we create separate DataLoaders for:

- The accompaniment data (paired with vocals) from the MusDB dataset.

- The vocal data (paired with accompaniment) from the MusDB dataset.

- The speech data from the LibriSpeech dataset.

This lets us shuffle the speech data independently, while keeping the vocal/accompaniment pairs synchronized during training.

In [None]:


# # Print how many batches each DataLoader contains
# print("Accompaniment loader length:", len(accompaniment_loader))
# print("Vocal loader length:", len(vocal_loader))
# print("Speech loader length:", len(speech_loader))

# # Optionally, fetch and print the shape of the first batch
# accompaniment_batch = next(iter(accompaniment_loader))
# vocal_batch = next(iter(vocal_loader))
# speech_batch = next(iter(speech_loader))
# print(accompaniment_batch["pad"])

# print("Accompaniment first batch shape:", accompaniment_batch.shape)
# print("Vocal first batch shape:", vocal_batch.shape)
# print("Speech first batch shape:", speech_batch.shape)


## Second Generator Model
Here we initialize the second generator model whose purpose is to convert the generated vocals back to normal speach for the cycle GAN. We again use Wave-U-Net, but with a different configuration. The main difference is that we will not input the music along with the vocal track.

## Transform Input to generator_2

The output of the generator model is a (batch_size, 128, 257) tensor. The model expects a tensor of size (batch_size, 128, 289). We need to pad the last dimension with 16 zeros on each size.

In [None]:
def transform_for_gen_2(batch, output_length = 289):
  current_len = batch.size(-1)
  delta = output_length - current_len

  if delta > 0:
      left_pad_len = (delta // 2) + (delta % 2)
      right_pad_len = delta // 2
      batch = F.pad(batch, (left_pad_len, right_pad_len), "constant", 0)
  return batch


## Train the Cycle GAN
The models are ``generator`` and ``discriminator`` and ``generator_2``.


In [None]:
import IPython.display as ipd
def convert_to_audio(mel_spectrograms, n_fft=2048, hop_length=512, power=2.0, n_iter=32):
    audio_files = []
    sr = 44100

    for mel_spec in mel_spectrograms:
        # Convert Mel spectrogram back to linear spectrogram
        # mel_s
        mel_spec = mel_spec.detach().cpu().numpy() # Remove batch dimension
        print(mel_spec)
        # print(mel_spec.shape)
        linear_spec = librosa.feature.inverse.mel_to_stft(
            mel_spec, sr=sr, n_fft=n_fft, power=power
        )

        # Apply Griffin-Lim algorithm for phase reconstruction
        audio = librosa.griffinlim(
            linear_spec, hop_length=hop_length, n_iter=n_iter
        )
        print(audio)
        audio_files.append(audio)
        # break

    return audio_files

def display_audio(audio_file):
    # y, sr = librosa.load(audio_file, sr=44100)
    ipd.display(ipd.Audio(audio_file, rate=44100))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

# ----- Single Epoch Training Function -----
def train_epoch(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    lambda_l1,
    lambda_cycle,
    adversarial_loss,
    device,
    virtual_batch_size=1,
    save_output=False,
):
    total_loss_D = total_loss_G = total_loss_G_adv = total_loss_cycle = 0.0
    # Optionally record gradient norms per batch for diagnosing vanishing gradients.
    grad_norms_D = []
    grad_norms_G = []
    num_batches = 0

    # ---- batch loop ----
    for ((accomp, voc), speech) in tqdm(
        zip(zip(accompaniment_loader, vocal_loader), speech_loader),
        desc="Training Batches"
    ):
        # Move data to device
        x_acc = accomp["pad"].float().to(device)       # [B,128,289]
        x_speech = speech["pad"].float().to(device)    # [B,128,289]
        x_in = torch.cat([x_speech, x_acc], dim=1)     # [B,256,289]

        # Discriminator real/fake labels
        B = x_acc.size(0)
        # real_labels = torch.ones(B, 1, device=device)
        # fake_labels = torch.zeros(B, 1, device=device)

        real_labels = torch.rand(B,1, device=device) * 0.2 + 0.8  # [0.8,1.0]
        fake_labels = torch.rand(B,1, device=device) * 0.2        # [0.0,0.2]

        # real_labels = torch.ones(B, 1, device=device)
        # fake_labels = -torch.ones(B, 1, device=device)

        # 1) Train D
        optimizer_D.zero_grad()
        acc_np = accomp["no_pad"].float().to(device)
        voc_np = voc["no_pad"].float().to(device)

        pred_real = discriminator(voc_np, acc_np)
        loss_D_real = adversarial_loss(pred_real, real_labels)

        raw_fake = generator(x_in)["vocal"]
        fake = raw_fake.clone()
        fake_crop = torch.narrow(fake, 2, 0, 256).clone()

        pred_fake = discriminator(fake_crop, acc_np)
        loss_D_fake = adversarial_loss(pred_fake, fake_labels)



        loss_D = 0.5 * (loss_D_real + loss_D_fake)
        if loss_D.item() > 0.5:
            loss_D.backward()
            optimizer_D.step()

        # 2) Train G & G2
        if virtual_batch_size == 1:
            optimizer_G.zero_grad()
            optimizer_G2.zero_grad()

        pred_for_G = discriminator(fake, acc_np)
        loss_G_adv = adversarial_loss(pred_for_G, real_labels)


        # cycle‑consistency
        fake_pad = transform_for_gen_2(fake)  # you must define this
        raw_rec = generator_2(fake_pad)["speech"]
        rec = raw_rec.clone()
        rec_crop = torch.narrow(rec, 2, 0, 256).clone()
        speech_np = speech["no_pad"].float().to(device)

        loss_cycle = l1_loss(rec_crop, speech_np)

        # convex combination
        loss_G = (loss_G_adv + lambda_cycle * loss_cycle) / (1 + lambda_cycle)
        loss_G.backward()

        # Record discriminator gradient norms.
        grad_norm = 0.0
        count = 0
        for p in discriminator.parameters():
            if p.grad is not None:
                grad_norm += p.grad.norm().item()
                count += 1
        if count > 0:
            grad_norms_D.append(grad_norm / count)

                # Record generator gradient norms.
        grad_norm = 0.0
        count = 0

        for p in generator.parameters():
            if p.grad is not None:
                grad_norm += p.grad.norm().item()
                count += 1
        if count > 0:
            grad_norms_G.append(grad_norm / count)

        if (num_batches + 1) % virtual_batch_size == 0:
            optimizer_G.step()
            optimizer_G2.step()

        # Accumulate metrics
        total_loss_D      += loss_D.item()
        total_loss_G_adv  += loss_G_adv.item()
        total_loss_cycle  += loss_cycle.item()
        total_loss_G      += loss_G.item()
        num_batches       += 1
        # audio_files = []
        # if save_output:
        #     fake_singing_list = [fake_crop[i] for i in range(fake_crop.shape[0])]
        #     audio_files = convert_to_audio(fake_singing_list)
    return {
        "loss_D":      total_loss_D / num_batches,
        "loss_G_adv":  total_loss_G_adv / num_batches,
        "loss_cycle":  total_loss_cycle / num_batches,
        "loss_G":      total_loss_G / num_batches,
        "avg_grad_norm_D": sum(grad_norms_D) / len(grad_norms_D) if grad_norms_D else 0.0,
        "avg_grad_norm_G": sum(grad_norms_G) / len(grad_norms_G) if grad_norms_G else 0.0,
    }#, audio_files

# ----- Multi-Epoch Training Function -----
def train(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    lambda_l1,
    lambda_cycle,
    adversarial_loss,
    device,
    num_epochs,
    virtual_batch_size,
    log_dir,
    save_audio = True
):
    writer = SummaryWriter(log_dir=log_dir)
    global_step = 0

    for epoch in range(num_epochs):
        save_audio = True if epoch == num_epochs-1 else False
        print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
        epoch_metrics = train_epoch(
            generator,
            generator_2,
            discriminator,
            optimizer_D,
            optimizer_G,
            optimizer_G2,
            accompaniment_loader,
            vocal_loader,
            speech_loader,
            l1_loss,
            lambda_l1,
            lambda_cycle,
            adversarial_loss,
            device,
            virtual_batch_size,
            save_output = save_audio
        )
        print(f"Epoch {epoch+1} Metrics:")
        print(f"  Loss_D:         {epoch_metrics['loss_D']:.4f}")
        # print(f"  Loss_G_total:   {epoch_metrics['loss_G_total']:.4f}")
        print(f"  Loss_G_adv:     {epoch_metrics['loss_G_adv']:.4f}")
        # print(f"  Loss_G_L1:      {epoch_metrics['loss_G_L1']:.4f}")
        print(f"  Loss_Cycle:     {epoch_metrics['loss_cycle']:.4f}")
        print(f"  Grad Norm D:    {epoch_metrics['avg_grad_norm_D']:.4f}")
        print(f"  Grad Norm G:    {epoch_metrics['avg_grad_norm_G']:.4f}")

        # Log metrics to TensorBoard.
        writer.add_scalar("Loss/Discriminator", epoch_metrics["loss_D"], epoch)
        # writer.add_scalar("Loss/Generator_total", epoch_metrics["loss_G_total"], epoch)
        writer.add_scalar("Loss/Generator_adversarial", epoch_metrics["loss_G_adv"], epoch)
        # writer.add_scalar("Loss/Generator_L1", epoch_metrics["loss_G_L1"], epoch)
        writer.add_scalar("Loss/Cycle", epoch_metrics["loss_cycle"], epoch)
        writer.add_scalar("Gradients/Discriminator", epoch_metrics["avg_grad_norm_D"], epoch)
        writer.add_scalar("Gradients/Generator", epoch_metrics["avg_grad_norm_G"], epoch)

        global_step += 1

    writer.close()
    return #audio_files

In [None]:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "runs/" + "cycleGAN_experiment_" + now

# ---------------- hyper‑parameters in ONE place ----------------
train_parameters = {
    # optimisation
    "lr_G":          1e-4,
    "lr_G2":         1e-4,
    "lr_D":          2e-4,
    "betas":         (0.9, 0.999),

    # loss weights
    "lambda_l1":     1,
    "lambda_cycle": 0.001,


    # schedule
    "num_epochs":    20,
    "virtual_batch_size": 1,

    # bookkeeping
    "log_dir":       f"runs/cycleGAN_experiment_{now}",
    "model_dir":     "models",
}

# Define batch size
batch_size = 32  # Change as needed
num_workers = 1
# Create data loaders
accompaniment_loader = DataLoader(
    AccompanimentData(musdb_dataset),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)
vocal_loader = DataLoader(
    VocalData(musdb_dataset),
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)
speech_loader = DataLoader(
    SpeechData(librispeech_dataset),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)


In [None]:
# ----- Example Setup and Training Invocation -----

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Model configurations for generator and generator_2.
model_config_gen = {
    "num_inputs": 256,  # Two spectrograms concatenated (2 * 128 mel bins)
    "num_outputs": 128,
    "num_channels": [512*2, 512*4, 512*8],
    "instruments": ["vocal"],
    "kernel_size": 3,
    "target_output_size": 256,
    "conv_type": "normal",
    "res": "fixed",
    "separate": False,
    "depth": 1,
    "strides": 2
}
generator = Waveunet(**model_config_gen).to(device)

model_config_gen2 = {
    "num_inputs": 128,  # One spectrogram input
    "num_outputs": 128,
    "num_channels": [256*2, 256*4, 256*8],
    "instruments": ["speech"],
    "kernel_size": 3,
    "target_output_size": 256,
    "conv_type": "normal",
    "res": "fixed",
    "separate": False,
    "depth": 1,
    "strides": 2
}
generator_2 = Waveunet(**model_config_gen2).to(device)
# check the number of cores
import multiprocessing
num_cores = multiprocessing.cpu_count()
print("Number of CPU cores:", num_cores)
minirocket_n_jobs = num_cores-2 # Instantiate the discriminator.
discriminator = TsaiMiniRocketDiscriminator().to(device)

# Assume dataloaders `accompaniment_loader`, `vocal_loader`, and `speech_loader` are defined.

# Optionally, prepare the discriminator (e.g., pre-fitting on some speech data).
accompaniment_batch = next(iter(accompaniment_loader))
vocal_batch = next(iter(vocal_loader))
speech_batch = next(iter(speech_loader))["no_pad"]
print("Fitting discriminator...")
discriminator.fit_rocket(speech_batch)

# Loss functions.
adversarial_loss = nn.BCELoss().to(device)
# adversarial_loss = nn.HingeEmbeddingLoss().to(device)
l1_loss = nn.L1Loss().to(device)

optimizer_G  = optim.Adam(generator.parameters(),  lr=train_parameters["lr_G"],  betas=train_parameters["betas"])
optimizer_G2 = optim.Adam(generator_2.parameters(), lr=train_parameters["lr_G2"], betas=train_parameters["betas"])
optimizer_D  = optim.Adam(discriminator.parameters(), lr=train_parameters["lr_D"], betas=train_parameters["betas"])



Using device: cuda
Using valid convolutions with 289 inputs and 257 outputs
Using valid convolutions with 289 inputs and 257 outputs
Number of CPU cores: 12
Fitting discriminator...
MiniRocket fitted. Feature dimension: 9996


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()


# state_time = "20250416-143134"
# model_dir  = "models"

# # ----  load the raw weight dictionaries ----
# disc_state = torch.load(f"{model_dir}/discriminator_state_dict_{state_time}.pt", map_location=device)
# gen_state  = torch.load(f"{model_dir}/generator_state_dict_{state_time}.pt",         map_location=device)
# gen2_state = torch.load(f"{model_dir}/generator_2_state_dict_{state_time}.pt",       map_location=device)

# # ----   rebuild the model objects ----
# discriminator = TsaiMiniRocketDiscriminator().to(device)
# generator     = Waveunet(**model_config_gen).to(device)
# generator_2   = Waveunet(**model_config_gen2).to(device)

# # ---- load the weights into those models ----
# discriminator.load_state_dict(disc_state)
# generator.load_state_dict(gen_state)
# generator_2.load_state_dict(gen2_state)

train_epoch = torch.compile(train_epoch, mode="default")


# Start training.
audio_files = train(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    train_parameters["lambda_l1"],
    train_parameters["lambda_cycle"],
    adversarial_loss,
    device,
    num_epochs          = train_parameters["num_epochs"],
    virtual_batch_size  = train_parameters["virtual_batch_size"],
    log_dir             = train_parameters["log_dir"],
)



=== Epoch 1/20 ===


W0421 03:31:55.842000 1534 torch/_logging/_internal.py:1089] [59/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][4].grad", "L['self'].param_groups[0]['params'][5].grad"] will be copied during cudagraphs execution.If using cudagraphs and the grad tensor addresses will be the same across runs, use torch._dynamo.decorators.mark_static_address to elide this copy.',)
('Grad tensors ["L['self'].param_groups[0]['params'][0].grad", "L['self'].param_groups[0]['params'][1].grad", "L['self'].param_groups[0]['params'][2].grad", "L['self'].param_groups[0]['params'][3].grad", "L['self'].param_groups[0]['params'][5].grad", "L['self'].param_groups[0]['params'][6].grad", "L['self'].param_groups[0]['params'][7].grad", "L['self']

Epoch 1 Metrics:
  Loss_D:         0.4842
  Loss_G_adv:     1.1048
  Loss_Cycle:     7.0766
  Grad Norm D:    2.4591
  Grad Norm G:    0.0024

=== Epoch 2/20 ===


Training Batches: 356it [03:47,  1.56it/s]


Epoch 2 Metrics:
  Loss_D:         0.4405
  Loss_G_adv:     1.1420
  Loss_Cycle:     3.9150
  Grad Norm D:    2.4319
  Grad Norm G:    0.0025

=== Epoch 3/20 ===


Training Batches: 356it [03:47,  1.56it/s]


Epoch 3 Metrics:
  Loss_D:         0.4319
  Loss_G_adv:     1.1209
  Loss_Cycle:     3.5731
  Grad Norm D:    2.3441
  Grad Norm G:    0.0028

=== Epoch 4/20 ===


Training Batches: 356it [03:47,  1.56it/s]


Epoch 4 Metrics:
  Loss_D:         0.4249
  Loss_G_adv:     1.2809
  Loss_Cycle:     3.2521
  Grad Norm D:    2.5393
  Grad Norm G:    0.0023

=== Epoch 5/20 ===


Training Batches: 356it [03:47,  1.56it/s]


Epoch 5 Metrics:
  Loss_D:         0.4298
  Loss_G_adv:     1.2300
  Loss_Cycle:     2.9690
  Grad Norm D:    2.4777
  Grad Norm G:    0.0025

=== Epoch 6/20 ===


Training Batches: 356it [03:48,  1.56it/s]


Epoch 6 Metrics:
  Loss_D:         0.4362
  Loss_G_adv:     1.2028
  Loss_Cycle:     2.6877
  Grad Norm D:    2.4459
  Grad Norm G:    0.0024

=== Epoch 7/20 ===


Training Batches: 356it [03:47,  1.56it/s]


Epoch 7 Metrics:
  Loss_D:         0.4389
  Loss_G_adv:     1.1794
  Loss_Cycle:     2.5222
  Grad Norm D:    2.4191
  Grad Norm G:    0.0025

=== Epoch 8/20 ===


Training Batches: 356it [03:47,  1.56it/s]


Epoch 8 Metrics:
  Loss_D:         0.4343
  Loss_G_adv:     1.3114
  Loss_Cycle:     2.3966
  Grad Norm D:    2.5415
  Grad Norm G:    0.0025

=== Epoch 9/20 ===


Training Batches: 356it [03:48,  1.56it/s]


Epoch 9 Metrics:
  Loss_D:         0.4308
  Loss_G_adv:     1.2529
  Loss_Cycle:     2.2958
  Grad Norm D:    2.4512
  Grad Norm G:    0.0026

=== Epoch 10/20 ===


Training Batches: 356it [03:49,  1.55it/s]


Epoch 10 Metrics:
  Loss_D:         0.4315
  Loss_G_adv:     1.2226
  Loss_Cycle:     2.2390
  Grad Norm D:    2.4098
  Grad Norm G:    0.0029

=== Epoch 11/20 ===


Training Batches: 356it [03:49,  1.55it/s]


Epoch 11 Metrics:
  Loss_D:         0.4338
  Loss_G_adv:     1.1986
  Loss_Cycle:     2.1319
  Grad Norm D:    2.3752
  Grad Norm G:    0.0027

=== Epoch 12/20 ===


Training Batches: 356it [03:50,  1.55it/s]


Epoch 12 Metrics:
  Loss_D:         0.4355
  Loss_G_adv:     1.1898
  Loss_Cycle:     2.1117
  Grad Norm D:    2.3683
  Grad Norm G:    0.0029

=== Epoch 13/20 ===


Training Batches: 356it [03:50,  1.55it/s]


Epoch 13 Metrics:
  Loss_D:         0.4385
  Loss_G_adv:     1.1687
  Loss_Cycle:     2.0194
  Grad Norm D:    2.3345
  Grad Norm G:    0.0028

=== Epoch 14/20 ===


Training Batches: 356it [03:51,  1.54it/s]


Epoch 14 Metrics:
  Loss_D:         0.4458
  Loss_G_adv:     1.0294
  Loss_Cycle:     2.0288
  Grad Norm D:    2.1235
  Grad Norm G:    0.0033

=== Epoch 15/20 ===


Training Batches: 356it [03:50,  1.54it/s]


Epoch 15 Metrics:
  Loss_D:         0.4448
  Loss_G_adv:     1.0268
  Loss_Cycle:     2.0092
  Grad Norm D:    2.1083
  Grad Norm G:    0.0035

=== Epoch 16/20 ===


Training Batches: 356it [03:50,  1.55it/s]


Epoch 16 Metrics:
  Loss_D:         0.4441
  Loss_G_adv:     1.0236
  Loss_Cycle:     1.9993
  Grad Norm D:    2.0914
  Grad Norm G:    0.0037

=== Epoch 17/20 ===


Training Batches: 356it [03:49,  1.55it/s]


Epoch 17 Metrics:
  Loss_D:         0.4452
  Loss_G_adv:     1.0175
  Loss_Cycle:     1.9056
  Grad Norm D:    2.0842
  Grad Norm G:    0.0034

=== Epoch 18/20 ===


Training Batches: 356it [03:49,  1.55it/s]


Epoch 18 Metrics:
  Loss_D:         0.4462
  Loss_G_adv:     1.0191
  Loss_Cycle:     1.7466
  Grad Norm D:    2.0875
  Grad Norm G:    0.0023

=== Epoch 19/20 ===


Training Batches: 356it [03:49,  1.55it/s]


Epoch 19 Metrics:
  Loss_D:         0.4327
  Loss_G_adv:     1.1646
  Loss_Cycle:     1.7541
  Grad Norm D:    2.2819
  Grad Norm G:    0.0028

=== Epoch 20/20 ===


Training Batches: 356it [03:49,  1.55it/s]

Epoch 20 Metrics:
  Loss_D:         0.4369
  Loss_G_adv:     1.1427
  Loss_Cycle:     1.6998
  Grad Norm D:    2.2500
  Grad Norm G:    0.0026





In [None]:
state_time = "20250421-033033"
model_dir  = "models"

# ----  load the raw weight dictionaries ----
disc_state = torch.load(f"{model_dir}/discriminator_state_dict_{state_time}.pt", map_location=device)
gen_state  = torch.load(f"{model_dir}/generator_state_dict_{state_time}.pt",         map_location=device)
gen2_state = torch.load(f"{model_dir}/generator_2_state_dict_{state_time}.pt",       map_location=device)

# ----   rebuild the model objects ----
discriminator = TsaiMiniRocketDiscriminator().to(device)
generator     = Waveunet(**model_config_gen).to(device)
generator_2   = Waveunet(**model_config_gen2).to(device)

# ---- load the weights into those models ----
discriminator.load_state_dict(disc_state)
generator.load_state_dict(gen_state)
generator_2.load_state_dict(gen2_state)


Using valid convolutions with 289 inputs and 257 outputs
Using valid convolutions with 289 inputs and 257 outputs


<All keys matched successfully>

In [None]:
librispeech_test = torch.load("/content/drive/MyDrive/git_projects/spring_2025_dl_audio_project_data/librispeech_longClip_test.pt", weights_only = False)
musdb_test = torch.load("/content/drive/MyDrive/git_projects/spring_2025_dl_audio_project_data/musdb_noOverlap_test.pt", weights_only = False)

In [None]:
vocal_data_test = VocalData(musdb_test)
acc_data_test = AccompanimentData(musdb_test)
speech_data_test = SpeechData(librispeech_test)



In [None]:
test_acc = acc_data_test[9]["pad"].float().to(device)
test_voc = vocal_data_test[9]["pad"].float().to(device)
test_speech = librispeech_test[18][0][:,0:289].float().to(device)


print(test_speech.shape)

sr = 44100
test_acc_np = test_acc.detach().cpu().numpy()
test_voc_np = test_voc.detach().cpu().numpy()
test_speech_np = test_speech.detach().cpu().numpy()

acc = librosa.feature.inverse.mel_to_audio(M = test_acc_np, sr = sr)
voc = librosa.feature.inverse.mel_to_audio(M = test_voc_np, sr = sr)
speech = librosa.feature.inverse.mel_to_audio(M = test_speech_np, sr = sr)

torch.Size([128, 289])


In [None]:
from IPython.display import Audio

Audio(data=acc, rate=sr)

In [None]:
generated_vocal = generator(torch.cat([torch.unsqueeze(test_speech, 0), torch.unsqueeze(test_acc, 0)], dim=1))["vocal"]

In [None]:
def mel_to_audio(spec, sr = 44100):
  return librosa.feature.inverse.mel_to_audio(M = spec.detach().cpu().numpy(), sr = sr)

In [None]:
audio = mel_to_audio(generated_vocal)[0]
Audio(data=audio, rate=sr)

In [None]:

fake_pad = transform_for_gen_2(generated_vocal)  # you must define this
raw_rec = generator_2(fake_pad)["speech"]

In [None]:
reconstructed_speech = mel_to_audio(raw_rec[0])
Audio(data = reconstructed_speech, rate = sr)

In [None]:
# display the converted audios
print(audio_files)
for audio in audio_files:
    print(audio)
    display_audio(audio)

In [None]:
# gc.collect()
# torch.cuda.empty_cache()

# # Start training.
# train(
#     generator,
#     generator_2,
#     discriminator,
#     optimizer_D,
#     optimizer_G,
#     optimizer_G2,
#     accompaniment_loader,
#     vocal_loader,
#     speech_loader,
#     l1_loss,
#     train_parameters["lambda_l1"],
#     train_parameters["lambda_cycle"],
#     adversarial_loss,
#     device,
#     num_epochs          = 50, #train_parameters["num_epochs"],
#     virtual_batch_size  = train_parameters["virtual_batch_size"],
#     log_dir             = train_parameters["log_dir"],
# )


## Save the models

In [None]:
# assert False
path = "models/"
torch.save(generator.state_dict(), path + "generator_state_dict_" + now + ".pt")
torch.save(generator_2.state_dict(), path + "generator_2_state_dict_" + now + ".pt")
torch.save(discriminator.state_dict(), path + "discriminator_state_dict_" + now + ".pt")

# ------------- package everything to save -------------
export_dict = {
    "train_parameters": train_parameters,
    "model_config_gen": model_config_gen,      # Wave‑U‑Net (speech+accomp → vocal)
    "model_config_gen2": model_config_gen2,    # Wave‑U‑Net (vocal → speech)
}
import json

# (optional) ensure JSON‑serialisable: convert tuples → lists
def _convert(obj):
    if isinstance(obj, tuple):
        return list(obj)
    if isinstance(obj, dict):
        return {k: _convert(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [_convert(x) for x in obj]
    return obj

export_dict = _convert(export_dict)

with open(f"{path}/training_record_{now}.json", "w") as fp:
    json.dump(export_dict, fp, indent=2)



In [None]:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "runs/" + "cycleGAN_experiment_" + now

gc.collect()
torch.cuda.empty_cache()

# Start training.
train(
    generator,
    generator_2,
    discriminator,
    optimizer_D,
    optimizer_G,
    optimizer_G2,
    accompaniment_loader,
    vocal_loader,
    speech_loader,
    l1_loss,
    train_parameters["lambda_l1"],
    train_parameters["lambda_cycle"],
    adversarial_loss,
    device,
    num_epochs          = 50, #train_parameters["num_epochs"],
    virtual_batch_size  = train_parameters["virtual_batch_size"],
    log_dir             = train_parameters["log_dir"],
)