<a href="https://colab.research.google.com/github/kavish-24/Konkani_Mentall_Health/blob/main/WhiperFineTuneSmall.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install transformers datasets evaluate jiwer torch torchaudio accelerate tensorboard

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.14.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m58.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer, evaluate
Successfully installed evaluate-0.4.6 jiwer-4.0.0 rapidfuzz-3.14.2


In [8]:
import os
from transformers import WhisperTokenizer

# Load tokenizer (use whisper-small since you're fine-tuning it)
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small")

# Path where your transcripts are stored
transcript_dir = "/content/drive/MyDrive/training/10 Aug"  # <-- change to your directory

for filename in os.listdir(transcript_dir):
    if filename.endswith(".txt"):  # process only text files
        filepath = os.path.join(transcript_dir, filename)

        # Read the transcript file
        with open(filepath, "r", encoding="utf-8") as f:
            text = f.read().strip()

        # Tokenize and count
        tokens = tokenizer(text, return_tensors="pt").input_ids
        token_count = tokens.shape[1]

        print(f"{filename}: {token_count} tokens")


News_100817_segment_014_transcript.txt: 418 tokens
News_100817_segment_003_transcript.txt: 521 tokens
News_100817_segment_001_transcript.txt: 421 tokens
News_100817_segment_012_transcript.txt: 441 tokens
News_100817_segment_007_transcript.txt: 406 tokens
News_100817_segment_005_transcript.txt: 403 tokens
News_100817_segment_010_transcript.txt: 249 tokens
News_100817_segment_011_transcript.txt: 401 tokens
News_100817_segment_015_transcript.txt: 92 tokens
News_100817_segment_002_transcript.txt: 47 tokens
News_100817_segment_008_transcript.txt: 395 tokens
News_100817_segment_009_transcript.txt: 135 tokens
News_100817_segment_013_transcript.txt: 388 tokens
News_100817_segment_006_transcript.txt: 440 tokens
News_100817_segment_004_transcript.txt: 242 tokens


In [7]:
"""
Fine-tune Whisper Small model for Konkani language with Marathi support
"""

import os
import torch
from datasets import Dataset, Audio
from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
import numpy as np
import librosa

# Configuration
AUDIO_DIR = "/content/drive/MyDrive/Anju Project (1)/Audio Prudent media (1)/audio segment/10 Aug"  # Update with your audio directory
TRANSCRIPT_DIR = "/content/drive/MyDrive/training/10 Aug"  # Update with your transcript directory
MODEL_NAME = "openai/whisper-small"
OUTPUT_DIR = "/content/drive/MyDrive/whisper-small-konkani"
LANGUAGE = "konkani"
TASK = "transcribe"

# Training mode
CONTINUE_FROM_CHECKPOINT = False  # Set to True to continue training from existing model

# Training parameters
BATCH_SIZE = 8  # Adjust based on your GPU memory
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 1e-5
WARMUP_STEPS = 500
MAX_STEPS = 5000
EVAL_STEPS = 500
SAVE_STEPS = 500


def load_audio_transcript_pairs(audio_dir, transcript_dir):
    """
    Load audio files and their corresponding transcripts
    """
    data = []

    # Get all audio files
    audio_files = [f for f in os.listdir(audio_dir) if f.endswith('.mp3')]

    for audio_file in audio_files:
        # Extract the base name to match with transcript
        # From "Konkani Prime News_100817_segment_001.mp3"
        # to "News_100817_segment_001_transcript.txt"
        base_name = audio_file.replace("Konkani Prime ", "").replace(".mp3", "")
        transcript_file = f"{base_name}_transcript.txt"

        audio_path = os.path.join(audio_dir, audio_file)
        transcript_path = os.path.join(transcript_dir, transcript_file)

        # Check if transcript exists
        if os.path.exists(transcript_path):
            with open(transcript_path, 'r', encoding='utf-8') as f:
                transcript = f.read().strip()

            data.append({
                'audio': audio_path,
                'sentence': transcript
            })
        else:
            print(f"Warning: Transcript not found for {audio_file}")

    print(f"Loaded {len(data)} audio-transcript pairs")
    return data


def load_and_resample_audio(audio_path, target_sr=16000):
    """
    Load and resample audio file to target sample rate
    """
    try:
        audio_array, sr = librosa.load(audio_path, sr=target_sr)
        return audio_array
    except Exception as e:
        print(f"Error loading {audio_path}: {e}")
        return None


