!/usr/bin/env python
coding: utf-8

# ModernBERT Fine-tuning for 20 Newsgroups Text Classification

Optimized for **Kaggle T4 x2** (2Ã— NVIDIA Tesla T4, 16GB VRAM each).
Uses `DataParallel` to leverage both GPUs and
T4-optimized settings (Tensor Cores, cuDNN, FP16).

## 1. Design Decisions (ModernBERT Full Fine-tuning + T4 x2)

| Parameter | Value | Justification |
|-----------|-------|---------------|
| **Model** | answerdotai/ModernBERT-base | 22-layer modernized BERT with RoPE, GeGLU, 8192 context |
| **Multi-GPU** | DataParallel (2Ã— T4) | Simple multi-GPU, ~1.8Ã— speedup |
| **Frozen Layers** | None (all 22 trainable) | Full fine-tuning for maximum accuracy |
| **Learning Rate (encoder)** | 2e-5 | Lower LR prevents catastrophic forgetting with all layers trainable |
| **Learning Rate (head)** | 1e-3 | 50Ã— encoder LR â€” head is randomly initialized, needs fast convergence |
| **Batch Size** | 32 (16 per GPU) | Good balance of gradient quality and memory |
| **Epochs** | 3 | Full fine-tuning converges faster; 3 epochs prevents overfitting |
| **Max Length** | 512 | Captures more newsgroup post context; ModernBERT handles efficiently |
| **Weight Decay** | 0.01 | Moderate regularization â€” not too aggressive for full fine-tuning |
| **Warmup** | 10% of steps | Sufficient warmup for all-layers training stability |
| **Optimizer** | AdamW | Weight-decoupled Adam for transformers |
| **Scheduler** | Cosine warmup + decay | Smooth convergence, better than linear for ModernBERT |
| **FP16** | Yes (Tensor Cores) | T4 has FP16 Tensor Cores â€” ~2Ã— throughput |

## 2. Imports

In [None]:


import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import random
import time
from typing import Tuple
from dataclasses import dataclass, field

import numpy as np
import torch
import torch.nn as nn

