<div class="alert alert-warning">

# PS 4 - Perceptron for Digit Recognition

In this problem set, we will implement a **perceptron** — a simple neural network — to classify handwritten digits from the MNIST dataset.

## Reminder - The Problem

Given an image of a handwritten digit (0-9), can we train a simple neural network to recognize which digit it is? This is a classic machine learning problem that connects to how the brain might learn to categorize visual patterns.

## Reminder - The Perceptron

The perceptron is the simplest neural network:
1. It takes **input features** (here, pixel values from an image)
2. Each input is multiplied by a **weight**
3. The weighted inputs are summed: $y = \sum_i w_i \times x_i$
4. The output is **thresholded**: output 1 if $y > 0$, else output 0

Learning happens by adjusting weights when the perceptron makes mistakes:
- If it should have output 1 but didn't: **increase** weights for active inputs
- If it should have output 0 but didn't: **decrease** weights for active inputs

## This Problem Set

We'll:
1. Load and visualize MNIST handwritten digit images
2. Implement the perceptron output function
3. Implement the perceptron learning rule
4. Train a perceptron to distinguish "0" from "1"
5. Explore what the trained weights reveal about digit features
6. Extend to classifying all digit pairs

Your goal is to understand how a simple neural network learns from examples and what its weights represent.</div>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import zipfile

# 1. Loading and Understanding the Data

We will train a perceptron to classify images of handwritten digits 0-9 from the MNIST dataset.

Each image is 28×28 pixels, which we flatten into a vector of length 784. Each pixel value indicates whether that location was white (0) or black (1).

**Useful functions for this problem set:**
- `x.flatten()` — convert an N-dimensional array to 1D
- `np.dot(a, b)` — compute dot product of two vectors
- `np.reshape(x, shape)` — reshape a vector to a matrix
- `plt.imshow(x)` — visualize a matrix as an image

Run the cell below to load images of "0"s and "1"s. Make sure the `images` folder (or `images.zip`) is in the same directory as this notebook.

In [None]:
DIM = (28, 28)  # Image dimensions

def load_image_files(n, path="images/", zip_path="images.zip"):
    """Load images of digit n from the MNIST dataset."""
    if not os.path.exists(path):
        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(path[:-1])
    images = []
    for f in sorted(os.listdir(os.path.join(path, str(n)))):
        p = os.path.join(path, str(n), f)
        if os.path.isfile(p):
            i = np.loadtxt(p)
            assert i.shape == DIM
            images.append(i.flatten())
    return images

# Load images of 0s and 1s
A = load_image_files(0)
B = load_image_files(1)

N = len(A[0])  # Total number of features (pixels)
print(f'Loaded {len(A)} images of "0" and {len(B)} images of "1"')
print(f'Each image has {N} pixels ({DIM[0]}x{DIM[1]})')

<div class="alert alert-success">

# 2. Let's Visualize the Images

Let's explore what we have loaded by examining the images of "0"s (stored in `A`) and "1"s (stored in `B`).

We will create a figure with 2 rows × 3 columns showing:
- Top row: 3 randomly chosen images of "0"
- Bottom row: 3 randomly chosen images of "1"

Requirements:
- Use `np.reshape(image, (28, 28))` to convert each vector back to a 2D image
- Use `imshow()` or `matshow()` with `cmap='gray'` to display
- Add a title to each subplot indicating which digit is shown
- Remove x and y tick marks (these are images, not plots)
- Add an overall figure title

</div>

In [None]:
figure, ax = plt.subplots(2, 3, figsize=(10, 7))
figure.suptitle('Example Handwritten Digit Images')

for i in range(3):
    ax[0, i].matshow(A[np.random.randint(len(A))].reshape(28, 28), cmap='gray')
    ax[0, i].set_title('Digit: 0')
    ax[0, i].set_xticks([])
    ax[0, i].set_yticks([])

    ax[1, i].matshow(B[np.random.randint(len(B))].reshape(28, 28), cmap='gray')
    ax[1, i].set_title('Digit: 1')
    ax[1, i].set_xticks([])
    ax[1, i].set_yticks([])

plt.tight_layout()
plt.show()

<div class="alert alert-success">

# 3. Exercise 1: Perceptron Output Function

Write the `compute_output` function that computes a perceptron's output given:
1. Weights $W = (w_1, \ldots, w_N)$
2. An image $X = (x_1, \ldots, x_N)$ as a vector of N=784 features

