# Introduction & Setup

In this notebook, I examine the behavior of a transformer trained to solve the [Majority Element](https://leetcode.com/problems/majority-element/) LeetCode problem. This involves returning the most common integer from an array of N elements (all of which are integers between 0 and a specified limit).

Specifically, we will look for attention heads that are performing "counting" behavior. My best guess is that we will see something similar to "information about how many occurrences of the number N are contained by the sequence are stored in the position corresponding to the last occurrence of N."

Throughout the notebook I will use the [TransformerLens](https://github.com/neelnanda-io/TransformerLens) library developed by Neel Nanda, as well as the [CircuitsVis](https://github.com/alan-cooney/CircuitsVis) library maintained by Alan Cooney. TransformerLens uses PyTorch hooks in its implementation of HookPoints, which return the activations from all the various layers as information flows through the network.

We will use this to examine the transformer heads and MLP layers in detail as they perform inference.

In [1]:
import circuitsvis as cv

In [2]:
from IPython import get_ipython
ipython = get_ipython()
# Code to automatically update the HookedTransformer code as it is edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [46]:
import plotly.io as pio
pio.renderers.default = "notebook_connected" # or use "browser" if you want plots to open with browser


import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
from dataclasses import dataclass
from einops import rearrange
from tqdm.notebook import tqdm_notebook
import plotly.express as px
from torch.utils.data import DataLoader, Dataset

from torchtyping import TensorType as TT
from typing import List, Tuple

import transformer_lens as tl
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint  # Hooking utilities
from utils import set_seed

# Basic plotting functions

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    #return px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    #return px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)
    #return px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

device = torch.device("cpu")
MODEL_FILENAME = "majority_element_model.pt"
global_seed = 23
set_seed(global_seed)

Random seed set as 23


# Model Training

If you would like to jump to the interpretability section directly, this section is not necessary. You can also just view the outputs without running the model. If you decide to train variations, simple edits to the below will do--however, you will need to ensure you are loading the correct model file in the interpretability section.

## Setup

### Model Definition

TransformerLens provides a basic, configurable decoder-only transformer implementation with HookPoints already integrated. We will use this for our training.

In [4]:
config = tl.EasyTransformerConfig(
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=48,
    n_layers=1,
    n_ctx=128, # context window size
    #use_local_attn=True,
    #attn_types=["local","global"],
    #window_size=2,
    act_fn="solu_ln",
    #attn_only=True,
    d_vocab=10,
    normalization_type="LN",
    seed=global_seed,
)

model = tl.EasyTransformer(config).to(device)

Moving model to device:  cpu


In [5]:
from torchinfo import summary

batch_size = 256
summary(model)

Layer (type:depth-idx)                   Param #
HookedTransformer                        --
├─Embed: 1-1                             640
├─HookPoint: 1-2                         --
├─PosEmbed: 1-3                          8,192
├─HookPoint: 1-4                         --
├─ModuleList: 1-5                        --
│    └─TransformerBlock: 2-1             --
│    │    └─LayerNorm: 3-1               128
│    │    └─LayerNorm: 3-2               128
│    │    └─Attention: 3-3               16,640
│    │    └─MLP: 3-4                     6,352
│    │    └─HookPoint: 3-5               --
│    │    └─HookPoint: 3-6               --
│    │    └─HookPoint: 3-7               --
│    │    └─HookPoint: 3-8               --
│    │    └─HookPoint: 3-9               --
├─LayerNorm: 1-6                         128
│    └─HookPoint: 2-2                    --
│    └─HookPoint: 2-3                    --
├─Unembed: 1-7                           650
Total params: 32,858
Trainable params: 32,858
Non-traina

### Create Data

In [6]:
from typing import Tuple

@dataclass
class DataArgs():
    max_seq_len: int
    batch_size: int
    num_workers: int
    dataset_size: int
    max_val: int

In [7]:
class NumSequenceDataset(Dataset):
    def __init__(self, x, y):
        """Initialize the dataset
        Args:
            x (list): list of input sequences
            y (list): list of output sequences
        """
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [8]:
def build_data(seq_len: int, dataset_size: int, max_val:int =9) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """Builds a dataset of sequences of numbers and their majority element for each 
        position in the sequence

    Args:
        seq_len (int): Length of individual sequences
        dataset_size (int): Number of items in the dataset
        max_val (int, optional): Specifies the maximum of possible integers in sequence. 
            Defaults to 9.

    Returns:
        Tuple[List[torch.Tensor], List[torch.Tensor]]: X and Y lists of sequences
    """    
    x, y = [], []
    for i in range(0, dataset_size):

        seq = torch.randint(0, max_val+1, (seq_len,))
        x.append(seq)

        # Majority element is the most frequent element in the sequence up to that point
        y_list = [torch.bincount(seq[:i]).argmax() for i in range(1, seq_len+1)]
        y.append(torch.tensor(y_list))
        
    return x, y

In [9]:
data_args = DataArgs(
    max_seq_len = config.n_ctx,
    batch_size = 256,
    num_workers = 4,
    dataset_size = 2048,
    max_val = config.d_vocab-1,
)

x, y = build_data(data_args.max_seq_len, data_args.dataset_size, data_args.max_val)
dataset = NumSequenceDataset(x, y)

train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), int(len(dataset) - int(len(dataset)*0.8))])

