# Reconnaissance Automatique de la Parole - Jour 1

In [1]:
!pip install Levenshtein optuna



In [2]:
import torchaudio
import torch
from torchaudio.datasets import LIBRISPEECH
import torch.nn.functional as F
from torchaudio import transforms
from torch.utils.data import DataLoader
import os
import torch.nn as nn
import Levenshtein
import math
import optuna
import torch.optim as optim
from torch.utils.data import random_split

In [3]:
# Model selection parameter - change this to 'MLP', 'CNN', 'GRU', or 'Transformer'
MODEL_TYPE = 'MLP'  # Options: 'MLP', 'CNN', 'GRU', 'Transformer'

In [4]:
def collate_fn(batch):
    """Custom collate function to handle variable-length audio"""
    waveforms = [item[0] for item in batch]
    sample_rates = [item[1] for item in batch]
    transcripts = [item[2] for item in batch]
    speaker_ids = [item[3] for item in batch]
    chapter_ids = [item[4] for item in batch]
    utterance_ids = [item[5] for item in batch]

    return waveforms, sample_rates, transcripts, speaker_ids, chapter_ids, utterance_ids

def manage_data(batch_size, num_workers):
    train_dataset = LIBRISPEECH("./data", url="dev-clean", download=True)
    print("Data Downloaded !")

    print(f"Dataset size: {len(train_dataset)} samples")
    print(f"First sample shape: {train_dataset[0][0].shape}")

    # Create DataLoader for batch processing
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    return train_loader

def load_data():
    print("Downloading data")

    os.makedirs("data", exist_ok=True)
    # Use 'dev-clean' for testing, 'train-clean-100' for training
    train_loader = manage_data(batch_size=8, num_workers=2)

load_data()

Downloading data
Data Downloaded !
Dataset size: 2703 samples
First sample shape: torch.Size([1, 93680])


In [5]:
# -----------------------------
# DEVICE MANAGEMENT
# -----------------------------
def get_device():
    if torch.cuda.is_available():
        print("Using CUDA GPU")
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        print("Using Apple MPS")
        return torch.device("mps")
    else:
        print("Using CPU")
        return torch.device("cpu")

device = get_device()


Using CUDA GPU


# Models

In [6]:
def _MFCC(waveform, sample_rate):
    transform = transforms.MFCC(
        sample_rate=sample_rate,
        n_mfcc=13,
        melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23, "center": False}
    )
    mfcc = transform(waveform)
    return mfcc

In [7]:
class MLP(nn.Module):
    def __init__(self, input_size=13, hidden_size=256, output_size=29, num_layers=3):
        super(MLP, self).__init__()

        layers = []
        layers.append(nn.Linear(input_size, hidden_size))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(0.2))

        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))

        layers.append(nn.Linear(hidden_size, output_size))

        self.network = nn.Sequential(*layers)
        self.loss_fn = nn.CTCLoss(blank=28, zero_infinity=True)

    def forward(self, x):
        return self.network(x)

    def loss(self, log_probs, targets, input_lengths, target_lengths):
        return self.loss_fn(log_probs, targets, input_lengths, target_lengths)

In [8]:
class CNN(nn.Module):
    def __init__(self, n_classes=30, n_mels=40):  # 30 phonemes or characters
        super().__init__()

        # 1D convolutions along time axis
        # Input channels = n_mels (frequency bins treated as channels)
        self.conv1 = nn.Conv1d(n_mels, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)

        self.pool = nn.MaxPool1d(2)  # pool along time axis

        self.dropout = nn.Dropout(0.3)

        # final classifier
        self.classifier = nn.Linear(512, n_classes)

    def forward(self, x):
        # x: [B, 1, F, T] from spectrogram (mel_spec gives [1, n_mels, time])
        # Reshape to [B, F, T] for 1D conv along time
        if x.dim() == 4:
            # Input is [B, 1, F, T]
            x = x.squeeze(1)  # -> [B, F, T]
        elif x.dim() == 3:
            # Input is already [B, F, T] or needs transpose
            # Check if last dim is n_mels (needs transpose)
            if x.shape[-1] == self.conv1.in_channels:
                x = x.transpose(1, 2)  # [B, T, F] -> [B, F, T]

        x = F.relu(self.conv1(x))
        x = self.pool(x)  # -> [B, 128, T/2]

        x = F.relu(self.conv2(x))
        x = self.pool(x)  # -> [B, 256, T/4]

        x = F.relu(self.conv3(x))
        x = self.pool(x)  # -> [B, 512, T/8]

        # Transpose to [B, T', C] for time-distributed classification
        x = x.transpose(1, 2)  # -> [B, T', 512]

        x = self.dropout(x)
        out = self.classifier(x)  # [B, T', n_classes]

        return out


