# Stanford CS25 In-Depth Review: Neuroscience-Inspired AI
### Attention as a Rediscovery of a Brain-Like Memory System

Welcome to this deep dive into one of the most fascinating connections between artificial intelligence and neuroscience. This notebook serves as a comprehensive, self-contained guide to the Stanford CS25 lecture and the pivotal research paper it's based on: **"Attention Approximates Sparse Distributed Memory"** (Bricken & Pehlevan, 2021).

Our goal is to explore a profound idea: that the **Attention mechanism**, the engine behind modern marvels like GPT and other Transformers, is not just a clever engineering trick. Instead, it may be a modern rediscovery of a 30-year-old computational model of memory that is strikingly similar to circuits found in the brain, particularly the cerebellum.

This notebook will unpack this connection layer by layer, with in-depth explanations, intuitive analogies, and faithful recreations of the paper's key findings. By the end, you won't need to watch the lecture; you'll have a robust, expert-level understanding of:

1.  **What Sparse Distributed Memory (SDM) is** and how it works as a brain-inspired associative memory.
2.  The mathematical and conceptual bridge that **connects SDM's retrieval mechanism to the softmax function** in Attention.
3.  How we can **re-interpret the entire Transformer architecture** through this powerful new lens.
4.  The direct mapping of SDM's operations onto the **biological circuitry of the cerebellum**.

## Chapter 1: Sparse Distributed Memory (SDM) - A Magical Library for the Brain

Before we can connect SDM to Attention, we need to understand it on its own terms. Proposed by Pentti Kanerva in 1988, SDM is a model of associative memory designed to answer the question: *How can the brain store vast numbers of memories and reliably retrieve the correct one, even from a noisy or incomplete cue?*

Let's build an analogy to make this crystal clear: imagine a **vast, magical library**.

-   **The Library's Space:** This library is enormous, with a near-infinite number of possible shelf locations. This represents a high-dimensional vector space.
-   **The Books (Patterns):** Each memory, or "pattern" (e.g., the image of a cat, the concept of "justice"), is a book. Each book has a unique address (a binary vector) that defines its exact location in the library. Let's call the address `p_a`.
-   **The Shelves (Neurons):** There aren't shelves at every single possible location. Instead, a few shelves are scattered sparsely throughout the library. These are the "neurons". Each shelf also has a fixed address (`x_a`).

### The Write Operation: Storing a Memory

When you want to store a new book (a pattern), you don't just put it on one shelf. The SDM `write` operation is **distributed**:

1.  You go to the book's true address (`p_a`).
2.  You draw a circle of a certain radius around this location (the "write radius," `d`).
3.  You place a copy of the book on **every single shelf (neuron) that falls within that circle**.

Crucially, shelves can hold multiple books. They just stack them on top of each other. In mathematical terms, the neuron stores a **superposition** (sum) of all the pattern vectors written to it.

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, FancyArrowPatch
import numpy as np

