In [17]:
import numpy as np
import math
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
import torch
import torch.nn as nn

full_dataset = load_dataset("coastalcph/tydi_xor_rc")

model_checkpoint = "distilbert/distilbert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def preprocess_function(examples):
    tokenized_input = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=512,
        padding="max_length"
    )
    tokenized_input["label"] = [int(ans) for ans in examples["answerable"]]
    return tokenized_input

def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

def analyze_class_distribution(dataset, lang):
    labels = [int(example['answerable']) for example in dataset]
    counter = Counter(labels)
    total = len(labels)
    
    print(f"\n--- Class Distribution for {lang.upper()} ---")
    print(f"Class 0 (Not Answerable): {counter[0]} ({counter[0]/total:.2%})")
    print(f"Class 1 (Answerable): {counter[1]} ({counter[1]/total:.2%})")
    print(f"Imbalance Ratio: {max(counter.values()) / min(counter.values()):.2f}:1")
    
    return labels

def compute_class_weights(labels):
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight('balanced', classes=unique_labels, y=labels)
    class_weight_dict = dict(zip(unique_labels, class_weights))
    
    print(f"Computed class weights: {class_weight_dict}")
    return torch.tensor(class_weights, dtype=torch.float32)

class WeightedTrainer(Trainer):    
    def __init__(self, class_weights=None, **kwargs):
        super().__init__(**kwargs)
        self.class_weights = class_weights
        
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get('logits')
        
        if self.class_weights is not None:
            weights = self.class_weights.to(logits.device)
            loss_fct = nn.CrossEntropyLoss(weight=weights)
        else:
            loss_fct = nn.CrossEntropyLoss()
            
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

def apply_focal_loss_trainer(alpha=0.25, gamma=2.0):
    
    class FocalLossTrainer(Trainer):
        def __init__(self, alpha=alpha, gamma=gamma, **kwargs):
            super().__init__(**kwargs)
            self.alpha = alpha
            self.gamma = gamma
            
        def compute_loss(self, model, inputs, return_outputs=False, **kwargs):

            labels = inputs.get("labels")
            outputs = model(**inputs)
            logits = outputs.get('logits')
            
            ce_loss = nn.CrossEntropyLoss(reduction='none')(logits.view(-1, self.model.config.num_labels), labels.view(-1))
            pt = torch.exp(-ce_loss)
            focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
            loss = focal_loss.mean()
            
            return (loss, outputs) if return_outputs else loss
    
    return FocalLossTrainer

languages = ['ar', 'ko', 'te']
results = {}

for lang in languages:
    print(f"Language: {lang.upper()}")
    
    train_dataset = full_dataset["train"].filter(lambda example: example['lang'] == lang)
    val_dataset = full_dataset["validation"].filter(lambda example: example['lang'] == lang)

    train_labels = analyze_class_distribution(train_dataset, f"{lang}_train")
    val_labels = analyze_class_distribution(val_dataset, f"{lang}_validation")
    
    class_weights = compute_class_weights(train_labels)
    
    tokenized_train = train_dataset.map(preprocess_function, batched=True)
    tokenized_val = val_dataset.map(preprocess_function, batched=True)

    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

    print(f"\nPerforming Zero-Shot Evaluation for {lang.upper()}...")
    
    zero_shot_trainer = Trainer(
        model=model,
        eval_dataset=tokenized_val,
        compute_metrics=compute_metrics,
    )
    zero_shot_results = zero_shot_trainer.evaluate()
    
    print(f"Zero-shot results: Acc={zero_shot_results['eval_accuracy']:.4f}, "
          f"F1={zero_shot_results['eval_f1']:.4f}, "
          f"Precision={zero_shot_results['eval_precision']:.4f}, "
          f"Recall={zero_shot_results['eval_recall']:.4f}")
    
    print(f"\nFine-Tuning Model with Class Balancing for {lang.upper()}...")
    
    training_args = TrainingArguments(
        output_dir=f"./results_{lang}",
        eval_strategy="epoch",
        learning_rate=2e-5,
        eval_steps=50,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,  
        weight_decay=0.01,
        logging_dir=f'./logs_{lang}',
        logging_steps=50,
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_f1",  
        greater_is_better=True,
        save_total_limit=2,
        dataloader_pin_memory=False, 
    )

    imbalance_ratio = max(Counter(train_labels).values()) / min(Counter(train_labels).values())
    
    if imbalance_ratio > 10:  
        print(f"Severe imbalance detected (ratio: {imbalance_ratio:.2f}:1). Using Focal Loss.")
        TrainerClass = apply_focal_loss_trainer(alpha=0.25, gamma=2.0)
        trainer = TrainerClass(
            model=model,
            args=training_args,
            train_dataset=tokenized_train,
            eval_dataset=tokenized_val,
            compute_metrics=compute_metrics,
        )
    else:  
        print(f"Moderate imbalance detected (ratio: {imbalance_ratio:.2f}:1). Using Class Weights.")
        trainer = WeightedTrainer(
            class_weights=class_weights,
            model=model,
            args=training_args,
            train_dataset=tokenized_train,
            eval_dataset=tokenized_val,
            compute_metrics=compute_metrics,
        )

    trainer.train()
    
    print(f"\nEvaluating Fine-Tuned Model for {lang.upper()}...")
    final_eval_results = trainer.evaluate()
    
    print(f"Fine-tuned results: Acc={final_eval_results['eval_accuracy']:.4f}, "
          f"F1={final_eval_results['eval_f1']:.4f}, "
          f"Precision={final_eval_results['eval_precision']:.4f}, "
          f"Recall={final_eval_results['eval_recall']:.4f}")
    
    results[lang] = {
        "zero_shot": zero_shot_results,
        "fine_tuned": final_eval_results,
        "class_distribution": {
            "train": dict(Counter(train_labels)),
            "val": dict(Counter(val_labels))
        },
        "imbalance_ratio": imbalance_ratio
    }

