# Phase 3 - Mistral Extension - XLM-R Variety Classifier

## What this notebook is for
This notebook trains and evaluates a **variety identification model** using **XLM-RoBERTa** to classify English text into the target varieties:
- **en-AU**, **en-IN**, **en-UK**

This model is a **prerequisite** for the Mistral extension pipeline, where we need a reliable way to:
- assign / validate variety labels (or predicted variety),
- compute variety-aware statistics,
- optionally route examples to variety-specific components in downstream experiments.

## Inputs
- Processed BESSTIE-style data with at least: `text`, `variety_name` (or `variety_id`), and split info (train/val/test indices).
- (If applicable) existing `data/processed/...` indices defining train/test sets.

## What this notebook does (high level)
1. Load the dataset and keep only examples with valid variety labels (**en-AU/en-IN/en-UK**)
2. Preprocess + tokenize using **XLM-R**
3. Train a **3-class classifier** for variety prediction
4. Evaluate performance (accuracy + macro F1, plus per-class precision/recall/F1)
5. Save the trained classifier checkpoint and evaluation artifacts

## Outputs saved (artifacts)
- `checkpoints/` → trained XLM-R variety classifier weights
- `metrics/` → CSV metrics (overall + per-class breakdown)
- `predictions/` → CSV predictions (true variety, predicted variety, probabilities)
- `plots/` → confusion matrix + paper-style bar charts (if enabled)
- `analysis/` → example-level error analysis (hard confusions AU↔UK, AU/UK↔IN)

## How this connects to the Mistral extension
Downstream (Mistral extension), this classifier can be used to:
- **predict variety** for routing / conditioning,
- **measure** whether improvements are consistent across varieties,
- support **variety-aware** prompting or adapter selection (depending on the extension design).

## Reproducibility
- All runs should use fixed `seed` values and consistent split/index files.
- Keep output naming consistent with the project run naming scheme.

In [None]:
!pip uninstall unsloth unsloth-zoo -y
!pip install --upgrade --no-cache-dir unsloth unsloth-zoo
!pip install bitsandbytes -U

Found existing installation: unsloth 2026.1.4
Uninstalling unsloth-2026.1.4:
  Successfully uninstalled unsloth-2026.1.4
Found existing installation: unsloth_zoo 2026.1.4
Uninstalling unsloth_zoo-2026.1.4:
  Successfully uninstalled unsloth_zoo-2026.1.4
Collecting unsloth
  Downloading unsloth-2026.1.4-py3-none-any.whl.metadata (66 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.8/66.8 kB[0m [31m183.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting unsloth-zoo
  Downloading unsloth_zoo-2026.1.4-py3-none-any.whl.metadata (32 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.2 kB)
Collecting torch>=2.4.0 (from unsloth)
  Downloading torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (31 kB)
Collecting cuda-bindings==12.9.4 (from torch>=2.4.0->unsloth)
  Downloading cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (2.6 kB)
Collecting nvi

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset, ClassLabel
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, f1_score, classification_report

# ==========================================
# 1. CONFIGURATION
# ==========================================
MODEL_NAME = "xlm-roberta-base"
ROUTER_OUTPUT_DIR = "./variety_router_full_finetune_stratified"
MAX_SEQ_LENGTH = 512
# Explicitly mapping the VARIETY strings to integer labels
LABEL_MAP = {"en-UK": 0, "en-AU": 1, "en-IN": 2}

os.makedirs(ROUTER_OUTPUT_DIR, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"

# ==========================================
# 2. DATA LOADING & CLEANING
# ==========================================
print("Loading BESSTIE dataset...")
raw_dataset = load_dataset("unswnlporg/BESSTIE")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token_id = 1  # XLM-RoBERTa's <pad> token

def preprocess_for_router(examples):
    # CRITICAL: We only use 'variety' for labels, ignoring the original 'label' column
    labels = [LABEL_MAP[v] for v in examples["variety"]]

    # Tokenize text according to the benchmark baseline
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        padding="max_length"
    )

    # Add our new three-class mapped labels
    tokenized["labels"] = labels
    return tokenized

print("Preprocessing and Tokenizing (Removing existing task labels)...")
# We remove ALL original columns to ensure the model only sees 'input_ids', 'attention_mask', and our new 'labels'
tokenized_full = raw_dataset.map(
    preprocess_for_router,
    batched=True,
    remove_columns=raw_dataset["train"].column_names # Removes 'text', 'label', 'variety', 'source', 'task'
)

# --- Casting to ClassLabel for Stratification ---
train_dataset = tokenized_full["train"].class_encode_column("labels")

# 80/20 Stratified Split from the original 'train' set
train_val_split = train_dataset.train_test_split(
    test_size=0.2,
    seed=42,
    stratify_by_column="labels"
)

train_ds = train_val_split["train"]
val_ds = train_val_split["test"]
test_ds = tokenized_full["validation"] # Original 'validation' used as final Test set

print(f"Counts -> Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}")

# ==========================================
# 3. MODEL (FULL FINE-TUNING - NO LORA)
# ==========================================
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(LABEL_MAP) # Correctly set to 3
).to(device)

print(f"Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# ==========================================
# 4. TRAINING ARGUMENTS
# ==========================================
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="weighted")
    }

training_args = TrainingArguments(
    output_dir = ROUTER_OUTPUT_DIR,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    metric_for_best_model = "f1",
    save_total_limit = 1,
    learning_rate = 2e-5,  # Lower LR for full fine-tuning (was 1e-4 for LoRA)
    per_device_train_batch_size = 32,
    num_train_epochs = 20,  # Increased from 6 to 20
    weight_decay = 0.01,  # Standard weight decay for full fine-tuning
    fp16 = True,
    optim = "adamw_torch",  # Standard AdamW (removed 8bit since no PEFT)
    logging_steps = 50,
    report_to = "none"
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train_ds,
    eval_dataset = val_ds,
    processing_class = tokenizer,
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer),
    compute_metrics = compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=8, early_stopping_threshold=0.01)]  # Patience=3, 5% min improvement
)

