<a href="https://colab.research.google.com/github/matpaolacci/masked-lm-for-audio/blob/main/vq_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Configurations

## Set the paths

In [None]:
import os, sys

# Paths
DATASET_HOME_DIR = '/content/drive/MyDrive/DLProj/Dataset'
RAW_DATASET_DIR = os.path.join(DATASET_HOME_DIR, "OriginalVersion")        # the Slackh2100 dataset
WAV_DATASET_DIR = os.path.join(DATASET_HOME_DIR, "WavVersion")             # the Slackh2100 dataset with .wav stems
UTILITIES_DIR = '/content/drive/MyDrive/DLProj/Utilities'
CHECKPOINTS_DIR = '/content/drive/MyDrive/DLProj/Checkpoints'

sys.path.append(UTILITIES_DIR)

## Set Configuration variables

In [None]:
# @title  {"display-mode":"form"}
PREPARE_DATASET = False # @param {"type":"boolean"}

In [None]:
# @title Choose if you want to continue to train a pre-saved model and if you want continue the training or only evaluate it {"form-width":"30%","display-mode":"form"}
EVALUATE_OR_TRAIN = "Evaluate model" # @param ["Start new training", "Evaluate model","Continue training"]
CHECKPOINT_DIRECTORY_NAME = "2024-09-17_112711" # @param {"type":"string","placeholder":"Enter the directory name 'YYYY-MM-DD_HHMMSS'"}
MODEL_EPOCH_LABEL = 4852 # @param {"type":"integer"}
PRE_SAVED_MODEL_TYPE = "BEST_RECON_ERR" # @param ["BEST_RECON_ERR","BEST_PERPLX","BOTH_PERPLX_RECON_ERR"]
EPOCHS_RE_TRAIN = 3000 # @param {"type":"integer","placeholder":"Enter the number of epochs for training"}

# Set following fields only if previous variable has been set to True
model_checkpoints_dir = os.path.join(CHECKPOINTS_DIR, CHECKPOINT_DIRECTORY_NAME)
PATH_TO_PRE_SAVED_MODEL = os.path.join(model_checkpoints_dir, f'VQVAE_{MODEL_EPOCH_LABEL}_{PRE_SAVED_MODEL_TYPE}')
PATH_TO_HYPERPARAMETERS = os.path.join(model_checkpoints_dir, 'hyper_params.json')

USE_PRE_SAVED_MODEL = True if EVALUATE_OR_TRAIN == "Continue training" else False
ONLY_EVALUATION = True if EVALUATE_OR_TRAIN == "Evaluate model" else False
assert (not USE_PRE_SAVED_MODEL) or (not ONLY_EVALUATION), "If you want do only evaluation then uncheck 'USE_PRE_SAVED_MODEL'"

EVAL_CHECKPOINT_DIR = os.path.join(CHECKPOINTS_DIR, CHECKPOINT_DIRECTORY_NAME)
PATH_TO_MODEL_TO_EVALUATE = os.path.join(EVAL_CHECKPOINT_DIR, f'VQVAE_{MODEL_EPOCH_LABEL}_{PRE_SAVED_MODEL_TYPE}')
PATH_TO_HYPERPARAMETERS_EVALUATION = os.path.join(EVAL_CHECKPOINT_DIR, 'hyper_params.json')

## Import python libraries

In [None]:
# Install the required python packages
!pip install -r $UTILITIES_DIR/requirements.txt

import zipfile, yaml
import flacconverter as fc
from enum import Enum
from tqdm import tqdm
from datetime import datetime, timedelta
import gc
import json

import torch, torchaudio, pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

# For getting data structure from list
import ast

# Libraries for plotting
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# To plot the graphs of loss and perplexity
from scipy.signal import savgol_filter
from matplotlib.ticker import MaxNLocator

## Create checkpoints directory
It will be created a directory for each training, each one will contain the best trained model across the epochs and a file containing the used hyper parameters.

In [None]:
datetime_start = (datetime.now() + timedelta(hours=2)).strftime("%Y-%m-%d_%H%M%S")
CURR_CHECKPOINT_DIR = os.path.join(CHECKPOINTS_DIR, datetime_start)
if not ONLY_EVALUATION:
    os.makedirs(CURR_CHECKPOINT_DIR, exist_ok=True)

## Set log file and log function

In [None]:
LOG_FILE_PATH = os.path.join(CURR_CHECKPOINT_DIR, 'training.log')

def log(message: str):
    if ONLY_EVALUATION:
        return

    with open(LOG_FILE_PATH, "a") as log_file:
        now = datetime.now()
        now_local = now + timedelta(hours=2)
        log_file.write(f'[{now_local.strftime("%Y-%m-%d %H:%M:%S")}] - {message}\n')

