!/usr/bin/env python
coding: utf-8

# ModernBERT-Large Fine-tuning for 20 Newsgroups Text Classification

An implementation of ModernBERT-large fine-tuning for multi-class text classification
using the 20 newsgroups dataset, optimized for Kaggle T4 x2 GPUs (2 × 16 GB VRAM).

## 1. Design Decisions

| Parameter | Value | Justification |
|-----------|-------|---------------|
| **Model** | answerdotai/ModernBERT-large | 395M params, 28 layers, RoPE, 8192 context |
| **Learning Rate** | 3e-5 | Lower than base for stability with large model |
| **Batch Size** | 16 per GPU × 2 GPUs = 32 | Fits T4 16GB with FP16 at seq len 256 |
| **Gradient Accum** | 2 | Effective batch size 64 |
| **Epochs** | 4 | Large model benefits from more epochs |
| **Max Length** | 256 | Good coverage of newsgroup posts |
| **Multi-GPU** | DataParallel on 2 T4s | Near-linear speedup |
| **Layer Freezing** | Bottom 50% (14/28 layers) | Saves memory on T4 |

## 2. Imports & Environment Setup

In [None]:


import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Prevent fork warnings with DataLoader

import json
import random
import time
from typing import Tuple
from dataclasses import dataclass, field
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torch.optim import AdamW

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoConfig,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix
)
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)} "
              f"({torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB)")
    torch.backends.cudnn.benchmark = True

## 3. Configuration

In [None]:


"""
Configuration for ModernBERT-Large Fine-tuning on 20 Newsgroups
(Optimized for Kaggle T4 x2 — 2 × 16 GB VRAM)

Design Decisions:
-----------------
1. Model: answerdotai/ModernBERT-large
   - 395M parameters, 28 encoder layers
   - Modern bidirectional encoder pre-trained on 2 trillion tokens
   - Uses RoPE and alternating local-global attention
   
2. Learning Rate: 3e-5
   - Slightly lower than base model for training stability
   - Larger models are more sensitive to high LR
   
3. Batch Size: 16 per GPU (32 total across 2 T4s)
   - With gradient accumulation of 2, effective batch = 64
   - Fits comfortably in T4 16GB with FP16 + layer freezing
   
4. Epochs: 4
   - Large model with frozen layers benefits from more training

5. Layer Freezing: Bottom 50% (14/28 encoder layers)
   - Critical for fitting 395M model on T4 GPUs
   - Lower layers capture universal language features

6. Multi-GPU: DataParallel
   - Automatically splits batches across 2 T4 GPUs
   - Simple, no code changes needed for the training loop
"""

@dataclass
class Config:
    """Configuration class with all hyperparameters and settings."""
    
    # Model Configuration
    model_name: str = "answerdotai/ModernBERT-large"
    num_labels: int = 20
    
    # Training Hyperparameters
    learning_rate: float = 3e-5
    batch_size: int = 8  # Per-GPU batch size (total = batch_size × num_gpus)
    num_epochs: int = 4
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    gradient_accumulation_steps: int = 4  # Effective batch = 8 × 2 GPUs × 4 = 64
    
    # Layer Freezing
    freeze_layers: bool = True
    freeze_ratio: float = 0.5  # Freeze bottom 50% of encoder layers (14/28)
    
    # Data Configuration
    dataset_name: str = "SetFit/20_newsgroups"
    max_length: int = 512  # 91% token coverage (vs 76% at 256)
    
    # Training Settings
    seed: int = 42
    use_fp16: bool = True
    save_model: bool = True
    output_dir: str = "./output"
    
    # Device (auto-detected)
    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
    num_gpus: int = field(default_factory=lambda: torch.cuda.device_count() if torch.cuda.is_available() else 0)
    
    def __post_init__(self):
        if self.device == "cpu":
            self.use_fp16 = False
        # Scale batch size across GPUs
        self.total_batch_size = self.batch_size * max(1, self.num_gpus)
        self.effective_batch_size = self.total_batch_size * self.gradient_accumulation_steps
            
    def to_dict(self) -> dict:
        return {
            "model_name": self.model_name,
            "num_labels": self.num_labels,
            "learning_rate": self.learning_rate,
            "batch_size_per_gpu": self.batch_size,
            "num_gpus": self.num_gpus,
            "total_batch_size": self.total_batch_size,
            "effective_batch_size": self.effective_batch_size,
            "num_epochs": self.num_epochs,
            "warmup_ratio": self.warmup_ratio,
            "weight_decay": self.weight_decay,
            "gradient_accumulation_steps": self.gradient_accumulation_steps,
            "max_length": self.max_length,
            "freeze_layers": self.freeze_layers,
            "freeze_ratio": self.freeze_ratio,
            "seed": self.seed,
            "use_fp16": self.use_fp16,
            "device": self.device,
        }