In [9]:
class GRU(nn.Module):
    def __init__(self, input_dim=40, hidden_dim=128, num_layers=2, n_classes=30):
        super().__init__()

        self.gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )

        self.classifier = nn.Linear(hidden_dim * 2, n_classes)

    def forward(self, x):
        # x: [B, T, F]
        out, _ = self.gru(x)       # out: [B, T, hidden*2]
        logits = self.classifier(out)
        return logits              # [B, T, n_classes]

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()

        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: [B, T, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return x

class Transformer(nn.Module):
    def __init__(self, input_dim=40, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, n_classes=29, dropout=0.1):
        """
        Transformer-based speech recognition model

        Args:
            input_dim: Input feature dimension (e.g., 40 for mel spectrograms)
            d_model: Dimension of the model
            nhead: Number of attention heads
            num_layers: Number of transformer encoder layers
            dim_feedforward: Dimension of feedforward network
            n_classes: Number of output classes
            dropout: Dropout rate
        """
        super().__init__()

        # Input projection
        self.input_projection = nn.Linear(input_dim, d_model)

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output classifier
        self.classifier = nn.Linear(d_model, n_classes)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, T, F] where F is input_dim (e.g., 40 mel bins)

        # Project to d_model
        x = self.input_projection(x)  # [B, T, d_model]

        # Add positional encoding
        x = self.pos_encoder(x)

        # Apply dropout
        x = self.dropout(x)

        # Transformer encoding
        x = self.transformer_encoder(x)  # [B, T, d_model]

        # Classification
        logits = self.classifier(x)  # [B, T, n_classes]

        return logits

# Train

In [11]:
import torch
import torchaudio
import torch.nn.functional as F

