# ðŸ“ˆ Model Evaluation - Hateful Meme Detection

Comprehensive evaluation of the trained model.

**Contents:**
1. Load Model
2. Validation Metrics
3. Confusion Matrix
4. ROC & PR Curves
5. Threshold Analysis
6. Error Analysis
7. Inference Demo

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    classification_report, confusion_matrix,
    roc_curve, auc, precision_recall_curve,
    average_precision_score
)
from transformers import CLIPProcessor
from PIL import Image

from src.model import create_model
from src.dataset import create_dataloaders
from src.losses import FocalLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Load Model

In [None]:
# Configuration
MODEL_PATH = '../models/best_model.pth'
DATA_PATH = '../data/hateful_memes'

# Load model
checkpoint = torch.load(MODEL_PATH, map_location=device)

model = create_model(
    config=checkpoint.get('model_config'),
    device=device
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Model loaded successfully!")

# Load data
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
dataloaders = create_dataloaders(DATA_PATH, processor, batch_size=32)
val_loader = dataloaders['val']

## 2. Generate Predictions

In [None]:
from tqdm.notebook import tqdm

y_true = []
y_prob = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label']
        
        outputs = model(pixel_values, input_ids, attention_mask)
        probs = torch.sigmoid(outputs)
        
        y_true.extend(labels.numpy())
        y_prob.extend(probs.cpu().numpy())

y_true = np.array(y_true)
y_prob = np.array(y_prob)
y_pred = (y_prob > 0.5).astype(int)

print(f"Total samples: {len(y_true)}")

## 3. Classification Report

In [None]:
print("Classification Report")
print("=" * 50)
print(classification_report(y_true, y_pred, target_names=['Not Hateful', 'Hateful']))

## 4. Confusion Matrix

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

cm = confusion_matrix(y_true, y_pred)

# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=['Not Hateful', 'Hateful'],
            yticklabels=['Not Hateful', 'Hateful'])
axes[0].set_title('Confusion Matrix (Counts)')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')

# Normalized
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.1%', cmap='Blues', ax=axes[1],
            xticklabels=['Not Hateful', 'Hateful'],
            yticklabels=['Not Hateful', 'Hateful'])
axes[1].set_title('Confusion Matrix (Normalized)')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Actual')

plt.tight_layout()
plt.savefig('../results/figures/confusion_matrix.png', dpi=150)
plt.show()

## 5. ROC & PR Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# ROC Curve
fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)

# Find optimal threshold (Youden's J)
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds_roc[optimal_idx]

axes[0].plot(fpr, tpr, 'b-', lw=2, label=f'ROC (AUC = {roc_auc:.4f})')
axes[0].plot([0, 1], [0, 1], 'k--', lw=1)
axes[0].scatter(fpr[optimal_idx], tpr[optimal_idx], c='red', s=100, 
                label=f'Optimal Threshold = {optimal_threshold:.3f}')
axes[0].fill_between(fpr, tpr, alpha=0.3)
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('ROC Curve')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# PR Curve
precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
avg_precision = average_precision_score(y_true, y_prob)

axes[1].plot(recall, precision, 'g-', lw=2, label=f'PR (AP = {avg_precision:.4f})')
axes[1].fill_between(recall, precision, alpha=0.3, color='green')
axes[1].axhline(y=y_true.mean(), color='gray', linestyle='--', label='Baseline')
axes[1].set_xlabel('Recall')
axes[1].set_ylabel('Precision')
axes[1].set_title('Precision-Recall Curve')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/figures/roc_pr_curves.png', dpi=150)
plt.show()

print(f"\nOptimal Threshold: {optimal_threshold:.4f}")
print(f"At optimal: TPR = {tpr[optimal_idx]:.4f}, FPR = {fpr[optimal_idx]:.4f}")

## 6. Threshold Analysis

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score

thresholds = np.arange(0.1, 0.9, 0.05)
metrics = {'threshold': [], 'f1': [], 'precision': [], 'recall': []}

