# Comparison of BERT, RoBERTa, and DeBERTa on Zindi Tweets Data

This notebook provides a clean, modular comparison framework for fine-tuning three state-of-the-art
transformer models on a tweets classification task from Zindi.

**Models compared:**
- **BERT** (`bert-base-uncased`) â€“ Bidirectional Encoder Representations from Transformers
- **RoBERTa** (`roberta-base`) â€“ Robustly Optimized BERT Approach
- **DeBERTa** (`microsoft/deberta-v3-base`) â€“ Decoding-enhanced BERT with Disentangled Attention

**Sections:**
1. Setup and Imports
2. Hyperparameter Configuration
3. Data Loading and Exploration
4. Text Preprocessing
5. Dataset and DataLoader
6. Training Pipeline
7. Train All Models
8. Evaluation and Metrics
9. Comparison and Visualization
10. Save Results


## 1. Setup and Imports

In [None]:
# â”€â”€ Install required libraries (uncomment in Colab / fresh environments)
# !pip install transformers datasets evaluate accelerate


In [None]:
import os
import re
import time
import json
import copy
import random
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score,
    recall_score, classification_report, confusion_matrix,
)

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup,
)
import evaluate

sns.set_theme(style='whitegrid')
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s',
                    datefmt='%H:%M:%S')
logger = logging.getLogger(__name__)

print('PyTorch version:', torch.__version__)


In [None]:
# â”€â”€ Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)
if DEVICE.type == 'cuda':
    print('GPU:', torch.cuda.get_device_name(0))
    print('Memory (GB):', round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1))


In [None]:
# â”€â”€ Reproducibility
SEED = 42