def train_model(batch_size, num_workers, num_epochs=10, save_path='model_checkpoint.pth', model_type='Transformer'):
    # -------------------------------
    # 1. GPU / MPS / CPU management
    # -------------------------------
    device = torch.device(
        "cuda" if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using device: {device}")

    train_loader = manage_data(batch_size, num_workers)

    # -------------------------------
    # 2. Initialize Transforms ONCE and move to Device
    # -------------------------------
    # We define these outside the loop to avoid re-creating them
    # and to ensure the internal window buffers are on the GPU.

    # For CNN / GRU / Transformer
    melspec_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=16000, # Assuming 16k, adjust if your data differs
        n_mels=40,
        n_fft=400,
        hop_length=80
    ).to(device)

    # For MLP (Replacing your _MFCC function with the official class to be safe)
    mfcc_transform = torchaudio.transforms.MFCC(
        sample_rate=16000,
        n_mfcc=13,
        melkwargs={"n_fft": 400, "hop_length": 80, "n_mels": 40}
    ).to(device)

    # -------------------------------
    # 3. Initialize Model
    # -------------------------------
    if model_type == 'MLP':
        model = MLP(input_size=13, hidden_size=128, output_size=29)
        print("Using MLP")
    elif model_type == 'CNN':
        model = CNN(n_classes=29, n_mels=40)
        print("Using CNN")
    elif model_type == 'GRU':
        model = GRU(input_dim=40, hidden_dim=128, num_layers=2, n_classes=29)
        print("Using GRU")
    elif model_type == 'Transformer':
        model = Transformer(input_dim=40, d_model=256, nhead=8, num_layers=6, n_classes=29)
        print("Using Transformer")
    else:
        raise ValueError("Unknown model_type")

    model = model.to(device)
    model.train(True)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    ctc_loss = torch.nn.CTCLoss(blank=28, zero_infinity=True)
    char2idx = {c: i for i, c in enumerate("abcdefghijklmnopqrstuvwxyz '")}

    # -------------------------------
    # 4. Training Loop
    # -------------------------------
    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for batch_idx, (waveforms, sample_rates, transcripts, _, _, _) in enumerate(train_loader):

            batch_log_probs = []
            batch_targets = []
            batch_input_lengths = []
            batch_target_lengths = []

            for i, (waveform, sample_rate) in enumerate(zip(waveforms, sample_rates)):

                # Move waveform to device
                waveform = waveform.to(device)

                # -------------------------------
                # Feature extraction (Using pre-loaded transforms)
                # -------------------------------
                if model_type == 'MLP':
                    # Use the pre-initialized MFCC transform
                    features = mfcc_transform(waveform)
                    features = features.squeeze(0).transpose(0, 1)

                elif model_type in ['CNN', 'GRU', 'Transformer']:
                    # Use the pre-initialized MelSpec transform
                    features = melspec_transform(waveform)
                    features = features.clamp(min=1e-9).log2()

                    if model_type == 'CNN':
                        features = features.unsqueeze(0)
                    elif model_type in ['GRU', 'Transformer']:
                        # Expects [B, T, F] -> transpose
                        features = features.squeeze(0).transpose(0, 1).unsqueeze(0)

                # Target processing
                target = torch.tensor(
                    [char2idx[c] for c in transcripts[i].lower() if c in char2idx],
                    dtype=torch.long,
                    device=device
                )

                # Calculate input length for CTC
                if model_type == 'MLP':
                    input_len = features.size(0)
                elif model_type == 'CNN':
                    input_len = features.size(-1) // 8 # Adjust based on your CNN strides
                elif model_type in ['GRU', 'Transformer']:
                    input_len = features.size(1)

                # Skip if target is longer than input (CTC requirement)
                if input_len <= len(target):
                    continue

                # Forward Pass
                logits = model(features)

                # Squeeze batch dim if necessary (logic from your original code)
                if model_type in ['CNN', 'GRU', 'Transformer']:
                    logits = logits.squeeze(0)

                log_probs = F.log_softmax(logits, dim=1)

                batch_log_probs.append(log_probs)
                batch_targets.append(target)
                batch_input_lengths.append(log_probs.size(0))
                batch_target_lengths.append(len(target))

            if len(batch_log_probs) == 0:
                continue

            # -------------------------------
            # Pad sequences and Compute Loss
            # -------------------------------
            max_input_len = max(batch_input_lengths)
            padded_log_probs = torch.zeros(
                max_input_len, len(batch_log_probs), 29,
                device=device
            )

            for i, log_probs in enumerate(batch_log_probs):
                padded_log_probs[:log_probs.size(0), i, :] = log_probs

            concatenated_targets = torch.cat(batch_targets)
            input_lengths = torch.tensor(batch_input_lengths, dtype=torch.long, device=device)
            target_lengths = torch.tensor(batch_target_lengths, dtype=torch.long, device=device)

            optimizer.zero_grad()
            loss = ctc_loss(padded_log_probs, concatenated_targets, input_lengths, target_lengths)

            if torch.isnan(loss):
                print(f"NaN loss detected at batch {batch_idx}")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Avg Loss: {total_loss/num_batches:.4f}")

        print(f"\nEpoch {epoch+1} completed. Avg loss = {total_loss/num_batches:.4f}")

    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    return model

# Evaluation

In [None]:
def decode_predictions(log_probs, idx2char, blank_idx=28):
    predictions = torch.argmax(log_probs, dim=1)

    decoded = []
    prev_char = None
    for pred in predictions:
        pred_idx = pred.item()
        if pred_idx != blank_idx and pred_idx != prev_char:
            decoded.append(idx2char[pred_idx])
        prev_char = pred_idx
    return ''.join(decoded)