# Initialize configuration
config = Config()

print("Configuration:")
for key, value in config.to_dict().items():
    print(f"  {key}: {value}")

## 4. Dataset Exploration & Statistical Overview

View the dataset structure, class distribution, and text length statistics
before training.

In [None]:


"""
Dataset Exploration for 20 Newsgroups

Shows:
- Dataset structure and splits
- Class distribution (train & test)
- Text length statistics (chars, words, tokens)
- Sample documents from each class
"""

def explore_dataset(config):
    """Load and display comprehensive dataset statistics."""
    print("\n" + "="*70)
    print("DATASET EXPLORATION: 20 Newsgroups")
    print("="*70)
    
    # Load raw dataset
    dataset = load_dataset(config.dataset_name)
    train_data = dataset['train']
    test_data = dataset['test']
    
    label_names = [
        'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc',
        'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x',
        'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball',
        'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med',
        'sci.space', 'soc.religion.christian', 'talk.politics.guns',
        'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'
    ]
    
    # --- Basic Info ---
    print(f"\n{'─'*50}")
    print(f"  Dataset: {config.dataset_name}")
    print(f"  Number of classes: {len(label_names)}")
    print(f"  Train samples: {len(train_data):,}")
    print(f"  Test samples:  {len(test_data):,}")
    print(f"  Total samples: {len(train_data) + len(test_data):,}")
    print(f"  Features: {list(train_data.features.keys())}")
    print(f"{'─'*50}")
    
    # --- Class Distribution ---
    print(f"\n{'─'*50}")
    print("  CLASS DISTRIBUTION")
    print(f"{'─'*50}")
    
    train_labels = train_data['label']
    test_labels = test_data['label']
    train_counts = Counter(train_labels)
    test_counts = Counter(test_labels)
    
    print(f"\n  {'Category':<35} {'Train':>6} {'Test':>6} {'Total':>6}")
    print(f"  {'─'*55}")
    for i, name in enumerate(label_names):
        tr = train_counts.get(i, 0)
        te = test_counts.get(i, 0)
        bar = '█' * (tr // 20)
        print(f"  {name:<35} {tr:>6} {te:>6} {tr+te:>6}  {bar}")
    
    print(f"  {'─'*55}")
    print(f"  {'TOTAL':<35} {len(train_data):>6} {len(test_data):>6} {len(train_data)+len(test_data):>6}")
    
    # Class balance metrics
    train_counts_list = [train_counts.get(i, 0) for i in range(len(label_names))]
    print(f"\n  Train class balance:")
    print(f"    Min samples/class: {min(train_counts_list)}")
    print(f"    Max samples/class: {max(train_counts_list)}")
    print(f"    Mean samples/class: {np.mean(train_counts_list):.1f}")
    print(f"    Std samples/class: {np.std(train_counts_list):.1f}")
    print(f"    Imbalance ratio (max/min): {max(train_counts_list)/max(min(train_counts_list),1):.2f}")
    
    # --- Text Length Statistics ---
    print(f"\n{'─'*50}")
    print("  TEXT LENGTH STATISTICS (Training Set)")
    print(f"{'─'*50}")
    
    texts = train_data['text']
    char_lengths = [len(t) for t in texts]
    word_lengths = [len(t.split()) for t in texts]
    
    for metric_name, lengths in [("Character lengths", char_lengths), ("Word counts", word_lengths)]:
        arr = np.array(lengths)
        print(f"\n  {metric_name}:")
        print(f"    Min:    {arr.min():>8,}")
        print(f"    Max:    {arr.max():>8,}")
        print(f"    Mean:   {arr.mean():>8,.1f}")
        print(f"    Median: {np.median(arr):>8,.1f}")
        print(f"    Std:    {arr.std():>8,.1f}")
        print(f"    P25:    {np.percentile(arr, 25):>8,.1f}")
        print(f"    P75:    {np.percentile(arr, 75):>8,.1f}")
        print(f"    P95:    {np.percentile(arr, 95):>8,.1f}")
    
    # Token-level stats with tokenizer
    print(f"\n  Tokenized lengths (using {config.model_name} tokenizer):")
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    
    # Sample for speed (full tokenization on large dataset is slow)
    sample_size = min(2000, len(texts))
    sample_texts = random.sample(texts, sample_size)
    token_lengths = [len(tokenizer.encode(t)) for t in sample_texts]
    arr = np.array(token_lengths)
    
    print(f"    (Sampled {sample_size:,} documents)")
    print(f"    Min:    {arr.min():>8,}")
    print(f"    Max:    {arr.max():>8,}")
    print(f"    Mean:   {arr.mean():>8,.1f}")
    print(f"    Median: {np.median(arr):>8,.1f}")
    print(f"    P95:    {np.percentile(arr, 95):>8,.1f}")
    
    # Coverage at different max_length thresholds
    print(f"\n  Token coverage at different max_length:")
    for ml in [128, 256, 512]:
        coverage = (arr <= ml).sum() / len(arr) * 100
        print(f"    max_length={ml}: {coverage:.1f}% of documents fully covered")
    print(f"    → Using max_length={config.max_length}")
    
    # --- Sample Documents ---
    print(f"\n{'─'*50}")
    print("  SAMPLE DOCUMENTS (first 200 chars)")
    print(f"{'─'*50}")
    
    # Show 1 sample per first 5 classes
    for i in range(min(5, len(label_names))):
        # Find first document with this label
        for j, lbl in enumerate(train_labels):
            if lbl == i:
                text_preview = texts[j][:200].replace('\n', ' ')
                print(f"\n  [{label_names[i]}]")
                print(f"  \"{text_preview}...\"")
                break
    
    print(f"\n{'─'*50}")
    print(f"  (Showing 5 of {len(label_names)} classes)")
    print("="*70 + "\n")
    
    return dataset


# Run dataset exploration
dataset = explore_dataset(config)

## 5. Data Loading & Tokenization

In [None]:


"""
Data Loading and Preprocessing for 20 Newsgroups

Design Decisions:
-----------------
1. Tokenization: ModernBERT-large tokenizer
2. Padding: max_length for uniform batch shapes (better for DataParallel)
3. DataLoader: 4 workers per GPU, pin memory
"""

def get_label_names() -> list:
    """Get the 20 newsgroup category names."""
    return [
        'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc',
        'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x',
        'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball',
        'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med',
        'sci.space', 'soc.religion.christian', 'talk.politics.guns',
        'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'
    ]


def load_and_prepare_data(config, dataset=None) -> Tuple[DataLoader, DataLoader]:
    """Load 20 newsgroups dataset, tokenize, and create DataLoaders."""
    print(f"\nPreparing data for training...")
    
    # Use pre-loaded dataset if available (from exploration step)
    if dataset is None:
        dataset = load_dataset(config.dataset_name)
    
    train_dataset = dataset['train']
    test_dataset = dataset['test']
    
    print(f"  Train size: {len(train_dataset)}")
    print(f"  Test size: {len(test_dataset)}")
    
    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            truncation=True,
            padding='max_length',
            max_length=config.max_length,
            return_tensors=None
        )
    
    # Apply tokenization
    print("  Tokenizing datasets...")
    train_dataset = train_dataset.map(tokenize_function, batched=True, desc="Tokenizing train")
    test_dataset = test_dataset.map(tokenize_function, batched=True, desc="Tokenizing test")
    
    # Set format for PyTorch
    columns = ['input_ids', 'attention_mask', 'label']
    train_dataset.set_format(type='torch', columns=columns)
    test_dataset.set_format(type='torch', columns=columns)
    
    # DataLoader config — use total_batch_size (accounts for multi-GPU)
    num_workers = 4 if config.device == 'cuda' else 0
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.total_batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if config.device == 'cuda' else False,
        drop_last=True  # Avoids uneven batch splits across GPUs
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.total_batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if config.device == 'cuda' else False
    )
    
    print(f"  DataLoaders ready — Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")
    print(f"  Batch size per GPU: {config.batch_size}, Total batch: {config.total_batch_size}, "
          f"Effective: {config.effective_batch_size}")
    
    return train_loader, test_loader