**Algorithm:**
1. Compute $y = \sum_i w_i \times x_i$ (use `np.dot()` for efficiency)
2. Return 1 if $y > 0$, otherwise return 0

Test your function with the provided test case.

</div>

In [None]:
# YOUR CODE - Exercise 1

In [None]:
# Test your function
test_weights = (np.arange(N) / N) - 0.5
test_results = [compute_output(test_weights, A[i]) for i in range(3)]
print(f'Test outputs for 3 images of 0: {test_results}')
print('(Expected: all 1s with these test weights)')

<div class="alert alert-success">

# 4. Exercise 2: Compute Accuracy

Write a function `compute_accuracy` that computes the classification accuracy of the perceptron on a set of images.

**Parameters:**
- `W`: weight vector
- `images`: list of image vectors
- `labels`: list of correct labels (0 or 1)

**Returns:** Proportion of correctly classified images

</div>

In [None]:
# YOUR CODE - Exercise 2

In [None]:
# Test your function
test_weights = np.random.randn(N)
acc_0 = compute_accuracy(test_weights, A[:100], np.zeros(100))
acc_1 = compute_accuracy(test_weights, B[:100], np.ones(100))
print(f'Accuracy on "0" images: {acc_0:.2f}')
print(f'Accuracy on "1" images: {acc_1:.2f}')
print('(With random weights, expect ~50% for each)')

<div class="alert alert-success">

# 5. Exercise 3: Perceptron Learning Rule

Write the function `update_weights_single_image` that updates the perceptron weights after seeing one training example.

**Perceptron Learning Algorithm:**
1. Compute the predicted label using current weights
2. If prediction matches true label: do nothing
3. If prediction is wrong:
   - If true label is 1: $W = W + X$ (increase weights for active pixels)
   - If true label is 0: $W = W - X$ (decrease weights for active pixels)

</div>

In [None]:
# YOUR CODE - Exercise 3

**Test your function:** Start with zero weights and train on a single "1" image. Visualize the resulting weights — they should look like the training image!

In [None]:
# Train on one example
weights = np.zeros(N)
weights = update_weights_single_image(weights, B[0], 1)

# Visualize
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].imshow(B[0].reshape(28, 28), cmap='gray')
ax[0].set_title('Training image (a "1")')
ax[0].axis('off')

ax[1].imshow(weights.reshape(28, 28))
ax[1].set_title('Weights after training')
ax[1].axis('off')
plt.show()

<div class="alert alert-success">

# 6. Exercise 4: Training on Multiple Images

Write a function `update_weights_multiple_images` that applies the learning rule to a batch of training examples in sequence.
Your function should take the current weight vector W as input; as well as a list of image vectors and an array of labels for each image; and return the updated weights W after training over the list of images.
</div>

In [None]:
# YOUR CODE - Exercise 4

In [None]:
# Test your function: Train on a small batch and check that weights change appropriately
weights01 = np.zeros(N)
weights10 = np.zeros(N)

# Train on 3 images of "0" and 3 images of "1"
test_images01 = A[:3] + B[:3]
test_labels01 = np.array([0, 0, 0, 1, 1, 1])
test_images10 = B[:3] + A[:3]
test_labels10 = np.array([1, 1, 1, 0, 0, 0])

weights01 = update_weights_multiple_images(weights01, test_images01, test_labels01)
weights10 = update_weights_multiple_images(weights10, test_images10, test_labels10)

# Visualize
fig, axes = plt.subplots(1, 4, figsize=(12, 4))

axes[0].imshow(B[0].reshape(28,28)-A[0].reshape(28,28), cmap='gray')
axes[0].set_title('First training "0"s - "1"s')
axes[0].axis('off')

axes[1].imshow(weights10.reshape(28, 28), cmap='gray')
axes[1].set_title('Weights after training on batch 01')
axes[1].axis('off')

axes[2].imshow(weights01.reshape(28, 28), cmap='gray')
axes[2].set_title('Weights after training on batch 10')
axes[2].axis('off')

axes[3].imshow(B[0].reshape(28,28), cmap='gray')
axes[3].set_title('First training "1"s')
axes[3].axis('off')

plt.tight_layout()
plt.show()

print(f'The left two images should look identical;')
print(f'the right two images should look identical.')
print(f'Try to reason about why!')
print(f'Or use code to figure it out')

# 7. Training the Perceptron

We'll train the perceptron in batches. Each step:
1. Randomly select some "0" images and some "1" images
2. Update weights on this batch
3. Measure accuracy on the full dataset