def calculate_wer(reference, hypothesis):
    ref_words = reference.split()
    hyp_words = hypothesis.split()
    if len(ref_words) == 0:
        return 0.0 if len(hyp_words) == 0 else 1.0
    distance = Levenshtein.distance(' '.join(ref_words), ' '.join(hyp_words))
    return distance / len(ref_words)

def calculate_cer(reference, hypothesis):
    if len(reference) == 0:
        return 0.0 if len(hypothesis) == 0 else 1.0
    distance = Levenshtein.distance(reference, hypothesis)
    return distance / len(reference)


# ============================================================
#   E V A L U A T I O N    W I T H    G P U / M P S / C P U
# ============================================================

def evaluate_model(model, dataloader, model_type='MLP', device=None, num_samples=None):
    # -----------------------------
    # Device selection
    # -----------------------------
    if device is None:
        device = torch.device(
            "cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu"
        )
    print(f"\nEvaluating on device: {device}")

    model = model.to(device)
    model.eval()

    chars = "abcdefghijklmnopqrstuvwxyz '"
    char2idx = {c: i for i, c in enumerate(chars)}
    idx2char = {i: c for c, i in char2idx.items()}

    # Check sample rate - LibriSpeech is 16000Hz
    sample_rate = 16000

    # Transform for MLP
    mfcc_transform = torchaudio.transforms.MFCC(
        sample_rate=sample_rate,
        n_mfcc=13,
        melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23, "center": False}
    ).to(device)

    # Transform for CNN / GRU / Transformer
    melspec_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_mels=40,
        n_fft=400,
        hop_length=160
    ).to(device)

    total_wer = 0.0
    total_cer = 0.0
    num_processed = 0

    print("\n" + "="*80)
    print("EVALUATION RESULTS")
    print("="*80)

    with torch.no_grad():
        for batch_idx, (waveforms, sample_rates, transcripts, _, _, _) in enumerate(dataloader):

            for i, (waveform, sr, transcript) in enumerate(zip(waveforms, sample_rates, transcripts)):

                # Move waveform to device
                waveform = waveform.to(device)

                # ---------------------------------------------------
                # Feature extraction (Using pre-loaded transforms)
                # ---------------------------------------------------
                if model_type == 'MLP':
                    # Use the pre-initialized MFCC transform
                    features = mfcc_transform(waveform)
                    features = features.squeeze(0).transpose(0, 1)

                elif model_type in ['CNN', 'GRU', 'Transformer']:
                    # Use the pre-initialized MelSpec transform
                    features = melspec_transform(waveform)
                    features = features.clamp(min=1e-9).log2()

                    if model_type == 'CNN':
                        features = features.unsqueeze(0)
                    elif model_type in ['GRU', 'Transformer']:
                        # Expects [B, T, F] -> transpose
                        features = features.squeeze(0).transpose(0, 1).unsqueeze(0)

                else:
                    raise ValueError(f"Unknown model_type: {model_type}")

                # Forward → logits
                logits = model(features)

                # Remove batch dim if needed
                if model_type in ['CNN', 'GRU', 'Transformer']:
                    logits = logits.squeeze(0)

                log_probs = F.log_softmax(logits, dim=1)

                # Decode text
                predicted_text = decode_predictions(log_probs, idx2char)
                reference_text = transcript.lower()

                # Metrics
                wer = calculate_wer(reference_text, predicted_text)
                cer = calculate_cer(reference_text, predicted_text)

                total_wer += wer
                total_cer += cer
                num_processed += 1

                # Print first few samples
                if num_processed <= 5:
                    print(f"\nSample {num_processed}:")
                    print(f"  Reference:  {reference_text}")
                    print(f"  Predicted:  {predicted_text}")
                    print(f"  WER: {wer:.2%}, CER: {cer:.2%}")

                if num_samples and num_processed >= num_samples:
                    break

            if num_samples and num_processed >= num_samples:
                break

    avg_wer = total_wer / num_processed if num_processed > 0 else 0
    avg_cer = total_cer / num_processed if num_processed > 0 else 0

    print("\n" + "="*80)
    print(f"AVERAGE METRICS (over {num_processed} samples)")
    print(f"  Average WER: {avg_wer:.2%}")
    print(f"  Average CER: {avg_cer:.2%}")
    print("="*80 + "\n")

    return avg_wer, avg_cer


