In [None]:
# ===========================================================
# SECTION 1: INSTALLATION
# ===========================================================
# Install required libraries with GPU support and SafeTensors compatibility
!pip install -q torch==2.5.1+cu121 torchvision==0.20.1+cu121 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
!pip install -q "datasets==2.16.1" librosa==0.10.2.post1
!pip install -q transformers jiwer accelerate soundfile safetensors
!pip install -qU gradio==4.44.0

In [None]:
# ===========================================================
# SECTION 2: ENVIRONMENT SETUP
# ===========================================================
# Reduce GPU memory fragmentation (highly recommended for T4 GPU)
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
# ===========================================================
# SECTION 3: IMPORT LIBRARIES
# ===========================================================
import torch
import numpy as np
import librosa
import soundfile as sf
from datasets import load_dataset, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer, EarlyStoppingCallback
from dataclasses import dataclass
from typing import Dict, List, Union
import gradio as gr
from huggingface_hub import login
from jiwer import wer, cer

In [None]:
# ===========================================================
# SECTION 4: AUTHENTICATION
# ===========================================================
# Log in to Hugging Face Hub
HF_TOKEN = "Your token"  # Your token
login(token=HF_TOKEN)

In [None]:
# ===========================================================
# SECTION 5: LOAD DATASET
# ===========================================================
# Load the Persian speech dataset and resample to 16kHz
dataset = load_dataset("SeyedAli/Persian-Speech-Dataset")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [None]:
# ===========================================================
# SECTION 6: CONFIGURATION
# ===========================================================
# Model and repository configuration
BASE_MODEL = "jonatasgrosman/wav2vec2-large-xlsr-53-persian"
HF_USERNAME = "Your USERNAME"
MODEL_REPO_NAME = f"{HF_USERNAME}/wav2vec2-large-xlsr53-fa-finetuned-gpu-final"

In [None]:
# ===========================================================
# SECTION 7: LOAD PROCESSOR
# ===========================================================
# Load the processor with SafeTensors for secure loading
processor = Wav2Vec2Processor.from_pretrained(BASE_MODEL, use_safetensors=True)

In [None]:
# ===========================================================
# SECTION 8: DATA PREPROCESSING
# ===========================================================
# Preprocess the data with normalization and length truncation to save memory
MAX_LENGTH = 160000  # ~10 seconds at 16kHz

def prepare_dataset(batch):
    audio = batch["audio"]
    audio_array = np.asarray(audio["array"], dtype=np.float32)

    # Truncate long audios to prevent memory issues
    if len(audio_array) > MAX_LENGTH:
        audio_array = audio_array[:MAX_LENGTH]

    # Normalize audio
    max_abs = np.max(np.abs(audio_array))
    if max_abs > 0:
        audio_array = audio_array / max_abs

    # Process input features
    inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
    batch["input_values"] = inputs.input_values[0]

    # Process labels
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcript"]).input_ids

    return batch

# Apply preprocessing
dataset = dataset.map(
    prepare_dataset,
    remove_columns=["audio", "speaker_id", "gender", "ipa", "emotion", "transcript"]
)

In [None]:
# ===========================================================
# SECTION 9: TRAIN-TEST SPLIT
# ===========================================================
# Split the dataset into train and evaluation sets
split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

