<a href="https://colab.research.google.com/github/milik0/speech-recognition/blob/main/RECOP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Reconnaissance Automatique de la Parole - Jour 1

In [40]:
import torchaudio
import torch
from torchaudio.datasets import LIBRISPEECH
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [41]:
!pip install Levenshtein



In [42]:
import os

def manage_data(batch_size, num_workers):
    train_dataset = LIBRISPEECH("./data", url="dev-clean", download=True)
    print("Data Downloaded !")

    print(f"Dataset size: {len(train_dataset)} samples")
    print(f"First sample shape: {train_dataset[0][0].shape}")

    # Create DataLoader for batch processing
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    return train_loader

def load_data():
    print("Downloading data")

    os.makedirs("data", exist_ok=True)
    # Use 'dev-clean' for testing, 'train-clean-100' for training
    train_loader = manage_data(batch_size=8, num_workers=2)

load_data()

Downloading data
Data Downloaded !
Dataset size: 2703 samples
First sample shape: torch.Size([1, 93680])


In [52]:
import torch
import torch.nn.functional as F

# -----------------------------
# DEVICE MANAGEMENT
# -----------------------------
def get_device():
    if torch.cuda.is_available():
        print("Using CUDA GPU")
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        print("Using Apple MPS")
        return torch.device("mps")
    else:
        print("Using CPU")
        return torch.device("cpu")

device = get_device()


Using CUDA GPU


In [53]:
import torchaudio
from torchaudio import transforms

def _MFCC(waveform, sample_rate):
    transform = transforms.MFCC(
        sample_rate=sample_rate,
        n_mfcc=13,
        melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23, "center": False}
    )
    mfcc = transform(waveform)
    return mfcc

# Models

In [54]:
import torch.nn as nn
import torchaudio

class MLP(nn.Module):
    def __init__(self, input_size=13, hidden_size=256, output_size=29, num_layers=3):
        super(MLP, self).__init__()

        layers = []
        layers.append(nn.Linear(input_size, hidden_size))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(0.2))

        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.2))

        layers.append(nn.Linear(hidden_size, output_size))

        self.network = nn.Sequential(*layers)
        self.loss_fn = nn.CTCLoss(blank=28, zero_infinity=True)

    def forward(self, x):
        return self.network(x)

    def loss(self, log_probs, targets, input_lengths, target_lengths):
        return self.loss_fn(log_probs, targets, input_lengths, target_lengths)

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, n_classes=30, n_mels=40):  # 30 phonemes or characters
        super().__init__()

        # 1D convolutions along time axis
        # Input channels = n_mels (frequency bins treated as channels)
        self.conv1 = nn.Conv1d(n_mels, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)

        self.pool = nn.MaxPool1d(2)  # pool along time axis

        self.dropout = nn.Dropout(0.3)

        # final classifier
        self.classifier = nn.Linear(512, n_classes)

    def forward(self, x):
        # x: [B, 1, F, T] from spectrogram (mel_spec gives [1, n_mels, time])
        # Reshape to [B, F, T] for 1D conv along time
        if x.dim() == 4:
            # Input is [B, 1, F, T]
            x = x.squeeze(1)  # -> [B, F, T]
        elif x.dim() == 3:
            # Input is already [B, F, T] or needs transpose
            # Check if last dim is n_mels (needs transpose)
            if x.shape[-1] == self.conv1.in_channels:
                x = x.transpose(1, 2)  # [B, T, F] -> [B, F, T]

        x = F.relu(self.conv1(x))
        x = self.pool(x)  # -> [B, 128, T/2]

        x = F.relu(self.conv2(x))
        x = self.pool(x)  # -> [B, 256, T/4]

        x = F.relu(self.conv3(x))
        x = self.pool(x)  # -> [B, 512, T/8]

        # Transpose to [B, T', C] for time-distributed classification
        x = x.transpose(1, 2)  # -> [B, T', 512]

        x = self.dropout(x)
        out = self.classifier(x)  # [B, T', n_classes]

        return out


In [56]:
import torch.nn as nn