for lang, res in results.items():
    print(f"\nLanguage: {lang.upper()}")
    print("-" * 40)
    
    train_dist = res['class_distribution']['train']
    print(f"Training Set - Class 0: {train_dist.get(0, 0)}, Class 1: {train_dist.get(1, 0)}")
    print(f"Imbalance Ratio: {res['imbalance_ratio']:.2f}:1")
    
    zs_metrics = res['zero_shot']
    ft_metrics = res['fine_tuned']
    
    print(f"\nZero-Shot Performance:")
    print(f"  Accuracy: {zs_metrics['eval_accuracy']:.4f}")
    print(f"  F1: {zs_metrics['eval_f1']:.4f}")
    print(f"  Precision: {zs_metrics['eval_precision']:.4f}")
    print(f"  Recall: {zs_metrics['eval_recall']:.4f}")
    
    print(f"\nFine-Tuned Performance (Balanced):")
    print(f"  Accuracy: {ft_metrics['eval_accuracy']:.4f}")
    print(f"  F1: {ft_metrics['eval_f1']:.4f}")
    print(f"  Precision: {ft_metrics['eval_precision']:.4f}")
    print(f"  Recall: {ft_metrics['eval_recall']:.4f}")
    
    print(f"\nImprovements:")
    print(f"  Accuracy: {ft_metrics['eval_accuracy'] - zs_metrics['eval_accuracy']:+.4f}")
    print(f"  F1: {ft_metrics['eval_f1'] - zs_metrics['eval_f1']:+.4f}")
    print(f"  Precision: {ft_metrics['eval_precision'] - zs_metrics['eval_precision']:+.4f}")
    print(f"  Recall: {ft_metrics['eval_recall'] - zs_metrics['eval_recall']:+.4f}")

print("SUMMARY STATISTICS")

avg_improvements = {
    'accuracy': np.mean([res['fine_tuned']['eval_accuracy'] - res['zero_shot']['eval_accuracy'] for res in results.values()]),
    'f1': np.mean([res['fine_tuned']['eval_f1'] - res['zero_shot']['eval_f1'] for res in results.values()]),
    'precision': np.mean([res['fine_tuned']['eval_precision'] - res['zero_shot']['eval_precision'] for res in results.values()]),
    'recall': np.mean([res['fine_tuned']['eval_recall'] - res['zero_shot']['eval_recall'] for res in results.values()])
}

print(f"Average Improvements Across Languages:")
for metric, improvement in avg_improvements.items():
    print(f"  {metric.capitalize()}: {improvement:+.4f}")

Language: AR

--- Class Distribution for AR_TRAIN ---
Class 0 (Not Answerable): 255 (9.97%)
Class 1 (Answerable): 2303 (90.03%)
Imbalance Ratio: 9.03:1

--- Class Distribution for AR_VALIDATION ---
Class 0 (Not Answerable): 52 (12.53%)
Class 1 (Answerable): 363 (87.47%)
Imbalance Ratio: 6.98:1
Computed class weights: {0: 5.015686274509804, 1: 0.5553625705601389}


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Performing Zero-Shot Evaluation for AR...