# Prepare the data
In this section we are going to download the Slakh2100 dataset available [here](http://www.slakh.com/)

In [None]:
class Instrument(Enum):
    BASS = 'Bass'
    BRASS = 'Brass'
    CHROMATIC_PERCUSSION = 'Chromatic Percussion'
    DRUMS = 'Drums'
    ETHNIC = 'Ethnic'
    GUITAR = 'Guitar'
    ORGAN = 'Organ'
    PERCUSSIVE = 'Percussive'
    PIANO = 'Piano'
    PIPE = 'Pipe'
    REED = 'Reed'
    SOUND_EFFECTS = 'Sound Effects'
    STRINGS = 'Strings'
    STRINGS_CONTINUED = 'Strings (continued)'
    SYNTH_LEAD = 'Synth Lead'
    SYNTH_PAD = 'Synth Pad'

def extractZipDataset():
    # Apri il file ZIP e estrai i contenuti
    with zipfile.ZipFile(os.path.join(RAW_DATASET_DIR, 'dataset.zip'), 'r') as zipRef:
        zipRef.extractall(RAW_DATASET_DIR)

    print(f'File estratti in {RAW_DATASET_DIR}')

def convertToWav(baseDir: str, outDir: str):
    fc.to_wav(baseDir, outDir, n_threads=2)

def createDatasetForVqVae(instruments_set: set[str], split_sets = ['train', 'test', 'validation']):
    """Questa funzione deve generare un file csv contenente i path, uno per riga, delle tracce audio composte
    dagli strumenti nell'instrumentSet.
    """

    for set_dir in split_sets:
        csv_stems_file_path = os.path.join(DATASET_HOME_DIR, set_dir + "_stems.csv")
        csv_mixes_file_path = os.path.join(DATASET_HOME_DIR, set_dir + "_mixes.csv")

        if os.path.exists(csv_stems_file_path):
            print(f"The file \'{csv_stems_file_path}\' already exists! Are you sure you want overwrite it?")
            continue

        elif os.path.exists(csv_mixes_file_path):
            print(f"The file \'{csv_mixes_file_path}\' already exists! Are you sure you want overwrite it?")
            continue

        else:
            # Adding the table header to the csv for stems
            with open(csv_stems_file_path, 'w') as csv_stems_file:
                csv_stems_file.write("file_path;instrument_class;midi_program;track_name\n")

            # Adding the table header to the csv for mixes
            with open(csv_mixes_file_path, 'w') as csv_mixes_file:
                csv_mixes_file.write("file_path;midi_programs;instruments_classes;track_name\n")

            # Just creating the label to be added to tqdm
            tqdm_description = f"Creating {os.path.basename(csv_stems_file_path)} and {os.path.basename(csv_mixes_file_path)} files"

            # Here we keep track of the file doesn't exist
            not_existing_stems = []

            for track_dir_name in tqdm(os.listdir(os.path.join(WAV_DATASET_DIR, set_dir)), desc=tqdm_description):
                if not os.path.isdir(os.path.join(WAV_DATASET_DIR, set_dir, track_dir_name)):
                    continue

                # Compose the path to the directory of the track
                track_dir_path = os.path.join(WAV_DATASET_DIR, set_dir, track_dir_name)

                with open(os.path.join(track_dir_path, "metadata.yaml"), 'r') as file:
                    yamldata = yaml.safe_load(file)

                    # The list of the midi programs that composed the track
                    midi_programs_list: list[str] = []

                    # The set of the instrument class (es. guitar, piano) that composed the track
                    #   we need of a set since a track can be composed by several midi programs
                    #   (es. Electic Guitar, Classic Guitar) belonging to the same instrument class
                    instruments_classes_set: set[str] = set()

                    for stem_name in yamldata["stems"]:
                        # get the name of the instruments
                        instr_class: str = yamldata["stems"][stem_name]["inst_class"]
                        midi_program: str = yamldata["stems"][stem_name]["midi_program_name"]
                        midi_programs_list.append(midi_program)
                        instruments_classes_set.add(instr_class)

                        if instr_class in instruments_set:

                            # The path to the stem file
                            stem_file_path = os.path.join(track_dir_path, 'stems', stem_name) + ".wav"

                            # check if the file exists
                            if not os.path.exists(stem_file_path):
                                not_existing_stems.append(stem_file_path)
                                continue

                            # add the path to the stem of the intrument to the csv
                            entry = f"{stem_file_path};{instr_class};{midi_program};{track_dir_name.lower()}\n"

                            with open(csv_stems_file_path, 'a') as csv_stems_file:
                                csv_stems_file.write(entry)

                    mixes_file_path = os.path.join(track_dir_path, 'mix.wav')

                    if not os.path.exists(mixes_file_path):
                        not_existing_stems.append(mixes_file_path)
                        continue

                    with open(csv_mixes_file_path, 'a') as csv_mixes_file:
                        entry = f"{mixes_file_path};{midi_programs_list};{list(instruments_classes_set)};{track_dir_name}\n"
                        csv_mixes_file.write(entry)

            # Print the errors
            print(f"The following files don't exist in the {set_dir} set:")
            for err in not_existing_stems:
                print(f"{' ' * 4}{err}")


def prepareDataset():
    extractZipDataset()

    # Convert each set to wav
    for dir in ['train', 'test', 'validation']:
        convertToWav(os.path.join(RAW_DATASET_DIR, dir), os.path.join(WAV_DATASET_DIR, dir))

    createDatasetForVqVae({Instrument.BASS.value, Instrument.PIANO.value, Instrument.DRUMS.value, Instrument.GUITAR.value})

    print("Dataset is ready!")

In [None]:
#createDatasetForVqVae({Instrument.BASS.value, Instrument.PIANO.value, Instrument.DRUMS.value, Instrument.GUITAR.value} )

# Look at the data


## Composition of the train set

In [None]:
FIG_SIZE = (15, 7)

pathToHistogramTrain = os.path.join(RAW_DATASET_DIR, "histograms/train.png")
pathToHistogramTest = os.path.join(RAW_DATASET_DIR, "histograms/test.png")
pathToHistogramValidation = os.path.join(RAW_DATASET_DIR, "histograms/validation.png")

trainHist = mpimg.imread(pathToHistogramTrain)
plt.figure(figsize=FIG_SIZE)
plt.imshow(trainHist)
plt.axis('off')
plt.show()

## Let's Plot the graphic of some tracks

In [None]:
def plot_waveform(waveform, graph_title: str, from_sec:int = None, to_sec: int = None, sample_rate = 44100):

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    # Crea il plot
    plt.figure(figsize=(15, 5))

    for i in range(num_channels):
        plt.plot(time_axis.numpy(), waveform[i].numpy(), label=f'Channel {i+1}')

    # Imposta range di visualizzazione su asse x
    duration_in_sec = num_frames / sample_rate
    plt.xlim(from_sec if from_sec is not None else 0, to_sec if to_sec is not None else duration_in_sec)

    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.title(graph_title)
    plt.legend()
    plt.show()

# Define the Pytorch datasets

In [None]:
# The header of the csv are
#   *_stems.csv -> ['file_path', 'instrument_class', 'midi_program', 'track_name']
#   *_mixes.csv -> ['file_path', 'midi_programs', 'instruments_classes', 'track_name']

#TODO: da sistemare
SAMPLE_RATE = 44100 #Hz

def get_alert_message(path, audio_length):
    return f"The file \"{path}\" is shorter than {audio_length}sec"

class StemsDataset(Dataset):
    def __init__(self, csv_audio_path, audio_length, custom_size=None, batch_size=None, transform=None):
        self.audio_csv = pd.read_csv(csv_audio_path, delimiter=';')
        if(custom_size != None):
            sample_size = max(1, min(len(self.audio_csv), custom_size))
            self.audio_csv = self.audio_csv.sample(n=sample_size, random_state=1)

        self.audio_length = audio_length
        self.transform = transform

        # In this way we ensure the last batch have the right size
        if(batch_size is not None):
            assert batch_size <= len(self.audio_csv)
            self.audio_csv = self.audio_csv.head(len(self.audio_csv) - (len(self.audio_csv) % batch_size))

    def __len__(self):
        return len(self.audio_csv)

    def __getitem__(self, idx):
        audio_path = self.audio_csv.iloc[idx, 0]
        waveform, sample_rate = torchaudio.load(audio_path)
        instrument_class = self.audio_csv.iloc[idx, 1]
        midi_program = self.audio_csv.iloc[idx, 2]

        assert waveform.shape[1] >= SAMPLE_RATE * self.audio_length, get_alert_message(audio_path, self.audio_length)

        if self.transform:
            waveform = self.transform(waveform)

        assert waveform.shape[1]/SAMPLE_RATE >= self.audio_length, f"The transformed audio is too short ({waveform.shape[1]/SAMPLE_RATE}sec)!"

        # take the first audio_length seconds
        waveform = waveform[:,0: SAMPLE_RATE * self.audio_length]

        return waveform

In [None]:
class MixesDataset(Dataset):
    def __init__(self, csv_audio_path, audio_length, custom_size=None, batch_size=None, transform=None):
        self.audio_csv = pd.read_csv(csv_audio_path, delimiter=';')
        if(custom_size != None):
            sample_size = max(1, min(len(self.audio_csv), custom_size))
            self.audio_csv = self.audio_csv.sample(n=sample_size, random_state=1)

        self.audio_length = audio_length
        self.transform = transform

        # In this way we ensure the last batch have the right size
        if(batch_size is not None):
            assert batch_size <= len(self.audio_csv)
            self.audio_csv = self.audio_csv.head(len(self.audio_csv) - (len(self.audio_csv) % batch_size))

    def __len__(self):
        return len(self.audio_csv)

    def __getitem__(self, idx):
        audio_path = self.audio_csv.iloc[idx, 0]
        waveform, sample_rate = torchaudio.load(audio_path)
        midi_programs = ast.literal_eval(self.audio_csv.iloc[idx, 1])
        instruments_class = ast.literal_eval(self.audio_csv.iloc[idx, 2])
        track_name = self.audio_csv.iloc[idx, 3]

        assert waveform.shape[1] >= SAMPLE_RATE * self.audio_length, get_alert_message(audio_path, self.audio_length)

        if self.transform:
            waveform = self.transform(waveform)

        assert waveform.shape[1]/SAMPLE_RATE >= self.audio_length, f"The transformed audio is too short ({waveform.shape[1]/SAMPLE_RATE}sec)!"

        # take the first audio_length seconds
        waveform = waveform[:,0: SAMPLE_RATE * self.audio_length]

        return waveform

In [None]:
class SingleTrackTestDataset(Dataset):
    def __init__(self, csv_audio_path, index_of_the_track, model_input_length, model_batch_size):
        audio_csv = pd.read_csv(csv_audio_path, delimiter=';')

        assert len(audio_csv) > index_of_the_track, f"The provided index {{{index_of_the_track}}} is greater than the number of tracks in the dataset {{{len(self.audio_csv)}}}!"

        audio_path = audio_csv.iloc[index_of_the_track, 0]
        waveform, sample_rate = torchaudio.load(audio_path)

        self.track_name = audio_csv.iloc[index_of_the_track, 3]

        # Calculate the padding for the track so that we can create the batches to feed the model
        chunk_length = SAMPLE_RATE * model_input_length * model_batch_size
        remainder = waveform.shape[1] % chunk_length
        self.padding_elements_to_add_to_end = chunk_length - remainder

        # resize the waveform to adapt it to the audio_lenght set for the model and batch_size for the model
        waveform = torch.nn.functional.pad(waveform, (0, self.padding_elements_to_add_to_end))
        n_waveforms = waveform.shape[1]//(SAMPLE_RATE * model_input_length)

        # It contains n (divisible by batch_size) chunk of the selected track
        self.single_track_dataset = waveform.view(n_waveforms, (SAMPLE_RATE * model_input_length))

    def __len__(self):
        return self.single_track_dataset.shape[0]

    def __getitem__(self, idx):
        return self.single_track_dataset[idx].unsqueeze(0)

# VQ-VAE: The model architecture

## The Vector Quantizer module
We calculate the loss by summing up two terms:
- **codebook loss**, which moves the embedding towards the encoder output;
- **commitment loss**, which makes sure the encoder commits to an embedding;

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        self._commitment_cost = commitment_cost

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) # TODO: Verifica con float64
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)

    def forward(self, inputs):
        # convert inputs from BCW -> BWC
        inputs = inputs.permute(0, 2, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # # Calculate L2-normalized distance between the inputs and the codes
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)                  # a^2
                    + torch.sum(self._embedding.weight.t()**2, dim=0, keepdim=True) # b^2
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))     # -2ab

        # Get the index of the neareast code (from codebook) for each embeddings output by encoder
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)

        # Quantize and unflatten
        quantized = self._embedding(encoding_indices).view(input_shape)

        # Calculating commitment loss and codebook loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs) # commitment loss
        q_latent_loss = F.mse_loss(quantized, inputs.detach()) # codebook loss
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        # Straight-through estimator trick for gradient backpropagation
        quantized = inputs + (quantized - inputs).detach()

        # Create a one-hot matrix where the ones at (i, j) position
        #   indicates that the j-th code, from codebook, it was selected
        #   for the i-th embedding vector output by the encoder.
        one_hot_enc = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        one_hot_enc.scatter_(1, encoding_indices, 1)

        # Calculates the average probability of the utilization of each code from codebook
        avg_probs = torch.mean(one_hot_enc, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BWC -> BCHW
        return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, quantized


In [None]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        self._commitment_cost = commitment_cost
        self._decay = decay
        self._epsilon = epsilon

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) # TODO: Verifica con float64
        self._embedding.weight.data.normal_()

        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()

    def forward(self, inputs):
        # convert inputs from BCW -> BWC
        inputs = inputs.permute(0, 2, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # # Calculate L2-normalized distance between the inputs and the codes
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)                  # a^2
                    + torch.sum(self._embedding.weight.t()**2, dim=0, keepdim=True) # b^2
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))     # -2ab

        # Get the index of the neareast code (from codebook) for each embeddings output by encoder
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)

        # Quantize and unflatten
        quantized = self._embedding(encoding_indices).view(input_shape)

        # Create a one-hot matrix where the ones at (i, j) position
        #   indicates that the j-th code, from codebook, it was selected
        #   for the i-th embedding vector output by the encoder.
        one_hot_enc = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        one_hot_enc.scatter_(1, encoding_indices, 1)

        # Use EMA to update the embedding vectors (codebook)
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(one_hot_enc, 0)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)

            dw = torch.matmul(one_hot_enc.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Calculating commitment loss and codebook loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs) # commitment loss
        loss = self._commitment_cost * e_latent_loss

        # Straight-through estimator trick for gradient backpropagation
        quantized = inputs + (quantized - inputs).detach()

        # Calculates the average probability of the utilization of each code from codebook
        avg_probs = torch.mean(one_hot_enc, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BWC -> BCHW
        return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, quantized

## The ResNet module

In [None]:
class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()

        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv1d(in_channels=in_channels,
                      out_channels=num_residual_hiddens,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv1d(in_channels=num_residual_hiddens,
                      out_channels=num_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        return x + self._block(x)


class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()

        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)

## Function for initializing convolutional layers

In [None]:
from math import log2

def get_levels(encoder_output_embeddings, num_layers, reverse=False, model_input_sequence_length=None, padding=1, kernel_size=4, stride=2):
    '''This function returns a list of 2-ple ((input_channels, output_channels), out_padding_boolean)
    within which the first element is a 2-ple representing the input and output dimensions of each encoder/decoder block;
    whereas the second element is a boolean which indicates whether or not the output_padding should be added in the
    decoded layer which it refers.
    Each two subsequent layer it will be a halving/doubling of the dimensionality of the output channels.

    For example, for encoder_output_embeddings=128 and num_layers=4 and model_input_sequence_length=132
    it will return:
            ( (1, 16),   False )
            ( (16, 32),  False )
            ( (32, 64),  True )
            ( (64, 128), False )
    '''
    input_hiddens = []
    output_hiddens = []
    output_padding = []

    assert log2(encoder_output_embeddings) - (int(log2(encoder_output_embeddings))) == 0
    assert log2(encoder_output_embeddings) >= num_layers
    assert model_input_sequence_length is None or model_input_sequence_length >= encoder_output_embeddings

    # assert if reverse=True than model_input_sequence_length must be provided
    assert (not reverse) or model_input_sequence_length is not None

    while(encoder_output_embeddings>=2 and num_layers>0):
        last = not (encoder_output_embeddings//2>=2 and num_layers-1>0)
        input_hiddens.append(encoder_output_embeddings//2 if not last else 1)
        output_hiddens.append(encoder_output_embeddings)

        # if reverse then calculates if the padding for transposeConv1d is needed
        if reverse:
            out_sequence_length = \
                (float(model_input_sequence_length + (2 * padding) - (kernel_size - 1) - 1) / stride) + 1
            output_padding.append(out_sequence_length - int(out_sequence_length) != 0)
            model_input_sequence_length = int(out_sequence_length)

        # Update the loop parameters
        encoder_output_embeddings = encoder_output_embeddings//2
        num_layers-=1

    input_hiddens = input_hiddens[::-1 if not reverse else None]
    output_hiddens = output_hiddens[::-1 if not reverse else None]

    output_padding = output_padding[::-1]

    if reverse:
        return zip(output_hiddens, input_hiddens, output_padding)
    else:
        return zip(input_hiddens, output_hiddens)

## The Encoder architecture

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation_function):
        super(EncoderBlock, self).__init__()

        self._activation_function = activation_function
        self._conv = nn.Conv1d( in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=4,
                                stride=2,
                                padding=1)

    def forward(self, x):
        x = self._conv(x)
        x = self._activation_function(x)
        return x


class Encoder(nn.Module):
    def __init__(self,
                 in_channels,
                 num_hiddens, # the dimensionality of the embeddings output by the encoder.
                 num_halving_layers,
                 num_residual_layers,
                 num_residual_hiddens):
        '''The encoder takes a in_channels-dimensional embedding sequence of length N as input
        and maps it in a sequence N/"2^num_halving_layers" smaller of embeddings of dimension "num_hiddens".
        '''

        super(Encoder, self).__init__()
        self._block_activation_function = nn.ReLU()

        self._layers = [
            EncoderBlock(in_channels=_in_channels,
                         out_channels=_out_channels,
                         activation_function=nn.Identity() if (_out_channels == num_hiddens) else self._block_activation_function
            )
            for _in_channels, _out_channels in get_levels(num_hiddens, num_halving_layers)
        ] + [
            nn.Conv1d(  in_channels=num_hiddens,
                        out_channels=num_hiddens,
                        kernel_size=3,
                        stride=1,
                        padding=1),

            ResidualStack(  in_channels=num_hiddens,
                            num_hiddens=num_hiddens,
                            num_residual_layers=num_residual_layers,
                            num_residual_hiddens=num_residual_hiddens),
        ]

        self._encoder = nn.Sequential(*self._layers)

    def forward(self, inputs):
        return self._encoder(inputs)

## The Decoder architecture

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation_function, output_padding):
        super(DecoderBlock, self).__init__()

        self._conv = nn.ConvTranspose1d(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=4,
                                        stride=2,
                                        padding=1,
                                        output_padding=output_padding)

        self._activation_function = activation_function

    def forward(self, x):
        x = self._conv(x)
        x = self._activation_function(x)
        return x

class Decoder(nn.Module):
    def __init__(self,
                 in_channels,
                 model_input_sequence_length,
                 num_hiddens, # the dimensionality of the embeddings output by the encoder.
                 num_doubling_layer,
                 num_residual_layers,
                 num_residual_hiddens):
        '''The decoder takes a in_channels-dimensional embedding sequence of length N as input
        and maps it in a sequence N*"2^num_doubling_layer" longer of embeddings of dimension "num_hiddens".
        '''
        super(Decoder, self).__init__()

        self._block_activation_function = nn.ReLU()
        self._output_activation_func = nn.Tanh()  # Output values in range [-1, 1]

        self._layers = [
            nn.Conv1d(  in_channels=num_hiddens,
                        out_channels=num_hiddens,
                        kernel_size=3,
                        stride=1,
                        padding=1),

            ResidualStack(in_channels=num_hiddens,
                        num_hiddens=num_hiddens,
                        num_residual_layers=num_residual_layers,
                        num_residual_hiddens=num_residual_hiddens)
        ] + [
            DecoderBlock(in_channels=_in_channels,
                        out_channels=_out_channels,
                        activation_function=nn.Identity() if (_out_channels == 1) else self._block_activation_function,
                        output_padding=1 if _output_padding_needed else 0)

            for _in_channels, _out_channels, _output_padding_needed in \
                get_levels(
                    num_hiddens,
                    num_doubling_layer,
                    reverse=True,
                    model_input_sequence_length=model_input_sequence_length
                )
        ]

        self._decoder = nn.Sequential(*self._layers)

    def forward(self, x):
        x = self._decoder(x)
        return self._output_activation_func(x)

## The VQ-VAE model

In [None]:
class VQVAE(nn.Module):
    def __init__(self,
                 model_input_sequence_length,
                 num_hiddens, # the dimensionality of the embeddings output by the encoder.
                 num_halving_layers,
                 num_residual_layers,
                 num_residual_hiddens,
                 num_embeddings,
                 embedding_dim,
                 commitment_cost,
                 decay):

        super(VQVAE, self).__init__()

        self._encoder = Encoder(1,
                                num_hiddens,
                                num_halving_layers,
                                num_residual_layers,
                                num_residual_hiddens)

        # It maps the "num_hiddens-dimensional" embedding sequence in a "embedding_dim-dimensional" sequence
        #   preparing it for quantization
        self._pre_vq_conv = nn.Conv1d(in_channels=num_hiddens,
                                      out_channels=embedding_dim,
                                      kernel_size=1,
                                      stride=1)

        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
                                               commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                            commitment_cost)

        # It maps the "embedding_dim-dimensional" embedding sequence in a "num_hiddens-dimensional" embeddin sequence
        #   preparing it for the decoding operation (added by me)
        self._post_vq_conv = nn.Conv1d(in_channels=embedding_dim,
                                       out_channels=num_hiddens,
                                       kernel_size=1,
                                       stride=1)

        self._decoder = Decoder(num_hiddens,
                                model_input_sequence_length,
                                num_hiddens,
                                num_halving_layers,
                                num_residual_layers,
                                num_residual_hiddens)

    def forward(self, x):
        z = self._encoder(x)
        z = self._pre_vq_conv(z)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        decoder_input = self._post_vq_conv(quantized)
        x_recon = self._decoder(decoder_input)

        return loss, x_recon, perplexity

# Training of the VQ-VAE

## Select the device (cpu or gpu)

In [None]:
# @title
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"You are using {device}!")

