In [None]:
!pip install pytorch-lightning==1.9.0 --quiet
!pip install pandas --quiet
print("Dependencies installed!")

In [None]:
import os
print("Checking LibriSpeech dataset...")
librispeech_path = "/kaggle/input/librispeech-asr-corpus/"
if os.path.exists(librispeech_path):
    print("LibriSpeech dataset found")
    print("\nAvailable splits:")
    for item in os.listdir(librispeech_path):
        print(f"  - {item}")
else:
    print("LibriSpeech not found. Add it via 'Add Data' button")

In [None]:
%%writefile utils.py

class TextProcess:
    """Text processor for converting between text and integer sequences"""
    def __init__(self):
        self.char_map = {}
        self.int_map = {}
        
        chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 
                 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
        
        self.char_map["'"] = 0
        self.char_map[' '] = 1
        for i, char in enumerate(chars):
            self.char_map[char] = i + 2
        self.char_map['_'] = 28
        
        self.int_map = {v: k for k, v in self.char_map.items()}
    
    def text_to_int_sequence(self, text):
        text = text.lower()
        return [self.char_map.get(c, 1) for c in text if c in self.char_map]
    
    def int_to_text_sequence(self, labels):
        return ''.join([self.int_map.get(i, '') for i in labels])

In [None]:
%%writefile model.py

import torch
import torch.nn as nn
from torch.nn import functional as F

