# Tutorial 15: Cross-Entropy — Code

This notebook provides interactive code examples for the Cross-Entropy tutorial.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid")

## 1. Cross-Entropy Implementation

This corresponds to Exercise B1. We'll create a numerically stable function to calculate cross-entropy.

In [None]:
def cross_entropy(P, Q):
    """
    Calculates the cross-entropy between two probability distributions.
    
    Args:
        P (np.array): The true distribution (one-hot).
        Q (np.array): The predicted distribution.
        
    Returns:
        float: The cross-entropy loss.
    """
    # Add a small epsilon for numerical stability to prevent log(0)
    epsilon = 1e-9
    Q = np.clip(Q, epsilon, 1. - epsilon)
    
    return -np.sum(P * np.log(Q))

### Verification

Let's test the function with the example from Exercise A3.

In [None]:
P_banana = np.array([0, 1, 0])
Q_banana = np.array([0.2, 0.5, 0.3])

loss = cross_entropy(P_banana, Q_banana)
print(f"Calculated Loss: {loss:.3f}")
print(f"Expected Loss (-ln(0.5)): {-np.log(0.5):.3f}")

## 2. Visualizing BCE vs. MSE Loss

This plot, from Exercise B2, is crucial for understanding why Cross-Entropy is preferred over Mean Squared Error for classification. It shows how each loss function penalizes incorrect predictions.

In [None]:
# Predicted probabilities for the correct class (where true label y=1)
p_hat = np.linspace(0.01, 0.99, 200)

# Binary Cross-Entropy Loss for y=1 is -log(p_hat)
bce_loss = -np.log(p_hat)

# Mean Squared Error Loss for y=1 is (1-p_hat)^2
mse_loss = (1 - p_hat)**2

plt.figure(figsize=(12, 7))
plt.plot(p_hat, bce_loss, label='Binary Cross-Entropy Loss', color='darkred', lw=2.5)
plt.plot(p_hat, mse_loss, label='Mean Squared Error Loss', color='darkblue', lw=2.5, linestyle='--')

plt.xlabel("Predicted Probability for Correct Class (p̂)", fontsize=12)
plt.ylabel("Loss Value", fontsize=12)
plt.title("BCE vs. MSE Loss (When True Label is 1)", fontsize=14)
plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.ylim(0, 5) # Limit y-axis to better see the shapes
plt.show()

### Analysis of the Plot

1.  **When the prediction is good (p̂ → 1)**: Both losses approach 0.
2.  **When the prediction is bad (p̂ → 0)**:
    -   **Cross-Entropy** shoots up towards infinity. This provides a large gradient for the model to learn from its mistake.
    -   **MSE** flattens out. The gradient becomes very small, meaning the model learns very slowly from its most confident mistakes. This is the primary reason MSE is a poor choice for classification.