<a href="https://colab.research.google.com/github/hissain/awesome_bangla_asr/blob/main/finetuning/fine_tuning_wishper_bn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!pip install -q torch torchaudio transformers datasets evaluate jiwer accelerate

In [2]:
import huggingface_hub
huggingface_hub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [1]:
import torch
from datasets import load_dataset, Audio, IterableDataset
from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

# Config
MODEL_NAME = "openai/whisper-small"
DATASET_NAME = "mozilla-foundation/common_voice_17_0"
DATASET_CONFIG = "bn"
SUBSET_SIZE = 10  # Start with 10 samples
MAX_INPUT_LENGTH = 30.0  # Seconds

# Detect and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load processor
processor = WhisperProcessor.from_pretrained(
    MODEL_NAME,
    language="bengali",
    task="transcribe"
)

def prepare_dataset(batch):
    # Process audio (unchanged)
    audio = batch["audio"]
    inputs = processor(
        audio["array"],
        sampling_rate=audio["sampling_rate"],
        max_duration=MAX_INPUT_LENGTH,
        return_tensors="pt"
    )
    batch["input_features"] = inputs.input_features[0]
    
    # Process text with padding
    tokenized = processor.tokenizer(
        batch["sentence"].strip(),
        max_length=128,
        truncation=True,
        padding="max_length",  # Add padding
        return_tensors="pt"
    )
    batch["labels"] = tokenized.input_ids[0]
    return batch

# Load dataset in streaming mode
streaming_dataset = load_dataset(
    DATASET_NAME,
    DATASET_CONFIG,
    split="train",
    streaming=True
).take(SUBSET_SIZE)  # Take first N samples

# Cast audio column and preprocess in streaming
streaming_dataset = streaming_dataset.cast_column("audio", Audio(sampling_rate=16000))
streaming_dataset = streaming_dataset.map(prepare_dataset, remove_columns=["audio", "sentence"])

# Manually split for streaming (80/20)
train_dataset = IterableDataset.from_generator(lambda: (x for i, x in enumerate(streaming_dataset) if i % 5 != 0))
eval_dataset = IterableDataset.from_generator(lambda: (x for i, x in enumerate(streaming_dataset) if i % 5 == 0))

# Model
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)
model.config.forced_decoder_ids = None

# Adjust training args based on device
training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-bn-stream",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,  # Better memory usage
    learning_rate=1e-5,
    max_steps=100,
    fp16=torch.cuda.is_available(),  # Only use fp16 if CUDA is available
    fp16_backend="auto",
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    eval_steps=20,
    save_strategy="no",
    report_to="none",
    remove_unused_columns=False,  # Important for streaming
    dataloader_pin_memory=torch.cuda.is_available(),
)

def streaming_collator(features):
    # Pad dynamically to max length in batch
    label_lengths = [len(f["labels"]) for f in features]
    max_length = max(label_lengths)
    
    padded_labels = [
        torch.cat([f["labels"], 
                  torch.full((max_length - len(f["labels"]),), 
                            processor.tokenizer.pad_token_id,
                            dtype=torch.long)])
        for f in features
    ]
    
    return {
        "input_features": torch.stack([f["input_features"] for f in features]).to(device),
        "labels": torch.stack(padded_labels).to(device)
    }

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=streaming_collator,
)

# Start training
trainer.train()



ValueError: Tried to use `fp16` but it is not supported on cpu