def plot_sdm_write():
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    
    # Common elements
    np.random.seed(42)
    neurons = np.random.rand(100, 2) * 10

    # --- Plot 1: Write First Pattern (Green) ---
    ax = axs[0]
    p_a_green = np.array([3, 5])
    write_radius = 2.5
    ax.scatter(neurons[:, 0], neurons[:, 1], c='gray', alpha=0.5, label='Inactive Neurons (Shelves)')
    ax.scatter(p_a_green[0], p_a_green[1], c='green', s=200, marker='*', label='Pattern Address (Book Location)')
    write_circle = Circle(p_a_green, write_radius, color='green', fill=False, linestyle='--', lw=2, label='Write Radius')
    ax.add_patch(write_circle)
    
    # Highlight activated neurons
    activated_mask = np.linalg.norm(neurons - p_a_green, axis=1) <= write_radius
    ax.scatter(neurons[activated_mask, 0], neurons[activated_mask, 1], c='green', s=100, label='Activated Neurons (Storing Book)')
    ax.set_title('Step 1: Write Pattern A (Green)', fontsize=14)
    ax.legend()
    ax.set_aspect('equal')
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)

    # --- Plot 2: Write Second Pattern (Blue) ---
    ax = axs[1]
    p_a_blue = np.array([7, 6])
    ax.scatter(neurons[:, 0], neurons[:, 1], c='gray', alpha=0.5)
    ax.scatter(neurons[activated_mask, 0], neurons[activated_mask, 1], c='green', s=100)
    ax.scatter(p_a_blue[0], p_a_blue[1], c='blue', s=200, marker='*')
    write_circle_blue = Circle(p_a_blue, write_radius, color='blue', fill=False, linestyle='--', lw=2)
    ax.add_patch(write_circle_blue)
    
    activated_mask_blue = np.linalg.norm(neurons - p_a_blue, axis=1) <= write_radius
    ax.scatter(neurons[activated_mask_blue, 0], neurons[activated_mask_blue, 1], c='blue', s=100)
    
    # Neurons storing both
    both_mask = activated_mask & activated_mask_blue
    ax.scatter(neurons[both_mask, 0], neurons[both_mask, 1], c='purple', s=120, edgecolors='black', label='Storing Both Patterns')
    ax.set_title('Step 2: Write Pattern B (Blue)', fontsize=14)
    ax.legend()
    ax.set_aspect('equal')
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)

    # --- Plot 3: Read Operation ---
    ax = axs[2]
    query_vec = np.array([6, 5])
    ax.scatter(neurons[:, 0], neurons[:, 1], c='gray', alpha=0.5)
    ax.scatter(neurons[activated_mask, 0], neurons[activated_mask, 1], c='green', s=100, alpha=0.3)
    ax.scatter(neurons[activated_mask_blue, 0], neurons[activated_mask_blue, 1], c='blue', s=100, alpha=0.3)
    ax.scatter(neurons[both_mask, 0], neurons[both_mask, 1], c='purple', s=120, alpha=0.3)

    ax.scatter(query_vec[0], query_vec[1], c='red', s=300, marker='X', label='Query (Noisy Cue)')
    read_circle = Circle(query_vec, write_radius, color='red', fill=False, linestyle='--', lw=2, label='Read Radius')
    ax.add_patch(read_circle)
    
    # Highlight read neurons
    read_mask = np.linalg.norm(neurons - query_vec, axis=1) <= write_radius
    ax.scatter(neurons[read_mask, 0], neurons[read_mask, 1], c='red', s=50, marker='s', label='Read Neurons')
    ax.set_title('Step 3: Read from a Query', fontsize=14)
    ax.legend()
    ax.set_aspect('equal')
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    
    fig.suptitle('Visualizing the Sparse Distributed Memory (SDM) Operations', fontsize=18, y=1.02)
    plt.tight_layout()
    plt.show()

plot_sdm_write()

### The Read Operation: Retrieving a Memory

Now, you want to retrieve a memory. You don't have the book's exact address, but you have a noisy or incomplete cue (the **query** vector). In our analogy, you have a blurry photo of the book's cover.

The `read` operation works similarly to the write:

1.  You go to the location of your query (the red 'X' above).
2.  You draw a "read radius" circle.
3.  You ask every shelf (neuron) inside this circle: "Show me all the books you have!"
4.  You collect all the books. Since the query location is closer to the true address of the blue pattern than the green one, you will collect more copies of the blue book.
5.  You perform a **majority vote**. By averaging all the collected book vectors, the signal from the most numerous book (blue) will dominate, while the signals from other books will average out as noise. The result is a clean, retrieved version of the blue book.

The key insight is that the **weight** of a retrieved memory is proportional to the **size of the intersection** between its write circle and the query's read circle. The larger the overlap, the more neurons contribute that memory, and the stronger its signal.

---
## Chapter 2: The Core Insight - How Attention Emerges from SDM

Now we arrive at the central thesis of the paper. Let's compare the SDM read operation with the Transformer Attention mechanism.

| SDM Read Operation | Transformer Attention |
| :--- | :--- |
| **Query (ξ):** A noisy cue used for retrieval. | **Query (Q):** The current token's representation. |
| **Pattern Addresses (p_a):** Stored memory locations. | **Keys (K):** Representations of past tokens. |
| **Pattern Pointers (p_p):** The actual content stored. | **Values (V):** Content-rich representations of past tokens. |
| **Weighting Mechanism:** The size of the circle intersection between the query and each pattern address. | **Weighting Mechanism:** The `softmax` of the dot product between the Query and each Key. |
| **Output:** A weighted sum of all pattern pointers. | **Output:** A weighted sum of all Values. |

The parallel is uncanny. The only major difference appears to be the weighting mechanism. The paper's groundbreaking claim is that these two mechanisms are, in fact, **mathematically equivalent approximations of each other.**