# Load data (reuses the dataset from exploration to avoid re-downloading)
train_loader, test_loader = load_and_prepare_data(config, dataset=dataset)

## 6. Model

Initialize ModernBERT-large with classification head, layer freezing, and multi-GPU support.

In [None]:


"""
ModernBERT-Large Model for Text Classification

Design Decisions:
-----------------
1. Architecture: ModernBERT-large (395M params) + classification head
2. Layer Freezing: Freeze embeddings + bottom 14/28 encoder layers
3. Multi-GPU: Wrap with DataParallel for dual T4s
"""

def get_model(config):
    """Initialize ModernBERT-large with layer freezing and optional DataParallel."""
    print(f"\nLoading model: {config.model_name}")
    print(f"  Number of classes: {config.num_labels}")
    
    model_config = AutoConfig.from_pretrained(
        config.model_name,
        num_labels=config.num_labels,
        finetuning_task="text-classification"
    )
    
    model = AutoModelForSequenceClassification.from_pretrained(
        config.model_name,
        config=model_config
    )
    
    # Layer freezing for efficiency
    if config.freeze_layers:
        # Freeze embeddings
        if hasattr(model, 'model') and hasattr(model.model, 'embeddings'):
            for param in model.model.embeddings.parameters():
                param.requires_grad = False
            print("  ✓ Froze embedding layer")
        elif hasattr(model, 'bert') and hasattr(model.bert, 'embeddings'):
            for param in model.bert.embeddings.parameters():
                param.requires_grad = False
            print("  ✓ Froze embedding layer")
        
        # Freeze bottom encoder layers
        encoder_layers = None
        if hasattr(model, 'model') and hasattr(model.model, 'encoder'):
            encoder = model.model.encoder
            if hasattr(encoder, 'layers'):
                encoder_layers = encoder.layers
            elif hasattr(encoder, 'layer'):
                encoder_layers = encoder.layer
        elif hasattr(model, 'bert') and hasattr(model.bert, 'encoder'):
            encoder = model.bert.encoder
            if hasattr(encoder, 'layer'):
                encoder_layers = encoder.layer
        
        if encoder_layers is not None:
            num_layers = len(encoder_layers)
            num_freeze = int(num_layers * config.freeze_ratio)
            for i, layer in enumerate(encoder_layers):
                if i < num_freeze:
                    for param in layer.parameters():
                        param.requires_grad = False
            print(f"  ✓ Froze {num_freeze}/{num_layers} encoder layers")
        else:
            print("  ⚠ Warning: Could not identify encoder layers for freezing")
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    
    print(f"\n  Parameter Summary:")
    print(f"    Total:     {total_params:>12,}")
    print(f"    Trainable: {trainable_params:>12,} ({100*trainable_params/total_params:.1f}%)")
    print(f"    Frozen:    {frozen_params:>12,} ({100*frozen_params/total_params:.1f}%)")
    
    # Multi-GPU support with DataParallel
    model.to(config.device)
    if config.num_gpus > 1:
        model = nn.DataParallel(model)
        print(f"\n  ✓ DataParallel enabled across {config.num_gpus} GPUs")
    
    return model


