# **1: Install required packages**

In [None]:
!pip install transformers==4.45.2 datasets==3.0.1 evaluate==0.4.3 jiwer==3.0.4 soundfile==0.12.1 librosa==0.10.2 accelerate==0.34.2 --quiet

# **2: Import libraries**

In [None]:
import torch
import os
import numpy as np
import librosa
import soundfile as sf
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer, WhisperFeatureExtractor
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
import evaluate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
torch.cuda.empty_cache()  # Clear CUDA memory

# **3: Download and prepare data (small size for free Colab)**

In [None]:
train_dataset = load_dataset("mozilla-foundation/common_voice_11_0", "fa", split="train[:500]")
eval_dataset = load_dataset("mozilla-foundation/common_voice_11_0", "fa", split="validation[:50]")

def normalize_audio(example):
    audio = example["audio"]["array"]
    sr = example["audio"]["sampling_rate"]
    if sr != 16000:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    example["audio"]["array"] = audio
    example["audio"]["sampling_rate"] = 16000
    return example

train_dataset = train_dataset.map(normalize_audio)
eval_dataset = eval_dataset.map(normalize_audio)

def prepare_text(example):
    example["sentence"] = example["sentence"].strip()
    return example

train_dataset = train_dataset.map(prepare_text)
eval_dataset = eval_dataset.map(prepare_text)

# Filter invalid examples
def filter_invalid(example):
    return len(example["audio"]["array"]) > 0 and len(example["sentence"]) > 0

train_dataset = train_dataset.filter(filter_invalid)
eval_dataset = eval_dataset.filter(filter_invalid)

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="fa", task="transcribe")

def prepare_dataset(example):
    audio = example["audio"]
    features = feature_extractor(audio["array"], sampling_rate=16000).input_features[0]
    labels = tokenizer(example["sentence"], add_special_tokens=True).input_ids
    example["input_features"] = features
    example["labels"] = labels
    return example

train_dataset = train_dataset.map(prepare_dataset, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=eval_dataset.column_names)

# Debug: Inspect dataset
print("Train sample shape:", np.array(train_dataset[0]["input_features"]).shape, "labels len:", len(train_dataset[0]["labels"]))
print("Eval sample shape:", np.array(eval_dataset[0]["input_features"]).shape, "labels len:", len(eval_dataset[0]["labels"]))

# **4: Load Whisper-Tiny model**

In [None]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.generation_config.language = "fa"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language="fa", task="transcribe")
model.generation_config.max_length = 448
model.to(device)
print("✅ Model loaded.")

# **5: Define WER metric**

In [None]:
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Handle tuple case for pred_ids
    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]

    if not isinstance(pred_ids, torch.Tensor):
        pred_ids = torch.tensor(pred_ids)
    if pred_ids.ndim == 3:
        pred_ids = pred_ids.argmax(dim=-1)

    label_ids = torch.tensor(label_ids)
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Debug: Print sample predictions
    for i in range(min(3, len(pred_str))):
        print(f"Prediction {i}: {pred_str[i]}")
        print(f"Reference {i}: {label_str[i]}")

    wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

# **6: Training arguments and data collator (optimized for free Colab)**

In [None]:
output_dir = "./whisper-fa-finetuned"
os.makedirs(output_dir, exist_ok=True)  # Create output directory if it doesn't exist

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=1,             # Minimum value to prevent memory errors
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,             # Effective for maintaining larger batch size
    num_train_epochs=4,                        # Moderate number of epochs
    learning_rate=1e-5,                        # Reduced lr for stability
    warmup_steps=100,                          # Add warmup to prevent divergence
    lr_scheduler_type="cosine",
    fp16=torch.cuda.is_available(),
    evaluation_strategy="steps",
    eval_steps=20,
    save_steps=40,
    logging_strategy="steps",
    logging_steps=20,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    remove_unused_columns=False,
    report_to="none",
    save_total_limit=1,                        # Limit checkpoints to save disk space
)

def data_collator(features):
    input_features = [feature["input_features"] for feature in features]
    labels = [feature["labels"] for feature in features]
    max_label_length = max(len(l) for l in labels)
    labels = [l + [tokenizer.pad_token_id] * (max_label_length - len(l)) for l in labels]
    batch = {
        "input_features": torch.tensor(input_features, dtype=torch.float32),
        "labels": torch.tensor(labels, dtype=torch.long),
    }
    return batch

# **7: Train model**

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Baseline evaluation
print("Baseline Evaluation (before training):")
baseline_results = trainer.evaluate()
print(f"Baseline WER: {baseline_results.get('eval_wer', 'N/A'):.2f}%")

trainer.train()
print("✅ Training finished.")

# Cell 8: Evaluate and save model
eval_results = trainer.evaluate()
print(f"Final Evaluation Results: WER = {eval_results.get('eval_wer', 'N/A'):.2f}%")

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="fa", task="transcribe")
processor.feature_extractor = feature_extractor
processor.tokenizer = tokenizer
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

# Verify checkpoint
if os.path.exists(os.path.join(output_dir, "pytorch_model.bin")):
    print(f"✅ Checkpoint verified: pytorch_model.bin exists in {output_dir}")
else:
    print(f"⚠️ Checkpoint verification failed: pytorch_model.bin not found in {output_dir}")

print(f"✅ Model & processor saved in {output_dir}")

# **9: Transcription function**

In [None]:
def transcribe_audio(audio_path, model_path=output_dir, device=device):
    model = WhisperForConditionalGeneration.from_pretrained(model_path).to(device)
    processor = WhisperProcessor.from_pretrained(model_path)
    audio, sampling_rate = sf.read(audio_path)
    if sampling_rate != 16000:
        audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        predicted_ids = model.generate(inputs["input_features"])
    transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
    return transcription

# **10: Final report**

In [None]:
print("\n📋 Final Project Report")
print("="*50)
if os.path.exists(output_dir) and os.path.isfile(os.path.join(output_dir, "pytorch_model.bin")):
    print("✅ Model successfully trained and saved")
    print("Directory contents:")
    for item in os.listdir(output_dir):
        print(f"  - {item}")
else:
    print("⚠️ Model was not trained or saved")
print(f"\n🔸 Model: Whisper-Tiny for Persian")
print(f"🔸 Training data: {len(train_dataset)} samples")
print(f"🔸 Evaluation data: {len(eval_dataset)} samples")
print(f"🔸 Device: {str(device).upper()}")
print(f"🔸 Word Error Rate (WER): {eval_results.get('eval_wer', 'N/A'):.2f}%")
print("🎉 Project completed!")