# T4 supports Triton (CUDA 7.5 >= 7.0), so torch.compile works!
# Enable cuDNN auto-tuner for 
# optimal convolution algorithms on T4
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torch.optim import AdamW

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoConfig,
    get_cosine_schedule_with_warmup,
    DataCollatorWithPadding
)
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix
)
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)} ({torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB)")

## 3. Configuration

All hyperparameters are centralized here, optimized for Kaggle T4 x2.

In [None]:


"""
Configuration for ModernBERT Full Fine-tuning on 20 Newsgroups
Optimized for ModernBERT-base + Kaggle T4 x2 (2Ã— NVIDIA Tesla T4, 16GB each)

Full Fine-tuning Strategy (Maximum Accuracy):
----------------------------------------------
1. ALL layers trainable (no freezing)
   - Full fine-tuning allows ModernBERT to fully adapt its representations
   - With ~12k examples, overfitting is managed via low LR + weight decay + few epochs
   
2. Discriminative Learning Rates: 2e-5 (encoder), 1e-3 (head)
   - Encoder LR of 2e-5 prevents catastrophic forgetting of pretrained knowledge
   - Head LR of 1e-3 (50Ã—) allows the randomly initialized classifier to converge fast
   
3. Weight Decay: 0.01
   - Moderate regularization â€” prevents overfitting without constraining the model
   - Less aggressive than 0.1 which was over-regularizing with full fine-tuning
   
4. LR Schedule: Cosine with 10% warmup
   - Cosine annealing provides smooth convergence
   - 10% warmup stabilizes gradients when all layers are trainable
   
5. Batch Size: 32 total (16 per GPU)
   - Good gradient quality per step
   - Comfortably fits in T4 VRAM with FP16 (all layers need gradients)
   
6. Max Length: 512
   - Newsgroup posts average ~300 tokens; 512 captures most content
   - ModernBERT handles longer sequences efficiently via FlashAttention
   
7. Epochs: 3
   - Full fine-tuning converges faster than partial fine-tuning
   - 3 epochs prevents overfitting (all ~149M params are updating)
   
T4 x2 Hardware Optimizations:
------------------------------
1. FP16 Mixed Precision (T4 Tensor Cores â†’ ~2Ã— throughput)
2. DataParallel across both T4s
3. num_workers=4 (Kaggle CPU cores), pin_memory=True
4. cuDNN auto-tune benchmark
"""

@dataclass
class Config:
    """Configuration class with all hyperparameters and settings."""
    
    # Model Configuration
    model_name: str = "answerdotai/ModernBERT-base"
    num_labels: int = 20
    
    # Layer Freezing Configuration
    num_layers_to_freeze: int = 0  # 0 = all layers trainable (full fine-tuning)
    
    # Training Hyperparameters â€” Full fine-tuning for maximum accuracy
    learning_rate_encoder: float = 2e-5   # Low LR prevents catastrophic forgetting
    learning_rate_head: float = 1e-3      # 50Ã— encoder LR for randomly initialized head
    batch_size: int = 32  # 32 total â†’ 16 per GPU with DataParallel
    num_epochs: int = 3   # Full fine-tuning converges fast; 3 epochs prevents overfitting
    warmup_ratio: float = 0.1   # 10% warmup for stability with all layers trainable
    weight_decay: float = 0.01  # Moderate regularization for full fine-tuning
    label_smoothing: float = 0.1  # Prevents overconfident predictions on noisy labels
    max_grad_norm: float = 1.0
    
    # Data Configuration
    dataset_name: str = "SetFit/20_newsgroups"
    max_length: int = 512  # ModernBERT handles 512 efficiently; captures more post content
    num_workers: int = 4   # Kaggle has 4 CPU cores
    
    # Training Settings
    seed: int = 42
    use_fp16: bool = True   # T4 Tensor Cores excel at FP16
    use_multi_gpu: bool = True  # Enable DataParallel for 2Ã— T4
    save_model: bool = True
    output_dir: str = "./output"
    
    # Device (auto-detected)
    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
    
    def __post_init__(self):
        if self.device == "cpu":
            self.use_fp16 = False
            self.use_multi_gpu = False
        # Only use multi-GPU if more than 1 GPU is available
        if torch.cuda.is_available() and torch.cuda.device_count() < 2:
            self.use_multi_gpu = False
            
    def to_dict(self) -> dict:
        return {
            "model_name": self.model_name,
            "num_labels": self.num_labels,
            "num_layers_to_freeze": self.num_layers_to_freeze,
            "learning_rate_encoder": self.learning_rate_encoder,
            "learning_rate_head": self.learning_rate_head,
            "batch_size": self.batch_size,
            "num_epochs": self.num_epochs,
            "warmup_ratio": self.warmup_ratio,
            "weight_decay": self.weight_decay,
            "label_smoothing": self.label_smoothing,
            "max_length": self.max_length,
            "num_workers": self.num_workers,
            "seed": self.seed,
            "use_fp16": self.use_fp16,
            "use_multi_gpu": self.use_multi_gpu,
            "device": self.device,
            "num_gpus": torch.cuda.device_count() if torch.cuda.is_available() else 0,
        }

# Initialize configuration
config = Config()

print("Configuration:")
for key, value in config.to_dict().items():
    print(f"  {key}: {value}")

## 4. Data Loading

Load the 20 newsgroups dataset from HuggingFace and tokenize with ModernBERT tokenizer.

In [None]:


"""
Data Loading and Preprocessing for 20 Newsgroups

T4 x2 Optimizations:
---------------------
1. num_workers=4: Kaggle provides 4 CPU cores â€” use all for data loading
2. pin_memory=True: Enables fast CPUâ†’GPU transfers via page-locked memory
3. prefetch_factor=2: Pre-load 2 batches per worker to keep GPUs fed
4. persistent_workers=True: Avoid worker respawn overhead between epochs
"""

def get_label_names() -> list:
    """Get the 20 newsgroup category names."""
    return [
        'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc',
        'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x',
        'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball',
        'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med',
        'sci.space', 'soc.religion.christian', 'talk.politics.guns',
        'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'
    ]


def load_and_prepare_data(config) -> Tuple[DataLoader, DataLoader, list]:
    """Load 20 newsgroups dataset, tokenize with DYNAMIC padding, create DataLoaders.
    
    Returns:
        train_loader, test_loader, label_names
    """
    print(f"Loading dataset: {config.dataset_name}")
    
    # Load raw dataset from HuggingFace
    dataset = load_dataset(config.dataset_name)
    
    train_dataset = dataset['train']
    test_dataset = dataset['test']
    
    # Extract label names directly from the dataset (not hardcoded!)
    label_names = train_dataset.features['label'].names
    
    print(f"Train size: {len(train_dataset)}")
    print(f"Test size: {len(test_dataset)}")
    print(f"Labels ({len(label_names)}): {label_names}")
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            truncation=True,
            padding=False,  # DYNAMIC padding via DataCollatorWithPadding
            max_length=config.max_length,
            return_tensors=None
        )
    
    # Apply tokenization
    print("Tokenizing datasets...")
    train_dataset = train_dataset.map(tokenize_function, batched=True, desc="Tokenizing train")
    test_dataset = test_dataset.map(tokenize_function, batched=True, desc="Tokenizing test")
    
    # Keep only needed columns
    keep_columns = ['input_ids', 'attention_mask', 'label']
    remove_cols = [c for c in train_dataset.column_names if c not in keep_columns]
    if remove_cols:
        train_dataset = train_dataset.remove_columns(remove_cols)
        test_dataset = test_dataset.remove_columns(remove_cols)
    
    # Dynamic padding collator â€” pads each batch to its max length
    # instead of padding ALL sequences to 512. This is the #1 accuracy fix.
    data_collator = DataCollatorWithPadding(tokenizer)
    
    # Create DataLoaders â€” optimized for T4 x2
    use_cuda = config.device == 'cuda'
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=use_cuda,
        prefetch_factor=2 if config.num_workers > 0 else None,
        persistent_workers=True if config.num_workers > 0 else False,
        collate_fn=data_collator,
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=use_cuda,
        prefetch_factor=2 if config.num_workers > 0 else None,
        persistent_workers=True if config.num_workers > 0 else False,
        collate_fn=data_collator,
    )
    
    print(f"Created DataLoaders - Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")
    print(f"  Batch size: {config.batch_size} (effective per-GPU: {config.batch_size // max(1, torch.cuda.device_count()) if use_cuda else config.batch_size})")
    print(f"  Workers: {config.num_workers}, pin_memory: {use_cuda}")
    print(f"  Padding: DYNAMIC (per-batch) via DataCollatorWithPadding")
    
    return train_loader, test_loader, label_names


