# Jupyter Notebook for Autoregressive Audio Generation

## Install Necessary Libraries

In this section, we install the required libraries for our project:
- `x-transformers` for building our Transformer model.
- `encodec` for using the EnCodec neural audio codec.

In [None]:
# Install the required libraries for the project.
!pip install x-transformers encodec

## Import Libraries

We import all the necessary modules and libraries needed for training and evaluation, such as PyTorch, `x-transformers`, and `encodec`.

In [2]:
# Import necessary modules for training and evaluation.
import torch
from tqdm import tqdm
from x_transformers import TransformerWrapper, Decoder
from encodec import EncodecModel
from torch import nn
import soundfile as sf
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
from encodec.utils import convert_audio
import IPython.display as ipd

## Set Parameters

Here, we define global parameters for our project, such as the target bandwidth for EnCodec, the number of levels, and the number of timesteps for our sequences. We also set the device to GPU if available.

In [3]:
# Define some global parameters for the project.
BANDWIDTH = 1.5  # Target bandwidth in kbps for EnCodec
LEVELS = 2       # Number of codebooks used in EnCodec (2 for bandwidth 1.5)
TIMESTEPS = 125  # Number of timesteps in each sequence

# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Download NSYNTH_GUITAR_MP3 dataset

We download the guitar subset of the NSynth dataset, which we'll use for training our model. The dataset contains audio samples of guitar notes.

In [4]:
# Clone the repository containing the NSYNTH_GUITAR_MP3 dataset
!git clone https://github.com/SonyCSLParis/test-lfs.git
# Run the download script to get the dataset
!bash ./test-lfs/download.sh NSYNTH_GUITAR_MP3

