In [None]:
# Unified Jupyter Notebook for Discrete Audio Generation

# ## Install Necessary Libraries
# First, we need to install the required libraries for the project.
!pip install torchaudio x-transformers encodec

In [None]:
## Import Libraries
# Import necessary modules for training and evaluation.
import torch
from tqdm import tqdm
from x_transformers import TransformerWrapper, Decoder
from encodec import EncodecModel
from torch import nn
import soundfile as sf
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
from encodec.utils import convert_audio
import IPython.display as ipd

In [2]:
# ## Set Parameters
# Define some global parameters for the project.
LEVELS = 2
TIMESTEPS = 125

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# ## Dataset Class Definition
# Define a dataset class to handle discrete audio representations.
class DiscreteAudioRepDataset(Dataset):
    def __init__(self, root_dir, model, lazy_encode=True, extensions=[".wav", ".mp3", ".flac"]):
        """
        Args:
            root_dir (string): Directory with all the audio files.
            model: The EnCodec model for encoding audio.
            lazy_encode (bool): If True, encodes audio on-demand (when __getitem__ is called).
                               If False, encodes all audio at initialization.
            extensions (list): List of audio file extensions to include.
        """
        self.root_dir = root_dir
        self.extensions = extensions
        self.model = model
        self.lazy_encode = lazy_encode
        self.audio_files = []

        # Walk through all subfolders to gather audio files
        for root, _, files in os.walk(root_dir):
            for file in files:
                if any(file.endswith(ext) for ext in self.extensions):
                    self.audio_files.append(os.path.join(root, file))

        # If not lazy encoding, encode all audio files during initialization
        if not self.lazy_encode:
            self.encoded_data = [self._encode_audio(file) for file in tqdm(self.audio_files, desc="Encoding Audio")]

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        if self.lazy_encode:
            filename = self.audio_files[idx]
            return self._encode_audio(filename)[..., :TIMESTEPS*LEVELS]
        else:
            encoded_audio = self.encoded_data[idx]
            return encoded_audio[:TIMESTEPS * LEVELS]

    def _encode_audio(self, filename):
        waveform, sample_rate = sf.read(filename)
        waveform = torch.tensor(waveform, dtype=torch.float32)[None, None, :]  # Add batch dimension
        waveform = convert_audio(waveform, sample_rate, self.model.sample_rate, self.model.channels)

        with torch.no_grad():
            discrete_reps = self.model.encode(waveform.to(device))

        discrete_reps = discrete_reps[0][0].contiguous().permute(0, 2, 1).reshape(-1)
        return discrete_reps.cpu()

In [None]:
# ## Load EnCodec Model
# Load the EnCodec model to transform the audio representation.
codec = EncodecModel.encodec_model_24khz().to(device)
codec.set_target_bandwidth(1.5)

# ## Load Dataset
# Load the NSYNTH dataset and prepare DataLoader for training and validation.
audio_folder_train = "./NSYNTH/nsynth-train"
audio_folder_val = "./NSYNTH/nsynth-valid"

dataset = DiscreteAudioRepDataset(root_dir=audio_folder_train, model=codec, lazy_encode=False)
dataloader = DataLoader(dataset, batch_size=350, shuffle=True)

dataset_val = DiscreteAudioRepDataset(root_dir=audio_folder_val, model=codec, lazy_encode=False)
dataloader_val = DataLoader(dataset_val, batch_size=350, shuffle=True)

In [None]:
# ## Define Transformer Model
# Define the Transformer model using TransformerWrapper and Decoder.
model = TransformerWrapper(
    emb_dropout=0.1,
    num_tokens=1024,
    max_seq_len=LEVELS*TIMESTEPS,
    attn_layers=Decoder(
        dim=256,
        depth=6,
        heads=4,
        rotary_pos_emb=True,
        attn_dropout=0.1,
        ff_dropout=0.1
    )
).to(device)

In [0]:
# Optionally load pretrained weights
model.load_state_dict(torch.load('./model_weights/model_gen_autoreg_transformer.pth', map_location=device))

In [ ]:
# ## Training Loop
# Define a training loop for training the Transformer model.
epochs = 1000
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

for epoch in tqdm(range(epochs), desc="Epochs"):
    train_loss, val_loss = 0, 0
    model.train()
    total_correct, total_predictions = 0, 0
    count = 0

    for batch in dataloader:
        if batch.dtype != torch.long:
            batch = batch.long()
        start_tokens = torch.zeros((batch.shape[0], 1), dtype=torch.long, device=batch.device)
        batch = torch.cat([start_tokens, batch], dim=1)
        discrete_reps = batch.to(device)

        logits = model(discrete_reps)
        logits = logits.permute(0, 2, 1)
        loss = criterion(logits[..., :-1], discrete_reps[..., 1:])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
        count += 1

        preds = logits[..., :-1].argmax(dim=1)
        targets = discrete_reps[..., 1:]
        correct = (preds == targets).sum().item()
        total_correct += correct
        total_predictions += targets.numel()

    avg_train_loss = train_loss / count
    precision = total_correct / total_predictions
    print(f'Epoch {epoch + 1}, Train Loss: {avg_train_loss:.4f}, Precision: {precision:.4f}')
    torch.save(model.state_dict(), 'model.pth')

In [None]:
# ## Evaluation and Audio Generation
# Evaluate the model and generate audio from the trained Transformer.
model.eval()
num_samples = 5
seq_len = LEVELS * TIMESTEPS
temperature = 1.0
os.makedirs('generated_audio', exist_ok=True)

for i in range(num_samples):
    print(f"Generating sample {i+1}/{num_samples}")
    start_token = 0
    generated = [start_token]
    for _ in tqdm(range(seq_len), desc="Generating Tokens", leave=False):
        input_seq = torch.tensor([generated], dtype=torch.long).to(device)
        with torch.no_grad():
            logits = model(input_seq)[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
    generated_sequence = torch.tensor(generated[1:], dtype=torch.long).to(device)
    codes = generated_sequence.view(1, -1, LEVELS).transpose(1, 2)

    with torch.no_grad():
        decoded_audio = codec.decode([(codes, None)])
    decoded_audio = decoded_audio.squeeze().cpu().numpy().astype(np.float32)

    output_filename = f'generated_audio/sample_{i+1}.wav'
    sf.write(output_filename, decoded_audio, samplerate=codec.sample_rate)
    print(f"Saved {output_filename}")

In [ ]:
# ## Play Generated Audio 
# Use IPython audio player to play generated audio samples.
for i in range(1, num_samples + 1):
    output_filename = f'generated_audio/sample_{i}.wav'
    print(f"Playing {output_filename}")
    ipd.display(ipd.Audio(output_filename))