In [1]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import numpy as np
import os
from pathlib import Path
import glob
from typing import List, Tuple
from tqdm import tqdm

In [2]:
class AudioDataset(Dataset):
    """Dataset for loading and preprocessing audio files"""
    
    def __init__(
        self, 
        audio_dir: str,
        sample_rate: int = 24000,
        segment_length: float = 4.0,  # seconds
        file_extensions: List[str] = ['.mp3', '.wav', '.flac']
    ):
        self.audio_dir = Path(audio_dir)
        self.sample_rate = sample_rate
        self.segment_samples = int(segment_length * sample_rate)
        
        # Find all audio files
        self.audio_files = []
        for ext in file_extensions:
            self.audio_files.extend(glob.glob(str(self.audio_dir / f"**/*{ext}"), recursive=True))
        
        print(f"Found {len(self.audio_files)} audio files")
        
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        
        try:
            # Load audio
            waveform, orig_sr = torchaudio.load(audio_path)
            
            # Resample if needed
            if orig_sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(orig_sr, self.sample_rate)
                waveform = resampler(waveform)
            
            # Convert to mono if stereo
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            # Random crop to fixed length
            if waveform.shape[1] > self.segment_samples:
                start_idx = torch.randint(0, waveform.shape[1] - self.segment_samples, (1,))
                waveform = waveform[:, start_idx:start_idx + self.segment_samples]
            elif waveform.shape[1] < self.segment_samples:
                # Pad if too short
                pad_length = self.segment_samples - waveform.shape[1]
                waveform = F.pad(waveform, (0, pad_length))
            
            # Normalize - FIXED: Handle edge case where max is 0
            max_val = torch.max(torch.abs(waveform))
            if max_val > 0:
                waveform = waveform / max_val
            
            return waveform
            
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            # Return a zero tensor if loading fails
            return torch.zeros(1, self.segment_samples)

In [3]:
class ConvEncoder(nn.Module):
    """
    Convolutional encoder that downsamples audio and extracts features
    Similar to DAC encoder structure
    """
    def __init__(self, dim: int = 64):
        super().__init__()
        # Progressive downsampling with increasing channels
        self.layers = nn.ModuleList([
            nn.Conv1d(1, dim, kernel_size=7, stride=1, padding=3),
            nn.Conv1d(dim, dim, kernel_size=3, stride=2, padding=1),    # /2
            nn.Conv1d(dim, dim*2, kernel_size=3, stride=2, padding=1),  # /4  
            nn.Conv1d(dim*2, dim*4, kernel_size=3, stride=2, padding=1), # /8
            nn.Conv1d(dim*4, dim*8, kernel_size=3, stride=2, padding=1), # /16
        ])
        
        # FIXED: Use GroupNorm instead of BatchNorm for better stability
        self.norms = nn.ModuleList([
            nn.GroupNorm(4, dim),
            nn.GroupNorm(4, dim), 
            nn.GroupNorm(8, dim*2),
            nn.GroupNorm(16, dim*4),
            nn.GroupNorm(32, dim*8),
        ])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer, norm in zip(self.layers, self.norms):
            x = layer(x)
            x = norm(x)
            x = F.leaky_relu(x, 0.2)
        return x

In [4]:
class VectorQuantizer(nn.Module):
    """
    Single layer Vector Quantizer
    Maps continuous features to discrete codebook entries
    """
    def __init__(self, codebook_size: int, codebook_dim: int, beta: float = 0.25):
        super().__init__()
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.beta = beta
        
        # Initialize codebook with random vectors
        self.codebook = nn.Embedding(codebook_size, codebook_dim)
        self.codebook.weight.data.uniform_(-1/codebook_size, 1/codebook_size)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch_size, dim, time = x.shape
        x_flat = x.permute(0, 2, 1).contiguous().view(-1, dim)
        
        # IMPROVED: More efficient distance calculation
        distances = torch.cdist(x_flat, self.codebook.weight)
        indices = torch.argmin(distances, dim=1)
        quantized_flat = self.codebook(indices)
        
        quantized = quantized_flat.view(batch_size, time, dim).permute(0, 2, 1)
        indices = indices.view(batch_size, time)
        
        # VQ losses
        codebook_loss = F.mse_loss(quantized_flat, x_flat.detach())
        commitment_loss = F.mse_loss(x_flat, quantized_flat.detach())
        vq_loss = codebook_loss + self.beta * commitment_loss
        
        # Straight-through estimator
        quantized = x + (quantized - x).detach()
        
        return quantized, indices, vq_loss