The `train_perceptron` function is provided below.

In [None]:
def train_perceptron(train_0, train_1, N_samples=5, steps=200):
    """
    Train a perceptron to distinguish two digit classes.
    
    Parameters
    ----------
    train_0, train_1 : lists
        Training images for each class
    N_samples : int
        Number of samples per class per training step
    steps : int
        Number of training steps
        
    Returns
    -------
    performance : array
        Accuracy at each step
    weights : array
        Final trained weights
    """
    performance = np.empty(steps)
    train_labels = np.concatenate([np.zeros(N_samples), np.ones(N_samples)])
    
    # Full dataset for evaluation
    full_sample = train_0 + train_1
    full_labels = np.concatenate([np.zeros(len(train_0)), np.ones(len(train_1))])

    # Initialize random weights
    weights = np.random.normal(0, 1, size=N)
    
    for i in range(steps):
        # Sample random training batch
        idx_0 = np.random.choice(len(train_0), size=N_samples, replace=False)
        idx_1 = np.random.choice(len(train_1), size=N_samples, replace=False)
        examples = [train_0[j] for j in idx_0] + [train_1[j] for j in idx_1]
        
        # Update weights
        weights = update_weights_multiple_images(weights, examples, train_labels)
        
        # Evaluate
        performance[i] = compute_accuracy(weights, full_sample, full_labels)
        
    return performance, weights

<div class="alert alert-success">

# 8. Exercise 5: Training Curves

Train the perceptron with different batch sizes (N_samples = 1, 5, 25) and plot the learning curves.

**Your task:**
1. Create a figure with 3 subplots (3 rows × 1 column)
2. For each batch size, plot accuracy vs. training step
3. Set y-axis limits to (0.7, 1.0)
4. Add appropriate titles and labels

**Questions to answer:**
1. How does batch size affect learning speed and stability?
2. Does the perceptron reach 100% accuracy? What does this tell you about linear separability?

</div>

In [None]:
# YOUR CODE - Exercise 5

### YOUR NOTES - Exercise 5

<div class="alert alert-success">

# 9. Exercise 6: Visualizing Trained Weights

Reshape the trained weights into a 28×28 image and visualize them.

**Your task:**
1. Train a perceptron with N_samples=25, steps=200
2. Reshape the weights to (28, 28) and display with `imshow()`
3. Add a colorbar

**Questions to answer:**
- What do large positive weights mean?
- What do large negative weights mean?
- What do weights near zero mean?
- Why does the weight pattern look the way it does?

</div>

In [None]:
# YOUR CODE - Exercise 6

### YOUR NOTES - Exercise 6

<div class="alert alert-success">

# 10. Exercise 7: Which Pixels Matter?

How many pixels actually matter for classification? Let's find out by zeroing out the smallest weights.

**Your task:**
1. Start with trained weights
2. Progressively set the smallest weights (by absolute value) to zero
3. Measure accuracy at each step
4. Plot accuracy vs. number of weights set to zero

**Question:** What proportion of pixels are truly diagnostic for distinguishing "0" from "1"?

</div>

In [None]:
# YOUR CODE - Exercise 7

### YOUR NOTES - Exercise 7

<div class="alert alert-success">

# 11. Exercise 8: All Digit Pairs

Train a perceptron for every pair of digits (0 vs 1, 0 vs 2, ..., 8 vs 9) and visualize which pairs are hardest to distinguish.

**Your task:**
1. Create a 10×10 matrix where entry (i,j) is the accuracy for distinguishing digit i from digit j
2. Visualize as a heatmap with colorbar
3. Label axes with digit values

**Note:** This may take several minutes to run.

**Question:** Which digit pairs are hardest to distinguish? Does this match your intuition?

</div>

In [None]:
# YOUR CODE - Exercise 8

### YOUR NOTES - Exercise 8

# 12. Summary

In this problem set, you implemented a perceptron from scratch and explored:

| Concept | What you learned |
|---------|------------------|
| **Perceptron output** | Weighted sum + threshold → binary decision |
| **Learning rule** | Adjust weights when wrong: W ± X |
| **Training** | Iterative improvement through error correction |
| **Weight interpretation** | Weights encode which features matter for each class |
| **Linear separability** | Perceptron works when classes can be separated by a line |

**Key insights:**
1. The perceptron learns by adjusting weights based on errors
2. Trained weights reveal which pixels are diagnostic for classification
3. Some digit pairs are harder to distinguish than others
4. The perceptron is limited to linearly separable problems