In [1]:
from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer, TrainingArguments, EarlyStoppingCallback
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from datasets import load_dataset
from collections import Counter, defaultdict

full_dataset = load_dataset("coastalcph/tydi_xor_rc")

model_checkpoint = "distilbert/distilbert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

label_list = ["O", "ANS"]
label_to_id = {l: i for i, l in enumerate(label_list)}
id_to_label = {i: l for l, i in label_to_id.items()}

max_length = 384
doc_stride = 128
pad_on_right = tokenizer.padding_side == "right"

languages = ["ar", "ko", "te"]

train_dataset = full_dataset["train"]
val_dataset = full_dataset["validation"]

def analyze_class_distribution(dataset, lang):
    labels = []
    for example in dataset:
        if example["answer_start"] == -1:
            labels.append(0)
        else:
            labels.append(1)
    
    counter = Counter(labels)
    total = len(labels)
    
    print(f"\n--- Class Distribution for {lang.lower()} ---")
    print(f"Class 0 (Not Answerable): {counter[0]} ({counter[0]/total:.2%})")
    print(f"Class 1 (Answerable): {counter[1]} ({counter[1]/total:.2%})")
    print(f"Imbalance Ratio: {max(counter.values()) / min(counter.values()):.2f}:1")
    
    return labels

def compute_class_weights(labels):
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight('balanced', classes=unique_labels, y=labels)
    class_weights[1] *= 150
    class_weight_dict = dict(zip(unique_labels, class_weights))
    
    print(f"Computed class weights: {class_weight_dict}")
    return torch.tensor(class_weights, dtype=torch.float32)

def create_token_labels(examples):
    examples["question"] = [q.lstrip() for q in examples["question"]]
    tokenized = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_mapping = tokenized["overflow_to_sample_mapping"]
    offsets_mapping = tokenized["offset_mapping"]

    labels = []
    langs = [] 
    for i, offsets in enumerate(offsets_mapping):
        sequence_ids = tokenized.sequence_ids(i)
        sample_idx = sample_mapping[i]
        answer_start = examples["answer_start"][sample_idx]
        answer_text = examples["answer"][sample_idx]
        answer_end = -1 if answer_start == -1 else answer_start + len(answer_text)
        
        langs.append(examples["lang"][sample_idx])

        example_labels = []
        context_id = 1 if pad_on_right else 0
        
        for idx, offset in enumerate(offsets):
            if sequence_ids[idx] is None:
                example_labels.append(-100)
            elif sequence_ids[idx] != context_id:
                example_labels.append(-100)
            else:
                if answer_start == -1 or offset is None:
                    example_labels.append(label_to_id["O"])
                else:
                    start, end = offset
                    if start >= answer_end or end <= answer_start:
                        example_labels.append(label_to_id["O"])
                    else:
                        example_labels.append(label_to_id["ANS"])
        labels.append(example_labels)

    tokenized["labels"] = labels
    tokenized["lang"] = langs  
    tokenized["offset_mapping"] = offsets_mapping
    return tokenized

def extract_answer_from_predictions(input_ids, predictions, offset_mapping, sequence_ids, context_id=1):
    spans = []
    start = None
    for i, (pred, sid) in enumerate(zip(predictions, sequence_ids)):
        if sid == context_id and pred == label_to_id["ANS"]:
            if start is None:
                start = i
        else:
            if start is not None:
                spans.append((start, i-1))
                start = None
    if start is not None:
        spans.append((start, len(predictions)-1))
    best_text = ""
    best_len = -1
    for s, e in spans:
        if offset_mapping[s] is not None and offset_mapping[e] is not None:
            tokens = input_ids[s:e+1]
            text = tokenizer.decode(tokens, skip_special_tokens=True).strip()
            if len(text) > best_len:
                best_text = text
                best_len = len(text)
    return best_text

def compute_token_metrics(eval_preds):
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=-1)
    
    true_labels = []
    pred_labels = []
    
    for p, l in zip(preds, labels):
        for pi, li in zip(p, l):
            if li != -100:
                true_labels.append(li)
                pred_labels.append(pi)
    
    acc = accuracy_score(true_labels, pred_labels)
    
    return {
        "accuracy": acc
    }

