
# MNIST Softmax Classifier — Forward/Backward Pass (Step‑by‑Step) + Animations

You’ll build and **understand** a single‑layer neural network (softmax classifier / multinomial logistic regression) for MNIST **from scratch** in NumPy.  
We’ll cover:
1. Data prep (flattening 28×28 → 784).
2. Forward pass: \( \mathbf{Z} = \mathbf{X}\mathbf{W} + \mathbf{b} \), \( \hat{\mathbf{Y}} = \mathrm{softmax}(\mathbf{Z}) \).
3. Loss: average cross‑entropy.
4. Backprop: gradients for \( \mathbf{W}, \mathbf{b} \) via the compact **softmax + CE** identity.
5. Training loop with gradient descent.
6. **Matplotlib animations** for loss and for a single example’s probability vector during training.

> **Network layout**: `Input 784 → Linear (784×10) → Softmax → 10 classes`. No hidden layers.


In [None]:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

np.random.seed(42)



## 1. Math — Forward pass and Loss

Let \(X\in\mathbb{R}^{m\times 784}\) be a batch of \(m\) flattened images, \(W\in\mathbb{R}^{784\times 10}\), \(b\in\mathbb{R}^{1\times 10}\).  
The **logits** (pre‑softmax scores) are
$$
\mathbf{Z} = \mathbf{X}\mathbf{W} + \mathbf{b} \quad\in\mathbb{R}^{m\times 10}.
$$
Softmax (row‑wise) yields class probabilities:
$$
\hat{\mathbf{Y}}_{i,c} = \frac{e^{Z_{i,c} - \max_j Z_{i,j}}}{\sum_{k=1}^{10} e^{Z_{i,k}-\max_j Z_{i,j}}}.
$$
With one‑hot labels \(Y\in\mathbb{R}^{m\times 10}\), the average **cross‑entropy** loss is
$$
\mathcal{L} = -\frac{1}{m}\sum_{i=1}^m \sum_{c=1}^{10} Y_{i,c}\,\log \hat{Y}_{i,c}.
$$



## 2. Backprop — compact gradient
For softmax + cross‑entropy together,
$$
\boxed{\;\frac{\partial \mathcal{L}}{\partial \mathbf{Z}} = \hat{\mathbf{Y}} - \mathbf{Y}\;}
$$
Then, since \(\mathbf{Z} = \mathbf{X}\mathbf{W} + \mathbf{b}\),
$$
\frac{\partial \mathcal{L}}{\partial \mathbf{W}} = \frac{1}{m}\mathbf{X}^\top(\hat{\mathbf{Y}}-\mathbf{Y}), 
\qquad
\frac{\partial \mathcal{L}}{\partial \mathbf{b}} = \frac{1}{m}\sum_{i=1}^m (\hat{\mathbf{Y}}_i-\mathbf{Y}_i).
$$

**Update rule (GD):**
$$
\mathbf{W} \leftarrow \mathbf{W} - \eta\,\frac{\partial \mathcal{L}}{\partial \mathbf{W}},\qquad
\mathbf{b} \leftarrow \mathbf{b} - \eta\,\frac{\partial \mathcal{L}}{\partial \mathbf{b}}.
$$



## 3. Load MNIST (Keras) and Prepare Data

> If your environment is offline, this cell may fail to download MNIST. In that case, replace it with any local MNIST loader or pre‑placed arrays.


In [4]:

from utils import get_mnist_corrected

x_train, y_train, x_test, y_test = get_mnist_corrected()

# Normalize to [0,1], flatten to (m, 784)
x_train = (x_train.astype(np.float32) / 255.0).reshape(x_train.shape[0], -1)
x_test  = (x_test.astype(np.float32)  / 255.0).reshape(x_test.shape[0], -1)

num_features = x_train.shape[1]  # 784
num_classes = 10

# One-hot labels
y_train_oh = np.eye(num_classes)[y_train]
y_test_oh  = np.eye(num_classes)[y_test]

x_train.shape, y_train_oh.shape


ImportError: cannot import name 'get_mnist_corrected' from 'utils' (/home/eng/workspace/simple_neural_network_2025b/notebooks/utils.py)


## 4. Model pieces (NumPy)


In [None]:

def softmax(z):
    # Numerically stable row-wise softmax
    z_shift = z - np.max(z, axis=1, keepdims=True)
    exp_z = np.exp(z_shift)
    return exp_z / np.sum(exp_z, axis=1, keepdims=True)

def cross_entropy(y_true_oh, y_prob):
    # Mean CE over batch
    return -np.mean(np.sum(y_true_oh * np.log(y_prob + 1e-9), axis=1))

def accuracy(y_true, y_prob):
    preds = np.argmax(y_prob, axis=1)
    return np.mean(preds == y_true)



## 5. Training Loop (Full‑batch by default)

We record `loss_history`, `acc_history`, and snapshots of **a single example’s** probability vector each epoch to animate later.


In [None]:

# Hyperparameters
learning_rate = 0.1
epochs = 25           # keep modest for demo/animation speed
batch_size = 1024     # use mini-batches for speed/stability

m = x_train.shape[0]

