# Introduction & Setup

In this notebook, I attempt to examine the behavior of a transformer trained to solve the [Majority Element](https://leetcode.com/problems/majority-element/) LeetCode problem. This involves returning the most common integer from an array of N elements (all of which are integers between 0 and a specified limit).

Specifically, we will look for attention heads that are performing "counting" behavior. My best guess is that we will see something similar to "information about how many occurrences of the number N are contained by the sequence are stored in the position corresponding to the last occurrence of N."

Throughout the notebook I will use the [TransformerLens](https://github.com/neelnanda-io/TransformerLens) library developed by Neel Nanda, as well as the [CircuitsVis](https://github.com/alan-cooney/CircuitsVis) library maintained by Alan Cooney. TransformerLens uses PyTorch hooks in its implementation of HookPoints, which return the activations from all the various layers as information flows through the network.

We will use this to examine the transformer heads and MLP layers in detail as they perform inference.

In [1]:
import circuitsvis as cv

In [2]:
from IPython import get_ipython
ipython = get_ipython()
# Code to automatically update the HookedTransformer code as it is edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [217]:
import plotly.io as pio
pio.renderers.default = "notebook_connected" # or use "browser" if you want plots to open with browser


import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange
from tqdm.notebook import tqdm_notebook
import plotly.express as px
from torch.utils.data import DataLoader, Dataset

from torchtyping import TensorType as TT
from typing import List, Tuple

import transformer_lens as tl
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint  # Hooking utilities

# Basic plotting functions

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    #return px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    #return px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)
    #return px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model Training

## Setup

### Model Definition

TransformerLens provides a basic, configurable transformer implementation with HookPoints already integrated. We will use this for our training.

In [4]:
config = tl.EasyTransformerConfig(
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=256,
    n_layers=1,
    n_ctx=128, # context window size
    #use_local_attn=True,
    #attn_types=["local","global"],
    #window_size=2,
    act_fn="solu_ln",
    #attn_only=True,
    d_vocab=10,
    normalization_type="LN",
    seed=23,
)

model = tl.EasyTransformer(config).to(device)

Moving model to device:  cpu


In [168]:
from torchinfo import summary

batch_size = 256
summary(model)

Layer (type:depth-idx)                   Param #
HookedTransformer                        --
├─Embed: 1-1                             640
├─HookPoint: 1-2                         --
├─PosEmbed: 1-3                          8,192
├─HookPoint: 1-4                         --
├─ModuleList: 1-5                        --
│    └─TransformerBlock: 2-1             --
│    │    └─LayerNorm: 3-1               128
│    │    └─LayerNorm: 3-2               128
│    │    └─Attention: 3-3               16,640
│    │    └─MLP: 3-4                     33,600
│    │    └─HookPoint: 3-5               --
│    │    └─HookPoint: 3-6               --
│    │    └─HookPoint: 3-7               --
│    │    └─HookPoint: 3-8               --
│    │    └─HookPoint: 3-9               --
├─LayerNorm: 1-6                         128
│    └─HookPoint: 2-2                    --
│    └─HookPoint: 2-3                    --
├─Unembed: 1-7                           650
Total params: 60,106
Trainable params: 60,106
Non-train

### Create Data

In [6]:
from typing import Tuple

@dataclass
class DataArgs():
    max_seq_len: int
    batch_size: int
    num_workers: int
    dataset_size: int
    max_val: int

In [7]:
class NumSequenceDataset(Dataset):
    def __init__(self, x, y):
        """Initialize the dataset
        Args:
            x (list): list of input sequences
            y (list): list of output sequences
        """
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [8]:
def build_data(seq_len: int, dataset_size: int, max_val:int =9) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """Builds a dataset of sequences of numbers and their majority element for each 
        position in the sequence

    Args:
        seq_len (int): Length of individual sequences
        dataset_size (int): Number of items in the dataset
        max_val (int, optional): Specifies the maximum of possible integers in sequence. 
            Defaults to 9.

    Returns:
        Tuple[List[torch.Tensor], List[torch.Tensor]]: X and Y lists of sequences
    """    
    x, y = [], []
    for i in range(0, dataset_size):

        seq = torch.randint(0, max_val+1, (seq_len,))
        x.append(seq)

        # Majority element is the most frequent element in the sequence up to that point
        y_list = [torch.bincount(seq[:i]).argmax() for i in range(1, seq_len+1)]
        y.append(torch.tensor(y_list))
        
    return x, y

In [9]:
data_args = DataArgs(
    max_seq_len = config.n_ctx,
    batch_size = 256,
    num_workers = 4,
    dataset_size = 2048,
    max_val = config.d_vocab-1,
)

x, y = build_data(data_args.max_seq_len, data_args.dataset_size, data_args.max_val)
dataset = NumSequenceDataset(x, y)

train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), int(len(dataset) - int(len(dataset)*0.8))])

