# Fine-tuning Whisper for Vietnamese ASR
This notebook demonstrates how to fine-tune OpenAI's Whisper model on a Vietnamese speech dataset using Hugging Face Transformers. The workflow includes environment setup, data loading, preprocessing, model training, and evaluation.

## 1. Environment Setup
Install the required libraries: `transformers`, `datasets`, `torchaudio`, and `jiwer` for evaluation.

In [1]:
!pip install transformers datasets torchaudio jiwer --quiet

In [1]:
import os
import gc
import psutil
import pandas as pd
from datasets import load_dataset, Dataset, Audio, load_from_disk
from transformers import WhisperProcessor, WhisperForConditionalGeneration, TrainingArguments, Trainer
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

  from .autonotebook import tqdm as notebook_tqdm


## 2. Load and Prepare Data
Assume your CSV files (`fpt_train.csv`, `fpt_val.csv`, `fpt_test.csv`) are in the current directory and contain columns: `path` (audio file path) and `transcription` (text).

In [None]:
def load_csv_to_dataset(csv_path):
    df = pd.read_csv(csv_path)
    ds = Dataset.from_pandas(df)
    ds = ds.cast_column("path", Audio(sampling_rate=16000))
    return ds

train_dataset = load_csv_to_dataset("fpt_train.csv")
val_dataset = load_csv_to_dataset("fpt_val.csv")

## 3. Load Whisper Model and Processor

In [3]:
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="vi", task="transcribe")

## 4. Preprocessing Function
Tokenize transcriptions and prepare input features.

In [13]:
def prepare_dataset(batch):
    audio = batch["path"]
    if isinstance(audio, list):
        input_features = [processor.feature_extractor(a["array"], sampling_rate=a["sampling_rate"]).input_features[0] for a in audio]
    else:
        input_features = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    if isinstance(batch["transcription"], list):
        labels = [processor.tokenizer(text).input_ids for text in batch["transcription"]]
    else:
        labels = processor.tokenizer(batch["transcription"]).input_ids
    return {"input_features": input_features, "labels": labels}

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch

def prepare_dataset(batch):
    audio = batch["path"]
    if isinstance(audio, list):
        input_features = [processor.feature_extractor(a["array"], sampling_rate=a["sampling_rate"]).input_features[0] for a in audio]
    else:
        input_features = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    if isinstance(batch["transcription"], list):
        labels = [processor.tokenizer(text).input_ids for text in batch["transcription"]]
    else:
        labels = processor.tokenizer(batch["transcription"]).input_ids
    return {"input_features": input_features, "labels": labels}

train_dataset = train_dataset.map(
    prepare_dataset,
    batched=True,
    batch_size=4,
    num_proc=1,
    remove_columns=train_dataset.column_names,
    load_from_cache_file=True,
    keep_in_memory=False,
    desc="Processing training data"
)
gc.collect()
val_dataset = val_dataset.map(
    prepare_dataset,
    batched=True,
    batch_size=4,
    num_proc=1,
    remove_columns=val_dataset.column_names,
    load_from_cache_file=True,
    keep_in_memory=False,
    desc="Processing validation data"
)
gc.collect()
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

val_dataset_processed = val_dataset.save_to_disk("val_dataset_processed")
train_dataset_processed = train_dataset.save_to_disk("train_dataset_processed")

Processing training data: 100%|██████████| 20735/20735 [02:03<00:00, 167.95 examples/s]

Processing validation data: 100%|██████████| 2592/2592 [00:13<00:00, 190.81 examples/s]

Saving the dataset (5/5 shards): 100%|██████████| 2592/2592 [00:03<00:00, 845.67 examples/s]
Saving the dataset (0/40 shards):   0%|          | 0/20735 [00:00<?, ? examples/s]
Saving the dataset (40/40 shards): 100%|██████████| 20735/20735 [01:23<00:00, 247.90 examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 2592/2592 [00:03<00:00, 861.85 examples/s]



## 5. Training Arguments and Trainer

In [5]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch
    
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [6]:
train_dataset = load_from_disk("train_dataset_processed")
val_dataset = load_from_disk("val_dataset_processed")
test_dataset = load_from_disk("test_dataset_processed")

training_args = TrainingArguments(
    output_dir="./whisper-vi-finetuned",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    eval_strategy="steps",
    num_train_epochs=3,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=100,
    learning_rate=1e-4,
    warmup_steps=500,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),
    dataloader_pin_memory=False,
    dataloader_num_workers=0,
    push_to_hub=False,
    remove_unused_columns=True,
    prediction_loss_only=True,
)

def compute_metrics(pred):
    try:
        from jiwer import wer
        pred_ids = pred.predictions
        label_ids = pred.label_ids
        label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
        pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
        wer_score = wer(label_str, pred_str)
        del pred_str, label_str
        gc.collect()
        return {"wer": wer_score}
    except Exception as e:
        print(f"Error computing metrics: {e}")
        return {"wer": 1.0}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor,
)

## 6. Start Training

In [7]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
1000,0.892,0.864725
2000,0.6371,0.668017
3000,0.405,0.586669
4000,0.3762,0.516293
5000,0.309,0.445573
6000,0.114,0.451866
7000,0.1237,0.412356




TrainOutput(global_step=7776, training_loss=0.43403253560203586, metrics={'train_runtime': 15227.2317, 'train_samples_per_second': 4.085, 'train_steps_per_second': 0.511, 'total_flos': 1.79514548269056e+19, 'train_loss': 0.43403253560203586, 'epoch': 3.0})

## 7. Evaluate on Test Set

In [20]:
test_dataset = load_csv_to_dataset("fpt_test.csv")
test_dataset = test_dataset.map(
    prepare_dataset,
    batched=True,
    batch_size=4,
    num_proc=1,
    remove_columns=test_dataset.column_names,
    load_from_cache_file=True,
    keep_in_memory=False,
    desc="Processing validation data"
)
gc.collect()
test_dataset_processed = test_dataset.save_to_disk("test_dataset_processed")

Processing validation data: 100%|██████████| 2592/2592 [00:15<00:00, 170.58 examples/s]
Processing validation data: 100%|██████████| 2592/2592 [00:15<00:00, 170.58 examples/s]
Saving the dataset (5/5 shards): 100%|██████████| 2592/2592 [00:03<00:00, 746.66 examples/s]



In [None]:
# Evaluate on the processed test set and print WER and loss
from datasets import load_from_disk

test_dataset = load_from_disk("test_dataset_processed")
results = trainer.evaluate(test_dataset=test_dataset)
print(f"Test Loss: {results['eval_loss']:.4f}")
if 'eval_wer' in results:
    print(f"Test WER: {results['eval_wer']:.4f}")
else:
    print("WER not computed. Check compute_metrics function.")

In [None]:
# Save the fine-tuned model and processor for later use
model.save_pretrained("./whisper-vi-finetuned")
processor.save_pretrained("./whisper-vi-finetuned")
print("Model and processor saved to ./whisper-vi-finetuned")

---
This notebook provides a basic pipeline for fine-tuning Whisper on Vietnamese ASR data. You can further customize preprocessing, augmentation, and hyperparameters as needed.