In [15]:
!pip install transformers torch torchaudio torchvision evaluate jiwer openai-whisper --quiet

In [16]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchaudio
import os
from pathlib import Path
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration
from tqdm.auto import tqdm 
import pickle
import jiwer 
import numpy as np 

## Config 

In [17]:
DATA_PATH = "/kaggle/input/librispeech-train-clean-100/" 
# MODEL_NAME = "openai/whisper-small.en"
MODEL_NAME = "/kaggle/input/whisper-small-epochs-8_librispeech-100h/pytorch/default/1/whisper-small-epochs-8"
SAVE_TO = "/kaggle/working/models/"
SAVE_EVERY = 3000
GRADIENT_ACCUMULATION_STEPS = 64
EVALUATION_SIZE = 150 
LEARNING_RATE = 1e-4
EPOCHS = 1
BATCH_SIZE = 16 
TARGET_SAMPLE_RATE = 16000
CHECKPOINT_EVERY_N_EPOCHS = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [18]:
torch.cuda.empty_cache()

## Load Feature Extractor and Tokenizer

In [19]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME, language="english", task="transcribe")

## Load dataset

In [20]:
def load_librispeech_item(fileid: str, path: str, ext_audio: str, ext_txt: str):
    speaker_id, chapter_id, utterance_id = fileid.split("-")

    file_text = f"{speaker_id}-{chapter_id}{ext_txt}"
    file_text = os.path.join(path, speaker_id, chapter_id, file_text)

    fileid_audio = f"{speaker_id}-{chapter_id}-{utterance_id}"
    file_audio = f"{fileid_audio}{ext_audio}"
    file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)

    waveform, sample_rate = torchaudio.load(file_audio)
    if sample_rate != TARGET_SAMPLE_RATE:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=TARGET_SAMPLE_RATE)
        waveform = resampler(waveform)

    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    waveform = waveform.squeeze(0)
    transcript = None
    with open(file_text) as ft:
        for line in ft:
            fileid_text, text = line.strip().split(" ", 1)
            if fileid_audio == fileid_text:
                transcript = text
                break
        else:
            raise FileNotFoundError(f"Translation not found for {fileid_audio} in {file_text}")

    return {"waveform": waveform.numpy(), "transcript": transcript}

In [21]:
class LibriSpeechDataset(Dataset): 
    _ext_txt = ".trans.txt"
    _ext_audio = ".flac"

    def __init__(self, data_type='train'):
        base_path = DATA_PATH
        if data_type == 'train':
            self.url = os.path.join(base_path, 'train-clean-100')
        elif data_type == 'dev':
            self.url = os.path.join(base_path, 'dev-clean')
        elif data_type == 'test':
            self.url = os.path.join(base_path, 'test-clean')
        else:
            raise ValueError("data_type must be 'train', 'dev', or 'test'")

        if not os.path.isdir(self.url):
             raise FileNotFoundError(f"Dataset directory not found: {self.url}")

        _dataset_path = Path(self.url).resolve()

        self.walker = []
        print(f"Scanning dataset in: {_dataset_path}")
        found_files = 0
        for p in _dataset_path.glob('*/*/*' + self._ext_audio):
             self.walker.append((str(p.stem), str(_dataset_path)))
             found_files += 1
        print(f"Found {found_files} audio files.")

        if not self.walker:
             print(f"Warning: No audio files found in {self.url} with pattern * G * / * / *{self._ext_audio}")

        self.walker = sorted(self.walker)

    def __len__(self):
        return len(self.walker)

    def __getitem__(self, n):
        fileid, _path = self.walker[n]
        item = load_librispeech_item(fileid, _path, self._ext_audio, self._ext_txt)
        while item is None:
             print(f"Warning: Skipping item at index {n} due to loading error.")
             n = (n + 1) % len(self.walker) 
             if n == self.walker[n][0]: 
                 raise RuntimeError("Could not load any valid items from the dataset.")
             fileid, _path = self.walker[n]
             item = load_librispeech_item(fileid, _path, self._ext_audio, self._ext_txt)
        return item

