### Loading dataset and instantiating

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install -q evaluate

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
!pip install -q jiwer

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/3.1 MB[0m [31m40.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.1/3.1 MB[0m [31m57.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
import torch
from datasets import load_from_disk


In [5]:
final_dataset = load_from_disk("/content/drive/MyDrive/audio-to-text/DALI/dataset_processed")

In [6]:
final_dataset[0].keys()

dict_keys(['input_features', 'attention_mask', 'labels'])

In [12]:
from typing import Any, Dict, List, Union
import torch

class DataCollatorSpeechSeq2SeqWithPadding:
    def __init__(self, processor: Any, label_pad_token_id: int = -100):
        self.processor = processor
        self.label_pad_token_id = label_pad_token_id

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # Extract and pad audio features
        input_features = [
            {"input_features": feature["input_features"]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(
            input_features,
            return_tensors="pt"
        )

        # Extract and pad tokenized labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            return_tensors="pt"
        )

        # Replace padding token with -100 for loss ignoring
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), self.label_pad_token_id
        )

        # Remove BOS if it's automatically added (to avoid duplication)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [None]:
#!pip install -U transformers


In [17]:
from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration

processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/456 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/417k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

In [18]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor)


In [25]:
from sklearn.model_selection import KFold

k = 5
kf = KFold(n_splits=k, shuffle=True, random_state=42)

## Training

In [32]:
import evaluate

wer_metric = evaluate.load("wer")


Downloading builder script: 0.00B [00:00, ?B/s]

In [34]:
import torch
import torch.nn.utils as nn_utils
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm import tqdm
import os
import shutil
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

best_wer = 1.0
best_model_dir = None

for fold, (train_idx, val_idx) in enumerate(kf.split(final_dataset)):

    print(f"\n=== Fold {fold+1}/{kf.n_splits} ===")

    train_dataset = final_dataset.select(train_idx.tolist())
    val_dataset = final_dataset.select(val_idx.tolist())

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=data_collator)

    model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")

    for name, param in model.named_parameters():
      if name.startswith("encoder.conv"):
        param.requires_grad = False
      else:
        param.requires_grad = True


    model.to(device)

    optimizer = AdamW(model.parameters(), lr=5e-5)
    num_epochs = 10
    num_training_steps = num_epochs * len(train_loader)
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=500,
        num_training_steps=num_training_steps,
    )

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        print(f"Epoch {epoch+1}/{num_epochs}")
        for step, batch in enumerate(tqdm(train_loader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            try:
                outputs = model(
                    input_features=batch["input_features"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"]
                )
                loss = outputs.loss

                if not torch.isfinite(loss):
                    print(f"Warning: Non-finite loss at step {step}, skipping batch")
                    optimizer.zero_grad()
                    continue

                loss.backward()
                nn_utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                total_loss += loss.item()

            except Exception as e:
                print(f"Exception during forward pass at step {step}: {e}")
                raise

        avg_train_loss = total_loss / len(train_loader)
        print(f"Avg training loss: {avg_train_loss:.4f}")

        # Validation
        model.eval()
        all_preds = []
        all_refs = []

        for batch in tqdm(val_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                generated_ids = model.generate(
                    input_features=batch["input_features"],
                    attention_mask=batch["attention_mask"],
                    max_length=128,
                )

            preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
            labels = batch["labels"].cpu().numpy()
            # Replace -100 with pad_token_id before decoding
            labels = np.where(labels == -100, processor.tokenizer.pad_token_id, labels)
            refs = processor.batch_decode(labels, skip_special_tokens=True)

            all_preds.extend(preds)
            all_refs.extend(refs)

        wer = wer_metric.compute(predictions=all_preds, references=all_refs)
        print(f"Validation WER: {wer:.4f}")

        if wer < best_wer:
            if best_model_dir and os.path.exists(best_model_dir):
                shutil.rmtree(best_model_dir)

            best_wer = wer
            timestamp = int(time.time())
            best_model_dir = f"./checkpoint-fold{fold}-epoch{epoch}-wer{wer:.4f}-{timestamp}"
            os.makedirs(best_model_dir, exist_ok=True)
            model.save_pretrained(best_model_dir)
            processor.save_pretrained(best_model_dir)
            print(f"Saved best model to {best_model_dir}")

print(f"\nTraining finished. Best WER: {best_wer:.4f}")



=== Fold 1/5 ===
Epoch 1/10


100%|██████████| 130/130 [00:49<00:00,  2.60it/s]


Avg training loss: 4.2810


100%|██████████| 33/33 [00:27<00:00,  1.20it/s]


Validation WER: 0.8759
Saved best model to ./checkpoint-fold0-epoch0-wer0.8759-1752023021
Epoch 2/10


100%|██████████| 130/130 [00:47<00:00,  2.71it/s]


Avg training loss: 3.8562


100%|██████████| 33/33 [00:26<00:00,  1.23it/s]


Validation WER: 0.8239
Saved best model to ./checkpoint-fold0-epoch1-wer0.8239-1752023096
Epoch 3/10


100%|██████████| 130/130 [00:48<00:00,  2.70it/s]


Avg training loss: 3.3917


100%|██████████| 33/33 [00:28<00:00,  1.17it/s]


Validation WER: 0.7725
Saved best model to ./checkpoint-fold0-epoch2-wer0.7725-1752023173
Epoch 4/10


100%|██████████| 130/130 [00:47<00:00,  2.76it/s]


Avg training loss: 2.9710


100%|██████████| 33/33 [00:25<00:00,  1.30it/s]


Validation WER: 0.7269
Saved best model to ./checkpoint-fold0-epoch3-wer0.7269-1752023245
Epoch 5/10


100%|██████████| 130/130 [00:47<00:00,  2.71it/s]


Avg training loss: 2.6405


100%|██████████| 33/33 [00:26<00:00,  1.25it/s]


Validation WER: 0.6893
Saved best model to ./checkpoint-fold0-epoch4-wer0.6893-1752023320
Epoch 6/10


100%|██████████| 130/130 [00:48<00:00,  2.69it/s]


Avg training loss: 2.3892


100%|██████████| 33/33 [00:26<00:00,  1.23it/s]


Validation WER: 0.6340
Saved best model to ./checkpoint-fold0-epoch5-wer0.6340-1752023395
Epoch 7/10


100%|██████████| 130/130 [00:46<00:00,  2.77it/s]


Avg training loss: 2.1887


100%|██████████| 33/33 [00:24<00:00,  1.32it/s]


Validation WER: 0.6315
Saved best model to ./checkpoint-fold0-epoch6-wer0.6315-1752023468
Epoch 8/10


100%|██████████| 130/130 [00:49<00:00,  2.60it/s]


Avg training loss: 2.0590


100%|██████████| 33/33 [00:25<00:00,  1.30it/s]


Validation WER: 0.6188
Saved best model to ./checkpoint-fold0-epoch7-wer0.6188-1752023543
Epoch 9/10


100%|██████████| 130/130 [00:48<00:00,  2.70it/s]


Avg training loss: 1.9954


100%|██████████| 33/33 [00:24<00:00,  1.34it/s]


Validation WER: 0.6054
Saved best model to ./checkpoint-fold0-epoch8-wer0.6054-1752023616
Epoch 10/10


100%|██████████| 130/130 [00:48<00:00,  2.68it/s]


Avg training loss: 1.9741


100%|██████████| 33/33 [00:25<00:00,  1.31it/s]


Validation WER: 0.5993
Saved best model to ./checkpoint-fold0-epoch9-wer0.5993-1752023691

=== Fold 2/5 ===
Epoch 1/10


100%|██████████| 130/130 [00:48<00:00,  2.66it/s]


Avg training loss: 4.1891


100%|██████████| 33/33 [00:27<00:00,  1.22it/s]


Validation WER: 0.9369
Epoch 2/10


100%|██████████| 130/130 [00:50<00:00,  2.59it/s]


Avg training loss: 3.7429


100%|██████████| 33/33 [00:26<00:00,  1.26it/s]


Validation WER: 0.8687
Epoch 3/10


100%|██████████| 130/130 [00:49<00:00,  2.64it/s]


Avg training loss: 3.2988


100%|██████████| 33/33 [00:26<00:00,  1.27it/s]


Validation WER: 0.8657
Epoch 4/10


100%|██████████| 130/130 [00:48<00:00,  2.70it/s]


Avg training loss: 2.8876


100%|██████████| 33/33 [00:25<00:00,  1.29it/s]


Validation WER: 0.8661
Epoch 5/10


100%|██████████| 130/130 [00:48<00:00,  2.69it/s]


Avg training loss: 2.5452


100%|██████████| 33/33 [00:25<00:00,  1.30it/s]


Validation WER: 0.8356
Epoch 6/10


100%|██████████| 130/130 [00:47<00:00,  2.71it/s]


Avg training loss: 2.2993


100%|██████████| 33/33 [00:24<00:00,  1.32it/s]


Validation WER: 0.7105
Epoch 7/10


100%|██████████| 130/130 [00:48<00:00,  2.71it/s]


Avg training loss: 2.1079


100%|██████████| 33/33 [00:24<00:00,  1.35it/s]


Validation WER: 0.7256
Epoch 8/10


100%|██████████| 130/130 [00:48<00:00,  2.68it/s]


Avg training loss: 1.9894


100%|██████████| 33/33 [00:24<00:00,  1.34it/s]


Validation WER: 0.6998
Epoch 9/10


100%|██████████| 130/130 [00:48<00:00,  2.68it/s]


Avg training loss: 1.9153


100%|██████████| 33/33 [00:25<00:00,  1.28it/s]


Validation WER: 0.6786
Epoch 10/10


100%|██████████| 130/130 [00:48<00:00,  2.70it/s]


Avg training loss: 1.8635


100%|██████████| 33/33 [00:24<00:00,  1.33it/s]


Validation WER: 0.7002

=== Fold 3/5 ===
Epoch 1/10


100%|██████████| 130/130 [00:49<00:00,  2.64it/s]


Avg training loss: 4.2189


100%|██████████| 33/33 [00:28<00:00,  1.17it/s]


Validation WER: 0.9107
Epoch 2/10


100%|██████████| 130/130 [00:46<00:00,  2.77it/s]


Avg training loss: 3.7673


100%|██████████| 33/33 [00:28<00:00,  1.15it/s]


Validation WER: 0.8999
Epoch 3/10


100%|██████████| 130/130 [00:48<00:00,  2.68it/s]


Avg training loss: 3.3626


100%|██████████| 33/33 [00:26<00:00,  1.25it/s]


Validation WER: 0.8220
Epoch 4/10


100%|██████████| 130/130 [00:48<00:00,  2.70it/s]


Avg training loss: 2.9484


100%|██████████| 33/33 [00:27<00:00,  1.21it/s]


Validation WER: 0.7880
Epoch 5/10


100%|██████████| 130/130 [00:48<00:00,  2.68it/s]


Avg training loss: 2.6235


100%|██████████| 33/33 [00:27<00:00,  1.21it/s]


Validation WER: 0.7924
Epoch 6/10


100%|██████████| 130/130 [00:47<00:00,  2.75it/s]


Avg training loss: 2.2955


100%|██████████| 33/33 [00:26<00:00,  1.24it/s]


Validation WER: 0.7506
Epoch 7/10


100%|██████████| 130/130 [00:47<00:00,  2.71it/s]


Avg training loss: 2.1767


100%|██████████| 33/33 [00:25<00:00,  1.30it/s]


Validation WER: 0.6997
Epoch 8/10


100%|██████████| 130/130 [00:48<00:00,  2.66it/s]


Avg training loss: 2.0608


100%|██████████| 33/33 [00:26<00:00,  1.26it/s]


Validation WER: 0.6916
Epoch 9/10


100%|██████████| 130/130 [00:46<00:00,  2.79it/s]


Avg training loss: 1.9708


100%|██████████| 33/33 [00:26<00:00,  1.25it/s]


Validation WER: 0.6815
Epoch 10/10


100%|██████████| 130/130 [00:49<00:00,  2.62it/s]


Avg training loss: 1.9615


100%|██████████| 33/33 [00:26<00:00,  1.24it/s]


Validation WER: 0.6768

=== Fold 4/5 ===
Epoch 1/10


100%|██████████| 130/130 [00:47<00:00,  2.72it/s]


Avg training loss: 4.2452


100%|██████████| 33/33 [00:30<00:00,  1.08it/s]


Validation WER: 1.0406
Epoch 2/10


100%|██████████| 130/130 [00:49<00:00,  2.63it/s]


Avg training loss: 3.7492


100%|██████████| 33/33 [00:29<00:00,  1.12it/s]


Validation WER: 0.9840
Epoch 3/10


100%|██████████| 130/130 [00:48<00:00,  2.71it/s]


Avg training loss: 3.3036


100%|██████████| 33/33 [00:30<00:00,  1.08it/s]


Validation WER: 0.9242
Epoch 4/10


100%|██████████| 130/130 [00:48<00:00,  2.71it/s]


Avg training loss: 2.9432


100%|██████████| 33/33 [00:32<00:00,  1.02it/s]


Validation WER: 0.9391
Epoch 5/10


100%|██████████| 130/130 [00:48<00:00,  2.69it/s]


Avg training loss: 2.5529


100%|██████████| 33/33 [00:31<00:00,  1.05it/s]


Validation WER: 0.9263
Epoch 6/10


100%|██████████| 130/130 [00:48<00:00,  2.71it/s]


Avg training loss: 2.3149


100%|██████████| 33/33 [00:32<00:00,  1.03it/s]


Validation WER: 0.8278
Epoch 7/10


100%|██████████| 130/130 [00:47<00:00,  2.72it/s]


Avg training loss: 2.1427


100%|██████████| 33/33 [00:29<00:00,  1.14it/s]


Validation WER: 0.8267
Epoch 8/10


100%|██████████| 130/130 [00:48<00:00,  2.69it/s]


Avg training loss: 2.0418


100%|██████████| 33/33 [00:29<00:00,  1.10it/s]


Validation WER: 0.7911
Epoch 9/10


100%|██████████| 130/130 [00:48<00:00,  2.68it/s]


Avg training loss: 1.9351


100%|██████████| 33/33 [00:28<00:00,  1.14it/s]


Validation WER: 0.7648
Epoch 10/10


100%|██████████| 130/130 [00:48<00:00,  2.70it/s]


Avg training loss: 1.9188


100%|██████████| 33/33 [00:28<00:00,  1.17it/s]


Validation WER: 0.7815

=== Fold 5/5 ===
Epoch 1/10


100%|██████████| 130/130 [00:49<00:00,  2.62it/s]


Avg training loss: 4.3480


100%|██████████| 33/33 [00:26<00:00,  1.23it/s]


Validation WER: 0.8120
Epoch 2/10


100%|██████████| 130/130 [00:47<00:00,  2.72it/s]


Avg training loss: 3.8954


100%|██████████| 33/33 [00:26<00:00,  1.23it/s]


Validation WER: 0.8022
Epoch 3/10


100%|██████████| 130/130 [00:48<00:00,  2.68it/s]


Avg training loss: 3.4191


100%|██████████| 33/33 [00:26<00:00,  1.23it/s]


Validation WER: 0.7746
Epoch 4/10


100%|██████████| 130/130 [00:47<00:00,  2.73it/s]


Avg training loss: 3.0212


100%|██████████| 33/33 [00:25<00:00,  1.28it/s]


Validation WER: 0.7400
Epoch 5/10


100%|██████████| 130/130 [00:47<00:00,  2.76it/s]


Avg training loss: 2.6600


100%|██████████| 33/33 [00:27<00:00,  1.19it/s]


Validation WER: 0.7368
Epoch 6/10


100%|██████████| 130/130 [00:46<00:00,  2.77it/s]


Avg training loss: 2.4094


100%|██████████| 33/33 [00:26<00:00,  1.22it/s]


Validation WER: 0.6942
Epoch 7/10


100%|██████████| 130/130 [00:47<00:00,  2.72it/s]


Avg training loss: 2.2023


100%|██████████| 33/33 [00:25<00:00,  1.30it/s]


Validation WER: 0.6456
Epoch 8/10


100%|██████████| 130/130 [00:48<00:00,  2.67it/s]


Avg training loss: 2.1005


100%|██████████| 33/33 [00:26<00:00,  1.24it/s]


Validation WER: 0.6463
Epoch 9/10


100%|██████████| 130/130 [00:46<00:00,  2.77it/s]


Avg training loss: 2.0400


100%|██████████| 33/33 [00:26<00:00,  1.25it/s]


Validation WER: 0.6351
Epoch 10/10


100%|██████████| 130/130 [00:48<00:00,  2.67it/s]


Avg training loss: 1.9715


100%|██████████| 33/33 [00:26<00:00,  1.24it/s]

Validation WER: 0.6229

Training finished. Best WER: 0.5993





In [None]:
def transcribe_wav(wav_path, model_path):
    processor = Speech2TextProcessor.from_pretrained(model_path)
    model = Speech2TextForConditionalGeneration.from_pretrained(model_path)
    model.to("cuda" if torch.cuda.is_available() else "cpu")

    waveform, sr = torchaudio.load(wav_path)
    waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform).squeeze().numpy()

    inputs = processor(waveform, sampling_rate=16000, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    generated_ids = model.generate(input_features=inputs["input_features"], max_length=128)
    return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