## Define the hyper parameters

In [None]:
if USE_PRE_SAVED_MODEL:
    with open(PATH_TO_HYPERPARAMETERS, 'r') as hyper_params_file:
        HYPER_PARAMS = json.load(hyper_params_file)
        HYPER_PARAMS["epochs"] = EPOCHS_RE_TRAIN

elif ONLY_EVALUATION:
    with open(PATH_TO_HYPERPARAMETERS_EVALUATION, 'r') as hyper_params_file:
        HYPER_PARAMS = json.load(hyper_params_file)
else:
    HYPER_PARAMS = {
        "batch_size": 4,
        "epochs": 8000,

        # Dataset settings
        "train_size": 16,
        "validation_size": 4,
        "test_size": 4,

        # Inputs audio track length in seconds
        "audio_length": 14,

        # Control the silence in the track
        "remove_silence": True,

        # Dimensionality of each encoded vector just before quantization
        "num_hiddens": 128,

        # The quantized sequence has length:
        #   ~(audio_length * 44100 / 2^num_halving_layers)
        "num_halving_layers": 4,

        # The ResNet parameters
        "num_residual_hiddens": 64,
        "num_residual_layers": 3,

        # The dimensionality of the space where the codebook lies
        "embedding_dim": 64,

        # Number of the codes in the codebook
        "num_embeddings": 512,

        # To make sure the encoder commits to an embedding
        "commitment_cost": 0.25,
        "decay": 0.99,
        "learning_rate": 0.001,
        "use_spectral": True
    }