# Load data
train_loader, test_loader, label_names = load_and_prepare_data(config)

## 5. Model

Initialize ModernBERT model with classification head, apply layer freezing,
and wrap in DataParallel for multi-GPU training.

In [None]:


"""
ModernBERT Model for Text Classification with Layer Freezing + Multi-GPU

T4 x2 Optimizations:
---------------------
1. DataParallel wraps the model to split batches across both T4 GPUs
2. torch.compile (optional) â€” T4 supports Triton for kernel fusion
3. Layer freezing reduces memory footprint â†’ larger effective batch sizes
"""

def disable_modernbert_compiled_mlp(model: nn.Module) -> int:
    """Disable ModernBERT's internal compiled MLP to avoid Dynamo/FX conflicts.

    Some ModernBERT variants wire an attribute like `compiled_mlp` that is a
    Dynamo-optimized callable. If the overall model is later traced/compiled
    (directly or indirectly), PyTorch can raise:
    "Detected that you are using FX to symbolically trace a dynamo-optimized function".

    This function forces the eager MLP path when possible.
    """

    import inspect

    patched = 0

    for module in model.modules():
        if not hasattr(module, "compiled_mlp"):
            continue

        compiled = getattr(module, "compiled_mlp")

        # 1) If this is a torch.compile() OptimizedModule, unwrap to eager.
        eager = getattr(compiled, "_orig_mod", None)
        if eager is not None:
            try:
                setattr(module, "compiled_mlp", eager)
                patched += 1
                continue
            except Exception:
                pass

        # 2) Prefer swapping to the eager MLP module if it exists.
        if hasattr(module, "mlp") and callable(getattr(module, "mlp")):
            try:
                setattr(module, "compiled_mlp", getattr(module, "mlp"))
                patched += 1
                continue
            except Exception:
                pass

        # 3) If it's a wrapped callable, unwrap it.
        if callable(compiled):
            try:
                unwrapped = inspect.unwrap(compiled)
                if unwrapped is not compiled:
                    setattr(module, "compiled_mlp", unwrapped)
                    patched += 1
                    continue
            except Exception:
                pass

        # 4) Last resort: wrap callable to be Dynamo-disabled.
        try:
            import torch._dynamo  # type: ignore

            if callable(compiled):
                setattr(module, "compiled_mlp", torch._dynamo.disable(compiled))
                patched += 1
        except Exception:
            pass

    return patched

