In [1]:
# run once per environment/kernel (remove -q if you want verbose output)
!pip install -q torch torchaudio transformers datasets evaluate librosa soundfile pandas tqdm accelerate


[0m

In [2]:
# Cell 2
import os
import torch
import pandas as pd
import evaluate
import numpy as np
from tqdm.auto import tqdm
from datasets import load_dataset, concatenate_datasets, Audio
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    TrainerCallback,
)

# ============================================
# Configuration (same as your original)
# ============================================
LANGS = ["hindi", "gujarati"]
MODEL_ID = "openai/whisper-medium"
OUTPUT_DIR = "./whisper-kathbath-hi-gu"
MIN_DUR = 2.0
MAX_DUR = 30.0
USE_QUALITY_FILTER = True    # your original setting (kept)
SEED = 42


2025-10-08 06:11:32.758578: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-08 06:11:33.616754: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9360] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-10-08 06:11:33.616829: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-10-08 06:11:33.616909: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1537] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-08 06:11:33.870459: I tensorflow/core/platform/cpu_feature_g

In [3]:
# Cell 3
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device} ({torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'})")


Using device: cuda (Quadro RTX 5000)


In [4]:
# Cell 4
datasets_list = []
for lang in LANGS:
    ds = load_dataset("ai4bharat/Kathbath", lang)
    print(f"✅ Loaded {lang}: {ds}")
    datasets_list.append(ds)


Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/32 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/33 [00:00<?, ?it/s]

✅ Loaded hindi: DatasetDict({
    valid: Dataset({
        features: ['fname', 'text', 'audio_filepath', 'lang', 'duration', 'gender', 'speaker_id'],
        num_rows: 3151
    })
    train: Dataset({
        features: ['fname', 'text', 'audio_filepath', 'lang', 'duration', 'gender', 'speaker_id'],
        num_rows: 91752
    })
})


Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/27 [00:00<?, ?it/s]

✅ Loaded gujarati: DatasetDict({
    valid: Dataset({
        features: ['fname', 'text', 'audio_filepath', 'lang', 'duration', 'gender', 'speaker_id'],
        num_rows: 2905
    })
    train: Dataset({
        features: ['fname', 'text', 'audio_filepath', 'lang', 'duration', 'gender', 'speaker_id'],
        num_rows: 66865
    })
})


In [5]:
# Cell 5
from datasets import DatasetDict

# Combine train splits across languages
train_datasets = [d["train"] for d in datasets_list]
train_data = concatenate_datasets(train_datasets)

# Build validation list by checking available split names per dataset
val_splits = []
for d in datasets_list:
    if "validation" in d:
        val_splits.append(d["validation"])
    elif "valid" in d:    # Kathbath uses 'valid'
        val_splits.append(d["valid"])
    elif "test" in d:
        val_splits.append(d["test"])
    else:
        # fallback: take 5% of train as validation for that language
        print("⚠️ No val/test/validation split for one dataset — taking 5% from train as val.")
        val_splits.append(d["train"].train_test_split(test_size=0.05, seed=SEED)["test"])

val_data = concatenate_datasets(val_splits)

# Cast the audio filepath column to Audio BUT disable automatic decoding (decode=False)
audio_feature = Audio(sampling_rate=16000, decode=False)
train_data = train_data.cast_column("audio_filepath", audio_feature)
val_data = val_data.cast_column("audio_filepath", audio_feature)

print(f"✅ Combined sizes -> train: {len(train_data)} | val: {len(val_data)}")


✅ Combined sizes -> train: 158617 | val: 6056


In [6]:
# Cell 6
def filter_fn(batch):
    # Use metadata duration (no audio decoding)
    dur = batch.get("duration", 0.0)
    if dur < MIN_DUR or dur > MAX_DUR:
        return False
    # Kathbath doesn't have a 'quality' column by default; safe check:
    if USE_QUALITY_FILTER and batch.get("quality") and batch.get("quality") != "high":
        return False
    return True

print("🔍 Filtering dataset (by duration)...")
# Use small num_proc to avoid multiprocessing issues in notebooks
train_data = train_data.filter(filter_fn, desc="Filtering Train", num_proc=1)
val_data = val_data.filter(filter_fn, desc="Filtering Validation", num_proc=1)
print(f"✅ After filtering -> train: {len(train_data)} | val: {len(val_data)}")


🔍 Filtering dataset (by duration)...
✅ After filtering -> train: 158609 | val: 6056


In [7]:
# Cell 7
processor = WhisperProcessor.from_pretrained(MODEL_ID, language="hindi", task="transcribe")
# processor includes tokenizer & feature_extractor
print("✅ Whisper processor loaded.")


✅ Whisper processor loaded.


In [8]:
# Cell 8 (REPLACEMENT) — robust audio loader + preprocess
import io
import soundfile as sf
import librosa
import numpy as np

# keep a small failure count so we don't flood the output
_FAILURE_LOG_LIMIT = 5
failure_count = 0

def _load_audio_field(field):
    """
    Accepts:
      - a string file path
      - a dict containing 'bytes' (raw audio bytes) OR 'path' OR 'array' (+ optional sampling_rate)
      - a list/ndarray of samples
    Returns:
      (audio_array (1D np.float32), sr)
    Raises:
      Exception on failure (caller handles/logs)
    """
    # case: path string
    if isinstance(field, str):
        audio_array, sr = sf.read(field, dtype="float32")
        # soundfile returns shape (n,) for mono or (n, channels) for multi-channel
        if audio_array.ndim > 1:
            # average channels to mono
            audio_array = np.mean(audio_array, axis=1)
        return audio_array.astype("float32"), int(sr)

    # case: dict with bytes / path / array
    if isinstance(field, dict):
        # raw bytes stored inside dataset (common when audio embedded)
        if "bytes" in field:
            bio = io.BytesIO(field["bytes"])
            audio_array, sr = sf.read(bio, dtype="float32")
            if audio_array.ndim > 1:
                audio_array = np.mean(audio_array, axis=1)
            return audio_array.astype("float32"), int(sr)
        # some datasets store {'path': '/...'}
        if "path" in field:
            audio_array, sr = sf.read(field["path"], dtype="float32")
            if audio_array.ndim > 1:
                audio_array = np.mean(audio_array, axis=1)
            return audio_array.astype("float32"), int(sr)
        # pre-decoded array (list or numpy)
        if "array" in field:
            arr = np.array(field["array"], dtype="float32")
            sr = field.get("sampling_rate", 16000)
            return arr, int(sr)

    # case: already list/ndarray
    if isinstance(field, (list, np.ndarray)):
        arr = np.array(field, dtype="float32")
        return arr, 16000

    raise ValueError("Unsupported audio field format")

# The prepare_dataset returns a dict with only new columns (so remove_columns works)
def prepare_dataset(example):
    global failure_count
    field = example.get("audio_filepath")  # whatever column name you have

    try:
        audio_array, sr = _load_audio_field(field)
    except Exception as e:
        # don't print raw bytes; give a concise message and count limited occurrences
        failure_count += 1
        if failure_count <= _FAILURE_LOG_LIMIT:
            fname = example.get("fname", "unknown")
            print(f"⚠️ Failed to load audio for {fname}: {str(e)}")
        # return None entries to be filtered out later
        return {"input_features": None, "labels": None}

    # resample to 16k if needed
    if sr != 16000:
        audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000)
        sr = 16000

    # create input features — processor.feature_extractor expects raw waveform (1D numpy)
    input_feats = processor.feature_extractor(audio_array, sampling_rate=16000).input_features[0]

    # tokenise the text -> labels
    text = example.get("text", "")
    label_ids = processor.tokenizer(text).input_ids

    return {"input_features": input_feats, "labels": label_ids}

print("🎧 Preprocessing audio samples (robust loader)...")
# Use num_proc=1 in notebooks for stability. Keep remove_columns as before.
train_data = train_data.map(prepare_dataset, remove_columns=train_data.column_names, num_proc=1, desc="Processing Train")
val_data = val_data.map(prepare_dataset, remove_columns=val_data.column_names, num_proc=1, desc="Processing Validation")

# Drop examples where preprocessing failed
train_data = train_data.filter(lambda ex: ex["input_features"] is not None, num_proc=1)
val_data = val_data.filter(lambda ex: ex["input_features"] is not None, num_proc=1)

print(f"✅ Preprocessing done -> train: {len(train_data)} | val: {len(val_data)}")
if failure_count > _FAILURE_LOG_LIMIT:
    print(f"⚠️ Note: {failure_count} total load failures (only first {_FAILURE_LOG_LIMIT} logged).")


🎧 Preprocessing audio samples (robust loader)...
✅ Preprocessing done -> train: 158609 | val: 6056


In [9]:
# Cell 9
print("⚙️ Loading model...")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.gradient_checkpointing_enable()
model = model.to(device)
print("✅ Model loaded to device.")


⚙️ Loading model...
✅ Model loaded to device.


In [10]:
# Cell 10
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")

def compute_metrics(pred):
    # preds are token ids from generate
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # If label_ids contain -100 (masked), replace with pad token id for decoding
    label_ids = np.where(label_ids == -100, processor.tokenizer.pad_token_id, label_ids)

    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wers, cers = [], []
    sample_metrics = []
    for ps, ls in zip(pred_str, label_str):
        wer = wer_metric.compute(predictions=[ps], references=[ls])
        cer = cer_metric.compute(predictions=[ps], references=[ls])
        wers.append(wer)
        cers.append(cer)
        sample_metrics.append({"prediction": ps, "reference": ls, "wer": wer, "cer": cer})

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    pd.DataFrame(sample_metrics).to_csv(os.path.join(OUTPUT_DIR, "eval_predictions.csv"), index=False)

    return {"wer": float(np.mean(wers)), "cer": float(np.mean(cers))}


In [34]:
# Cell 11 — Custom DataCollator to handle input_features & labels
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Separate labels and inputs
        labels = [f["labels"] for f in features]
        inputs = [{"input_features": f["input_features"]} for f in features]

        batch = self.processor.feature_extractor.pad(
            inputs,
            padding=self.padding,
            return_tensors="pt"
        )

        with self.processor.as_target_processor():
            labels_batch = self.processor.tokenizer.pad(
                {"input_ids": labels},
                padding=self.padding,
                return_tensors="pt"
            )

        batch["labels"] = labels_batch["input_ids"]
        return batch

# Initialize collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)


In [35]:
# Cell 12 — Seq2SeqTrainingArguments compatible with older transformers
from transformers import Seq2SeqTrainingArguments
import torch

training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=6000,
    fp16=torch.cuda.is_available(),
    gradient_checkpointing=True,
    evaluation_strategy="steps",     # compatible
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    logging_steps=100,
    predict_with_generate=True,
    generation_max_length=225,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",  # ⚠ must provide metric
    greater_is_better=False,
    dataloader_num_workers=2,
    group_by_length=False,
    remove_unused_columns=False,
    seed=SEED,
)