trainloader = DataLoader(train_set, batch_size=data_args.batch_size, shuffle=True, num_workers=4)
valloader = DataLoader(val_set, batch_size=data_args.batch_size, shuffle=False, num_workers=4)


In [10]:
# Examine data
# In order to provide more signal to the transformer, each item in Y represents the 
# majority element up to that point in the sequence. For example, if the sequence is 
# [6, 4, 2, 5, 7, 9, 4, 6, 0] then the corresponding Y sequence is
# [6, 4, 2, 2, 2, 2, 4, 4, 4]. (Note that for our definition of this problem, if two 
# elements are tied for majority, the smallest one is chosen.)
dataset.__getitem__(5)

(tensor([6, 4, 2, 5, 7, 9, 4, 6, 0, 8, 7, 9, 3, 1, 0, 7, 4, 7, 2, 3, 8, 8, 8, 0,
         5, 2, 3, 9, 2, 1, 3, 2, 7, 0, 7, 2, 0, 7, 3, 6, 8, 1, 3, 5, 3, 7, 5, 4,
         4, 5, 8, 6, 7, 3, 5, 0, 8, 4, 5, 0, 0, 9, 3, 8, 1, 4, 2, 9, 7, 0, 1, 5,
         0, 9, 7, 0, 0, 0, 7, 5, 0, 8, 4, 6, 7, 1, 5, 6, 6, 2, 6, 3, 9, 0, 6, 8,
         9, 8, 8, 2, 6, 3, 1, 5, 0, 6, 8, 7, 2, 1, 2, 5, 4, 1, 9, 6, 4, 2, 3, 1,
         0, 9, 5, 2, 4, 2, 3, 3]),
 tensor([6, 4, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 0, 7, 4, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 2, 2, 2, 2, 2, 2, 7, 2, 2, 7, 7, 7, 7, 7, 7, 7, 3, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 3, 3, 3, 3, 3, 3, 7, 7, 7, 7,
         0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]))

### Optimizer & Training Hyperparameter Setup

In [11]:
@dataclass
class TrainingArgs():
    batch_size: int
    epochs: int
    optimizer: torch.optim.Optimizer
    lr: float
    betas: Tuple[float]
    weight_decay: float
    track: bool
    cuda: bool

In [12]:
args = TrainingArgs(
    batch_size = 128,
    epochs = 60,
    optimizer = torch.optim.AdamW,
    lr = 0.001,
    betas = (0.99, 0.999),
    weight_decay = 0.01,
    track = False,
    cuda = False
)