# Save the hyper params configuration in the training dir
if not ONLY_EVALUATION:
    with open(os.path.join(CURR_CHECKPOINT_DIR, "hyper_params.json"), 'w') as hyp_params_file:
        json.dump(HYPER_PARAMS, hyp_params_file, indent=4)

## Create the Datasets

### Define dataset transformation function

In [None]:
def trim_silence(waveform: torch.Tensor):
    return waveform[waveform != 0].unsqueeze(0)

### Build the Pytorch dataset object

In [None]:
# @title Select what dataset you want to use {"form-width":"20%","display-mode":"form"}
USE_STEMS = False # @param {"type":"boolean"}
if USE_STEMS:
    train_csv_path = os.path.join(DATASET_HOME_DIR, "train_stems.csv")
    test_csv_path = os.path.join(DATASET_HOME_DIR, "test_stems.csv")
    validation_csv_path = os.path.join(DATASET_HOME_DIR, "validation_stems.csv")

    train_set = StemsDataset(
                    train_csv_path,
                    HYPER_PARAMS['audio_length'],
                    batch_size=HYPER_PARAMS['batch_size'],
                    transform=trim_silence if HYPER_PARAMS['remove_silence'] else None
    )
    test_set = StemsDataset(
                    test_csv_path,
                    HYPER_PARAMS['audio_length'],
    )
    validation_set = StemsDataset(
                    validation_csv_path,
                    HYPER_PARAMS['audio_length'],
                    batch_size=HYPER_PARAMS['batch_size']
    )
