<a target="_blank" href="https://colab.research.google.com/github/ezponda/deep-learning-course/blob/main/book/appendix/A3_model_evaluation.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Model Evaluation & Debugging

Beyond accuracy: metrics, visualization, and troubleshooting.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import seaborn as sns

---

## Why Accuracy Isn't Enough

Consider a fraud detection model with 99% accuracy. Sounds great, right?

But if only 1% of transactions are fraudulent, a model that **always predicts "not fraud"** achieves 99% accuracy!

**Accuracy hides important information** when classes are imbalanced.

---

## The Confusion Matrix

A confusion matrix shows **all four types of predictions**:

```
                    Predicted
                 Neg      Pos
              ┌────────┬────────┐
    Actual    │   TN   │   FP   │  Negative
              ├────────┼────────┤
              │   FN   │   TP   │  Positive
              └────────┴────────┘

TN = True Negative  (correct rejection)
FP = False Positive (false alarm, Type I error)
FN = False Negative (miss, Type II error)
TP = True Positive  (correct detection)
```

In [None]:
# Example: Binary classification results
y_true = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1])
y_pred = np.array([0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1])

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Predicted Neg', 'Predicted Pos'],
            yticklabels=['Actual Neg', 'Actual Pos'])
plt.title('Confusion Matrix', fontsize=14)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

tn, fp, fn, tp = cm.ravel()
print(f"True Negatives: {tn}, False Positives: {fp}")
print(f"False Negatives: {fn}, True Positives: {tp}")

---

## Precision, Recall, and F1-Score

### Precision (Positive Predictive Value)

"Of all predicted positives, how many are actually positive?"

$$\text{Precision} = \frac{TP}{TP + FP}$$

**High precision = few false alarms**

### Recall (Sensitivity, True Positive Rate)

"Of all actual positives, how many did we catch?"

$$\text{Recall} = \frac{TP}{TP + FN}$$

**High recall = few misses**

### F1-Score (Harmonic Mean)

Balances precision and recall:

$$F1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}}$$

In [None]:
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))

### Precision vs Recall: The Trade-off

| Scenario | Prioritize | Why |
|----------|------------|-----|
| **Spam filter** | Precision | Don't want important emails in spam |
| **Cancer screening** | Recall | Don't want to miss any cases |
| **Fraud detection** | Depends on cost | Balance false alarms vs. missed fraud |
| **Search engine** | Precision | Users want relevant results first |

In [None]:
# Visualize the precision-recall trade-off
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# High threshold (high precision, low recall)
axes[0].bar(['Precision', 'Recall', 'F1'], [0.95, 0.40, 0.56], color=['green', 'orange', 'blue'])
axes[0].set_ylim(0, 1)
axes[0].set_title('High Threshold\n(Conservative)', fontsize=12)
axes[0].axhline(0.5, color='gray', linestyle='--', alpha=0.5)

# Balanced threshold
axes[1].bar(['Precision', 'Recall', 'F1'], [0.75, 0.75, 0.75], color=['green', 'orange', 'blue'])
axes[1].set_ylim(0, 1)
axes[1].set_title('Balanced Threshold\n(Default)', fontsize=12)
axes[1].axhline(0.5, color='gray', linestyle='--', alpha=0.5)

# Low threshold (low precision, high recall)
axes[2].bar(['Precision', 'Recall', 'F1'], [0.45, 0.95, 0.61], color=['green', 'orange', 'blue'])
axes[2].set_ylim(0, 1)
axes[2].set_title('Low Threshold\n(Aggressive)', fontsize=12)
axes[2].axhline(0.5, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

---

## ROC Curve and AUC

The **ROC curve** (Receiver Operating Characteristic) shows the trade-off between:
- **True Positive Rate** (Recall)
- **False Positive Rate** (1 - Specificity)

**AUC** (Area Under Curve) summarizes performance: 
- AUC = 1.0 → Perfect
- AUC = 0.5 → Random guessing
- AUC < 0.5 → Worse than random

In [None]:
# Simulated probability scores
np.random.seed(42)
y_true_proba = np.array([0]*50 + [1]*50)
# Good model: positive class has higher scores
y_scores = np.concatenate([
    np.random.beta(2, 5, 50),  # Negative class: lower scores
    np.random.beta(5, 2, 50)   # Positive class: higher scores
])

fpr, tpr, thresholds = roc_curve(y_true_proba, y_scores)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, 'b-', linewidth=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random classifier (AUC = 0.50)')
plt.fill_between(fpr, tpr, alpha=0.3)
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curve', fontsize=14)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.show()

---

## Multi-class Metrics

For multi-class problems, metrics are computed per-class then averaged:

| Averaging | Method |
|-----------|--------|
| **Macro** | Average metrics across classes (treats all classes equally) |
| **Weighted** | Average weighted by class frequency |
| **Micro** | Compute globally (total TP, FP, FN) |

In [None]:
# Multi-class example
y_true_multi = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 0, 1, 2])
y_pred_multi = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 0, 0, 2, 0, 1, 1])