optimizer = args.optimizer(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

### Define Training Loop

In [13]:
from typing import Callable

loss_fn = nn.CrossEntropyLoss()

MODEL_FILENAME = "./majority_element_model.pt"


def train_transformer(model, optimizer, trainloader, args) -> list:
    """Trains a transformer model
    Args:
        model (torch.nn.Module): Transformer model
        optimizer (torch.optim.Optimizer): Optimizer
        trainloader (torch.utils.data.DataLoader): Training data loader
        args (TrainingArgs): Training arguments

    Returns:
        model: Trained model
        list: List of losses

    """
    epochs = args.epochs
    
    loss_list = []
    val_loss_list = []
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=0.01, steps_per_epoch=len(trainloader), epochs=args.epochs)
    
    progress_bar = tqdm_notebook(range(epochs))
    for epoch in progress_bar:
        model.train()
        for (x, y) in trainloader:
        
            logits = model(x.long(), return_type="logits")
            logits = rearrange(logits, 'B S V -> (B S) V')
            y = rearrange(y, 'B S -> (B S)')

            loss = F.cross_entropy(logits, y)
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

            loss_list.append(loss.item())

            progress_bar.set_description(f"Epoch = {epoch}, Loss = {loss.item():.4f}")

        model.eval()
        for (x, y) in valloader:
        
            logits = model(x.long(), return_type="logits")
            logits = rearrange(logits, 'B S V -> (B S) V')
            y = rearrange(y, 'B S -> (B S)')

            loss = F.cross_entropy(logits, y)
            val_loss_list.append(loss.item())

    print(f"Saving model to: {MODEL_FILENAME}")
    torch.save(model, MODEL_FILENAME)
    return model, loss_list, val_loss_list

## Train Model

In [14]:
model, loss_list, val_loss = train_transformer(model, optimizer, trainloader, args=args)

  0%|          | 0/60 [00:00<?, ?it/s]

Saving model to: ./majority_element_model.pt


In [15]:
fig = px.line(y=loss_list, template="simple_white")
fig.update_layout(title="Cross entropy loss on Majority Element", yaxis_range=[0, max(loss_list)])
fig.show()

In [16]:
fig = px.line(y=val_loss, template="simple_white")
fig.update_layout(title="Validation loss on Majority Element", yaxis_range=[0, max(val_loss)])
fig.show()

## Test Inference

Inference can be tested below. The model is not perfectly accurate, but is still quite accurate most of the time.

In [17]:
import sys
sys.path.append('common_modules/')

import sample_methods as s

example_data = torch.randint(0, data_args.max_val+1, (data_args.max_seq_len,))
with torch.inference_mode():
    all_logits = model(example_data.long())
    logits = all_logits[0, -1]
    print(example_data)
    print(f"True majority element: {torch.bincount(example_data).argmax().item()}")
    print(f"Model prediction for majority element: {int(logits.argmax(dim=-1).squeeze())}")

tensor([6, 7, 3, 5, 9, 8, 0, 7, 1, 5, 3, 0, 9, 4, 2, 5, 4, 6, 5, 9, 9, 0, 6, 2,
        7, 1, 3, 5, 8, 4, 5, 5, 8, 7, 7, 4, 5, 7, 1, 9, 1, 7, 9, 7, 1, 5, 4, 9,
        9, 0, 4, 0, 4, 7, 5, 1, 0, 4, 5, 6, 1, 1, 8, 7, 7, 9, 5, 4, 9, 3, 0, 2,
        8, 7, 0, 9, 1, 3, 2, 0, 1, 0, 4, 5, 9, 3, 8, 1, 6, 6, 9, 2, 5, 4, 1, 5,
        5, 8, 2, 4, 1, 6, 5, 0, 3, 1, 7, 6, 7, 0, 7, 9, 8, 2, 6, 6, 0, 6, 5, 6,
        4, 1, 2, 9, 3, 6, 0, 4])
True majority element: 5
Model prediction for majority element: 5


Let's also set up a batch for the interpretability section below, and test inference on it:

In [74]:
example_batch = next(iter(valloader))
batch_index = 5

example_data = example_batch[0][batch_index]

with torch.inference_mode():
    all_logits = model(example_data.long())
    logits = all_logits[0, -1]
    print(example_data)
    print(example_batch[1][batch_index])
    print(f"True majority element: {torch.bincount(example_data).argmax().item()}")
    print(f"Model prediction for majority element: {int(logits.argmax(dim=-1).squeeze())}")

tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
        9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
        2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
        4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
        9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
        8, 7, 1, 8, 8, 4, 0, 4])
tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
        1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
True majority element: 2
Model prediction for majority element: 2


# Model Interpretability

## Attention Patterns

Here, we'll use the TransformerLens `run_with_cache` method to collect the activations as we pass input through the model. The cache can be accessed with the numbers and names of the individual layers.

