# Advanced Bengali Regional Dialect ASR - Encoder-Only Fine-tuning
## With Enhanced Training Features:
- ✅ **Encoder-only fine-tuning** (decoder frozen, encoder trainable)
- ✅ **On-the-fly data augmentation** (fresh each epoch)
- ✅ **Weighted sampling** for class balance  
- ✅ **Stratified train/val split**
- ✅ **Early stopping and learning rate scheduling**
- ✅ **Comprehensive metrics tracking**
- ✅ **Optimized for P100 GPU**

**Strategy:** Freeze decoder, train encoder only (~25% parameters trainable)
This adapts the acoustic encoder to regional dialects while preserving the decoder's Bengali language model.

**Dataset:** BengaliAI Regional ASR

## 1. Install Required Packages

In [None]:
!apt-get install -y libsndfile1
!pip install numpy scipy librosa soundfile
!pip install audiomentations --no-build-isolation

In [None]:
!pip install transformers==4.47.0 datasets accelerate evaluate jiwer tensorboard librosa soundfile audiomentations -q

In [None]:
!pip install evaluate

## 2. Import Libraries

In [None]:
# Fix protobuf compatibility issues
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'

# Suppress CUDA warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
!pip uninstall -y protobuf
!pip install protobuf==3.20.3

In [None]:
import os
import pandas as pd
import numpy as np
import random
import torch
import librosa
from collections import Counter
from torch.utils.data import Dataset as TorchDataset, WeightedRandomSampler
from datasets import Dataset, DatasetDict
from transformers import (
    WhisperFeatureExtractor,
    WhisperTokenizer,
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback
)
import evaluate
from dataclasses import dataclass
from typing import Any, Dict, List, Union

# Audio augmentation
try:
    from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Gain
    USE_AUDIOMENTATIONS = True
    print("✓ audiomentations library loaded")
except ImportError:
    USE_AUDIOMENTATIONS = False
    print("⚠ audiomentations not available, using basic augmentation")


## 3. Configuration

In [None]:
# Model and paths
MODEL_NAME = "bengaliAI/tugstugi_bengaliai-asr_whisper-medium"  # ⭐ CHANGED: Standard Bangla model
TRAIN_AUDIO_PATH = "/kaggle/input/shobdotori-regspeech12-compact-v2/shobdotori_regspeech12_compact_v2/Train"
TRAIN_ANNOTATION_PATH = "/kaggle/input/shobdotori-regspeech12-compact-v2/shobdotori_regspeech12_compact_v2/Train_annotation"
OUTPUT_DIR = "./whisper-bengali-encoder-only-finetune"  # ⭐ CHANGED: New directory name

# Training configuration (optimized for P100)
SEED = 42
CONFIG = {
    'batch_size': 1,              
    'gradient_accumulation': 4,   
    'learning_rate': 1e-5,        
    'weight_decay': 0.01,
    'warmup_steps': 500,
    'num_epochs': 6,
    'eval_steps': 500,
    'save_steps': None,           # ⭐ CHANGED: No intermediate saves
    'save_total_limit': None,     # ⭐ CHANGED: Not needed
    'early_stopping_patience': 3,
    'use_augmentation': True,
    'augmentation_prob': 0.5,
    'use_weighted_sampling': True,
    'fp16': True,
}



# Set random seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## Audio Augmentation Setup

## 4. Setup Data Augmentation Pipeline

In [None]:
# Audio augmentation pipeline
if USE_AUDIOMENTATIONS:
    augment = Compose([
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.5),
        TimeStretch(min_rate=0.9, max_rate=1.1, p=0.5),
        PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
        Gain(min_gain_db=-6, max_gain_db=6, p=0.3),
    ])
    
    def apply_augmentation(audio, sample_rate, probability=0.5):
        if np.random.random() < probability:
            return augment(samples=audio, sample_rate=sample_rate)
        return audio
else:
    def apply_augmentation(audio, sample_rate, probability=0.5):
        if np.random.random() >= probability:
            return audio
        if np.random.random() < 0.5:
            noise = np.random.normal(0, 0.005, audio.shape)
            audio = audio + noise
        if np.random.random() < 0.3:
            gain = np.random.uniform(0.7, 1.3)
            audio = audio * gain
        return np.clip(audio, -1.0, 1.0)

print("✓ Augmentation pipeline initialized")
print("  Applied ON-THE-FLY during training (fresh augmentations each epoch)")


