In [None]:
# -*- coding: utf-8 -*-
"""Transformer-based Speech Recognition with Noise Augmentation"""

In [None]:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import string
import math
from torch.utils.data import DataLoader

In [None]:
# Ensure data directory exists
os.makedirs("./data", exist_ok=True)

In [None]:
# -------------------- Noise Augmentation --------------------
class AddGaussianNoise(nn.Module):
    def __init__(self, noise_level=0.005):
        super().__init__()
        self.noise_level = noise_level

    def forward(self, waveform):
        if self.training:
            noise = torch.randn_like(waveform) * self.noise_level
            return waveform + noise
        return waveform

In [None]:
# -------------------- Audio Transforms --------------------
train_audio_transforms = nn.Sequential(
    AddGaussianNoise(0.01),
    torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
    torchaudio.transforms.TimeMasking(time_mask_param=100)
)
valid_audio_transforms = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)


In [None]:
# -------------------- Text Processing --------------------
class TextTransform:
    def __init__(self):
        self.chars = ["'", '<SPACE>'] + list(string.ascii_lowercase)
        self.char_map = {c: i for i, c in enumerate(self.chars)}
        self.index_map = {i: c for i, c in enumerate(self.chars)}
        self.index_map[self.char_map['<SPACE>']] = ' '

    def text_to_int(self, text):
        return [self.char_map.get(c, self.char_map['<SPACE>']) for c in text.lower()]

    def int_to_text(self, labels):
        return ''.join([self.index_map[i] for i in labels]).replace('<SPACE>', ' ')

text_transform = TextTransform()