In [75]:
loss, cache = model.run_with_cache(example_batch[0][batch_index].long(), return_type="loss")

Below, we can see the attention pattern of each of the two heads in our attention layer. Perhaps expectedly, they appear extremely uniform--at each position, the heads seem to pay roughly equal attention to all previous positions. We'll examine this more closely later to see what the heads might actually be doing.

In [76]:
print(type(cache))
attention_pattern = cache["pattern", 0, "attn"]
attention_pattern = attention_pattern.squeeze(0)
print(attention_pattern.shape)

token_labels = [str(int(t)) for t in example_batch[0][batch_index]]
#print(token_labels.shape)

print("Layer 0 Head Attention Patterns:")
cv.attention.attention_heads(tokens=token_labels[:20], attention=attention_pattern[:,:20, :20])

# Use the following if outputs not displaying in notebook or if you want to save the html file
#html = cv.attention.attention_heads(tokens=list(map(str, example_input)), attention=attention_pattern)
#with open("cv_attn_2.html", "w") as f:
#    f.write(str(html))

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([2, 128, 128])
Layer 0 Head Attention Patterns:


## Logit Contribution for Assorted Layers

The logits output by the transformer result from an unembedding of the content of the residual stream, which itself consists of a series of addition operations performed on the initial embedding--each layer adds its own values to the stream--as well as LayerNorm operations. We can examine the contribution of each layer to the residual stream directly, as follows.

Note: Depending on your specific setup, the graphs below might open in a browser instead of the Jupyter notebook.

In [77]:
resid_components = [cache["embed"], cache["attn_out", 0], cache["mlp_out", 0]]
labels = ["embed", "A0", "M0"]
resid_stack = torch.stack(resid_components, 0)
resid_stack = resid_stack - resid_stack.mean(-1, keepdim=True)
print(resid_stack.shape)

torch.Size([3, 1, 128, 64])


In [78]:
fold_W_U = model.ln_final.w[:, None] * model.unembed.W_U
logit_components = resid_stack[:, 0] @ fold_W_U/cache["scale", None, "ln_final"][0]
print(logit_components.shape)

torch.Size([3, 128, 10])


In [79]:
example_input = dataset.__getitem__(1)[0].long()
logit_components = logit_components - logit_components.mean(-1, keepdim=True)
line(logit_components[:, torch.arange(1, model.cfg.n_ctx), example_input[:-1]].T)

Above, we can see the embedding contributions to the residual stream in blue, and the attention block's contributions in red. The MLP's contributions are in green, and are responsible for the greatest-magnitude contribution to the logits. Most of these seem to be negative, but there are several spikes at sequence positions 9, 56, 71, 82, 92, 101, 110, and 119.

In [80]:
print(example_batch[0][batch_index])
print(example_batch[1][batch_index])
print(example_batch[0][batch_index][8:11])
print(example_batch[0][batch_index][55:58])
print(example_batch[0][batch_index][70:73])
print(example_batch[0][batch_index][81:84])
print(example_batch[0][batch_index][91:94])
print(example_batch[0][batch_index][100:103])
print(example_batch[0][batch_index][109:112])
print(example_batch[0][batch_index][118:121])

tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
        9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
        2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
        4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
        9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
        8, 7, 1, 8, 8, 4, 0, 4])
tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
        1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
