In [1]:
# -*- coding: utf-8 -*-
"""Advanced Transformer-based Speech Recognition System (Optimized)"""

'Advanced Transformer-based Speech Recognition System (Optimized)'

In [2]:
import os
import torch
import math
import string
import torch.nn as nn
import torchaudio
import numpy as np
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.nn.utils.rnn import pad_sequence
from torchaudio.datasets import LIBRISPEECH
from typing import Tuple, List, Dict, Optional
from pathlib import Path


In [3]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

In [4]:
# -------------------- Configuration --------------------
class Config:
    # Audio parameters
    sample_rate = 16000
    n_mels = 64
    win_length = 320
    hop_length = 160
    n_fft = 320
    max_audio_length = 10  # Seconds
    
    # Vocabulary
    vocab = list(string.ascii_lowercase) + [' ', "'", '<blank>']
    
    # Model architecture
    cnn_channels = 32
    num_cnn_layers = 2
    encoder_layers = 4
    attention_heads = 4
    ff_dim = 1024
    dropout = 0.1
    emb_dim = 256
    
    # Training parameters
    batch_size = 4
    gradient_accumulation = 4
    epochs = 30
    lr = 3e-4
    weight_decay = 1e-5
    max_grad_norm = 5.0
    warmup_steps = 1000
    
    # Augmentation
    noise_snr = (15, 20)
    time_mask = 20
    freq_mask = 8

config = Config()

In [5]:
# -------------------- Audio Augmentations --------------------
class AddGaussianNoise(nn.Module):
    def __init__(self, min_snr=5, max_snr=20):
        super().__init__()
        self.min_snr = min_snr
        self.max_snr = max_snr

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        if self.training:
            snr_db = torch.empty(1).uniform_(self.min_snr, self.max_snr)
            snr_linear = 10 ** (snr_db / 20)
            signal_power = torch.mean(waveform ** 2)
            noise_power = signal_power / snr_linear
            noise = torch.randn_like(waveform) * torch.sqrt(noise_power)
            return waveform + noise
        return waveform

class SpecAugment(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_mask = torchaudio.transforms.TimeMasking(config.time_mask)
        self.freq_mask = torchaudio.transforms.FrequencyMasking(config.freq_mask)

    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        if self.training:
            spec = self.time_mask(spec)
            spec = self.freq_mask(spec)
        return spec

In [6]:
# -------------------- Audio Processing --------------------
class AudioProcessor:
    @staticmethod
    def process(waveform: torch.Tensor, train: bool = True) -> torch.Tensor:
        max_samples = config.sample_rate * config.max_audio_length
        if waveform.shape[-1] > max_samples:
            waveform = waveform[..., :max_samples]
            
        transforms = [
            torchaudio.transforms.MelSpectrogram(
                sample_rate=config.sample_rate,
                n_mels=config.n_mels,
                n_fft=config.n_fft,
                win_length=config.win_length,
                hop_length=config.hop_length
            ),
            torchaudio.transforms.AmplitudeToDB(),
        ]
        
        if train:
            transforms = [AddGaussianNoise(*config.noise_snr)] + transforms
            transforms.append(SpecAugment())
            
        return nn.Sequential(*transforms)(waveform)

In [7]:
# -------------------- Neural Modules --------------------
class DepthwiseConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels),
            nn.Conv2d(in_channels, out_channels, 1),
            nn.GELU(),
            nn.Dropout(config.dropout)
        )
        self.norm = nn.LayerNorm(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layers(x)
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        return x.permute(0, 3, 1, 2)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:x.size(1)]
        return self.dropout(x)

class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.pos_encoder = PositionalEncoding(config.emb_dim, config.dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.emb_dim,
            nhead=config.attention_heads,
            dim_feedforward=config.ff_dim,
            dropout=config.dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, config.encoder_layers)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x = self.pos_encoder(x)
        return self.encoder(x, src_key_padding_mask=mask)

class ASRTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.frontend = nn.Sequential(
            DepthwiseConv(1, config.cnn_channels),
            *[DepthwiseConv(config.cnn_channels, config.cnn_channels) 
              for _ in range(config.num_cnn_layers-1)]
        )
        self.projection = nn.Linear(config.cnn_channels * config.n_mels, config.emb_dim)
        self.encoder = TransformerEncoder()
        self.classifier = nn.Sequential(
            nn.LayerNorm(config.emb_dim),
            nn.Linear(config.emb_dim, len(config.vocab) + 1)
        )

    def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        x = self.frontend(x)
        x = x.permute(0, 2, 1, 3).flatten(2, 3)
        x = self.projection(x)
        mask = self.create_mask(lengths, x.size(1))
        x = self.encoder(x, mask)
        return self.classifier(x)

    def create_mask(self, lengths: torch.Tensor, max_len: int) -> torch.Tensor:
        return (torch.arange(max_len, device=lengths.device)[None, :] >= lengths[:, None])