def set_seed(seed: int = 42) -> None:
    """Fix random seeds for full reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)
print(f'Random seed set to {SEED}')


## 2. Hyperparameter Configuration

Centralise all tunable parameters in one place for easy experimentation.


In [None]:
CONFIG = {
    # Data
    'data_path': 'Train.csv',       # Path to your Zindi tweets CSV
    'text_col': 'tweet',            # Column containing tweet text
    'label_col': 'label',           # Column containing class labels
    'test_size': 0.15,              # Fraction for test split
    'val_size': 0.15,               # Fraction of training data for validation

    # Tokenisation
    'max_length': 128,              # Max tokens (tweets are usually short)

    # Training
    'batch_size': 16,
    'num_epochs': 3,
    'learning_rate': 2e-5,
    'warmup_ratio': 0.1,            # Fraction of steps for linear warmup
    'weight_decay': 0.01,
    'gradient_clip': 1.0,           # Max gradient norm
    'fp16': torch.cuda.is_available(),  # Mixed precision (GPU only)

    # Early stopping
    'patience': 3,

    # Output
    'checkpoint_dir': 'checkpoints',
    'results_dir': 'results',
}

# â”€â”€ Model registry
MODEL_REGISTRY = {
    'BERT':    'bert-base-uncased',
    'RoBERTa': 'roberta-base',
    'DeBERTa': 'microsoft/deberta-v3-base',
}

Path(CONFIG['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
Path(CONFIG['results_dir']).mkdir(parents=True, exist_ok=True)
print('Configuration loaded.')
print(CONFIG)


## 3. Data Loading and Exploration

In [None]:
# â”€â”€ Load Zindi tweets dataset
data_path = Path(CONFIG['data_path'])

if not data_path.exists():
    print(f'WARNING: {data_path} not found.')
    print('Please set CONFIG["data_path"] to the correct path of your Zindi CSV.')
    # Create a small synthetic dataset for demonstration purposes
    print('\nCreating a synthetic demo dataset...')
    demo_texts = [
        'Covid-19 vaccines are saving lives! #health',
        'Stay home stay safe #COVID19',
        'This pandemic is a hoax created by the government',
        'Please wear your mask in public spaces',
        'RT @WHO: Wash your hands frequently https://t.co/abc123',
        'I got my vaccine today feeling great! ðŸ’‰',
        'Conspiracy theorists spreading misinformation again @username',
        'New variant detected in multiple countries #omicron',
    ] * 50
    demo_labels = ([0, 0, 1, 0, 0, 0, 1, 0] * 50)
    df = pd.DataFrame({'tweet': demo_texts, 'label': demo_labels})
    print(f'Demo dataset created: {df.shape}')
else:
    df = pd.read_csv(data_path)
    print(f'Loaded: {data_path}  |  Shape: {df.shape}')

# Standardise column names
df.columns = [c.strip().lower().replace(' ', '_') for c in df.columns]

# Auto-detect text and label columns if not matching config exactly
TEXT_COL = CONFIG['text_col'] if CONFIG['text_col'] in df.columns else df.columns[0]
LABEL_COL = CONFIG['label_col'] if CONFIG['label_col'] in df.columns else df.columns[-1]
print(f'Text column: "{TEXT_COL}" | Label column: "{LABEL_COL}"')


In [None]:
# â”€â”€ Exploratory Data Analysis
print('=== Dataset overview ===')
display(df.head(5))
print('Shape:', df.shape)
print('\nMissing values:')
print(df[[TEXT_COL, LABEL_COL]].isnull().sum())

# Drop rows with missing text or label
n_before = len(df)
df = df.dropna(subset=[TEXT_COL, LABEL_COL]).copy()
df[TEXT_COL] = df[TEXT_COL].astype(str).str.strip()
print(f'\nDropped {n_before - len(df)} rows with missing values. Remaining: {len(df)}')


In [None]:
# â”€â”€ Class distribution
label_counts = df[LABEL_COL].value_counts().sort_index()
NUM_CLASSES = df[LABEL_COL].nunique()
print(f'Number of classes: {NUM_CLASSES}')
print(label_counts)

fig, ax = plt.subplots(figsize=(7, 4))
label_counts.plot(kind='bar', ax=ax, color='steelblue', edgecolor='white')
ax.set_title('Class Distribution')
ax.set_xlabel('Label')
ax.set_ylabel('Count')
plt.xticks(rotation=0)
plt.tight_layout()
plt.savefig(f"{CONFIG['results_dir']}/class_distribution.png", dpi=120)
plt.show()


In [None]:
# â”€â”€ Tweet length analysis
df['tweet_len'] = df[TEXT_COL].str.split().str.len()
print('Tweet word-count statistics:')
print(df['tweet_len'].describe().round(1))

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(df['tweet_len'], bins=40, color='coral', edgecolor='white')
ax.axvline(df['tweet_len'].median(), color='navy', linestyle='--', label=f'Median: {df["tweet_len"].median():.0f}')
ax.set_title('Tweet Length Distribution (words)')
ax.set_xlabel('Word count')
ax.set_ylabel('Frequency')
ax.legend()
plt.tight_layout()
plt.savefig(f"{CONFIG['results_dir']}/tweet_length.png", dpi=120)
plt.show()


In [None]:
# â”€â”€ Sample tweets per class
for label in sorted(df[LABEL_COL].unique()):
    print(f'\n--- Class {label} (sample) ---')
    samples = df[df[LABEL_COL] == label][TEXT_COL].head(3).tolist()
    for s in samples:
        print(' â€¢', s)


## 4. Text Preprocessing

Twitter-specific cleaning pipeline that:
- Removes or replaces URLs (replaced with `[URL]`)
- Removes `@mentions` (replaced with `[USER]`)
- Normalises hashtags (keeps the word, removes `#`)
- Handles retweet markers (`RT`)
- Preserves emojis (transformers can handle them)
- Strips extra whitespace


In [None]:
def clean_tweet(text: str) -> str:
    """Apply Twitter-specific text cleaning."""
    # Retweet marker
    text = re.sub(r'^RT\s+', '', text, flags=re.IGNORECASE)
    # URLs
    text = re.sub(r'https?://\S+|www\.\S+', '[URL]', text)
    # Mentions
    text = re.sub(r'@\w+', '[USER]', text)
    # Hashtags â€“ keep the word without the '#'
    text = re.sub(r'#(\w+)', r'\1', text)
    # Collapse repeated punctuation (e.g. '!!!' -> '!')
    text = re.sub(r'([!?.]){2,}', r'\1', text)
    # Extra whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Apply cleaning
df['cleaned_text'] = df[TEXT_COL].apply(clean_tweet)

# Show before/after examples
print('Cleaning examples:')
for _, row in df.head(3).iterrows():
    print(f'  BEFORE: {row[TEXT_COL]}')
    print(f'  AFTER : {row["cleaned_text"]}')
    print()


In [None]:
# â”€â”€ Encode labels as integers
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
df['encoded_label'] = le.fit_transform(df[LABEL_COL])
label_map = dict(enumerate(le.classes_))
print('Label mapping:', label_map)
NUM_CLASSES = len(label_map)
print('NUM_CLASSES:', NUM_CLASSES)


In [None]:
# â”€â”€ Train / Validation / Test split
train_val_df, test_df = train_test_split(
    df, test_size=CONFIG['test_size'],
    stratify=df['encoded_label'], random_state=SEED,
)
train_df, val_df = train_test_split(
    train_val_df, test_size=CONFIG['val_size'],
    stratify=train_val_df['encoded_label'], random_state=SEED,
)

print(f'Train: {len(train_df):,}  |  Val: {len(val_df):,}  |  Test: {len(test_df):,}')

# Save splits for reproducibility
train_df.to_csv(f"{CONFIG['results_dir']}/train_split.csv", index=False)
val_df.to_csv(f"{CONFIG['results_dir']}/val_split.csv", index=False)
test_df.to_csv(f"{CONFIG['results_dir']}/test_split.csv", index=False)
print('Splits saved.')


## 5. Dataset and DataLoader

A single `TweetDataset` class works for all three tokenizers.


In [None]:
class TweetDataset(Dataset):
    """
    PyTorch Dataset for tweet classification.

    Tokenises on-the-fly so the same dataset class works with
    BERT, RoBERTa, and DeBERTa tokenizers.
    """

    def __init__(self, texts: List[str], labels: List[int],
                 tokenizer, max_length: int = 128) -> None:
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        encoding = self.tokenizer(
            self.texts[idx],
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt',
        )
        item = {k: v.squeeze(0) for k, v in encoding.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item


def build_dataloaders(
    train_df: pd.DataFrame,
    val_df: pd.DataFrame,
    test_df: pd.DataFrame,
    tokenizer,
    text_col: str = 'cleaned_text',
    label_col: str = 'encoded_label',
    batch_size: int = 16,
    max_length: int = 128,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Build train / val / test DataLoaders for a given tokenizer."""
    train_ds = TweetDataset(train_df[text_col].tolist(), train_df[label_col].tolist(),
                            tokenizer, max_length)
    val_ds   = TweetDataset(val_df[text_col].tolist(),   val_df[label_col].tolist(),
                            tokenizer, max_length)
    test_ds  = TweetDataset(test_df[text_col].tolist(),  test_df[label_col].tolist(),
                            tokenizer, max_length)

    pin = torch.cuda.is_available()
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  pin_memory=pin)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, pin_memory=pin)
    return train_loader, val_loader, test_loader