# Initialize model
model = get_model(config)

## 7. Trainer

Training loop with AdamW, warmup scheduler, gradient accumulation, FP16, and multi-GPU.

In [None]:


"""
Training Module for ModernBERT-Large Fine-tuning (Dual T4 Optimized)

Key features:
- DataParallel multi-GPU support
- Gradient accumulation (effective batch 64)
- FP16 mixed precision via torch.amp
- Linear warmup + decay scheduler
"""

class Trainer:
    """Trainer class for ModernBERT-large fine-tuning on dual T4 GPUs."""
    
    def __init__(self, model, config, train_loader):
        self.model = model
        self.config = config
        self.train_loader = train_loader
        self.device = config.device
        self.grad_accum_steps = config.gradient_accumulation_steps
        
        # Access underlying model for parameter filtering (DataParallel wraps it)
        base_model = model.module if hasattr(model, 'module') else model
        
        # Only optimize trainable parameters
        self.optimizer = AdamW(
            filter(lambda p: p.requires_grad, base_model.parameters()),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # Total optimizer steps accounts for gradient accumulation
        self.total_steps = (len(train_loader) // self.grad_accum_steps) * config.num_epochs
        self.warmup_steps = int(self.total_steps * config.warmup_ratio)
        
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_steps
        )
        
        # Modern AMP API
        self.scaler = GradScaler("cuda") if config.use_fp16 else None
        self.use_fp16 = config.use_fp16
        
        self.history = {'train_loss': [], 'learning_rate': []}
        
        print(f"\nTraining Configuration:")
        print(f"  Device: {self.device} × {config.num_gpus} GPUs")
        print(f"  Total optimizer steps: {self.total_steps}")
        print(f"  Warmup steps: {self.warmup_steps}")
        print(f"  Gradient accumulation: {self.grad_accum_steps}")
        print(f"  Effective batch size: {config.effective_batch_size}")
        print(f"  Mixed precision (FP16): {self.use_fp16}")
    
    def _get_trainable_params(self):
        """Get trainable parameters from model (handles DataParallel)."""
        base_model = self.model.module if hasattr(self.model, 'module') else self.model
        return filter(lambda p: p.requires_grad, base_model.parameters())
    
    def train_epoch(self, epoch):
        """Train for one epoch with gradient accumulation."""
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        progress_bar = tqdm(
            self.train_loader,
            desc=f"Epoch {epoch+1}/{self.config.num_epochs}",
            leave=True
        )
        
        self.optimizer.zero_grad()
        
        for step, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['label'].to(self.device)
            
            if self.use_fp16:
                with autocast("cuda"):
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels
                    )
                    # DataParallel returns averaged loss across GPUs
                    loss = outputs.loss.mean() / self.grad_accum_steps
                
                self.scaler.scale(loss).backward()
                
                if (step + 1) % self.grad_accum_steps == 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self._get_trainable_params(),
                        self.config.max_grad_norm
                    )
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
            else:
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss.mean() / self.grad_accum_steps
                
                loss.backward()
                
                if (step + 1) % self.grad_accum_steps == 0:
                    torch.nn.utils.clip_grad_norm_(
                        self._get_trainable_params(),
                        self.config.max_grad_norm
                    )
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
            
            total_loss += loss.item() * self.grad_accum_steps
            num_batches += 1
            
            progress_bar.set_postfix({
                'loss': f'{loss.item() * self.grad_accum_steps:.4f}',
                'lr': f'{self.scheduler.get_last_lr()[0]:.2e}'
            })
        
        # Handle remaining gradients
        remaining = len(self.train_loader) % self.grad_accum_steps
        if remaining != 0:
            if self.use_fp16:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    self._get_trainable_params(),
                    self.config.max_grad_norm
                )
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(
                    self._get_trainable_params(),
                    self.config.max_grad_norm
                )
                self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
        
        return total_loss / num_batches
    
    def train(self):
        """Full training loop."""
        print("\n" + "="*60)
        print("Starting Training")
        print("="*60 + "\n")
        
        start_time = time.time()
        
        for epoch in range(self.config.num_epochs):
            epoch_start = time.time()
            
            train_loss = self.train_epoch(epoch)
            self.history['train_loss'].append(train_loss)
            
            current_lr = self.scheduler.get_last_lr()[0]
            self.history['learning_rate'].append(current_lr)
            
            epoch_time = time.time() - epoch_start
            
            print(f"\nEpoch {epoch+1}/{self.config.num_epochs} - "
                  f"Train Loss: {train_loss:.4f} - "
                  f"LR: {current_lr:.2e} - "
                  f"Time: {epoch_time:.1f}s")
            
            # Memory report
            if torch.cuda.is_available():
                for i in range(config.num_gpus):
                    allocated = torch.cuda.memory_allocated(i) / 1024**3
                    reserved = torch.cuda.memory_reserved(i) / 1024**3
                    print(f"  GPU {i} memory: {allocated:.1f} GB allocated, {reserved:.1f} GB reserved")
        
        total_time = time.time() - start_time
        print(f"\nTraining Complete! Total time: {total_time/60:.1f} minutes")
        
        return self.history