Cloning into 'test-lfs'...
remote: Enumerating objects: 42, done.[K
remote: Counting objects: 100% (42/42), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 42 (delta 5), reused 40 (delta 3), pack-reused 0 (from 0)[K
Unpacking objects: 100% (42/42), 5.92 KiB | 433.00 KiB/s, done.
--2024-10-21 13:37:15--  https://media.githubusercontent.com/media/SonyCSLParis/test-lfs/refs/heads/master/NSYNTH_GUITAR_MP3.zip
Resolving media.githubusercontent.com (media.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to media.githubusercontent.com (media.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 334999208 (319M) [application/zip]
Saving to: ‘NSYNTH_GUITAR_MP3.zip’


2024-10-21 13:37:57 (64.1 MB/s) - ‘NSYNTH_GUITAR_MP3.zip’ saved [334999208/334999208]

Fix archive (-F) - assume mostly intact archive
Zip entry offsets do not need adjusting
 copying:


  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_004-108-050.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_024-097-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_014-074-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_000-098-050.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_013-067-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_034-092-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_003-091-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_011-107-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_005-021-100.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_020-047-050.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-test/guitar_acoustic_010-098-100.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guita

## Dataset Class Definition

We define a custom `Dataset` class `DiscreteAudioRepDataset` to handle the discrete audio representations. This class loads audio files, encodes them using EnCodec, and provides sequences of discrete tokens suitable for training our Transformer model.

In [5]:
# Define a dataset class to handle discrete audio representations.
class DiscreteAudioRepDataset(Dataset):
    def __init__(self, root_dir, encoder, lazy_encode=True,
                 extensions=[".wav", ".mp3", ".flac"], max_samples=-1):
        """
        Args:
            root_dir (string): Directory with all the audio files.
            encoder: The EnCodec model for encoding audio.
            lazy_encode (bool): If True, encodes audio on-demand (when __getitem__ is called).
                               If False, encodes all audio at initialization.
            extensions (list): List of audio file extensions to include.
            max_samples (int): Maximum number of samples to load. -1 for all samples.
        """
        self.root_dir = root_dir
        self.extensions = extensions
        self.encoder = encoder
        self.lazy_encode = lazy_encode
        self.audio_files = []

        # Walk through all subfolders to gather audio files
        for root, _, files in os.walk(root_dir):
            for file in files:
                if any(file.endswith(ext) for ext in self.extensions):
                    self.audio_files.append(os.path.join(root, file))

        if max_samples < 0:
            max_samples = len(self.audio_files)

        self.audio_files = self.audio_files[:max_samples]

        # If not lazy encoding, encode all audio files during initialization
        if not self.lazy_encode:
            self.encoded_data = [self._encode_audio(file) for file in tqdm(self.audio_files, desc="Encoding Audio")]

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        if self.lazy_encode:
            # Encode the audio file on-the-fly
            filename = self.audio_files[idx]
            return self._encode_audio(filename)[..., :TIMESTEPS*LEVELS]
        else:
            # Return the pre-encoded audio data
            encoded_audio = self.encoded_data[idx]
            return encoded_audio[:TIMESTEPS * LEVELS]

    def _encode_audio(self, filename):
        # Read the audio file
        waveform, sample_rate = sf.read(filename)
        # Convert to tensor and add batch and channel dimensions
        waveform = torch.tensor(waveform, dtype=torch.float32)[None, None, :]
        # Resample and adjust channels if necessary
        waveform = convert_audio(waveform, sample_rate, self.encoder.sample_rate, self.encoder.channels)

        with torch.no_grad():
            # Encode the audio using EnCodec
            encoded_frames = self.encoder.encode(waveform.to(device))

        # Flatten the codes from multiple codebooks
        codes = encoded_frames[0][0].contiguous().permute(0, 2, 1).reshape(-1)
        return codes.cpu()


## Load EnCodec Model and Convert Files

We load the EnCodec model, which is a neural audio codec that compresses audio signals into discrete latent codes. We set the target bandwidth and prepare our datasets for training and validation by encoding the audio files into discrete tokens.

In [6]:
# Load the EnCodec model to transform the audio to discrete representation.
codec = EncodecModel.encodec_model_24khz().to(device)
codec.set_target_bandwidth(BANDWIDTH)  # Set the target bandwidth (e.g., 1.5 kbps)

# Load Dataset
# Prepare the training and validation datasets by encoding the audio files
audio_folder_train = "./NSYNTH_GUITAR_MP3/nsynth-guitar-train"
audio_folder_val = "./NSYNTH_GUITAR_MP3/nsynth-guitar-valid"

dataset = DiscreteAudioRepDataset(root_dir=audio_folder_train, encoder=codec,
                                  lazy_encode=False, max_samples=-1)

dataset_val = DiscreteAudioRepDataset(root_dir=audio_folder_val, encoder=codec,
                                      lazy_encode=False, max_samples=-1)

Encoding Audio: 100%|██████████| 32690/32690 [05:42<00:00, 95.57it/s] 
Encoding Audio: 100%|██████████| 2081/2081 [00:22<00:00, 93.69it/s]


In [7]:
# Create Dataloaders for training and validation.
dataloader = DataLoader(dataset, batch_size=125, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=125, shuffle=True)

## Define Transformer Model

We define our Transformer model using the `x-transformers` library. The model is an autoregressive Transformer decoder with rotary positional embeddings, suitable for modeling sequences of discrete tokens.

In [8]:
# Define the Transformer model using TransformerWrapper and Decoder.
model = TransformerWrapper(
    emb_dropout=0.1,
    num_tokens=1024,              # Vocabulary size from EnCodec (number of possible tokens)
    max_seq_len=LEVELS*TIMESTEPS, # Maximum sequence length
    attn_layers=Decoder(
        dim=256,                  # Dimension of the model
        depth=6,                  # Number of Transformer layers
        heads=4,                  # Number of attention heads
        rotary_pos_emb=True,      # Use rotary positional embeddings
        attn_dropout=0.1,
        ff_dropout=0.1
    )
).to(device)

## (Optional) Load Pretrained Weights

If we have a pretrained model saved, we can load its weights to continue training or for inference.

In [None]:
# Optionally load pretrained weights
# Uncomment the following line if you have a pretrained model saved
# model.load_state_dict(torch.load('./model_gen_autoreg_transformer.pth', map_location=device))

## Training and Validation loop

We define the training loop for our Transformer model. The model is trained to minimize the cross-entropy loss between the predicted token distribution and the true next token in the sequence. We use teacher forcing during training, providing the true previous tokens as input.

In [10]:
# Define the training loop for the Transformer model.
epochs = 100
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

for epoch in tqdm(range(epochs), desc="Epochs"):
    train_loss, val_loss = 0, 0
    model.train()
    total_correct, total_predictions = 0, 0
    count = 0

    # Training loop
    for batch in dataloader:
        # Ensure batch is of type long (integer tokens)
        if batch.dtype != torch.long:
            batch = batch.long()
        # Add a start token at the beginning of each sequence
        start_tokens = torch.zeros((batch.shape[0], 1), dtype=torch.long, device=batch.device)
        batch = torch.cat([start_tokens, batch], dim=1)
        discrete_reps = batch.to(device)

        # Forward pass
        logits = model(discrete_reps)
        logits = logits.permute(0, 2, 1)  # Rearrange logits for loss computation
        
        # Compute loss (excluding the last token)
        loss = criterion(logits[..., :-1], discrete_reps[..., 1:])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
        count += 1

        # Compute training precision
        preds = logits[..., :-1].argmax(dim=1)
        targets = discrete_reps[..., 1:]
        correct = (preds == targets).sum().item()
        total_correct += correct
        total_predictions += targets.numel()

    avg_train_loss = train_loss / count
    precision = total_correct / total_predictions
    print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss:.4f}, Precision: {precision:.4f}')
    
    # Save the model after each epoch
    torch.save(model.state_dict(), 'model_gen_autoreg_transformer.pth')
    
    # Validation loop
    model.eval()
    val_loss = 0
    val_total_correct, val_total_predictions = 0, 0
    count_val = 0

    with torch.no_grad():
        for batch_val in dataloader_val:
            if batch_val.dtype != torch.long:
                batch_val = batch_val.long()
            start_tokens_val = torch.zeros((batch_val.shape[0], 1), dtype=torch.long, device=batch_val.device)
            batch_val = torch.cat([start_tokens_val, batch_val], dim=1)
            discrete_reps_val = batch_val.to(device)

            # Forward pass
            logits_val = model(discrete_reps_val)
            logits_val = logits_val.permute(0, 2, 1)
            loss_val = criterion(logits_val[..., :-1], discrete_reps_val[..., 1:])
            val_loss += loss_val.item()
            count_val += 1

            # Compute validation precision
            preds_val = logits_val[..., :-1].argmax(dim=1)
            targets_val = discrete_reps_val[..., 1:]
            correct_val = (preds_val == targets_val).sum().item()
            val_total_correct += correct_val
            val_total_predictions += targets_val.numel()

    avg_val_loss = val_loss / count_val
    val_precision = val_total_correct / val_total_predictions
    print(f'Epoch {epoch + 1}, Validation Loss: {avg_val_loss:.4f}, Validation Precision: {val_precision:.4f}')


Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1, Train Loss: 4.6227, Precision: 0.2465


Epochs:  10%|█         | 10/100 [05:54<53:19, 35.55s/it]

Epoch 10, Validation Loss: 2.2304, Validation Precision: 0.4731
Epoch 11, Train Loss: 2.0314, Precision: 0.5060


Epochs:  20%|██        | 20/100 [11:50<47:30, 35.63s/it]

Epoch 20, Validation Loss: 1.9364, Validation Precision: 0.5255
Epoch 21, Train Loss: 1.7311, Precision: 0.5603


Epochs:  30%|███       | 30/100 [17:46<41:29, 35.57s/it]

Epoch 30, Validation Loss: 1.7924, Validation Precision: 0.5507
Epoch 31, Train Loss: 1.5746, Precision: 0.5888


Epochs:  40%|████      | 40/100 [23:41<35:38, 35.65s/it]

Epoch 40, Validation Loss: 1.7296, Validation Precision: 0.5619
Epoch 41, Train Loss: 1.4773, Precision: 0.6073


Epochs:  50%|█████     | 50/100 [29:37<29:37, 35.55s/it]

Epoch 50, Validation Loss: 1.6898, Validation Precision: 0.5706
Epoch 51, Train Loss: 1.4072, Precision: 0.6209


Epochs:  60%|██████    | 60/100 [35:33<23:44, 35.61s/it]

Epoch 60, Validation Loss: 1.6703, Validation Precision: 0.5744
Epoch 61, Train Loss: 1.3546, Precision: 0.6312


Epochs:  70%|███████   | 70/100 [41:30<17:49, 35.66s/it]

Epoch 70, Validation Loss: 1.6713, Validation Precision: 0.5748
Epoch 71, Train Loss: 1.3130, Precision: 0.6399


Epochs:  80%|████████  | 80/100 [47:26<11:51, 35.58s/it]

Epoch 80, Validation Loss: 1.6697, Validation Precision: 0.5784
Epoch 81, Train Loss: 1.2787, Precision: 0.6469


Epochs:  90%|█████████ | 90/100 [53:22<05:54, 35.50s/it]

Epoch 90, Validation Loss: 1.6762, Validation Precision: 0.5792
Epoch 91, Train Loss: 1.2499, Precision: 0.6528


Epoch 99, Validation Loss: 1.6857, Validation Precision: 0.5778
Epoch 100, Train Loss: 1.2273, Precision: 0.6577


Epochs: 100%|██████████| 100/100 [59:16<00:00, 35.57s/it]

Epoch 100, Validation Loss: 1.6781, Validation Precision: 0.5795





## Audio Generation

After training, we use the model to generate new audio sequences. We start with a start token and iteratively sample the next token from the model's output distribution. The generated sequence of tokens is then decoded back into audio using EnCodec.

In [11]:
# Generate audio from the trained Transformer model.
model.eval()
num_samples = 5             # Number of audio samples to generate
seq_len = LEVELS * TIMESTEPS  # Length of each generated sequence
temperature = 1.0           # Sampling temperature
os.makedirs('generated_audio', exist_ok=True)

for i in range(num_samples):
    print(f"Generating sample {i+1}/{num_samples}")
    start_token = 0
    generated = [start_token]
    
    # Generate a sequence of tokens
    for _ in tqdm(range(seq_len), desc="Generating Tokens", leave=False):
        input_seq = torch.tensor([generated], dtype=torch.long).to(device)
        with torch.no_grad():
            logits = model(input_seq)[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            # Sample the next token
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
    
    # Remove the start token and convert to tensor
    generated_sequence = torch.tensor(generated[1:], dtype=torch.long).to(device)
    
    # Reshape the sequence to match EnCodec's expected input shape
    codes = generated_sequence.view(1, -1, LEVELS).transpose(1, 2)

    with torch.no_grad():
        # Decode the sequence of tokens back into audio waveform
        decoded_audio = codec.decode([(codes, None)])
    decoded_audio = decoded_audio.squeeze().cpu().numpy().astype(np.float32)

    # Save the generated audio to a WAV file
    output_filename = f'generated_audio/sample_{i+1}.wav'
    sf.write(output_filename, decoded_audio, samplerate=codec.sample_rate)
    print(f"Saved {output_filename}")

Generating sample 1/5


                                                                     

Saved generated_audio/sample_1.wav
Generating sample 2/5


                                                                     

Saved generated_audio/sample_2.wav
Generating sample 3/5


                                                                     

Saved generated_audio/sample_3.wav
Generating sample 4/5


                                                                     

Saved generated_audio/sample_4.wav
Generating sample 5/5


                                                                     

Saved generated_audio/sample_5.wav


## Play Generated Audio

Finally, we can load and play the generated audio samples to listen to the results of our autoregressive generation.

In [12]:
# Use IPython audio player to play generated audio samples.
for i in range(1, num_samples + 1):
    output_filename = f'generated_audio/autoreg_sample_{i}.wav'
    print(f"Playing {output_filename}")
    ipd.display(ipd.Audio(output_filename))

Playing generated_audio/sample_1.wav


Playing generated_audio/sample_2.wav


Playing generated_audio/sample_3.wav


Playing generated_audio/sample_4.wav


Playing generated_audio/sample_5.wav
