!/usr/bin/env python
coding: utf-8

!/usr/bin/env python
coding: utf-8


## 1. Imports & Environment Setup

In [1]:


import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Prevent fork warnings with DataLoader
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import json
import random
import time
import copy
import optuna
import gc
from typing import Tuple
from dataclasses import dataclass, field
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torch.optim import AdamW

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoConfig,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix
)
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)} "
              f"({torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB)")
    torch.backends.cudnn.benchmark = True

PyTorch version: 2.9.0+cu126
CUDA available: True
Number of GPUs: 2
  GPU 0: Tesla T4 (14.6 GB)
  GPU 1: Tesla T4 (14.6 GB)


## 2. Configuration

In [2]:


@dataclass
class Config:
    """Configuration class with all hyperparameters and settings."""
    
    # Model Configuration
    model_name: str = "answerdotai/ModernBERT-large"
    num_labels: int = 20
    
    # Training Hyperparameters
    learning_rate: float = 3e-5
    batch_size: int = 16  # Per-GPU batch size (total = batch_size × num_gpus)
    num_epochs: int = 4
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    
    # Layer Freezing
    freeze_layers: bool = True
    freeze_ratio: float = 0.5  # Freeze bottom 50% of encoder layers (14/28)
    
    # Data Configuration
    dataset_name: str = "SetFit/20_newsgroups"
    max_length: int = 256  
    
    # Training Settings
    seed: int = 42
    use_fp16: bool = True
    save_model: bool = True
    output_dir: str = "/kaggle/tmp/output" if os.path.exists("/kaggle/working") else "./output"
    
    # Device (auto-detected)
    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
    num_gpus: int = field(default_factory=lambda: torch.cuda.device_count() if torch.cuda.is_available() else 0)
    
    def __post_init__(self):
        if self.device == "cpu":
            self.use_fp16 = False
        # Scale batch size across GPUs
        self.total_batch_size = self.batch_size * max(1, self.num_gpus)
            
    def to_dict(self) -> dict:
        return {
            "model_name": self.model_name,
            "num_labels": self.num_labels,
            "learning_rate": self.learning_rate,
            "batch_size_per_gpu": self.batch_size,
            "num_gpus": self.num_gpus,
            "total_batch_size": self.total_batch_size,
            "num_epochs": self.num_epochs,
            "warmup_ratio": self.warmup_ratio,
            "weight_decay": self.weight_decay,
            "max_length": self.max_length,
            "freeze_layers": self.freeze_layers,
            "freeze_ratio": self.freeze_ratio,
            "seed": self.seed,
            "use_fp16": self.use_fp16,
            "device": self.device,
        }

# Initialize configuration
config = Config()

print("Configuration:")
for key, value in config.to_dict().items():
    print(f"  {key}: {value}")

Configuration:
  model_name: answerdotai/ModernBERT-large
  num_labels: 20
  learning_rate: 3e-05
  batch_size_per_gpu: 16
  num_gpus: 2
  total_batch_size: 32
  num_epochs: 4
  warmup_ratio: 0.1
  weight_decay: 0.01
  max_length: 256
  freeze_layers: True
  freeze_ratio: 0.5
  seed: 42
  use_fp16: True
  device: cuda


## 3. Dataset Exploration & Statistical Overview

In [3]:


def explore_dataset(config):
    """Load and display comprehensive dataset statistics."""
    print("\n" + "="*70)
    print("DATASET EXPLORATION: 20 Newsgroups")
    print("="*70)
    
    # Load raw dataset
    dataset = load_dataset(config.dataset_name)
    train_data = dataset['train']
    test_data = dataset['test']
    
    label_names = [
        'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc',
        'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x',
        'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball',
        'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med',
        'sci.space', 'soc.religion.christian', 'talk.politics.guns',
        'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'
    ]
    
    # --- Basic Info ---
    print(f"\n{'─'*50}")
    print(f"  Dataset: {config.dataset_name}")
    print(f"  Number of classes: {len(label_names)}")
    print(f"  Train samples: {len(train_data):,}")
    print(f"  Test samples:  {len(test_data):,}")
    print(f"  Total samples: {len(train_data) + len(test_data):,}")
    print(f"  Features: {list(train_data.features.keys())}")
    print(f"{'─'*50}")
    
    # --- Class Distribution ---
    print(f"\n{'─'*50}")
    print("  CLASS DISTRIBUTION")
    print(f"{'─'*50}")
    
    train_labels = train_data['label']
    test_labels = test_data['label']
    train_counts = Counter(train_labels)
    test_counts = Counter(test_labels)
    
    print(f"\n  {'Category':<35} {'Train':>6} {'Test':>6} {'Total':>6}")
    print(f"  {'─'*55}")
    for i, name in enumerate(label_names):
        tr = train_counts.get(i, 0)
        te = test_counts.get(i, 0)
        bar = '█' * (tr // 20)
        print(f"  {name:<35} {tr:>6} {te:>6} {tr+te:>6}  {bar}")
    
    print(f"  {'─'*55}")
    print(f"  {'TOTAL':<35} {len(train_data):>6} {len(test_data):>6} {len(train_data)+len(test_data):>6}")
    
    # Class balance metrics
    train_counts_list = [train_counts.get(i, 0) for i in range(len(label_names))]
    print(f"\n  Train class balance:")
    print(f"    Min samples/class: {min(train_counts_list)}")
    print(f"    Max samples/class: {max(train_counts_list)}")
    print(f"    Mean samples/class: {np.mean(train_counts_list):.1f}")
    print(f"    Std samples/class: {np.std(train_counts_list):.1f}")
    print(f"    Imbalance ratio (max/min): {max(train_counts_list)/max(min(train_counts_list),1):.2f}")
    
    # --- Text Length Statistics ---
    print(f"\n{'─'*50}")
    print("  TEXT LENGTH STATISTICS (Training Set)")
    print(f"{'─'*50}")
    
    texts = train_data['text']
    char_lengths = [len(t) for t in texts]
    word_lengths = [len(t.split()) for t in texts]
    
    for metric_name, lengths in [("Character lengths", char_lengths), ("Word counts", word_lengths)]:
        arr = np.array(lengths)
        print(f"\n  {metric_name}:")
        print(f"    Min:    {arr.min():>8,}")
        print(f"    Max:    {arr.max():>8,}")
        print(f"    Mean:   {arr.mean():>8,.1f}")
        print(f"    Median: {np.median(arr):>8,.1f}")
        print(f"    Std:    {arr.std():>8,.1f}")
        print(f"    P25:    {np.percentile(arr, 25):>8,.1f}")
        print(f"    P75:    {np.percentile(arr, 75):>8,.1f}")
        print(f"    P95:    {np.percentile(arr, 95):>8,.1f}")
    
    # Token-level stats with tokenizer
    print(f"\n  Tokenized lengths (using {config.model_name} tokenizer):")
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    
    # Sample for speed (full tokenization on large dataset is slow)
    sample_size = min(2000, len(texts))
    sample_texts = random.sample(texts, sample_size)
    token_lengths = [len(tokenizer.encode(t)) for t in sample_texts]
    arr = np.array(token_lengths)
    
    print(f"    (Sampled {sample_size:,} documents)")
    print(f"    Min:    {arr.min():>8,}")
    print(f"    Max:    {arr.max():>8,}")
    print(f"    Mean:   {arr.mean():>8,.1f}")
    print(f"    Median: {np.median(arr):>8,.1f}")
    print(f"    P95:    {np.percentile(arr, 95):>8,.1f}")
    
    # Coverage at different max_length thresholds
    print(f"\n  Token coverage at different max_length:")
    for ml in [128, 256, 512]:
        coverage = (arr <= ml).sum() / len(arr) * 100
        print(f"    max_length={ml}: {coverage:.1f}% of documents fully covered")
    print(f"    → Using max_length={config.max_length}")
    
    # --- Sample Documents ---
    print(f"\n{'─'*50}")
    print("  SAMPLE DOCUMENTS (first 200 chars)")
    print(f"{'─'*50}")
    
    # Show 1 sample per first 5 classes
    for i in range(min(5, len(label_names))):
        # Find first document with this label
        for j, lbl in enumerate(train_labels):
            if lbl == i:
                text_preview = texts[j][:200].replace('\n', ' ')
                print(f"\n  [{label_names[i]}]")
                print(f"  \"{text_preview}...\"")
                break
    
    print(f"\n{'─'*50}")
    print(f"  (Showing 5 of {len(label_names)} classes)")
    print("="*70 + "\n")
    
    return dataset


# Run dataset exploration
dataset = explore_dataset(config)


DATASET EXPLORATION: 20 Newsgroups


README.md:   0%|          | 0.00/734 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


train.jsonl:   0%|          | 0.00/14.8M [00:00<?, ?B/s]

test.jsonl:   0%|          | 0.00/8.91M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/11314 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7532 [00:00<?, ? examples/s]


──────────────────────────────────────────────────
  Dataset: SetFit/20_newsgroups
  Number of classes: 20
  Train samples: 11,314
  Test samples:  7,532
  Total samples: 18,846
  Features: ['text', 'label', 'label_text']
──────────────────────────────────────────────────

──────────────────────────────────────────────────
  CLASS DISTRIBUTION
──────────────────────────────────────────────────

  Category                             Train   Test  Total
  ───────────────────────────────────────────────────────
  alt.atheism                            480    319    799  ████████████████████████
  comp.graphics                          584    389    973  █████████████████████████████
  comp.os.ms-windows.misc                591    394    985  █████████████████████████████
  comp.sys.ibm.pc.hardware               590    392    982  █████████████████████████████
  comp.sys.mac.hardware                  578    385    963  ████████████████████████████
  comp.windows.x                        

config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (37451 > 8192). Running this sequence through the model will result in indexing errors


    (Sampled 2,000 documents)
    Min:           2
    Max:      45,247
    Mean:      488.7
    Median:    133.0
    P95:     1,076.1

  Token coverage at different max_length:
    max_length=128: 48.8% of documents fully covered
    max_length=256: 74.9% of documents fully covered
    max_length=512: 89.3% of documents fully covered
    → Using max_length=256

──────────────────────────────────────────────────
  SAMPLE DOCUMENTS (first 200 chars)
──────────────────────────────────────────────────

  [alt.atheism]
  " Don't be so sure.  Look what happened to Japanese citizens in the US during World War II.  If you're prepared to say "Let's round these people up and stick them in a concentration camp without trial"..."

  [comp.graphics]
  " Do you have Weitek's address/phone number?  I'd like to get some information about this chip. ..."

  [comp.os.ms-windows.misc]
  "I have win 3.0 and downloaded several icons and BMP's but I can't figure out how to change the "wallpaper" or use the

## 4. Data Loading & Tokenization

In [4]:


"""
Data Loading and Preprocessing for 20 Newsgroups

Design Decisions:
-----------------
1. Tokenization: ModernBERT-large tokenizer
2. Padding: max_length for uniform batch shapes (better for DataParallel)
3. DataLoader: 4 workers per GPU, pin memory
"""

def get_label_names() -> list:
    """Get the 20 newsgroup category names."""
    return [
        'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc',
        'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x',
        'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball',
        'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med',
        'sci.space', 'soc.religion.christian', 'talk.politics.guns',
        'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'
    ]


def load_and_prepare_data(config, dataset=None) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Load 20 newsgroups dataset, tokenize, and create DataLoaders."""
    print(f"\nPreparing data for training...")
    
    # Use pre-loaded dataset if available (from exploration step)
    if dataset is None:
        dataset = load_dataset(config.dataset_name)
    
    train_dataset = dataset['train']
    test_split = dataset['test'].train_test_split(test_size=0.5, seed=config.seed)
    val_dataset = test_split['train']
    test_dataset = test_split['test']
    
    print(f"  Train size: {len(train_dataset)}")
    print(f"  Validation size: {len(val_dataset)}")
    print(f"  Test size: {len(test_dataset)}")
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            truncation=True,
            padding='max_length',
            max_length=config.max_length,
            return_tensors=None
        )
    
    # Apply tokenization
    print("  Tokenizing datasets...")
    train_dataset = train_dataset.map(tokenize_function, batched=True, desc="Tokenizing train")
    val_dataset = val_dataset.map(tokenize_function, batched=True, desc="Tokenizing val")
    test_dataset = test_dataset.map(tokenize_function, batched=True, desc="Tokenizing test")
    
    # Set format for PyTorch
    columns = ['input_ids', 'attention_mask', 'label']
    train_dataset.set_format(type='torch', columns=columns)
    val_dataset.set_format(type='torch', columns=columns)
    test_dataset.set_format(type='torch', columns=columns)
    
    # DataLoader config — use total_batch_size (accounts for multi-GPU)
    num_workers = 4 if config.device == 'cuda' else 0
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.total_batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if config.device == 'cuda' else False,
        drop_last=True  # Avoids uneven batch splits across GPUs
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.total_batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if config.device == 'cuda' else False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.total_batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if config.device == 'cuda' else False
    )
    
    print(f"  DataLoaders ready — Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")
    
    
    return train_loader, val_loader, test_loader


# Load data (reuses the dataset from exploration to avoid re-downloading)
train_loader, val_loader, test_loader = load_and_prepare_data(config, dataset=dataset)


Preparing data for training...
  Train size: 11314
  Validation size: 3766
  Test size: 3766
  Tokenizing datasets...


Tokenizing train:   0%|          | 0/11314 [00:00<?, ? examples/s]

Tokenizing val:   0%|          | 0/3766 [00:00<?, ? examples/s]

Tokenizing test:   0%|          | 0/3766 [00:00<?, ? examples/s]

  DataLoaders ready — Train batches: 353, Val batches: 118, Test batches: 118


## 5. Model

In [5]:


def get_model(config):
    """Initialize ModernBERT-large with layer freezing and optional DataParallel."""
    print(f"\nLoading model: {config.model_name}")
    print(f"  Number of classes: {config.num_labels}")
    
    model_config = AutoConfig.from_pretrained(
        config.model_name,
        num_labels=config.num_labels,
        finetuning_task="text-classification"
    )
    
    model = AutoModelForSequenceClassification.from_pretrained(
        config.model_name,
        config=model_config,
        attn_implementation="eager"  # Required: compiled attention breaks DataParallel
    )
    
    # Layer freezing for efficiency
    if config.freeze_layers:
        # Freeze embeddings
        if hasattr(model, 'model') and hasattr(model.model, 'embeddings'):
            for param in model.model.embeddings.parameters():
                param.requires_grad = False
            print("  ✓ Froze embedding layer")
        elif hasattr(model, 'bert') and hasattr(model.bert, 'embeddings'):
            for param in model.bert.embeddings.parameters():
                param.requires_grad = False
            print("  ✓ Froze embedding layer")
        
        # Freeze bottom encoder layers
        encoder_layers = None
        if hasattr(model, 'model') and hasattr(model.model, 'encoder'):
            encoder = model.model.encoder
            if hasattr(encoder, 'layers'):
                encoder_layers = encoder.layers
            elif hasattr(encoder, 'layer'):
                encoder_layers = encoder.layer
        elif hasattr(model, 'bert') and hasattr(model.bert, 'encoder'):
            encoder = model.bert.encoder
            if hasattr(encoder, 'layer'):
                encoder_layers = encoder.layer
        
        if encoder_layers is not None:
            num_layers = len(encoder_layers)
            num_freeze = int(num_layers * config.freeze_ratio)
            for i, layer in enumerate(encoder_layers):
                if i < num_freeze:
                    for param in layer.parameters():
                        param.requires_grad = False
            print(f"  ✓ Froze {num_freeze}/{num_layers} encoder layers")
        else:
            print("  ⚠ Warning: Could not identify encoder layers for freezing")
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    
    print(f"\n  Parameter Summary:")
    print(f"    Total:     {total_params:>12,}")
    print(f"    Trainable: {trainable_params:>12,} ({100*trainable_params/total_params:.1f}%)")
    print(f"    Frozen:    {frozen_params:>12,} ({100*frozen_params/total_params:.1f}%)")
    
    # Multi-GPU support with DataParallel
    model.to(config.device)
    if config.num_gpus > 1:
        model = nn.DataParallel(model)
        print(f"\n  ✓ DataParallel enabled across {config.num_gpus} GPUs")
    
    return model


# Initialize model
model = get_model(config)


Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


model.safetensors:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs


## 6. Trainer

In [6]:


class Trainer:
    
    def __init__(self, model, config, train_loader, val_loader):
        self.model = model
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = config.device

        
        # Access underlying model for parameter filtering (DataParallel wraps it)
        base_model = model.module if hasattr(model, 'module') else model
        
        # Only optimize trainable parameters
        self.optimizer = AdamW(
            filter(lambda p: p.requires_grad, base_model.parameters()),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # Total optimizer steps = one step per batch, per epoch
        self.total_steps = len(train_loader) * config.num_epochs
        self.warmup_steps = int(self.total_steps * config.warmup_ratio)
        
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_steps
        )
        
        # Modern AMP API
        self.scaler = GradScaler("cuda") if config.use_fp16 else None
        self.use_fp16 = config.use_fp16
        
        self.history = {'train_loss': [], 'learning_rate': []}
        
        print(f"\nTraining Configuration:")
        print(f"  Device: {self.device} × {config.num_gpus} GPUs")
        print(f"  Total optimizer steps: {self.total_steps}")
        print(f"  Warmup steps: {self.warmup_steps}")
        print(f"  Mixed precision (FP16): {self.use_fp16}")
    
    def _get_trainable_params(self):
        """Get trainable parameters from model (handles DataParallel)."""
        base_model = self.model.module if hasattr(self.model, 'module') else self.model
        return filter(lambda p: p.requires_grad, base_model.parameters())

    @torch.no_grad()
    def evaluate_val(self):
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        for batch in self.val_loader:
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['label'].to(self.device)
            if self.use_fp16:
                with autocast("cuda"):
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            else:
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss.mean() if outputs.loss.dim() > 0 else outputs.loss
            total_loss += loss.item()
            
            preds = torch.argmax(outputs.logits, dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
        acc = accuracy_score(all_labels, all_preds)
        return total_loss / len(self.val_loader), acc
    
    def train_epoch(self, epoch):
        """Train for one epoch, stepping the optimizer on every batch."""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        progress_bar = tqdm(
            self.train_loader,
            desc=f"Epoch {epoch+1}/{self.config.num_epochs}",
            leave=True
        )
        
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['label'].to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.use_fp16:
                with autocast("cuda"):
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels
                    )
                    # DataParallel returns averaged loss across GPUs
                    loss = outputs.loss.mean()
                
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    self._get_trainable_params(),
                    self.config.max_grad_norm
                )
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss.mean()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self._get_trainable_params(),
                    self.config.max_grad_norm
                )
                self.optimizer.step()
            
            self.scheduler.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'lr': f'{self.scheduler.get_last_lr()[0]:.2e}'
            })
        
        return total_loss / num_batches
    
    def train(self):
        """Full training loop."""
        print("\n" + "="*60)
        print("Starting Training")
        print("="*60 + "\n")
        
        start_time = time.time()
        
        for epoch in range(self.config.num_epochs):
            epoch_start = time.time()
            
            train_loss = self.train_epoch(epoch)
            self.history['train_loss'].append(train_loss)
            
            current_lr = self.scheduler.get_last_lr()[0]
            self.history['learning_rate'].append(current_lr)
            
            epoch_time = time.time() - epoch_start
            
            print(f"\nEpoch {epoch+1}/{self.config.num_epochs} - "
                  f"Train Loss: {train_loss:.4f} - "
                  f"LR: {current_lr:.2e} - "
                  f"Time: {epoch_time:.1f}s")
            
            # Memory report
            if torch.cuda.is_available():
                for i in range(config.num_gpus):
                    allocated = torch.cuda.memory_allocated(i) / 1024**3
                    reserved = torch.cuda.memory_reserved(i) / 1024**3
                    print(f"  GPU {i} memory: {allocated:.1f} GB allocated, {reserved:.1f} GB reserved")
        val_loss, val_acc = self.evaluate_val()
        print(f"\nFinal Validation Loss: {val_loss:.4f} | Validation Accuracy: {val_acc:.4f}")
        
        total_time = time.time() - start_time
        print(f"\nTraining Complete! Total time: {total_time/60:.1f} minutes")
        
        return val_loss, val_acc, self.history

## 8. Optuna Hyperparameter Optimization

In [7]:


import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def objective(trial):
    set_seed(config.seed + trial.number)
    
    lr = trial.suggest_float('learning_rate', 1e-5, 1e-4, log=True)
    wd = trial.suggest_float('weight_decay', 0.001, 0.1)
    wr = trial.suggest_float('warmup_ratio', 0.0, 0.2)
    
    trial_config = copy.deepcopy(config)
    trial_config.learning_rate = lr
    trial_config.weight_decay = wd
    trial_config.warmup_ratio = wr
    
    print(f"\n--- Starting Trial {trial.number} ---")
    print(f"Params: lr={lr:.2e}, weight_decay={wd:.4f}, warmup_ratio={wr:.2f}")
    
    trial_model = get_model(trial_config)
    trainer = Trainer(trial_model, trial_config, train_loader, val_loader)
    val_loss, val_acc, history = trainer.train()
    
    # Save checkpoint
    checkpoint_dir = os.path.join(trial_config.output_dir, f"trial_{trial.number}")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    save_model = trial_model.module if hasattr(trial_model, 'module') else trial_model
    save_model.save_pretrained(checkpoint_dir)
    tokenizer = AutoTokenizer.from_pretrained(trial_config.model_name)
    tokenizer.save_pretrained(checkpoint_dir)
    
    trial.set_user_attr('checkpoint_dir', checkpoint_dir)
    
    # Aggressively clean up memory of this trial's model objects
    del trial_model
    del trainer
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Delete model checkpoints from disk if they definitely aren't Top 3
    # Wait until there's at least 3 completed trials before deciding:
    completed_trials = [t for t in trial.study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    if len(completed_trials) >= 3:
        # Sort by value (val_acc), descending (larger is better)
        completed_trials.sort(key=lambda x: x.value if x.value is not None else float('-inf'), reverse=True)
        best_3_values = [t.value for t in completed_trials[:3] if t.value is not None]
        
        if len(best_3_values) == 3 and val_acc < min(best_3_values):
            import shutil
            # This trial is worse than the current Top 3, immediately delete the checkpoint from disk
            if os.path.exists(checkpoint_dir):
                shutil.rmtree(checkpoint_dir)
                trial.set_user_attr('checkpoint_deleted', True)
        
    return val_acc

# Run QRS
os.makedirs(config.output_dir, exist_ok=True)
sampler = optuna.samplers.QMCSampler(seed=config.seed)
study = optuna.create_study(direction="maximize", sampler=sampler, study_name="modernbert_hpo")

print("\n" + "="*70)
print("STARTING OPTUNA QUASI-RANDOM SEARCH")
print("="*70)
study.optimize(objective, n_trials=8)

# Visualization
try:
    fig = optuna.visualization.matplotlib.plot_slice(study)
    plt.tight_layout()
    plt.savefig(os.path.join(config.output_dir, "qrs_hyperparameters_vs_loss.png"), dpi=300)
    plt.close()
    
    fig2 = optuna.visualization.matplotlib.plot_param_importances(study)
    plt.tight_layout()
    plt.savefig(os.path.join(config.output_dir, "qrs_param_importances.png"), dpi=300)
    plt.close()
    print(f"Saved Optuna visualizations to {config.output_dir}")
except Exception as e:
    print(f"Failed to generate Optuna plots: {e}")

  sampler = optuna.samplers.QMCSampler(seed=config.seed)
[I 2026-02-21 17:26:38,112] A new study created in memory with name: modernbert_hpo



STARTING OPTUNA QUASI-RANDOM SEARCH

--- Starting Trial 0 ---
Params: lr=2.37e-05, weight_decay=0.0951, warmup_ratio=0.15

Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs

Training Configuration:
  Device: cuda × 2 GPUs
  Total optimizer steps: 1412
  Warmup steps: 206
  Mixed precision (FP16): True

Starting Training



Epoch 1/4: 100%|██████████| 353/353 [07:27<00:00,  1.27s/it, loss=0.6913, lr=2.08e-05]



Epoch 1/4 - Train Loss: 1.5282 - LR: 2.08e-05 - Time: 447.6s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 2/4: 100%|██████████| 353/353 [07:27<00:00,  1.27s/it, loss=0.5416, lr=1.39e-05]



Epoch 2/4 - Train Loss: 0.5888 - LR: 1.39e-05 - Time: 447.8s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 3/4: 100%|██████████| 353/353 [07:26<00:00,  1.26s/it, loss=0.1470, lr=6.93e-06]



Epoch 3/4 - Train Loss: 0.2183 - LR: 6.93e-06 - Time: 446.0s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 4/4: 100%|██████████| 353/353 [07:24<00:00,  1.26s/it, loss=0.0918, lr=0.00e+00]


Epoch 4/4 - Train Loss: 0.0983 - LR: 0.00e+00 - Time: 444.9s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved






Final Validation Loss: 1.2520 | Validation Accuracy: 0.7461

Training Complete! Total time: 30.7 minutes


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2026-02-21 17:57:25,778] Trial 0 finished with value: 0.7461497610196495 and parameters: {'learning_rate': 2.368863950364079e-05, 'weight_decay': 0.0951207163345817, 'warmup_ratio': 0.146398788362281}. Best is trial 0 with value: 0.7461497610196495.



--- Starting Trial 1 ---
Params: lr=3.97e-05, weight_decay=0.0010, warmup_ratio=0.00

Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs

Training Configuration:
  Device: cuda × 2 GPUs
  Total optimizer steps: 1412
  Warmup steps: 0
  Mixed precision (FP16): True

Starting Training



Epoch 1/4: 100%|██████████| 353/353 [07:29<00:00,  1.27s/it, loss=0.6993, lr=2.98e-05]



Epoch 1/4 - Train Loss: 1.0659 - LR: 2.98e-05 - Time: 449.4s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 2/4: 100%|██████████| 353/353 [07:27<00:00,  1.27s/it, loss=0.1956, lr=1.98e-05]



Epoch 2/4 - Train Loss: 0.4777 - LR: 1.98e-05 - Time: 447.4s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 3/4: 100%|██████████| 353/353 [07:24<00:00,  1.26s/it, loss=0.0622, lr=9.92e-06]



Epoch 3/4 - Train Loss: 0.1881 - LR: 9.92e-06 - Time: 444.7s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 4/4: 100%|██████████| 353/353 [07:23<00:00,  1.26s/it, loss=0.1119, lr=0.00e+00]


Epoch 4/4 - Train Loss: 0.1020 - LR: 0.00e+00 - Time: 443.4s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved






Final Validation Loss: 1.3294 | Validation Accuracy: 0.7411

Training Complete! Total time: 30.7 minutes


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2026-02-21 18:28:11,724] Trial 1 finished with value: 0.7411046202867764 and parameters: {'learning_rate': 3.968793330444374e-05, 'weight_decay': 0.001, 'warmup_ratio': 0.0}. Best is trial 0 with value: 0.7461497610196495.



--- Starting Trial 2 ---
Params: lr=3.16e-05, weight_decay=0.0505, warmup_ratio=0.10

Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs

Training Configuration:
  Device: cuda × 2 GPUs
  Total optimizer steps: 1412
  Warmup steps: 141
  Mixed precision (FP16): True

Starting Training



Epoch 1/4: 100%|██████████| 353/353 [07:29<00:00,  1.27s/it, loss=0.9250, lr=2.63e-05]



Epoch 1/4 - Train Loss: 1.3515 - LR: 2.63e-05 - Time: 449.0s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 2/4: 100%|██████████| 353/353 [07:28<00:00,  1.27s/it, loss=0.3536, lr=1.76e-05]



Epoch 2/4 - Train Loss: 0.5169 - LR: 1.76e-05 - Time: 448.1s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 3/4: 100%|██████████| 353/353 [07:25<00:00,  1.26s/it, loss=0.0280, lr=8.78e-06]



Epoch 3/4 - Train Loss: 0.1868 - LR: 8.78e-06 - Time: 445.7s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 4/4: 100%|██████████| 353/353 [07:24<00:00,  1.26s/it, loss=0.0117, lr=0.00e+00]


Epoch 4/4 - Train Loss: 0.0934 - LR: 0.00e+00 - Time: 444.8s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved






Final Validation Loss: 1.3307 | Validation Accuracy: 0.7411

Training Complete! Total time: 30.7 minutes


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2026-02-21 18:59:00,532] Trial 2 finished with value: 0.7411046202867764 and parameters: {'learning_rate': 3.16227766016838e-05, 'weight_decay': 0.0505, 'warmup_ratio': 0.1}. Best is trial 0 with value: 0.7461497610196495.



--- Starting Trial 3 ---
Params: lr=5.62e-05, weight_decay=0.0258, warmup_ratio=0.05

Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs

Training Configuration:
  Device: cuda × 2 GPUs
  Total optimizer steps: 1412
  Warmup steps: 70
  Mixed precision (FP16): True

Starting Training



Epoch 1/4: 100%|██████████| 353/353 [07:28<00:00,  1.27s/it, loss=0.5291, lr=4.44e-05]



Epoch 1/4 - Train Loss: 1.2155 - LR: 4.44e-05 - Time: 448.8s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 2/4: 100%|██████████| 353/353 [07:27<00:00,  1.27s/it, loss=0.3493, lr=2.96e-05]



Epoch 2/4 - Train Loss: 0.5187 - LR: 2.96e-05 - Time: 447.5s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 3/4: 100%|██████████| 353/353 [07:23<00:00,  1.26s/it, loss=0.1728, lr=1.48e-05]



Epoch 3/4 - Train Loss: 0.2012 - LR: 1.48e-05 - Time: 444.0s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 4/4: 100%|██████████| 353/353 [07:22<00:00,  1.25s/it, loss=0.0746, lr=0.00e+00]


Epoch 4/4 - Train Loss: 0.1006 - LR: 0.00e+00 - Time: 442.1s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved






Final Validation Loss: 1.4788 | Validation Accuracy: 0.7366

Training Complete! Total time: 30.6 minutes


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2026-02-21 19:29:44,403] Trial 3 finished with value: 0.7365905469994689 and parameters: {'learning_rate': 5.6234132519034995e-05, 'weight_decay': 0.025750000000000002, 'warmup_ratio': 0.05}. Best is trial 0 with value: 0.7461497610196495.



--- Starting Trial 4 ---
Params: lr=1.78e-05, weight_decay=0.0753, warmup_ratio=0.15

Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs

Training Configuration:
  Device: cuda × 2 GPUs
  Total optimizer steps: 1412
  Warmup steps: 211
  Mixed precision (FP16): True

Starting Training



Epoch 1/4: 100%|██████████| 353/353 [07:29<00:00,  1.27s/it, loss=1.3765, lr=1.57e-05]



Epoch 1/4 - Train Loss: 1.5985 - LR: 1.57e-05 - Time: 449.7s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 2/4: 100%|██████████| 353/353 [07:28<00:00,  1.27s/it, loss=0.4642, lr=1.05e-05]



Epoch 2/4 - Train Loss: 0.6278 - LR: 1.05e-05 - Time: 448.9s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 3/4: 100%|██████████| 353/353 [07:27<00:00,  1.27s/it, loss=0.1410, lr=5.23e-06]



Epoch 3/4 - Train Loss: 0.2566 - LR: 5.23e-06 - Time: 447.2s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 4/4: 100%|██████████| 353/353 [07:26<00:00,  1.26s/it, loss=0.1130, lr=0.00e+00]


Epoch 4/4 - Train Loss: 0.1136 - LR: 0.00e+00 - Time: 446.2s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved






Final Validation Loss: 1.2011 | Validation Accuracy: 0.7379

Training Complete! Total time: 30.8 minutes


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2026-02-21 20:00:37,485] Trial 4 finished with value: 0.7379182156133829 and parameters: {'learning_rate': 1.7782794100389212e-05, 'weight_decay': 0.07525000000000001, 'warmup_ratio': 0.15000000000000002}. Best is trial 0 with value: 0.7461497610196495.



--- Starting Trial 5 ---
Params: lr=2.37e-05, weight_decay=0.0381, warmup_ratio=0.12

Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs

Training Configuration:
  Device: cuda × 2 GPUs
  Total optimizer steps: 1412
  Warmup steps: 176
  Mixed precision (FP16): True

Starting Training



Epoch 1/4: 100%|██████████| 353/353 [07:29<00:00,  1.27s/it, loss=0.5933, lr=2.03e-05]



Epoch 1/4 - Train Loss: 1.4037 - LR: 2.03e-05 - Time: 449.4s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 2/4: 100%|██████████| 353/353 [07:28<00:00,  1.27s/it, loss=0.7164, lr=1.35e-05]



Epoch 2/4 - Train Loss: 0.5367 - LR: 1.35e-05 - Time: 448.6s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 3/4: 100%|██████████| 353/353 [07:27<00:00,  1.27s/it, loss=0.2082, lr=6.77e-06]



Epoch 3/4 - Train Loss: 0.2086 - LR: 6.77e-06 - Time: 447.0s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 4/4: 100%|██████████| 353/353 [07:26<00:00,  1.26s/it, loss=0.0040, lr=0.00e+00]


Epoch 4/4 - Train Loss: 0.0994 - LR: 0.00e+00 - Time: 446.3s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved






Final Validation Loss: 1.2590 | Validation Accuracy: 0.7411

Training Complete! Total time: 30.8 minutes


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2026-02-21 20:31:30,095] Trial 5 finished with value: 0.7411046202867764 and parameters: {'learning_rate': 2.3713737056616547e-05, 'weight_decay': 0.038125000000000006, 'warmup_ratio': 0.125}. Best is trial 0 with value: 0.7461497610196495.



--- Starting Trial 6 ---
Params: lr=7.50e-05, weight_decay=0.0876, warmup_ratio=0.03

Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs

Training Configuration:
  Device: cuda × 2 GPUs
  Total optimizer steps: 1412
  Warmup steps: 35
  Mixed precision (FP16): True

Starting Training



Epoch 1/4: 100%|██████████| 353/353 [07:29<00:00,  1.27s/it, loss=0.9636, lr=5.77e-05]



Epoch 1/4 - Train Loss: 1.1469 - LR: 5.77e-05 - Time: 449.2s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 2/4: 100%|██████████| 353/353 [07:27<00:00,  1.27s/it, loss=0.6849, lr=3.84e-05]



Epoch 2/4 - Train Loss: 0.4661 - LR: 3.84e-05 - Time: 447.5s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 3/4: 100%|██████████| 353/353 [07:24<00:00,  1.26s/it, loss=0.2408, lr=1.92e-05]



Epoch 3/4 - Train Loss: 0.1647 - LR: 1.92e-05 - Time: 444.1s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 4/4: 100%|██████████| 353/353 [07:22<00:00,  1.25s/it, loss=0.0940, lr=0.00e+00]


Epoch 4/4 - Train Loss: 0.0903 - LR: 0.00e+00 - Time: 442.0s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved






Final Validation Loss: 1.6026 | Validation Accuracy: 0.7475

Training Complete! Total time: 30.6 minutes


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2026-02-21 21:02:13,941] Trial 6 finished with value: 0.7474774296335635 and parameters: {'learning_rate': 7.498942093324561e-05, 'weight_decay': 0.08762500000000001, 'warmup_ratio': 0.025}. Best is trial 6 with value: 0.7474774296335635.



--- Starting Trial 7 ---
Params: lr=4.22e-05, weight_decay=0.0134, warmup_ratio=0.18

Loading model: answerdotai/ModernBERT-large
  Number of classes: 20


Loading weights:   0%|          | 0/172 [00:00<?, ?it/s]

[1mModernBertForSequenceClassification LOAD REPORT[0m from: answerdotai/ModernBERT-large
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


  ✓ Froze embedding layer

  Parameter Summary:
    Total:      395,851,796
    Trainable:  344,273,940 (87.0%)
    Frozen:      51,577,856 (13.0%)

  ✓ DataParallel enabled across 2 GPUs

Training Configuration:
  Device: cuda × 2 GPUs
  Total optimizer steps: 1412
  Warmup steps: 247
  Mixed precision (FP16): True

Starting Training



Epoch 1/4: 100%|██████████| 353/353 [07:29<00:00,  1.27s/it, loss=0.7775, lr=3.83e-05]



Epoch 1/4 - Train Loss: 1.4470 - LR: 3.83e-05 - Time: 449.1s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 2/4: 100%|██████████| 353/353 [07:28<00:00,  1.27s/it, loss=0.6612, lr=2.56e-05]



Epoch 2/4 - Train Loss: 0.5993 - LR: 2.56e-05 - Time: 448.1s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 3/4: 100%|██████████| 353/353 [07:25<00:00,  1.26s/it, loss=0.0343, lr=1.28e-05]



Epoch 3/4 - Train Loss: 0.2180 - LR: 1.28e-05 - Time: 445.1s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved


Epoch 4/4: 100%|██████████| 353/353 [07:23<00:00,  1.26s/it, loss=0.0843, lr=0.00e+00]


Epoch 4/4 - Train Loss: 0.0964 - LR: 0.00e+00 - Time: 443.4s
  GPU 0 memory: 6.8 GB allocated, 13.5 GB reserved
  GPU 1 memory: 0.0 GB allocated, 9.4 GB reserved






Final Validation Loss: 1.4307 | Validation Accuracy: 0.7416

Training Complete! Total time: 30.7 minutes


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

[I 2026-02-21 21:33:00,303] Trial 7 finished with value: 0.741635687732342 and parameters: {'learning_rate': 4.216965034285826e-05, 'weight_decay': 0.013375000000000001, 'warmup_ratio': 0.17500000000000002}. Best is trial 6 with value: 0.7474774296335635.
  fig = optuna.visualization.matplotlib.plot_slice(study)
  plt.tight_layout()
  fig2 = optuna.visualization.matplotlib.plot_param_importances(study)


Saved Optuna visualizations to /kaggle/tmp/output


## 9. Evaluation

In [8]:


@torch.no_grad()
def evaluate(model, test_loader, config):
    """Comprehensive evaluation on test set."""
    print("\n" + "="*60)
    print("Evaluating on Test Set")
    print("="*60 + "\n")
    
    model.eval()
    
    all_predictions = []
    all_labels = []
    total_loss = 0
    
    for batch in tqdm(test_loader, desc="Evaluating"):
        input_ids = batch['input_ids'].to(config.device)
        attention_mask = batch['attention_mask'].to(config.device)
        labels = batch['label'].to(config.device)
        
        if config.use_fp16:
            with autocast("cuda"):
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
        else:
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
        
        # Handle DataParallel loss
        loss = outputs.loss.mean() if outputs.loss.dim() > 0 else outputs.loss
        total_loss += loss.item()
        predictions = torch.argmax(outputs.logits, dim=-1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro'
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted'
    )
    avg_loss = total_loss / len(test_loader)
    
    label_names = get_label_names()
    report = classification_report(all_labels, all_predictions, target_names=label_names, digits=4)
    conf_matrix = confusion_matrix(all_labels, all_predictions)
    
    # Print results
    print("\n" + "="*60)
    print("EVALUATION RESULTS")
    print("="*60)
    print(f"\n[Overall Metrics]")
    print(f"  Test Loss: {avg_loss:.4f}")
    print(f"  Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"\n[Macro Averages]")
    print(f"  Precision: {precision_macro:.4f}")
    print(f"  Recall: {recall_macro:.4f}")
    print(f"  F1 Score: {f1_macro:.4f}")
    print(f"\n[Weighted Averages]")
    print(f"  Precision: {precision_weighted:.4f}")
    print(f"  Recall: {recall_weighted:.4f}")
    print(f"  F1 Score: {f1_weighted:.4f}")
    print("\n" + "="*60)
    print("CLASSIFICATION REPORT (Per-Class)")
    print("="*60)
    print(report)
    
    return {
        'test_loss': avg_loss,
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'classification_report': report,
        'confusion_matrix': conf_matrix,
        'predictions': all_predictions,
        'labels': all_labels,
        'label_names': label_names
    }

## 10. Final Evaluation & Model Selection

In [9]:

completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE and not t.user_attrs.get('checkpoint_deleted', False)]
completed_trials.sort(key=lambda t: t.value if t.value is not None else float('-inf'), reverse=True)
top_trials = completed_trials[:3]

print("\n" + "="*70)
print("FINAL EVALUATION RESULTS - TOP 3 MODELS")
print("="*70)

for idx, t in enumerate(top_trials):
    print(f"\n{'='*50}")
    print(f"--- Loading Top {idx+1} Model (Trial {t.number}) ---")
    print(f"Params: {t.params}")
    print(f"Validation Accuracy: {t.value:.4f}")
    print(f"{'='*50}")
    
    ckpt_dir = t.user_attrs['checkpoint_dir']
    
    # Load and Evaluate
    best_model = AutoModelForSequenceClassification.from_pretrained(ckpt_dir)
    best_model.to(config.device)
    if config.num_gpus > 1:
        best_model = nn.DataParallel(best_model)
        
    results = evaluate(best_model, test_loader, config)
    
    print(f"\n[Top {idx+1} Model Test Metrics]")
    print(f"  Test Loss: {results['test_loss']:.4f}")
    print(f"  Test Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
    print(f"  Macro F1: {results['f1_macro']:.4f}")
    
    final_dir = os.path.join(config.output_dir, f"top{idx+1}_model")
    if os.path.exists(final_dir):
        import shutil
        shutil.rmtree(final_dir)
    os.rename(ckpt_dir, final_dir)
    print(f"Saved Top {idx+1} Model to {final_dir}")
    
    if idx == 0 and os.path.exists("/kaggle/working"):
        kaggle_working_dir = "/kaggle/working/top1_model"
        if os.path.exists(kaggle_working_dir):
            import shutil
            shutil.rmtree(kaggle_working_dir)
        import shutil
        shutil.copytree(final_dir, kaggle_working_dir)
        print(f"Also copied Top 1 Model to {kaggle_working_dir}")
    
    metrics_path = os.path.join(final_dir, "evaluation_metrics.json")
    with open(metrics_path, 'w') as f:
        # Avoid saving non-serializable objects (like numpy arrays from evaluation)
        json_results = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in results.items() if k not in ['confusion_matrix']}
        json.dump(json_results, f, indent=2)

    # Save trial info
    with open(os.path.join(final_dir, "trial_info.json"), 'w') as f:
        json.dump({"trial_id": t.number, "params": t.params, "val_acc": t.value}, f, indent=2)

# Cleanup other trial weights to save space
import shutil
for t in study.trials:
    if 'checkpoint_dir' in t.user_attrs:
        ckpt_dir = t.user_attrs['checkpoint_dir']
        if os.path.exists(ckpt_dir):
            shutil.rmtree(ckpt_dir)

# Keep the global 'model' pointing to the Best Model (Top 1) for Section 12 OOD evaluation
best_ckpt_dir = os.path.join(config.output_dir, "top1_model")
model = AutoModelForSequenceClassification.from_pretrained(best_ckpt_dir)
model.to(config.device)
if config.num_gpus > 1:
    model = nn.DataParallel(model)


FINAL EVALUATION RESULTS - TOP 3 MODELS

--- Loading Top 1 Model (Trial 6) ---
Params: {'learning_rate': 7.498942093324561e-05, 'weight_decay': 0.08762500000000001, 'warmup_ratio': 0.025}
Validation Accuracy: 0.7475


Loading weights:   0%|          | 0/174 [00:00<?, ?it/s]


Evaluating on Test Set



Evaluating: 100%|██████████| 118/118 [00:49<00:00,  2.39it/s]



EVALUATION RESULTS

[Overall Metrics]
  Test Loss: 1.6971
  Accuracy: 0.7326 (73.26%)

[Macro Averages]
  Precision: 0.7326
  Recall: 0.7221
  F1 Score: 0.7246

[Weighted Averages]
  Precision: 0.7415
  Recall: 0.7326
  F1 Score: 0.7344

CLASSIFICATION REPORT (Per-Class)
                          precision    recall  f1-score   support

             alt.atheism     0.5089    0.5443    0.5260       158
           comp.graphics     0.7796    0.7323    0.7552       198
 comp.os.ms-windows.misc     0.6071    0.7286    0.6623       210
comp.sys.ibm.pc.hardware     0.7174    0.7293    0.7233       181
   comp.sys.mac.hardware     0.7753    0.7005    0.7360       197
          comp.windows.x     0.8426    0.8426    0.8426       197
            misc.forsale     0.8984    0.8358    0.8660       201
               rec.autos     0.5667    0.7391    0.6415       184
         rec.motorcycles     0.8176    0.7394    0.7765       188
      rec.sport.baseball     0.8539    0.8172    0.8352       186


Loading weights:   0%|          | 0/174 [00:00<?, ?it/s]


Evaluating on Test Set



Evaluating: 100%|██████████| 118/118 [00:48<00:00,  2.44it/s]


EVALUATION RESULTS

[Overall Metrics]
  Test Loss: 1.2554
  Accuracy: 0.7493 (74.93%)

[Macro Averages]
  Precision: 0.7424
  Recall: 0.7388
  F1 Score: 0.7385

[Weighted Averages]
  Precision: 0.7518
  Recall: 0.7493
  F1 Score: 0.7486

CLASSIFICATION REPORT (Per-Class)
                          precision    recall  f1-score   support

             alt.atheism     0.5621    0.5443    0.5531       158
           comp.graphics     0.7487    0.7374    0.7430       198
 comp.os.ms-windows.misc     0.7212    0.7143    0.7177       210
comp.sys.ibm.pc.hardware     0.7173    0.7569    0.7366       181
   comp.sys.mac.hardware     0.7709    0.7005    0.7340       197
          comp.windows.x     0.8750    0.8528    0.8638       197
            misc.forsale     0.8626    0.9055    0.8835       201
               rec.autos     0.5949    0.7663    0.6698       184
         rec.motorcycles     0.7933    0.7553    0.7738       188
      rec.sport.baseball     0.7608    0.8548    0.8051       186





Loading weights:   0%|          | 0/174 [00:00<?, ?it/s]


Evaluating on Test Set



Evaluating: 100%|██████████| 118/118 [00:48<00:00,  2.45it/s]



EVALUATION RESULTS

[Overall Metrics]
  Test Loss: 1.4491
  Accuracy: 0.7435 (74.35%)

[Macro Averages]
  Precision: 0.7407
  Recall: 0.7341
  F1 Score: 0.7353

[Weighted Averages]
  Precision: 0.7511
  Recall: 0.7435
  F1 Score: 0.7453

CLASSIFICATION REPORT (Per-Class)
                          precision    recall  f1-score   support

             alt.atheism     0.5862    0.5380    0.5611       158
           comp.graphics     0.7590    0.7475    0.7532       198
 comp.os.ms-windows.misc     0.7225    0.7190    0.7208       210
comp.sys.ibm.pc.hardware     0.6927    0.7348    0.7131       181
   comp.sys.mac.hardware     0.7527    0.6954    0.7230       197
          comp.windows.x     0.8807    0.7868    0.8311       197
            misc.forsale     0.8848    0.8408    0.8622       201
               rec.autos     0.5469    0.7609    0.6364       184
         rec.motorcycles     0.7512    0.8032    0.7763       188
      rec.sport.baseball     0.8201    0.8333    0.8267       186


Loading weights:   0%|          | 0/174 [00:00<?, ?it/s]

## 12. Out-of-Distribution (OOD) Detection — "Null / Other" Class

Strategy: Maximum Softmax Probability (MSP) with temperature scaling.

A sample is labelled "null/other" when:
    max( softmax( logits / T ) ) < tau

  T (temperature): reshapes prob mass before scoring.
                   T > 1 softens probs (often improves separation);
                   T = 1 is the vanilla MSP baseline.

  tau (threshold): directly controls the FP / FN trade-off:
      increase tau -> stricter  -> fewer OOD accepted (fewer FP),
                                   more ID rejected   (more  FN)
      decrease tau -> looser   -> fewer ID rejected   (fewer FN),
                                   more OOD accepted  (more  FP)

In-distribution (ID) : 20 Newsgroups test set (same split as Section 9)
Out-of-distribution  : AG News — 4-class news topics, completely different domain

In [10]:


import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score
from scipy.special import softmax as scipy_softmax

## 12.1 Core helpers

In [11]:


@torch.no_grad()
def collect_logits_and_labels(mdl, loader, cfg, has_labels=True):
    """Run the model; return (logits ndarray shape (N,20), labels ndarray | None)."""
    mdl.eval()
    all_logits, all_labels = [], []
    for batch in tqdm(loader, desc="  collecting", leave=False):
        ids  = batch["input_ids"].to(cfg.device)
        mask = batch["attention_mask"].to(cfg.device)
        out  = mdl(input_ids=ids, attention_mask=mask)
        all_logits.append(out.logits.cpu().float().numpy())
        if has_labels and "label" in batch:
            all_labels.extend(batch["label"].cpu().numpy())
    logits = np.concatenate(all_logits, axis=0)
    labels = np.array(all_labels, dtype=int) if all_labels else None
    return logits, labels


def msp_scores(logits, temperature=1.0):
    """Maximum Softmax Probability: higher score -> more confident -> more likely ID."""
    probs = scipy_softmax(logits / max(temperature, 1e-6), axis=1)
    return probs.max(axis=1)

print("\n[1/3] Collecting in-distribution logits (20 Newsgroups test set) ...")
id_logits_ood, id_true_ood = collect_logits_and_labels(
    model, test_loader, config, has_labels=True
)


[1/3] Collecting in-distribution logits (20 Newsgroups test set) ...


                                                               

## 12.2 OOD data

In [12]:


def load_ood_data(cfg, n_samples=2000, seed=42):
    """Sample n_samples texts from AG News (test split) and return a DataLoader."""
    print("  Loading OOD dataset: AG News ...")
    raw = load_dataset("ag_news", split="test")
    rng = np.random.default_rng(seed)
    idx = rng.choice(len(raw), size=min(n_samples, len(raw)), replace=False).tolist()
    raw = raw.select(idx)

    _tok = AutoTokenizer.from_pretrained(cfg.model_name)

    def _tokenize(examples):
        return _tok(
            examples["text"],
            truncation=True, padding="max_length",
            max_length=cfg.max_length, return_tensors=None,
        )

    tok_data = raw.map(_tokenize, batched=True, desc="  tokenising OOD",
                       remove_columns=raw.column_names)
    tok_data.set_format(type="torch", columns=["input_ids", "attention_mask"])

    ood_loader = DataLoader(
        tok_data,
        batch_size=cfg.total_batch_size,
        shuffle=False, num_workers=0,
        pin_memory=(cfg.device == "cuda"),
    )
    print(f"  OOD samples: {len(tok_data)} | batches: {len(ood_loader)}")
    return ood_loader

print("\n[2/3] Loading & collecting OOD logits (AG News) ...")
ood_loader_sec12 = load_ood_data(config, n_samples=2000)
ood_logits_ood, _ = collect_logits_and_labels(
    model, ood_loader_sec12, config, has_labels=False
)


[2/3] Loading & collecting OOD logits (AG News) ...
  Loading OOD dataset: AG News ...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

  tokenising OOD:   0%|          | 0/2000 [00:00<?, ? examples/s]

  OOD samples: 2000 | batches: 63


                                                             

## 12.3 Evaluation at one temperature

In [13]:


def evaluate_ood_detection(id_logits, id_true, ood_logits, temperature, save_dir):
    """
    Evaluate OOD detection performance using MSP at a given temperature T.

    Metrics reported
    ----------------
    AUROC      : area under ROC (ID=1, OOD=0)
    AP         : average precision
    FPR@TPR70  : false-pos rate when 70% of ID samples are accepted
    Table      : per-tau FPR / FNR / retained-ID-accuracy / % ID retained
    Plots      : ROC curve, score distributions, FP-FN trade-off curve
    """
    id_conf  = msp_scores(id_logits,  temperature)
    ood_conf = msp_scores(ood_logits, temperature)

    y_true  = np.concatenate([np.ones(len(id_conf)),  np.zeros(len(ood_conf))])
    y_score = np.concatenate([id_conf,                ood_conf])

    auroc = roc_auc_score(y_true, y_score)
    acc = max(accuracy_score(y_true, y_score >= tau) for tau in np.arange(0.1, 1.0, 0.05))

    fpr_arr, tpr_arr, _ = roc_curve(y_true, y_score)
    idx70 = np.searchsorted(tpr_arr, 0.70)
    fpr70 = float(fpr_arr[min(idx70, len(fpr_arr) - 1)])

    print(f"\n  T={temperature:.2f}  |  "
          f"AUROC={auroc:.4f}  Accuracy={acc:.4f}  FPR@TPR70={fpr70:.4f}")

    # Threshold table
    id_preds = id_logits.argmax(axis=1)
    print(f"\n  {'tau':>6} | {'FPR':>8} | {'FNR':>8} | "
          f"{'ID acc (retained)':>20} | {'% ID kept':>10}")
    print(f"  {'─' * 62}")
    for tau in np.arange(0.10, 1.0, 0.10):
        fpr_val = float((ood_conf >= tau).mean())
        fnr_val = float((id_conf  <  tau).mean())
        mask    = id_conf >= tau
        acc_val = (float((id_preds[mask] == id_true[mask]).mean())
                   if mask.sum() > 0 else float("nan"))
        pct     = float(mask.mean()) * 100
        print(f"  {tau:6.2f} | {fpr_val:8.4f} | {fnr_val:8.4f} | "
              f"{acc_val:>20.4f} | {pct:>9.1f}%")

    # Plots
    os.makedirs(save_dir, exist_ok=True)
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle(
        f"OOD Detection (MSP) — Temperature T = {temperature:.2f}", fontsize=13
    )

    # 1. ROC curve
    ax = axes[0]
    ax.plot(fpr_arr, tpr_arr, lw=2, color="#4C72B0", label=f"AUROC = {auroc:.3f}")
    ax.plot([0, 1], [0, 1], "--", color="grey", lw=1)
    ax.axvline(fpr70, color="#DD8452", linestyle=":", lw=1.5,
               label=f"FPR@TPR70 = {fpr70:.3f}")
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title("ROC Curve")
    ax.legend(); ax.grid(alpha=0.3)

    # 2. Confidence score distributions
    ax = axes[1]
    ax.hist(ood_conf, bins=60, alpha=0.65, color="#DD8452",
            label="OOD (AG News)", density=True)
    ax.hist(id_conf,  bins=60, alpha=0.65, color="#4C72B0",
            label="ID (20 Newsgroups)", density=True)
    ax.set_xlabel("Max Softmax Probability")
    ax.set_ylabel("Density")
    ax.set_title("Confidence Score Distributions")
    ax.legend(); ax.grid(alpha=0.3)

    # 3. FP / FN trade-off vs threshold tau
    tau_range = np.linspace(0, 1, 300)
    fp_curve  = [(ood_conf >= t).mean() for t in tau_range]
    fn_curve  = [(id_conf  <  t).mean() for t in tau_range]
    ax = axes[2]
    ax.plot(tau_range, fp_curve, lw=2, color="#DD8452",
            label="FPR — OOD wrongly accepted")
    ax.plot(tau_range, fn_curve, lw=2, color="#4C72B0",
            label="FNR — ID wrongly rejected")
    ax.set_xlabel("Threshold tau")
    ax.set_ylabel("Error Rate")
    ax.set_title("FP / FN Trade-off vs Threshold")
    ax.legend(); ax.grid(alpha=0.3)

    fig.tight_layout()
    tag  = f"T{int(temperature * 100):04d}"
    path = os.path.join(save_dir, f"ood_detection_{tag}.png")
    fig.savefig(path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"  Plot saved -> {path}")

    return dict(auroc=auroc, acc=acc, fpr70=fpr70,
                id_conf=id_conf, ood_conf=ood_conf, temperature=temperature)

print("\n[3/3] Evaluating across temperatures ...")
temperatures_ood = [3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
all_ood_results  = {}

for T in temperatures_ood:
    print(f"\n{'─' * 58}")
    print(f"  Temperature T = {T}")
    print(f"{'─' * 58}")
    all_ood_results[T] = evaluate_ood_detection(
        id_logits_ood, id_true_ood, ood_logits_ood,
        temperature=T, save_dir=config.output_dir,
    )


[3/3] Evaluating across temperatures ...

──────────────────────────────────────────────────────────
  Temperature T = 3.0
──────────────────────────────────────────────────────────

  T=3.00  |  AUROC=0.7080  Accuracy=0.6870  FPR@TPR70=0.3835

     tau |      FPR |      FNR |    ID acc (retained) |  % ID kept
  ──────────────────────────────────────────────────────────────
    0.10 |   1.0000 |   0.0289 |               0.7525 |      97.1%
    0.20 |   0.9475 |   0.0475 |               0.7653 |      95.2%
    0.30 |   0.7855 |   0.0799 |               0.7833 |      92.0%
    0.40 |   0.6400 |   0.1394 |               0.8143 |      86.1%
    0.50 |   0.5235 |   0.2108 |               0.8486 |      78.9%
    0.60 |   0.4180 |   0.2746 |               0.8755 |      72.5%
    0.70 |   0.3295 |   0.3394 |               0.9007 |      66.1%
    0.80 |   0.2470 |   0.4206 |               0.9225 |      57.9%
    0.90 |   0.1750 |   0.5404 |               0.9555 |      46.0%
  Plot saved -> /ka

## 12.4 OOD Detection Summery

In [14]:

print(f"  {'T':>5} | {'AUROC':>7} | {'Accuracy':>8} | {'FPR@TPR70':>10}")
print(f"  {'─' * 42}")
best_T_ood = max(all_ood_results, key=lambda t: all_ood_results[t]["auroc"])
for T, r in sorted(all_ood_results.items()):
    flag = "  <- best AUROC" if T == best_T_ood else ""
    print(f"  {T:5.2f} | {r['auroc']:7.4f} | {r['acc']:8.4f} | {r['fpr70']:10.4f}{flag}")

      T |   AUROC | Accuracy |  FPR@TPR70
  ──────────────────────────────────────────
   3.00 |  0.7080 |   0.6870 |     0.3835
   4.00 |  0.7186 |   0.6934 |     0.3540
   5.00 |  0.7253 |   0.6996 |     0.3365
   6.00 |  0.7297 |   0.7038 |     0.3275
   7.00 |  0.7327 |   0.7050 |     0.3210
   8.00 |  0.7349 |   0.7040 |     0.3165  <- best AUROC