def compute_f1(trainer, eval_dataset, original_dataset):
    predictions = trainer.predict(eval_dataset)
    pred_logits = predictions.predictions
    pred_labels = np.argmax(pred_logits, axis=-1)
    
    f1_scores = []
    
    sample_to_predictions = defaultdict(list)
    
    for i in range(len(eval_dataset)):
        input_ids = eval_dataset[i]["input_ids"]
        offset_mapping = eval_dataset[i].get("offset_mapping", None)
        
        sequence_ids = []
        for idx in range(len(input_ids)):
            sid = None
            if input_ids[idx] == tokenizer.cls_token_id or input_ids[idx] == tokenizer.sep_token_id or input_ids[idx] == tokenizer.pad_token_id:
                sid = None
            else:
                sep_indices = [j for j, token_id in enumerate(input_ids) if token_id == tokenizer.sep_token_id]
                if sep_indices:
                    if pad_on_right:
                        sid = 0 if idx < sep_indices[0] else 1
                    else:
                        sid = 1 if idx < sep_indices[0] else 0
            sequence_ids.append(sid)
        
        pred_answer = extract_answer_from_predictions(
            input_ids, pred_labels[i], offset_mapping, sequence_ids, 
            context_id=1 if pad_on_right else 0
        )
        
        if hasattr(eval_dataset, 'features') and 'overflow_to_sample_mapping' in eval_dataset.features:
            original_idx = eval_dataset[i].get('overflow_to_sample_mapping', i)
        else:
            original_idx = i % len(original_dataset)
        
        sample_to_predictions[original_idx].append(pred_answer)
    
    for sample_idx, pred_answers in sample_to_predictions.items():
        if sample_idx < len(original_dataset):
            true_answer = original_dataset[sample_idx]["answer"].strip().lower()
            
            best_f1 = 0.0
            for pred_answer in pred_answers:
                pred_answer = pred_answer.strip().lower()
                if pred_answer == "":
                    f1_scores.append(0.0)
                    continue
                true_chars = set(range(len(true_answer)))
                pred_chars = set(range(len(pred_answer)))
                common = min(len(true_chars), len(pred_chars))
                precision = common / len(pred_chars) if len(pred_chars) > 0 else 0.0
                recall = common / len(true_chars) if len(true_chars) > 0 else 0.0
                f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
                if f1 > best_f1:
                    best_f1 = f1
            f1_scores.append(best_f1)
    
    pred_labels_flat = pred_labels.flatten()
    valid_preds = pred_labels_flat[pred_labels_flat != -100]
    unique, counts = np.unique(valid_preds, return_counts=True)
    pred_dist = dict(zip(unique, counts))
    print(f"Prediction distribution: {pred_dist}")
    
    avg_f1 = np.mean(f1_scores) if f1_scores else 0.0
    return avg_f1, len(f1_scores)

def focal_loss(logits, labels, alpha=None, gamma=2.0, ignore_index=-100, label_smoothing=0.1):
    ce_loss = F.cross_entropy(
        logits, 
        labels, 
        weight=alpha, 
        reduction='none', 
        ignore_index=ignore_index,
        label_smoothing=label_smoothing
    )
    pt = torch.exp(-ce_loss)
    focal_loss = ((1 - pt) ** gamma) * ce_loss
    return focal_loss.mean()

class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights=None, focal_gamma=2.0, label_smoothing=0.1, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
        self.focal_gamma = focal_gamma
        self.label_smoothing = label_smoothing
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        alpha = self.class_weights.to(model.device) if self.class_weights is not None else None
        loss = focal_loss(
            logits.view(-1, self.model.config.num_labels), 
            labels.view(-1),
            alpha=alpha,
            gamma=self.focal_gamma,
            ignore_index=-100,
            label_smoothing=self.label_smoothing
        )
        
        return (loss, outputs) if return_outputs else loss

print(f"\n{'='*60}")
print(f"Training multilingual sequence labeler")
print(f"{'='*60}")

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

all_labels = []
for example in train_dataset:
    if example["answer_start"] == -1:
        all_labels.append(0)
    else:
        all_labels.append(1)

class_weights = compute_class_weights(all_labels)

columns_to_remove = [col for col in train_dataset.column_names if col != "lang"]
tokenized_train = train_dataset.map(
    create_token_labels, 
    batched=True, 
    remove_columns=columns_to_remove
)
tokenized_val = val_dataset.map(
    create_token_labels, 
    batched=True, 
    remove_columns=columns_to_remove
)

model_tc = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint, 
    num_labels=len(label_list), 
    id2label=id_to_label, 
    label2id=label_to_id,
    dropout=0.2,
    attention_dropout=0.2,
)