## 8. Train the Model

In [None]:


# Set random seed for reproducibility
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(config.seed)

# Initialize trainer and train
trainer = Trainer(model, config, train_loader)
history = trainer.train()

## 9. Evaluation

In [None]:


"""
Evaluation Module

Metrics: Accuracy, Macro/Weighted F1, Per-class report, Confusion matrix
"""

@torch.no_grad()
def evaluate(model, test_loader, config):
    """Comprehensive evaluation on test set."""
    print("\n" + "="*60)
    print("Evaluating on Test Set")
    print("="*60 + "\n")
    
    model.eval()
    
    all_predictions = []
    all_labels = []
    total_loss = 0
    
    for batch in tqdm(test_loader, desc="Evaluating"):
        input_ids = batch['input_ids'].to(config.device)
        attention_mask = batch['attention_mask'].to(config.device)
        labels = batch['label'].to(config.device)
        
        if config.use_fp16:
            with autocast("cuda"):
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
        else:
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
        
        # Handle DataParallel loss
        loss = outputs.loss.mean() if outputs.loss.dim() > 0 else outputs.loss
        total_loss += loss.item()
        predictions = torch.argmax(outputs.logits, dim=-1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='macro'
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions, average='weighted'
    )
    avg_loss = total_loss / len(test_loader)
    
    label_names = get_label_names()
    report = classification_report(all_labels, all_predictions, target_names=label_names, digits=4)
    conf_matrix = confusion_matrix(all_labels, all_predictions)
    
    # Print results
    print("\n" + "="*60)
    print("EVALUATION RESULTS")
    print("="*60)
    print(f"\n[Overall Metrics]")
    print(f"  Test Loss: {avg_loss:.4f}")
    print(f"  Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"\n[Macro Averages]")
    print(f"  Precision: {precision_macro:.4f}")
    print(f"  Recall: {recall_macro:.4f}")
    print(f"  F1 Score: {f1_macro:.4f}")
    print(f"\n[Weighted Averages]")
    print(f"  Precision: {precision_weighted:.4f}")
    print(f"  Recall: {recall_weighted:.4f}")
    print(f"  F1 Score: {f1_weighted:.4f}")
    print("\n" + "="*60)
    print("CLASSIFICATION REPORT (Per-Class)")
    print("="*60)
    print(report)
    
    return {
        'test_loss': avg_loss,
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'classification_report': report,
        'confusion_matrix': conf_matrix,
        'predictions': all_predictions,
        'labels': all_labels,
        'label_names': label_names
    }


# Evaluate
results = evaluate(model, test_loader, config)

## 10. Final Summary

In [None]:


print("\n" + "="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)
print(f"  Model: {config.model_name}")
print(f"  GPUs: {config.num_gpus} × T4")
print(f"  Test Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
print(f"  Macro F1 Score: {results['f1_macro']:.4f}")
print(f"  Weighted F1 Score: {results['f1_weighted']:.4f}")
print("="*70)

## 11. Save Model

In [None]:


if config.save_model:
    os.makedirs(config.output_dir, exist_ok=True)
    
    # Unwrap DataParallel if needed
    save_model = model.module if hasattr(model, 'module') else model
    
    # Save model
    model_path = os.path.join(config.output_dir, "model")
    save_model.save_pretrained(model_path)
    print(f"Model saved to: {model_path}")
    
    # Save tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.save_pretrained(model_path)
    print(f"Tokenizer saved to: {model_path}")
    
    # Save training history
    history_path = os.path.join(config.output_dir, "training_history.json")
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=2)
    print(f"Training history saved to: {history_path}")
    
    # Save evaluation metrics
    metrics = {
        'model': config.model_name,
        'num_gpus': config.num_gpus,
        'test_loss': results['test_loss'],
        'accuracy': results['accuracy'],
        'precision_macro': results['precision_macro'],
        'recall_macro': results['recall_macro'],
        'f1_macro': results['f1_macro'],
        'precision_weighted': results['precision_weighted'],
        'recall_weighted': results['recall_weighted'],
        'f1_weighted': results['f1_weighted'],
    }
    metrics_path = os.path.join(config.output_dir, "evaluation_metrics.json")
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=2)
    print(f"Evaluation metrics saved to: {metrics_path}")