def get_model(config):
    """Initialize ModernBERT model with layer freezing and optional multi-GPU."""
    print(f"Loading model: {config.model_name}")
    print(f"Number of classes: {config.num_labels}")
    
    model_config = AutoConfig.from_pretrained(
        config.model_name,
        num_labels=config.num_labels,
        finetuning_task="text-classification"
    )
    
    model = AutoModelForSequenceClassification.from_pretrained(
        config.model_name,
        config=model_config
    )

    # Avoid Dynamo/FX tracing conflicts seen with some ModernBERT builds.
    patched = disable_modernbert_compiled_mlp(model)
    if patched:
        print(f"Disabled internal compiled MLP in {patched} module(s)")
    
    # =========================================================
    # Layer Freezing / Full Fine-tuning Strategy
    # =========================================================
    # ModernBERT-base architecture:
    #   - Embeddings layer
    #   - 22 transformer layers (model.model.layers[0..21])
    #   - Classification head (model.classifier or model.head)
    #
    # If num_layers_to_freeze == 0: ALL parameters are trainable
    # Otherwise: freeze embeddings + first N encoder layers
    # =========================================================
    
    num_total_layers = len(model.model.layers)
    num_to_freeze = config.num_layers_to_freeze
    num_to_train = num_total_layers - num_to_freeze
    
    if num_to_freeze == 0:
        # â”€â”€ Full fine-tuning: all parameters trainable â”€â”€
        print(f"\nFull Fine-tuning Mode:")
        print(f"  All {num_total_layers} encoder layers + embeddings + head are trainable")
        for param in model.parameters():
            param.requires_grad = True
    else:
        # â”€â”€ Partial fine-tuning: freeze early layers â”€â”€
        # Step 1: Freeze ALL parameters first
        for param in model.parameters():
            param.requires_grad = False
        
        # Step 2: Unfreeze the last few encoder layers
        print(f"\nLayer Freezing Configuration:")
        print(f"  Total encoder layers: {num_total_layers}")
        print(f"  Frozen layers: {num_to_freeze} (layers 0-{num_to_freeze - 1})")
        print(f"  Trainable layers: {num_to_train} (layers {num_to_freeze}-{num_total_layers - 1})")
        
        for i in range(num_to_freeze, num_total_layers):
            for param in model.model.layers[i].parameters():
                param.requires_grad = True
        
        # Step 3: Unfreeze the classification head
        head_unfrozen = False
        for name, param in model.named_parameters():
            if any(keyword in name for keyword in ['classifier', 'head', 'cls', 'score']):
                param.requires_grad = True
                head_unfrozen = True
        
        if not head_unfrozen:
            print("WARNING: Could not identify classification head parameters to unfreeze!")
            for name, param in model.named_parameters():
                if 'layers' not in name and 'embeddings' not in name:
                    param.requires_grad = True
        
        # Step 4: Also unfreeze the final layer norm if it exists
        for name, param in model.named_parameters():
            if 'final_norm' in name or 'norm' in name.split('.')[-1]:
                parts = name.split('.')
                is_in_frozen_layer = False
                for j, part in enumerate(parts):
                    if part == 'layers' and j + 1 < len(parts):
                        try:
                            layer_idx = int(parts[j + 1])
                            if layer_idx < num_to_freeze:
                                is_in_frozen_layer = True
                        except ValueError:
                            pass
                if not is_in_frozen_layer:
                    param.requires_grad = True
    
    # Print parameter statistics
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    
    print(f"\nParameter Statistics:")
    print(f"  Total parameters:     {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)")
    print(f"  Frozen parameters:    {frozen_params:,} ({100*frozen_params/total_params:.1f}%)")
    
    # Print which parameter groups are trainable
    print(f"\nTrainable parameter groups:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"  âœ“ {name} [{param.numel():,} params]")
    
    # =========================================================
    # Multi-GPU: Wrap with DataParallel for 2Ã— T4
    # =========================================================
    if config.use_multi_gpu and torch.cuda.device_count() > 1:
        print(f"\nðŸš€ Wrapping model in DataParallel across {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
    
    model.to(config.device)
    
    # Print GPU memory usage after loading
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            print(f"  GPU {i} memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")
    
    return model


# Initialize model
model = get_model(config)

## 6. Trainer

Training loop with discriminative learning rates, AdamW optimizer,
cosine LR schedule, FP16 with T4 Tensor Cores, and multi-GPU support.

In [None]:


"""
Training Module for ModernBERT Fine-tuning on T4 x2

T4 x2 Optimizations:
---------------------
1. FP16 with GradScaler â€” T4 Tensor Cores give ~2Ã— throughput
2. DataParallel handles batch splitting across GPUs automatically
3. Loss averaging across GPUs for DataParallel (mean of per-GPU losses)
4. Gradient accumulation ready if you want even larger effective batch sizes
5. CUDA event-based timing for accurate GPU measurements
"""

class Trainer:
    """Trainer class for ModernBERT fine-tuning with multi-GPU and FP16."""
    
    def __init__(self, model, config, train_loader):
        self.model = model
        self.config = config
        self.train_loader = train_loader
        self.device = config.device
        self.is_parallel = isinstance(model, nn.DataParallel)
        
        # Get the underlying model for parameter grouping
        base_model = model.module if self.is_parallel else model
        
        # Create parameter groups with discriminative learning rates
        encoder_params = []
        head_params = []
        
        for name, param in base_model.named_parameters():
            if not param.requires_grad:
                continue
            if any(keyword in name for keyword in ['classifier', 'head', 'cls', 'score']):
                head_params.append(param)
            else:
                encoder_params.append(param)
        
        self.optimizer = AdamW([
            {
                'params': encoder_params,
                'lr': config.learning_rate_encoder,
                'weight_decay': config.weight_decay
            },
            {
                'params': head_params,
                'lr': config.learning_rate_head,
                'weight_decay': config.weight_decay  # Regularize the head too
            }
        ])
        
        self.total_steps = len(train_loader) * config.num_epochs
        self.warmup_steps = int(self.total_steps * config.warmup_ratio)
        
        # Label smoothing loss â€” prevents overconfident predictions on noisy labels
        self.criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
        
        self.scheduler = get_cosine_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_steps
        )
        
        self.scaler = GradScaler('cuda') if config.use_fp16 else None
        self.use_fp16 = config.use_fp16
        
        self.history = {'train_loss': [], 'train_accuracy': [], 'learning_rate': []}
        
        num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
        print(f"\nTraining Configuration:")
        print(f"  Device: {self.device} ({'DataParallel on ' + str(num_gpus) + ' GPUs' if self.is_parallel else 'single GPU'})")
        print(f"  Total batch size: {config.batch_size}")
        print(f"  Per-GPU batch size: {config.batch_size // num_gpus}")
        print(f"  Total steps: {self.total_steps}")
        print(f"  Warmup steps: {self.warmup_steps}")
        print(f"  Encoder LR: {config.learning_rate_encoder}")
        print(f"  Head LR: {config.learning_rate_head}")
        print(f"  Encoder params: {sum(p.numel() for p in encoder_params):,}")
        print(f"  Head params: {sum(p.numel() for p in head_params):,}")
        print(f"  Mixed precision (FP16): {self.use_fp16}")
    
    def train_epoch(self, epoch):
        """Train for one epoch."""
        self.model.train()
        total_loss = 0
        num_batches = 0
        correct = 0
        total = 0

        # Some ModernBERT builds create/overwrite `compiled_mlp` lazily.
        # Re-disable it once per epoch to avoid nested Dynamo/FX issues.
        base_model = self.model.module if self.is_parallel else self.model
        try:
            disable_modernbert_compiled_mlp(base_model)
        except Exception:
            pass
        
        progress_bar = tqdm(
            self.train_loader,
            desc=f"Epoch {epoch+1}/{self.config.num_epochs}",
            leave=True
        )
        
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(self.device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(self.device, non_blocking=True)
            labels = batch['labels'].to(self.device, non_blocking=True)
            
            self.optimizer.zero_grad(set_to_none=True)  # More memory efficient than zero_grad()
            
            if self.use_fp16:
                with autocast('cuda'):
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                    )
                    # Compute loss with label smoothing (cleaner than model's built-in)
                    loss = self.criterion(outputs.logits, labels)

                preds = torch.argmax(outputs.logits, dim=-1)
                correct += (preds == labels).sum().item()
                total += labels.numel()
                
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                )
                # Compute loss with label smoothing
                loss = self.criterion(outputs.logits, labels)

                preds = torch.argmax(outputs.logits, dim=-1)
                correct += (preds == labels).sum().item()
                total += labels.numel()
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                self.optimizer.step()
            
            self.scheduler.step()
            
            total_loss += loss.item()
            num_batches += 1
            running_acc = (correct / total) if total else 0.0
            
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{running_acc:.4f}',
                'lr': f'{self.scheduler.get_last_lr()[0]:.2e}'
            })

        epoch_loss = total_loss / num_batches
        epoch_acc = (correct / total) if total else 0.0
        return epoch_loss, epoch_acc
    
    def train(self):
        """Full training loop."""
        print("\n" + "="*60)
        print("Starting Training")
        if self.is_parallel:
            print(f"  Using {torch.cuda.device_count()} GPUs via DataParallel")
        print("="*60 + "\n")
        
        start_time = time.time()
        
        for epoch in range(self.config.num_epochs):
            epoch_start = time.time()
            
            train_loss, train_acc = self.train_epoch(epoch)
            self.history['train_loss'].append(train_loss)
            self.history['train_accuracy'].append(train_acc)
            
            current_lr = self.scheduler.get_last_lr()[0]
            self.history['learning_rate'].append(current_lr)
            
            epoch_time = time.time() - epoch_start
            
            # Print GPU memory stats per epoch
            mem_info = ""
            if torch.cuda.is_available():
                for i in range(torch.cuda.device_count()):
                    allocated = torch.cuda.memory_allocated(i) / 1024**3
                    peak = torch.cuda.max_memory_allocated(i) / 1024**3
                    mem_info += f" | GPU{i}: {allocated:.1f}GB (peak {peak:.1f}GB)"
            
            print(f"\nEpoch {epoch+1}/{self.config.num_epochs} - "
                f"Train Loss: {train_loss:.4f} - "
                f"Train Acc: {train_acc:.4f} ({train_acc*100:.2f}%) - "
                f"LR: {current_lr:.2e} - "
                f"Time: {epoch_time:.1f}s{mem_info}")
        
        total_time = time.time() - start_time
        print(f"\nTraining Complete! Total time: {total_time/60:.1f} minutes")
        
        # Final memory summary
        if torch.cuda.is_available():
            print("\nGPU Memory Summary:")
            for i in range(torch.cuda.device_count()):
                peak = torch.cuda.max_memory_allocated(i) / 1024**3
                total = torch.cuda.get_device_properties(i).total_memory / 1024**3
                print(f"  GPU {i}: Peak {peak:.2f} GB / {total:.1f} GB ({100*peak/total:.0f}% utilization)")
        
        return self.history

