### Loading dataset and instantiating

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q evaluate

In [None]:
!pip install -q jiwer

In [None]:
import torch
from datasets import load_from_disk


In [None]:
final_dataset = load_from_disk("/content/drive/MyDrive/audio-to-text/DALI/dataset_processed")

In [None]:
final_dataset[0].keys()

In [None]:
from typing import Any, Dict, List, Union
import torch

class DataCollatorSpeechSeq2SeqWithPadding:
    def __init__(self, processor: Any, label_pad_token_id: int = -100):
        self.processor = processor
        self.label_pad_token_id = label_pad_token_id

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # Extract and pad audio features
        input_features = [
            {"input_features": feature["input_features"]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(
            input_features,
            return_tensors="pt"
        )

        # Extract and pad tokenized labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            return_tensors="pt"
        )

        # Replace padding token with -100 for loss ignoring
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), self.label_pad_token_id
        )

        # Remove BOS if it's automatically added (to avoid duplication)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [None]:
#!pip install -U transformers


In [None]:
from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration

processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")


In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor)


In [None]:
from sklearn.model_selection import KFold

k = 5
kf = KFold(n_splits=k, shuffle=True, random_state=42)

## Training

In [None]:
import evaluate

wer_metric = evaluate.load("wer")


In [None]:
import torch
import torch.nn.utils as nn_utils
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm import tqdm
import os
import shutil
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

best_wer = 1.0
best_model_dir = None

for fold, (train_idx, val_idx) in enumerate(kf.split(final_dataset)):

    print(f"\n=== Fold {fold+1}/{kf.n_splits} ===")

    train_dataset = final_dataset.select(train_idx.tolist())
    val_dataset = final_dataset.select(val_idx.tolist())

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=data_collator)

    model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")

    for name, param in model.named_parameters():
      if name.startswith("encoder.conv"):
        param.requires_grad = False
      else:
        param.requires_grad = True


    model.to(device)

    optimizer = AdamW(model.parameters(), lr=5e-5)
    num_epochs = 10
    num_training_steps = num_epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=500,
        num_training_steps=num_training_steps,
    )

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        print(f"Epoch {epoch+1}/{num_epochs}")
        for step, batch in enumerate(tqdm(train_loader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            try:
                outputs = model(
                    input_features=batch["input_features"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"]
                )
                loss = outputs.loss

                if not torch.isfinite(loss):
                    print(f"Warning: Non-finite loss at step {step}, skipping batch")
                    optimizer.zero_grad()
                    continue

                loss.backward()
                nn_utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                total_loss += loss.item()

            except Exception as e:
                print(f"Exception during forward pass at step {step}: {e}")
                raise

        avg_train_loss = total_loss / len(train_loader)
        print(f"Avg training loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        all_preds = []
        all_refs = []

        for batch in tqdm(val_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                generated_ids = model.generate(
                    input_features=batch["input_features"],
                    attention_mask=batch["attention_mask"],
                    max_length=128,
                )

            preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
            labels = batch["labels"].cpu().numpy()
            # Replace -100 with pad_token_id before decoding
            labels = np.where(labels == -100, processor.tokenizer.pad_token_id, labels)
            refs = processor.batch_decode(labels, skip_special_tokens=True)

            all_preds.extend(preds)
            all_refs.extend(refs)

        wer = wer_metric.compute(predictions=all_preds, references=all_refs)
        print(f"Validation WER: {wer:.4f}")

        if wer < best_wer:
            if best_model_dir and os.path.exists(best_model_dir):
                shutil.rmtree(best_model_dir)

            best_wer = wer
            timestamp = int(time.time())
            best_model_dir = f"./checkpoint-fold{fold}-epoch{epoch}-wer{wer:.4f}-{timestamp}"
            os.makedirs(best_model_dir, exist_ok=True)
            model.save_pretrained(best_model_dir)
            processor.save_pretrained(best_model_dir)
            print(f"Saved best model to {best_model_dir}")

print(f"\nTraining finished. Best WER: {best_wer:.4f}")


In [None]:
def transcribe_wav(wav_path, model_path):
    processor = Speech2TextProcessor.from_pretrained(model_path)
    model = Speech2TextForConditionalGeneration.from_pretrained(model_path)
    model.to("cuda" if torch.cuda.is_available() else "cpu")

    waveform, sr = torchaudio.load(wav_path)
    waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform).squeeze().numpy()

    inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    generated_ids = model.generate(input_features=inputs["input_features"], max_length=128)
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