In [None]:
# -------------------- Data Processing --------------------
def data_processing(data, data_type="train"):
    specs, labels = [], []
    input_lengths, label_lengths = [], []
    transform = train_audio_transforms if data_type == 'train' else valid_audio_transforms
    
    for (waveform, _, utterance, *_ ) in data:
        spec = transform(waveform).squeeze(0).transpose(0, 1)
        specs.append(spec)
        label = torch.tensor(text_transform.text_to_int(utterance))
        labels.append(label)
        input_lengths.append(spec.shape[0] // 2)
        label_lengths.append(len(label))

    specs = nn.utils.rnn.pad_sequence(specs, batch_first=True)
    specs = specs.unsqueeze(1).transpose(2, 3)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
    return specs, labels, input_lengths, label_lengths

In [None]:
# -------------------- Model Components --------------------
class CNNLayerNorm(nn.Module):
    def __init__(self, n_feats):
        super().__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        x = x.transpose(2, 3).contiguous()
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous()

class ResidualCNN(nn.Module):
    def __init__(self, in_ch, out_ch, kernel, stride, dropout, n_feats):
        super().__init__()
        self.cnn1 = nn.Conv2d(in_ch, out_ch, kernel, stride, padding=kernel//2)
        self.cnn2 = nn.Conv2d(out_ch, out_ch, kernel, stride, padding=kernel//2)
        self.dropout = nn.Dropout(dropout)
        self.ln1 = CNNLayerNorm(n_feats)
        self.ln2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x
        x = self.ln1(x)
        x = nn.GELU()(x)
        x = self.dropout(x)
        x = self.cnn1(x)
        x = self.ln2(x)
        x = nn.GELU()(x)
        x = self.dropout(x)
        x = self.cnn2(x)
        return x + residual

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


In [None]:

# -------------------- Main Model --------------------
class SpeechRecognitionModel(nn.Module):
    def __init__(self, n_cnn_layers=3, n_rnn_layers=5, rnn_dim=512,
                 n_class=29, n_feats=128, stride=2, dropout=0.2):
        super().__init__()
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=1)
        self.rescnn = nn.Sequential(*[
            ResidualCNN(32, 32, 3, 1, dropout, n_feats//2)
            for _ in range(n_cnn_layers)
        ])
        self.linear = nn.Linear(32 * (n_feats//2), rnn_dim)
        self.pos_encoder = PositionalEncoding(rnn_dim, dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=rnn_dim,
            nhead=8,
            dim_feedforward=2048,
            dropout=dropout,
            activation='gelu'
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_rnn_layers)
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim, rnn_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x, input_lengths):
        x = self.cnn(x)
        x = self.rescnn(x)
        batch, ch, feat_dim, seq_len = x.size()
        x = x.permute(0, 3, 1, 2).contiguous()
        x = x.view(batch, seq_len, ch * feat_dim)
        x = self.linear(x)
        x = x.permute(1, 0, 2)
        x = self.pos_encoder(x)
        
        # Create padding mask
        max_len = x.size(0)
        mask = torch.zeros(batch, max_len, dtype=torch.bool, device=x.device)
        for i, length in enumerate(input_lengths):
            if length < max_len:
                mask[i, length:] = True

        x = self.transformer_encoder(x, src_key_padding_mask=mask)
        x = x.permute(1, 0, 2)
        return self.classifier(x)


In [None]:
# -------------------- Training Setup --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
params = {
    "batch_size": 16,
    "epochs": 20,
    "lr": 3e-4,
    "n_cnn_layers": 3,
    "n_rnn_layers": 5,
    "rnn_dim": 512,
    "n_class": 29,
    "n_feats": 128,
    "stride": 2,
    "dropout": 0.2
}

model = SpeechRecognitionModel(**{k: params[k] for k in [
    'n_cnn_layers','n_rnn_layers','rnn_dim','n_class','n_feats','stride','dropout']
}).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'])
criterion = nn.CTCLoss(blank=28).to(device)

In [None]:

# -------------------- Training Utilities --------------------
def decode(outputs):
    _, preds = torch.max(outputs, dim=2)
    return [text_transform.int_to_text(p.tolist()) for p in preds]

def levenshtein_distance(a, b):
    m, n = len(a), len(b)
    dp = [[0]*(n+1) for _ in range(m+1)]
    
    for i in range(m+1):
        for j in range(n+1):
            if i == 0:
                dp[i][j] = j
            elif j == 0:
                dp[i][j] = i
            else:
                cost = 0 if a[i-1] == b[j-1] else 1
                dp[i][j] = min(dp[i-1][j]+1, dp[i][j-1]+1, dp[i-1][j-1]+cost)
    return dp[m][n]

def wer(ref, hyp):
    ref_words = ref.split()
    hyp_words = hyp.split()
    return levenshtein_distance(ref_words, hyp_words) / max(len(ref_words), 1)

def cer(ref, hyp):
    return levenshtein_distance(ref, hyp) / max(len(ref), 1)

In [None]:

# -------------------- Training Execution --------------------
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch_idx, (specs, labels, input_lens, label_lens) in enumerate(loader):
        specs, labels = specs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(specs, input_lens)
        outputs = F.log_softmax(outputs, dim=2).transpose(0, 1)
        
        loss = criterion(outputs, labels, input_lens, label_lens)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx}/{len(loader)} Loss: {loss.item():.4f}")
    
    return total_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = total_cer = total_wer = 0
    with torch.no_grad():
        for specs, labels, input_lens, label_lens in loader:
            specs, labels = specs.to(device), labels.to(device)
            outputs = model(specs, input_lens)
            outputs = F.log_softmax(outputs, dim=2).transpose(0, 1)
            
            loss = criterion(outputs, labels, input_lens, label_lens)
            total_loss += loss.item()
            
            pred_texts = decode(outputs.transpose(0, 1))
            true_texts = [text_transform.int_to_text(l.tolist()) for l in labels]
            
            for ref, hyp in zip(true_texts, pred_texts):
                total_cer += cer(ref, hyp)
                total_wer += wer(ref, hyp)
    
    avg_loss = total_loss / len(loader)
    avg_cer = total_cer / len(loader.dataset)
    avg_wer = total_wer / len(loader.dataset)
    print(f"Validation Loss: {avg_loss:.4f} | CER: {avg_cer:.4f} | WER: {avg_wer:.4f}")
    return avg_loss, avg_cer, avg_wer

In [None]:
# -------------------- Main Execution --------------------
if __name__ == "__main__":
    train_dataset = torchaudio.datasets.LIBRISPEECH(
        root="./data", url="train-clean-100", download=True)
    test_dataset = torchaudio.datasets.LIBRISPEECH(
        root="./data", url="test-clean", download=True)

    train_loader = DataLoader(
        train_dataset,
        batch_size=params["batch_size"],
        shuffle=True,
        collate_fn=lambda x: data_processing(x, "train")
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=params["batch_size"],
        collate_fn=lambda x: data_processing(x, "valid")
    )

    best_wer = float('inf')
    for epoch in range(params["epochs"]):
        print(f"\nEpoch {epoch+1}/{params['epochs']}")
        train_loss = train(model, train_loader, optimizer, criterion, device)
        val_loss, val_cer, val_wer = validate(model, test_loader, criterion, device)
        
        if val_wer < best_wer:
            best_wer = val_wer
            torch.save(model.state_dict(), "best_model_transformer.pth")
            print("Saved new best model!")