## 7. Train the Model

In [None]:


# Set random seed for reproducibility
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(config.seed)

# Initialize trainer and train
trainer = Trainer(model, config, train_loader)
history = trainer.train()

## 8. Evaluation

Evaluate the trained model on the test set.

**Important:** We unwrap DataParallel and run inference on a single GPU.
DataParallel.replicate() creates fresh model copies per forward pass,
re-introducing ModernBERT's compiled_mlp (torch.compile) attributes
which conflict with FX tracing. Single-GPU eval avoids this entirely
and is standard practice (DataParallel mainly benefits training).

In [None]:


"""
Evaluation Module for ModernBERT Fine-tuning

Metrics:
- Accuracy: Overall correctness
- Macro F1: Balanced performance across all classes
- Per-class metrics: Identify weak categories
- Confusion matrix: Reveals class confusion patterns

Note: Evaluation runs on a single GPU (unwrapped from DataParallel)
to avoid Dynamo/FX tracing conflicts with ModernBERT's compiled MLPs.
"""

@torch.no_grad()
def evaluate(model, test_loader, config):
    """Comprehensive evaluation on test set (single-GPU, Dynamo-safe)."""
    print("\n" + "="*60)
    print("Evaluating on Test Set")
    print("="*60 + "\n")
    
    # â”€â”€â”€ Step 1: Unwrap DataParallel â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    # DataParallel.replicate() deep-copies the model per GPU on every
    # forward pass, re-creating compiled_mlp attrs we tried to remove.
    # Solution: just use the base model on GPU 0 for inference.
    is_parallel = isinstance(model, nn.DataParallel)
    eval_model = model.module if is_parallel else model
    
    if is_parallel:
        print("  Unwrapped DataParallel â†’ evaluating on single GPU")
    
    # â”€â”€â”€ Step 2: Fully disable torch.compile / Dynamo â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    # Reset Dynamo state and disable compiled MLPs on the base model
    # to prevent any FX-vs-Dynamo conflicts during inference.
    try:
        import torch._dynamo
        torch._dynamo.reset()
        # Suppress any remaining Dynamo errors as fallback
        torch._dynamo.config.suppress_errors = True
    except Exception:
        pass
    
    patched = disable_modernbert_compiled_mlp(eval_model)
    if patched:
        print(f"  Disabled compiled MLP in {patched} module(s)")
    
    # â”€â”€â”€ Step 3: Set eval mode and device â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    eval_model.eval()
    eval_model.to(config.device)
    
    all_predictions = []
    all_labels = []
    total_loss = 0
    
    # â”€â”€â”€ Step 4: Inference loop â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    for batch in tqdm(test_loader, desc="Evaluating"):
        input_ids = batch['input_ids'].to(config.device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(config.device, non_blocking=True)
        labels = batch['labels'].to(config.device, non_blocking=True)
        
        with autocast('cuda', enabled=config.use_fp16):
            outputs = eval_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
        
        total_loss += outputs.loss.item()
        predictions = torch.argmax(outputs.logits, dim=-1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    # â”€â”€â”€ Step 5: Calculate metrics â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    accuracy = accuracy_score(all_labels, all_predictions)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro'
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted'
    )
    avg_loss = total_loss / len(test_loader)
    
    report = classification_report(all_labels, all_predictions, target_names=label_names, digits=4)
    conf_matrix = confusion_matrix(all_labels, all_predictions)
    
    # â”€â”€â”€ Step 6: Print results â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    print("\n" + "="*60)
    print("EVALUATION RESULTS")
    print("="*60)
    print(f"\n[Overall Metrics]")
    print(f"  Test Loss: {avg_loss:.4f}")
    print(f"  Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"\n[Macro Averages]")
    print(f"  Precision: {precision_macro:.4f}")
    print(f"  Recall: {recall_macro:.4f}")
    print(f"  F1 Score: {f1_macro:.4f}")
    print(f"\n[Weighted Averages]")
    print(f"  Precision: {precision_weighted:.4f}")
    print(f"  Recall: {recall_weighted:.4f}")
    print(f"  F1 Score: {f1_weighted:.4f}")
    print("\n" + "="*60)
    print("CLASSIFICATION REPORT (Per-Class)")
    print("="*60)
    print(report)
    
    return {
        'test_loss': avg_loss,
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'classification_report': report,
        'confusion_matrix': conf_matrix,
        'predictions': all_predictions,
        'labels': all_labels,
        'label_names': label_names
    }


# Evaluate
results = evaluate(model, test_loader, config)

## 9. Final Summary

In [None]:


print("\n" + "="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)
print(f"  Model: {config.model_name}")
print(f"  GPUs: {torch.cuda.device_count() if torch.cuda.is_available() else 0}Ã— {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
print(f"  Frozen layers: {config.num_layers_to_freeze} / 22")
print(f"  Trainable layers: {22 - config.num_layers_to_freeze} + classification head")
print(f"  Batch size: {config.batch_size} total")
print(f"  Test Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
print(f"  Macro F1 Score: {results['f1_macro']:.4f}")
print(f"  Weighted F1 Score: {results['f1_weighted']:.4f}")
print("="*70)

## 10. Save Model (Optional)

In [None]:


if config.save_model:
    os.makedirs(config.output_dir, exist_ok=True)
    
    # Save model â€” unwrap DataParallel if needed
    base_model = model.module if isinstance(model, nn.DataParallel) else model
    model_path = os.path.join(config.output_dir, "model")
    base_model.save_pretrained(model_path)
    print(f"Model saved to: {model_path}")
    
    # Save tokenizer for easy reloading
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.save_pretrained(model_path)
    print(f"Tokenizer saved to: {model_path}")
    
    # Save training history
    history_path = os.path.join(config.output_dir, "training_history.json")
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    print(f"Training history saved to: {history_path}")
    
    # Save evaluation metrics
    metrics = {
        'model_name': config.model_name,
        'num_layers_frozen': config.num_layers_to_freeze,
        'num_gpus': torch.cuda.device_count() if torch.cuda.is_available() else 0,
        'batch_size': config.batch_size,
        'test_loss': results['test_loss'],
        'accuracy': results['accuracy'],
        'precision_macro': results['precision_macro'],
        'recall_macro': results['recall_macro'],
        'f1_macro': results['f1_macro'],
        'precision_weighted': results['precision_weighted'],
        'recall_weighted': results['recall_weighted'],
        'f1_weighted': results['f1_weighted'],
    }
    metrics_path = os.path.join(config.output_dir, "evaluation_metrics.json")
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=2)
    print(f"Evaluation metrics saved to: {metrics_path}")