cm_multi = confusion_matrix(y_true_multi, y_pred_multi)

plt.figure(figsize=(7, 6))
sns.heatmap(cm_multi, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Class 0', 'Class 1', 'Class 2'],
            yticklabels=['Class 0', 'Class 1', 'Class 2'])
plt.title('Multi-class Confusion Matrix', fontsize=14)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

print(classification_report(y_true_multi, y_pred_multi, 
                            target_names=['Class 0', 'Class 1', 'Class 2']))

---

## Common Training Problems & Solutions

### Problem: Loss Not Decreasing

| Symptom | Likely Cause | Solution |
|---------|--------------|----------|
| Loss stays high | Learning rate too low | Increase LR |
| Loss oscillates wildly | Learning rate too high | Decrease LR |
| Loss is NaN | Exploding gradients | Lower LR, add gradient clipping |
| Loss stuck at constant | Dead neurons (ReLU) | Use LeakyReLU, check initialization |

In [None]:
# Visualize learning rate problems
epochs = np.arange(50)

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Too low
loss_low = 2.5 - 0.02 * epochs + np.random.normal(0, 0.05, 50)
axes[0].plot(epochs, loss_low, 'b-', linewidth=2)
axes[0].set_title('Learning Rate Too Low\n(Barely decreasing)', fontsize=12)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_ylim(0, 3)

# Good
loss_good = 2.5 * np.exp(-0.1 * epochs) + 0.1 + np.random.normal(0, 0.03, 50)
axes[1].plot(epochs, loss_good, 'g-', linewidth=2)
axes[1].set_title('Good Learning Rate\n(Smooth decrease)', fontsize=12)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_ylim(0, 3)

# Too high
loss_high = 1.5 + 0.8 * np.sin(epochs * 0.5) + np.random.normal(0, 0.2, 50)
axes[2].plot(epochs, loss_high, 'r-', linewidth=2)
axes[2].set_title('Learning Rate Too High\n(Oscillating)', fontsize=12)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Loss')
axes[2].set_ylim(0, 3)

plt.tight_layout()
plt.show()

### Problem: Overfitting

| Symptom | Solution |
|---------|----------|
| Val loss increases while train loss decreases | Early stopping |
| Large gap between train/val accuracy | Add dropout, regularization |
| Model memorizes training data | Get more data, augmentation |
| Perfect training accuracy | Reduce model complexity |

### Problem: Underfitting

| Symptom | Solution |
|---------|----------|
| Both train and val loss stay high | Increase model capacity |
| Model too simple for data | Add more layers/neurons |
| Training stops too early | Increase epochs, tune LR |
| Poor feature representation | Better preprocessing, feature engineering |

---

## Debugging Checklist

```
□ Data
  □ Are inputs normalized? (mean ~0, std ~1)
  □ Are labels correct? (spot-check a few examples)
  □ Is there data leakage? (val/test data in training)
  □ Are classes balanced? (if not, use class weights)

□ Model
  □ Can it overfit a tiny dataset? (sanity check)
  □ Are activation functions appropriate?
  □ Is output layer correct for the task?
  
□ Training
  □ Is loss decreasing? (if not, check LR)
  □ Is validation loss tracked? (use callbacks)
  □ Are gradients flowing? (no NaN, not all zeros)
  
□ Evaluation
  □ Using appropriate metrics? (not just accuracy)
  □ Test set never seen during training?
  □ Results reproducible? (set random seeds)
```

---

## Keras Callbacks for Monitoring

```python
from tensorflow import keras

callbacks = [
    # Stop when val_loss stops improving
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    ),
    
    # Reduce LR when stuck
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5
    ),
    
    # Save best model
    keras.callbacks.ModelCheckpoint(
        'best_model.keras',
        monitor='val_loss',
        save_best_only=True
    ),
    
    # TensorBoard logging
    keras.callbacks.TensorBoard(
        log_dir='./logs'
    )
]

model.fit(X_train, y_train, 
          validation_data=(X_val, y_val),
          callbacks=callbacks,
          epochs=100)
```

---

## Key Takeaways

1. **Don't rely on accuracy alone** — use precision, recall, F1 for imbalanced data
2. **Confusion matrix reveals errors** — see exactly what's being confused
3. **ROC-AUC for probability outputs** — threshold-independent performance measure
4. **Watch both train and val loss** — the gap tells you about overfitting
5. **Use the debugging checklist** — systematic approach to fixing problems
6. **Keras callbacks automate monitoring** — EarlyStopping, ReduceLROnPlateau