trainloader = DataLoader(train_set, batch_size=data_args.batch_size, shuffle=True) #, num_workers=4)
valloader = DataLoader(val_set, batch_size=data_args.batch_size, shuffle=False) #, num_workers=4)


In order to provide more signal to the transformer, each item in Y represents the 
majority element up to that point in the sequence. For example, if the sequence is 
[6, 4, 2, 5, 7, 9, 4, 6, 0] then the corresponding Y sequence is
[6, 4, 2, 2, 2, 2, 4, 4, 4]. (Note that for our definition of this problem, if two 
elements are tied for majority, the smallest one is chosen.)

In [10]:
# Examine data
dataset.__getitem__(5)

(tensor([8, 3, 4, 7, 1, 1, 3, 6, 9, 9, 6, 0, 6, 9, 4, 7, 8, 0, 6, 9, 9, 8, 1, 2,
         4, 4, 3, 0, 6, 1, 9, 1, 0, 4, 0, 6, 6, 2, 2, 8, 8, 5, 0, 9, 3, 5, 5, 1,
         7, 2, 5, 9, 3, 8, 8, 2, 0, 5, 9, 6, 6, 6, 3, 4, 3, 6, 1, 1, 7, 6, 3, 9,
         1, 2, 0, 9, 3, 2, 9, 3, 6, 6, 4, 1, 4, 9, 9, 1, 2, 6, 0, 3, 4, 9, 0, 9,
         1, 9, 5, 7, 7, 2, 0, 2, 0, 7, 0, 5, 2, 9, 6, 9, 0, 9, 3, 8, 6, 1, 2, 6,
         9, 4, 2, 6, 7, 9, 7, 3]),
 tensor([8, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 6, 6, 6, 6, 6, 6, 6, 6, 9, 9, 9, 9,
         9, 9, 9, 9, 6, 6, 9, 9, 9, 9, 9, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
         6, 6, 6, 9, 9, 9, 9, 9, 9, 9, 9, 9, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
         6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 9,
         9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
         9, 9, 9, 9, 9, 9, 9, 9]))

### Optimizer & Training Hyperparameter Setup

In [11]:
@dataclass
class TrainingArgs():
    batch_size: int
    epochs: int
    optimizer: torch.optim.Optimizer
    lr: float
    betas: Tuple[float]
    weight_decay: float
    track: bool
    cuda: bool

In [12]:
args = TrainingArgs(
    batch_size = 128,
    epochs = 60,
    optimizer = torch.optim.AdamW,
    lr = 0.001,
    betas = (0.99, 0.999),
    weight_decay = 0.01,
    track = False,
    cuda = False
)

