In [None]:
from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer, TrainingArguments, EarlyStoppingCallback
import numpy as np
import torch
from sklearn.metrics import accuracy_score
from sklearn.utils.class_weight import compute_class_weight
from datasets import load_dataset
from collections import Counter

full_dataset = load_dataset("coastalcph/tydi_xor_rc")

model_checkpoint = "distilbert/distilbert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

label_list = ["O", "ANS"]
label_to_id = {l: i for i, l in enumerate(label_list)}
id_to_label = {i: l for l, i in label_to_id.items()}

max_length = 384
doc_stride = 128
pad_on_right = tokenizer.padding_side == "right"

languages = ["ar", "ko", "te"]

train_dataset = full_dataset["train"]
val_dataset = full_dataset["validation"]

def analyze_class_distribution(dataset, lang):
    labels = []
    for example in dataset:
        if example["answer_start"] == -1:
            labels.append(0)
        else:
            labels.append(1)
    
    counter = Counter(labels)
    total = len(labels)
    
    print(f"\n--- Class Distribution for {lang.lower()} ---")
    print(f"Class 0 (Not Answerable): {counter[0]} ({counter[0]/total:.2%})")
    print(f"Class 1 (Answerable): {counter[1]} ({counter[1]/total:.2%})")
    print(f"Imbalance Ratio: {max(counter.values()) / min(counter.values()):.2f}:1")
    
    return labels

def compute_class_weights(labels):
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight('balanced', classes=unique_labels, y=labels)
    class_weights[1] *= 100
    class_weight_dict = dict(zip(unique_labels, class_weights))
    
    print(f"Computed class weights: {class_weight_dict}")
    return torch.tensor(class_weights, dtype=torch.float32)

def create_token_labels(examples):
    examples["question"] = [q.lstrip() for q in examples["question"]]
    tokenized = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_mapping = tokenized.pop("overflow_to_sample_mapping")
    offsets_mapping = tokenized.pop("offset_mapping")

    labels = []
    langs = []  # Preserve language information
    for i, offsets in enumerate(offsets_mapping):
        sequence_ids = tokenized.sequence_ids(i)
        sample_idx = sample_mapping[i]
        answer_start = examples["answer_start"][sample_idx]
        answer_text = examples["answer"][sample_idx]
        answer_end = -1 if answer_start == -1 else answer_start + len(answer_text)
        
        # Preserve the language for this tokenized sample
        langs.append(examples["lang"][sample_idx])

        example_labels = []
        context_id = 1 if pad_on_right else 0
        
        for idx, offset in enumerate(offsets):
            if sequence_ids[idx] is None:
                example_labels.append(-100)
            elif sequence_ids[idx] != context_id:
                example_labels.append(-100)
            else:
                if answer_start == -1 or offset is None:
                    example_labels.append(label_to_id["O"])
                else:
                    start, end = offset
                    if start >= answer_end or end <= answer_start:
                        example_labels.append(label_to_id["O"])
                    else:
                        example_labels.append(label_to_id["ANS"])
        labels.append(example_labels)

    tokenized["labels"] = labels
    tokenized["lang"] = langs  # Add language column
    tokenized["offset_mapping"] = offsets_mapping
    return tokenized

def extract_answer_from_predictions(input_ids, predictions, offset_mapping, sequence_ids, context_id=1):
    ans_indices = [i for i, (pred, seq_id) in enumerate(zip(predictions, sequence_ids)) 
                   if pred == label_to_id["ANS"] and seq_id == context_id]
    
    if not ans_indices:
        return ""
    
    start_idx = ans_indices[0]
    end_idx = ans_indices[0]
    
    for i in range(1, len(ans_indices)):
        if ans_indices[i] == ans_indices[i-1] + 1:
            end_idx = ans_indices[i]
        else:
            break
    
    if offset_mapping[start_idx] is not None and offset_mapping[end_idx] is not None:
        char_start = offset_mapping[start_idx][0]
        char_end = offset_mapping[end_idx][1]
        
        tokens = input_ids[start_idx:end_idx+1]
        pred_text = tokenizer.decode(tokens, skip_special_tokens=True)
        return pred_text.strip()
    
    return ""

def compute_token_metrics(eval_preds):
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=-1)
    
    true_labels = []
    pred_labels = []
    
    for p, l in zip(preds, labels):
        for pi, li in zip(p, l):
            if li != -100:
                true_labels.append(li)
                pred_labels.append(pi)
    
    acc = accuracy_score(true_labels, pred_labels)
    
    return {
        "accuracy": acc
    }