else:
    train_csv_path = os.path.join(DATASET_HOME_DIR, "train_mixes.csv")
    test_csv_path = os.path.join(DATASET_HOME_DIR, "test_mixes.csv")
    validation_csv_path = os.path.join(DATASET_HOME_DIR, "validation_mixes.csv")

    train_set = MixesDataset(
                    train_csv_path,
                    HYPER_PARAMS['audio_length'],
                    custom_size=HYPER_PARAMS['train_size'],
                    transform=trim_silence if HYPER_PARAMS['remove_silence'] else None # It was set later than the best training
    )
    test_set = MixesDataset(
                    test_csv_path,
                    HYPER_PARAMS['audio_length']
    )
    validation_set = MixesDataset(
                    validation_csv_path,
                    HYPER_PARAMS['audio_length'],
                    custom_size=HYPER_PARAMS['validation_size']
    )

train_csv_entries = !wc -l $train_csv_path
train_csv_entries = int(train_csv_entries[0].split()[0])
test_csv_entries = !wc -l $test_csv_path
test_csv_entries = int(test_csv_entries[0].split()[0])
validation_csv_entries = !wc -l $validation_csv_path
validation_csv_entries = int(validation_csv_entries[0].split()[0])

print(f"------- Original datasets statistics -------\n")
print(f"The train csv file has {{{train_csv_entries}}} entries.")
print(f"The test csv file has {{{test_csv_entries}}} entries.")
print(f"The validation csv file has {{{validation_csv_entries}}}.")