## 6. Training Pipeline

A **unified** `train_model` function that works for all three transformers with:
- AdamW optimiser (recommended for transformers)
- Linear warmup + linear decay learning-rate schedule
- Gradient clipping to prevent exploding gradients
- Mixed-precision (FP16) training when a GPU is available
- Early stopping based on validation F1
- Best-model checkpointing


In [None]:
def compute_metrics_from_preds(preds: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
    """Compute accuracy, F1 (macro & weighted), precision and recall."""
    return {
        'accuracy':           accuracy_score(labels, preds),
        'f1_macro':           f1_score(labels, preds, average='macro',    zero_division=0),
        'f1_weighted':        f1_score(labels, preds, average='weighted', zero_division=0),
        'precision_macro':    precision_score(labels, preds, average='macro',    zero_division=0),
        'precision_weighted': precision_score(labels, preds, average='weighted', zero_division=0),
        'recall_macro':       recall_score(labels, preds, average='macro',    zero_division=0),
        'recall_weighted':    recall_score(labels, preds, average='weighted', zero_division=0),
    }


def evaluate_loader(
    model: torch.nn.Module,
    loader: DataLoader,
    device: torch.device,
) -> Tuple[float, Dict[str, float], np.ndarray, np.ndarray]:
    """
    Run model on *loader* and return:
      (avg_loss, metrics_dict, all_preds, all_labels)
    """
    model.eval()
    total_loss = 0.0
    all_preds: List[int] = []
    all_labels: List[int] = []

    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            total_loss += outputs.loss.item()
            preds = outputs.logits.argmax(dim=-1).cpu().numpy()
            all_preds.extend(preds.tolist())
            all_labels.extend(batch['labels'].cpu().numpy().tolist())

    avg_loss = total_loss / len(loader)
    metrics = compute_metrics_from_preds(np.array(all_preds), np.array(all_labels))
    return avg_loss, metrics, np.array(all_preds), np.array(all_labels)


In [None]:
def train_transformer(
    model_name: str,
    model_display: str,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_classes: int,
    config: dict,
    device: torch.device,
) -> Dict:
    """
    Unified fine-tuning function for BERT, RoBERTa, and DeBERTa.

    Parameters
    ----------
    model_name     : HuggingFace model identifier
    model_display  : Human-readable name for logging
    train_loader   : Training DataLoader (already tokenised for this model)
    val_loader     : Validation DataLoader
    num_classes    : Number of output classes
    config         : CONFIG dict
    device         : torch.device

    Returns
    -------
    dict with training history and the best model state dict.
    """
    logger.info('\n%s\n  Training: %s  (%s)\n%s',
                '='*60, model_display, model_name, '='*60)

    # â”€â”€ Load model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=num_classes, ignore_mismatched_sizes=True,
    )
    model.to(device)

    # â”€â”€ Optimiser: AdamW with weight decay on non-bias/norm params
    no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
    param_groups = [
        {'params': [p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)],
         'weight_decay': config['weight_decay']},
        {'params': [p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0},
    ]
    optimizer = torch.optim.AdamW(param_groups, lr=config['learning_rate'])

    # â”€â”€ Learning-rate schedule: linear warmup + linear decay
    total_steps = len(train_loader) * config['num_epochs']
    warmup_steps = int(total_steps * config['warmup_ratio'])
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps,
    )

    # â”€â”€ Mixed precision (GPU only)
    # DeBERTa-v3: disable fp16 (known instability) and use a smaller LR
    # per the DeBERTa-v3 paper (He et al., 2021) which recommends 1e-5.
    use_fp16 = config['fp16'] and device.type == 'cuda'
    if 'deberta-v3' in model_name.lower():
        use_fp16 = False
        # Rebuild optimizer with DeBERTa-specific learning rate
        deberta_lr = min(config['learning_rate'], 1e-5)
        if deberta_lr < config['learning_rate']:
            logger.info('[%s] Using DeBERTa-v3 learning rate: %.0e', model_display, deberta_lr)
            optimizer = torch.optim.AdamW(param_groups, lr=deberta_lr)
            # Rebuild scheduler for new optimizer
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps,
            )
    scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)

    # â”€â”€ Training state
    history = {
        'train_loss': [], 'val_loss': [],
        'train_f1': [],   'val_f1': [],
        'train_acc': [],  'val_acc': [],
        'epoch_times': [],
    }
    best_val_f1 = -1.0
    best_state = None
    patience_counter = 0
    ckpt_path = Path(config['checkpoint_dir']) / f"{model_display.replace('/', '_')}_best.pt"

    train_start = time.time()

    for epoch in range(config['num_epochs']):
        epoch_start = time.time()
        model.train()
        total_loss = 0.0
        all_preds, all_labels = [], []

        for batch in tqdm(train_loader,
                          desc=f'[{model_display}] Epoch {epoch+1}/{config["num_epochs"]}',
                          leave=False):
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()

            if use_fp16:
                with torch.cuda.amp.autocast():
                    outputs = model(**batch)
                    loss = outputs.loss
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip'])
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(**batch)
                loss = outputs.loss
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip'])
                optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            preds = outputs.logits.detach().argmax(dim=-1).cpu().numpy()
            all_preds.extend(preds.tolist())
            all_labels.extend(batch['labels'].cpu().numpy().tolist())

        train_loss = total_loss / len(train_loader)
        train_metrics = compute_metrics_from_preds(np.array(all_preds), np.array(all_labels))
        val_loss, val_metrics, _, _ = evaluate_loader(model, val_loader, device)
        epoch_time = time.time() - epoch_start

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_f1'].append(train_metrics['f1_weighted'])
        history['val_f1'].append(val_metrics['f1_weighted'])
        history['train_acc'].append(train_metrics['accuracy'])
        history['val_acc'].append(val_metrics['accuracy'])
        history['epoch_times'].append(epoch_time)

        logger.info(
            '[%s] Epoch %d/%d  train_loss=%.4f  val_loss=%.4f  '
            'train_f1=%.4f  val_f1=%.4f  val_acc=%.4f  time=%.1fs',
            model_display, epoch + 1, config['num_epochs'],
            train_loss, val_loss,
            train_metrics['f1_weighted'], val_metrics['f1_weighted'],
            val_metrics['accuracy'], epoch_time,
        )

        # â”€â”€ Checkpoint best model
        if val_metrics['f1_weighted'] > best_val_f1:
            best_val_f1 = val_metrics['f1_weighted']
            best_state = copy.deepcopy(model.state_dict())
            torch.save(best_state, ckpt_path)
            patience_counter = 0
            logger.info('  âœ“ New best val_f1=%.4f â€“ checkpoint saved to %s', best_val_f1, ckpt_path)
        else:
            patience_counter += 1
            if patience_counter >= config['patience']:
                logger.info('  Early stopping triggered at epoch %d.', epoch + 1)
                break

    total_time = time.time() - train_start
    history['total_time'] = total_time
    history['best_val_f1'] = best_val_f1
    history['model_display'] = model_display
    history['best_state'] = best_state

    logger.info('[%s] Training complete in %.1fs  best_val_f1=%.4f',
                model_display, total_time, best_val_f1)
    return history


