<a href="https://colab.research.google.com/github/kavish-24/Konkani_Mental_Health/blob/main/WhiperFineTuneSmall.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
pip install transformers datasets evaluate jiwer torch torchaudio accelerate tensorboard

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m80.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer, evaluate
Successfully installed evaluate-0.4.6 jiwer-4.0.0 rapidfuzz-3.14.3


In [None]:
import os
from transformers import WhisperTokenizer

# Load tokenizer (use whisper-small since you're fine-tuning it)
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small")

# Path where your transcripts are stored
transcript_dir = "/content/drive/MyDrive/training/10 Aug"  # <-- change to your directory

for filename in os.listdir(transcript_dir):
    if filename.endswith(".txt"):  # process only text files
        filepath = os.path.join(transcript_dir, filename)

        # Read the transcript file
        with open(filepath, "r", encoding="utf-8") as f:
            text = f.read().strip()

        # Tokenize and count
        tokens = tokenizer(text, return_tensors="pt").input_ids
        token_count = tokens.shape[1]

        print(f"{filename}: {token_count} tokens")


News_100817_segment_014_transcript.txt: 418 tokens
News_100817_segment_003_transcript.txt: 521 tokens
News_100817_segment_001_transcript.txt: 421 tokens
News_100817_segment_012_transcript.txt: 441 tokens
News_100817_segment_007_transcript.txt: 406 tokens
News_100817_segment_005_transcript.txt: 403 tokens
News_100817_segment_010_transcript.txt: 249 tokens
News_100817_segment_011_transcript.txt: 401 tokens
News_100817_segment_015_transcript.txt: 92 tokens
News_100817_segment_002_transcript.txt: 47 tokens
News_100817_segment_008_transcript.txt: 395 tokens
News_100817_segment_009_transcript.txt: 135 tokens
News_100817_segment_013_transcript.txt: 388 tokens
News_100817_segment_006_transcript.txt: 440 tokens
News_100817_segment_004_transcript.txt: 242 tokens


In [None]:
"""
Fine-tune Whisper Small model for Konkani language with Marathi support
With support for stopping and resuming training
FIXED: Gradient accumulation and checkpointing issues
"""

import os
import torch
from datasets import Dataset, Audio
from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
import numpy as np
import librosa
import json
from pathlib import Path

# Configuration
AUDIO_DIR = "/content/drive/MyDrive/Anju Project (1)/Audio_segments"
TRANSCRIPT_DIR = "/content/drive/MyDrive/Anju Project (1)/training_through_bot"
MODEL_NAME = "openai/whisper-small"
OUTPUT_DIR = "/content/drive/MyDrive/whisper-small-konkani"
LANGUAGE = "konkani"
TASK = "transcribe"

# Training mode - Set this to continue from last checkpoint
RESUME_TRAINING = True  # Set to True to resume from last checkpoint
START_FRESH = False     # Set to True to ignore existing checkpoints and start over

# Training parameters
BATCH_SIZE = 8  # Adjust based on your GPU memory
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 1e-5
WARMUP_STEPS = 500
MAX_STEPS = 10000  # Total training steps
EVAL_STEPS = 500
SAVE_STEPS = 500
SAVE_TOTAL_LIMIT = 3  # Keep only last 3 checkpoints to save space


def get_latest_checkpoint(output_dir):
    """
    Find the latest checkpoint in the output directory
    """
    checkpoint_dirs = []
    if os.path.exists(output_dir):
        for item in os.listdir(output_dir):
            item_path = os.path.join(output_dir, item)
            if os.path.isdir(item_path) and item.startswith("checkpoint-"):
                checkpoint_dirs.append(item_path)

    if not checkpoint_dirs:
        return None

    # Sort by checkpoint number
    checkpoint_dirs.sort(key=lambda x: int(x.split("-")[-1]))
    latest_checkpoint = checkpoint_dirs[-1]

    print(f"Found latest checkpoint: {latest_checkpoint}")
    return latest_checkpoint


def load_audio_transcript_pairs(audio_dir, transcript_dir):
    """
    Load audio files and their corresponding transcripts
    Recursively searches through all subdirectories
    """
    data = []
    audio_dir_path = Path(audio_dir)
    transcript_dir_path = Path(transcript_dir)

    # Find all audio files recursively
    audio_files = list(audio_dir_path.rglob("*.mp3"))

    print(f"Found {len(audio_files)} audio files")

    for audio_path in audio_files:
        # Get relative path structure
        rel_path = audio_path.relative_to(audio_dir_path)

        # Create corresponding transcript path
        # Change extension from .mp3 to .txt
        transcript_rel_path = rel_path.with_suffix('.txt')
        transcript_path = transcript_dir_path / transcript_rel_path

        # Check if transcript exists
        if transcript_path.exists():
            try:
                with open(transcript_path, 'r', encoding='utf-8') as f:
                    transcript = f.read().strip()

                if transcript:  # Only add if transcript is not empty
                    data.append({
                        'audio': str(audio_path),
                        'sentence': transcript
                    })
            except Exception as e:
                print(f"Error reading transcript {transcript_path}: {e}")
        else:
            print(f"Warning: Transcript not found for {audio_path}")
            print(f"  Expected at: {transcript_path}")

    print(f"Successfully loaded {len(data)} audio-transcript pairs")
    return data


