# ðŸš€ Model Training - Hateful Meme Detection

Train the CLIP + Cross-Attention Fusion model.

**Contents:**
1. Setup & Configuration
2. Data Loading
3. Model Creation
4. Training Loop
5. Save Model

In [None]:
# Install dependencies (if needed)
# !pip install torch torchvision transformers albumentations scikit-learn tqdm

In [None]:
import sys
sys.path.append('..')

import torch
import json
from pathlib import Path
from transformers import CLIPProcessor

from src.model import HatefulMemeClassifier, create_model
from src.dataset import create_dataloaders
from src.losses import FocalLoss

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Configuration

In [None]:
# Configuration
CONFIG = {
    # Paths
    'data_path': '../data/hateful_memes',
    'output_path': '../outputs',
    'model_path': '../models',
    
    # Model
    'clip_model': 'openai/clip-vit-base-patch32',
    'hidden_dim': 512,
    'num_heads': 8,
    'dropout': 0.3,
    
    # Training
    'epochs': 10,
    'batch_size': 32,
    'learning_rate': 2e-4,
    'weight_decay': 0.01,
    'warmup_ratio': 0.1,
    'gradient_clip': 1.0,
    'patience': 3,
    
    # Loss
    'focal_alpha': 0.6412,  # Based on class distribution
    'focal_gamma': 2.0,
}

# Create directories
Path(CONFIG['output_path']).mkdir(exist_ok=True)
Path(CONFIG['model_path']).mkdir(exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 2. Load Data

In [None]:
# Load processor and create dataloaders
processor = CLIPProcessor.from_pretrained(CONFIG['clip_model'])

dataloaders = create_dataloaders(
    data_path=CONFIG['data_path'],
    processor=processor,
    batch_size=CONFIG['batch_size'],
    num_workers=4
)

train_loader = dataloaders['train']
val_loader = dataloaders['val']

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 3. Create Model

In [None]:
# Create model
model = create_model(
    config={
        'clip_model_name': CONFIG['clip_model'],
        'hidden_dim': CONFIG['hidden_dim'],
        'num_heads': CONFIG['num_heads'],
        'dropout': CONFIG['dropout'],
        'freeze_clip': True
    },
    device=device
)

# Print parameter count
params = model.count_parameters()
print(f"\nModel Parameters:")
print(f"  Total: {params['total']:,}")
print(f"  Trainable: {params['trainable']:,} ({params['trainable_pct']:.2f}%)")

## 4. Training Setup

In [None]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

# Loss function
criterion = FocalLoss(alpha=CONFIG['focal_alpha'], gamma=CONFIG['focal_gamma'])

# Optimizer
optimizer = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

# Scheduler
total_steps = len(train_loader) * CONFIG['epochs']
scheduler = OneCycleLR(
    optimizer,
    max_lr=CONFIG['learning_rate'],
    total_steps=total_steps,
    pct_start=CONFIG['warmup_ratio']
)

print("Training setup complete!")

## 5. Training Loop

In [None]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tqdm.notebook import tqdm
import time

def train_epoch(model, loader, criterion, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    
    for batch in tqdm(loader, desc="Training"):
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(pixel_values, input_ids, attention_mask)
        loss = criterion(outputs, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip'])
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        preds = (torch.sigmoid(outputs) > 0.5).float()
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    return {
        'loss': total_loss / len(loader),
        'accuracy': accuracy_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds)
    }

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_probs, all_labels = [], [], []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(pixel_values, input_ids, attention_mask)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return {
        'loss': total_loss / len(loader),
        'accuracy': accuracy_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds),
        'auc': roc_auc_score(all_labels, all_probs)
    }

In [None]:
# Training loop
best_f1 = 0
patience_counter = 0
history = {'train_loss': [], 'train_f1': [], 'val_loss': [], 'val_f1': [], 'val_auc': []}

print("Starting training...\n")

for epoch in range(CONFIG['epochs']):
    print(f"Epoch {epoch+1}/{CONFIG['epochs']}")
    print("-" * 40)
    
    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, scheduler, device)
    
    # Validate
    val_metrics = validate(model, val_loader, criterion, device)
    
    # Log
    history['train_loss'].append(train_metrics['loss'])
    history['train_f1'].append(train_metrics['f1'])
    history['val_loss'].append(val_metrics['loss'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_auc'].append(val_metrics['auc'])
    
    print(f"Train - Loss: {train_metrics['loss']:.4f}, F1: {train_metrics['f1']:.4f}")
    print(f"Val   - Loss: {val_metrics['loss']:.4f}, F1: {val_metrics['f1']:.4f}, AUC: {val_metrics['auc']:.4f}")
    
    # Save best model
    if val_metrics['f1'] > best_f1:
        best_f1 = val_metrics['f1']
        torch.save(model.state_dict(), f"{CONFIG['model_path']}/best_model.pth")
        print(f"  >> New best model! F1: {best_f1:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1
        print(f"  No improvement ({patience_counter}/{CONFIG['patience']})")
    
    print()
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print("Early stopping triggered!")
        break

print(f"\nTraining complete! Best F1: {best_f1:.4f}")

## 6. Training Curves

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

epochs = range(1, len(history['train_loss']) + 1)

# Loss
axes[0].plot(epochs, history['train_loss'], 'b-', label='Train')
axes[0].plot(epochs, history['val_loss'], 'r-', label='Val')
axes[0].set_title('Loss')
axes[0].set_xlabel('Epoch')
axes[0].legend()

# F1
axes[1].plot(epochs, history['train_f1'], 'b-', label='Train')
axes[1].plot(epochs, history['val_f1'], 'r-', label='Val')
axes[1].set_title('F1 Score')
axes[1].set_xlabel('Epoch')
axes[1].legend()

# AUC
axes[2].plot(epochs, history['val_auc'], 'g-')
axes[2].set_title('Validation AUC')
axes[2].set_xlabel('Epoch')

plt.tight_layout()
plt.savefig(f"{CONFIG['output_path']}/training_curves.png", dpi=150)
plt.show()

## 7. Save Final Model

In [None]:
# Save complete checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'model_config': {
        'clip_model_name': CONFIG['clip_model'],
        'hidden_dim': CONFIG['hidden_dim'],
        'num_heads': CONFIG['num_heads'],
        'dropout': CONFIG['dropout'],
        'freeze_clip': True
    },
    'training_config': CONFIG,
    'history': history,
    'best_f1': best_f1
}

torch.save(checkpoint, f"{CONFIG['model_path']}/final_checkpoint.pth")
print(f"Model saved to {CONFIG['model_path']}/final_checkpoint.pth")