optimizer = args.optimizer(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

### Define Training Loop

In [13]:
from typing import Callable

loss_fn = nn.CrossEntropyLoss()

MODEL_FILENAME = "./majority_element_model.pt"


def train_transformer(model, optimizer, trainloader, args) -> list:
    """Trains a transformer model
    Args:
        model (torch.nn.Module): Transformer model
        optimizer (torch.optim.Optimizer): Optimizer
        trainloader (torch.utils.data.DataLoader): Training data loader
        args (TrainingArgs): Training arguments

    Returns:
        model: Trained model
        list: List of losses

    """
    epochs = args.epochs
    
    loss_list = []
    val_loss_list = []
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=0.01, steps_per_epoch=len(trainloader), epochs=args.epochs)
    
    progress_bar = tqdm_notebook(range(epochs))
    for epoch in progress_bar:
        model.train()
        for (x, y) in trainloader:
        
            logits = model(x.long(), return_type="logits")
            logits = rearrange(logits, 'B S V -> (B S) V')
            y = rearrange(y, 'B S -> (B S)')

            loss = F.cross_entropy(logits, y)
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

            loss_list.append(loss.item())

            progress_bar.set_description(f"Epoch = {epoch}, Loss = {loss.item():.4f}")

        model.eval()
        for (x, y) in valloader:
        
            logits = model(x.long(), return_type="logits")
            logits = rearrange(logits, 'B S V -> (B S) V')
            y = rearrange(y, 'B S -> (B S)')

            loss = F.cross_entropy(logits, y)
            val_loss_list.append(loss.item())

    print(f"Saving model to: {MODEL_FILENAME}")
    torch.save(model, MODEL_FILENAME + "_retrained")
    return model, loss_list, val_loss_list

## Train Model

In [14]:
model, loss_list, val_loss = train_transformer(model, optimizer, trainloader, args=args)

  0%|          | 0/60 [00:00<?, ?it/s]

Saving model to: ./majority_element_model.pt


In [15]:
fig = px.line(y=loss_list, template="simple_white")
fig.update_layout(title="Cross entropy loss on Majority Element", yaxis_range=[0, max(loss_list)], hovermode="x unified")
fig.show()

In [16]:
fig = px.line(y=val_loss, template="simple_white")
fig.update_layout(title="Validation loss on Majority Element", yaxis_range=[0, max(val_loss)], hovermode="x unified")
fig.show()

## Test Inference

Inference can be tested below. The model is not perfectly accurate, but is still quite accurate most of the time.

In [17]:
example_data = torch.randint(0, data_args.max_val+1, (data_args.max_seq_len,))
with torch.inference_mode():
    all_logits = model(example_data.long())
    logits = all_logits[0, -1]
    print(example_data)
    print(f"True majority element: {torch.bincount(example_data).argmax().item()}")
    print(f"Model prediction for majority element: {int(logits.argmax(dim=-1).squeeze())}")

tensor([6, 2, 5, 9, 4, 2, 4, 9, 9, 4, 1, 7, 2, 0, 9, 0, 1, 4, 0, 7, 3, 9, 1, 5,
        9, 9, 1, 9, 4, 4, 6, 8, 9, 5, 8, 5, 7, 9, 1, 0, 0, 0, 9, 4, 9, 8, 7, 4,
        7, 7, 2, 8, 2, 4, 4, 8, 5, 6, 2, 5, 6, 0, 1, 6, 2, 4, 1, 4, 2, 6, 6, 7,
        3, 2, 8, 2, 4, 9, 0, 6, 3, 8, 4, 2, 4, 8, 6, 1, 0, 5, 0, 4, 6, 3, 7, 6,
        1, 2, 0, 3, 7, 2, 0, 3, 9, 6, 0, 2, 8, 3, 7, 0, 7, 0, 3, 6, 7, 1, 6, 1,
        3, 1, 3, 1, 9, 8, 1, 5])
True majority element: 4
Model prediction for majority element: 4


# Model Interpretability
In order to interpret our model, we'll examine the two main parts of the transformer decoder block: 1. The attention heads, and 2. The MLP layers. This is an exploratory analysis, and as such not all layers we examine here have interesting or easily-interpretable information. Nevertheless, it can be beneficial to visualize them and observe their behavior for future analyses.

Below, we'll create an analysis version of the model, load in the state dictionary from the saved model, and load this into the analysis model.

In [19]:
trained_model = torch.load(MODEL_FILENAME, map_location=torch.device('cpu'))

analysis_cfg = tl.EasyTransformerConfig(
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=48,
    n_layers=1,
    n_ctx=128,
    act_fn="solu_ln",
    d_vocab=10,
    normalization_type="LNPre",
    seed=global_seed,
)
analysis_model = tl.EasyTransformer(analysis_cfg)
state_dict = trained_model.state_dict()
analysis_model.load_and_process_state_dict(state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True)

Next, we'll set up a batch for the interpretability section below, and test inference on it:

In [36]:
test_x = torch.tensor(
    [1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
        9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
        2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
        4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
        9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
        8, 7, 1, 8, 8, 4, 0, 4]
)