## 7. Train BERT, RoBERTa, and DeBERTa

Each model uses its own tokenizer, but the same training function.


In [None]:
all_histories: Dict[str, dict] = {}
all_models: Dict[str, torch.nn.Module] = {}
all_tokenizers: Dict[str, object] = {}
all_test_loaders: Dict[str, DataLoader] = {}

for display_name, hf_name in MODEL_REGISTRY.items():
    print(f'\n{"="*60}')
    print(f'  Loading tokenizer: {display_name} ({hf_name})')
    print(f'{"="*60}')

    tokenizer = AutoTokenizer.from_pretrained(hf_name)
    train_loader, val_loader, test_loader = build_dataloaders(
        train_df, val_df, test_df, tokenizer,
        batch_size=CONFIG['batch_size'],
        max_length=CONFIG['max_length'],
    )

    history = train_transformer(
        model_name=hf_name,
        model_display=display_name,
        train_loader=train_loader,
        val_loader=val_loader,
        num_classes=NUM_CLASSES,
        config=CONFIG,
        device=DEVICE,
    )

    # Reload best model weights
    best_model = AutoModelForSequenceClassification.from_pretrained(
        hf_name, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True,
    )
    best_model.load_state_dict(history['best_state'])
    best_model.to(DEVICE)

    all_histories[display_name] = history
    all_models[display_name] = best_model
    all_tokenizers[display_name] = tokenizer
    all_test_loaders[display_name] = test_loader

    # Free GPU memory between models
    torch.cuda.empty_cache()