## 5. Load and Analyze Dataset

In [None]:
# Get all regional folders
regions = [d for d in os.listdir(TRAIN_AUDIO_PATH) if os.path.isdir(os.path.join(TRAIN_AUDIO_PATH, d))]
print(f"Found {len(regions)} regional dialects: {regions}")

In [None]:
# Load all training data
def load_dataset_from_folders():
    data = []
    
    for region in regions:
        csv_path = os.path.join(TRAIN_ANNOTATION_PATH, f"{region}.csv")
        df_region = pd.read_csv(csv_path)
        
        for _, row in df_region.iterrows():
            audio_path = os.path.join(TRAIN_AUDIO_PATH, region, row['audio'])
            if os.path.exists(audio_path):
                data.append({
                    'audio': audio_path,
                    'text': row['text'],
                    'region': region
                })
    
    return pd.DataFrame(data)

print("Loading dataset...")
df = load_dataset_from_folders()
print(f"Total samples: {len(df)}")
print(f"\nSamples per region:")
print(df['region'].value_counts())

## Calculate Class Weights for Balanced Sampling

## 6. Calculate Sample Weights for Balanced Training

In [None]:
# Calculate class weights
region_counts = Counter(df['region'])
total_samples = len(df)

class_weights = {region: total_samples / (len(region_counts) * count) 
                 for region, count in region_counts.items()}

sum_weights = sum(class_weights.values())
class_weights = {k: v/sum_weights * len(class_weights) for k, v in class_weights.items()}

print("\nClass weights (for balanced sampling):")
for region, weight in sorted(class_weights.items(), key=lambda x: x[1], reverse=True):
    print(f"  {region}: {weight:.3f} (n={region_counts[region]})")

df['sample_weight'] = df['region'].map(class_weights)


In [None]:
# Calculate class weights (inverse frequency)
def calculate_sample_weights(df):
    """Calculate sample weights for balanced sampling"""
    region_counts = df['region'].value_counts()
    total_samples = len(df)
    n_classes = len(region_counts)
    
    # Weight = (total_samples / (n_classes * class_count))
    weights = {}
    for region, count in region_counts.items():
        weights[region] = total_samples / (n_classes * count)
    
    # Assign weights to each sample
    df['sample_weight'] = df['region'].map(weights)
    
    print("Sample weights by region:")
    for region, weight in weights.items():
        print(f"  {region}: {weight:.3f}")
    
    return df

if CONFIG['use_weighted_sampling']:
    df = calculate_sample_weights(df)
    print("\n✓ Sample weights calculated")
else:
    df['sample_weight'] = 1.0
    print("\n⚠ Using uniform sampling (no weights)")

## 7. Train/Validation Split (Stratified)

In [None]:
from sklearn.model_selection import train_test_split

# Stratified split to ensure all dialects in both sets
train_df, val_df = train_test_split(
    df, 
    test_size=0.1, 
    random_state=SEED,
    stratify=df['region']  # Stratify by region
)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"\nTraining set distribution:")
print(train_df['region'].value_counts())
print(f"\nValidation set distribution:")
print(val_df['region'].value_counts())

## 8. Convert to HuggingFace Dataset

In [None]:
# Convert to Hugging Face Dataset
train_dataset = Dataset.from_pandas(train_df[['audio', 'text', 'region', 'sample_weight']])
val_dataset = Dataset.from_pandas(val_df[['audio', 'text', 'region', 'sample_weight']])

dataset = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

print(dataset)

## 9. Load Model and Processor

In [None]:
# Load processor
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME, language="Bengali", task="transcribe")
processor = WhisperProcessor.from_pretrained(MODEL_NAME, language="Bengali", task="transcribe")

print("✓ Processor loaded")
print(f"  Feature extractor sampling rate: {feature_extractor.sampling_rate}")
print(f"  Tokenizer language: {tokenizer.language}")

## 10. Prepare Data with Augmentation