In [5]:
class ResidualVectorQuantizer(nn.Module):
    """
    Residual Vector Quantizer with multiple layers
    Each layer quantizes the residual from previous layers
    """
    def __init__(
        self, 
        input_dim: int = 512,  # FIXED: Make input dimension configurable
        n_layers: int = 12,
        codebook_size: int = 1024, 
        codebook_dim: int = 8,
        beta: float = 0.25
    ):
        super().__init__()
        self.n_layers = n_layers
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        
        # Create multiple VQ layers
        self.quantizers = nn.ModuleList([
            VectorQuantizer(codebook_size, codebook_dim, beta) 
            for _ in range(n_layers)
        ])
        
        # Project encoder features to codebook dimension
        self.input_proj = nn.Conv1d(input_dim, codebook_dim, kernel_size=1)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
        x = self.input_proj(x)
        quantized = torch.zeros_like(x)
        residual = x
        all_indices = []
        total_loss = 0.0
        
        for quantizer in self.quantizers:
            q, indices, loss = quantizer(residual)
            quantized += q
            residual -= q
            all_indices.append(indices)
            total_loss += loss
        
        return quantized, all_indices, total_loss
    
    def encode(self, x: torch.Tensor) -> List[torch.Tensor]:
        with torch.no_grad():
            _, indices, _ = self.forward(x)
        return indices
    
    def decode(self, indices_list: List[torch.Tensor]) -> torch.Tensor:
        quantized = None
        for i, indices in enumerate(indices_list):
            q = self.quantizers[i].codebook(indices)
            q = q.permute(0, 2, 1)
            quantized = q if quantized is None else quantized + q
        return quantized

