# Fine-Tuning OpenAI Whisper-base model on Hindi ASR Dataset

**Note** : I have used L4 GPU for training the model. These models require a lot of GPU memory. So you can use Google Colab or any other cloud service to train the model. I used Lightning AI cloud platform for training the model. Lightning AI provides some free credits to train on some good GPU's. [Lighting AI](https://lightning.ai/)

In [1]:
%pip install datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio huggingface_hub 

zsh:1: no matches found: datasets[audio]
Note: you may need to restart the kernel to use updated packages.


In [None]:
from huggingface_hub import login

# Paste your Hugging Face token here (you can get it from https://huggingface.co/settings/tokens)
login(token="your token here")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/zeus/.cache/huggingface/token
Login successful


In [3]:
%pip install accelerate -U

Note: you may need to restart the kernel to use updated packages.


In [4]:
%pip install transformers[torch]

zsh:1: no matches found: transformers[torch]
Note: you may need to restart the kernel to use updated packages.


In [5]:
!pip install librosa jiwer



In [6]:
from datasets import load_dataset, DatasetDict
from transformers import (
    WhisperTokenizer,
    WhisperProcessor,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from datasets import Audio
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch
import evaluate
import librosa

In [None]:
model_id = 'openai/whisper-base'                  # you can change this to any other Whisper model tiny, base, small, medium, large
out_dir = '/teamspace/studios/this_studio/whisper_base_hindi'
epochs = 10
batch_size = 32

In [None]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("SPRINGLab/IndicTTS-Hindi", split="train")

# Shuffle the dataset for randomness
dataset = dataset.shuffle(seed=42)

# First split: 6k samples (5k train + 1k valid), rest discarded
dataset_small = dataset.select(range(6000))         # Only fine-tuning on 5k samples for faster training

# Second split: train = 5k, valid = 1k
train_valid_split = dataset_small.train_test_split(test_size=1000, seed=42)

train_dataset = train_valid_split['train']
valid_dataset = train_valid_split['test']

Loading dataset shards:   0%|          | 0/17 [00:00<?, ?it/s]

In [9]:
print(f"Train size: {len(train_dataset)}")
print(f"Valid size: {len(valid_dataset)}")

Train size: 5000
Valid size: 1000


In [10]:
train_dataset

Dataset({
    features: ['audio', 'text', 'gender'],
    num_rows: 5000
})

In [11]:
train_dataset[0]

{'audio': {'path': 'train_hindifullfemale_03648.wav',
  'array': array([0., 0., 0., ..., 0., 0., 0.]),
  'sampling_rate': 48000},
 'text': 'उन्होंने महसूस किया, कि वे काफ़ी भावुक हो गईं थी, और अपनी पोती सारा से कहा, कि उनकी तस्वीर, फ़ेसबुक पर पोस्ट कर दें.',
 'gender': 0}

In [None]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)

tokenizer = WhisperTokenizer.from_pretrained(model_id, language='Hindi', task='transcribe')   # change the language to accordingly

processor = WhisperProcessor.from_pretrained(model_id, language='Hindi', task='transcribe')


*Whisper models are pretrained on 16KHz, so you must preprocess your audio files to 16KHz.*

In [None]:
hindi_dataset_train = train_dataset.cast_column('audio', Audio(sampling_rate=16000))       
hindi_dataset_valid = valid_dataset.cast_column('audio', Audio(sampling_rate=16000))

In [None]:
def prepare_dataset(batch):
    audio = batch['audio']
    batch['input_features'] = feature_extractor(
        audio['array'],
        sampling_rate=audio['sampling_rate']
    ).input_features[0]

    # Truncate labels to max decoder length = 448
    batch['labels'] = tokenizer(
        batch['text'],
        padding='max_length',
        truncation=True,
        max_length=448,   # ✅ Max target length for Whisper decoder
        return_tensors='pt'
    ).input_ids[0].tolist()

    return batch