class GRU(nn.Module):
    def __init__(self, input_dim=40, hidden_dim=128, num_layers=2, n_classes=30):
        super().__init__()

        self.gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )

        self.classifier = nn.Linear(hidden_dim * 2, n_classes)

    def forward(self, x):
        # x: [B, T, F]
        out, _ = self.gru(x)       # out: [B, T, hidden*2]
        logits = self.classifier(out)
        return logits              # [B, T, n_classes]


In [65]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: [B, T, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return x

class Transformer(nn.Module):
    def __init__(self, input_dim=40, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, n_classes=29, dropout=0.1):
        """
        Transformer-based speech recognition model

        Args:
            input_dim: Input feature dimension (e.g., 40 for mel spectrograms)
            d_model: Dimension of the model
            nhead: Number of attention heads
            num_layers: Number of transformer encoder layers
            dim_feedforward: Dimension of feedforward network
            n_classes: Number of output classes
            dropout: Dropout rate
        """
        super().__init__()

        # Input projection
        self.input_projection = nn.Linear(input_dim, d_model)

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output classifier
        self.classifier = nn.Linear(d_model, n_classes)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, T, F] where F is input_dim (e.g., 40 mel bins)

        # Project to d_model
        x = self.input_projection(x)  # [B, T, d_model]

        # Add positional encoding
        x = self.pos_encoder(x)

        # Apply dropout
        x = self.dropout(x)

        # Transformer encoding
        x = self.transformer_encoder(x)  # [B, T, d_model]

        # Classification
        logits = self.classifier(x)  # [B, T, n_classes]

        return logits

# Train

In [66]:
import torchaudio
import torch
from torchaudio.datasets import LIBRISPEECH
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Model selection parameter - change this to 'MLP', 'CNN', 'GRU', or 'Transformer'
MODEL_TYPE = 'Transformer'  # Options: 'MLP', 'CNN', 'GRU', 'Transformer'

def collate_fn(batch):
    """Custom collate function to handle variable-length audio"""
    waveforms = [item[0] for item in batch]
    sample_rates = [item[1] for item in batch]
    transcripts = [item[2] for item in batch]
    speaker_ids = [item[3] for item in batch]
    chapter_ids = [item[4] for item in batch]
    utterance_ids = [item[5] for item in batch]

    return waveforms, sample_rates, transcripts, speaker_ids, chapter_ids, utterance_ids

def manage_data(batch_size, num_workers):
    train_dataset = LIBRISPEECH("./data", url="train-clean-100", download=True)
    print("Data Downloaded !")

    print(f"Dataset size: {len(train_dataset)} samples")
    print(f"First sample shape: {train_dataset[0][0].shape}")

    # Create DataLoader for batch processing
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn
    )
    return train_loader