In [6]:
class VocosDecoder(nn.Module):
    """
    Vocos-style decoder for efficient audio reconstruction
    Uses transposed convolutions to upsample back to original resolution
    """
    def __init__(self, input_dim: int = 8, dim: int = 64):
        super().__init__()
        
        # FIXED: Match input dimension from quantizer
        self.input_proj = nn.Conv1d(input_dim, dim*8, kernel_size=1)
        
        # Progressive upsampling with decreasing channels
        self.layers = nn.ModuleList([
            nn.ConvTranspose1d(dim*8, dim*4, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(dim*4, dim*2, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(dim*2, dim, kernel_size=4, stride=2, padding=1),
            nn.ConvTranspose1d(dim, dim, kernel_size=4, stride=2, padding=1),
            nn.Conv1d(dim, 1, kernel_size=7, stride=1, padding=3),
        ])
        
        # FIXED: Use GroupNorm for better stability
        self.norms = nn.ModuleList([
            nn.GroupNorm(16, dim*4),
            nn.GroupNorm(8, dim*2), 
            nn.GroupNorm(4, dim),
            nn.GroupNorm(4, dim),
        ])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.input_proj(x)
        
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = self.norms[i](x)
            x = F.leaky_relu(x, 0.2)
        
        x = self.layers[-1](x)
        return torch.tanh(x)

In [7]:
class AcousticCodec(nn.Module):
    """
    Acoustic Codec for converting audio waveform to discrete tokens and back
    Based on DAC (Descript Audio Codec) with Vocos decoder
    """
    def __init__(
        self,
        sample_rate: int = 24000,
        n_layers: int = 12,           # RVQ layers
        codebook_size: int = 1024,    # entries per codebook
        codebook_dim: int = 8,        # dimension of each entry
        encoder_dim: int = 64,        # base encoder dimension
        decoder_dim: int = 64,        # base decoder dimension
    ):
        super().__init__()
        
        self.sample_rate = sample_rate
        self.n_layers = n_layers
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        
        # Main components
        self.encoder = ConvEncoder(encoder_dim)
        
        # FIXED: Pass correct input dimension to RVQ
        encoder_output_dim = encoder_dim * 8  # From ConvEncoder final layer
        self.quantizer = ResidualVectorQuantizer(
            input_dim=encoder_output_dim,
            n_layers=n_layers,
            codebook_size=codebook_size,
            codebook_dim=codebook_dim
        )
        
        self.decoder = VocosDecoder(
            input_dim=codebook_dim,
            dim=decoder_dim
        )
    
    def encode(self, audio: torch.Tensor) -> List[torch.Tensor]:
        """
        Encode audio to discrete tokens
        
        Args:
            audio: (batch, 1, time) - Raw audio waveform
            
        Returns:
            tokens: List of (batch, time_compressed) - Discrete tokens for each layer
        """
        features = self.encoder(audio)
        tokens = self.quantizer.encode(features)
        return tokens
    
    def decode(self, tokens: List[torch.Tensor]) -> torch.Tensor:
        """
        Decode tokens back to audio
        
        Args:
            tokens: List of (batch, time_compressed) - Discrete tokens for each layer
            
        Returns:
            audio: (batch, 1, time) - Reconstructed audio
        """
        quantized = self.quantizer.decode(tokens)
        audio = self.decoder(quantized)
        return audio
    
    def forward(self, audio: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
        """
        Full encode-decode cycle
        
        Returns:
            reconstructed_audio: Reconstructed audio
            tokens: List of discrete tokens for each RVQ layer  
            vq_loss: Loss for training quantizer
        """
        # Encode to features
        features = self.encoder(audio)
        
        # Quantize features
        quantized, tokens, vq_loss = self.quantizer(features)
        
        # Decode to audio
        reconstructed = self.decoder(quantized)
        
        return reconstructed, tokens, vq_loss

In [8]:
# IMPROVED: Simpler training function
def train_codec(
    audio_dir: str,
    num_epochs: int = 10,
    batch_size: int = 8,
    learning_rate: float = 1e-4,
    sample_rate: int = 24000,
    segment_length: float = 4.0,
    save_path: str = "acoustic_codec.pth"
):
    """
    Simplified training function for the acoustic codec
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create dataset and dataloader
    dataset = AudioDataset(
        audio_dir=audio_dir,
        sample_rate=sample_rate,
        segment_length=segment_length
    )
    
    if len(dataset) == 0:
        raise ValueError(f"No audio files found in {audio_dir}")
    
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=2,  # Reduced for stability
        pin_memory=True if device.type == 'cuda' else False
    )
    
    # Create model
    codec = AcousticCodec(sample_rate=sample_rate).to(device)
    optimizer = optim.Adam(codec.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    
    print(f"Model parameters: {sum(p.numel() for p in codec.parameters()):,}")
    
    # Training loop
    codec.train()
    for epoch in range(num_epochs):
        total_loss = 0
        total_rec_loss = 0
        total_vq_loss = 0
        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch_idx, audio in enumerate(progress_bar):
            audio = audio.to(device)
            
            # Forward pass
            reconstructed_audio, tokens, vq_loss = codec(audio)
            
            # Losses
            reconstruction_loss = F.l1_loss(reconstructed_audio, audio)
            loss = reconstruction_loss + 0.1 * vq_loss  # Weight VQ loss lower
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            
            # ADDED: Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(codec.parameters(), 1.0)
            
            optimizer.step()
            
            # Track losses
            total_loss += loss.item()
            total_rec_loss += reconstruction_loss.item()
            total_vq_loss += vq_loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f"{loss.item():.4f}",
                'Rec': f"{reconstruction_loss.item():.4f}",
                'VQ': f"{vq_loss.item():.4f}"
            })
        
        scheduler.step()
        
        # Print epoch summary
        avg_loss = total_loss / len(dataloader)
        avg_rec = total_rec_loss / len(dataloader)
        avg_vq = total_vq_loss / len(dataloader)
        
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Rec={avg_rec:.4f}, VQ={avg_vq:.4f}")
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'model_state_dict': codec.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': epoch,
                'loss': avg_loss
            }, f"codec_epoch_{epoch+1}.pth")
    
    # Save final model
    torch.save(codec.state_dict(), save_path)
    print(f"Training completed! Model saved to {save_path}")
    
    return codec

In [9]:
# ADDED: Testing function
def test_codec(model_path: str, test_audio_path: str, output_path: str = "reconstructed.wav"):
    """
    Test the trained codec on a single audio file
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model
    codec = AcousticCodec().to(device)
    codec.load_state_dict(torch.load(model_path, map_location=device))
    codec.eval()
    
    # Load test audio
    audio, sr = torchaudio.load(test_audio_path)
    
    # Resample if needed
    if sr != codec.sample_rate:
        resampler = torchaudio.transforms.Resample(sr, codec.sample_rate)
        audio = resampler(audio)
    
    # Convert to mono and add batch dimension
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)
    audio = audio.unsqueeze(0).to(device)  # Add batch dimension
    
    # Test encoding/decoding
    with torch.no_grad():
        # Full reconstruction
        reconstructed, tokens, vq_loss = codec(audio)
        
        # Just encoding
        encoded_tokens = codec.encode(audio)
        
        # Just decoding  
        decoded_audio = codec.decode(encoded_tokens)
        
        print(f"Original shape: {audio.shape}")
        print(f"Reconstructed shape: {reconstructed.shape}")
        print(f"Number of token layers: {len(tokens)}")
        print(f"Token shapes: {[t.shape for t in tokens]}")
        print(f"VQ Loss: {vq_loss.item():.4f}")
    
    # Save reconstructed audio
    torchaudio.save(output_path, reconstructed.squeeze(0).cpu(), codec.sample_rate)
    print(f"Reconstructed audio saved to {output_path}")

In [10]:
# Example usage
if __name__ == "__main__":
     # For training:
    codec = train_codec(
        audio_dir="/kaggle/input/tts-dataset/FINAL_TTS_DATA/",
        num_epochs=10,
        batch_size=4,  # Start small
        learning_rate=1e-4
    )
    
    # For testing:
    test_codec(
        model_path="acoustic_codec.pth",
        test_audio_path="/kaggle/input/tts-dataset/FINAL_TTS_DATA/processed_00xcPM_229.mp3"
    )

Using device: cuda
Found 8000 audio files
Model parameters: 1,345,417


Epoch 1/10: 100%|██████████| 2000/2000 [06:34<00:00,  5.07it/s, Loss=0.0165, Rec=0.0121, VQ=0.0435]  


Epoch 1: Loss=1.2942, Rec=0.0208, VQ=12.7345


Epoch 2/10: 100%|██████████| 2000/2000 [06:34<00:00,  5.08it/s, Loss=0.0136, Rec=0.0092, VQ=0.0440]


Epoch 2: Loss=0.0136, Rec=0.0086, VQ=0.0503


Epoch 3/10: 100%|██████████| 2000/2000 [06:33<00:00,  5.08it/s, Loss=0.0091, Rec=0.0061, VQ=0.0295]


Epoch 3: Loss=0.0096, Rec=0.0069, VQ=0.0272


Epoch 4/10: 100%|██████████| 2000/2000 [06:34<00:00,  5.08it/s, Loss=0.0074, Rec=0.0061, VQ=0.0139]


Epoch 4: Loss=0.0083, Rec=0.0063, VQ=0.0199


Epoch 5/10: 100%|██████████| 2000/2000 [06:34<00:00,  5.07it/s, Loss=0.0062, Rec=0.0049, VQ=0.0136]


Epoch 5: Loss=0.0072, Rec=0.0057, VQ=0.0151


Epoch 6/10: 100%|██████████| 2000/2000 [06:33<00:00,  5.08it/s, Loss=0.0051, Rec=0.0041, VQ=0.0101]


Epoch 6: Loss=0.0071, Rec=0.0058, VQ=0.0128


Epoch 7/10: 100%|██████████| 2000/2000 [06:34<00:00,  5.08it/s, Loss=0.0071, Rec=0.0060, VQ=0.0108]


Epoch 7: Loss=0.0063, Rec=0.0052, VQ=0.0111


Epoch 8/10: 100%|██████████| 2000/2000 [06:33<00:00,  5.08it/s, Loss=0.0073, Rec=0.0065, VQ=0.0072]


Epoch 8: Loss=0.0061, Rec=0.0051, VQ=0.0095


Epoch 9/10: 100%|██████████| 2000/2000 [06:33<00:00,  5.08it/s, Loss=0.0060, Rec=0.0051, VQ=0.0094]


Epoch 9: Loss=0.0057, Rec=0.0048, VQ=0.0087


Epoch 10/10: 100%|██████████| 2000/2000 [06:34<00:00,  5.07it/s, Loss=0.0060, Rec=0.0049, VQ=0.0105]


Epoch 10: Loss=0.0055, Rec=0.0047, VQ=0.0081
Training completed! Model saved to acoustic_codec.pth
Original shape: torch.Size([1, 1, 115226])
Reconstructed shape: torch.Size([1, 1, 115232])
Number of token layers: 12
Token shapes: [torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202]), torch.Size([1, 7202])]
VQ Loss: 0.0062
Reconstructed audio saved to reconstructed.wav


In [13]:
    # For testing:
    test_codec(
        model_path="acoustic_codec.pth",
        test_audio_path="/kaggle/input/tts-dataset/FINAL_TTS_DATA/processed_01kOfp_182.mp3"
    )

Original shape: torch.Size([1, 1, 132456])
Reconstructed shape: torch.Size([1, 1, 132464])
Number of token layers: 12
Token shapes: [torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279]), torch.Size([1, 8279])]
VQ Loss: 0.0048
Reconstructed audio saved to reconstructed.wav