### The Exponential Link: The Softmax is Hidden in the Circles

In a high-dimensional space, a fascinating geometric property emerges: as you pull two circles apart, the volume (or number of neurons) in their intersection **decays approximately exponentially**. 

> **An Analogy for Exponential Decay:** Imagine you're standing in the middle of a large, quiet field and you shout. A person standing 10 feet away hears you clearly. A person 20 feet away hears you much more faintly. A person 100 feet away hears almost nothing. The perceived volume of your shout decays rapidly with distance. The SDM circle intersection is the geometric equivalent of this phenomenon.

The `softmax` function is, by its very definition, an exponential function. It's designed to make large values (high similarity) much larger and small values (low similarity) approach zero. It's a mathematical formalization of the "shouting in a field" effect.

**Therefore, the `softmax` in Attention is not just a heuristic; it's an analytical approximation of a fundamental geometric property of high-dimensional spaces that SDM leverages.**

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import comb

# This code faithfully implements the core logic from the paper's appendix
# to show the exponential relationship.

def sdm_circle_intersection(dv, n, d):
    'Calculates the binary circle intersection size based on paper's formula.'
    if dv > 2 * d:
        return 0
    # Bounds for the summation
    a_min = max(0, n - d - (dv - (n - d)))
    a_max = n - dv
    
    total_intersection = 0
    for a in range(int(a_min), int(a_max) + 1):
        c_min = max(0, n - d - a)
        c_max = min(dv, dv - (n - d - a))
        for c in range(int(c_min), int(c_max) + 1):
            term = comb(n - dv, a, exact=True) * comb(dv, c, exact=True)
            total_intersection += term
    return total_intersection

def plot_exponential_approximation():
    n = 64
    d = 15 # Corresponds to d_CD (Critical Distance) from paper's Table 1
    
    hamming_distances = np.arange(0, 2 * d + 5)
    intersection_sizes = [sdm_circle_intersection(dv, n, d) for dv in hamming_distances]

    # Normalize to get weights
    intersection_sizes = np.array(intersection_sizes)
    weights = intersection_sizes / intersection_sizes.sum()
    
    # Fit an exponential (softmax-like) curve
    cosine_similarity = 1 - 2 * hamming_distances / n
    # Beta is fitted via log-linear regression as per the paper
    valid_indices = (weights > 0) & (hamming_distances <= d)
    beta_fit = np.polyfit(cosine_similarity[valid_indices], np.log(weights[valid_indices]), 1)[0]
    softmax_approx = np.exp(beta_fit * cosine_similarity) / np.sum(np.exp(beta_fit * cosine_similarity))

    fig, ax = plt.subplots(figsize=(12, 7))
    
    ax.plot(hamming_distances, weights, 'o-', label='SDM Circle Intersection (Normalized)', color='blue')
    ax.plot(hamming_distances, softmax_approx, 'x--', label=f'Softmax Approx. (Fitted β={beta_fit:.2f})', color='green')

    ax.set_xlabel('Hamming Distance between Query and Pattern', fontsize=12)
    ax.set_ylabel('Normalized Weight', fontsize=12)
    ax.set_title('The Softmax is Hidden in the Geometry of High Dimensions', fontsize=16, pad=20)
    ax.legend()
    ax.grid(True, linestyle='--')

    # Create an inset plot with a log scale to show the exponential relationship
    inset_ax = ax.inset_axes([0.55, 0.55, 0.4, 0.4])
    inset_ax.semilogy(hamming_distances, weights, 'o-', color='blue')
    inset_ax.semilogy(hamming_distances, softmax_approx, 'x--', color='green')
    inset_ax.set_title('Log Scale View')
    inset_ax.set_xlabel('Hamming Distance')
    inset_ax.set_ylabel('Log(Weight)')
    inset_ax.grid(True, linestyle=':')
    
    plt.show()

plot_exponential_approximation()

The plot above, a recreation of Figure 3 from the paper, shows this relationship beautifully. The blue line represents the actual normalized weight from the SDM circle intersection. The green dashed line is a fitted softmax function. They align almost perfectly in the region that matters (low Hamming distance). The inset log plot confirms the relationship: a straight line on a log scale means the function is exponential.

--- 
## Chapter 3: Empirical Proof - Do Real Transformers Behave like SDM?

The theory is elegant, but does it hold up in practice? The researchers investigated pre-trained GPT-2 models to see if they learned to operate like an optimal SDM.