In [8]:
# -------------------- Training System --------------------
class CTCTrainer:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = ASRTransformer().to(self.device)
        self.optimizer = AdamW(self.model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        self.scaler = torch.amp.GradScaler()
        self.criterion = nn.CTCLoss(blank=0, zero_infinity=True)
        
    def train_step(self, batch: Tuple, batch_idx: int) -> float:
        self.model.train()
        specs, labels, spec_lens, label_lens = [t.to(self.device, non_blocking=True) for t in batch]
        
        with torch.autocast(device_type=self.device.type):
            logits = self.model(specs, spec_lens)
            log_probs = torch.log_softmax(logits, dim=-1)
            loss = self.criterion(
                log_probs.permute(1, 0, 2),
                labels,
                spec_lens,
                label_lens
            ) / config.gradient_accumulation
            
        self.scaler.scale(loss).backward()
        
        if (batch_idx + 1) % config.gradient_accumulation == 0:
            self.scaler.unscale_(self.optimizer)
            nn.utils.clip_grad_norm_(self.model.parameters(), config.max_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad(set_to_none=True)
            
            if batch_idx % (10 * config.gradient_accumulation) == 0:
                torch.cuda.empty_cache()
                
        return loss.item() * config.gradient_accumulation

    def validate(self, valid_loader) -> Tuple[float, float]:
        self.model.eval()
        total_loss, total_wer, total_samples = 0.0, 0.0, 0
        
        with torch.no_grad():
            for batch in valid_loader:
                specs, labels, spec_lens, label_lens = [t.to(self.device) for t in batch]
                
                with torch.autocast(device_type=self.device.type):
                    logits = self.model(specs, spec_lens)
                    log_probs = torch.log_softmax(logits, dim=-1)
                    loss = self.criterion(
                        log_probs.permute(1, 0, 2),
                        labels,
                        spec_lens,
                        label_lens
                    )
                total_loss += loss.item() * specs.size(0)
                
                preds = logits.argmax(dim=-1).cpu()
                decoded_preds = [self._decode_ctc(p) for p in preds]
                
                labels_cpu = labels.cpu().numpy()
                label_lens_cpu = label_lens.cpu().numpy()
                decoded_labels = [
                    self._decode_ctc(label[:l]) 
                    for label, l in zip(labels_cpu, label_lens_cpu)
                ]
                
                for pred, label in zip(decoded_preds, decoded_labels):
                    total_wer += word_error_rate(label, pred)
                    total_samples += 1
        
        return total_loss / len(valid_loader.dataset), total_wer / total_samples

    def _decode_ctc(self, sequence: np.ndarray) -> str:
        chars = ['<blank>'] + config.vocab
        decoded = []
        prev = 0
        for idx in sequence:
            if idx != prev:
                if idx != 0:
                    decoded.append(chars[idx])
                prev = idx
        return ''.join(decoded).strip()

def word_error_rate(ref: str, hyp: str) -> float:
    ref_words = ref.split()
    hyp_words = hyp.split()
    if len(ref_words) == 0:
        return 0.0
    return levenshtein(ref_words, hyp_words) / len(ref_words)

def levenshtein(a: List[str], b: List[str]) -> int:
    m, n = len(a), len(b)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m+1):
        for j in range(n+1):
            if i == 0:
                dp[i][j] = j
            elif j == 0:
                dp[i][j] = i
            else:
                cost = 0 if a[i-1] == b[j-1] else 1
                dp[i][j] = min(dp[i-1][j]+1, dp[i][j-1]+1, dp[i-1][j-1]+cost)
    return dp[m][n]


In [9]:
# -------------------- Data Pipeline --------------------
class ASRDataset(Dataset):
    def __init__(self, subset: str):
        self.subset = subset
        self.dataset = LIBRISPEECH(root="./data", url=subset, download=True)
        self.vocab = {c: i for i, c in enumerate(['<blank>'] + config.vocab)}
        self.lengths = [self._get_trimmed_length(i) for i in range(len(self.dataset))]
        
    def _get_trimmed_length(self, idx: int) -> int:
        waveform, sample_rate, _, _, n_samples, _ = self.dataset[idx]
        return min(n_samples, config.sample_rate * config.max_audio_length)
        
    def __getitem__(self, idx: int) -> Tuple:
        waveform, _, text, *_ = self.dataset[idx]
        spec = AudioProcessor.process(waveform, train='train' in self.subset)
        labels = torch.tensor([self.vocab.get(c, self.vocab[' ']) for c in text.lower()])
        return spec.squeeze(0).T, labels
    
    def __len__(self) -> int:
        return len(self.dataset)

def collate_fn(batch: List) -> Tuple:
    specs, labels = zip(*batch)
    spec_lens = torch.tensor([s.size(0) for s in specs], dtype=torch.long)
    label_lens = torch.tensor([len(l) for l in labels], dtype=torch.long)
    
    specs = pad_sequence(specs, batch_first=True).unsqueeze(1)
    labels = pad_sequence(labels, batch_first=True, padding_value=0)
    return specs, labels, spec_lens, label_lens

class BucketSampler(Sampler):
    def __init__(self, lengths: List[int], batch_size: int, shuffle: bool = True):
        self.lengths = lengths
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        indices = np.argsort(lengths).tolist()
        self.batches = [indices[i:i+batch_size] for i in range(0, len(indices), batch_size)]
        if shuffle:
            np.random.shuffle(self.batches)
    
    def __iter__(self):
        for batch in self.batches:
            yield batch
        if self.shuffle:
            np.random.shuffle(self.batches)
    
    def __len__(self):
        return len(self.batches)

In [None]:
def main():
    # Create data directory if it doesn't exist
    os.makedirs("./data", exist_ok=True)
    
    try:
        train_dataset = ASRDataset('train-clean-100')
        valid_dataset = ASRDataset('dev-clean')
    except Exception as e:
        print(f"Failed to load datasets: {e}")
        print("Please check your internet connection and disk space.")
        return

    train_sampler = BucketSampler(train_dataset.lengths, config.batch_size)
    valid_sampler = BucketSampler(valid_dataset.lengths, config.batch_size, shuffle=False)
    
    train_loader = DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        collate_fn=collate_fn,
        num_workers=4,
        pin_memory=True,
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_sampler=valid_sampler,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True,
    )
    
    trainer = CTCTrainer()
    
    for epoch in range(config.epochs):
        # Training
        trainer.model.train()
        total_loss = 0.0
        for batch_idx, batch in enumerate(train_loader):
            loss = trainer.train_step(batch, batch_idx)
            total_loss += loss
            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch} Batch {batch_idx} Loss: {loss:.3f}")
        
        # Validation
        val_loss, val_wer = trainer.validate(valid_loader)
        print(f"Epoch {epoch} | Train Loss: {total_loss/len(train_loader):.3f} | Val Loss: {val_loss:.3f} | WER: {val_wer:.3f}")
        
        # Save checkpoint
        torch.save(trainer.model.state_dict(), f"asr_epoch_{epoch}.pt")