print('\nAll models trained successfully!')


## 8. Evaluation and Metrics

Evaluate each model on the held-out test set and collect:
- Accuracy, F1 (macro & weighted), Precision, Recall
- Confusion matrix
- Model parameter count
- Average inference time per sample (ms)


In [None]:
def get_inference_time_ms(
    model: torch.nn.Module,
    loader: DataLoader,
    device: torch.device,
    n_warmup: int = 5,
) -> float:
    """Average per-sample inference time in milliseconds."""
    model.eval()
    times = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            t0 = time.perf_counter()
            _ = model(**batch)
            t1 = time.perf_counter()
            if i >= n_warmup:
                times.append((t1 - t0) * 1000 / batch['input_ids'].size(0))
            if i >= n_warmup + 50:  # cap at 50 batches for speed
                break
    return float(np.mean(times)) if times else 0.0


In [None]:
test_results: Dict[str, dict] = {}

for display_name, model in all_models.items():
    print(f'\nEvaluating {display_name}...')
    test_loader = all_test_loaders[display_name]

    _, metrics, preds, labels = evaluate_loader(model, test_loader, DEVICE)

    # Parameter count
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Inference time
    inf_time_ms = get_inference_time_ms(model, test_loader, DEVICE)

    # Confusion matrix
    cm = confusion_matrix(labels, preds)

    # Detailed classification report
    report = classification_report(
        labels, preds,
        target_names=[str(label_map[i]) for i in range(NUM_CLASSES)],
        zero_division=0,
    )

    test_results[display_name] = {
        **metrics,
        'total_params': total_params,
        'trainable_params': trainable_params,
        'inference_ms': inf_time_ms,
        'total_train_time_s': all_histories[display_name]['total_time'],
        'best_val_f1': all_histories[display_name]['best_val_f1'],
        'confusion_matrix': cm,
        'preds': preds,
        'labels': labels,
        'report': report,
    }

    print(f'  Accuracy:    {metrics["accuracy"]:.4f}')
    print(f'  F1 Macro:    {metrics["f1_macro"]:.4f}')
    print(f'  F1 Weighted: {metrics["f1_weighted"]:.4f}')
    print(f'  Params (M):  {total_params/1e6:.1f}')
    print(f'  Infer. (ms): {inf_time_ms:.2f}')
    print('\nClassification report:')
    print(report)