# Initialize W, b
rng = np.random.default_rng(42)
W = rng.normal(0, 0.01, size=(num_features, num_classes))
b = np.zeros((1, num_classes), dtype=np.float32)

# Choose one fixed example to visualize probabilities over time
viz_idx = 0
viz_x = x_train[viz_idx:viz_idx+1]          # (1,784)
viz_label = y_train[viz_idx]

loss_history = []
acc_history = []
viz_prob_history = []   # list of (10,) arrays

# Mini-batch indices helper
def iterate_minibatches(X, Y, batch_size, shuffle=True):
    n = X.shape[0]
    indices = np.arange(n)
    if shuffle:
        np.random.shuffle(indices)
    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        batch_idx = indices[start:end]
        yield X[batch_idx], Y[batch_idx]

for epoch in range(epochs):
    # Training epoch (mini-batch SGD)
    for Xb, Yb in iterate_minibatches(x_train, y_train_oh, batch_size, shuffle=True):
        scores = Xb @ W + b                 # (B,10)
        probs  = softmax(scores)            # (B,10)

        # Gradients
        d_scores = probs - Yb               # (B,10)
        dW = (Xb.T @ d_scores) / Xb.shape[0]
        db = np.sum(d_scores, axis=0, keepdims=True) / Xb.shape[0]

        # Update
        W -= learning_rate * dW
        b -= learning_rate * db

    # End-of-epoch metrics (full train set for simplicity)
    train_scores = x_train @ W + b
    train_probs  = softmax(train_scores)
    L = cross_entropy(y_train_oh, train_probs)
    A = accuracy(y_train, train_probs)
    loss_history.append(L)
    acc_history.append(A)

    # Store viz probabilities for the fixed example
    viz_prob = softmax(viz_x @ W + b)[0]  # (10,)
    viz_prob_history.append(viz_prob.copy())

    print(f"Epoch {epoch+1:02d}/{epochs}  |  loss={L:.4f}  acc={A*100:.2f}%")

# Final test accuracy (optional)
test_probs = softmax(x_test @ W + b)
test_acc = accuracy(y_test, test_probs)
print(f"Test accuracy: {test_acc*100:.2f}%")



## 6. Animation — Loss over Epochs

This animates the training loss as it evolves per epoch.


In [None]:

# Create a loss animation
fig1, ax1 = plt.subplots(figsize=(6,4))
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Training Loss Over Epochs")
line1, = ax1.plot([], [])  # no explicit colors

ax1.set_xlim(1, len(loss_history))
ax1.set_ylim(min(loss_history)*0.95, max(loss_history)*1.05)

def init_loss():
    line1.set_data([], [])
    return (line1,)

def update_loss(frame):
    xdata = np.arange(1, frame+2)
    ydata = loss_history[:frame+1]
    line1.set_data(xdata, ydata)
    return (line1,)

ani1 = animation.FuncAnimation(fig1, update_loss, frames=len(loss_history),
                               init_func=init_loss, interval=300, blit=True)
HTML(ani1.to_jshtml())



## 7. Animation — Probability Vector for One Example

We track the predicted probability distribution over the 10 digits for a **fixed** training image across epochs.


In [None]:

prob_array = np.stack(viz_prob_history, axis=0)  # (epochs, 10)
x_classes = np.arange(10)

fig2, ax2 = plt.subplots(figsize=(6,4))
ax2.set_xlabel("Class")
ax2.set_ylabel("Probability")
ax2.set_title(f"Predicted Probabilities for Sample idx={0} (true={viz_label})")
bar_container = ax2.bar(x_classes, prob_array[0])  # default colors

ax2.set_ylim(0.0, 1.0)
ax2.set_xticks(x_classes)

def update_bars(frame):
    probs = prob_array[frame]
    for rect, h in zip(bar_container, probs):
        rect.set_height(float(h))
    return bar_container.patches

ani2 = animation.FuncAnimation(fig2, update_bars, frames=prob_array.shape[0],
                               interval=300, blit=False)
HTML(ani2.to_jshtml())



## 8. Inference Helper (Try a few test digits)


In [None]:

def predict_proba(X):
    return softmax(X @ W + b)

def predict(X):
    return np.argmax(predict_proba(X), axis=1)

preds = predict(x_test[:10])
print("Predictions for first 10 test digits:", preds)
print("Ground truth:", y_test[:10])



## 9. Summary

- **Architecture**: `784 → Linear (W,b) → Softmax → 10`.
- **Forward**: \(Z = XW + b\), \(\hat{Y}=\mathrm{softmax}(Z)\).
- **Loss**: mean cross‑entropy.
- **Backprop**: \(\\frac{\\partial L}{\\partial Z} = \\hat{Y}-Y\),  
  \(\\frac{\\partial L}{\\partial W} = \\frac{1}{m}X^\\top(\\hat{Y}-Y)\),  
  \(\\frac{\\partial L}{\\partial b} = \\frac{1}{m}\\sum_i(\\hat{Y}_i-Y_i)\\).
- **Training**: gradient descent on \(W,b\).
- **Animations**: loss over epochs; probability bars for one example.
