# Jupyter Notebook for Discrete Audio Generation

## Install Necessary Libraries

In [None]:
# First, we need to install the required libraries for the project.
!pip install x-transformers encodec

## Import Libraries

In [2]:
# Import necessary modules for training and evaluation.
import torch
from tqdm import tqdm
from x_transformers import TransformerWrapper, Decoder
from encodec import EncodecModel
from torch import nn
import soundfile as sf
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
from encodec.utils import convert_audio
import IPython.display as ipd

## Set Parameters

In [3]:
# Define some global parameters for the project.
BANDWIDTH = 1.5
LEVELS = 2 # 2 for bandwidth 1.5, 8 for bandwith 6.0
TIMESTEPS = 125

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Download NSYNTH_GUITAR_MP3 dataset

In [4]:
!git clone https://github.com/SonyCSLParis/test-lfs.git
!bash ./test-lfs/download.sh NSYNTH_GUITAR_MP3

Cloning into 'test-lfs'...
remote: Enumerating objects: 42, done.[K
remote: Counting objects: 100% (42/42), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 42 (delta 5), reused 40 (delta 3), pack-reused 0 (from 0)[K
Unpacking objects: 100% (42/42), 5.92 KiB | 433.00 KiB/s, done.
--2024-10-21 13:37:15--  https://media.githubusercontent.com/media/SonyCSLParis/test-lfs/refs/heads/master/NSYNTH_GUITAR_MP3.zip
Resolving media.githubusercontent.com (media.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to media.githubusercontent.com (media.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 334999208 (319M) [application/zip]
Saving to: ‘NSYNTH_GUITAR_MP3.zip’


2024-10-21 13:37:57 (64.1 MB/s) - ‘NSYNTH_GUITAR_MP3.zip’ saved [334999208/334999208]

Fix archive (-F) - assume mostly intact archive
Zip entry offsets do not need adjusting
 copying:

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)




  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_004-108-050.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_024-097-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_014-074-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_000-098-050.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_013-067-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_034-092-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_003-091-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_011-107-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_005-021-100.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_020-047-050.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_020-073-100.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guit

## Dataset Class Definition

In [5]:
# Define a dataset class to handle discrete audio representations.
class DiscreteAudioRepDataset(Dataset):
    def __init__(self, root_dir, model, lazy_encode=True,
                 extensions=[".wav", ".mp3", ".flac"], max_samples=-1):
        """
        Args:
            root_dir (string): Directory with all the audio files.
            model: The EnCodec model for encoding audio.
            lazy_encode (bool): If True, encodes audio on-demand (when __getitem__ is called).
                               If False, encodes all audio at initialization.
            extensions (list): List of audio file extensions to include.
        """
        self.root_dir = root_dir
        self.extensions = extensions
        self.model = model
        self.lazy_encode = lazy_encode
        self.audio_files = []

        # Walk through all subfolders to gather audio files
        for root, _, files in os.walk(root_dir):
            for file in files:
                if any(file.endswith(ext) for ext in self.extensions):
                    self.audio_files.append(os.path.join(root, file))

        if max_samples < 0:
          max_samples = len(self.audio_files)

        self.audio_files = self.audio_files[:max_samples]

        # If not lazy encoding, encode all audio files during initialization
        if not self.lazy_encode:
            self.encoded_data = [self._encode_audio(file) for file in tqdm(self.audio_files, desc="Encoding Audio")]

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        if self.lazy_encode:
            filename = self.audio_files[idx]
            return self._encode_audio(filename)[..., :TIMESTEPS*LEVELS]
        else:
            encoded_audio = self.encoded_data[idx]
            return encoded_audio[:TIMESTEPS * LEVELS]

    def _encode_audio(self, filename):
        waveform, sample_rate = sf.read(filename)
        waveform = torch.tensor(waveform, dtype=torch.float32)[None, None, :]  # Add batch dimension
        waveform = convert_audio(waveform, sample_rate, self.model.sample_rate, self.model.channels)

        with torch.no_grad():
            discrete_reps = self.model.encode(waveform.to(device))

        discrete_reps = discrete_reps[0][0].contiguous().permute(0, 2, 1).reshape(-1)
        return discrete_reps.cpu()

## Load EnCodec Model and Convert Files

In [6]:
# Load the EnCodec model to transform the audio to discrete representation.
codec = EncodecModel.encodec_model_24khz().to(device)
codec.set_target_bandwidth(BANDWIDTH)

# ## Load Dataset
# Load the NSYNTH dataset and prepare DataLoader for training and validation.
audio_folder_train = "./NSYNTH_GUITAR_MP3/nsynth-guitar-train"
audio_folder_val = "./NSYNTH_GUITAR_MP3/nsynth-guitar-valid"

dataset = DiscreteAudioRepDataset(root_dir=audio_folder_train, model=codec,
                                  lazy_encode=False, max_samples=-1)

dataset_val = DiscreteAudioRepDataset(root_dir=audio_folder_val, model=codec,
                                      lazy_encode=False, max_samples=-1)

Encoding Audio: 100%|██████████| 32690/32690 [05:42<00:00, 95.57it/s] 
Encoding Audio: 100%|██████████| 2081/2081 [00:22<00:00, 93.69it/s]


In [7]:
# Create Dataloaders for training and validation.
dataloader = DataLoader(dataset, batch_size=125, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=125, shuffle=True)

## Define Transformer Model

In [8]:
# Define the Transformer model using TransformerWrapper and Decoder.
model = TransformerWrapper(
    emb_dropout=0.1,
    num_tokens=1024,
    max_seq_len=LEVELS*TIMESTEPS,
    attn_layers=Decoder(
        dim=256,
        depth=6,
        heads=4,
        rotary_pos_emb=True,
        attn_dropout=0.1,
        ff_dropout=0.1
    )
).to(device)

## (Optional) Load Pretrained Weights
### If available, you can load pretrained weights for the Transformer model.

In [None]:
# Optionally load pretrained weights
model.load_state_dict(torch.load('./model_gen_autoreg_transformer.pth', map_location=device))

## Training Loop

In [10]:
# Define a training loop for training the Transformer model.
epochs = 100
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

for epoch in tqdm(range(epochs), desc="Epochs"):
    train_loss, val_loss = 0, 0
    model.train()
    total_correct, total_predictions = 0, 0
    count = 0

    for batch in dataloader:
        if batch.dtype != torch.long:
            batch = batch.long()
        start_tokens = torch.zeros((batch.shape[0], 1), dtype=torch.long, device=batch.device)
        batch = torch.cat([start_tokens, batch], dim=1)
        discrete_reps = batch.to(device)

        logits = model(discrete_reps)
        logits = logits.permute(0, 2, 1)
        loss = criterion(logits[..., :-1], discrete_reps[..., 1:])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
        count += 1

        preds = logits[..., :-1].argmax(dim=1)
        targets = discrete_reps[..., 1:]
        correct = (preds == targets).sum().item()
        total_correct += correct
        total_predictions += targets.numel()

    avg_train_loss = train_loss / count
    precision = total_correct / total_predictions
    print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss:.4f}, Precision: {precision:.4f}')
    torch.save(model.state_dict(), 'model_gen_autoreg_transformer.pth')
    
    # Validation loop
    model.eval()
    val_loss = 0
    val_total_correct, val_total_predictions = 0, 0
    count_val = 0

    with torch.no_grad():
        for batch_val in dataloader_val:
            if batch_val.dtype != torch.long:
                batch_val = batch_val.long()
            start_tokens_val = torch.zeros((batch_val.shape[0], 1), dtype=torch.long, device=batch_val.device)
            batch_val = torch.cat([start_tokens_val, batch_val], dim=1)
            discrete_reps_val = batch_val.to(device)

            logits_val = model(discrete_reps_val)
            logits_val = logits_val.permute(0, 2, 1)
            loss_val = criterion(logits_val[..., :-1], discrete_reps_val[..., 1:])
            val_loss += loss_val.item()
            count_val += 1

            preds_val = logits_val[..., :-1].argmax(dim=1)
            targets_val = discrete_reps_val[..., 1:]
            correct_val = (preds_val == targets_val).sum().item()
            val_total_correct += correct_val
            val_total_predictions += targets_val.numel()

    avg_val_loss = val_loss / count_val
    val_precision = val_total_correct / val_total_predictions
    print(f'Epoch {epoch + 1}, Validation Loss: {avg_val_loss:.4f}, Validation Precision: {val_precision:.4f}')


Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1, Train Loss: 4.6227, Precision: 0.2465


Epochs:   1%|          | 1/100 [00:35<57:49, 35.05s/it]

Epoch 1, Validation Loss: 3.9979, Validation Precision: 0.2892
Epoch 2, Train Loss: 3.3278, Precision: 0.3626


Epochs:   2%|▏         | 2/100 [01:10<57:14, 35.05s/it]

Epoch 2, Validation Loss: 3.3031, Validation Precision: 0.3395
Epoch 3, Train Loss: 2.8671, Precision: 0.3976


Epochs:   3%|▎         | 3/100 [01:45<56:52, 35.18s/it]

Epoch 3, Validation Loss: 2.9399, Validation Precision: 0.3666
Epoch 4, Train Loss: 2.6267, Precision: 0.4177


Epochs:   4%|▍         | 4/100 [02:20<56:28, 35.30s/it]

Epoch 4, Validation Loss: 2.7269, Validation Precision: 0.3868
Epoch 5, Train Loss: 2.4770, Precision: 0.4334


Epochs:   5%|▌         | 5/100 [02:56<56:04, 35.41s/it]

Epoch 5, Validation Loss: 2.6014, Validation Precision: 0.4068
Epoch 6, Train Loss: 2.3660, Precision: 0.4478


Epochs:   6%|▌         | 6/100 [03:32<55:34, 35.48s/it]

Epoch 6, Validation Loss: 2.5088, Validation Precision: 0.4234
Epoch 7, Train Loss: 2.2797, Precision: 0.4607


Epochs:   7%|▋         | 7/100 [04:07<55:00, 35.49s/it]

Epoch 7, Validation Loss: 2.4057, Validation Precision: 0.4389
Epoch 8, Train Loss: 2.2055, Precision: 0.4738


Epochs:   8%|▊         | 8/100 [04:43<54:26, 35.51s/it]

Epoch 8, Validation Loss: 2.3370, Validation Precision: 0.4524
Epoch 9, Train Loss: 2.1394, Precision: 0.4862


Epochs:   9%|▉         | 9/100 [05:18<53:55, 35.56s/it]

Epoch 9, Validation Loss: 2.2746, Validation Precision: 0.4628
Epoch 10, Train Loss: 2.0815, Precision: 0.4969


Epochs:  10%|█         | 10/100 [05:54<53:19, 35.55s/it]

Epoch 10, Validation Loss: 2.2304, Validation Precision: 0.4731
Epoch 11, Train Loss: 2.0314, Precision: 0.5060


Epochs:  11%|█         | 11/100 [06:29<52:44, 35.56s/it]

Epoch 11, Validation Loss: 2.1881, Validation Precision: 0.4817
Epoch 12, Train Loss: 1.9873, Precision: 0.5141


Epochs:  12%|█▏        | 12/100 [07:05<52:09, 35.56s/it]

Epoch 12, Validation Loss: 2.1499, Validation Precision: 0.4886
Epoch 13, Train Loss: 1.9481, Precision: 0.5210


Epochs:  13%|█▎        | 13/100 [07:41<51:36, 35.59s/it]

Epoch 13, Validation Loss: 2.1203, Validation Precision: 0.4950
Epoch 14, Train Loss: 1.9132, Precision: 0.5275


Epochs:  14%|█▍        | 14/100 [08:16<51:03, 35.62s/it]

Epoch 14, Validation Loss: 2.0975, Validation Precision: 0.4987
Epoch 15, Train Loss: 1.8812, Precision: 0.5331


Epochs:  15%|█▌        | 15/100 [08:52<50:25, 35.60s/it]

Epoch 15, Validation Loss: 2.0531, Validation Precision: 0.5044
Epoch 16, Train Loss: 1.8524, Precision: 0.5383


Epochs:  16%|█▌        | 16/100 [09:28<49:53, 35.64s/it]

Epoch 16, Validation Loss: 2.0269, Validation Precision: 0.5087
Epoch 17, Train Loss: 1.8251, Precision: 0.5431


Epochs:  17%|█▋        | 17/100 [10:03<49:19, 35.65s/it]

Epoch 17, Validation Loss: 2.0044, Validation Precision: 0.5133
Epoch 18, Train Loss: 1.7993, Precision: 0.5478


Epochs:  18%|█▊        | 18/100 [10:39<48:41, 35.63s/it]

Epoch 18, Validation Loss: 1.9764, Validation Precision: 0.5182
Epoch 19, Train Loss: 1.7754, Precision: 0.5523


Epochs:  19%|█▉        | 19/100 [11:15<48:05, 35.62s/it]

Epoch 19, Validation Loss: 1.9605, Validation Precision: 0.5212
Epoch 20, Train Loss: 1.7525, Precision: 0.5565


Epochs:  20%|██        | 20/100 [11:50<47:30, 35.63s/it]

Epoch 20, Validation Loss: 1.9364, Validation Precision: 0.5255
Epoch 21, Train Loss: 1.7311, Precision: 0.5603


Epochs:  21%|██        | 21/100 [12:26<46:54, 35.62s/it]

Epoch 21, Validation Loss: 1.9167, Validation Precision: 0.5285
Epoch 22, Train Loss: 1.7115, Precision: 0.5639


Epochs:  22%|██▏       | 22/100 [13:01<46:16, 35.59s/it]

Epoch 22, Validation Loss: 1.8999, Validation Precision: 0.5317
Epoch 23, Train Loss: 1.6926, Precision: 0.5673


Epochs:  23%|██▎       | 23/100 [13:37<45:39, 35.57s/it]

Epoch 23, Validation Loss: 1.8781, Validation Precision: 0.5349
Epoch 24, Train Loss: 1.6748, Precision: 0.5705


Epochs:  24%|██▍       | 24/100 [14:12<45:02, 35.56s/it]

Epoch 24, Validation Loss: 1.8597, Validation Precision: 0.5381
Epoch 25, Train Loss: 1.6581, Precision: 0.5736


Epochs:  25%|██▌       | 25/100 [14:48<44:27, 35.57s/it]

Epoch 25, Validation Loss: 1.8586, Validation Precision: 0.5403
Epoch 26, Train Loss: 1.6423, Precision: 0.5764


Epochs:  26%|██▌       | 26/100 [15:23<43:48, 35.53s/it]

Epoch 26, Validation Loss: 1.8329, Validation Precision: 0.5437
Epoch 27, Train Loss: 1.6273, Precision: 0.5791


Epochs:  27%|██▋       | 27/100 [15:59<43:11, 35.51s/it]

Epoch 27, Validation Loss: 1.8179, Validation Precision: 0.5458
Epoch 28, Train Loss: 1.6130, Precision: 0.5817


Epochs:  28%|██▊       | 28/100 [16:34<42:37, 35.52s/it]

Epoch 28, Validation Loss: 1.8125, Validation Precision: 0.5467
Epoch 29, Train Loss: 1.5998, Precision: 0.5842


Epochs:  29%|██▉       | 29/100 [17:10<42:05, 35.57s/it]

Epoch 29, Validation Loss: 1.8039, Validation Precision: 0.5476
Epoch 30, Train Loss: 1.5869, Precision: 0.5864


Epochs:  30%|███       | 30/100 [17:46<41:29, 35.57s/it]

Epoch 30, Validation Loss: 1.7924, Validation Precision: 0.5507
Epoch 31, Train Loss: 1.5746, Precision: 0.5888


Epochs:  31%|███       | 31/100 [18:21<40:51, 35.53s/it]

Epoch 31, Validation Loss: 1.7865, Validation Precision: 0.5521
Epoch 32, Train Loss: 1.5625, Precision: 0.5910


Epochs:  32%|███▏      | 32/100 [18:57<40:13, 35.50s/it]

Epoch 32, Validation Loss: 1.7723, Validation Precision: 0.5546
Epoch 33, Train Loss: 1.5515, Precision: 0.5931


Epochs:  33%|███▎      | 33/100 [19:32<39:40, 35.53s/it]

Epoch 33, Validation Loss: 1.7631, Validation Precision: 0.5558
Epoch 34, Train Loss: 1.5410, Precision: 0.5950


Epochs:  34%|███▍      | 34/100 [20:08<39:03, 35.50s/it]

Epoch 34, Validation Loss: 1.7606, Validation Precision: 0.5565
Epoch 35, Train Loss: 1.5309, Precision: 0.5971


Epochs:  35%|███▌      | 35/100 [20:43<38:28, 35.51s/it]

Epoch 35, Validation Loss: 1.7587, Validation Precision: 0.5569
Epoch 36, Train Loss: 1.5202, Precision: 0.5989


Epochs:  36%|███▌      | 36/100 [21:19<37:52, 35.52s/it]

Epoch 36, Validation Loss: 1.7535, Validation Precision: 0.5583
Epoch 37, Train Loss: 1.5115, Precision: 0.6007


Epochs:  37%|███▋      | 37/100 [21:54<37:18, 35.54s/it]

Epoch 37, Validation Loss: 1.7386, Validation Precision: 0.5608
Epoch 38, Train Loss: 1.5030, Precision: 0.6023


Epochs:  38%|███▊      | 38/100 [22:30<36:43, 35.55s/it]

Epoch 38, Validation Loss: 1.7350, Validation Precision: 0.5617
Epoch 39, Train Loss: 1.4933, Precision: 0.6039


Epochs:  39%|███▉      | 39/100 [23:05<36:11, 35.60s/it]

Epoch 39, Validation Loss: 1.7284, Validation Precision: 0.5624
Epoch 40, Train Loss: 1.4855, Precision: 0.6057


Epochs:  40%|████      | 40/100 [23:41<35:38, 35.65s/it]

Epoch 40, Validation Loss: 1.7296, Validation Precision: 0.5619
Epoch 41, Train Loss: 1.4773, Precision: 0.6073


Epochs:  41%|████      | 41/100 [24:17<35:04, 35.67s/it]

Epoch 41, Validation Loss: 1.7220, Validation Precision: 0.5638
Epoch 42, Train Loss: 1.4686, Precision: 0.6088


Epochs:  42%|████▏     | 42/100 [24:53<34:29, 35.68s/it]

Epoch 42, Validation Loss: 1.7153, Validation Precision: 0.5658
Epoch 43, Train Loss: 1.4616, Precision: 0.6103


Epochs:  43%|████▎     | 43/100 [25:28<33:54, 35.69s/it]

Epoch 43, Validation Loss: 1.7176, Validation Precision: 0.5648
Epoch 44, Train Loss: 1.4539, Precision: 0.6117


Epochs:  44%|████▍     | 44/100 [26:04<33:15, 35.63s/it]

Epoch 44, Validation Loss: 1.7054, Validation Precision: 0.5676
Epoch 45, Train Loss: 1.4461, Precision: 0.6131


Epochs:  45%|████▌     | 45/100 [26:39<32:36, 35.58s/it]

Epoch 45, Validation Loss: 1.7092, Validation Precision: 0.5677
Epoch 46, Train Loss: 1.4398, Precision: 0.6145


Epochs:  46%|████▌     | 46/100 [27:15<32:00, 35.57s/it]

Epoch 46, Validation Loss: 1.7075, Validation Precision: 0.5665
Epoch 47, Train Loss: 1.4321, Precision: 0.6160


Epochs:  47%|████▋     | 47/100 [27:50<31:24, 35.56s/it]

Epoch 47, Validation Loss: 1.6915, Validation Precision: 0.5692
Epoch 48, Train Loss: 1.4260, Precision: 0.6171


Epochs:  48%|████▊     | 48/100 [28:26<30:47, 35.52s/it]

Epoch 48, Validation Loss: 1.6988, Validation Precision: 0.5691
Epoch 49, Train Loss: 1.4195, Precision: 0.6183


Epochs:  49%|████▉     | 49/100 [29:01<30:12, 35.55s/it]

Epoch 49, Validation Loss: 1.6827, Validation Precision: 0.5718
Epoch 50, Train Loss: 1.4134, Precision: 0.6195


Epochs:  50%|█████     | 50/100 [29:37<29:37, 35.55s/it]

Epoch 50, Validation Loss: 1.6898, Validation Precision: 0.5706
Epoch 51, Train Loss: 1.4072, Precision: 0.6209


Epochs:  51%|█████     | 51/100 [30:13<29:05, 35.62s/it]

Epoch 51, Validation Loss: 1.6842, Validation Precision: 0.5717
Epoch 52, Train Loss: 1.4015, Precision: 0.6221


Epochs:  52%|█████▏    | 52/100 [30:48<28:29, 35.62s/it]

Epoch 52, Validation Loss: 1.6809, Validation Precision: 0.5714
Epoch 53, Train Loss: 1.3959, Precision: 0.6230


Epochs:  53%|█████▎    | 53/100 [31:24<27:52, 35.58s/it]

Epoch 53, Validation Loss: 1.6796, Validation Precision: 0.5717
Epoch 54, Train Loss: 1.3898, Precision: 0.6243


Epochs:  54%|█████▍    | 54/100 [32:00<27:17, 35.59s/it]

Epoch 54, Validation Loss: 1.6717, Validation Precision: 0.5718
Epoch 55, Train Loss: 1.3849, Precision: 0.6254


Epochs:  55%|█████▌    | 55/100 [32:35<26:42, 35.61s/it]

Epoch 55, Validation Loss: 1.6785, Validation Precision: 0.5726
Epoch 56, Train Loss: 1.3788, Precision: 0.6263


Epochs:  56%|█████▌    | 56/100 [33:11<26:10, 35.69s/it]

Epoch 56, Validation Loss: 1.6757, Validation Precision: 0.5731
Epoch 57, Train Loss: 1.3739, Precision: 0.6274


Epochs:  57%|█████▋    | 57/100 [33:47<25:33, 35.67s/it]

Epoch 57, Validation Loss: 1.6714, Validation Precision: 0.5740
Epoch 58, Train Loss: 1.3689, Precision: 0.6285


Epochs:  58%|█████▊    | 58/100 [34:22<24:57, 35.66s/it]

Epoch 58, Validation Loss: 1.6767, Validation Precision: 0.5739
Epoch 59, Train Loss: 1.3641, Precision: 0.6294


Epochs:  59%|█████▉    | 59/100 [34:58<24:21, 35.64s/it]

Epoch 59, Validation Loss: 1.6803, Validation Precision: 0.5726
Epoch 60, Train Loss: 1.3594, Precision: 0.6304


Epochs:  60%|██████    | 60/100 [35:33<23:44, 35.61s/it]

Epoch 60, Validation Loss: 1.6703, Validation Precision: 0.5744
Epoch 61, Train Loss: 1.3546, Precision: 0.6312


Epochs:  61%|██████    | 61/100 [36:09<23:08, 35.61s/it]

Epoch 61, Validation Loss: 1.6686, Validation Precision: 0.5749
Epoch 62, Train Loss: 1.3497, Precision: 0.6322


Epochs:  62%|██████▏   | 62/100 [36:45<22:33, 35.62s/it]

Epoch 62, Validation Loss: 1.6759, Validation Precision: 0.5739
Epoch 63, Train Loss: 1.3461, Precision: 0.6332


Epochs:  63%|██████▎   | 63/100 [37:20<21:59, 35.66s/it]

Epoch 63, Validation Loss: 1.6646, Validation Precision: 0.5755
Epoch 64, Train Loss: 1.3411, Precision: 0.6341


Epochs:  64%|██████▍   | 64/100 [37:56<21:26, 35.73s/it]

Epoch 64, Validation Loss: 1.6779, Validation Precision: 0.5743
Epoch 65, Train Loss: 1.3369, Precision: 0.6349


Epochs:  65%|██████▌   | 65/100 [38:32<20:51, 35.74s/it]

Epoch 65, Validation Loss: 1.6723, Validation Precision: 0.5751
Epoch 66, Train Loss: 1.3330, Precision: 0.6357


Epochs:  66%|██████▌   | 66/100 [39:08<20:15, 35.75s/it]

Epoch 66, Validation Loss: 1.6677, Validation Precision: 0.5761
Epoch 67, Train Loss: 1.3286, Precision: 0.6366


Epochs:  67%|██████▋   | 67/100 [39:44<19:38, 35.72s/it]

Epoch 67, Validation Loss: 1.6629, Validation Precision: 0.5765
Epoch 68, Train Loss: 1.3246, Precision: 0.6375


Epochs:  68%|██████▊   | 68/100 [40:19<19:01, 35.67s/it]

Epoch 68, Validation Loss: 1.6661, Validation Precision: 0.5766
Epoch 69, Train Loss: 1.3201, Precision: 0.6381


Epochs:  69%|██████▉   | 69/100 [40:55<18:25, 35.66s/it]

Epoch 69, Validation Loss: 1.6761, Validation Precision: 0.5750
Epoch 70, Train Loss: 1.3169, Precision: 0.6389


Epochs:  70%|███████   | 70/100 [41:30<17:49, 35.66s/it]

Epoch 70, Validation Loss: 1.6713, Validation Precision: 0.5748
Epoch 71, Train Loss: 1.3130, Precision: 0.6399


Epochs:  71%|███████   | 71/100 [42:06<17:14, 35.68s/it]

Epoch 71, Validation Loss: 1.6658, Validation Precision: 0.5771
Epoch 72, Train Loss: 1.3090, Precision: 0.6405


Epochs:  72%|███████▏  | 72/100 [42:42<16:37, 35.63s/it]

Epoch 72, Validation Loss: 1.6550, Validation Precision: 0.5792
Epoch 73, Train Loss: 1.3051, Precision: 0.6413


Epochs:  73%|███████▎  | 73/100 [43:17<16:02, 35.66s/it]

Epoch 73, Validation Loss: 1.6592, Validation Precision: 0.5784
Epoch 74, Train Loss: 1.3014, Precision: 0.6423


Epochs:  74%|███████▍  | 74/100 [43:53<15:27, 35.69s/it]

Epoch 74, Validation Loss: 1.6625, Validation Precision: 0.5789
Epoch 75, Train Loss: 1.2983, Precision: 0.6428


Epochs:  75%|███████▌  | 75/100 [44:29<14:51, 35.65s/it]

Epoch 75, Validation Loss: 1.6578, Validation Precision: 0.5784
Epoch 76, Train Loss: 1.2945, Precision: 0.6435


Epochs:  76%|███████▌  | 76/100 [45:04<14:14, 35.60s/it]

Epoch 76, Validation Loss: 1.6732, Validation Precision: 0.5769
Epoch 77, Train Loss: 1.2912, Precision: 0.6444


Epochs:  77%|███████▋  | 77/100 [45:40<13:37, 35.56s/it]

Epoch 77, Validation Loss: 1.6585, Validation Precision: 0.5788
Epoch 78, Train Loss: 1.2881, Precision: 0.6450


Epochs:  78%|███████▊  | 78/100 [46:15<13:01, 35.51s/it]

Epoch 78, Validation Loss: 1.6651, Validation Precision: 0.5783
Epoch 79, Train Loss: 1.2848, Precision: 0.6457


Epochs:  79%|███████▉  | 79/100 [46:51<12:26, 35.56s/it]

Epoch 79, Validation Loss: 1.6646, Validation Precision: 0.5781
Epoch 80, Train Loss: 1.2817, Precision: 0.6462


Epochs:  80%|████████  | 80/100 [47:26<11:51, 35.58s/it]

Epoch 80, Validation Loss: 1.6697, Validation Precision: 0.5784
Epoch 81, Train Loss: 1.2787, Precision: 0.6469


Epochs:  81%|████████  | 81/100 [48:02<11:15, 35.54s/it]

Epoch 81, Validation Loss: 1.6692, Validation Precision: 0.5774
Epoch 82, Train Loss: 1.2758, Precision: 0.6474


Epochs:  82%|████████▏ | 82/100 [48:37<10:39, 35.54s/it]

Epoch 82, Validation Loss: 1.6616, Validation Precision: 0.5778
Epoch 83, Train Loss: 1.2725, Precision: 0.6481


Epochs:  83%|████████▎ | 83/100 [49:13<10:04, 35.57s/it]

Epoch 83, Validation Loss: 1.6527, Validation Precision: 0.5806
Epoch 84, Train Loss: 1.2689, Precision: 0.6489


Epochs:  84%|████████▍ | 84/100 [49:49<09:28, 35.56s/it]

Epoch 84, Validation Loss: 1.6725, Validation Precision: 0.5790
Epoch 85, Train Loss: 1.2659, Precision: 0.6494


Epochs:  85%|████████▌ | 85/100 [50:24<08:53, 35.56s/it]

Epoch 85, Validation Loss: 1.6652, Validation Precision: 0.5792
Epoch 86, Train Loss: 1.2639, Precision: 0.6500


Epochs:  86%|████████▌ | 86/100 [51:00<08:17, 35.55s/it]

Epoch 86, Validation Loss: 1.6742, Validation Precision: 0.5777
Epoch 87, Train Loss: 1.2608, Precision: 0.6506


Epochs:  87%|████████▋ | 87/100 [51:35<07:41, 35.52s/it]

Epoch 87, Validation Loss: 1.6744, Validation Precision: 0.5777
Epoch 88, Train Loss: 1.2571, Precision: 0.6513


Epochs:  88%|████████▊ | 88/100 [52:10<07:05, 35.49s/it]

Epoch 88, Validation Loss: 1.6614, Validation Precision: 0.5798
Epoch 89, Train Loss: 1.2548, Precision: 0.6519


Epochs:  89%|████████▉ | 89/100 [52:46<06:30, 35.53s/it]

Epoch 89, Validation Loss: 1.6725, Validation Precision: 0.5783
Epoch 90, Train Loss: 1.2530, Precision: 0.6522


Epochs:  90%|█████████ | 90/100 [53:22<05:54, 35.50s/it]

Epoch 90, Validation Loss: 1.6762, Validation Precision: 0.5792
Epoch 91, Train Loss: 1.2499, Precision: 0.6528


Epochs:  91%|█████████ | 91/100 [53:57<05:19, 35.48s/it]

Epoch 91, Validation Loss: 1.6660, Validation Precision: 0.5799
Epoch 92, Train Loss: 1.2474, Precision: 0.6536


Epochs:  92%|█████████▏| 92/100 [54:32<04:43, 35.46s/it]

Epoch 92, Validation Loss: 1.6820, Validation Precision: 0.5769
Epoch 93, Train Loss: 1.2441, Precision: 0.6541


Epochs:  93%|█████████▎| 93/100 [55:08<04:08, 35.50s/it]

Epoch 93, Validation Loss: 1.6853, Validation Precision: 0.5765
Epoch 94, Train Loss: 1.2420, Precision: 0.6547


Epochs:  94%|█████████▍| 94/100 [55:44<03:33, 35.52s/it]

Epoch 94, Validation Loss: 1.6771, Validation Precision: 0.5775
Epoch 95, Train Loss: 1.2388, Precision: 0.6552


Epochs:  95%|█████████▌| 95/100 [56:19<02:57, 35.47s/it]

Epoch 95, Validation Loss: 1.6656, Validation Precision: 0.5785
Epoch 96, Train Loss: 1.2367, Precision: 0.6557


Epochs:  96%|█████████▌| 96/100 [56:54<02:21, 35.48s/it]

Epoch 96, Validation Loss: 1.6687, Validation Precision: 0.5796
Epoch 97, Train Loss: 1.2344, Precision: 0.6563


Epochs:  97%|█████████▋| 97/100 [57:30<01:46, 35.52s/it]

Epoch 97, Validation Loss: 1.6690, Validation Precision: 0.5799
Epoch 98, Train Loss: 1.2313, Precision: 0.6569


Epochs:  98%|█████████▊| 98/100 [58:05<01:11, 35.50s/it]

Epoch 98, Validation Loss: 1.6801, Validation Precision: 0.5781
Epoch 99, Train Loss: 1.2296, Precision: 0.6572


Epochs:  99%|█████████▉| 99/100 [58:41<00:35, 35.48s/it]

Epoch 99, Validation Loss: 1.6857, Validation Precision: 0.5778
Epoch 100, Train Loss: 1.2273, Precision: 0.6577


Epochs: 100%|██████████| 100/100 [59:16<00:00, 35.57s/it]

Epoch 100, Validation Loss: 1.6781, Validation Precision: 0.5795





## Evaluation and Audio Generation

In [11]:
# Evaluate the model and generate audio from the trained Transformer.
model.eval()
num_samples = 5
seq_len = LEVELS * TIMESTEPS
temperature = 1.0
os.makedirs('generated_audio', exist_ok=True)

for i in range(num_samples):
    print(f"Generating sample {i+1}/{num_samples}")
    start_token = 0
    generated = [start_token]
    for _ in tqdm(range(seq_len), desc="Generating Tokens", leave=False):
        input_seq = torch.tensor([generated], dtype=torch.long).to(device)
        with torch.no_grad():
            logits = model(input_seq)[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
    generated_sequence = torch.tensor(generated[1:], dtype=torch.long).to(device)
    codes = generated_sequence.view(1, -1, LEVELS).transpose(1, 2)

    with torch.no_grad():
        decoded_audio = codec.decode([(codes, None)])
    decoded_audio = decoded_audio.squeeze().cpu().numpy().astype(np.float32)

    output_filename = f'generated_audio/sample_{i+1}.wav'
    sf.write(output_filename, decoded_audio, samplerate=codec.sample_rate)
    print(f"Saved {output_filename}")

Generating sample 1/5


                                                                     

Saved generated_audio/sample_1.wav
Generating sample 2/5


                                                                     

Saved generated_audio/sample_2.wav
Generating sample 3/5


                                                                     

Saved generated_audio/sample_3.wav
Generating sample 4/5


                                                                     

Saved generated_audio/sample_4.wav
Generating sample 5/5


                                                                     

Saved generated_audio/sample_5.wav


## Play Generated Audio

In [12]:
# Use IPython audio player to play generated audio samples.
for i in range(1, num_samples + 1):
    output_filename = f'generated_audio/sample_{i}.wav'
    print(f"Playing {output_filename}")
    ipd.display(ipd.Audio(output_filename))

Playing generated_audio/sample_1.wav


Playing generated_audio/sample_2.wav


Playing generated_audio/sample_3.wav


Playing generated_audio/sample_4.wav


Playing generated_audio/sample_5.wav