print(f"\n\n------- Built datasets statistics -------\n")
print(f"Train set has {{{len(train_set)}}} entries.")
print(f"Test set has {{{len(test_set)}}} entries.")
print(f"Validation set has {{{len(validation_set)}}} entries.")

## Create the Dataloaders

In [None]:
train_vqvae_dataloader = DataLoader(train_set, batch_size=HYPER_PARAMS['batch_size'], shuffle=True)
validation_vqvae_dataloader = DataLoader(validation_set, batch_size=HYPER_PARAMS['batch_size'], shuffle=True)
test_vqvae_dataloader = DataLoader(test_set, batch_size=HYPER_PARAMS["batch_size"], shuffle=False)

## Create the model to train
If we set the 'USE_PRE_SAVED_MODEL' to *True* then the selected pre-saved model will be train, else a new one will be created.

In [None]:
if not ONLY_EVALUATION:
    # define the model
    model = VQVAE(
        HYPER_PARAMS["audio_length"] * SAMPLE_RATE,
        HYPER_PARAMS["num_hiddens"],
        HYPER_PARAMS["num_halving_layers"],
        HYPER_PARAMS["num_residual_layers"],
        HYPER_PARAMS["num_residual_hiddens"],
        HYPER_PARAMS["num_embeddings"],
        HYPER_PARAMS["embedding_dim"],
        HYPER_PARAMS["commitment_cost"],
        HYPER_PARAMS["decay"]
    ).to(device)

    if USE_PRE_SAVED_MODEL:
        model.load_state_dict(
            torch.load(PATH_TO_PRE_SAVED_MODEL, map_location=torch.device('cuda'))
        )

    optimizer = optim.Adam(model.parameters(), lr=HYPER_PARAMS['learning_rate'], amsgrad=False)

## Pretty print of the Encoder and Decoder configuration

In [None]:
# @title
from prettytable import PrettyTable

model_input_sequence_length = SAMPLE_RATE*HYPER_PARAMS["audio_length"]

# calculates output sequence length from convolution
conv_out_sequence_len = lambda in_dim: int((float(id_dim + (2 * 1) - (4 - 1) - 1) / 2) + 1)

conv_layers = HYPER_PARAMS["num_halving_layers"]
lay_i=1
encoder_output_embeddings = HYPER_PARAMS["num_hiddens"]

print(f"\nInfo: The sequence output by the Encoder is {2**conv_layers} times shorter than the input sequence to the model.\n")
print("+--------------------- Display Encoder layers setting ---------------------+\n")

enc_table = PrettyTable()
enc_table.field_names = ["input_embeddings", "out_embeddings", "out_sequence_length"]

for i, o in get_levels(encoder_output_embeddings, conv_layers, reverse=False):
    enc_table.add_row((i, o, int(model_input_sequence_length/2**(lay_i))))
    lay_i+=1

print(enc_table)

print("\n\n+--------------------- Display Decoder layers setting ---------------------+\n")

dec_table = PrettyTable()
dec_table.field_names = ["input_embeddings", "out_embeddings", "output_padding", "out_sequence_length"]

lay_i=0 #reset
for i, o, pad in get_levels(encoder_output_embeddings, conv_layers, reverse=True, model_input_sequence_length=model_input_sequence_length):
    dec_table.add_row((i, o, pad, int(model_input_sequence_length/2**(conv_layers-lay_i))))
    lay_i+=1

print(dec_table)

## Define the Spectral Loss

In [None]:
def compute_stft_magnitude(x):
    stft_x = torch.stft(x, n_fft=1024, hop_length=256, win_length=1024, window=torch.hann_window(1024, device=device), return_complex=True)
    return torch.abs(stft_x)

def spectral_loss(x_original, x_reconstructed):
    '''Compute the magnitude of the STFT for both signals'''
    mag_x_original = compute_stft_magnitude(x_original.squeeze(1)).unsqueeze(1)
    mag_x_reconstructed = compute_stft_magnitude(x_reconstructed.squeeze(1)).unsqueeze(1)

    # Return the loss as L2 between signals
    return F.mse_loss(mag_x_original, mag_x_reconstructed)

## Define a function that plots training statistics

In [None]:
def plot_loss_and_perplexity(train_res_recon_error, train_perplexity, valid_recon_error, valid_perplexity):
    WINDOW_LENGTH = 201

    # Create a figure with a 2x2 grid of graphs
    f = plt.figure(figsize=(16, 12))

    # Graph 1: NMSE train
    ax = f.add_subplot(2, 2, 1)
    ax.plot(savgol_filter(train_res_recon_error, WINDOW_LENGTH, 7))
    ax.set_yscale('log')
    ax.set_title('Smoothed NMSE (Train)')
    ax.set_xlabel('epochs')

    # Graph 2: Perplexity train
    ax = f.add_subplot(2, 2, 2)
    ax.plot(savgol_filter(train_perplexity, WINDOW_LENGTH, 7))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=10))
    ax.set_title('Train Smoothed Average codebook usage (Perplexity)')
    ax.set_xlabel('epochs')

    # Graph 3: NMSE validation
    ax = f.add_subplot(2, 2, 3)
    ax.plot(savgol_filter(valid_recon_error, WINDOW_LENGTH, 7))
    ax.set_yscale('log')
    ax.set_title('Smoothed NMSE (Validation)')
    ax.set_xlabel('epochs')

    # Graph 4: Perplexity validation
    ax = f.add_subplot(2, 2, 4)
    ax.plot(savgol_filter(valid_perplexity, WINDOW_LENGTH, 7))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=10))
    ax.set_title('Validation Smoothed Average codebook usage (Perplexity)')
    ax.set_xlabel('epochs')

    f.savefig(os.path.join(CURR_CHECKPOINT_DIR, "train_statistics.png"))

## Define a per-epoch training function