TypeError: Seq2SeqTrainingArguments.__init__() got an unexpected keyword argument 'evaluation_strategy'

In [27]:
# Cell 13 — Initialize Trainer with custom collator and sampler
from transformers import Seq2SeqTrainer
from torch.utils.data import RandomSampler
from types import MethodType

# Initialize Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    processing_class=processor,
    compute_metrics=compute_metrics,
)

# Force trainer to use RandomSampler (avoids input_ids length issues)
def _get_train_sampler_override(self, dataset):
    return RandomSampler(dataset)

trainer._get_train_sampler = MethodType(_get_train_sampler_override, trainer)
print("✅ Trainer initialized with custom sampler and collator")


✅ Training arguments configured successfully.


In [29]:
# Cell 14 — Fix: force random sampler
from torch.utils.data import RandomSampler
from types import MethodType

def _get_train_sampler_override(self, dataset):
    return RandomSampler(dataset)

# Override trainer's method
trainer._get_train_sampler = MethodType(_get_train_sampler_override, trainer)

print("✅ Trainer sampler fixed — ready for training!")


✅ Trainer sampler fixed — ready for training!


In [32]:
# Cell 15 — Reinitialize accelerate and attach to trainer
from accelerate import Accelerator
from accelerate.state import AcceleratorState

# Reset any stale state
try:
    AcceleratorState._reset_state()
