<a href="https://colab.research.google.com/github/hissain/awesome_bangla_asr/blob/main/finetuning/fine_tuning_wishper_bn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torch torchaudio transformers datasets evaluate jiwer accelerate

In [2]:
import huggingface_hub
huggingface_hub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import torch
from datasets import load_dataset, Audio, IterableDataset
from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

# Config
MODEL_NAME = "openai/whisper-small"
DATASET_NAME = "mozilla-foundation/common_voice_17_0"
DATASET_CONFIG = "bn"
SUBSET_SIZE = 10  # Start with 10 samples
MAX_INPUT_LENGTH = 30.0  # Seconds

# Load processor
processor = WhisperProcessor.from_pretrained(
    MODEL_NAME,
    language="bengali",
    task="transcribe"
)

def prepare_dataset(batch):
    # Load and resample audio
    audio = batch["audio"]

    # Process audio
    inputs = processor(
        audio["array"],
        sampling_rate=audio["sampling_rate"],
        max_duration=MAX_INPUT_LENGTH,
        return_tensors="pt"
    )
    batch["input_features"] = inputs.input_features[0]

    # Process text (Common Voice uses "sentence" field)
    batch["labels"] = processor.tokenizer(
        batch["sentence"].strip(),
        max_length=128,
        truncation=True
    ).input_ids
    return batch

# Load dataset in streaming mode
streaming_dataset = load_dataset(
    DATASET_NAME,
    DATASET_CONFIG,
    split="train",
    streaming=True
).take(SUBSET_SIZE)  # Take first N samples

# Cast audio column and preprocess in streaming
streaming_dataset = streaming_dataset.cast_column("audio", Audio(sampling_rate=16000))
streaming_dataset = streaming_dataset.map(prepare_dataset, remove_columns=["audio", "sentence"])

# Manually split for streaming (80/20)
train_dataset = IterableDataset.from_generator(lambda: (x for i, x in enumerate(streaming_dataset) if i % 5 != 0))
eval_dataset = IterableDataset.from_generator(lambda: (x for i, x in enumerate(streaming_dataset) if i % 5 == 0))

# Model
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
model.config.forced_decoder_ids = None

# Training args optimized for streaming
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-bn-stream",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,  # Better memory usage
    learning_rate=1e-5,
    max_steps=100,
    fp16=True,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    eval_steps=20,
    save_strategy="no",
    report_to="none",
    remove_unused_columns=False,  # Important for streaming
    dataloader_pin_memory=False,  # Reduces memory pressure
)

def streaming_collator(features):
    # Pad labels to the maximum length in the batch
    max_label_length = max(len(f["labels"]) for f in features)
    labels = [f["labels"] + [processor.tokenizer.pad_token_id] * (max_label_length - len(f["labels"])) for f in features]

    return {
        "input_features": torch.stack([f["input_features"] for f in features]),
        "labels": torch.tensor(labels) # Create tensor after padding
    }

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=streaming_collator,
)

# Start training
trainer.train()

Reading metadata...: 21228it [00:01, 20655.87it/s]
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...
Reading metadata...: 21228it [00:00, 22971.54it/s]
