# 05 - Model Evaluation

This notebook evaluates trained models:
- Load saved model checkpoint
- Evaluate on test set
- Analyze predictions and errors

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

from src.config import EMOTION_LABELS
from src.data.dataset import load_emotion_data, get_tokenizer
from src.models.bert_classifier import BertClassifier
from src.training.trainer import Trainer
from src.training.utils import load_checkpoint, get_device

## Load Trained Model

In [None]:
device = get_device()

# Initialize and load model
model = BertClassifier()
model = load_checkpoint(model, filename='best_model.pth', device=device)
model = model.to(device)
model.eval()

print("Model loaded successfully!")

## Evaluate on Test Set

In [None]:
# Load test data
_, _, test_df = load_emotion_data(resample=False)
print(f"Test set size: {len(test_df)}")

# Create trainer for evaluation
trainer = Trainer(model=model)
results = trainer.evaluate(test_df)

In [None]:
print(f"\nTest Results:")
print(f"  Accuracy: {results['accuracy']:.4f}")
print(f"  F1 Score: {results['f1_score']:.4f}")
print(f"\nClassification Report:")
print(results['classification_report'])

## Confusion Matrix

In [None]:
# Get predictions for confusion matrix
import torch
from src.data.dataset import EmotionDataset
from torch.utils.data import DataLoader

tokenizer = get_tokenizer()
test_dataset = EmotionDataset(test_df, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=4)

all_preds, all_labels = [], []
with torch.no_grad():
    for batch_input, batch_label in test_loader:
        mask = batch_input['attention_mask'].to(device)
        input_id = batch_input['input_ids'].squeeze(1).to(device)
        output = model(input_id, mask)
        all_preds.extend(output.argmax(dim=1).cpu().numpy())
        all_labels.extend(batch_label.numpy())

# Plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=list(EMOTION_LABELS.values()),
    yticklabels=list(EMOTION_LABELS.values())
)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.tight_layout()
plt.show()

## Sample Predictions

In [None]:
# Interactive prediction
from src.config import MAX_LENGTH

def predict_emotion(text):
    """Predict emotion for a single text."""
    inputs = tokenizer(
        text,
        padding='max_length',
        max_length=MAX_LENGTH,
        truncation=True,
        return_tensors='pt'
    )
    
    with torch.no_grad():
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        output = model(input_ids, attention_mask)
        probs = torch.softmax(output, dim=1)
        pred_class = output.argmax(dim=1).item()
        confidence = probs[0][pred_class].item()
    
    return EMOTION_LABELS[pred_class], confidence

# Test predictions
test_texts = [
    "I just got promoted at work! Best day ever!",
    "I can't believe they would do this to me.",
    "Missing my grandmother who passed away last year.",
    "The test results came back... I'm so scared.",
    "You're the best thing that ever happened to me.",
    "Wow, I never expected that ending!"
]

print("Sample Predictions:\n")
for text in test_texts:
    emotion, conf = predict_emotion(text)
    print(f"Text: {text}")
    print(f"  â†’ {emotion} ({conf:.1%} confidence)\n")

## Error Analysis

In [None]:
# Find misclassified examples
test_df_copy = test_df.copy()
test_df_copy['predicted'] = [EMOTION_LABELS[p] for p in all_preds]
test_df_copy['actual'] = test_df_copy['category']
test_df_copy['correct'] = test_df_copy['predicted'] == test_df_copy['actual']

# Show some misclassified examples
misclassified = test_df_copy[~test_df_copy['correct']].sample(min(10, len(test_df_copy[~test_df_copy['correct']])))
print("Sample Misclassified Examples:\n")
for _, row in misclassified.iterrows():
    print(f"Text: {row['text'][:100]}...")
    print(f"  Actual: {row['actual']} | Predicted: {row['predicted']}\n")