tensor([2, 3, 7])
tensor([2, 2, 3])
tensor([9, 2, 4])
tensor([7, 6, 4])
tensor([0, 6, 1])
tensor([9, 0, 0])
tensor([7, 1, 2])
tensor([

Looking at the sequence at these specific positions doesn't show much right away. All except one of the activation spikes occur when the number at the activated position is different from the previous number, but it's usually the case that any given number is different from the previous one. Let's examine the MLP in more detail to see what we can learn.

## Layer Analyses

Below, we'll create an analysis version of the model and load in the state dictionary from the trained model.

In [81]:
analysis_cfg = tl.EasyTransformerConfig(
    d_model=64,
    d_head=32,
    n_heads=2,
    d_mlp=256,
    n_layers=1,
    n_ctx=128, # context window size
    #use_local_attn=True,
    #attn_types=["local","global"],
    #window_size=2,
    act_fn="solu_ln",
    #attn_only=True,
    d_vocab=10,
    normalization_type="LNPre",
    seed=23,
)
analysis_model = tl.EasyTransformer(analysis_cfg)
state_dict = model.state_dict()
analysis_model.load_and_process_state_dict(state_dict, fold_ln=True, center_writing_weights=True, center_unembed=True)

### Attention Circuit Examination

Following the [Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) paper, let's break down the attention head into $W_Q W_K$ (attention pattern) and $W_V W_O$ (output value) circuits.

#### $W_Q W_K$ Circuit

In [85]:
QK = analysis_model.W_E @ analysis_model.W_Q[0, 0] @ analysis_model.W_K[0, 0].T @ analysis_model.W_E.T
imshow(QK, yaxis="Query", xaxis="Key", title="Full QK Circuit for Attn Layer 0")

#### $W_V W_O$ Circuit

In [87]:
OV = analysis_model.W_E @ analysis_model.W_V[0, 0] @ analysis_model.W_O[0, 0] @ analysis_model.W_in[0]
imshow(OV, yaxis="Input Vocab", xaxis="Neuron", title="Full OV Circuit for Attn Layer 0")

The above visualizations are somewhat difficult to interpret, but we can see that most neurons don't seem to activate very strongly for most inputs. Neurons with strong activations seem scattered through the vocabulary, suggesting that perhaps activations resulting from specific vocabulary terms mostly come from a few specific neurons.

In [83]:
line(OV[:, torch.randint(0, 256, (5,))])

In [88]:
example_batch[0][batch_index], example_batch[1][batch_index]

(tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
         9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
         2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
         4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
         9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
         8, 7, 1, 8, 8, 4, 0, 4]),
 tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
         1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2]))

### MLP 0 Analysis

In [92]:
imshow(cache["post", 0][0], yaxis="Pos", xaxis="Neuron", title="Neuron activations for single inputs")

Now we seem to be getting somewhere. Some observations:
- There's a clear change at position 29, with most activations dropping off and a few remaining. This is exactly the point at which--according to the above sequence--the network should register the number `2` as the majority element (a state that continues for almost the entire remaining sequence).
- Activations become less widespread for most neurons, and neurons 93, 103, and 117 remain almost constantly activated once `2` becomes the majority element.
- This seems to suggest that for MLP layer 0, the linear combination of activations 93, 103, and 117 represent the number 2!

Let's graph these neurons specifically to see how they evolve with position.

In [91]:
line(cache["post", 0][0, :, 93])
line(cache["post", 0][0, :, 103])
line(cache["post", 0][0, :, 117])

Activations here show a pattern of increasing steadily (for neuron 93) or getting fairly high and remaining so for the remaining of the sequence.

In [98]:
n93 = cache["post", 0][0, :, 93]
n103 = cache["post", 0][0, :, 103]
n117 = cache["post", 0][0, :, 117]
torch.add(n93, n103, out=n93)
torch.add(n93, n117, out=n93)
line(n93)

In [106]:
# count number of 2s in tensor up to each position with a list comprehension
# then convert to tensor
count_of_2s = [torch.sum(example_batch[0][batch_index][:i] == 2) for i in range(len(example_batch[0][batch_index].tolist()))]
line(count_of_2s)

Roughly, the combined outputs of the three neurons seems to drop somewhat whenever the number of `2`s stops increasing (perhaps reflecting decreased confidence that 2 is the majority element). However, the activations don't go up monotonically with the number of `2`s, indicating that perhaps the network is simply keeping track of the relative predominance of 2 vs. other elements rather than strictly counting.

To confirm that these three neurons are indeed performing the function of tracking the predominance of `2`s, let's perform an ablation experiment.

### Ablation Experiment

In [127]:
with torch.inference_mode():
    all_logits = model(example_batch[0][batch_index].long())
    loss = model(example_batch[0][batch_index].long(), return_type="loss")
    logits = all_logits[0, -1]
    print(example_batch[0][batch_index])
    print(example_batch[1][batch_index])
    print(f"True majority element: {torch.bincount(example_data).argmax().item()}")
    print(f"Model prediction for majority element: {int(logits.argmax(dim=-1).squeeze())}")
    print(f"Model loss for majority element: {loss.item()}")

tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
        9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
        2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
        4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
        9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
        8, 7, 1, 8, 8, 4, 0, 4])
tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
        1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
True majority element: 2
Model prediction for majority element: 2
Model loss for majority element: 11.001136779785156


In [215]:
layer_to_ablate = 0
neurons_to_ablate = [93, 103, 117, 206]
#neurons_to_ablate = torch.randint(0, 256, (4,)).tolist()
#neurons_to_ablate = [n for n in range(0, 256)]
analysis_model.reset_hooks()

def multi_neuron_ablation_hook(
    value: TT["batch", "pos", "neuron_index"],
    hook: HookPoint
) -> TT["batch", "pos", "neuron_index"]:
    print(f"Shape of the value tensor: {value.shape}")
    for n in neurons_to_ablate:
        value[:, :, n] = 0.
    return value


original_logits, original_l = analysis_model(example_batch[0][batch_index], return_type="both")
ablated_logits, ablated_l = analysis_model.run_with_hooks(
    example_batch[0][batch_index], 
    return_type="both", 
    fwd_hooks=[(
        utils.get_act_name("post", layer_to_ablate), 
        multi_neuron_ablation_hook
    )]
)
print(f"Original Loss: {F.cross_entropy(original_logits[0], example_batch[1][batch_index].long()).item():.3f}")
print(f"Ablated Loss: {F.cross_entropy(ablated_logits[0], example_batch[1][batch_index].long()).item():.3f}")

Shape of the value tensor: torch.Size([1, 128, 256])
Original Loss: 0.002
Ablated Loss: 1.840


In [216]:
abl_logits = ablated_logits[0]
org_logits = original_logits[0]

print(f"Input:\n {example_batch[0][batch_index]}")
print(f"True majority element at each point:\n {example_batch[1][batch_index]}")
print(f"Original model prediction for majority element:\n {org_logits.argmax(dim=-1).squeeze()}")
print(f"Ablated model prediction for majority element:\n {abl_logits.argmax(dim=-1).squeeze()}")

print(f"True majority element: {torch.bincount(example_batch[0][batch_index]).argmax().item()}")
print(f"Original model prediction for majority element: {int(org_logits[-1].argmax(dim=-1).squeeze())}")
print(f"Ablated model prediction for majority element: {int(abl_logits[-1].argmax(dim=-1).squeeze())}")

print(f"Original Loss: {F.cross_entropy(original_logits[0], example_batch[1][batch_index].long()).item():.3f}")
print(f"Ablated Loss: {F.cross_entropy(ablated_logits[0], example_batch[1][batch_index].long()).item():.3f}")

Input:
 tensor([1, 5, 3, 6, 0, 9, 6, 4, 2, 3, 7, 5, 2, 0, 6, 8, 8, 1, 1, 4, 0, 6, 1, 3,
        9, 3, 2, 6, 2, 2, 5, 3, 7, 3, 2, 5, 0, 8, 2, 6, 8, 1, 0, 6, 3, 7, 4, 6,
        2, 2, 0, 5, 7, 6, 8, 2, 2, 3, 8, 9, 4, 0, 4, 6, 7, 5, 5, 5, 2, 3, 9, 2,
        4, 9, 4, 6, 5, 8, 9, 1, 0, 7, 6, 4, 0, 1, 0, 9, 3, 6, 2, 0, 6, 1, 8, 2,
        9, 2, 7, 1, 9, 0, 0, 3, 7, 3, 3, 3, 9, 7, 1, 2, 4, 7, 1, 9, 8, 2, 6, 7,
        8, 7, 1, 8, 8, 4, 0, 4])
True majority element at each point:
 tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3, 3, 2, 0, 6, 6, 6, 6, 1, 1, 0, 6, 1, 1,
        1, 1, 1, 6, 6, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 6,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2])
Original model prediction for majority element:
 tensor([1, 1, 1, 1, 0, 0, 6, 6, 6, 3, 3