# ============================================================
#     L O A D   &   E V A L U A T E
# ============================================================

def load_and_evaluate(model_path=None, model_type='MLP', batch_size=8, num_workers=2, num_samples=50):

    # -----------------------------
    # Device selection
    # -----------------------------
    device = torch.device(
        "cuda" if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"\nUsing device: {device}")

    # Dataset
    eval_dataset = LIBRISPEECH("./data", url="test-clean", download=True)
    eval_loader = DataLoader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    # Model init
    print(f"Initializing {model_type} model...")
    if model_type == 'MLP':
        model = MLP(input_size=13, hidden_size=128, output_size=29)
    elif model_type == 'CNN':
        model = CNN(n_classes=29, n_mels=40)
    elif model_type == 'GRU':
        model = GRU(input_dim=40, hidden_dim=128, num_layers=2, n_classes=29)
    elif model_type == 'Transformer':
        model = Transformer(input_dim=40, d_model=256, nhead=8, num_layers=6, n_classes=29)
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # Load weights
    if model_path:
        print(f"Loading model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))
    else:
        print("Warning: No model_path provided → evaluating random weights")

    # Evaluate
    return evaluate_model(
        model,
        eval_loader,
        model_type=model_type,
        device=device,
        num_samples=num_samples
    )

In [13]:
# MLP
train_model(batch_size=8, num_workers=2, save_path=f'{MODEL_TYPE}.pth', model_type=f'{MODEL_TYPE}')

Using device: cuda
Data Downloaded !
Dataset size: 2703 samples
First sample shape: torch.Size([1, 93680])
Using MLP
Epoch 1, Batch 10, Avg Loss: 38.0441
Epoch 1, Batch 20, Avg Loss: 31.3048
Epoch 1, Batch 30, Avg Loss: 28.6114
Epoch 1, Batch 40, Avg Loss: 26.7469
Epoch 1, Batch 50, Avg Loss: 24.7071
Epoch 1, Batch 60, Avg Loss: 23.2611
Epoch 1, Batch 70, Avg Loss: 22.3891
Epoch 1, Batch 80, Avg Loss: 21.6588
Epoch 1, Batch 90, Avg Loss: 20.9450
Epoch 1, Batch 100, Avg Loss: 20.4959
Epoch 1, Batch 110, Avg Loss: 19.8567
Epoch 1, Batch 120, Avg Loss: 19.6249
Epoch 1, Batch 130, Avg Loss: 19.2286
Epoch 1, Batch 140, Avg Loss: 18.7564
Epoch 1, Batch 150, Avg Loss: 18.2734
Epoch 1, Batch 160, Avg Loss: 17.8654
Epoch 1, Batch 170, Avg Loss: 17.5288
Epoch 1, Batch 180, Avg Loss: 17.2822
Epoch 1, Batch 190, Avg Loss: 16.9694
Epoch 1, Batch 200, Avg Loss: 16.7779
Epoch 1, Batch 210, Avg Loss: 16.4548
Epoch 1, Batch 220, Avg Loss: 16.2362
Epoch 1, Batch 230, Avg Loss: 15.9949
Epoch 1, Batch 240

MLP(
  (network): Sequential(
    (0): Linear(in_features=13, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=128, out_features=29, bias=True)
  )
  (loss_fn): CTCLoss()
)

In [14]:
load_and_evaluate(
    model_path=f'{MODEL_TYPE}.pth',
    model_type=f'{MODEL_TYPE}',
    num_samples=20
)


Using device: cuda
Initializing MLP model...
Loading model from MLP.pth

Evaluating on device: cuda

EVALUATION RESULTS

Sample 1:
  Reference:  he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick peppered flour fattened sauce
  Predicted:  l
  WER: 560.71%, CER: 99.37%

Sample 2:
  Reference:  stuff it into you his belly counselled him
  Predicted:  l
  WER: 512.50%, CER: 97.62%

Sample 3:
  Reference:  after early nightfall the yellow lamps would light up here and there the squalid quarter of the brothels
  Predicted:  l
  WER: 572.22%, CER: 99.04%

Sample 4:
  Reference:  hello bertie any good in your mind
  Predicted:  l
  WER: 471.43%, CER: 97.06%

Sample 5:
  Reference:  number ten fresh nelly is waiting on you good night husband
  Predicted:  l
  WER: 527.27%, CER: 98.31%

AVERAGE METRICS (over 20 samples)
  Average WER: 528.66%
  Average CER: 98.72%



(5.286643914802452, 0.9872262922764088)

# Part 5 : Optuna Optimization

In [15]:
def collate_fn(batch):
    """Custom collate function to handle variable-length audio"""
    waveforms = [item[0] for item in batch]
    sample_rates = [item[1] for item in batch]
    transcripts = [item[2] for item in batch]

    # We only need these 3 for training/tuning
    return waveforms, sample_rates, transcripts

def get_data_loaders(batch_size=8, num_workers=2):
    """Downloads (if needed) and splits the dataset"""
    # Using dev-clean for tuning speed.
    os.makedirs("./data", exist_ok=True)
    dataset = LIBRISPEECH("./data", url="dev-clean", download=True)

    # Split: 80% Train, 20% Validation
    # We use a smaller subset for hyperparameter tuning to save time
    # (e.g., using only 50% of the data to find good params quickly)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_set, val_set = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_set, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, collate_fn=collate_fn
    )
    return train_loader, val_loader


