In [1]:
!git clone "https://github.com/deshcrete/leaf-technical-worksheets.git"
!mv leaf-technical-worksheets/* .
!mv week-1-Interpreting-NNs/* .

^C


'mv' is not recognized as an internal or external command,
operable program or batch file.
'mv' is not recognized as an internal or external command,
operable program or batch file.


Cloning into 'leaf-technical-worksheets'...


In [None]:
# Import helper functions - all the implementation details are hidden here!
from helper_interp import (
    # Model and training
    create_model, create_optimizer, train_model, 
    load_data, create_dataloaders, plot_training_history,
    # Visualization
    visualize_network_architecture, visualize_sample_digits, visualize_prediction,
    # Activations
    plot_activations, get_sample_by_digit, get_sample_by_index,
    # Probes
    train_probe, visualize_probe_weights, compare_probes_across_layers
)
import matplotlib.pyplot as plt

## Neural Network Architecture

Our MNIST classifier has the following structure:
- **Input**: 784 neurons (28x28 flattened image)
- **Hidden Layer 1**: 32 neurons + ReLU
- **Hidden Layer 2**: 16 neurons + ReLU
- **Output**: 10 neurons (one per digit 0-9)

In [None]:
# Visualize the network architecture
visualize_network_architecture()

In [None]:
# Load the MNIST data
train_imgs, train_lbls, test_imgs, test_lbls = load_data(
    '/content/sample_data/mnist_train_small.csv', 
    '/content/sample_data/mnist_test.csv'
)
train_loader, test_loader = create_dataloaders(train_imgs, train_lbls, test_imgs, test_lbls)

print(f"Loaded {len(train_imgs)} training samples and {len(test_imgs)} test samples")

## Visualizing Input Digits

Let's see what MNIST digits look like - these 28x28 grayscale images are flattened to 784 values before entering the network.

In [None]:
# Visualize sample digits (one of each 0-9)
visualize_sample_digits(train_imgs, train_lbls)

In [None]:
# =============================================================================
# TRAINING PARAMETERS - You can modify these!
# =============================================================================
EPOCHS = 20           # Number of training epochs
LEARNING_RATE = 0.001  # Learning rate for optimizer

# Create and train the model
model = create_model()
optimizer = create_optimizer(model, lr=LEARNING_RATE)
history = train_model(model, optimizer, train_loader, test_loader, epochs=EPOCHS)

In [None]:
# Plot training progress
plot_training_history(history)

## Visualizing Output Distributions

The network outputs 10 logits which are converted to probabilities via softmax. Let's see how the model's predictions look for different digits.

In [None]:
# =============================================================================
# CHOOSE A DIGIT TO VISUALIZE - Change this value!
# =============================================================================
DIGIT_TO_VIEW = 7  # Try values 0-9

# Get a sample of that digit and show its prediction
image, label = get_sample_by_digit(test_imgs, test_lbls, DIGIT_TO_VIEW)
visualize_prediction(model, image, label)

## Understanding Activations

**What are activations?** As data flows through a neural network, each layer transforms the input and produces an output called an "activation". These activations are the intermediate representations the network builds as it processes information.

**Why do they matter?** By examining activations, we can understand:
- What features each layer has learned to detect
- How the network transforms raw pixels into abstract concepts
- Which neurons "fire" (have high values) for different inputs

Our network has 5 activation points:
- **Layer 0**: After first Linear (784 → 32)
- **Layer 1**: After first ReLU (32 neurons, negative values zeroed)
- **Layer 2**: After second Linear (32 → 16)
- **Layer 3**: After second ReLU (16 neurons)
- **Layer 4**: Output logits (16 → 10)

In [None]:
# =============================================================================
# INTERACTIVE ACTIVATION EXPLORER
# =============================================================================
# Choose how to select your input:
#   - By digit (0-9): Shows activations for a sample of that digit
#   - By index: Shows activations for a specific sample in the test set

# OPTION 1: Select by digit
DIGIT_TO_EXPLORE = 5  # Change this! (0-9)

image, label = get_sample_by_digit(test_imgs, test_lbls, DIGIT_TO_EXPLORE)
plot_activations(model, image, label)

# OPTION 2: Select by index (uncomment to use)
# INDEX_TO_EXPLORE = 42  # Change this! (0 to ~10000)
# image, label = get_sample_by_index(test_imgs, test_lbls, INDEX_TO_EXPLORE)
# plot_activations(model, image, label)

In [None]:
# =============================================================================
# COMPARE ACTIVATIONS: Two digits side by side
# =============================================================================
# See how different digits activate different neurons!

COMPARE_DIGIT_1 = 1  # A "simple" digit
COMPARE_DIGIT_2 = 8  # A "complex" digit

img1, lbl1 = get_sample_by_digit(test_imgs, test_lbls, COMPARE_DIGIT_1)
img2, lbl2 = get_sample_by_digit(test_imgs, test_lbls, COMPARE_DIGIT_2)

print(f"Comparing digit {COMPARE_DIGIT_1} vs digit {COMPARE_DIGIT_2}:")
plot_activations(model, img1, lbl1)
plot_activations(model, img2, lbl2)

## Introduction to Probes

**What is a probe?** A probe is a simple classifier (usually linear) trained on the activations of a neural network to detect whether a specific concept is represented in those activations.

**Why use probes?** Probes help us understand:
- What concepts the network has learned to represent
- At which layer these concepts emerge
- Whether the network encodes features we didn't explicitly train it on

**How it works:**
1. Extract activations from a specific layer for many inputs
2. Label inputs by the concept we're testing (e.g., "has a loop" vs "no loop")
3. Train a simple linear classifier on these activations
4. If the probe achieves high accuracy, the concept is likely encoded in that layer

**Example concepts to probe:**
- Digits with loops (0, 6, 8, 9) vs without (1, 2, 3, 4, 5, 7)
- Even vs odd digits
- Any digit pair distinction

In [None]:
# =============================================================================
# EXPERIMENT 1: Probe for "digits with loops"
# =============================================================================
# Digits with loops: 0, 6, 8, 9
# Digits without loops: 1, 2, 3, 4, 5, 7
#
# Try changing these groups to test different concepts!

POSITIVE_DIGITS = [0, 6, 8, 9]  # Digits WITH loops
NEGATIVE_DIGITS = [1, 2, 3, 4, 5, 7]  # Digits WITHOUT loops
LAYER_TO_PROBE = 2  # Which layer to probe (0-4)

probe, accuracy = train_probe(
    model, train_loader, test_loader,
    positive_digits=POSITIVE_DIGITS,
    negative_digits=NEGATIVE_DIGITS,
    layer_num=LAYER_TO_PROBE
)

print(f"\nThe probe can distinguish {POSITIVE_DIGITS} from {NEGATIVE_DIGITS}")
print(f"at layer {LAYER_TO_PROBE} with {accuracy:.1%} accuracy!")

In [None]:
# Visualize what the probe learned (which neurons it relies on)
visualize_probe_weights(probe, title=f"Probe Weights for Layer {LAYER_TO_PROBE}")

In [None]:
# =============================================================================
# EXPERIMENT 2: Compare probe accuracy across all layers
# =============================================================================
# This shows at which layer the concept becomes most detectable

results = compare_probes_across_layers(
    model, train_loader, test_loader,
    positive_digits=POSITIVE_DIGITS,
    negative_digits=NEGATIVE_DIGITS
)

## Try Your Own Experiments!

Modify the cells below to test different hypotheses about what the network has learned.

In [None]:
# =============================================================================
# EXPERIMENT 3: Even vs Odd digits
# =============================================================================
# Try probing for even vs odd - does the network represent this concept?

EVEN_DIGITS = [0, 2, 4, 6, 8]
ODD_DIGITS = [1, 3, 5, 7, 9]

probe_even_odd, acc_even_odd = train_probe(
    model, train_loader, test_loader,
    positive_digits=EVEN_DIGITS,
    negative_digits=ODD_DIGITS,
    layer_num=2  # Try different layers!
)

print(f"\nEven vs Odd probe accuracy: {acc_even_odd:.1%}")

In [None]:
# =============================================================================
# EXPERIMENT 4: Single digit pair
# =============================================================================
# How well can the network distinguish between two specific digits?

DIGIT_A = [3]  # Change this!
DIGIT_B = [8]  # Change this!

probe_pair, acc_pair = train_probe(
    model, train_loader, test_loader,
    positive_digits=DIGIT_A,
    negative_digits=DIGIT_B,
    layer_num=2
)

print(f"\n{DIGIT_A} vs {DIGIT_B} probe accuracy: {acc_pair:.1%}")
visualize_probe_weights(probe_pair, title=f"Probe for {DIGIT_A} vs {DIGIT_B}")

In [None]:
# =============================================================================
# EXPERIMENT 5: Your own hypothesis!
# =============================================================================
# What other concepts might the network have learned?
# Ideas:
#   - Digits > 5 vs <= 5
#   - Digits with curves (0,2,3,5,6,8,9) vs straight lines (1,4,7)
#   - Digits with horizontal lines (2,4,5,7) vs without
#
# Fill in your own groups below:

MY_POSITIVE_DIGITS = []  # Fill this in!
MY_NEGATIVE_DIGITS = []  # Fill this in!

# Uncomment when ready:
# probe_custom, acc_custom = train_probe(
#     model, train_loader, test_loader,
#     positive_digits=MY_POSITIVE_DIGITS,
#     negative_digits=MY_NEGATIVE_DIGITS,
#     layer_num=2
# )
# print(f"Custom probe accuracy: {acc_custom:.1%}")