# Mapping take a lot of time to run, so we can use multiprocessing to speed it up
# Note: If you have a lot of CPU cores, you can increase num_proc to speed up the mapping process 
hindi_dataset_train = hindi_dataset_train.map(
    prepare_dataset,
    num_proc=2
)

hindi_dataset_valid = hindi_dataset_valid.map(
    prepare_dataset,
    num_proc=2
)

Map (num_proc=2):   0%|          | 0/5000 [00:00<?, ? examples/s]

Map (num_proc=2):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [15]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{'input_features': feature['input_features']} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')

        label_features = [{'input_ids': feature['labels']} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')

        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch['labels'] = labels

        return batch

Data Collator does the following things:
- Converting the audio input features and the tokenized transcriptions to PyTorch tensors.
- Changing all the masked tokens to -100 so the loss function ignores them while loss calculation.

In [16]:
model = WhisperForConditionalGeneration.from_pretrained(model_id)

model.generation_config.task = 'transcribe'

model.generation_config.forced_decoder_ids = None

In [17]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [None]:
metric = evaluate.load('wer')   # define the metric to be used for evaluation

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {'wer': wer}

In [None]:
training_args = Seq2SeqTrainingArguments(                      # Training arguments for the model
    output_dir=out_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=1,
    learning_rate=0.00001,
    warmup_steps=1000,
    bf16=False,
    fp16=True,
    num_train_epochs=epochs,
    eval_strategy ='epoch',
    logging_strategy='epoch',
    save_strategy='epoch',
    predict_with_generate=True,
    generation_max_length=225,
    report_to=['tensorboard'],
    load_best_model_at_end=True,
    metric_for_best_model='wer',
    greater_is_better=False,
    dataloader_num_workers=8,
    save_total_limit=2,
    lr_scheduler_type='constant',
    seed=42,
    data_seed=42
)

In [20]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=hindi_dataset_train,
    eval_dataset=hindi_dataset_valid,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()   # Train the model

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss,Wer
1,0.1362,0.061018,35.979045
2,0.0477,0.047074,30.17693
3,0.0346,0.04156,27.844223
4,0.0268,0.03854,26.094692
5,0.0208,0.037542,25.600474
6,0.0162,0.038051,25.234753
7,0.0124,0.03833,25.051893
8,0.0093,0.039591,24.587328
9,0.0068,0.042442,24.839379
10,0.005,0.043711,24.829495


You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50

TrainOutput(global_step=1570, training_loss=0.031567544770088926, metrics={'train_runtime': 3287.2143, 'train_samples_per_second': 15.21, 'train_steps_per_second': 0.478, 'total_flos': 3.242999808e+18, 'train_loss': 0.031567544770088926, 'epoch': 10.0})

# Inference of the model

In [26]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
import librosa

# Load from base model for processor, and fine-tuned model for model weights
base_model_path = "openai/whisper-base"  # or your original base: e.g., "whisper_base_hindi"
fine_tuned_path = "/teamspace/studios/this_studio/whisper_base_hindi/checkpoint-1570/"

processor = WhisperProcessor.from_pretrained(base_model_path)
model = WhisperForConditionalGeneration.from_pretrained(fine_tuned_path)
model.eval()


WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 512)
      (layers): ModuleList(
        (0-5): 6 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=False)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          

In [27]:
# Load audio with librosa (mono, 16kHz)
audio_file = "/teamspace/studios/this_studio/man_sound.wav"
waveform, _ = librosa.load(audio_file, sr=16000)

# Preprocess
input_features = processor(waveform, sampling_rate=16000, return_tensors="pt").input_features

In [28]:
# Predict
with torch.no_grad():
    predicted_ids = model.generate(input_features)

# Decode
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print("Transcription:", transcription)

Transcription: चुकि उनका मन जो आप बी जानती है कि छूटे वच्चे रखते हैं, चन्जर वच्य होते.