In [None]:
# Preprocessing function with augmentation (BATCHED)
def prepare_dataset(batch, is_training=True):
    """
    Prepare dataset with optional augmentation (applied ON-THE-FLY per epoch).
    Handles batched data from set_transform.
    """
    # Initialize output lists
    input_features = []
    labels = []
    
    # Process each item in the batch
    audio_paths = batch["audio"] if isinstance(batch["audio"], list) else [batch["audio"]]
    texts = batch["text"] if isinstance(batch["text"], list) else [batch["text"]]
    
    for audio_path, text in zip(audio_paths, texts):
        # Load audio
        audio, sr = librosa.load(audio_path, sr=16000)
        
        # Apply augmentation during training only - FRESH EACH TIME
        if is_training and CONFIG['use_augmentation']:
            audio = apply_augmentation(audio, sr, CONFIG['augmentation_prob'])
        
        # Extract features
        input_features.append(
            feature_extractor(audio, sampling_rate=sr).input_features[0]
        )
        
        # Encode text
        labels.append(tokenizer(text).input_ids)
    
    return {
        "input_features": input_features,
        "labels": labels
    }

# Use set_transform for ON-THE-FLY processing (fresh augmentation each epoch)
print("Setting up on-the-fly data transforms...")
dataset['train'].set_transform(lambda x: prepare_dataset(x, is_training=True))
dataset['validation'].set_transform(lambda x: prepare_dataset(x, is_training=False))

# Assign to dataset_processed for compatibility
dataset_processed = dataset

print("✓ Transforms set (augmentation will be applied FRESH each epoch)")
print(dataset_processed)

# Use set_transform for ON-THE-FLY processing (fresh augmentation each epoch)
print("Setting up on-the-fly data transforms...")
dataset['train'].set_transform(lambda x: prepare_dataset(x, is_training=True))
dataset['validation'].set_transform(lambda x: prepare_dataset(x, is_training=False))

# Assign to dataset_processed for compatibility
dataset_processed = dataset

print("✓ Transforms set (augmentation will be applied FRESH each epoch)")
print(dataset_processed)


## 11. Define Data Collator

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split into inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Get labels
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 to ignore in loss
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # Remove BOS token if present
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

# Instantiate the data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
print("✓ Data collator initialized")

## 12. Define Evaluation Metrics

In [None]:
!pip install jiwer

In [None]:
# Load metrics
wer_metric = evaluate.load("wer")

def compute_metrics(pred):
    """Compute WER and character accuracy"""
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Replace -100 with pad token
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # Decode predictions and labels
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # Calculate WER
    wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)

    # Calculate character-level accuracy
    total_chars = sum(len(label) for label in label_str)
    correct_chars = sum(
        sum(1 for p, l in zip(pred, label) if p == l)
        for pred, label in zip(pred_str, label_str)
    )
    char_accuracy = 100 * correct_chars / total_chars if total_chars > 0 else 0

    return {
        "wer": wer,
        "char_accuracy": char_accuracy
    }

print("✓ Metrics defined")

## 13. Load Model for Full Fine-tuning

In [None]:
# Load base model
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)

# Configure model
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False  # Required for gradient checkpointing

# Enable gradient checkpointing to save memory
#model.gradient_checkpointing_enable()
# ===== FREEZE DECODER, TRAIN ENCODER ONLY =====
print("\n" + "="*60)
print("FREEZING DECODER - TRAINING ENCODER ONLY")
print("="*60)

# Freeze the entire decoder
for param in model.model.decoder.parameters():
    param.requires_grad = False

# Keep encoder trainable (it's trainable by default, but let's be explicit)
for param in model.model.encoder.parameters():
    param.requires_grad = True

# Projection layers - keep trainable (they connect encoder to decoder)
# proj_out is the final output projection - keep trainable
if hasattr(model, 'proj_out'):
    for param in model.proj_out.parameters():
        param.requires_grad = True
# Count trainable parameters

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
encoder_params = sum(p.numel() for p in model.model.encoder.parameters())
decoder_params = sum(p.numel() for p in model.model.decoder.parameters())

print(f"\nTotal parameters: {all_params:,}")
print(f"Encoder parameters: {encoder_params:,}")
print(f"Decoder parameters: {decoder_params:,}")
print(f"\nTrainable parameters: {trainable_params:,}")
print(f"Trainable %: {100 * trainable_params / all_params:.2f}%")
print(f"\n✓ Decoder is FROZEN (not trainable)")
print(f"✓ Encoder is TRAINABLE (will adapt to dialectal acoustics)")
print("="*60)