def prepare_dataset(audio_dir, transcript_dir, test_size=0.1):
    """
    Prepare dataset from audio and transcript directories
    """
    # Load data
    data = load_audio_transcript_pairs(audio_dir, transcript_dir)

    # Create dataset with audio paths (not loaded yet)
    dataset = Dataset.from_dict({
        'audio_path': [item['audio'] for item in data],
        'sentence': [item['sentence'] for item in data]
    })

    # Split into train and test
    dataset = dataset.train_test_split(test_size=test_size, seed=42)

    return dataset


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator for speech-to-text models
    """
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Get labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore in loss
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Remove BOS token if present
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch


def prepare_data(batch, processor):
    """
    Prepare data for training
    """
    # Load and resample audio using librosa
    audio_array = load_and_resample_audio(batch["audio_path"], target_sr=16000)

    if audio_array is None:
        # Return empty features if audio loading fails
        batch["input_features"] = np.zeros((80, 3000))
        batch["labels"] = []
        return batch

    # Compute input features
    batch["input_features"] = processor.feature_extractor(
        audio_array, sampling_rate=16000
    ).input_features[0]

    # Encode target text
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids

    return batch


def compute_metrics(pred, processor, metric):
    """
    Compute WER metric
    """
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad token
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Compute WER
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


def main():
    """
    Main training function
    """
    # Determine which model to load
    if CONTINUE_FROM_CHECKPOINT and os.path.exists(OUTPUT_DIR):
        print(f"Loading model from existing checkpoint: {OUTPUT_DIR}")
        model_path = OUTPUT_DIR
        # Check if it's a valid checkpoint
        if not os.path.exists(os.path.join(OUTPUT_DIR, "config.json")):
            print(f"Warning: No valid checkpoint found in {OUTPUT_DIR}, starting from base model")
            model_path = MODEL_NAME
    else:
        print(f"Starting fresh from base model: {MODEL_NAME}")
        model_path = MODEL_NAME

    print("Loading model and processor...")

    # Load feature extractor, tokenizer, and processor
    feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path)
    tokenizer = WhisperTokenizer.from_pretrained(
        model_path,
        language=LANGUAGE,
        task=TASK
    )
    processor = WhisperProcessor.from_pretrained(
        model_path,
        language=LANGUAGE,
        task=TASK
    )

    # Load model
    model = WhisperForConditionalGeneration.from_pretrained(model_path)

    # Configure model for Konkani
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []
    model.config.use_cache = False

    # Set language and task
    model.generation_config.language = LANGUAGE
    model.generation_config.task = TASK

    print("Preparing dataset...")
    dataset = prepare_dataset(AUDIO_DIR, TRANSCRIPT_DIR)

    # Prepare dataset
    print("Processing audio and text...")
    dataset = dataset.map(
        lambda batch: prepare_data(batch, processor),
        remove_columns=dataset["train"].column_names,
        num_proc=1  # Changed to 1 to avoid multiprocessing issues with audio loading
    )

    # Data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

    # Metric
    metric = evaluate.load("wer")

    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        learning_rate=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        max_steps=MAX_STEPS,
        gradient_checkpointing=True,
        fp16=True,
        eval_strategy="steps",
        per_device_eval_batch_size=BATCH_SIZE,
        predict_with_generate=True,
        generation_max_length=225,
        save_steps=SAVE_STEPS,
        eval_steps=EVAL_STEPS,
        logging_steps=100,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
        push_to_hub=False,
    )

    # Trainer
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        data_collator=data_collator,
        compute_metrics=lambda pred: compute_metrics(pred, processor, metric),
        tokenizer=processor.feature_extractor,
    )

    print("Starting training...")
    trainer.train()

    # Save final model
    print("Saving model...")
    trainer.save_model(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)

    print(f"Training complete! Model saved to {OUTPUT_DIR}")


if __name__ == "__main__":
    main()

Starting fresh from base model: openai/whisper-small
Loading model and processor...
Preparing dataset...
Loaded 15 audio-transcript pairs
Processing audio and text...


Map:   0%|          | 0/13 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(


Starting training...


You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


ValueError: Labels' sequence length 521 cannot exceed the maximum allowed length of 448 tokens.