In [16]:
class TunableGRU(nn.Module):
    def __init__(self, input_dim=40, hidden_dim=128, num_layers=2, dropout=0.1, n_classes=29):
        super().__init__()

        self.gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0.0
        )

        # Hidden dim * 2 because of bidirectional
        self.classifier = nn.Linear(hidden_dim * 2, n_classes)

    def forward(self, x):
        # x: [Batch, Time, Features]
        out, _ = self.gru(x)
        logits = self.classifier(out)
        return logits

In [None]:
def objective(trial):
    # --- A. Device Setup ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- B. Hyperparameter Suggestion ---
    # Model Params
    hidden_dim = trial.suggest_categorical("hidden_dim", [128, 256, 512])
    num_layers = trial.suggest_int("num_layers", 1, 3)
    dropout = trial.suggest_float("dropout", 0.1, 0.5)

    # Training Params
    lr = trial.suggest_float("lr", 1e-4, 5e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [8, 16])

    # --- C. Data & Transform Setup ---
    train_loader, val_loader = get_data_loaders(batch_size=batch_size)

    melspec_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=16000,
        n_mels=40,
        n_fft=400,
        hop_length=160
    ).to(device)

    char2idx = {c: i for i, c in enumerate("abcdefghijklmnopqrstuvwxyz '")}

    # --- D. Model & Optimizer ---
    model = TunableGRU(
        input_dim=40,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        dropout=dropout,
        n_classes=29
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    ctc_loss = nn.CTCLoss(blank=28, zero_infinity=True)

    # --- E. Training Loop (Shortened for Tuning) ---
    # We train for fewer epochs (e.g., 5) just to check convergence speed
    num_epochs = 5

    for epoch in range(num_epochs):
        model.train()

        for batch_idx, (waveforms, _, transcripts) in enumerate(train_loader):
            # 1. Feature Extraction
            specs = []
            targets = []
            target_lengths = []

            for i, waveform in enumerate(waveforms):
                # Move to GPU *before* transform to avoid device mismatch
                waveform = waveform.to(device)

                # Transform
                spec = melspec_transform(waveform) # [1, n_mels, time]
                spec = spec.squeeze(0).transpose(0, 1) # [time, n_mels]

                # STABILIZATION: Log & Normalize
                spec = spec.clamp(min=1e-9).log2()
                spec = (spec - spec.mean()) / (spec.std() + 1e-5)

                specs.append(spec)

                # Targets
                t_idx = [char2idx[c] for c in transcripts[i].lower() if c in char2idx]
                targets.extend(t_idx)
                target_lengths.append(len(t_idx))

            if not specs: continue

            # Pad sequences
            specs_padded = nn.utils.rnn.pad_sequence(specs, batch_first=True).to(device)
            input_lengths = torch.tensor([s.size(0) for s in specs], dtype=torch.long, device=device)

            targets_tensor = torch.tensor(targets, dtype=torch.long, device=device)
            target_lengths_tensor = torch.tensor(target_lengths, dtype=torch.long, device=device)

            # 2. Forward Pass
            optimizer.zero_grad()
            logits = model(specs_padded) # [B, T, Classes]

            # CTC expects [T, B, Classes] + LogSoftmax
            log_probs = F.log_softmax(logits, dim=2).transpose(0, 1)

            # 3. Loss & Backprop
            loss = ctc_loss(log_probs, targets_tensor, input_lengths, target_lengths_tensor)

            if torch.isnan(loss):
                # Prune trials that explode immediately
                raise optuna.TrialPruned()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Clip gradients
            optimizer.step()

        # --- F. Validation Step (End of Epoch) ---
        model.eval()
        val_loss_accum = 0.0
        val_batches = 0

        with torch.no_grad():
            for waveforms, _, transcripts in val_loader:
                specs = []
                targets = []
                target_lengths = []

                for i, waveform in enumerate(waveforms):
                    waveform = waveform.to(device)
                    spec = melspec_transform(waveform).squeeze(0).transpose(0, 1)
                    spec = spec.clamp(min=1e-9).log2()
                    spec = (spec - spec.mean()) / (spec.std() + 1e-5)
                    specs.append(spec)

                    t_idx = [char2idx[c] for c in transcripts[i].lower() if c in char2idx]
                    targets.extend(t_idx)
                    target_lengths.append(len(t_idx))

                if not specs: continue

                specs_padded = nn.utils.rnn.pad_sequence(specs, batch_first=True).to(device)
                input_lengths = torch.tensor([s.size(0) for s in specs], dtype=torch.long, device=device)
                targets_tensor = torch.tensor(targets, dtype=torch.long, device=device)
                target_lengths_tensor = torch.tensor(target_lengths, dtype=torch.long, device=device)

                logits = model(specs_padded)
                log_probs = F.log_softmax(logits, dim=2).transpose(0, 1)

                loss = ctc_loss(log_probs, targets_tensor, input_lengths, target_lengths_tensor)
                if not torch.isnan(loss) and not torch.isinf(loss):
                    val_loss_accum += loss.item()
                    val_batches += 1

        avg_val_loss = val_loss_accum / val_batches if val_batches > 0 else float('inf')

        # --- G. Reporting & Pruning ---
        trial.report(avg_val_loss, epoch)

        if trial.should_prune():
            print(f"Trial {trial.number} pruned at epoch {epoch} with loss {avg_val_loss:.4f}")
            raise optuna.TrialPruned()

    return avg_val_loss

In [None]:
if __name__ == "__main__":
    print("Starting Optuna Study...")

    # TPESampler is efficient at finding good regions
    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.TPESampler())

    # Run 10-20 trials
    study.optimize(objective, n_trials=10)

    print("\n" + "="*50)
    print("HYPERPARAMETER TUNING COMPLETE")
    print("="*50)
    print(f"Best Validation Loss: {study.best_value:.4f}")
    print("Best Hyperparameters:")
    for key, value in study.best_params.items():
        print(f"  {key}: {value}")

[I 2025-11-24 10:21:26,447] A new study created in memory with name: no-name-b6db3b8c-73cc-49a5-8e54-a21d57285cae


Starting Optuna Study...


[I 2025-11-24 10:24:08,331] Trial 0 finished with value: 1.7167845456039204 and parameters: {'hidden_dim': 256, 'num_layers': 1, 'dropout': 0.3406807446255472, 'lr': 0.0006922706622600342, 'batch_size': 8}. Best is trial 0 with value: 1.7167845456039204.
[I 2025-11-24 10:29:13,853] Trial 1 finished with value: 1.2736535282695995 and parameters: {'hidden_dim': 256, 'num_layers': 3, 'dropout': 0.4406820799990565, 'lr': 0.0011901830103776144, 'batch_size': 16}. Best is trial 1 with value: 1.2736535282695995.