class ActDropNormCNN1D(nn.Module):
    def __init__(self, n_feats, dropout, keep_shape=False):
        super(ActDropNormCNN1D, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(n_feats)
        self.keep_shape = keep_shape
    
    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.dropout(F.gelu(self.norm(x)))
        if self.keep_shape:
            return x.transpose(1, 2)
        else:
            return x

class SpeechRecognition(nn.Module):
    hyper_parameters = {
        "num_classes": 29,
        "n_feats": 81,
        "dropout": 0.1,
        "hidden_size": 1024,
        "num_layers": 1
    }
    
    def __init__(self, hidden_size, num_classes, n_feats, num_layers, dropout):
        super(SpeechRecognition, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.cnn = nn.Sequential(
            nn.Conv1d(n_feats, n_feats, 10, 2, padding=10//2),
            ActDropNormCNN1D(n_feats, dropout),
        )
        self.dense = nn.Sequential(
            nn.Linear(n_feats, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        self.lstm = nn.LSTM(input_size=128, hidden_size=hidden_size,
                            num_layers=num_layers, dropout=0.0,
                            bidirectional=False)
        self.layer_norm2 = nn.LayerNorm(hidden_size)
        self.dropout2 = nn.Dropout(dropout)
        self.final_fc = nn.Linear(hidden_size, num_classes)
    
    def _init_hidden(self, batch_size):
        n, hs = self.num_layers, self.hidden_size
        return (torch.zeros(n*1, batch_size, hs),
                torch.zeros(n*1, batch_size, hs))
    
    def forward(self, x, hidden):
        x = x.squeeze(1)
        x = self.cnn(x)
        x = self.dense(x)
        x = x.transpose(0, 1)
        out, (hn, cn) = self.lstm(x, hidden)
        x = self.dropout2(F.gelu(self.layer_norm2(out)))
        return self.final_fc(x), (hn, cn)

In [None]:
%%writefile dataset.py

import torch
import torchaudio
import torch.nn as nn
import pandas as pd
import numpy as np
from utils import TextProcess

class SpecAugment(nn.Module):
    def __init__(self, rate, policy=3, freq_mask=15, time_mask=35):
        super(SpecAugment, self).__init__()
        self.rate = rate
        self.specaug = nn.Sequential(
            torchaudio.transforms.FrequencyMasking(freq_mask_param=freq_mask),
            torchaudio.transforms.TimeMasking(time_mask_param=time_mask)
        )
        self.specaug2 = nn.Sequential(
            torchaudio.transforms.FrequencyMasking(freq_mask_param=freq_mask),
            torchaudio.transforms.TimeMasking(time_mask_param=time_mask),
            torchaudio.transforms.FrequencyMasking(freq_mask_param=freq_mask),
            torchaudio.transforms.TimeMasking(time_mask_param=time_mask)
        )
        policies = {1: self.policy1, 2: self.policy2, 3: self.policy3}
        self._forward = policies[policy]
    
    def forward(self, x):
        return self._forward(x)
    
    def policy1(self, x):
        probability = torch.rand(1, 1).item()
        if self.rate > probability:
            return self.specaug(x)
        return x
    
    def policy2(self, x):
        probability = torch.rand(1, 1).item()
        if self.rate > probability:
            return self.specaug2(x)
        return x
    
    def policy3(self, x):
        probability = torch.rand(1, 1).item()
        if probability > 0.5:
            return self.policy1(x)
        return self.policy2(x)

class LogMelSpec(nn.Module):
    def __init__(self, sample_rate=8000, n_mels=128, win_length=160, hop_length=80):
        super(LogMelSpec, self).__init__()
        self.transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate, n_mels=n_mels,
            win_length=win_length, hop_length=hop_length)
    
    def forward(self, x):
        x = self.transform(x)
        x = np.log(x + 1e-14)
        return x

def get_featurizer(sample_rate, n_feats=81):
    return LogMelSpec(sample_rate=sample_rate, n_mels=n_feats, win_length=160, hop_length=80)

class Data(torch.utils.data.Dataset):
    parameters = {
        "sample_rate": 8000, "n_feats": 81,
        "specaug_rate": 0.5, "specaug_policy": 3,
        "time_mask": 70, "freq_mask": 15
    }
    
    def __init__(self, json_path, sample_rate, n_feats, specaug_rate, specaug_policy,
                 time_mask, freq_mask, valid=False, shuffle=True, text_to_int=True, log_ex=True):
        self.log_ex = log_ex
        self.text_process = TextProcess()
        
        print("Loading data json file from", json_path)
        self.data = pd.read_json(json_path, lines=True)
        
        if valid:
            self.audio_transforms = torch.nn.Sequential(
                LogMelSpec(sample_rate=sample_rate, n_mels=n_feats, win_length=160, hop_length=80)
            )
        else:
            self.audio_transforms = torch.nn.Sequential(
                LogMelSpec(sample_rate=sample_rate, n_mels=n_feats, win_length=160, hop_length=80),
                SpecAugment(specaug_rate, specaug_policy, freq_mask, time_mask)
            )
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()
        
        try:
            file_path = self.data.key.iloc[idx]
            waveform, sample_rate = torchaudio.load(file_path)
            
            # Resample to 8000 Hz if needed
            if sample_rate != 8000:
                resampler = torchaudio.transforms.Resample(sample_rate, 8000)
                waveform = resampler(waveform)
            
            label = self.text_process.text_to_int_sequence(self.data['text'].iloc[idx])
            spectrogram = self.audio_transforms(waveform)
            spec_len = spectrogram.shape[-1] // 2
            label_len = len(label)
            
            if spec_len < label_len:
                raise Exception('spectrogram len is bigger then label len')
            if spectrogram.shape[0] > 1:
                raise Exception('dual channel, skipping audio file %s' % file_path)
            if spectrogram.shape[2] > 1650:
                raise Exception('spectrogram to big. size %s' % spectrogram.shape[2])
            if label_len == 0:
                raise Exception('label len is zero... skipping %s' % file_path)
        except Exception as e:
            if self.log_ex:
                print(str(e), file_path)
            return self.__getitem__(idx - 1 if idx != 0 else idx + 1)
        return spectrogram, label, spec_len, label_len
    
    def describe(self):
        return self.data.describe()

def collate_fn_padd(data):
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    for (spectrogram, label, input_length, label_length) in data:
        if spectrogram is None:
            continue
        spectrograms.append(spectrogram.squeeze(0).transpose(0, 1))
        labels.append(torch.Tensor(label))
        input_lengths.append(input_length)
        label_lengths.append(label_length)
    
    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
    
    return spectrograms, labels, input_lengths, label_lengths

In [None]:
%%writefile train.py

import os
import ast
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from argparse import ArgumentParser
from model import SpeechRecognition
from dataset import Data, collate_fn_padd
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

class SpeechModule(LightningModule):
    def __init__(self, model, args):
        super(SpeechModule, self).__init__()
        self.model = model
        self.criterion = nn.CTCLoss(blank=28, zero_infinity=True)
        self.args = args
    
    def forward(self, x, hidden):
        return self.model(x, hidden)
    
    def configure_optimizers(self):
        self.optimizer = optim.AdamW(self.model.parameters(), self.args.learning_rate)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min',
            factor=0.50, patience=6)
        return [self.optimizer], [self.scheduler]
    
    def step(self, batch):
        spectrograms, labels, input_lengths, label_lengths = batch
        bs = spectrograms.shape[0]
        hidden = self.model._init_hidden(bs)
        hn, c0 = hidden[0].to(self.device), hidden[1].to(self.device)
        output, _ = self(spectrograms, (hn, c0))
        output = F.log_softmax(output, dim=2)
        loss = self.criterion(output, labels, input_lengths, label_lengths)
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self.step(batch)
        self.log('train_loss', loss)
        self.log('lr', self.optimizer.param_groups[0]['lr'])
        return loss
    
    def train_dataloader(self):
        d_params = Data.parameters
        d_params.update(self.args.dparams_override)
        train_dataset = Data(json_path=self.args.train_file, **d_params)
        return DataLoader(dataset=train_dataset,
                          batch_size=self.args.batch_size,
                          num_workers=self.args.data_workers,
                          pin_memory=True,
                          collate_fn=collate_fn_padd)
    
    def validation_step(self, batch, batch_idx):
        loss = self.step(batch)
        self.log('val_loss', loss)
        return loss
    
    def val_dataloader(self):
        d_params = Data.parameters
        d_params.update(self.args.dparams_override)
        test_dataset = Data(json_path=self.args.valid_file, **d_params, valid=True)
        return DataLoader(dataset=test_dataset,
                          batch_size=self.args.batch_size,
                          num_workers=self.args.data_workers,
                          collate_fn=collate_fn_padd,
                          pin_memory=True)

def checkpoint_callback(args):
    return ModelCheckpoint(
        dirpath=args.save_model_path,
        save_top_k=3,
        verbose=True,
        monitor='val_loss',
        mode='min',
        filename='speech-{epoch:02d}-{val_loss:.2f}'
    )

def main(args):
    h_params = SpeechRecognition.hyper_parameters
    h_params.update(args.hparams_override)
    model = SpeechRecognition(**h_params)
    
    if args.load_model_from:
        speech_module = SpeechModule.load_from_checkpoint(args.load_model_from, model=model, args=args)
    else:
        speech_module = SpeechModule(model, args)
    
    logger = TensorBoardLogger(args.logdir, name='speech_recognition')
    
    trainer = Trainer(
        max_epochs=args.epochs,
        accelerator='gpu',
        devices=args.gpus,
        logger=logger,
        gradient_clip_val=1.0,
        val_check_interval=args.valid_every,
        callbacks=[checkpoint_callback(args)]
    )
    trainer.fit(speech_module)

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('-g', '--gpus', default=1, type=int)
    parser.add_argument('-w', '--data_workers', default=2, type=int)
    parser.add_argument('--train_file', required=True, type=str)
    parser.add_argument('--valid_file', required=True, type=str)
    parser.add_argument('--valid_every', default=1000, type=int)
    parser.add_argument('--save_model_path', required=True, type=str)
    parser.add_argument('--load_model_from', default=None, type=str)
    parser.add_argument('--logdir', default='tb_logs', type=str)
    parser.add_argument('--epochs', default=10, type=int)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--learning_rate', default=1e-3, type=float)
    parser.add_argument("--hparams_override", default="{}", type=str)
    parser.add_argument("--dparams_override", default="{}", type=str)
    
    args = parser.parse_args()
    args.hparams_override = ast.literal_eval(args.hparams_override)
    args.dparams_override = ast.literal_eval(args.dparams_override)
    
    os.makedirs(args.save_model_path, exist_ok=True)
    main(args)

In [None]:

import os
import json
from pathlib import Path
from tqdm import tqdm

def parse_librispeech_transcript(trans_file):
    """Parse LibriSpeech transcript file"""
    transcripts = {}
    with open(trans_file, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split(' ', 1)
            if len(parts) == 2:
                file_id, text = parts
                transcripts[file_id] = text.lower()
    return transcripts

def find_audio_files_recursive(base_path, max_samples=None):
    """Recursively find all FLAC files and their transcripts"""
    data = []
    
    print(f" Searching for audio files in: {base_path}\n")
    
    # Walk through all directories
    for root, dirs, files in os.walk(base_path):
        # Look for transcript files
        for file in files:
            if file.endswith('.trans.txt'):
                trans_file = os.path.join(root, file)
                
                try:
                    transcripts = parse_librispeech_transcript(trans_file)
                    
                    # Find corresponding audio files in same directory
                    for audio_file in os.listdir(root):
                        if audio_file.endswith('.flac'):
                            file_id = audio_file.replace('.flac', '')
                            if file_id in transcripts:
                                audio_path = os.path.join(root, audio_file)
                                data.append({
                                    "key": audio_path,
                                    "text": transcripts[file_id]
                                })
                                
                                # Print progress every 100 files
                                if len(data) % 100 == 0:
                                    print(f"   Found {len(data)} audio files...", end='\r')
                                
                                if max_samples and len(data) >= max_samples:
                                    print(f"\nâœ… Reached max samples limit ({max_samples})")
                                    return data
                except Exception as e:
                    print(f"  Error reading {trans_file}: {e}")
                    continue
    
    return data

# ===========================================================================
# MAIN EXECUTION
# ===========================================================================

print("="*70)
print(" PREPARING LIBRISPEECH DATASET")
print("="*70)

base_path = "/kaggle/input/librispeech-asr-corpus"

# Check if path exists
if not os.path.exists(base_path):
    print(f" Path not found: {base_path}")
    print("\n Available paths:")
    if os.path.exists("/kaggle/input/"):
        for item in os.listdir("/kaggle/input/"):
            print(f" {item}")
else:
    print(f" Found dataset at: {base_path}\n")
    
    # Show directory structure
    print(" Directory structure:")
    for item in sorted(os.listdir(base_path))[:15]:
        item_path = os.path.join(base_path, item)
        if os.path.isdir(item_path):
            print(f"  {item}/")
            # Show subdirectories
            try:
                sub_items = os.listdir(item_path)[:3]
                for sub in sub_items:
                    print(f"      â””â”€ {sub}")
                if len(os.listdir(item_path)) > 3:
                    print(f"      â””â”€ ... and {len(os.listdir(item_path)) - 3} more")
            except:
                pass
        else:
            print(f"   ðŸ“„ {item}")
    
    print("\n" + "="*70)
    print("CREATING TRAINING DATASET")
    print("="*70)
    
    # Create training data
    # This will search the entire dataset recursively
    train_data = find_audio_files_recursive(
        base_path,
        max_samples=5000  # Limit to 5000 for quick testing
    )
    
    if len(train_data) > 0:
        # Split into train and validation (90% train, 10% valid)
        split_idx = int(len(train_data) * 0.9)
        
        train_samples = train_data[:split_idx]
        valid_samples = train_data[split_idx:]
        
        # Ensure we have at least some validation samples
        if len(valid_samples) < 100 and len(train_samples) > 100:
            valid_samples = train_samples[-100:]
            train_samples = train_samples[:-100]
        
        print(f"\n Saving dataset files...")
        
        # Save training data
        with open("train_data.json", "w") as f:
            for item in train_samples:
                f.write(json.dumps(item) + "\n")
        
        # Save validation data
        with open("valid_data.json", "w") as f:
            for item in valid_samples:
                f.write(json.dumps(item) + "\n")
        
        print("\n" + "="*70)
        print("DATASET SUMMARY")
        print("="*70)
        print(f" Training samples: {len(train_samples)}")
        print(f" Validation samples: {len(valid_samples)}")
        print(f" Total samples: {len(train_data)}")
        
        # Show sample
        print(f"\n Sample training data:")
        sample = train_samples[0]
        print(f"   Audio: {sample['key']}")
        print(f"   Text: {sample['text'][:100]}..." if len(sample['text']) > 100 else f"   Text: {sample['text']}")
        
        # Verify audio file
        print(f"\n Verifying sample audio...")
        try:
            import torchaudio
            waveform, sample_rate = torchaudio.load(sample['key'])
            print(f"   Audio loaded successfully!")
            print(f"   Sample rate: {sample_rate} Hz")
            print(f"   Duration: {waveform.shape[1] / sample_rate:.2f} seconds")
            print(f"   Channels: {waveform.shape[0]}")
        except Exception as e:
            print(f"   Error loading audio: {e}")
        
    else:
        print("\n NO AUDIO FILES FOUND!")
        print("\n Debugging information:")
        print("\nSearching for .flac files...")
        flac_files = []
        for root, dirs, files in os.walk(base_path):
            for file in files:
                if file.endswith('.flac'):
                    flac_files.append(os.path.join(root, file))
                    if len(flac_files) >= 5:
                        break
            if len(flac_files) >= 5:
                break
        
        if flac_files:
            print(f" Found {len(flac_files)} .flac files (showing first 5):")
            for f in flac_files[:5]:
                print(f"   {f}")
        else:
            print(" No .flac files found at all!")
        
        print("\nSearching for .trans.txt files...")
        trans_files = []
        for root, dirs, files in os.walk(base_path):
            for file in files:
                if file.endswith('.trans.txt'):
                    trans_files.append(os.path.join(root, file))
                    if len(trans_files) >= 3:
                        break
            if len(trans_files) >= 3:
                break
        
        if trans_files:
            print(f" Found transcript files (showing first 3):")
            for f in trans_files[:3]:
                print(f"   {f}")
        else:
            print(" No .trans.txt files found!")

print("\n" + "="*70)
print("DATA PREPARATION COMPLETE")
print("="*70)

# ===========================================================================
# TIPS FOR FULL DATASET
# ===========================================================================
"""
To use the FULL LibriSpeech dataset (not just 5000 samples):

Change line 66 to:
    max_samples=None  # Use ALL samples

This will take longer but give better accuracy:
- Full train-clean-100: ~28,000 samples (2-3 hours to prepare)
- Full train-clean-360: ~104,000 samples (8-10 hours to prepare)
"""

In [None]:
import json

print("Checking sample data...")
with open("train_data.json", 'r') as f:
    sample = json.loads(f.readline())
    print("\n Sample training data:")
    print(f"Audio: {sample['key']}")
    print(f"Text: {sample['text']}")

# Verify audio file exists
import torchaudio
try:
    waveform, sample_rate = torchaudio.load(sample['key'])
    print(f"\n Audio loaded successfully!")
    print(f"Sample rate: {sample_rate} Hz")
    print(f"Duration: {waveform.shape[1] / sample_rate:.2f} seconds")
    print(f"Channels: {waveform.shape[0]}")
except Exception as e:
    print(f" Error loading audio: {e}")

In [None]:
import os

# Create directories
os.makedirs("models", exist_ok=True)
os.makedirs("tb_logs", exist_ok=True)

print(" Starting training...")
print("\nConfiguration:")
print(f"  â€¢ GPU: Enabled")
print(f"  â€¢ Batch size: 16")
print(f"  â€¢ Epochs: 10")
print(f"  â€¢ Learning rate: 0.001")
print("\n" + "="*50 + "\n")

# Start training
!python train.py \
    --train_file train_data.json \
    --valid_file valid_data.json \
    --save_model_path models/ \
    --batch_size 16 \
    --epochs 10 \
    --learning_rate 0.001 \
    --gpus 1 \
    --data_workers 2 \
    --valid_every 500

In [None]:
%load_ext tensorboard
%tensorboard --logdir tb_logs

In [None]:
import torch
from model import SpeechRecognition
import os

print("Finding best model checkpoint...")
checkpoints = [f for f in os.listdir("models/") if f.endswith('.ckpt')]
if checkpoints:
    latest_checkpoint = max(checkpoints, key=lambda x: os.path.getctime(os.path.join("models/", x)))
    print(f" Found checkpoint: {latest_checkpoint}")
    
    checkpoint_path = os.path.join("models/", latest_checkpoint)
    print(f"\n Model saved at: {checkpoint_path}")
    print(f"File size: {os.path.getsize(checkpoint_path) / (1024*1024):.2f} MB")
else:
    print(" No checkpoints found")

In [None]:
import torch
import torchaudio
from model import SpeechRecognition
from dataset import get_featurizer
from utils import TextProcess
import json

# Load model
print("Loading model for inference...")
h_params = SpeechRecognition.hyper_parameters
model = SpeechRecognition(**h_params)

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()

# Load sample
with open("valid_data.json", 'r') as f:
    sample = json.loads(f.readline())

print(f"\n Testing on: {sample['key']}")
print(f" True text: {sample['text']}")

# Predict
waveform, sr = torchaudio.load(sample['key'])
if sr != 8000:
    resampler = torchaudio.transforms.Resample(sr, 8000)
    waveform = resampler(waveform)

featurizer = get_featurizer(8000)
log_mel = featurizer(waveform).unsqueeze(0).unsqueeze(0)

with torch.no_grad():
    hidden = model._init_hidden(1)
    output, _ = model(log_mel, hidden)
    output = torch.nn.functional.softmax(output, dim=2)
    
    # Greedy decode
    arg_maxes = torch.argmax(output, dim=2).squeeze()
    decode = []
    for i, index in enumerate(arg_maxes):
        if index != 28:  # blank
            if i == 0 or index != arg_maxes[i-1]:
                decode.append(index.item())
    
    text_process = TextProcess()
    predicted_text = text_process.int_to_text_sequence(decode)
    
    print(f" Predicted: {predicted_text}")

In [None]:
import shutil

# Copy best model to working directory for easy download
output_path = "/kaggle/working/speech_recognition_model.ckpt"
shutil.copy(checkpoint_path, output_path)

print(f" Model saved to: {output_path}")
print(f" Download from 'Output' tab in Kaggle")
print(f"\nTo use this model:")
print("1. Download the .ckpt file")
print("2. Load with PyTorch")
print("3. Use for inference")

In [None]:
print("="*60)
print(" TRAINING COMPLETE!")
print("="*60)
print(f"\n Model checkpoint: {checkpoint_path}")
print(f" Download ready: /kaggle/working/speech_recognition_model.ckpt")
print(f"\n Training Stats:")
print(f"  â€¢ Training samples: {train_count}")
print(f"  â€¢ Validation samples: {valid_count}")
print(f"  â€¢ Epochs: 10")
print(f"  â€¢ Batch size: 16")
print(f"\n Next Steps:")
print("  1. Download model from Output tab")
print("  2. Test on your own audio")
print("  3. Fine-tune with more epochs if needed")
print("  4. Try larger dataset for better accuracy")
print("\n" + "="*60)