In [22]:
train_set = LibriSpeechDataset('train')
dev_set = LibriSpeechDataset('dev')
test_set = LibriSpeechDataset('test')
print(f"Dataset lengths: Train={len(train_set)}, Dev={len(dev_set)}, Test={len(test_set)}")
if len(train_set) == 0 or len(dev_set) == 0 or len(test_set) == 0:
     print("Warning: One or more dataset splits are empty. Check dataset paths and file structure.")

Scanning dataset in: /kaggle/input/librispeech-train-clean-100/train-clean-100
Found 28539 audio files.
Scanning dataset in: /kaggle/input/librispeech-train-clean-100/dev-clean
Found 2703 audio files.
Scanning dataset in: /kaggle/input/librispeech-train-clean-100/test-clean
Found 2620 audio files.
Dataset lengths: Train=28539, Dev=2703, Test=2620


In [23]:
class DataCollatorSpeechSeq2SeqWithPadding:
    def __init__(self, feature_extractor, tokenizer):
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer

    def __call__(self, batch):
        batch = [item for item in batch if item is not None]
        if not batch:
             return {}

        waveforms = [item["waveform"] for item in batch]
        transcripts = [item["transcript"] for item in batch]

        input_features = self.feature_extractor(
            waveforms,
            sampling_rate=TARGET_SAMPLE_RATE,
            return_tensors="pt",

            padding="max_length",  
            truncation=True,    
        ).input_features

        labels = self.tokenizer(
            transcripts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=448
        ).input_ids

        return {
            "input_features": input_features,
            "labels": labels,
            "raw_texts": transcripts
        }

data_collator = DataCollatorSpeechSeq2SeqWithPadding(feature_extractor, tokenizer)

In [24]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=data_collator, shuffle=True, num_workers=2, pin_memory=True)
dev_loader = DataLoader(dev_set, batch_size=BATCH_SIZE, collate_fn=data_collator, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=data_collator, shuffle=False, num_workers=2, pin_memory=True)

## Model

In [25]:
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
model.freeze_encoder()
model.to(device)

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 768)
      (layers): ModuleList(
        (0-11): 12 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        

## Evaluate

In [26]:
def evaluate(model, dataloader, tokenizer, feature_extractor):
    print("Evaluating...")
    model.eval()
    total_loss = 0
    all_preds = []
    all_refs = []
    num_batches = 0

    with torch.no_grad():
        eval_batches = len(dataloader) 
        for i, batch in enumerate(tqdm(dataloader, total=eval_batches, desc="Evaluating")):
            if i >= eval_batches:
                 break
            if not batch: 
                 continue

            input_features = batch["input_features"].to(device)
            labels = batch["labels"].to(device)
            raw_texts = batch["raw_texts"]
            outputs = model(input_features=input_features, labels=labels)
            loss = outputs.loss
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Warning: Encountered {loss.item()} loss in evaluation batch {i}. Skipping.")
                continue
            total_loss += loss.item()

            decoder_start_token_id = model.config.decoder_start_token_id
            decoder_input_ids = torch.full(
                 (input_features.size(0), 1),
                 decoder_start_token_id,
                 dtype=torch.long,
                 device=device
            )

            generated_ids = model.generate(
                input_features,
                max_length=150 
            )

            generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            all_preds.extend(generated_texts)
            all_refs.extend(raw_texts) 

            num_batches += 1
            if num_batches == 0: 
                 print("Warning: No valid batches found during evaluation.")
                 return float('inf'), float('inf')


    avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
    normalized_refs = [text.lower() for text in all_refs]
    normalized_preds = [pred.lower() for pred in all_preds]
    wer_score = jiwer.wer(normalized_refs, normalized_preds) if all_refs and all_preds else float('inf')

    print(f"Evaluation Results - Loss: {avg_loss:.4f}, WER: {wer_score:.4f}")
    print("Example predictions:")
    for i in range(min(3, len(all_refs))):
         print(f"  Ref: {all_refs[i]}")
         print(f"  Pred: {all_preds[i]}")
         print("-" * 10)

    model.train() 
    return avg_loss, wer_score

## Optimizer

In [27]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) # AdamW thường tốt hơn Adam

## Training

In [28]:
steps = 0
gradient_steps = 0
total_loss_accumulated = 0.0

os.makedirs(SAVE_TO, exist_ok=True)

losses = []
epoch_train_losses = []
val_losses = []
val_wers = []