if __name__ == "__main__":
    main()

100%|██████████| 5.95G/5.95G [03:12<00:00, 33.2MB/s] 
100%|██████████| 322M/322M [00:10<00:00, 33.4MB/s] 


Epoch 0 Batch 9 Loss: 3.506
Epoch 0 Batch 19 Loss: 3.435
Epoch 0 Batch 29 Loss: 2.942
Epoch 0 Batch 39 Loss: 2.912
Epoch 0 Batch 49 Loss: 2.972
Epoch 0 Batch 59 Loss: 2.951
Epoch 0 Batch 69 Loss: 2.995
Epoch 0 Batch 79 Loss: 2.917
Epoch 0 Batch 89 Loss: 2.914
Epoch 0 Batch 99 Loss: 2.948
Epoch 0 Batch 109 Loss: 2.895
Epoch 0 Batch 119 Loss: 2.912
Epoch 0 Batch 129 Loss: 2.966
Epoch 0 Batch 139 Loss: 2.842
Epoch 0 Batch 149 Loss: 2.900
Epoch 0 Batch 159 Loss: 2.885
Epoch 0 Batch 169 Loss: 2.841
Epoch 0 Batch 179 Loss: 2.977
Epoch 0 Batch 189 Loss: 2.900
Epoch 0 Batch 199 Loss: 2.890
Epoch 0 Batch 209 Loss: 2.902
Epoch 0 Batch 219 Loss: 2.925
Epoch 0 Batch 229 Loss: 2.967
Epoch 0 Batch 239 Loss: 2.949
Epoch 0 Batch 249 Loss: 3.068
Epoch 0 Batch 259 Loss: 2.962
Epoch 0 Batch 269 Loss: 2.880
Epoch 0 Batch 279 Loss: 2.893
Epoch 0 Batch 289 Loss: 2.911
Epoch 0 Batch 299 Loss: 2.862
Epoch 0 Batch 309 Loss: 2.893
Epoch 0 Batch 319 Loss: 2.870
Epoch 0 Batch 329 Loss: 2.861
Epoch 0 Batch 339 Los