## 14. Define Training Arguments (Optimized for P100)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=CONFIG['batch_size'],
    per_device_eval_batch_size=CONFIG['batch_size'],
    gradient_accumulation_steps=CONFIG['gradient_accumulation'],
    learning_rate=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    warmup_steps=CONFIG['warmup_steps'],
    num_train_epochs=CONFIG['num_epochs'],
    
    # Evaluation and saving - NO INTERMEDIATE CHECKPOINTS
    eval_strategy="no",
    #eval_steps=CONFIG['eval_steps'],
    save_strategy="no",                      # ⭐ CHANGED: No intermediate saves
    save_total_limit=None,                   # ⭐ CHANGED: Not needed
    load_best_model_at_end=False,            # ⭐ CHANGED: Can't use without checkpoints
    
    # Optimization
    fp16=CONFIG['fp16'],
    gradient_checkpointing=False,             # ⭐ CHANGED: Enabled for memory savings
    optim="adamw_torch",
    
    # Generation
    predict_with_generate=True,
    generation_max_length=225,
    generation_num_beams=5,
    
    # Logging
    logging_steps=50,
    logging_dir=f"{OUTPUT_DIR}/logs",
    report_to=[],                            # ⭐ CHANGED: No tensorboard to save space
    
    # Other
    remove_unused_columns=False,
    label_names=["labels"],
    push_to_hub=False,
    seed=SEED,
)
print("\n" + "="*60)
print("TRAINING CONFIGURATION - ENCODER ONLY")
print("="*60)
print(f"Batch size per device: {CONFIG['batch_size']}")
print(f"Gradient accumulation steps: {CONFIG['gradient_accumulation']}")
print(f"Effective batch size: {CONFIG['batch_size'] * CONFIG['gradient_accumulation']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"Number of epochs: {CONFIG['num_epochs']}")
print(f"Training strategy: ENCODER ONLY (Decoder frozen)")
print(f"Checkpoint saving: FINAL MODEL ONLY")
print("="*60)



## 15. Initialize Trainer with Early Stopping

## Custom Trainer with Weighted Sampling

In [None]:
from torch.utils.data import WeightedRandomSampler

class WeightedSeq2SeqTrainer(Seq2SeqTrainer):
    """
    Custom trainer with WeightedRandomSampler for balanced regional sampling.
    The sampler resamples EVERY EPOCH automatically.
    """
    def __init__(self, sample_weights=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sample_weights = sample_weights
    
    def _get_train_sampler(self, dataset):
        if self.sample_weights is not None and CONFIG['use_weighted_sampling']:
            print("  ✓ Using WeightedRandomSampler (resamples every epoch)")
            return WeightedRandomSampler(
                weights=self.sample_weights,
                num_samples=len(self.sample_weights),
                replacement=True
            )
        return super()._get_train_sampler(dataset)

# Prepare sample weights tensor for trainer
train_sample_weights = torch.DoubleTensor(train_df['sample_weight'].values)

print("✓ Custom trainer with weighted sampling ready")
print(f"  Using weighted sampling: {CONFIG['use_weighted_sampling']}")
print("  ⚡ WeightedRandomSampler resamples each epoch")


In [None]:
# Initialize trainer with weighted sampling
trainer = WeightedSeq2SeqTrainer(
    sample_weights=train_sample_weights,
    args=training_args,
    model=model,
    train_dataset=dataset_processed['train'],  # ← Uses PROCESSED dataset with augmentation
    eval_dataset=dataset_processed['validation'],  # ← Uses PROCESSED dataset
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
    callbacks=[],
)

print("✓ Trainer initialized with enhanced features")
print("  ✓ Augmentation: Applied via prepare_dataset (fresh each epoch)")
print("  ✓ Weighted sampling: Applied via WeightedRandomSampler (every epoch)")


## 16. Train the Model

In [None]:
print("="*60)
print("STARTING ENCODER-ONLY TRAINING")
print("="*60)
print(f"\nDataset:")
print(f"  Training samples: {len(dataset_processed['train'])}")
print(f"  Validation samples: {len(dataset_processed['validation'])}")
print(f"\nTraining Strategy:")
print(f"  ✓ Base Model: {MODEL_NAME}")
print(f"  ✓ Encoder: TRAINABLE (learns dialectal acoustics)")
print(f"  ✓ Decoder: FROZEN (preserves standard Bangla)")
print(f"  ✓ Trainable params: {trainable_params:,} ({100 * trainable_params / all_params:.1f}%)")
print(f"\nTraining will start now...\n")

# Train
trainer.train()



In [None]:
print("\n" + "="*60)
print("TRAINING COMPLETED")
print("="*60)
# ⭐ ADD THIS SECTION ⭐
# Save only the final model
print("\nSaving final model...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
print(f"✓ Final model saved to {OUTPUT_DIR}")

# Save training metrics
import json
with open(f"{OUTPUT_DIR}/training_metrics.json", "w") as f:
    json.dump(trainer.state.log_history, f, indent=2)
print("✓ Training metrics saved")

In [None]:
# Save model info
model_info = {
    "base_model": MODEL_NAME,
    "training_strategy": "encoder_only",
    "decoder_frozen": True,
    "encoder_trainable": True,
    "total_params": all_params,
    "trainable_params": trainable_params,
    "trainable_percentage": 100 * trainable_params / all_params,
    "encoder_params": encoder_params,
    "decoder_params": decoder_params,
    "training_samples": len(dataset_processed['train']),
    "validation_samples": len(dataset_processed['validation']),
    "config": CONFIG
}

with open(f"{OUTPUT_DIR}/model_info.json", "w") as f:
    json.dump(model_info, f, indent=2)
print("✓ Model info saved")

print("\n" + "="*60)
print("ALL DONE!")
print("="*60)
print(f"\nModel saved to: {OUTPUT_DIR}")
print(f"Strategy: Encoder-only training (Decoder frozen)")
print(f"Ready for inference on test set!")

## 17. Evaluate on Validation Set

In [None]:
# print("\nEvaluating on validation set...")
# eval_results = trainer.evaluate()

# print("\n" + "="*60)
# print("VALIDATION RESULTS")
# print("="*60)
# for key, value in eval_results.items():
#     if isinstance(value, float):
#         print(f"{key}: {value:.4f}")
#     else:
#         print(f"{key}: {value}")
# print("="*60)

## 18. Save Model and Processor

In [None]:
# Save final model
final_model_dir = f"{OUTPUT_DIR}/final-model"
trainer.save_model(final_model_dir)
processor.save_pretrained(final_model_dir)

print("="*60)
print("MODEL SAVED")
print("="*60)
print(f"Model saved to: {final_model_dir}")
print("✓ Model weights")
print("✓ Processor configuration")
print("="*60)

## 19. Training Summary and Statistics

In [None]:
# # Get training history from logs
# import json

# log_history = trainer.state.log_history

# # Extract key metrics
# train_losses = [x['loss'] for x in log_history if 'loss' in x]
# eval_wers = [x['eval_wer'] for x in log_history if 'eval_wer' in x]
# eval_char_accs = [x['eval_char_accuracy'] for x in log_history if 'eval_char_accuracy' in x]

# print("\n" + "="*60)
# print("TRAINING STATISTICS")
# print("="*60)
# print(f"\nTotal training steps: {trainer.state.global_step}")
# print(f"Epochs completed: {trainer.state.epoch}")
# print(f"\nFinal training loss: {train_losses[-1]:.4f}")
# print(f"Best validation WER: {min(eval_wers):.2f}%")
# print(f"Best character accuracy: {max(eval_char_accs):.2f}%")
# print(f"\nImprovement over training:")
# if len(eval_wers) > 1:
#     print(f"  WER improved by: {eval_wers[0] - min(eval_wers):.2f}%")
#     print(f"  Char accuracy improved by: {max(eval_char_accs) - eval_char_accs[0]:.2f}%")
# print("="*60)

## 20. Quick Test on Sample Predictions

In [None]:
# # Test on a few validation samples
# print("\nTesting on sample predictions...")
# print("="*60)

# # Get a few samples from each dialect
# sample_indices = [0, 10, 20, 30, 40]  # Adjust as needed

# for idx in sample_indices[:5]:  # Show first 5
#     if idx >= len(dataset_processed['validation']):
#         break
    
#     sample = dataset_processed['validation'][idx]
    
#     # Get prediction
#     inputs = {
#         'input_features': torch.tensor(sample['input_features']).unsqueeze(0).to(model.device)
#     }
    
#     with torch.no_grad():
#         generated_ids = model.generate(**inputs, max_length=225, num_beams=5)
    
#     prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
#     reference = tokenizer.decode(sample['labels'], skip_special_tokens=True)
    
#     print(f"\nSample {idx + 1}:")
#     print(f"  Reference:  {reference}")
#     print(f"  Prediction: {prediction}")

# print("\n" + "="*60)

## 21. Test on Sample Files (First 10)