## 9. Comparison and Visualization

In [None]:
# â”€â”€ Summary DataFrame
summary_rows = []
for name, res in test_results.items():
    summary_rows.append({
        'Model':               name,
        'Accuracy (%)':        round(res['accuracy'] * 100, 2),
        'F1 Macro (%)':        round(res['f1_macro'] * 100, 2),
        'F1 Weighted (%)':     round(res['f1_weighted'] * 100, 2),
        'Precision Macro (%)': round(res['precision_macro'] * 100, 2),
        'Recall Macro (%)':    round(res['recall_macro'] * 100, 2),
        'Params (M)':          round(res['total_params'] / 1e6, 1),
        'Train Time (s)':      round(res['total_train_time_s'], 1),
        'Infer. Time (ms)':    round(res['inference_ms'], 2),
    })

summary_df = pd.DataFrame(summary_rows).sort_values('F1 Weighted (%)', ascending=False)
display(summary_df.reset_index(drop=True))


In [None]:
# â”€â”€ Training curves for all models
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
colors = {'BERT': 'steelblue', 'RoBERTa': 'coral', 'DeBERTa': 'seagreen'}

for name, hist in all_histories.items():
    epochs = range(1, len(hist['train_loss']) + 1)
    c = colors.get(name, 'gray')
    axes[0].plot(epochs, hist['val_loss'],  label=name, color=c)
    axes[1].plot(epochs, hist['val_f1'],    label=name, color=c)

axes[0].set_title('Validation Loss')
axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
axes[0].legend()
axes[1].set_title('Validation F1 (weighted)')
axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('F1')
axes[1].legend()
fig.suptitle('Training Curves â€“ All Models', fontsize=14)
plt.tight_layout()
plt.savefig(f"{CONFIG['results_dir']}/training_curves.png", dpi=120)
plt.show()


In [None]:
# â”€â”€ Comparison bar charts
metrics_to_plot = [
    ('Accuracy (%)',    'Accuracy'),
    ('F1 Macro (%)',    'F1 Macro'),
    ('F1 Weighted (%)', 'F1 Weighted'),
]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
model_names_list = summary_df['Model'].tolist()
bar_colors = [colors.get(n, 'gray') for n in model_names_list]

for ax, (col, title) in zip(axes, metrics_to_plot):
    vals = summary_df.set_index('Model')[col].reindex(model_names_list)
    bars = ax.bar(model_names_list, vals, color=bar_colors, edgecolor='white')
    ax.bar_label(bars, fmt='%.1f', padding=3, fontsize=10)
    ax.set_title(title)
    ax.set_ylabel('%')
    ax.set_ylim(0, min(100, vals.max() * 1.15))
    ax.set_xticklabels(model_names_list, rotation=15)

fig.suptitle('Model Comparison â€“ Classification Metrics', fontsize=14)
plt.tight_layout()
plt.savefig(f"{CONFIG['results_dir']}/metric_comparison.png", dpi=120)
plt.show()


In [None]:
# â”€â”€ Performance vs Efficiency scatter plots
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

for name in model_names_list:
    res = test_results[name]
    c = colors.get(name, 'gray')
    f1 = res['f1_weighted'] * 100
    axes[0].scatter(res['total_train_time_s'], f1, s=120, color=c, label=name, zorder=3)
    axes[0].annotate(name, (res['total_train_time_s'], f1),
                     textcoords='offset points', xytext=(6, 4), fontsize=9)
    axes[1].scatter(res['inference_ms'], f1, s=120, color=c, label=name, zorder=3)
    axes[1].annotate(name, (res['inference_ms'], f1),
                     textcoords='offset points', xytext=(6, 4), fontsize=9)

axes[0].set_xlabel('Training Time (s)'); axes[0].set_ylabel('F1 Weighted (%)')
axes[0].set_title('F1 vs Training Time')
axes[1].set_xlabel('Inference Time (ms/sample)'); axes[1].set_ylabel('F1 Weighted (%)')
axes[1].set_title('F1 vs Inference Time')