for t in thresholds:
    preds = (y_prob > t).astype(int)
    metrics['threshold'].append(t)
    metrics['f1'].append(f1_score(y_true, preds))
    metrics['precision'].append(precision_score(y_true, preds))
    metrics['recall'].append(recall_score(y_true, preds))

df_metrics = pd.DataFrame(metrics)

plt.figure(figsize=(10, 5))
plt.plot(df_metrics['threshold'], df_metrics['f1'], 'b-', label='F1')
plt.plot(df_metrics['threshold'], df_metrics['precision'], 'g-', label='Precision')
plt.plot(df_metrics['threshold'], df_metrics['recall'], 'r-', label='Recall')
plt.axvline(x=optimal_threshold, color='gray', linestyle='--', label=f'Optimal ({optimal_threshold:.2f})')
plt.xlabel('Threshold')
plt.ylabel('Score')
plt.title('Metrics vs Threshold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('../results/figures/threshold_analysis.png', dpi=150)
plt.show()

# Best threshold for F1
best_idx = df_metrics['f1'].idxmax()
print(f"\nBest threshold for F1: {df_metrics.loc[best_idx, 'threshold']:.2f}")
print(f"F1 at best threshold: {df_metrics.loc[best_idx, 'f1']:.4f}")

## 7. Error Analysis

In [None]:
# Use optimal threshold for predictions
y_pred_opt = (y_prob > optimal_threshold).astype(int)

# Create results dataframe
import json
with open(f'{DATA_PATH}/dev.jsonl', 'r') as f:
    dev_data = [json.loads(line) for line in f]

results_df = pd.DataFrame({
    'id': [d['id'] for d in dev_data],
    'text': [d['text'][:80] for d in dev_data],
    'true_label': y_true,
    'pred_label': y_pred_opt,
    'probability': y_prob,
    'correct': y_true == y_pred_opt
})

# False positives and negatives
fp = results_df[(results_df['true_label'] == 0) & (results_df['pred_label'] == 1)]
fn = results_df[(results_df['true_label'] == 1) & (results_df['pred_label'] == 0)]

print(f"False Positives: {len(fp)} (Not Hateful predicted as Hateful)")
print(f"False Negatives: {len(fn)} (Hateful predicted as Not Hateful)")

print("\n--- Top False Positives (High Confidence) ---")
print(fp.nlargest(5, 'probability')[['id', 'probability', 'text']])

print("\n--- Top False Negatives (Low Confidence) ---")
print(fn.nsmallest(5, 'probability')[['id', 'probability', 'text']])

## 8. Inference Demo

In [None]:
from src.inference import HatefulMemePredictor

# Initialize predictor
predictor = HatefulMemePredictor(
    model_path=MODEL_PATH,
    device=str(device),
    threshold=optimal_threshold
)

# Test on a sample
sample = dev_data[0]
image_path = f"{DATA_PATH}/{sample['img']}"
text = sample['text']

result = predictor.get_detailed_analysis(image_path, text)

print("\n" + "="*50)
print("INFERENCE DEMO")
print("="*50)
print(f"Text: {text}")
print(f"True Label: {'Hateful' if sample['label'] == 1 else 'Not Hateful'}")
print(f"Predicted: {result['label']}")
print(f"Probability: {result['probability']:.4f}")
print(f"Confidence: {result['confidence']:.2%}")
print(f"Zone: {result['zone']}")

# Show image
img = Image.open(image_path)
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.axis('off')
plt.title(f"Prediction: {result['label']} ({result['probability']:.2%})")
plt.show()

## 9. Summary

**Key Results:**
- Model achieves competitive performance on a challenging multimodal task
- Optimal threshold identified through ROC analysis
- Error analysis reveals model struggles with subtle/implicit hate

**Recommendations:**
- Use optimal threshold (not 0.5) for deployment
- Implement tiered response system (safe/warning/remove)
- Human review for edge cases in warning zone