test_y = torch.tensor(
    [1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
        1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2]
)

with torch.inference_mode():
    all_logits = analysis_model(test_x.long())
    logits = all_logits[0, -1]
    print(test_x)
    print(test_y)
    print(f"True majority element: {torch.bincount(test_x).argmax().item()}")
    print(f"Model prediction for majority element: {int(logits.argmax(dim=-1).squeeze())}")

tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
        9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
        2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
        4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
        9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
        8, 7, 1, 8, 8, 4, 0, 4])
tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
        1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
True majority element: 2
Model prediction for majority element: 2


Next, we'll use the TransformerLens `run_with_cache` method to collect the activations as we pass input through the model. The cache can be accessed with the numbers and names of the individual layers.

In [37]:
logits, cache = analysis_model.run_with_cache(test_x.long(), return_type="logits")

## Logit Contribution for Assorted Layers

Our next step is to examine the contribution of the various layers to the residual stream and the subsequent logits.

The logits output by the transformer result from an unembedding of the content of the residual stream, which itself consists of a series of addition operations performed on the initial embedding--each layer adds its own values to the stream--as well as LayerNorm operations. We can examine the contribution of each layer to the residual stream directly, as follows.

In [38]:
resid_components = [cache["embed"], cache["attn_out", 0], cache["mlp_out", 0]]
labels = ["embed", "A0", "M0"]
resid_stack = torch.stack(resid_components, 0)
resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)
print(resid_stack.shape)

torch.Size([3, 1, 128, 64])


In [39]:
logit_components = resid_stack[:, 0] @ analysis_model.unembed.W_U/cache["scale", None, "ln_final"][0]
print(logit_components.shape)

torch.Size([3, 128, 10])


In [40]:
example_input = test_x.long()
logit_components = logit_components - logit_components.mean(-1, keepdim=True)
line(logit_components[:, torch.arange(1, analysis_model.cfg.n_ctx), example_input[:-1]].T)

Above, we can see the embedding contributions to the residual stream in blue, and the attention block's contributions in red. The MLP's contributions are in green, and are responsible for the greatest-magnitude contribution to the logits. Most of these seem to be negative, but there are several spikes at various sequence positions.

Looking at the sequence at these specific positions doesn't show much right away. Let's examine the MLP in more detail to see what we can learn.

## Attention Analysis

Below, I investigate the attention layer in order to see if there are any clearly interpretable patterns. 

### Attention Patterns by Position

Here we can see the attention pattern of each of the two heads in our attention layer. They appear extremely uniform--at each position, the heads seem to pay roughly equal attention to all previous positions. This suggests that the resultant information copied to each sequence point is an evenly-distributed combination of the information at each previous sequence point. We'll examine this more closely later to see what the heads might actually be doing.

In [25]:
print(type(cache))
attention_pattern = cache["pattern", 0, "attn"]
attention_pattern = attention_pattern.squeeze(0)
print(attention_pattern.shape)

token_labels = [str(int(t)) for t in test_x]

print("Layer 0 Head Attention Patterns:")
cv.attention.attention_heads(tokens=token_labels[:20], attention=attention_pattern[:,:20, :20])

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([2, 128, 128])
Layer 0 Head Attention Patterns:


### Attention Circuit Examination
Following the [Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) paper, let's break down the attention head into QK (attention pattern) and OV (output value) circuits. The QK circuit consists of the embedding, query, key, and embedding-transpose matrices, and the OV circuit consists of the embedding, value, output, and unembedding matrices. However, we'll modify the latter to include the MLP input matrix instead of the unembedding, so that we can see the impact of the attention head on the MLP.

#### QK Circuit

This pattern is a bit difficult to interpret. For now, we'll move on and see if other parts of the network offer more insight.

In [55]:
QK = analysis_model.W_E @ analysis_model.W_Q[0, 0] @ analysis_model.W_K[0, 0].T @ analysis_model.W_E.T
imshow(QK, yaxis="Query", xaxis="Key", title="Full QK Circuit for Attn Layer 0")

#### OV Circuit

Below, we'll visualize the impact of the attention head 0 on the MLP layer. Interestingly, there seems to be a distinct pattern: Each neuron seems to be strongly activated by only one of the vocabulary items. In addition, each vocabulary item seems to result in a particularly strong activation for only one or two neurons.

