# Tutorial 15: Cross-Entropy â€” Code & Analysis

This notebook provides an in-depth exploration of Cross-Entropy, including:
1.  **Numerical Stability**: Why naive implementations fail and how to fix them (LogSumExp).
2.  **Gradient Analysis**: Why Cross-Entropy beats Mean Squared Error (MSE) for classification.
3.  **Focal Loss**: How to modify Cross-Entropy to handle class imbalance.
4.  **Softmax Temperature**: How temperature affects prediction confidence.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F

sns.set_theme(style="whitegrid")

## 1. Numerical Stability (LogSumExp)

A naive implementation of Softmax + Cross-Entropy often leads to numerical instability (overflow/underflow). 

$$ \text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}} $$

If $x_i$ is large (e.g., 1000), $e^{1000}$ overflows. If $x_i$ is small negative (e.g., -1000), it underflows.

In [None]:
def naive_softmax(logits):
    return np.exp(logits) / np.sum(np.exp(logits))

def stable_softmax(logits):
    # LogSumExp trick: shift logits by max value
    # e^(x-c) / sum(e^(x-c)) == e^x / sum(e^x)
    c = np.max(logits)
    exp_logits = np.exp(logits - c)
    return exp_logits / np.sum(exp_logits)

# Test with large values
logits_large = np.array([1000.0, 1001.0, 1002.0])

print("Naive Softmax (Large Input):")
try:
    print(naive_softmax(logits_large))
except Exception as e:
    print(e)
    
print("\nStable Softmax (Large Input):")
print(stable_softmax(logits_large))

## 2. Visualizing BCE vs. MSE Loss

Why don't we use MSE for classification? Let's visualize the loss surface and the gradients.

In [None]:
# Predicted probabilities for the correct class (where true label y=1)
p_hat = np.linspace(0.001, 0.999, 500)

# Binary Cross-Entropy Loss for y=1 is -log(p_hat)
bce_loss = -np.log(p_hat)

# Mean Squared Error Loss for y=1 is (1-p_hat)^2
mse_loss = (1 - p_hat)**2

plt.figure(figsize=(10, 6))
plt.plot(p_hat, bce_loss, label='Binary Cross-Entropy Loss', color='#d62728', lw=3)
plt.plot(p_hat, mse_loss, label='Mean Squared Error Loss', color='#1f77b4', lw=3, linestyle='--')

plt.xlabel("Predicted Probability for Correct Class (p̂)", fontsize=12)
plt.ylabel("Loss Value", fontsize=12)
plt.title("BCE vs. MSE Loss (When True Label is 1)", fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.ylim(0, 5) 
plt.show()

### Gradient Magnitude Analysis

Models learn via gradients. A vanishing gradient means no learning. 

Let $p = \sigma(z)$. 
-   Gradient of BCE w.r.t logits $z$: $|p - y|$
-   Gradient of MSE w.r.t logits $z$: $|(p - y)p(1-p)|$

In [None]:
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

z = np.linspace(-10, 10, 500)
p = sigmoid(z)

# Gradient of BCE w.r.t z (target y=1)
# dL/dz = p - 1
grad_bce = np.abs(p - 1)

# Gradient of MSE w.r.t z (target y=1)
# dL/dz = -2(1-p) * p(1-p)
grad_mse = np.abs(-2 * (1 - p) * p * (1 - p))

plt.figure(figsize=(10, 6))
plt.plot(z, grad_bce, label='Gradient of BCE', color='#d62728', lw=3)
plt.plot(z, grad_mse, label='Gradient of MSE', color='#1f77b4', lw=3, linestyle='--')

plt.xlabel("Logit input (z)", fontsize=12)
plt.ylabel("Gradient Magnitude |dL/dz|", fontsize=12)
plt.title("Gradient Strength: BCE vs MSE", fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)

# Highlight the "Confidently Wrong" area
plt.axvspan(-10, -5, color='gray', alpha=0.15)
plt.text(-7.5, 0.5, "Confidently Wrong\n(Vanishing MSE Grad)", fontsize=11, ha='center', fontweight='bold')

plt.show()

## 3. Focal Loss: Handling Imbalance

When you have many easy examples (background) and few hard examples (objects), standard Cross-Entropy can be overwhelmed by the easy negatives.

**Focal Loss** adds a factor $(1 - p_t)^\gamma$ to down-weight easy examples.

$$ FL(p_t) = -(1 - p_t)^\gamma \log(p_t) $$

In [None]:
p_t = np.linspace(0.001, 0.999, 500)
ce_loss = -np.log(p_t)

# Focal Loss with different gammas
gamma_0 = ce_loss # Gamma=0 is just CE
gamma_1 = -((1 - p_t)**1) * np.log(p_t)
gamma_2 = -((1 - p_t)**2) * np.log(p_t)
gamma_5 = -((1 - p_t)**5) * np.log(p_t)

plt.figure(figsize=(10, 6))
plt.plot(p_t, gamma_0, label='Cross Entropy (Gamma=0)', color='black', linestyle='--', lw=2)
plt.plot(p_t, gamma_1, label='Focal Loss (Gamma=1)', color='#fdae61', lw=2.5)
plt.plot(p_t, gamma_2, label='Focal Loss (Gamma=2)', color='#d7191c', lw=2.5)
plt.plot(p_t, gamma_5, label='Focal Loss (Gamma=5)', color='#2c7bb6', lw=2.5)

plt.xlabel("Probability of Ground Truth Class (p_t)", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.title("Focal Loss vs Cross Entropy", fontsize=14)
plt.legend()
plt.grid(True, alpha=0.5)
plt.ylim(0, 5)

plt.text(0.6, 2.0, "Easy Examples (High p_t)\nare down-weighted", fontsize=11, bbox=dict(facecolor='white', alpha=0.8))

plt.show()

## 4. PyTorch: Categorical vs Sparse

A common source of confusion is the difference between `CrossEntropyLoss` (which expects class indices, essentially "Sparse") and custom implementations needing one-hot encoding.

In [None]:
# 1. Standard PyTorch CrossEntropyLoss
# Expects: Logits (N, C) and Target Indices (N)
criterion = nn.CrossEntropyLoss()

logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]]) # 2 samples, 3 classes
targets = torch.tensor([0, 1]) # Class 0 for first, Class 1 for second

loss = criterion(logits, targets)
print(f"Standard Loss (indices): {loss.item():.4f}")

# 2. If you have One-Hot vectors (e.g. from Mixup augmentation)
# You strictly speaking need to implement it manually or use BCEWithLogitsLoss if binary.
# But usually, just use indices.

# Let's verify manually for the first sample:
# Logits: [2.0, 1.0, 0.1], True: 0
probs = F.softmax(logits[0], dim=0)
manual_loss = -torch.log(probs[0])
print(f"Manual Check: {manual_loss.item():.4f}")