current_val_loss = float('inf')
current_val_wer = float('inf')
# val_loss, val_wer = evaluate(model, dev_loader, tokenizer, feature_extractor)
# print(f"Initial Eval - Loss: {val_loss:.4f}, WER: {val_wer:.4f}")
# val_losses.append(val_loss)
# val_wers.append(val_wer)

model.train()

Evaluating...


Evaluating:   0%|          | 0/169 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Evaluation Results - Loss: 0.0543, WER: 0.0437
Example predictions:
  Ref: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
  Pred: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
----------
  Ref: NOR IS MISTER QUILTER'S MANNER LESS INTERESTING THAN HIS MATTER
  Pred: NOR IS MISTER QUILTER'S MANNER LESS INTERESTING THAN HIS MATTER
----------
  Ref: HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAS AND ROAST BEEF LOOMING BEFORE US SIMILES DRAWN FROM EATING AND ITS RESULTS OCCUR MOST READILY TO THE MIND
  Pred: HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAS AND ROAST BEEF LOOMING BEFORE US SIMILIES DRAWN FROM EATING AND ITS RESULTS OCCUR MOST READILY TO THE MIND
----------
Initial Eval - Loss: 0.0543, WER: 0.0437


WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 768)
      (layers): ModuleList(
        (0-11): 12 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        

In [None]:
for epoch in range(EPOCHS):
    print(f"\n--- Epoch: {epoch + 1}/{EPOCHS} ---")
    model.train() 
    epoch_loss_sum = 0.0
    batches_in_epoch = 0

    training_loop = tqdm(train_loader, desc=f"Epoch {epoch + 1}")

    for batch in training_loop:
        if not batch: 
            continue

        input_features = batch["input_features"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_features=input_features, labels=labels)
        loss = outputs.loss

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: Encountered {loss.item()} loss at step {steps}. Skipping batch.")
            optimizer.zero_grad()
            continue

        epoch_loss_sum += loss.item() * GRADIENT_ACCUMULATION_STEPS # Tính lại loss gốc trước khi scale
        batches_in_epoch += 1

        loss = loss / GRADIENT_ACCUMULATION_STEPS
        total_loss_accumulated += loss.item() # Cộng loss đã scale để tính loss của gradient step
        loss.backward()
        
        steps += 1
        if steps % GRADIENT_ACCUMULATION_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            gradient_steps += 1

            avg_accumulated_loss = total_loss_accumulated 
            losses.append(avg_accumulated_loss * GRADIENT_ACCUMULATION_STEPS) 
            total_loss_accumulated = 0.0 

            training_loop.set_description(
                 f"Epoch {epoch + 1} | Step: {steps} | Grad Steps: {gradient_steps} | Batch Loss: {loss.item() * GRADIENT_ACCUMULATION_STEPS:.4f} | Val Loss: {current_val_loss:.4f} | Val WER: {current_val_wer:.4f}"
            )

    avg_epoch_train_loss = epoch_loss_sum / batches_in_epoch if batches_in_epoch > 0 else float('inf')
    epoch_train_losses.append(avg_epoch_train_loss)
    print(f"Epoch {epoch + 1} finished. Average Training Loss: {avg_epoch_train_loss:.4f}")

    print(f"Evaluating at the end of Epoch {epoch + 1}...")
    current_val_loss, current_val_wer = evaluate(model, dev_loader, tokenizer, feature_extractor)
    val_losses.append(current_val_loss)
    val_wers.append(current_val_wer)
    print(f"Epoch {epoch + 1} Eval - Loss: {current_val_loss:.4f}, WER: {current_val_wer:.4f}")

    if (epoch + 1) % CHECKPOINT_EVERY_N_EPOCHS == 0 or (epoch + 1) == EPOCHS:
        print(f"\nSaving checkpoint at the end of Epoch {epoch + 1}...")
        save_path = os.path.join(SAVE_TO, f'whisper-epoch-{epoch + 1}')
        model.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)
        feature_extractor.save_pretrained(save_path)
        print(f"Checkpoint saved to {save_path}")

        try:
            with open(os.path.join(SAVE_TO, f'losses_grad_step_epoch_{epoch+1}.pkl'), 'wb') as f:
                 pickle.dump(losses, f)
            with open(os.path.join(SAVE_TO, f'epoch_train_losses_epoch_{epoch+1}.pkl'), 'wb') as f:
                 pickle.dump(epoch_train_losses, f)
            with open(os.path.join(SAVE_TO, f'val_losses_epoch_{epoch+1}.pkl'), 'wb') as f:
                 pickle.dump(val_losses, f)
            with open(os.path.join(SAVE_TO, f'val_wers_epoch_{epoch+1}.pkl'), 'wb') as f:
                 pickle.dump(val_wers, f)
            print("Metrics saved for Epoch", epoch + 1)
        except Exception as e:
            print(f"Error saving metrics for Epoch {epoch + 1}: {e}")

    model.train()