def train_model(batch_size, num_workers, num_epochs=2, save_path='model_checkpoint.pth', model_type=MODEL_TYPE):

    # -------------------------------
    # GPU / MPS / CPU management
    # -------------------------------
    device = torch.device(
        "cuda" if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using device: {device}")

    train_loader = manage_data(batch_size, num_workers)

    # Initialize model based on type
    if model_type == 'MLP':
        model = MLP(input_size=13, hidden_size=128, output_size=29)
        print("Using MLP")
    elif model_type == 'CNN':
        model = CNN(n_classes=29, n_mels=40)
        print("Using CNN")
    elif model_type == 'GRU':
        model = GRU(input_dim=40, hidden_dim=128, num_layers=2, n_classes=29)
        print("Using GRU")
    elif model_type == 'Transformer':
        model = Transformer(input_dim=40, d_model=256, nhead=8, num_layers=6, n_classes=29)
        print("Using Transformer")
    else:
        raise ValueError("Unknown model_type")

    # Move model to GPU/MPS/CPU
    model = model.to(device)
    model.train(True)

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    ctc_loss = torch.nn.CTCLoss(blank=28, zero_infinity=True)

    char2idx = {c: i for i, c in enumerate("abcdefghijklmnopqrstuvwxyz '")}

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for batch_idx, (waveforms, sample_rates, transcripts, _, _, _) in enumerate(train_loader):

            batch_log_probs = []
            batch_targets = []
            batch_input_lengths = []
            batch_target_lengths = []

            for i, (waveform, sample_rate) in enumerate(zip(waveforms, sample_rates)):

                # Move waveform to device
                waveform = waveform.to(device)

                # -------------------------------
                # Feature extraction (on device)
                # -------------------------------
                if model_type == 'MLP':
                    features = _MFCC(waveform, sample_rate)  # should return on same device
                    features = features.squeeze(0).transpose(0, 1)

                elif model_type == 'CNN':
                    mel_spec_transform = torchaudio.transforms.MelSpectrogram(
                        sample_rate=sample_rate,
                        n_mels=40,
                        n_fft=400,
                        hop_length=80
                    ).to(device)

                    features = mel_spec_transform(waveform)
                    features = features.clamp(min=1e-9).log2()
                    features = features.unsqueeze(0)

                elif model_type == 'GRU':
                    mel_spec_transform = torchaudio.transforms.MelSpectrogram(
                        sample_rate=sample_rate,
                        n_mels=40,
                        n_fft=400,
                        hop_length=80
                    ).to(device)

                    features = mel_spec_transform(waveform)
                    features = features.clamp(min=1e-9).log2()
                    # GRU expects [B, T, F] -> transpose
                    features = features.squeeze(0).transpose(0, 1).unsqueeze(0)  # [1, T, 40]

                elif model_type == 'Transformer':
                    mel_spec_transform = torchaudio.transforms.MelSpectrogram(
                        sample_rate=sample_rate,
                        n_mels=40,
                        n_fft=400,
                        hop_length=80
                    ).to(device)

                    features = mel_spec_transform(waveform)
                    features = features.clamp(min=1e-9).log2()
                    # Transformer expects [B, T, F]
                    features = features.squeeze(0).transpose(0, 1).unsqueeze(0)  # [1, T, 40]

                # Target on device
                target = torch.tensor(
                    [char2idx[c] for c in transcripts[i].lower() if c in char2idx],
                    dtype=torch.long,
                    device=device
                )

                # Length checks
                if model_type == 'MLP':
                    input_len = features.size(0)
                elif model_type == 'CNN':
                    input_len = features.size(-1) // 8
                elif model_type == 'GRU':
                    input_len = features.size(1)  # Time dimension for GRU
                elif model_type == 'Transformer':
                    input_len = features.size(1)  # Time dimension for Transformer

                if input_len <= len(target):
                    print(f"Skipping sample {i}: input_len={input_len}, target_len={len(target)}")
                    continue

                # Forward
                logits = model(features)
                if model_type == 'CNN':
                    logits = logits.squeeze(0)
                elif model_type == 'GRU':
                    logits = logits.squeeze(0)
                elif model_type == 'Transformer':
                    logits = logits.squeeze(0)

                log_probs = F.log_softmax(logits, dim=1)

                batch_log_probs.append(log_probs)
                batch_targets.append(target)
                batch_input_lengths.append(log_probs.size(0))
                batch_target_lengths.append(len(target))

            if len(batch_log_probs) == 0:
                continue

            # -------------------------------
            # Pad sequences (on device)
            # -------------------------------
            max_input_len = max(batch_input_lengths)
            padded_log_probs = torch.zeros(
                max_input_len, len(batch_log_probs), 29,
                device=device
            )

            for i, log_probs in enumerate(batch_log_probs):
                padded_log_probs[:log_probs.size(0), i, :] = log_probs

            concatenated_targets = torch.cat(batch_targets)
            input_lengths = torch.tensor(batch_input_lengths, dtype=torch.long, device=device)
            target_lengths = torch.tensor(batch_target_lengths, dtype=torch.long, device=device)

            optimizer.zero_grad()
            loss = ctc_loss(padded_log_probs, concatenated_targets, input_lengths, target_lengths)

            if torch.isnan(loss):
                print(f"NaN loss detected at batch {batch_idx}")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Avg Loss: {total_loss/num_batches:.4f}")

        print(f"\nEpoch {epoch+1} completed. Avg loss = {total_loss/num_batches:.4f}")

    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

    return model


# Evaluation

In [67]:
import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.datasets import LIBRISPEECH
from torch.utils.data import DataLoader
import Levenshtein


def decode_predictions(log_probs, idx2char, blank_idx=28):
    predictions = torch.argmax(log_probs, dim=1)

    decoded = []
    prev_char = None
    for pred in predictions:
        pred_idx = pred.item()
        if pred_idx != blank_idx and pred_idx != prev_char:
            decoded.append(idx2char[pred_idx])
        prev_char = pred_idx
    return ''.join(decoded)

def calculate_wer(reference, hypothesis):
    ref_words = reference.split()
    hyp_words = hypothesis.split()
    if len(ref_words) == 0:
        return 0.0 if len(hyp_words) == 0 else 1.0
    distance = Levenshtein.distance(' '.join(ref_words), ' '.join(hyp_words))
    return distance / len(ref_words)

def calculate_cer(reference, hypothesis):
    if len(reference) == 0:
        return 0.0 if len(hypothesis) == 0 else 1.0
    distance = Levenshtein.distance(reference, hypothesis)
    return distance / len(reference)


# ============================================================
#   E V A L U A T I O N    W I T H    G P U / M P S / C P U
# ============================================================

def evaluate_model(model, dataloader, model_type='MLP', device=None, num_samples=None):

    # -----------------------------
    # Device selection
    # -----------------------------
    if device is None:
        device = torch.device(
            "cuda" if torch.cuda.is_available()
            else "mps" if torch.backends.mps.is_available()
            else "cpu"
        )
    print(f"\nEvaluating on device: {device}")

    model = model.to(device)
    model.eval()

    chars = "abcdefghijklmnopqrstuvwxyz '"
    char2idx = {c: i for i, c in enumerate(chars)}
    idx2char = {i: c for c, i in char2idx.items()}

    total_wer = 0.0
    total_cer = 0.0
    num_processed = 0

    print("\n" + "="*80)
    print("EVALUATION RESULTS")
    print("="*80)

    with torch.no_grad():
        for batch_idx, (waveforms, sample_rates, transcripts, _, _, _) in enumerate(dataloader):

            for i, (waveform, sample_rate, transcript) in enumerate(
                zip(waveforms, sample_rates, transcripts)
            ):

                waveform = waveform.to(device)

                # ---------------------------------------------------
                # Feature extraction (GPU where possible)
                # ---------------------------------------------------
                if model_type == 'MLP':

                    features = _MFCC(waveform, sample_rate)  # your MFCC fn
                    features = features.squeeze(0).transpose(0, 1)
                    features = features.to(device)

                elif model_type == 'CNN':

                    mel_spec_transform = torchaudio.transforms.MelSpectrogram(
                        sample_rate=sample_rate,
                        n_mels=40,
                        n_fft=400,
                        hop_length=160
                    )

                    # Some torchaudio kernels cannot run on MPS/GPU → move after transform
                    mel_spec_transform = mel_spec_transform.to(device)

                    features = mel_spec_transform(waveform)
                    features = features.clamp(min=1e-9).log2()
                    features = features.unsqueeze(0).to(device)

                elif model_type == 'GRU':

                    mel_spec_transform = torchaudio.transforms.MelSpectrogram(
                        sample_rate=sample_rate,
                        n_mels=40,
                        n_fft=400,
                        hop_length=160
                    )

                    mel_spec_transform = mel_spec_transform.to(device)

                    features = mel_spec_transform(waveform)
                    features = features.clamp(min=1e-9).log2()
                    # GRU expects [B, T, F]
                    features = features.squeeze(0).transpose(0, 1).unsqueeze(0).to(device)  # [1, T, 40]

                elif model_type == 'Transformer':

                    mel_spec_transform = torchaudio.transforms.MelSpectrogram(
                        sample_rate=sample_rate,
                        n_mels=40,
                        n_fft=400,
                        hop_length=160
                    )

                    mel_spec_transform = mel_spec_transform.to(device)

                    features = mel_spec_transform(waveform)
                    features = features.clamp(min=1e-9).log2()
                    # Transformer expects [B, T, F]
                    features = features.squeeze(0).transpose(0, 1).unsqueeze(0).to(device)  # [1, T, 40]

                else:
                    raise ValueError(f"Unknown model_type: {model_type}")

                # Forward → logits
                logits = model(features)
                if model_type == 'CNN':
                    logits = logits.squeeze(0)
                elif model_type == 'GRU':
                    logits = logits.squeeze(0)
                elif model_type == 'Transformer':
                    logits = logits.squeeze(0)

                log_probs = F.log_softmax(logits, dim=1)

                # Decode text
                predicted_text = decode_predictions(log_probs, idx2char)
                reference_text = transcript.lower()

                # Metrics
                wer = calculate_wer(reference_text, predicted_text)
                cer = calculate_cer(reference_text, predicted_text)

                total_wer += wer
                total_cer += cer
                num_processed += 1

                # Print first few samples
                if num_processed <= 5:
                    print(f"\nSample {num_processed}:")
                    print(f"  Reference:  {reference_text}")
                    print(f"  Predicted:  {predicted_text}")
                    print(f"  WER: {wer:.2%}, CER: {cer:.2%}")

                if num_samples and num_processed >= num_samples:
                    break

            if num_samples and num_processed >= num_samples:
                break

    avg_wer = total_wer / num_processed if num_processed > 0 else 0
    avg_cer = total_cer / num_processed if num_processed > 0 else 0

    print("\n" + "="*80)
    print(f"AVERAGE METRICS (over {num_processed} samples)")
    print(f"  Average WER: {avg_wer:.2%}")
    print(f"  Average CER: {avg_cer:.2%}")
    print("="*80 + "\n")

    return avg_wer, avg_cer


# ============================================================
#     L O A D   &   E V A L U A T E
# ============================================================

def load_and_evaluate(model_path=None, model_type='MLP', batch_size=8, num_workers=2, num_samples=50):

    # -----------------------------
    # Device selection
    # -----------------------------
    device = torch.device(
        "cuda" if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"\nUsing device: {device}")

    # Dataset
    eval_dataset = LIBRISPEECH("./data", url="dev-clean", download=True)
    eval_loader = DataLoader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    # Model init
    print(f"Initializing {model_type} model...")
    if model_type == 'MLP':
        model = MLP(input_size=13, hidden_size=128, output_size=29)
    elif model_type == 'CNN':
        model = CNN(n_classes=29, n_mels=40)
    elif model_type == 'GRU':
        model = GRU(input_dim=40, hidden_dim=128, num_layers=2, n_classes=29)
    elif model_type == 'Transformer':
        model = Transformer(input_dim=40, d_model=256, nhead=8, num_layers=6, n_classes=29)
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # Load weights
    if model_path:
        print(f"Loading model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))
    else:
        print("Warning: No model_path provided → evaluating random weights")

    # Evaluate
    return evaluate_model(
        model,
        eval_loader,
        model_type=model_type,
        device=device,
        num_samples=num_samples
    )




In [None]:
# MLP
train_model(batch_size=8, num_workers=2)

Using device: cuda
Data Downloaded !
Dataset size: 28539 samples
First sample shape: torch.Size([1, 225360])
Using Transformer
Epoch 1, Batch 10, Avg Loss: 6.4464
Epoch 1, Batch 20, Avg Loss: 4.6797
Epoch 1, Batch 30, Avg Loss: 4.0824
Epoch 1, Batch 40, Avg Loss: 3.7850
Epoch 1, Batch 50, Avg Loss: 3.6060
Epoch 1, Batch 60, Avg Loss: 3.4881
Epoch 1, Batch 70, Avg Loss: 3.4064
Epoch 1, Batch 80, Avg Loss: 3.3411


In [63]:
load_and_evaluate(
    model_path='gru.pth',
    model_type='Transformers',
    num_samples=20
)




Using device: cuda
Initializing GRU model...
Loading model from gru.pth

Evaluating on device: cuda

EVALUATION RESULTS

Sample 1:
  Reference:  mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
  Predicted:  itecuters ipuo melcasis er  loe is guto
  WER: 347.06%, CER: 65.17%

Sample 2:
  Reference:  nor is mister quilter's manner less interesting than his matter
  Predicted:  norismeteoers mam lisenvrtin then is mater
  WER: 280.00%, CER: 44.44%

Sample 3:
  Reference:  he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind
  Predicted:  etois the e thetueesonno er wi cemesrs bein befrs solas orntoeting orso e cr marl ote m
  WER: 318.75%, CER: 59.30%

Sample 4:
  Reference:  he has grave doubts whether sir frederick leighton's work is really greek after all and can discover in it but little of rocky ithaca
  Predicted:  hes gra

(3.259282575603885, 0.5988340362371821)