for ax in axes:
    ax.grid(True, alpha=0.4)
    ax.legend()

plt.tight_layout()
plt.savefig(f"{CONFIG['results_dir']}/performance_vs_efficiency.png", dpi=120)
plt.show()


In [None]:
# â”€â”€ Confusion matrices
n_models = len(test_results)
fig, axes = plt.subplots(1, n_models, figsize=(6 * n_models, 5))
if n_models == 1:
    axes = [axes]

class_label_names = [str(label_map[i]) for i in range(NUM_CLASSES)]

for ax, (name, res) in zip(axes, test_results.items()):
    cm_norm = res['confusion_matrix'].astype(float)
    row_sums = cm_norm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm_norm, row_sums, where=row_sums != 0)
    sns.heatmap(
        cm_norm, annot=True, fmt='.2f', cmap='Blues',
        xticklabels=class_label_names,
        yticklabels=class_label_names,
        ax=ax, cbar=False,
    )
    ax.set_title(f'{name} â€“ Confusion Matrix')
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')

plt.tight_layout()
plt.savefig(f"{CONFIG['results_dir']}/confusion_matrices.png", dpi=120)
plt.show()


## 10. Save Results

Save all artefacts for reproducibility and Zindi submission.


In [None]:
# â”€â”€ Save summary CSV
summary_csv_path = f"{CONFIG['results_dir']}/model_comparison.csv"
summary_df.to_csv(summary_csv_path, index=False)
print(f'Comparison table saved to {summary_csv_path}')
display(summary_df.reset_index(drop=True))


In [None]:
# â”€â”€ Generate predictions on test set in Zindi submission format
best_model_name = summary_df.iloc[0]['Model']
print(f'Best model: {best_model_name} (highest F1 Weighted)')

best_preds = test_results[best_model_name]['preds']
# Decode integer labels back to original class labels
decoded_preds = le.inverse_transform(best_preds)

submission_df = test_df.reset_index(drop=True).copy()
submission_df['predicted_label'] = decoded_preds

submission_path = f"{CONFIG['results_dir']}/zindi_submission.csv"
# Keep only ID-like column + prediction; adjust 'id' column name to match Zindi format
id_cols = [c for c in submission_df.columns if 'id' in c.lower()]
out_cols = (id_cols if id_cols else []) + ['predicted_label']
submission_df[out_cols].to_csv(submission_path, index=False)
print(f'Zindi submission saved to {submission_path}')
display(submission_df[out_cols].head())


In [None]:
# â”€â”€ Save all predictions per model
all_preds_rows = []
for name, res in test_results.items():
    for i, (pred, true) in enumerate(zip(res['preds'], res['labels'])):
        all_preds_rows.append({'model': name, 'index': i,
                               'true_label': int(true), 'predicted_label': int(pred)})
all_preds_df = pd.DataFrame(all_preds_rows)
all_preds_df.to_csv(f"{CONFIG['results_dir']}/all_model_predictions.csv", index=False)
print(f"All predictions saved to {CONFIG['results_dir']}/all_model_predictions.csv")


In [None]:
# â”€â”€ Save trained models
for name, model in all_models.items():
    save_dir = Path(CONFIG['checkpoint_dir']) / name.replace('/', '_')
    save_dir.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(str(save_dir))
    all_tokenizers[name].save_pretrained(str(save_dir))
    print(f'{name} model + tokenizer saved to {save_dir}')


## Summary and Recommendations

After comparing BERT, RoBERTa, and DeBERTa on the Zindi tweets classification task, the table above
provides a comprehensive view of each model's strengths:

| Model | Strengths | Trade-offs |
|-------|-----------|------------|
| **BERT** | Widely supported, fast to fine-tune | Lower performance on noisy text |
| **RoBERTa** | Better handling of subword tokens, no NSP pre-training noise | Slightly larger than BERT |
| **DeBERTa** | State-of-the-art NLU benchmarks, disentangled attention | Slowest inference, fp16 instability |

**General guidance:**
- For **highest accuracy**, prefer **DeBERTa-v3** or **RoBERTa**.
- For **fastest inference**, prefer **BERT** or **RoBERTa**.
- For **resource-constrained** environments, **BERT** is the best baseline.

The winning model for your specific dataset will be the one with the highest **F1 Weighted** score
in the comparison table above.