## Final review after completing all epochs

In [30]:
print("--- Training Finished ---")
print("Performing final evaluation on Dev set...")
final_val_loss, final_val_wer = evaluate(model, dev_loader, tokenizer, feature_extractor)
print(f"Final Dev Eval - Loss: {final_val_loss:.4f}, WER: {final_val_wer:.4f}")

print("Performing final evaluation on Test set...")
final_test_loss, final_test_wer = evaluate(model, test_loader, tokenizer, feature_extractor)
print(f"Final Test Eval - Loss: {final_test_loss:.4f}, WER: {final_test_wer:.4f}")
print("Saving final model...")
final_save_path = os.path.join(SAVE_TO, 'whisper-final')
model.save_pretrained(final_save_path)
tokenizer.save_pretrained(final_save_path)
feature_extractor.save_pretrained(final_save_path)


--- Training Finished ---
Performing final evaluation on Dev set...
Evaluating...


Evaluating:   0%|          | 0/169 [00:00<?, ?it/s]

Evaluation Results - Loss: 0.0567, WER: 0.0499
Example predictions:
  Ref: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
  Pred: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
----------
  Ref: NOR IS MISTER QUILTER'S MANNER LESS INTERESTING THAN HIS MATTER
  Pred: NOR IS MISTER QUILTER'S MANNER LESS INTERESTING THAN HIS MATTER
----------
  Ref: HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAS AND ROAST BEEF LOOMING BEFORE US SIMILES DRAWN FROM EATING AND ITS RESULTS OCCUR MOST READILY TO THE MIND
  Pred:  HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAS AND ROAST BEEF LOOMING BEFORE US SIMILIES DRAWN FROM EATING AND ITS RESULTS OCCURMOSE READILY TO THE MINE
----------
Final Dev Eval - Loss: 0.0567, WER: 0.0499
Performing final evaluation on Test set...
Evaluating...


Evaluating:   0%|          | 0/164 [00:00<?, ?it/s]

Evaluation Results - Loss: 0.0563, WER: 0.0487
Example predictions:
  Ref: HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE
  Pred: HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE
----------
  Ref: STUFF IT INTO YOU HIS BELLY COUNSELLED HIM
  Pred: STUFF IT INTO YOU HIS BELLY COUNSELLED HIM
----------
  Ref: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
  Pred: AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
----------
Final Test Eval - Loss: 0.0563, WER: 0.0487
Saving final model...


['/kaggle/working/models/whisper-final/preprocessor_config.json']

In [31]:
with open(os.path.join(SAVE_TO, 'losses.pkl'), 'wb') as f:
    pickle.dump(losses, f)
with open(os.path.join(SAVE_TO, 'val_losses.pkl'), 'wb') as f:
    pickle.dump(val_losses + [final_val_loss], f)
with open(os.path.join(SAVE_TO, 'val_wers.pkl'), 'wb') as f:
    pickle.dump(val_wers + [final_val_wer], f)
with open(os.path.join(SAVE_TO, 'test_results.pkl'), 'wb') as f:
    pickle.dump({'loss': final_test_loss, 'wer': final_test_wer}, f)
print("Final model and metrics saved.")

Final model and metrics saved.


In [33]:
losses, val_losses+[final_val_loss], val_wers + [final_val_wer], final_test_loss, final_test_wer

([0.08220819412963465,
  1.4517763387411833,
  1.4123799963854253,
  0.6676259930245578,
  0.9075691448524594],
 [0.05432921249595796, 0.0566839841139343],
 [0.04365648321752877, 0.049924635123708684],
 0.05625625495321867,
 0.048691418137553254)