def load_and_resample_audio(audio_path, target_sr=16000):
    """
    Load and resample audio file to target sample rate
    """
    try:
        audio_array, sr = librosa.load(audio_path, sr=target_sr)
        return audio_array
    except Exception as e:
        print(f"Error loading {audio_path}: {e}")
        return None


def prepare_dataset(audio_dir, transcript_dir, test_size=0.1):
    """
    Prepare dataset from audio and transcript directories
    """
    # Load data
    data = load_audio_transcript_pairs(audio_dir, transcript_dir)

    if not data:
        raise ValueError("No audio-transcript pairs found! Please check your directory paths.")

    # Create dataset with audio paths (not loaded yet)
    dataset = Dataset.from_dict({
        'audio_path': [item['audio'] for item in data],
        'sentence': [item['sentence'] for item in data]
    })

    # Split into train and test with fixed seed for reproducibility
    dataset = dataset.train_test_split(test_size=test_size, seed=42)

    print(f"Train set: {len(dataset['train'])} samples")
    print(f"Test set: {len(dataset['test'])} samples")

    return dataset


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator for speech-to-text models
    """
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Get labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore in loss
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Remove BOS token if present
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch


def prepare_data(batch, processor, max_label_length=448):
    """
    Prepare data for training
    Skips samples with labels exceeding max_label_length
    """
    # Load and resample audio using librosa
    audio_array = load_and_resample_audio(batch["audio_path"], target_sr=16000)

    if audio_array is None:
        # Mark as invalid
        batch["input_features"] = None
        batch["labels"] = None
        batch["is_valid"] = False
        batch["skip_reason"] = "audio_load_failed"
        return batch

    # Encode target text first to check length
    labels = processor.tokenizer(batch["sentence"]).input_ids

    # Check if labels exceed maximum length
    if len(labels) > max_label_length:
        batch["input_features"] = None
        batch["labels"] = None
        batch["is_valid"] = False
        batch["skip_reason"] = f"labels_too_long_{len(labels)}_tokens"
        return batch

    # Compute input features
    batch["input_features"] = processor.feature_extractor(
        audio_array, sampling_rate=16000
    ).input_features[0]

    batch["labels"] = labels
    batch["is_valid"] = True
    batch["skip_reason"] = None

    return batch


def compute_metrics(pred, processor, metric):
    """
    Compute WER metric
    """
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad token
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Compute WER
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


def main():
    """
    Main training function with resume capability
    """
    # Determine starting point
    checkpoint_to_resume = None
    model_path = MODEL_NAME

    if START_FRESH:
        print("=" * 60)
        print("STARTING FRESH TRAINING")
        print("Ignoring any existing checkpoints")
        print("=" * 60)
        model_path = MODEL_NAME
    elif RESUME_TRAINING:
        # Look for latest checkpoint
        latest_checkpoint = get_latest_checkpoint(OUTPUT_DIR)

        if latest_checkpoint:
            print("=" * 60)
            print("RESUMING TRAINING FROM CHECKPOINT")
            print(f"Checkpoint: {latest_checkpoint}")
            print("=" * 60)
            checkpoint_to_resume = latest_checkpoint
            model_path = latest_checkpoint
        else:
            # Check if there's a saved model in OUTPUT_DIR
            if os.path.exists(OUTPUT_DIR) and os.path.exists(os.path.join(OUTPUT_DIR, "config.json")):
                print("=" * 60)
                print("RESUMING FROM SAVED MODEL")
                print(f"Model directory: {OUTPUT_DIR}")
                print("=" * 60)
                model_path = OUTPUT_DIR
            else:
                print("=" * 60)
                print("NO CHECKPOINT FOUND - STARTING FRESH")
                print("=" * 60)
                model_path = MODEL_NAME
    else:
        print("=" * 60)
        print("STARTING NEW TRAINING")
        print("=" * 60)
        model_path = MODEL_NAME

    print("\nLoading model and processor...")

    # Load feature extractor, tokenizer, and processor
    feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path)
    tokenizer = WhisperTokenizer.from_pretrained(
        model_path,
        language=LANGUAGE,
        task=TASK
    )
    processor = WhisperProcessor.from_pretrained(
        model_path,
        language=LANGUAGE,
        task=TASK
    )

    # Load model
    model = WhisperForConditionalGeneration.from_pretrained(model_path)

    # Configure model for Konkani
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []

    # CRITICAL FIX: Set use_cache to False for training
    model.config.use_cache = False

    # Ensure model is in training mode
    model.train()

    # Set language and task
    model.generation_config.language = LANGUAGE
    model.generation_config.task = TASK

    print("\nPreparing dataset...")
    dataset = prepare_dataset(AUDIO_DIR, TRANSCRIPT_DIR)

    # Prepare dataset
    print("\nProcessing audio and text...")

    # Store original audio paths for reporting
    train_audio_paths = dataset["train"]["audio_path"]
    test_audio_paths = dataset["test"]["audio_path"]

    dataset = dataset.map(
        lambda batch: prepare_data(batch, processor, max_label_length=448),
        remove_columns=dataset["train"].column_names,
        num_proc=1  # Changed to 1 to avoid multiprocessing issues
    )

    # Filter out invalid samples and collect skipped files
    print("\nFiltering valid samples...")
    skipped_files = {"train": [], "test": []}

    # Process train set
    train_valid_indices = []
    for idx, (is_valid, skip_reason, audio_path) in enumerate(zip(
        dataset["train"]["is_valid"],
        dataset["train"]["skip_reason"],
        train_audio_paths
    )):
        if is_valid:
            train_valid_indices.append(idx)
        else:
            skipped_files["train"].append({
                "path": audio_path,
                "reason": skip_reason
            })

    # Process test set
    test_valid_indices = []
    for idx, (is_valid, skip_reason, audio_path) in enumerate(zip(
        dataset["test"]["is_valid"],
        dataset["test"]["skip_reason"],
        test_audio_paths
    )):
        if is_valid:
            test_valid_indices.append(idx)
        else:
            skipped_files["test"].append({
                "path": audio_path,
                "reason": skip_reason
            })

    # Filter datasets
    dataset["train"] = dataset["train"].select(train_valid_indices)
    dataset["test"] = dataset["test"].select(test_valid_indices)

    # Remove helper columns
    dataset = dataset.remove_columns(["is_valid", "skip_reason"])

    print(f"\nFiltered dataset:")
    print(f"Train set: {len(dataset['train'])} samples (skipped {len(skipped_files['train'])})")
    print(f"Test set: {len(dataset['test'])} samples (skipped {len(skipped_files['test'])})")

    # Report skipped files
    total_skipped = len(skipped_files["train"]) + len(skipped_files["test"])
    if total_skipped > 0:
        print(f"\n{'=' * 60}")
        print(f"SKIPPED {total_skipped} FILES")
        print(f"{'=' * 60}")

        # Group by reason
        skip_reasons = {}
        for split in ["train", "test"]:
            for item in skipped_files[split]:
                reason = item["reason"]
                if reason not in skip_reasons:
                    skip_reasons[reason] = []
                skip_reasons[reason].append(item["path"])

        for reason, paths in skip_reasons.items():
            print(f"\n{reason}: {len(paths)} files")
            for path in paths[:5]:  # Show first 5
                print(f"  - {path}")
            if len(paths) > 5:
                print(f"  ... and {len(paths) - 5} more")

        # Save full list to file
        skip_report_path = os.path.join(OUTPUT_DIR, "skipped_files_report.json")
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        with open(skip_report_path, 'w', encoding='utf-8') as f:
            json.dump(skipped_files, f, indent=2, ensure_ascii=False)
        print(f"\nFull report saved to: {skip_report_path}")
        print(f"{'=' * 60}\n")

    # Data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

    # Metric
    metric = evaluate.load("wer")

    # Training arguments with FIXES
    training_args = Seq2SeqTrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        learning_rate=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        max_steps=MAX_STEPS,

        # CRITICAL FIXES for gradient issues
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},  # FIX: Use non-reentrant checkpointing

        fp16=True,
        eval_strategy="steps",
        per_device_eval_batch_size=BATCH_SIZE,
        predict_with_generate=True,
        generation_max_length=225,
        save_steps=SAVE_STEPS,
        eval_steps=EVAL_STEPS,
        logging_steps=100,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
        push_to_hub=False,
        save_total_limit=SAVE_TOTAL_LIMIT,

        # Additional fixes
        dataloader_num_workers=0,  # FIX: Avoid multiprocessing issues
        remove_unused_columns=False,  # Keep all columns

        # Resume from checkpoint
        resume_from_checkpoint=checkpoint_to_resume,
    )

    # Trainer
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        data_collator=data_collator,
        compute_metrics=lambda pred: compute_metrics(pred, processor, metric),
        processing_class=processor.feature_extractor,  # Updated parameter name
    )

    print("\n" + "=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)
    print(f"Total steps: {MAX_STEPS}")
    print(f"Save every: {SAVE_STEPS} steps")
    print(f"Evaluate every: {EVAL_STEPS} steps")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
    print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
    print("=" * 60 + "\n")

    # Train with resume capability
    trainer.train(resume_from_checkpoint=checkpoint_to_resume)

    # Save final model
    print("\nSaving final model...")
    trainer.save_model(OUTPUT_DIR)
    processor.save_pretrained(OUTPUT_DIR)

    print("\n" + "=" * 60)
    print("TRAINING COMPLETE!")
    print(f"Model saved to: {OUTPUT_DIR}")
    print("=" * 60)


if __name__ == "__main__":
    main()

NO CHECKPOINT FOUND - STARTING FRESH

Loading model and processor...

Preparing dataset...
Found 1100 audio files
  Expected at: /content/drive/MyDrive/Anju Project (1)/training_through_bot/Aug/Konkani Prime News_100817/Konkani Prime News_100817_segment_008.txt
  Expected at: /content/drive/MyDrive/Anju Project (1)/training_through_bot/July/Konkani Update news_170717/Konkani Update news_170717_segment_002.txt
  Expected at: /content/drive/MyDrive/Anju Project (1)/training_through_bot/July/Konkani Update news_170717/Konkani Update news_170717_segment_006.txt
  Expected at: /content/drive/MyDrive/Anju Project (1)/training_through_bot/July/konkani Prime news_030717/konkani Prime news_030717_segment_006.txt
  Expected at: /content/drive/MyDrive/Anju Project (1)/training_through_bot/July/konkani Prime news_030717/konkani Prime news_030717_segment_012.txt
  Expected at: /content/drive/MyDrive/Anju Project (1)/training_through_bot/July/konkani Prime news_030717/konkani Prime news_030717_segme

Map:   0%|          | 0/900 [00:00<?, ? examples/s]

Map:   0%|          | 0/101 [00:00<?, ? examples/s]


Filtering valid samples...

Filtered dataset:
Train set: 889 samples (skipped 11)
Test set: 100 samples (skipped 1)

SKIPPED 12 FILES

labels_too_long_459_tokens: 1 files
  - /content/drive/MyDrive/Anju Project (1)/Audio_segments/November/konkani prime 15 nov 17/konkani prime 15 nov 17_segment_001.mp3

labels_too_long_457_tokens: 1 files
  - /content/drive/MyDrive/Anju Project (1)/Audio_segments/October/9th Oct 17_Konk Prime News/9th Oct 17_Konk Prime News_segment_016.mp3

labels_too_long_627_tokens: 1 files
  - /content/drive/MyDrive/Anju Project (1)/Audio_segments/May/Konk Prime News_040517/Konk Prime News_040517_segment_009.mp3

labels_too_long_469_tokens: 1 files
  - /content/drive/MyDrive/Anju Project (1)/Audio_segments/May/Konk Prime_010517/Konk Prime_010517_segment_016.mp3

labels_too_long_462_tokens: 1 files
  - /content/drive/MyDrive/Anju Project (1)/Audio_segments/October/30th Oct 17_Konk Prime News/30th Oct 17_Konk Prime News_segment_002.mp3

labels_too_long_488_tokens: 1 f

You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss


In [3]:
!apt-get install tree -y
!tree -a "/content/drive/MyDrive/Anju Project (1)/Audio_segments"


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tree is already the newest version (2.0.2-1).
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.
[01;34m/content/drive/MyDrive/Anju Project (1)/Audio_segments[0m
├── [01;34mAug[0m
│   ├── [01;34mKonkani Prime News_070817[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_001.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_002.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_003.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_004.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_005.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_006.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_007.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_008.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_009.mp3[0m
│   │   ├── [01;35mKonkani Prime News_070817_segment_010.mp3[0m


In [1]:
!apt-get install tree -y
!tree -a "/content/drive/MyDrive/Anju Project (1)/training_through_bot"


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  tree
0 upgraded, 1 newly installed, 0 to remove and 41 not upgraded.
Need to get 47.9 kB of archives.
After this operation, 116 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 tree amd64 2.0.2-1 [47.9 kB]
Fetched 47.9 kB in 1s (54.9 kB/s)
Selecting previously unselected package tree.
(Reading database ... 121713 files and directories currently installed.)
Preparing to unpack .../tree_2.0.2-1_amd64.deb ...
Unpacking tree (2.0.2-1) ...
Setting up tree (2.0.2-1) ...
Processing triggers for man-db (2.10.2-1) ...
[01;34m/content/drive/MyDrive/Anju Project (1)/training_through_bot[0m
├── [01;34mAug[0m
│   ├── [01;34mKonkani Prime News_070817[0m
│   │   ├── [00mKonkani Prime News_070817_segment_001.txt[0m
│   │   ├── [00mKonkani Prime News_070817_segment_002.txt[0m
│   │   ├── [00mKonka