args_tc = TrainingArguments(
    output_dir="seq-lab-multilingual",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    num_train_epochs=50,
    weight_decay=0.01,
    warmup_steps=500,
    warmup_ratio=0.1,
    max_grad_norm=1.0,
    logging_steps=50,
    save_strategy="no",
    metric_for_best_model="accuracy",
    greater_is_better=True,
    report_to=[],
    push_to_hub=False,
    lr_scheduler_type="cosine",
    adam_epsilon=1e-8,
    adam_beta1=0.9,
    adam_beta2=0.999,
    fp16=True,
)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer_tc = WeightedTrainer(
    model=model_tc,
    args=args_tc,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_token_metrics,
    class_weights=class_weights,
    focal_gamma=2.5,
    label_smoothing=0.1,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

print(f"\nTraining...")
trainer_tc.train()

print(f"\nEvaluating per language...")
results_by_lang = {}

for lang in languages:
    print(f"\n{'='*60}")
    print(f"Evaluating on {lang.lower()}")
    print(f"{'='*60}")
    
    lang_val = val_dataset.filter(lambda ex: ex["lang"] == lang)
    tokenized_lang_val = tokenized_val.filter(lambda ex: ex["lang"] == lang)
    
    print(f"\nEvaluating token-level metrics...")
    token_metrics = trainer_tc.evaluate(tokenized_lang_val)
    
    print(f"Computing F1...")
    f1_score, f1_count = compute_f1(trainer_tc, tokenized_lang_val, lang_val)
    
    results_by_lang[lang] = {
        **token_metrics,
        "f1": f1_score,
        "f1_count": f1_count
    }
    
    print(f"\n{'='*60}")
    print(f"Results for {lang.lower()}:")
    print(f"{'='*60}")
    print(f"Token-level Accuracy:  {token_metrics['eval_accuracy']:.4f}")
    print(f"F1 Score:              {f1_score:.4f} (n={f1_count})")
    print(f"{'='*60}\n")

print(f"\n{'='*60}")
print("CROSS-LANGUAGE COMPARISON")
print(f"{'='*60}")
print(f"{'Language':<12} {'Acc':<8}  {'F1':<8}")
print(f"{'-'*60}")
lang_names = {"ar": "Arabic", "ko": "Korean", "te": "Telugu"}
for lang in languages:
    metrics = results_by_lang[lang]
    print(f"{lang_names.get(lang, lang).upper():<12} "
          f"{metrics['eval_accuracy']:.4f}   "
          f"{metrics['f1']:.4f}")
print(f"{'='*60}\n")

results_by_lang


Training multilingual sequence labeler
Training samples: 15343
Validation samples: 3011
Computed class weights: {0: 5.555032585083273, 1: 82.41834980661795}


Map:   0%|          | 0/3011 [00:00<?, ? examples/s]

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert/distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  super().__init__(*args, **kwargs)
Using EarlyStoppingCallback without load_best_model_at_end=True. Once training is finished, the best model will not be loaded automatically.



Training...


Epoch,Training Loss,Validation Loss,Accuracy
1,2.6902,2.664846,0.500284
2,2.565,2.639022,0.606317
3,2.5359,2.619267,0.726457
4,2.5553,2.607516,0.717664
5,2.4888,2.614565,0.85192
6,2.5215,2.612075,0.847645
7,2.4485,2.636699,0.900989
8,2.3987,2.63715,0.886161
9,2.3886,2.636163,0.927144
10,2.3825,2.644253,0.916865



Evaluating per language...

Evaluating on ar


Filter:   0%|          | 0/3072 [00:00<?, ? examples/s]


Evaluating token-level metrics...


Computing F1...
Prediction distribution: {0: 147919, 1: 19505}

Results for ar:
Token-level Accuracy:  0.9488
F1 Score:              0.4709 (n=126)


Evaluating on ko


Filter:   0%|          | 0/3072 [00:00<?, ? examples/s]


Evaluating token-level metrics...


Computing F1...
Prediction distribution: {0: 120693, 1: 18315}

Results for ko:
Token-level Accuracy:  0.9504
F1 Score:              0.5195 (n=105)


Evaluating on te


Filter:   0%|          | 0/3072 [00:00<?, ? examples/s]


Evaluating token-level metrics...


Computing F1...
Prediction distribution: {0: 128793, 1: 20199}

Results for te:
Token-level Accuracy:  0.9563
F1 Score:              0.7060 (n=285)


CROSS-LANGUAGE COMPARISON
Language     Acc       F1      
------------------------------------------------------------
ARABIC       0.9488   0.4709
KOREAN       0.9504   0.5195
TELUGU       0.9563   0.7060



{'ar': {'eval_loss': 2.7534103393554688,
  'eval_accuracy': 0.9487802150400204,
  'eval_runtime': 1.5483,
  'eval_samples_per_second': 281.594,
  'eval_steps_per_second': 35.522,
  'epoch': 20.0,
  'f1': 0.47086206945954434,
  'f1_count': 126},
 'ko': {'eval_loss': 2.6023454666137695,
  'eval_accuracy': 0.9503908725966618,
  'eval_runtime': 1.2988,
  'eval_samples_per_second': 278.718,
  'eval_steps_per_second': 35.417,
  'epoch': 20.0,
  'f1': 0.5195303797508997,
  'f1_count': 105},
 'te': {'eval_loss': 2.7963624000549316,
  'eval_accuracy': 0.9562619752656332,
  'eval_runtime': 1.4434,
  'eval_samples_per_second': 268.81,
  'eval_steps_per_second': 33.948,
  'epoch': 20.0,
  'f1': 0.7059633890748965,
  'f1_count': 285}}