# ModernBERT Radical Optimization - 20 Newsgroups
## A complete rewrite using Hugging Face Trainer API + ModernBERT-optimal settings.
**Goal:** Break the 71% accuracy ceiling (Target: >83%).

In [None]:
import os
# Disable parallelism to avoid deadlocks in some environments
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    TrainerCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
# ------------------------------------------------------------------
# 1. Configuration (The "Radical" Setup)
# ------------------------------------------------------------------
MODEL_NAME = "answerdotai/ModernBERT-base"
DATASET_NAME = "SetFit/20_newsgroups"
MAX_LENGTH = 512  # Capture full context
BATCH_SIZE = 16   # Per device (T4)
GRAD_ACCUM = 2    # Effective batch size = 16 * 2 * Num_GPUs = 64 (on 2 GPUs)
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 0.01
EPOCHS = 4
WARMUP_RATIO = 0.1
LABEL_SMOOTHING = 0.1  # Key for better generalization!

In [None]:
# ------------------------------------------------------------------
# 2. Compilation Fix Callback
# ------------------------------------------------------------------
class DisableCompiledMLPCallback(TrainerCallback):
    """
    ModernBERT's `compiled_mlp` can conflict with some distributed settings.
    This callback ensures it's disabled before training starts.
    """
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if model is None: return
        # Unwrap DDP/DataParallel if present
        if hasattr(model, "module"):
            model = model.module
        
        patched = 0
        for module in model.modules():
            if hasattr(module, "compiled_mlp"):
                try:
                    # Fallback to eager implementation if available/safe
                    if hasattr(module, "mlp"):
                        module.compiled_mlp = module.mlp
                        patched += 1
                except Exception:
                    pass
        if patched > 0:
            print(f"[Callback] Disabled compiled_mlp in {patched} layers for stability.")

# ------------------------------------------------------------------
# 3. Data Preparation
# ------------------------------------------------------------------
def prepare_data():
    print("Loading dataset...")
    # Load 20 Newsgroups
    dataset = load_dataset(DATASET_NAME)
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            padding=False, # Dynamic padding via collator later
            max_length=MAX_LENGTH
        )
    
    print("Tokenizing...")
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    
    # Remove raw text column to save memory
    tokenized_datasets = tokenized_datasets.remove_columns(["text"])
    
    # Rename label column if needed (SetFit dataset uses 'label', so we are good)
    # Ensure torch format
    tokenized_datasets.set_format("torch")
    
    return tokenized_datasets, tokenizer

# ------------------------------------------------------------------
# 4. Metrics
# ------------------------------------------------------------------
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average="weighted"
    )
    acc = accuracy_score(labels, predictions)
    
    return {
        "accuracy": acc,
        "f1": f1,
        "precision": precision,
        "recall": recall
    }

In [None]:
# ------------------------------------------------------------------
# 5. Main Training Execution
# ------------------------------------------------------------------

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Load Data
tokenized_datasets, tokenizer = prepare_data()
train_ds = tokenized_datasets["train"]
eval_ds = tokenized_datasets["test"]

# Model Init
# Get label mappings
labels = train_ds.features["label"].names
num_labels = len(labels)
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for i, l in enumerate(labels)}

print(f"Initializing {MODEL_NAME} with {num_labels} labels...")
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    # attn_implementation="flash_attention_2" if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else None
)

# Define Training Arguments (The Radical Shift)
training_args = TrainingArguments(
    output_dir="./modernbert_radical_output",
    
    # Batch Size & Optimization
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE * 2,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    num_train_epochs=EPOCHS,
    warmup_ratio=WARMUP_RATIO,
    
    # Generalization Tricks
    label_smoothing_factor=LABEL_SMOOTHING,
    
    # Evaluation & Saving
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    
    # Efficiency
    fp16=True if torch.cuda.is_available() else False,
    dataloader_num_workers=4,
    group_by_length=True, # Sorts batches by length -> faster training
    
    # Logging
    logging_steps=50,
    report_to="none" # Disable wandb for this script
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer),
    compute_metrics=compute_metrics,
    callbacks=[DisableCompiledMLPCallback()]
)

# Train
print("Starting Radical Training...")
trainer.train()

# Final Eval
print("\nFinal Evaluation...")
metrics = trainer.evaluate()
print(metrics)

# Save Model
trainer.save_model("./modernbert_radical_final")
print("Model saved to ./modernbert_radical_final")