# ==========================================
# 5. EXECUTION & FINAL TEST EVALUATION
# ==========================================
print("Starting Full Fine-tuning...")
trainer.train()

print("\n" + "="*40)
print("--- FINAL EVALUATION ON ORIGINAL 'VALIDATION' (TEST) SET ---")
print("="*40)

# Evaluate on the Test set (original validation split)
test_results = trainer.predict(test_ds)
preds = np.argmax(test_results.predictions, axis=-1)
y_true = test_results.label_ids

# Use labels=[0, 1, 2] to force report to include all 3 variants
print(classification_report(
    y_true,
    preds,
    labels=[0, 1, 2],
    target_names=list(LABEL_MAP.keys())
))

trainer.save_model("./3variety_router_full_finetune_stratified_best")

Loading BESSTIE dataset...


sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Preprocessing and Tokenizing (Removing existing task labels)...


Map:   0%|          | 0/17760 [00:00<?, ? examples/s]

Map:   0%|          | 0/2428 [00:00<?, ? examples/s]

Stringifying the column:   0%|          | 0/17760 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/17760 [00:00<?, ? examples/s]

Counts -> Train: 14208 | Val: 3552 | Test: 2428


model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/197 [00:00<?, ?it/s]

XLMRobertaForSequenceClassification LOAD REPORT from: xlm-roberta-base
Key                         | Status     | 
----------------------------+------------+-
lm_head.bias                | UNEXPECTED | 
roberta.pooler.dense.bias   | UNEXPECTED | 
roberta.pooler.dense.weight | UNEXPECTED | 
lm_head.dense.weight        | UNEXPECTED | 
lm_head.layer_norm.bias     | UNEXPECTED | 
lm_head.layer_norm.weight   | UNEXPECTED | 
lm_head.dense.bias          | UNEXPECTED | 
classifier.out_proj.weight  | MISSING    | 
classifier.out_proj.bias    | MISSING    | 
classifier.dense.bias       | MISSING    | 
classifier.dense.weight     | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


Total trainable parameters: 278,045,955
Starting Full Fine-tuning...


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.703456,0.617916,0.72607,0.71532
2,0.543852,0.538215,0.780124,0.775272
3,0.378794,0.488664,0.821791,0.819956
4,0.246446,0.513288,0.833615,0.835582
5,0.182091,0.502095,0.869369,0.869032
6,0.135582,0.552878,0.870777,0.870431
7,0.111891,0.670187,0.875563,0.876398
8,0.075669,0.763386,0.880068,0.879935
9,0.074139,0.794933,0.881757,0.881538
10,0.054466,0.891875,0.885417,0.885307


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.703456,0.617916,0.72607,0.71532
2,0.543852,0.538215,0.780124,0.775272
3,0.378794,0.488664,0.821791,0.819956
4,0.246446,0.513288,0.833615,0.835582
5,0.182091,0.502095,0.869369,0.869032
6,0.135582,0.552878,0.870777,0.870431
7,0.111891,0.670187,0.875563,0.876398
8,0.075669,0.763386,0.880068,0.879935
9,0.074139,0.794933,0.881757,0.881538
10,0.054466,0.891875,0.885417,0.885307


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

There were missing keys in the checkpoint model loaded: ['roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'roberta.encoder.layer.0.attention.output.LayerNorm.weight', 'roberta.encoder.layer.0.attention.output.LayerNorm.bias', 'roberta.encoder.layer.0.output.LayerNorm.weight', 'roberta.encoder.layer.0.output.LayerNorm.bias', 'roberta.encoder.layer.1.attention.output.LayerNorm.weight', 'roberta.encoder.layer.1.attention.output.LayerNorm.bias', 'roberta.encoder.layer.1.output.LayerNorm.weight', 'roberta.encoder.layer.1.output.LayerNorm.bias', 'roberta.encoder.layer.2.attention.output.LayerNorm.weight', 'roberta.encoder.layer.2.attention.output.LayerNorm.bias', 'roberta.encoder.layer.2.output.LayerNorm.weight', 'roberta.encoder.layer.2.output.LayerNorm.bias', 'roberta.encoder.layer.3.attention.output.LayerNorm.weight', 'roberta.encoder.layer.3.attention.output.LayerNorm.bias', 'roberta.encoder.layer.3.output.LayerNorm.weight', 'roberta.encoder.layer.3.output.Laye


--- FINAL EVALUATION ON ORIGINAL 'VALIDATION' (TEST) SET ---


              precision    recall  f1-score   support

       en-UK       0.83      0.88      0.85       776
       en-AU       0.88      0.83      0.86       742
       en-IN       0.93      0.92      0.93       910

    accuracy                           0.88      2428
   macro avg       0.88      0.88      0.88      2428
weighted avg       0.88      0.88      0.88      2428



Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]