### A Tutorial on Modern Hopfield Networks

#### Before You Begin

This tutorial explores the concepts introduced in the paper [**"Dense Associative Memory for Pattern Recognition"**](https://arxiv.org/abs/1606.01164) by Dmitry Krotov and John J. Hopfield (2016). You will need to implement key formulas from the paper to complete the exercises. Reading it beforehand is highly recommended to understand the theoretical foundation.

#### What are Hopfield Networks?

Hopfield Networks are a form of recurrent neural network that serve as **associative memory**. The core idea is simple: you store a set of patterns in the network, and when you later provide a partial or corrupted version of one of those patterns, the network can retrieve the original, complete pattern.

Think of it like human memory. If you see a blurry picture of a cat, your brain can often fill in the missing details and recognize it as a cat. Hopfield networks attempt to model this process mathematically.

#### From Classical to Modern

The **Classical Hopfield Network**, proposed by John Hopfield in 1982, had a significant limitation: its **storage capacity**. It could only reliably store a small number of patterns, approximately 14% of the number of neurons in the network ($K_{max} \approx 0.14N$).

This is where **Modern Hopfield Networks** come in. Based on work by Krotov and Hopfield (2016), these networks use a generalized energy function that allows for a much greater storage capacity, scaling polynomially with the number of neurons ($K_{max} \propto N^{n-1}$)[c. This enhancement makes them capable of tackling more complex problems, like recalling high-resolution images.

In this tutorial, we will:

1.  Implement and understand the Classical Hopfield Network.
2.  Demonstrate its limitations on a simple image-recall task.
3.  Implement the Modern Hopfield Network.
4.  Show its superior performance on the same task.

Let's start with some utility functions for handling and displaying images.

In [None]:
from typing import Self
from random import randint
from collections.abc import Callable

import PIL.Image
import torch
from torchvision.utils import make_grid
from torchvision.datasets import MNIST
from torchvision.transforms.functional import pil_to_tensor
from matplotlib import pyplot as plt

In [None]:
def preprocess(image: PIL.Image.Image) -> torch.Tensor:
    """Converts a PIL image to a binary tensor."""
    return pil_to_tensor(image).squeeze(0).div(255).round().to(torch.int32)


def image_to_spin_pattern(image: torch.Tensor) -> torch.Tensor:
    """Converts a binary image tensor (0s and 1s) to a spin pattern (-1s and 1s)."""
    return (image * 2 - 1).to(torch.int32)


def spin_pattern_to_image(tensor: torch.Tensor) -> torch.Tensor:
    """Converts a spin pattern back to a binary image tensor."""
    return ((tensor + 1) / 2).to(torch.int32)


def obscure(image: torch.Tensor) -> torch.Tensor:
    """Obscures the bottom half of an image."""
    rows = image.shape[0]
    obscured_image = image.clone()
    obscured_image[rows // 2 :, ...] = 0
    return obscured_image


def display(images: torch.Tensor | list[torch.Tensor], title: str | None = None) -> None:
    """Displays a single image or a grid of images."""
    if isinstance(images, list):
        images = torch.stack(images)

    if images.min() < 0:
        images = spin_pattern_to_image(images)

    match images.ndim:
        case 3:
            grid = (
                make_grid(images.unsqueeze(1) * 255, pad_value=255)
                .permute((1, 2, 0))
                .numpy()
            )
            plot = plt.imshow(grid, cmap="grey")

        case 2:
            plot = plt.imshow(images.numpy(), cmap="grey")

        case other:
            raise ValueError(f"Invalid input dimensionality. Expected 2 or 3-dimensional tensor, got {other}")

    plot.axes.set_title(title or "")
    plot.axes.set_axis_off()
    plt.show()

## Part 1: The Classical Hopfield Network

The classical model operates based on an **energy function**. The network's state is represented by a vector of "spins" (neurons), each being either +1 or -1. The stored patterns are considered stable, low-energy states. When a new pattern is presented, the network updates its state iteratively to descend the energy landscape until it settles into the nearest energy minimum, which corresponds to one of the stored memories.

### Key Concepts:

1.  **Energy Function**: The energy $E$ of a state $\sigma$ is given by a quadratic formula:
    $$E = -\frac{1}{2}\sum_{i,j} T_{ij}\sigma_i\sigma_j$$
    The network seeks to find a state $\sigma$ that minimizes this value.

2.  **Storing Patterns (Hebbian Learning)**: The weights $T_{ij}$ are determined by the patterns to be stored ($\xi^1, \xi^2, ..., \xi^K$). The learning rule is a form of Hebb's rule: "neurons that fire together, wire together." In practice, we sum the outer products of the patterns:
    $$T_{ij} = \sum_{\mu=1}^{K} \xi_i^\mu \xi_j^\mu$$
    This is what the `fit` method below implements.

3.  **Retrieving Patterns (Update Rule)**: The network retrieves a pattern by updating its neurons one by one (asynchronously). A neuron $\sigma_i$ keeps its state if it lowers the total energy, otherwise it flips. This is achieved by the rule that updates a unit in such a way that the energy of the entire configuration decreases. The `predict_async` method implements this iterative process.

In [None]:
class ClassicalHopfieldNetwork:
    size: int
    weights: torch.Tensor
    biases: torch.Tensor

    def __init__(self, size: int, neuron_fire_threshold: float) -> None:
        assert 0 < neuron_fire_threshold < 1

        self.size = size
        self.weights = torch.zeros((size, size), dtype=torch.float32)
        self.biases = torch.full((size,), neuron_fire_threshold, dtype=torch.float32)

    def fit(self, data: torch.Tensor) -> Self:
        """Stores patterns in the network using Hebbian learning."""
        match data.shape:
            case n_records, n_dimensions:
                assert n_records > 0
                assert n_dimensions == self.size

            case other:
                raise ValueError(f"Invalid input shape. Expected 2-dimensional tensor (n_records x n_features), got {other}")

        assert data.dtype == torch.int32

        # TODO:
        ########## YOUR CODE GOES HERE (EXERCISE 1) ##########
        # Your job is to fill in the Hebbian learning rule.
        # 1. Calculate the weights using the formula T_ij = sum(xi_i * xi_j).
        # 2. Normalize the weights by the number of records.
        # 3. Set the diagonal of the weights matrix to 0 (no self-connections).
        ######################################################

        return self

    def predict(self, data: torch.Tensor) -> torch.Tensor:  # TODO: Remove
        """Synchronous update rule (less common)."""
        assert data.shape == (self.size,)
        assert data.dtype == torch.int32

        state = data
        new_state = torch.sign(self.weights @ state.to(torch.float32) - self.biases).to(
            torch.int32
        )

        # Iterate until the state stabilizes.
        while not torch.equal(state, new_state) or torch.equal(state, -new_state):
            state = new_state
            new_state = torch.sign(
                self.weights @ state.to(torch.float32) - self.biases
            ).to(torch.int32)

        return state

    def predict_async(self, data: torch.Tensor) -> torch.Tensor:
        """Asynchronously updates neurons to minimize energy."""
        assert data.shape == (self.size,)
        assert data.dtype == torch.int32

        state = data.reshape(-1, 1).to(torch.float32, copy=True)

        # Iterate through each neuron and update its state.
        for i in range(self.size):
            # Calculate the weighted sum of inputs for neuron i.
            preactivation = self.weights[i, :] @ state - self.biases[i]
            # Update the neuron's state based on the sign.
            activation = torch.sign(preactivation)
            state[i] = activation

        return state.flatten().to(torch.int32)

    def predict_async_stochastic(  # TODO: Remove
        self, data: torch.Tensor, max_iterations: int = 1000
    ) -> torch.Tensor:
        """"Updates neurons in a random order."""
        assert data.shape == (self.size,)
        assert data.dtype == torch.int32

        state = data.reshape(-1, 1).to(torch.float32, copy=True)

        # Update neurons in a random order for a number of iterations.
        for _ in range(max_iterations):
            i = randint(0, self.size - 1)
            preactivation = self.weights[i, :] @ state - self.biases[i]
            activation = torch.sign(preactivation)
            state[i] = activation

        return state.flatten().to(torch.int32)

### A Simple Example: Storing and Recalling 3x3 Patterns

Let's test our classical network on a very simple task. We'll create two 3x3 patterns, a "circle" and a "cross," and store them in a 9-neuron network.

In [None]:
circle = torch.tensor(((-1, -1, -1), (-1, 1, -1), (-1, -1, -1)), dtype=torch.int32)
cross = torch.tensor(((-1, 1, -1), (1, -1, 1), (-1, 1, -1)), dtype=torch.int32)
cross_and_circle_network_input = torch.stack((circle.flatten(), cross.flatten()))

In [None]:
display([cross, circle])

In [None]:
cross_and_circle_network = ClassicalHopfieldNetwork(9, 0.3).fit(cross_and_circle_network_input)

Now, let's see if the network can recall one of the patterns from a corrupted input. We'll start with a completely blank image and see which memory it converges to.

In [None]:
blank_image = torch.full((3, 3), -1, dtype=torch.int32)
recalled_image = cross_and_circle_network.predict_async(blank_image.flatten()).reshape(3, 3)
display([blank_image, recalled_image], "Crosses & Circles - example of inference")

In [None]:
display(cross_and_circle_network.predict(torch.zeros(9, dtype=torch.int32)).reshape(3, 3))

The 3x3 example was trivial. Let's try a more challenging task: recalling handwritten digits from the MNIST dataset. We will store one example for each of the 10 digits (0 through 9) in the network. Then, we will take one of the digits, obscure its bottom half, and ask the network to reconstruct the original image.

This is where we expect to see the limitations of the classical model's low storage capacity. Storing 10 complex patterns in a network of 784 neurons (28x28 pixels) is a difficult task for it.

In [None]:
def extract_digit_examples(mnist: MNIST) -> dict[int, torch.Tensor]:
    """Extracts one example image for each digit from the MNIST dataset."""
    all_digits = dict()

    for image, label in mnist:
        all_digits.setdefault(label, image)

        if all(i in all_digits for i in range(10)):
            break

    return all_digits

In [None]:
mnist_train = MNIST("./data", train=True, transform=preprocess, download=True)
mnist_examples = extract_digit_examples(mnist_train)

In [None]:
display([mnist_examples[digit] for digit in range(10)], title="Memories stored in the network")

In [None]:
mnist_network_input = image_to_spin_pattern(torch.stack([digit.flatten() for digit in mnist_examples.values()]))

In [None]:
mnist_network = ClassicalHopfieldNetwork(
    size=28 * 28,
    neuron_fire_threshold=0.5,
).fit(mnist_network_input)

In [None]:
zero = image_to_spin_pattern(mnist_examples[4])
zero_obscured = image_to_spin_pattern(obscure(mnist_examples[4]))
zero_recalled = mnist_network.predict_async(zero_obscured.flatten()).reshape(28, 28)

In [None]:
display([zero, zero_obscured, zero_recalled], title="Classical Hopfield Network - MNIST digit recall")

As you can see, the classical network struggles to produce a clean reconstruction. The recalled image is noisy and distorted. This is a classic symptom of the network's low capacity—the energy minima for the 10 digits are interfering with each other.

## Part 2: The Modern Hopfield Network

Modern Hopfield Networks, also known as **Dense Associative Memories**, solve the capacity problem by generalizing the energy function. Instead of a simple quadratic function, they use a function $F$ that can create a much "sharper" and more complex energy landscape.

### The New Energy Function

The energy is now defined as:
$$E = -\sum_{\mu=1}^{K} F(\xi^\mu \cdot \sigma)$$
Here, $\xi^\mu \cdot \sigma$ is the dot product (or overlap) between a stored memory and the current state. The function $F$ determines the shape of the energy wells.

By choosing a rapidly growing function for $F$, like a high-degree polynomial ($F(x) = x^n$), the network can store vastly more patterns. For a polynomial of degree $n$, the capacity grows to $K_{max} = \alpha_n N^{n-1}$.

### The New Retrieval Rule

With the new energy function, the update rule for a neuron $\sigma_i$ becomes: **flip the neuron if and only if doing so decreases the total energy of the system**.

The `predict_async` method in our `ModernHopfieldNetwork` class implements this by calculating the total energy of the network with neuron $i$ as `+1` and as `-1`, and then choosing the state that results in a lower energy.

In [None]:
class InteractionFunction:
    """A wrapper for the function F and its derivative."""
    function: Callable[[torch.Tensor], torch.Tensor]
    derivative: Callable[[torch.Tensor], torch.Tensor]

    def __init__(
        self,
        function: Callable[[torch.Tensor], torch.Tensor],
        derivative: Callable[[torch.Tensor], torch.Tensor],
    ) -> None:
        self.function = function
        self.derivative = derivative

    def __call__(self, argument: torch.Tensor) -> torch.Tensor:
        return self.function(argument)


class PolynomialInteraction(InteractionFunction):
    """Interaction function F(x) = x^n."""
    degree: int

    def __init__(self, degree: int) -> None:
        super().__init__(
            function=lambda x: x**degree,
            derivative=lambda x: degree * x ** (degree - 1),
        )
        self.degree = degree


class ExponentialInteraction(InteractionFunction):
    """Interaction function F(x) = exp(x)."""
    def __init__(self) -> None:
        super().__init__(function=lambda x: x.exp(), derivative=lambda x: x.exp())


class ModernHopfieldNetwork:
    size: int
    interaction: InteractionFunction
    training_data: torch.Tensor
    weights: torch.Tensor
    biases: torch.Tensor

    def __init__(self, size: int, neuron_fire_threshold: float, interaction: InteractionFunction) -> None:
        assert 0 < neuron_fire_threshold < 1

        self.size = size
        self.interaction = interaction
        self.training_data = torch.zeros((1, size), dtype=torch.int32)
        self.weights = torch.zeros((size, size), dtype=torch.float32)
        self.biases = torch.full((size,), neuron_fire_threshold, dtype=torch.float32)

    def fit(self, data: torch.Tensor) -> Self:
        """Stores the patterns. For this model, 'fitting' is just memorizing the data."""
        match data.shape:
            case n_records, n_dimensions:
                assert n_records > 0
                assert n_dimensions == self.size

            case other:
                raise ValueError(f"Invalid input shape. Expected 2-dimensional tensor (n_records x n_features), got {other}")

        assert data.dtype == torch.int32
        self.training_data = data.to(torch.float32, copy=True)

        return self
    
    def energy(self, state: torch.Tensor) -> torch.Tensor:
        """Calculates the energy of a given state."""

        # TODO:
        ########## YOUR CODE GOES HERE (EXERCISE 2) ##########
        # Your job is to fill in the modern energy function.
        # 1. Calculate the overlaps (projections) of the state onto each memory pattern.
        # 2. Apply the interaction function F to the vector of overlaps.
        # 3. Sum the results and negate to get the final energy.
        ######################################################
        pass

    def predict_async(self, data: torch.Tensor) -> torch.Tensor:
        """Updates neurons by directly minimizing the energy function."""
        assert data.shape == (self.size,)
        assert data.dtype == torch.int32

        state = data.reshape(-1, 1).to(torch.float32, copy=True)

        # TODO:
        ########## YOUR CODE GOES HERE (EXERCISE 3) ##########
        # Your job is to fill in the modern update rule.
        # For each neuron i in the network:
        #   1. Calculate the network's energy if neuron i were +1.
        #   2. Calculate the network's energy if neuron i were -1.
        #   3. Set the state of neuron i to the value that resulted in lower energy.
        ######################################################

        return state.flatten().to(torch.int32)
    
    def predict_async_stochastic(self, data: torch.Tensor, max_iterations: int) -> torch.Tensor:  # TODO: Remove
        """Updates neurons in random order by directly minimizing energy."""
        assert data.shape == (self.size,)
        assert data.dtype == torch.int32

        state = data.reshape(-1, 1).to(torch.float32, copy=True)

        for _ in range(max_iterations):
            i = randint(0, self.size - 1)

            state[i] = 1
            energy_with_one = self.energy(state)

            state[i] = -1
            energy_with_minus_one = self.energy(state)

            state[i] = 1 if energy_with_one < energy_with_minus_one else -1

        return state.flatten().to(torch.int32)

### Testing the Modern Network on MNIST

Now, let's repeat the exact same experiment with our `ModernHopfieldNetwork`. We will use a `PolynomialInteraction` with a degree of 15. This creates a much sharper energy landscape, which should allow the network to distinguish between the stored digits far more effectively.

In [None]:
mnist_modern_network = ModernHopfieldNetwork(
    size=28 * 28,
    neuron_fire_threshold=0.5,
    interaction=PolynomialInteraction(15),
).fit(mnist_network_input)

In [None]:
digit = image_to_spin_pattern(mnist_examples[4])
digit_obscured = image_to_spin_pattern(obscure(mnist_examples[4]))
digit_recalled = mnist_modern_network.predict_async(digit_obscured.flatten()).reshape(28, 28)
display([digit, digit_obscured, digit_recalled], title="Modern Hopfield Network - MNIST digit recall")