In [None]:
def train_one_epoch(train_dataloader: DataLoader, epoch_index):

    train_res_recon_error = []
    train_res_perplexity = []

    for _, inputs in enumerate(train_dataloader):

        inputs = inputs.to(device)

        data_variance = torch.var(inputs, correction=0)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        vq_loss, data_recon, perplexity = model(inputs)

        # Compute the losses and its gradients
        recon_error = F.mse_loss(data_recon, inputs) / data_variance

        if HYPER_PARAMS["use_spectral"]:
            # Compute the Spectral Loss
            spectral_loss_value = spectral_loss(inputs, data_recon)
        else:
            spectral_loss_value = 0

        #vq_loss = vq_loss / data_variance
        loss = recon_error + vq_loss + spectral_loss_value
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        train_res_recon_error.append(recon_error.item())
        train_res_perplexity.append(perplexity.item())

    return np.mean(train_res_recon_error), np.mean(train_res_perplexity)

## Define the main training function

In [None]:
KEEP_LAST_N_MODELS = 1
THRESHOLD_ERROR_TO_SAVE_MODEL = 0.9

def training_loop():

    epoch_number = 1

    best_valid_recon_error = 1_000_000
    best_valid_perplexity = 1

    # To keep track of the saved models
    saved_models = {
        "RECON_ERROR": [],
        "PERPLX": [],
        "BOTH": []
    }

    # To keep track of the loss and perplexity along epochs and finally plot the graphs
    train_recon_error_per_epoch = []
    train_perplexity_per_epoch = []
    valid_recon_error_per_epoch = []
    valid_perplexity_per_epoch = []

    with tqdm(total=HYPER_PARAMS["epochs"]) as pbar:
        for epoch in range(HYPER_PARAMS["epochs"]):
            log('EPOCH {}'.format(epoch_number))

            # Make sure gradient tracking is on, and do a pass over the data
            model.train()
            avg_train_recon_error, avg_train_perplexity = train_one_epoch(train_vqvae_dataloader,epoch_number)

            # Store statistics
            train_recon_error_per_epoch.append(avg_train_recon_error)
            train_perplexity_per_epoch.append(avg_train_perplexity)

            # Set the model to evaluation mode, disabling dropout and using population
            # statistics for batch normalization.
            model.eval()

            valid_res_recon_error = []
            valid_res_perplexity = []

            # Disable gradient computation and reduce memory consumption.
            with torch.no_grad():
                for i, valid_inputs in enumerate(validation_vqvae_dataloader):
                    valid_inputs = valid_inputs.to(device)
                    data_variance = torch.var(valid_inputs, correction=0)
                    valid_vq_loss, valid_data_recon, valid_perplexity = model(valid_inputs)
                    recon_error = F.mse_loss(valid_data_recon, valid_inputs) / data_variance
                    valid_res_recon_error.append(recon_error.item())
                    valid_res_perplexity.append(valid_perplexity.item())

            avg_valid_recon_error = np.mean(valid_res_recon_error)
            avg_valid_perplexity = np.mean(valid_res_perplexity)

            # Store statistics
            valid_recon_error_per_epoch.append(avg_valid_recon_error)
            valid_perplexity_per_epoch.append(avg_valid_perplexity)

            # Log error and perplexity
            log(f"Train reconstruction error: {avg_train_recon_error} - Validation reconstruction error: {avg_valid_recon_error}")
            log(f"Train perplexity: {avg_train_perplexity} - Validation perplexity: {avg_valid_perplexity}")

            pbar.set_postfix({
                'Train reconstruction err': f'{avg_train_recon_error}',
                'Train perplexity': f'{avg_train_perplexity}'
            })

            # Check if the error is lower than set threshold
            if (avg_valid_recon_error < THRESHOLD_ERROR_TO_SAVE_MODEL):

                if (avg_valid_recon_error < best_valid_recon_error or
                    avg_valid_perplexity > best_valid_perplexity):

                    select_storage = None

                    if (avg_valid_recon_error < best_valid_recon_error and avg_valid_perplexity > best_valid_perplexity):
                        best_valid_recon_error = avg_valid_recon_error
                        best_valid_perplexity = avg_valid_perplexity
                        model_name = 'VQVAE_{}_BEST_BOTH_PERPLX_RECON_ERR'.format(epoch_number)
                        select_storage = "BOTH"

                    # Save the model because validation reconstruction error is improved
                    elif (avg_valid_recon_error < best_valid_recon_error):
                        best_valid_recon_error = avg_valid_recon_error
                        model_name = 'VQVAE_{}_BEST_RECON_ERR'.format(epoch_number)
                        select_storage = "RECON_ERROR"

                    # Save the model beacause validation perplexity as grown
                    elif (avg_valid_perplexity > best_valid_perplexity):
                        best_valid_perplexity = avg_valid_perplexity
                        model_name = 'VQVAE_{}_BEST_PERPLX'.format(epoch_number)
                        select_storage = "PERPLX"

                    # Save the model
                    model_path = os.path.join(CURR_CHECKPOINT_DIR, model_name)
                    torch.save(model.state_dict(), model_path)

                    # Keep track of saved models
                    saved_models[select_storage].append(model_path)

                    # delete too old saved models
                    if len(saved_models[select_storage]) > KEEP_LAST_N_MODELS:
                        os.remove(saved_models[select_storage][0])
                        saved_models[select_storage] = saved_models[select_storage][1:]

            epoch_number += 1
            pbar.update(1)

    plot_loss_and_perplexity(train_recon_error_per_epoch, train_perplexity_per_epoch, valid_recon_error_per_epoch, valid_perplexity_per_epoch)

## Train!

In [None]:
if not ONLY_EVALUATION:
    training_loop()

# Quantitative Evaluation

In [None]:
# @title
# If we are only evaluating then we must load a pre-trained model
#   else we use the newly trained model
if ONLY_EVALUATION:
    model = VQVAE(
        HYPER_PARAMS["audio_length"] * SAMPLE_RATE,
        HYPER_PARAMS["num_hiddens"],
        HYPER_PARAMS["num_halving_layers"],
        HYPER_PARAMS["num_residual_layers"],
        HYPER_PARAMS["num_residual_hiddens"],
        HYPER_PARAMS["num_embeddings"],
        HYPER_PARAMS["embedding_dim"],
        HYPER_PARAMS["commitment_cost"],
        HYPER_PARAMS["decay"]
    ).to(device)

    model.load_state_dict(
        torch.load(
            PATH_TO_MODEL_TO_EVALUATE,
            map_location=torch.device('cuda')
        )
    )