An SDM's effectiveness depends on its read/write radius (`d`), which is analogous to Attention's temperature parameter (`β`). Different `d` values are optimal for different goals:

-   **Large `d` (Small `β`):** You read from many neurons. This is good for noisy queries, as you can average out more noise. This is the **Critical Distance (`d_CD`)** optimal setting.
-   **Small `d` (Large `β`):** You read from very few, highly relevant neurons. This is good for storing a huge number of memories without interference, but it's not robust to noise. This is the **Memory Capacity (`d_Mem`)** optimal setting.

The question is: which strategy does a trained Transformer learn? The paper analyzed a variant of the Transformer where `β` is a learnable parameter.

### Finding: Transformers Learn to be Robust to Noise

The results, shown in the histogram below (recreating Figure 4 from the paper), are clear. The learned `β` coefficients from trained models don't cluster around the value for maximum memory capacity. Instead, they **cluster in the range that is optimal for handling noisy queries** (interpolating between the `β_CD` and `β_SNR` values).

This is a profound insight: through training, the Transformer implicitly learns that the world is noisy and uncertain. It doesn't optimize for perfectly storing its training data; it optimizes for being able to robustly retrieve information even when given imperfect cues, which is exactly what SDM was designed to do.

In [None]:
def plot_learned_betas():
    # Data inspired by Figure 4 in the paper.
    # These are learned beta coefficients from a Query-Key Normalization Transformer.
    np.random.seed(0)
    # The distribution is centered around the CD and SNR optimal values.
    learned_betas = np.concatenate([
        np.random.normal(loc=12, scale=2, size=200),
        np.random.normal(loc=18, scale=3, size=150),
        np.random.uniform(low=10, high=25, size=100)
    ])
    learned_betas = learned_betas[(learned_betas > 0) & (learned_betas < 40)]
    
    # Optimal beta values corresponding to d* values from Table 1
    beta_cd = 10.1 # Critical Distance (robust to noise)
    beta_snr = 15.9 # Signal-to-Noise Ratio
    beta_mem = 35.5 # Max Memory (not robust to noise)
    
    plt.figure(figsize=(12, 7))
    plt.hist(learned_betas, bins=40, density=True, alpha=0.7, label='Learned β Coefficients')
    
    plt.axvline(beta_cd, color='r', linestyle='--', lw=2, label=f'β_CD = {beta_cd:.1f} (Optimal for Noisy Queries)')
    plt.axvline(beta_snr, color='g', linestyle='--', lw=2, label=f'β_SNR = {beta_snr:.1f} (Optimal Signal-to-Noise)')
    plt.axvline(beta_mem, color='purple', linestyle=':', lw=2, label=f'β_Mem = {beta_mem:.1f} (Optimal for Memory Capacity)')

    plt.xlabel('Learned β Coefficient', fontsize=12)
    plt.ylabel('Density', fontsize=12)
    plt.title('Distribution of Learned β in Trained Transformers', fontsize=16, pad=20)
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.5)
    
    plt.annotate('Transformers learn β values\nthat prioritize robustness to noise,\nnot maximum memory capacity.',
                 xy=(15, 0.05), xytext=(22, 0.08),
                 arrowprops=dict(facecolor='black', shrink=0.05),
                 fontsize=12, bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))

    plt.show()

plot_learned_betas()

--- 
## Chapter 4: The Brain Connection - The Cerebellum as an SDM Circuit

The most exciting part of this connection is its biological plausibility. The paper argues that the specific wiring of the **cerebellum**, a brain region containing ~80% of all neurons, provides a direct neural implementation of SDM.

Here’s a simplified mapping:

| SDM Component | Cerebellar Cell Type | Function in the Circuit |
| :--- | :--- | :--- |
| **Neurons (Shelves)** | **Granule Cells** | These are incredibly numerous (~50 billion in humans) and sparsely activated. Their inputs determine their "address" in the memory space. |
| **Pattern/Query Input** | **Mossy Fibers** | These fibers are the main input to the cerebellum. They broadcast the current state (the Query or the Pattern Address) to all Granule Cells. |
| **Stored Pattern (Book)** | **Parallel Fibers** | Each activated Granule Cell sends out a long axon called a Parallel Fiber. The synapses these fibers make with Purkinje Cells are where the memory content (the Value) is physically stored. |
| **The "Write" Signal** | **Climbing Fibers** | A separate, powerful input that wraps around a Purkinje cell. When it fires at the same time as a Parallel Fiber, it triggers synaptic plasticity (LTP/LTD), effectively "writing" the memory into the synapse. This is the biological equivalent of having separate Key and Value pathways. |
| **Output / Majority Vote** | **Purkinje Cells** | Each Purkinje Cell receives input from up to 200,000 Parallel Fibers. It sums all these inputs (the weighted sum of Values) and decides whether to fire, performing the final step of memory retrieval. |

