# MLP Memory Experiment

The purpose of this experiment is to determine whether or not a single input token can be memorized by the in-weight memory model. \
If our mechanism of updating weights can't even memorize a single token sequence, then it is unlikely that it will be useful for multi-token sequences.

In the experiment we:
1. Construct an MLP with 1 hidden layer.
2. Add in-memory weights to the hidden layer (only 2 neurons).
3. For $n$ iterations:
   <ol type="a" style="margin-top: 0; padding-left: 20px;">
       <li>Reset the episodic memory weights to 0.</li>
       <li>Generate a random input token x={0|1} and a target y=(x + 1) % 2.</li>
       <li>Freeze the normal weights, and train the episodic memory weights to predict x given a 0.</li>
       <li>Freeze the episodic memory weights, and train the normal weights to predict y given a 0.</li>
   </ol>

Note that the input to both networks is always 0, so the only way that the network can learn to predict the target is by memorizing the input with the episodic memory weights, and then extracting that information during inference. \
If the memory mechanism is working, then we should expect the loss of the target (y) prediction to drop to 0 with enough iterations.

In [1]:
from typing import List

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

## Define a model structure

The MemoryMLP is structured as follows:

<img src="media/mem_network.png" width="500">

The blue weights are the normal weights, the green weights are the "episodic memory weights" (EM weights), and the red weights are frozen at initialization. \
In addition to the standard input -> output mapping, this architecture incorporates another head I have labeled the "memory output". \
This head is used to reconstruct the input when training the episodic memory weights. \
The weights of that head (the red weights) are fixed so that the weights in the hidden layer are forced to encode information about the input.

In [2]:
class MemoryMLP(nn.Module):
    def __init__(self, vocab_size: int, output_dim: int, hidden_dims: List, memory_dims: List, embed_dim: int = 32):
        super(MemoryMLP, self).__init__()

        # Model input is an integer
        self.embedding = nn.Embedding(vocab_size, embed_dim)

        # Standard layers will contain the normal neurons of a network
        self.standard_layers = nn.ModuleList()
        # Memory layers contain the memory specific neurons
        self.memory_layers = nn.ModuleList()

        combined_dims = [h + m for h, m in zip(hidden_dims, memory_dims)]
        layer_sizes = [embed_dim] + combined_dims
        for i in range(len(layer_sizes) - 1):
            self.standard_layers.append(nn.Linear(layer_sizes[i], hidden_dims[i])) # Blue weights
            self.memory_layers.append(nn.Linear(layer_sizes[i], memory_dims[i])) # Green weights

        self.standard_layers.append(nn.Linear(layer_sizes[-1], output_dim, bias=False)) # Final layer blue weights
        self.memory_layers.append(nn.Linear(layer_sizes[-1], output_dim, bias=False)) # Red weights
        self.memory_layers[-1].requires_grad_(False)

        self.reset_memory()

    def reset_memory(self):
        """Resets the memory layer weights."""
        # Set all memory layer weights (except for the output layer) to 0
        for layer in self.memory_layers[:-1]:
            layer.weight.data.fill_(0)
            if layer.bias is not None:
                layer.bias.data.fill_(0)

    def forward(self, x):
        z = self.embedding(x)
        for standard_layer, memory_layer in zip(self.standard_layers[:-1], self.memory_layers[:-1]):
            z = torch.cat([standard_layer(z), memory_layer(z)], dim=-1)
            z = F.gelu(z)
        output = self.standard_layers[-1](z)
        mem_output = self.memory_layers[-1](z)
        return output, mem_output
    
    def get_normal_params(self):
        standard_params = [param for layer in self.standard_layers for param in layer.parameters()]
        standard_params.append(self.embedding.weight)
        return standard_params
    
    def get_memory_params(self):
        return [param for layer in self.memory_layers for param in layer.parameters()]

## Create the model an optimizers

In [3]:
vocab_size = 2

model = MemoryMLP(
    vocab_size = vocab_size,
    output_dim = vocab_size,
    hidden_dims = [128],
    memory_dims = [2],
    embed_dim = 64,
)
model.memory_layers[-1].requires_grad_(False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
std_optimizer = torch.optim.Adam(model.get_normal_params(), lr=1e-4)
mem_optimizer = torch.optim.SGD(model.get_memory_params(), lr=10)

## Train

In [5]:
loss_hist = []
accuracy_hist = []
steps = 5000

bar = tqdm(range(steps))
for sample_idx in bar:
    # Create a random input
    X = torch.randint(0, vocab_size, (1,), dtype=torch.long, device=device)
    # The goal is to predict the input + 1
    y = (X + 1) % vocab_size
    
    zero_input = torch.zeros_like(X)

    # Reset the memory layers to 0
    model.reset_memory()

    # More steps here leads to faster convergence
    for _ in range(1):
        # First train the memory layers to predict the current input when given 0
        # (0 is chosen arbitrarily)
        out, mem_out = model(zero_input)
        loss = criterion(mem_out, X)

        mem_optimizer.zero_grad()
        loss.backward()
        mem_optimizer.step()

    # Then train the standard layers to predict the target (X + 1) from the zero input
    # Because it is not given X as in input, this should only be possible by using the memory
    out, mem_out = model(zero_input)
    loss = criterion(out, y)
    loss_hist.append(loss.item())
    accuracy_hist.append((out.argmax(-1) == y).int().item())

    std_optimizer.zero_grad()
    loss.backward()
    std_optimizer.step()

    if sample_idx % 100 == 0:
        bar.set_description(
            f"Loss: {np.mean(loss_hist[-100:]):.4f} Accuracy: {np.mean(accuracy_hist[-100:]):.4f}")

Loss: 0.0176 Accuracy: 1.0000: 100%|██████████| 5000/5000 [00:11<00:00, 423.65it/s]