Since we know that the MLP layer tends to contribute more to the residual stream (and that that contribution is either inhibitory or a strong positive contribution), it seems that the MLP is probably amplifying particular values received by the attention layer.

In [78]:
OV = analysis_model.W_E @ analysis_model.W_V[0, 0] @ analysis_model.W_O[0, 0] @ analysis_model.W_in[0]
imshow(OV, yaxis="Input Vocab", xaxis="Neuron", title="OV Circuit Impact on MLP",aspect="auto")

The above visualizations are somewhat difficult to interpret, but we can see that most neurons don't seem to activate very strongly for most inputs. Neurons with strong activations seem scattered through the vocabulary, suggesting that perhaps activations resulting from specific vocabulary terms mostly come from a few specific neurons.

In [34]:
# Let's print this for reference when we examine the MLP neurons
test_x, test_y

(tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
         9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
         2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
         4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
         9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
         8, 7, 1, 8, 8, 4, 0, 4]),
 tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
         1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2]))

### MLP Analysis

Let's visualize the MLP activations for our `test_x` by sequence.

In [79]:
imshow(cache["post", 0][0], yaxis="Pos", xaxis="Neuron", title="Neuron activations for single inputs", aspect="auto")

Now we seem to be getting somewhere. Some observations:
- There's a clear change at position 29, with most activations dropping off and mostly neuron 44 remaining. This is exactly the point at which--according to the above sequence--the network should register the number `2` as the majority element (a state that continues for almost the entire remaining sequence).
- Activations become less widespread for most neurons, and neuron 44 remains almost constantly activated once `2` becomes the majority element.
- This seems to suggest that for MLP layer 0, neuron 44 represents the number 2!
- This is confirmed even more strongly when we go back to look at the OV circuit. Neuron 44 is only neuron that is affected positively by the attention layer when the value `2` is present.

Let's graph this neuron specifically to see how it evolves with position.

In [80]:
line(cache["post", 0][0, :, 44])

Activations here shows a pattern of getting fairly high and remaining so for the remaining of the sequence. To explain the dips, let's look at the sequence at those points.

In [89]:
print(f"Sequence at positions 30-34: {test_x[30:35]} Predictions: {test_y[30:35]}")
print(f"Sequence at positions 45-49: {test_x[45:49]} Predictions: {test_y[45:49]}")

Sequence at positions 30-34: tensor([5, 3, 7, 3, 2]) Predictions: tensor([2, 2, 2, 3, 2])
Sequence at positions 45-49: tensor([7, 4, 6, 2]) Predictions: tensor([2, 2, 6, 2])


The activation drops when other values become more common, and returns to high activation once the predominance of `2` in the sequence is restored.

Roughly, the output of neuron 44 seems to drop somewhat whenever the number of `2`s stops increasing (perhaps reflecting decreased confidence that 2 is the majority element). However, the activations don't go up monotonically with the number of `2`s, indicating that perhaps the network is simply keeping track of the relative predominance of 2 vs. other elements rather than strictly counting.

To confirm that this neuron is indeed performing the function of tracking the predominance of `2`s, let's perform an ablation experiment.

### Ablation Experiment

In [90]:
with torch.inference_mode():
    all_logits = analysis_model(test_x.long())
    loss = analysis_model(test_x.long(), return_type="loss")
    logits = all_logits[0, -1]
    print(test_x)
    print(test_y)
    print(f"True majority element: {torch.bincount(test_x).argmax().item()}")
    print(f"Model prediction for majority element: {int(logits.argmax(dim=-1).squeeze())}")

tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
        9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
        2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
        4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
        9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
        8, 7, 1, 8, 8, 4, 0, 4])
tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
        1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
True majority element: 2
Model prediction for majority element: 2


We'll define two helper functions that will set activations at specific neurons to 0. We'll then calculate loss with and without the ablation.

In [97]:
layer_to_ablate = 0
neurons_to_ablate = [44]
analysis_model.reset_hooks()

def multi_neuron_ablation_hook(
    value: TT["batch", "pos", "neuron_index"], 
    hook: HookPoint,
    neurons_to_ablate: list[int] = [],
    complement: bool = False
) -> TT["batch", "pos", "neuron_index"]:
    #print(f"Shape of the value tensor: {value.shape}")
    if complement:
        neurons_to_ablate = [n for n in range(value.shape[-1]) if n not in neurons_to_ablate]
    for n in neurons_to_ablate:
        value[:, :, n] = 0.
    return value