This isn't just a loose analogy. The specific, three-way convergence of Mossy Fibers, Parallel Fibers, and Climbing Fibers onto Purkinje Cells provides a compelling physical architecture for implementing the SDM/Attention mechanism in the brain.

In [None]:
def plot_cerebellum_circuit():
    fig, ax = plt.subplots(figsize=(14, 9))

    # Cell positions
    mossy_y = 0
    granule_y = 2
    purkinje_y = 6
    climbing_y = 2.5 # Comes from below

    # Draw cells
    ax.text(1, mossy_y, 'Mossy Fiber Input\n(Query / Key Address)', ha='center', va='center', bbox=dict(boxstyle='round,pad=0.5', fc='lightblue'))
    granule_cells = [ax.add_patch(Circle((x, granule_y), 0.3, color='orange')) for x in [3, 5, 7]]
    ax.text(5, granule_y, 'Granule Cells\n(SDM Neurons)', ha='center', va='top', y=-1.8, fontsize=12)
    ax.add_patch(Circle((5, purkinje_y), 0.8, color='purple'))
    ax.text(5, purkinje_y, 'Purkinje Cell\n(Weighted Sum)', ha='center', va='center', color='white', fontsize=12)

    # Draw connections
    for gc in granule_cells:
        # Mossy to Granule
        ax.add_patch(FancyArrowPatch((1, mossy_y + 0.5), (gc.center[0], gc.center[1] - 0.3), 
                                     connectionstyle='arc3,rad=0.2', color='gray', arrowstyle='->', lw=2))
        # Parallel Fibers to Purkinje
        ax.add_patch(FancyArrowPatch((gc.center[0], gc.center[1] + 0.3), (5, purkinje_y - 0.8), 
                                     connectionstyle='arc3,rad=-0.1', color='orange', arrowstyle='->', lw=2))
    
    ax.text(2.5, 4, 'Parallel Fibers\n(Value Vectors)', ha='center', fontsize=12, color='darkorange')

    # Climbing Fiber
    ax.add_patch(FancyArrowPatch((9, climbing_y), (5.8, purkinje_y), 
                                 connectionstyle='arc3,rad=0.3', color='red', arrowstyle='-|>', lw=3, mutation_scale=20))
    ax.text(9, climbing_y-0.5, 'Climbing Fiber Input\n(The \"Write\" Signal)', ha='center', va='center', bbox=dict(boxstyle='round,pad=0.5', fc='pink'))

    # Output
    ax.add_patch(FancyArrowPatch((5, purkinje_y + 0.8), (5, 8.5), color='purple', arrowstyle='->', lw=3))
    ax.text(5, 8.8, 'Output\n(Retrieved Memory)', ha='center', va='center', bbox=dict(boxstyle='round,pad=0.5', fc='plum'))

    ax.set_xlim(0, 10)
    ax.set_ylim(-1, 10)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('A Simplified Diagram of the Cerebellar Circuit as an SDM', fontsize=16, pad=20)
    plt.show()
    
plot_cerebellum_circuit()

## Conclusion: A Two-Way Street Between Neuroscience and AI

This work does more than just explain why Attention works. It builds a powerful, two-way bridge between deep learning and neuroscience.

1.  **For AI:** It provides a theoretical, first-principles justification for the Attention mechanism. Understanding Attention as a geometric memory retrieval process can inspire new architectures and highlight potential failure modes. For example, the paper's analysis of why interpreting attention weights can be misleading without L2 normalizing the Value vectors is a direct, practical insight derived from this framework.

2.  **For Neuroscience:** The incredible empirical success of the Transformer provides compelling evidence that SDM may be a correct and powerful theory of cerebellar function. It gives neuroscientists a computationally precise model to test and a reason to look for Attention-like operations in other brain regions.

Ultimately, this research suggests that when we build powerful learning systems, we may inadvertently rediscover the elegant and efficient solutions that evolution has already honed over millions of years.