_ = model.eval()

In [None]:
# @title Select the dataset {"run":"auto"}
dataset_to_evaluate_on = test_vqvae_dataloader # @param ["train_vqvae_dataloader","validation_vqvae_dataloader","test_vqvae_dataloader"] {"type":"raw"}
data_iterator = iter(dataset_to_evaluate_on)
batch_number = 1

In [None]:
# @title Re-run this cell if you want to access the next batch

batch_test_inputs = next(data_iterator).to(device)
print(f"Batch number: {batch_number}")
batch_number += 1
test_vq_loss, test_data_recon, test_perplexity = model(batch_test_inputs)
batch_data_variance = torch.var(batch_test_inputs, correction=0)
print(f"Test perplexity: {test_perplexity}")
print(f"Test data_variance: {batch_data_variance}")
print(f"Test vq_loss: {test_vq_loss}")
print(f"Test recon_error: {F.mse_loss(test_data_recon, batch_test_inputs) / batch_data_variance}")

In [None]:
# @title ## Play with Test data {"run":"auto","vertical-output":true}
num = 1 # @param {"type":"slider","min":0,"max":4,"step":1}
import IPython.display as ipd
import librosa

audio_numpy_v = batch_test_inputs.detach().cpu()[num][0].numpy()
audio_numpy_r = test_data_recon.detach().cpu()[num][0].numpy()

# Usa IPython.display per riprodurre l'audio
ipd.display(ipd.Audio(audio_numpy_v, rate=44100))
ipd.display(ipd.Audio(audio_numpy_r, rate=44100))
plot_waveform(batch_test_inputs.cpu()[num], graph_title="Original waveform")
plot_waveform(test_data_recon.detach().cpu()[num], graph_title="Reconstructed waveform")

In [None]:
# @title ## Reconstruct entire tracks from Train, Validation or Test set {"form-width":"20%"}

def reconstruct_the_track(dataloader, padding_added_to_the_end: int):
    batch_test_recon_error = []
    batch_test_res_perplexity = []
    reconstructed_track = torch.tensor([]).cpu()

    for input in iter(dataloader):
        input = input.to(device)
        test_vq_loss, test_data_recon, test_perplexity = model(input)
        batch_data_variance = torch.var(batch_test_inputs, correction=0)
        recon_error = F.mse_loss(test_data_recon, batch_test_inputs) / batch_data_variance
        batch_test_recon_error.append(recon_error.item())
        batch_test_res_perplexity.append(test_perplexity.item())
        reconstructed_track = \
            torch.cat((reconstructed_track, test_data_recon.detach().view(HYPER_PARAMS['batch_size'] * SAMPLE_RATE * HYPER_PARAMS['audio_length']).cpu()), dim=0)

    reconstructed_track = reconstructed_track[:-padding_added_to_the_end].unsqueeze(0).cpu()

    return reconstructed_track, np.mean(batch_test_res_perplexity), np.mean(batch_test_recon_error)


def reconstruct_the_dataset(dataset_csv: str, selected_dataset: str, n_tracks_to_reconstruct: int):
    working_dir = os.path.join(EVAL_CHECKPOINT_DIR, "ReconstrutedTracks", selected_dataset)
    assert not os.path.exists(working_dir), f"The working directory '{working_dir}' already exists!"
    os.makedirs(working_dir, exist_ok=False)

    print(f"Reconstructed audio tracks wirth statistics are saved at '{working_dir}'")

    statistics_file = os.path.join(working_dir, "Statistics.csv")
    assert not os.path.exists(statistics_file), "The Statistics file already exists"

    label_for_recon_error = f"Reconstrucion Error{' (With Spectral loss)' if HYPER_PARAMS['use_spectral'] else None}"
    with open(statistics_file, 'w') as f:
        f.write(f"Track name;{label_for_recon_error};Perplexity;Average {label_for_recon_error};Average Perplexity\n")

    recon_error_along_dataset = []
    perplexity_along_dataset = []

    for track_i in tqdm(range(n_tracks_to_reconstruct)):
        single_track_test_set = SingleTrackTestDataset(
                                    dataset_csv,
                                    track_i,
                                    HYPER_PARAMS['audio_length'],
                                    HYPER_PARAMS['batch_size']
                                )

        single_track_dataloader = DataLoader(
                                    single_track_test_set,
                                    batch_size=HYPER_PARAMS['batch_size'],
                                    shuffle=False
                                )

        reconstructed_track, perplexity, recon_error = reconstruct_the_track(
                                                            single_track_dataloader,
                                                            single_track_test_set.padding_elements_to_add_to_end
                                                    )
        track_name = single_track_test_set.track_name

        # Clean memory
        del single_track_dataloader, single_track_test_set
        torch.cuda.empty_cache()

        # Save the reconstructed track into file
        torchaudio.save(os.path.join(working_dir,f'{track_name}.wav'), reconstructed_track, SAMPLE_RATE)

        # Store statistics
        recon_error_along_dataset.append(recon_error)
        perplexity_along_dataset.append(perplexity)

        with open(statistics_file, 'a') as f:
            f.write(f"{track_name};{recon_error};{perplexity};-;-\n")

    with open(statistics_file, 'a') as f:
        f.write(f"-;-;-;{np.mean(recon_error_along_dataset)};{np.mean(perplexity_along_dataset)}\n")

# @title  {"form-width":"30%"}
selected_dataset_to_reconstruct = "Test" # @param ["Train","Validation","Test"]
n_tracks_to_reconstruct = 31 # @param {"type":"integer","placeholder":"Insert the number of tracks you want to reconstruct from the selected dataset"}
selected_csv = None

if selected_dataset_to_reconstruct == "Train":
    selected_csv = train_csv_path
    assert n_tracks_to_reconstruct <= len(train_set), "There are not so many data in the Train set"
elif selected_dataset_to_reconstruct == "Validation":
    selected_csv = validation_csv_path
    assert n_tracks_to_reconstruct <= len(validation_set), "There are not so many data in the Train set"
elif selected_dataset_to_reconstruct == "Test":
    selected_csv = test_csv_path
    assert n_tracks_to_reconstruct <= len(test_set), "There are not so many data in the Train set"

reconstruct_the_dataset(selected_csv, selected_dataset_to_reconstruct, n_tracks_to_reconstruct)