<a href="https://colab.research.google.com/github/niharali/VQ-VAE/blob/main/VQ_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchaudio
!pip install torch



In [None]:
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchaudio.datasets import LIBRISPEECH
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

**Step 1: Load the LibriSpeech Dataset**

In [None]:
# Load LibriSpeech dataset from torchaudio (can use any other speech dataset from Kaggle)
train_dataset = LIBRISPEECH(root="./", url="train-clean-100", download=True)
test_dataset = LIBRISPEECH(root="./", url="test-clean", download=True)

# DataLoader to iterate over the dataset in batches
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=16, shuffle=False)

100%|██████████| 5.95G/5.95G [02:59<00:00, 35.6MB/s]
100%|██████████| 331M/331M [00:10<00:00, 32.1MB/s]


**Step 2: Preprocess Data (Mel-Spectrogram Transformation)**

We transform the audio data into Mel-spectrograms:

In [None]:
# Convert audio to Mel-spectrogram
mel_spectrogram = transforms.MelSpectrogram(sample_rate=16000, n_mels=80, win_length=400, hop_length=160)

def process_audio(data):
    waveform, sample_rate, _, _, _ = data
    return mel_spectrogram(waveform).squeeze(0)

**Step 3: Define the VQ-VAE Model**

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, embedding_dim, num_embeddings, commitment_cost=0.25):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        # Initialize the embedding table
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.embedding.weight.data.uniform_(-1/self.num_embeddings, 1/self.num_embeddings)

    def forward(self, inputs):
        # Flatten the input
        flat_input = inputs.view(-1, self.embedding_dim)

        # Compute distances between input and embedding vectors
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                     + torch.sum(self.embedding.weight**2, dim=1)
                     - 2 * torch.matmul(flat_input, self.embedding.weight.t()))

        # Get the closest embedding index for each input
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.size(0), self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Get quantized vectors
        quantized = torch.matmul(encodings, self.embedding.weight).view_as(inputs)

        # Compute commitment loss
        e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
        q_latent_loss = torch.mean((quantized - inputs.detach())**2)
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        return quantized, loss

# Define the VQ-VAE model
class VQVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim, num_embeddings):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, embedding_dim, kernel_size=4, stride=2, padding=1),
        )
        self.quantizer = VectorQuantizer(embedding_dim, num_embeddings)
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(embedding_dim, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(hidden_dim, input_dim, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, vq_loss = self.quantizer(z_e)
        x_recon = self.decoder(z_q)
        return x_recon, vq_loss


In [None]:
device = torch.device('cpu')
model = VQVAE(input_dim, hidden_dim, embedding_dim, num_embeddings).to(device)

**Step 4: Define the Training Loop**

In [None]:
# Define model parameters
input_dim = 80  # Mel-spectrogram bins
hidden_dim = 128
embedding_dim = 64
num_embeddings = 512
learning_rate = 0.001

# Initialize the VQ-VAE model
#model = VQVAE(input_dim, hidden_dim, embedding_dim, num_embeddings).cuda()
model = VQVAE(input_dim, hidden_dim, embedding_dim, num_embeddings).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Training loop
def train(model, dataloader, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader):
            optimizer.zero_grad()

            # Preprocess data
            audio_features = process_audio(data).unsqueeze(0).cuda()

            # Forward pass
            recon, vq_loss = model(audio_features)
            loss = criterion(recon, audio_features) + vq_loss

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss / 100}")
                running_loss = 0.0


In [None]:
# Training loop
def train(model, dataloader, epochs=10):
    model.train()  # Set model to training mode
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader):
            optimizer.zero_grad()

            # Preprocess data and pad/trim as needed
            audio_features = process_audio_with_padding(data).cuda()  # or process_audio_with_trimming

            # Forward pass
            recon, vq_loss = model(audio_features)
            loss = criterion(recon, audio_features) + vq_loss

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Print loss every 100 batches
            if i % 100 == 99:
                print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss / 100}")
                running_loss = 0.0

        # Empty the cache to manage memory
        torch.cuda.empty_cache()


In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence

# Custom collate function to pad sequences in a batch
def custom_collate_fn(batch):
    # Process audio features (convert each data sample to Mel-spectrogram)
    audio_features = [process_audio(data) for data in batch]

    # Pad the audio features so that all sequences have the same length
    padded_audio_features = pad_sequence(audio_features, batch_first=True, padding_value=0)

    # Optionally, handle other elements of the batch (e.g., labels, if present)
    return padded_audio_features

# Update the DataLoader with the custom collate function
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)
test_loader = DataLoader(dataset=test_dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)


**Step 5: Train the Model**

In [None]:
def train(model, dataloader, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader):
            optimizer.zero_grad()

            # Use padding or trimming to process audio features
            audio_features = custom_collate_fn(data).cuda()  # For padding
            # or
            # audio_features = trim_audio_sequences(data).cuda()  # For trimming

            # Forward pass
            recon, vq_loss = model(audio_features)
            loss = criterion(recon, audio_features) + vq_loss

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss / 100}")
                running_loss = 0.0


In [None]:
# Train the model
train(model, train_loader, epochs=10)

ValueError: too many values to unpack (expected 5)

**Step 6: Test and Evaluate the Model**

In [None]:
# Testing function
def test(model, dataloader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            audio_features = process_audio(data).unsqueeze(0).cuda()
            recon, vq_loss = model(audio_features)
            loss = criterion(recon, audio_features) + vq_loss
            total_loss += loss.item()

    print(f"Test Loss: {total_loss / len(dataloader)}")

# Test the model
test(model, test_loader)