In [None]:
# ===========================================================
# SECTION 10: DATA COLLATOR
# ===========================================================
# Custom data collator for CTC loss
@dataclass
class DataCollatorCTC:
    processor: Wav2Vec2Processor
    padding: bool = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt")

        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(label_features, padding=self.padding, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTC(processor=processor)

In [None]:
# ===========================================================
# SECTION 11: LOAD MODEL
# ===========================================================
# Load the pre-trained model with SafeTensors
print("Loading model with SafeTensors...")
model = Wav2Vec2ForCTC.from_pretrained(
    BASE_MODEL,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    use_safetensors=True,
)
model.gradient_checkpointing_enable()
model.config.ctc_zero_infinity = True

In [None]:
# ===========================================================
# SECTION 12: METRICS
# ===========================================================
# Define WER and CER metrics using jiwer
def compute_metrics(pred):
    pred_ids = np.argmax(pred.predictions, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    return {"wer": wer(label_str, pred_str), "cer": cer(label_str, pred_str)}

In [None]:
# ===========================================================
# SECTION 13: TRAINING ARGUMENTS
# ===========================================================
# Optimized training arguments to prevent OOM on T4 GPU
training_args = TrainingArguments(
    output_dir="./wav2vec2-fa-gpu",
    overwrite_output_dir=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=16,
    eval_strategy="steps",
    eval_steps=200,
    logging_steps=50,
    save_strategy="steps",
    save_steps=400,
    save_total_limit=3,
    learning_rate=1e-4,
    weight_decay=0.01,
    num_train_epochs=10,
    warmup_steps=300,
    lr_scheduler_type="cosine",
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    fp16=True,
    push_to_hub=True,
    hub_model_id=MODEL_REPO_NAME,
    hub_strategy="checkpoint",
    hub_private_repo=False,
    gradient_checkpointing=True,
    dataloader_num_workers=2,
    optim="adamw_torch",
    save_safetensors=True,
)

In [None]:
# ===========================================================
# SECTION 14: TRAINER SETUP
# ===========================================================
# Create the Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

In [None]:
# ===========================================================
# SECTION 15: START TRAINING
# ===========================================================
# Clear GPU cache and start training
torch.cuda.empty_cache()
print("üöÄ Starting final training without OOM on GPU...")
trainer.train()

In [None]:
# ===========================================================
# SECTION 16: SAVE AND PUSH TO HUB
# ===========================================================
# Save the model and processor, then push to Hugging Face Hub
trainer.save_model()
processor.save_pretrained("./wav2vec2-fa-gpu")
trainer.push_to_hub(
    commit_message="Final fine-tune of Wav2Vec2-large-xlsr-53-persian with memory-optimized settings and SafeTensors",
    tags=["automatic-speech-recognition", "persian", "farsi", "wav2vec2", "safetensors"]
)

In [None]:
# ===========================================================
# SECTION 17: FINAL EVALUATION
# ===========================================================
# Evaluate the fine-tuned model
metrics = trainer.evaluate()
print(f"\n‚úÖ Final WER: {metrics['eval_wer']:.2%}")
print(f"‚úÖ Final CER: {metrics['eval_cer']:.2%}")

In [None]:
# ===========================================================
# SECTION 18: INFERENCE FUNCTION
# ===========================================================
# Function to transcribe audio files
def transcribe_audio(audio_path):
    speech, _ = librosa.load(audio_path, sr=16000)
    inputs = processor(speech, sampling_rate=16000, return_tensors="pt").to("cuda")

    with torch.no_grad():
        logits = model(**inputs).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]
    return transcription

In [None]:
# ===========================================================
# SECTION 19: TEST ON SAMPLE
# ===========================================================
# Test the model on a sample from the evaluation set
example = eval_dataset[0]

original_text = processor.decode(example["labels"], skip_special_tokens=True)

input_values = example["input_values"]
if isinstance(input_values, list):
    input_values = np.array(input_values)
else:
    input_values = input_values.numpy()

sf.write("test_sample.wav", input_values, 16000)

result = transcribe_audio("test_sample.wav")

print(f"\nüìù ÿ™ÿ≥ÿ™ ŸÜŸÖŸàŸÜŸá:")
print(f"ŸÖÿ™ŸÜ ÿßÿµŸÑ€å: {original_text}")
print(f"ÿ™ÿ¥ÿÆ€åÿµ ŸÖÿØŸÑ: {result}")

In [None]:
# ===========================================================
# SECTION 20: COMPLETION MESSAGE
# ===========================================================
print(f"\n‚úÖ Training completed successfully!")
print(f"üîó Model uploaded to: https://huggingface.co/{MODEL_REPO_NAME}")