def ablate_neurons_in_layer(layer_to_ablate: int, neurons_to_ablate: list[int], complement: bool = False):

    analysis_model.reset_hooks()
    original_logits = analysis_model(test_x, return_type="logits")
    ablated_logits = analysis_model.run_with_hooks(
        test_x, 
        return_type="logits", 
        fwd_hooks=[(
            utils.get_act_name("post", layer_to_ablate), 
            functools.partial(multi_neuron_ablation_hook, neurons_to_ablate=neurons_to_ablate, complement=complement)
        )]
    )
    return original_logits, ablated_logits

original_logits, ablated_logits = ablate_neurons_in_layer(layer_to_ablate, neurons_to_ablate, complement=False)
print(f"Original Loss: {F.cross_entropy(original_logits[0], test_y.long()).item():.3f}")
print(f"Ablated Loss: {F.cross_entropy(ablated_logits[0], test_y.long()).item():.3f}")

Original Loss: 0.004
Ablated Loss: 1.768


The result of our ablation seems fairly clear: Ablating Neuron 44 above resulted in a dramatically higher loss. What if we ablate the complement (all neurons aside from 44)?

In [100]:
org_logits, abl_logits = ablate_neurons_in_layer(layer_to_ablate, neurons_to_ablate, complement=True)
print(f"Original Loss: {F.cross_entropy(org_logits[0], test_y.long()).item():.3f}")
print(f"Ablated Loss: {F.cross_entropy(abl_logits[0], test_y.long()).item():.3f}")

print(f"Input:\n {test_x}")
print(f"True majority element at each point:\n {test_y}")
print(f"Original model prediction for majority element:\n {org_logits.argmax(dim=-1).squeeze()}")
print(f"Ablated model prediction for majority element:\n {abl_logits.argmax(dim=-1).squeeze()}")

Original Loss: 0.004
Ablated Loss: 0.460
Input:
 tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
        9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
        2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
        4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
        9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
        8, 7, 1, 8, 8, 4, 0, 4])
True majority element at each point:
 tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
        1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
Original model prediction for majority element:

Loss did go up, but not by much--and the final prediction was still correct. In addition, the model seems to have gotten pretty bad at prediction any value other than `2`--even when those numbers would be appropriate.

### Checking Other Values

Let's generate a sequence to see what neurons are activated by each potential value:

In [115]:
comprehensive_sequence = [0 for _ in range(2)]
for i in range(1, data_args.max_val+1):
    for _ in range(9+i):
        comprehensive_sequence.append(i)
print(comprehensive_sequence)
comprehensive_sequence = torch.tensor(comprehensive_sequence)

[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]


In [116]:
logits, cache = analysis_model.run_with_cache(comprehensive_sequence.long(), return_type="logits")
imshow(cache["post", 0][0], yaxis="Pos", xaxis="Neuron", title="Neuron activations for single inputs", aspect="auto")

This appears to be a map of which neurons correspond to exactly which values! A predominance of `0` elements activates neuron 34, `1` activates neuron 8, `2` activates neuron 44, etc. Activation values slowly drop once the predominance of another element becomes more clear. Sometimes, multiple neurons are activated, but there's usually an obvious primary. Let's also check our OV circuit again to see if this matches.

In [117]:
OV = analysis_model.W_E @ analysis_model.W_V[0, 0] @ analysis_model.W_O[0, 0] @ analysis_model.W_in[0]
imshow(OV, yaxis="Input Vocab", xaxis="Neuron", title="OV Circuit Impact on MLP",aspect="auto")

It does indeed! We can now map each value to a specific neuron (or neural pair). As an aside, we can also see that our MLP is too large for this task--most neurons aren't even used.

# Conclusion

We have identified that the most salient neural patterns involved with solving the Majority Element task are single neurons or small combinations.

Instead of "counting" specifically, it appears that these neurons simply signal the strength of a value's predominance. It might be interesting to see how this scales with more complex sequences or more values.

We can also see more directions for future investigations:
- Attempt activation patching to see if we can make the model conclude that particular (incorrect) values are in the majority
- Try the model without an MLP to see what the attention pattern looks like
- Try higher vocabulary to see what that changes (possible superposition?)

More investigations will be added to this notebook (and other notebooks as well, as I look at other tasks)!