def compute_exact_match(trainer, eval_dataset, original_dataset):
    predictions = trainer.predict(eval_dataset)
    pred_logits = predictions.predictions
    pred_labels = np.argmax(pred_logits, axis=-1)
    
    exact_matches = 0
    total = 0
    
    sample_to_predictions = {}
    
    for i in range(len(eval_dataset)):
        input_ids = eval_dataset[i]["input_ids"]
        offset_mapping = eval_dataset[i].get("offset_mapping", None)
        
        sequence_ids = []
        for idx in range(len(input_ids)):
            sid = None
            if input_ids[idx] == tokenizer.cls_token_id or input_ids[idx] == tokenizer.sep_token_id or input_ids[idx] == tokenizer.pad_token_id:
                sid = None
            else:
                sep_indices = [j for j, token_id in enumerate(input_ids) if token_id == tokenizer.sep_token_id]
                if sep_indices:
                    if pad_on_right:
                        sid = 0 if idx < sep_indices[0] else 1
                    else:
                        sid = 1 if idx < sep_indices[0] else 0
            sequence_ids.append(sid)
        
        pred_answer = extract_answer_from_predictions(
            input_ids, pred_labels[i], offset_mapping, sequence_ids, 
            context_id=1 if pad_on_right else 0
        )
        
        if hasattr(eval_dataset, 'features') and 'overflow_to_sample_mapping' in eval_dataset.features:
            original_idx = eval_dataset[i].get('overflow_to_sample_mapping', i)
        else:
            original_idx = i % len(original_dataset)
        
        if original_idx not in sample_to_predictions:
            sample_to_predictions[original_idx] = []
        sample_to_predictions[original_idx].append(pred_answer)
    
    for sample_idx, pred_answers in sample_to_predictions.items():
        if sample_idx < len(original_dataset):
            true_answer = original_dataset[sample_idx]["answer"]
            pred_answer = next((p for p in pred_answers if p), "")
            
            true_norm = true_answer.strip().lower()
            pred_norm = pred_answer.strip().lower()
            
            if true_norm == pred_norm:
                exact_matches += 1
            total += 1
    
    pred_labels_flat = pred_labels.flatten()
    valid_preds = pred_labels_flat[pred_labels_flat != -100]
    unique, counts = np.unique(valid_preds, return_counts=True)
    pred_dist = dict(zip(unique, counts))
    print(f"Prediction distribution: {pred_dist}")
    
    em_score = exact_matches / total if total > 0 else 0
    return em_score, exact_matches, total

class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        if self.class_weights is not None:
            loss_fct = torch.nn.CrossEntropyLoss(
                weight=self.class_weights.to(model.device),
                ignore_index=-100
            )
            loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        else:
            loss = outputs.loss
        
        return (loss, outputs) if return_outputs else loss

print(f"\n{'='*60}")
print(f"Training multilingual sequence labeler")
print(f"{'='*60}")

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

all_labels = []
for example in train_dataset:
    if example["answer_start"] == -1:
        all_labels.append(0)
    else:
        all_labels.append(1)

class_weights = compute_class_weights(all_labels)

columns_to_remove = [col for col in train_dataset.column_names if col != "lang"]
tokenized_train = train_dataset.map(
    create_token_labels, 
    batched=True, 
    remove_columns=columns_to_remove
)
tokenized_val = val_dataset.map(
    create_token_labels, 
    batched=True, 
    remove_columns=columns_to_remove
)

model_tc = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint, 
    num_labels=len(label_list), 
    id2label=id_to_label, 
    label2id=label_to_id
)

args_tc = TrainingArguments(
    output_dir="seq-lab-multilingual",
    eval_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_steps=1,
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to=[],
    push_to_hub=False,
)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer_tc = WeightedTrainer(
    model=model_tc,
    args=args_tc,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_token_metrics,
    class_weights=class_weights,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

# Comment out training to load from checkpoint instead
# print(f"\nTraining...")
# trainer_tc.train()

# Load from checkpoint
print(f"\nLoading from checkpoint...")
checkpoint_path = "seq-lab-multilingual"  
model_tc = AutoModelForTokenClassification.from_pretrained(checkpoint_path)
trainer_tc.model = model_tc

print(f"\nEvaluating per language...")
results_by_lang = {}

for lang in languages:
    print(f"\n{'='*60}")
    print(f"Evaluating on {lang.lower()}")
    print(f"{'='*60}")
    
    lang_val = val_dataset.filter(lambda ex: ex["lang"] == lang)
    tokenized_lang_val = tokenized_val.filter(lambda ex: ex["lang"] == lang)
    
    print(f"\nEvaluating token-level metrics...")
    token_metrics = trainer_tc.evaluate(tokenized_lang_val)
    
    print(f"Computing exact match...")
    em_score, em_count, em_total = compute_exact_match(trainer_tc, tokenized_lang_val, lang_val)
    
    results_by_lang[lang] = {
        **token_metrics,
        "exact_match": em_score,
        "exact_match_count": f"{em_count}/{em_total}"
    }
    
    print(f"\n{'='*60}")
    print(f"Results for {lang.lower()}:")
    print(f"{'='*60}")
    print(f"Token-level Accuracy:  {token_metrics['eval_accuracy']:.4f}")
    print(f"Exact Match:           {em_score:.4f} ({em_count}/{em_total})")
    print(f"{'='*60}\n")

print(f"{'Language':<12} {'Acc':<8}  {'EM':<8}")
print(f"{'-'*60}")
for lang in languages:
    metrics = results_by_lang[lang]
    print(f"{lang.lower():<12} "
          f"{metrics['eval_accuracy']:.4f}"
          f"{metrics['exact_match']:.4f}")
print(f"{'='*60}\n")

results_by_lang