except Exception:
    pass

# Create fresh Accelerator and attach to trainer
accelerator = Accelerator()
trainer.accelerator = accelerator

# Ensure internal state points to the fresh accelerator
try:
    trainer.accelerator.state = accelerator.state
except Exception:
    pass

print("✅ New Accelerator attached to trainer.")
print("Accelerator device:", accelerator.device)
print("Distributed type:", getattr(accelerator.state, "distributed_type", "<not set>"))


✅ New Accelerator attached to trainer.
Accelerator device: cuda
Distributed type: DistributedType.NO


In [33]:
# Cell 16 — Dry-run 1-step
print("Running 1-step dry-run to validate pipeline after reinitializing accelerate...")
trainer.args.max_steps = 1
trainer.train()
print("✅ Dry-run finished successfully.")


Running 1-step dry-run to validate pipeline after reinitializing accelerate...


  0%|          | 0/1 [00:00<?, ?it/s]

ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided ['input_features']

In [31]:
# Cell 18 — Full Whisper fine-tuning
print("🚀 Starting full training...")
trainer.train()

print("📈 Evaluating model...")
metrics = trainer.evaluate()
print("✅ Final Metrics:", metrics)

print("💾 Saving model + processor...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
print(f"🎯 Model and processor saved at {OUTPUT_DIR}")


🚀 Starting full training...


AttributeError: `AcceleratorState` object has no attribute `distributed_type`. This happens if `AcceleratorState._reset_state()` was called and an `Accelerator` or `PartialState` was not reinitialized.