# 03 — Multilingual mBERT Fine-Tuning & Cross-Lingual Transfer

Fine-tune `bert-base-multilingual-cased` (mBERT) under multiple training regimes to study cross-lingual transfer for sentiment analysis.

**Experiments:**
- **A**: mBERT trained on English only → evaluate on EN, FR, NL (zero-shot transfer)
- **B**: mBERT trained on all three languages → evaluate on each
- **C**: mBERT trained on French only → evaluate on all
- **D**: mBERT trained on Dutch only → evaluate on all

**Prerequisite:** Run `01_data_exploration.ipynb` first.

In [None]:
import os
import json
import numpy as np
import pandas as pd
import torch
from datasets import load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import evaluate

SEED = 42
MAX_LENGTH = 256
MBERT = "bert-base-multilingual-cased"
DATA_DIR = "./data"
MODEL_DIR = "./models"

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Shared Utilities

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        **accuracy_metric.compute(predictions=preds, references=labels),
        **f1_metric.compute(predictions=preds, references=labels),
        **precision_metric.compute(predictions=preds, references=labels),
        **recall_metric.compute(predictions=preds, references=labels),
    }


# Shared mBERT tokenizer
tokenizer = AutoTokenizer.from_pretrained(MBERT)


def tokenize_dataset(dataset):
    def tokenize_fn(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=MAX_LENGTH,
        )
    cols_to_remove = [c for c in dataset.column_names if c not in ("label",)]
    tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=cols_to_remove)
    tokenized.set_format("torch")
    return tokenized


def get_training_args(output_dir, num_epochs=3):
    return TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=4,
        fp16=torch.cuda.is_available(),
        gradient_checkpointing=True,
        learning_rate=2e-5,
        weight_decay=0.01,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        max_grad_norm=1.0,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        logging_steps=100,
        report_to="none",
        seed=SEED,
        dataloader_num_workers=0,
        dataloader_pin_memory=False,
    )

## Load Data

In [None]:
en_data = load_from_disk(os.path.join(DATA_DIR, "en"))
fr_data = load_from_disk(os.path.join(DATA_DIR, "fr"))
nl_data = load_from_disk(os.path.join(DATA_DIR, "nl"))
combined_data = load_from_disk(os.path.join(DATA_DIR, "combined"))

# Pre-tokenize all test sets (reused across experiments)
test_sets = {
    "en": tokenize_dataset(en_data["test"]),
    "fr": tokenize_dataset(fr_data["test"]),
    "nl": tokenize_dataset(nl_data["test"]),
}

print("Test set sizes:")
for lang, ds in test_sets.items():
    print(f"  {lang}: {len(ds)}")

## Helper: Train mBERT and Evaluate on All Languages

In [None]:
all_results = []  # Collect results across experiments


def run_experiment(experiment_name, train_data, val_data, output_dir):
    """Train mBERT on given data and evaluate on all three languages."""
    print(f"\n{'='*60}")
    print(f"Experiment: {experiment_name}")
    print(f"Training size: {len(train_data)}")
    print(f"{'='*60}")

    model = AutoModelForSequenceClassification.from_pretrained(MBERT, num_labels=2)

    tok_train = tokenize_dataset(train_data)
    tok_val = tokenize_dataset(val_data)

    trainer = Trainer(
        model=model,
        args=get_training_args(output_dir),
        train_dataset=tok_train,
        eval_dataset=tok_val,
        compute_metrics=compute_metrics,
    )

    trainer.train()

    # Evaluate on all three languages
    exp_results = {"experiment": experiment_name}
    for lang, test_ds in test_sets.items():
        metrics = trainer.evaluate(test_ds, metric_key_prefix=f"test_{lang}")
        exp_results[f"{lang}_accuracy"] = metrics[f"test_{lang}_accuracy"]
        exp_results[f"{lang}_f1"] = metrics[f"test_{lang}_f1"]
        print(f"  {lang.upper()} test - Accuracy: {metrics[f'test_{lang}_accuracy']:.4f}, F1: {metrics[f'test_{lang}_f1']:.4f}")

    # Save model
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    all_results.append(exp_results)
    return exp_results

## Experiment A: mBERT Trained on English Only (Zero-Shot Transfer)

In [None]:
exp_a = run_experiment(
    experiment_name="mBERT (EN only)",
    train_data=en_data["train"],
    val_data=en_data["validation"],
    output_dir=os.path.join(MODEL_DIR, "mbert-en-only"),
)

## Experiment B: mBERT Trained on All Three Languages

In [None]:
exp_b = run_experiment(
    experiment_name="mBERT (EN+FR+NL)",
    train_data=combined_data["train"],
    val_data=combined_data["validation"],
    output_dir=os.path.join(MODEL_DIR, "mbert-multilingual"),
)

## Experiment C: mBERT Trained on French Only

In [None]:
exp_c = run_experiment(
    experiment_name="mBERT (FR only)",
    train_data=fr_data["train"],
    val_data=fr_data["validation"],
    output_dir=os.path.join(MODEL_DIR, "mbert-fr-only"),
)

## Experiment D: mBERT Trained on Dutch Only

In [None]:
exp_d = run_experiment(
    experiment_name="mBERT (NL only)",
    train_data=nl_data["train"],
    val_data=nl_data["validation"],
    output_dir=os.path.join(MODEL_DIR, "mbert-nl-only"),
)

## Cross-Lingual Results Summary

In [None]:
df_results = pd.DataFrame(all_results)

print("Cross-Lingual Transfer Results (F1 Scores)")
print("=" * 60)

f1_cols = ["experiment", "en_f1", "fr_f1", "nl_f1"]
print(df_results[f1_cols].to_string(index=False, float_format="{:.4f}".format))

print("\nCross-Lingual Transfer Results (Accuracy)")
print("=" * 60)

acc_cols = ["experiment", "en_accuracy", "fr_accuracy", "nl_accuracy"]
print(df_results[acc_cols].to_string(index=False, float_format="{:.4f}".format))

# Save results
df_results.to_csv("./results/crosslingual_results.csv", index=False)
print("\nResults saved to ./results/crosslingual_results.csv")

## Quick Visualization: Cross-Lingual Transfer Heatmap

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Create heatmap of F1 scores
heatmap_data = df_results.set_index("experiment")[["en_f1", "fr_f1", "nl_f1"]]
heatmap_data.columns = ["English", "French", "Dutch"]

fig, ax = plt.subplots(figsize=(8, 5))
sns.heatmap(
    heatmap_data,
    annot=True,
    fmt=".3f",
    cmap="YlOrRd",
    vmin=0.5,
    vmax=1.0,
    ax=ax,
)
ax.set_title("mBERT Cross-Lingual Transfer — F1 Scores")
ax.set_ylabel("Training Configuration")
ax.set_xlabel("Evaluation Language")

plt.tight_layout()
plt.savefig("./results/figures/crosslingual_heatmap.png", dpi=150, bbox_inches="tight")
plt.show()