Zero-shot results: Acc=0.3807, F1=0.4990, Precision=0.8533, Recall=0.3526

Fine-Tuning Model with Class Balancing for AR...
Moderate imbalance detected (ratio: 9.03:1). Using Class Weights.


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.2537,0.148033,0.980723,0.98892,0.994429,0.983471
2,0.1494,0.150443,0.980723,0.98892,0.994429,0.983471
3,0.1123,0.149198,0.980723,0.98892,0.994429,0.983471



Evaluating Fine-Tuned Model for AR...


Fine-tuned results: Acc=0.9807, F1=0.9889, Precision=0.9944, Recall=0.9835
Language: KO

--- Class Distribution for KO_TRAIN ---
Class 0 (Not Answerable): 63 (2.60%)
Class 1 (Answerable): 2359 (97.40%)
Imbalance Ratio: 37.44:1

--- Class Distribution for KO_VALIDATION ---
Class 0 (Not Answerable): 19 (5.34%)
Class 1 (Answerable): 337 (94.66%)
Imbalance Ratio: 17.74:1
Computed class weights: {0: 19.22222222222222, 1: 0.5133531157270029}


Map:   0%|          | 0/2422 [00:00<?, ? examples/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Performing Zero-Shot Evaluation for KO...


Zero-shot results: Acc=0.0646, F1=0.0235, Precision=1.0000, Recall=0.0119

Fine-Tuning Model with Class Balancing for KO...
Severe imbalance detected (ratio: 37.44:1). Using Focal Loss.


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.007,0.014385,0.946629,0.972583,0.946629,1.0
2,0.0062,0.013551,0.946629,0.972583,0.946629,1.0
3,0.002,0.013622,0.949438,0.973913,0.951841,0.997033



Evaluating Fine-Tuned Model for KO...


Fine-tuned results: Acc=0.9494, F1=0.9739, Precision=0.9518, Recall=0.9970
Language: TE

--- Class Distribution for TE_TRAIN ---
Class 0 (Not Answerable): 45 (3.32%)
Class 1 (Answerable): 1310 (96.68%)
Imbalance Ratio: 29.11:1

--- Class Distribution for TE_VALIDATION ---
Class 0 (Not Answerable): 93 (24.22%)
Class 1 (Answerable): 291 (75.78%)
Imbalance Ratio: 3.13:1
Computed class weights: {0: 15.055555555555555, 1: 0.517175572519084}


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Performing Zero-Shot Evaluation for TE...


Zero-shot results: Acc=0.2422, F1=0.0136, Precision=0.5000, Recall=0.0069

Fine-Tuning Model with Class Balancing for TE...
Severe imbalance detected (ratio: 29.11:1). Using Focal Loss.


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.0037,0.028354,0.895833,0.934853,0.888545,0.986254
2,0.0043,0.027777,0.911458,0.944079,0.905363,0.986254
3,0.0017,0.028787,0.929688,0.955075,0.925806,0.986254



Evaluating Fine-Tuned Model for TE...


Fine-tuned results: Acc=0.9297, F1=0.9551, Precision=0.9258, Recall=0.9863

Language: AR
----------------------------------------
Training Set - Class 0: 255, Class 1: 2303
Imbalance Ratio: 9.03:1

Zero-Shot Performance:
  Accuracy: 0.3807
  F1: 0.4990
  Precision: 0.8533
  Recall: 0.3526

Fine-Tuned Performance (Balanced):
  Accuracy: 0.9807
  F1: 0.9889
  Precision: 0.9944
  Recall: 0.9835

Improvements:
  Accuracy: +0.6000
  F1: +0.4899
  Precision: +0.1411
  Recall: +0.6309

Language: KO
----------------------------------------
Training Set - Class 0: 63, Class 1: 2359
Imbalance Ratio: 37.44:1

Zero-Shot Performance:
  Accuracy: 0.0646
  F1: 0.0235
  Precision: 1.0000
  Recall: 0.0119

Fine-Tuned Performance (Balanced):
  Accuracy: 0.9494
  F1: 0.9739
  Precision: 0.9518
  Recall: 0.9970

Improvements:
  Accuracy: +0.8848
  F1: +0.9505
  Precision: -0.0482
  Recall: +0.9852

Language: TE
----------------------------------------
Training Set - Class 0: 45, Class 1: 1310
Imbalance Ra