<a target="_blank" href="https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Grokking_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Mech Interp of Binarized Neural Networks


Question being explored: Recent papers have shown that binary and ternary transformer based networks with weights of {-1,1} or {-1,0,1} can achieve similar results to full precision networks. Are these networks simply simulating a full precision network or are they learning different and possibly more interpretable algorithms due to their discretized nature. 

Setup: A 1 layer transformer with all weights binarized except for the embed and unembed. The specific implementation is based off of the BitNet paper and code is in the BitNet folder. 

Findings: It seems like for modular addition, binary transformers exhibit grokking in a very similar way and seem to be learning fundamentally the same algorithm and is more or less just emulating a full precision network.

# Setup
(No need to read)

In [1]:
TRAIN_MODEL = False

In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer-lens
    %pip install circuitsvis

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: notebook_connected


In [4]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [5]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML


torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.



In [6]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache


torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.


torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.



Plotting helper functions:

In [7]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [8]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

# Model Training

## Config

In [9]:
p = 113
frac_train = 0.3

# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)

# num_epochs = 25000
num_epochs = 10000
checkpoint_every = 100

DATA_SEED = 598

## Define Task
* Define modular addition
* Define the dataset & labels

Input format:
|a|b|=|

In [10]:
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(113), " -> (i j)", i=p, j=p)


In [11]:
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).cuda()
print(dataset[:5])
print(dataset.shape)

tensor([[  0,   0, 113],
        [  0,   1, 113],
        [  0,   2, 113],
        [  0,   3, 113],
        [  0,   4, 113]], device='cuda:0')
torch.Size([12769, 3])


In [12]:
labels = (dataset[:, 0] + dataset[:, 1]) % p
print(labels.shape)
print(labels[:5])

torch.Size([12769])
tensor([0, 1, 2, 3, 4], device='cuda:0')


Convert this to a train + test set - 30% in the training set

In [13]:
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

tensor([[ 21,  31, 113],
        [ 30,  98, 113],
        [ 47,  10, 113],
        [ 86,  21, 113],
        [ 99,  83, 113]], device='cuda:0')
tensor([ 52,  15,  57, 107,  69], device='cuda:0')
torch.Size([3830, 3])
tensor([[ 43,  40, 113],
        [ 31,  42, 113],
        [ 39,  63, 113],
        [ 35,  61, 113],
        [112, 102, 113]], device='cuda:0')
tensor([ 83,  73, 102,  96, 101], device='cuda:0')
torch.Size([8939, 3])


## Define Model

In [14]:

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True,
    device="cuda",
    seed = 999,
)

In [15]:
model = HookedTransformer(cfg)

In [16]:
from bitnet import BitNetTransformer

model2 = BitNetTransformer(dim=128, depth=1, heads=4, in_features=p+1, out_features=p, random_seed=999).to("cuda")

tensor([[ 0.,  0.,  0.,  0., -0., -1.,  0., -1.,  1.,  1., -1.,  0., -1.,  0.,
         -1.,  1., -0.,  1.,  0., -1.],
        [ 1., -0.,  1., -0., -1.,  1., -1., -1.,  1., -1., -1., -1., -1.,  1.,
         -1.,  0.,  1.,  1.,  1., -1.],
        [-1., -0., -1., -1.,  0.,  0., -1.,  1., -1., -1.,  0., -1., -0., -0.,
          1., -0., -0., -1., -1.,  1.],
        [-1.,  1., -1., -1.,  1., -1.,  1.,  1., -0., -1., -1., -1., -1.,  1.,
          0., -1.,  1.,  0., -1.,  1.],
        [ 1.,  1., -1.,  1., -1.,  1., -1., -0.,  0., -1.,  1., -0., -0.,  1.,
         -0., -1.,  1., -1.,  1., -0.],
        [-1., -1.,  1., -1., -0., -1.,  1.,  1., -1., -1., -0., -1., -1.,  0.,
          1., -1., -1.,  0.,  0., -0.],
        [ 0., -1.,  1., -1.,  1., -1., -1.,  0.,  1., -0., -1.,  0., -1., -1.,
         -1.,  1.,  1.,  1.,  0.,  1.],
        [-1., -1., -1.,  1.,  0.,  1., -0.,  1., -1., -1.,  1., -1.,  1.,  0.,
          1., -0., -0., -1., -1., -0.],
        [ 1.,  1., -1.,  0.,  1.,  1., -1.,  1.,

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [17]:
for name, param in model2.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [18]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [19]:
optimizer2  = torch.optim.AdamW(model2.parameters(), lr=lr, weight_decay=0, betas=betas)

In [20]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model2(train_data)
print(train_logits.shape)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model2(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)

torch.Size([3830, 3, 113])
tensor(4.9126, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(4.9016, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)


In [21]:
print("Uniform loss:")
print(np.log(p))

Uniform loss:
4.727387818712341


## Actually Train

**Weird Decision:** Training the model with full batch training rather than stochastic gradient descent. We do this so to make training smoother and reduce the number of slingshots.

In [22]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []

if TRAIN_MODEL:
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())

        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

In [23]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if True:
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model2(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())

        optimizer2.step()
        optimizer2.zero_grad()

        with torch.inference_mode():
            test_logits = model2(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            # checkpoint_epochs.append(epoch)
            # model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

  0%|          | 0/10000 [00:00<?, ?it/s]

Epoch 99 Train Loss 2.908867800903007 Test Loss 5.384417223469513
Epoch 199 Train Loss 1.9082372552763511 Test Loss 5.911050040762759
Epoch 299 Train Loss 1.4226051982414765 Test Loss 6.3801427088020315
Epoch 399 Train Loss 1.2139898677038876 Test Loss 6.856175713242716
Epoch 499 Train Loss 0.9311827185862703 Test Loss 7.154155988761241
Epoch 599 Train Loss 0.7926346814811845 Test Loss 7.406866094158965
Epoch 699 Train Loss 0.6328064819476724 Test Loss 7.656983641108373
Epoch 799 Train Loss 0.5386195533025419 Test Loss 7.85244080546728
Epoch 899 Train Loss 0.483151151929906 Test Loss 7.980981875429242
Epoch 999 Train Loss 0.45889368534966607 Test Loss 8.104028788801582
Epoch 1099 Train Loss 0.3857090213262157 Test Loss 8.087395736664488
Epoch 1199 Train Loss 0.4052757432169338 Test Loss 8.24701438749579
Epoch 1299 Train Loss 0.3176503247031622 Test Loss 8.337920421518396
Epoch 1399 Train Loss 0.28707986447098593 Test Loss 8.297116156439191
Epoch 1499 Train Loss 0.26994059565921585 Test

In [24]:
print(model2.cache)
model2.state_dict

{'attn_pattern_BitAttention': tensor([[[[0.0000e+00, 0.0000e+00, 1.0000e+00],
          [7.1184e-07, 1.0000e+00, 1.2295e-33],
          [1.0000e+00, 0.0000e+00, 0.0000e+00]],

         [[1.1351e-43, 1.2782e-29, 1.0000e+00],
          [1.0000e+00, 1.1390e-38, 5.5719e-07],
          [1.0000e+00, 0.0000e+00, 0.0000e+00]],

         [[1.0000e+00, 0.0000e+00, 0.0000e+00],
          [9.2810e-30, 1.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.0000e+00, 0.0000e+00]],

         [[1.0000e+00, 0.0000e+00, 2.8614e-37],
          [6.1435e-19, 2.2685e-14, 1.0000e+00],
          [0.0000e+00, 1.0000e+00, 0.0000e+00]]],


        [[[1.4552e-30, 1.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.0000e+00],
          [0.0000e+00, 1.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 1.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.0000e+00],
          [2.2934e-36, 1.0000e+00, 0.0000e+00]],

         [[1.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.0000e+0

<bound method Module.state_dict of BitNetTransformer(
  (emb): Embedding(114, 128)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): BitAttention(
        (to_qkv): ModuleList(
          (0-2): 3 x BitLinear(in_features=128, out_features=128, bias=False)
        )
        (to_out): BitLinear(in_features=128, out_features=128, bias=False)
      )
    )
    (ffn_layers): ModuleList(
      (0): BitFeedForward(
        (layer1): BitLinear(in_features=128, out_features=512, bias=False)
        (activation): ReLU()
        (layer2): BitLinear(in_features=512, out_features=128, bias=False)
      )
    )
  )
  (to_logits): Sequential(
    (0): RMSNorm()
    (1): Linear(in_features=128, out_features=113, bias=False)
  )
)>

In [25]:
model2.emb
W_E = model2.emb.weight
W_E = W_E[:-1]
print(W_E.shape)

torch.Size([113, 128])


In [26]:
cache = model2.cache
print(cache.keys())

dict_keys(['attn_pattern_BitAttention', 'pre_activation_BitLinear', 'post_activation_BitLinear'])


In [27]:
torch.save(
    {
        "model":model.state_dict(),
        "config": model.cfg,
        "checkpoints": model_checkpoints,
        "checkpoint_epochs": checkpoint_epochs,
        "test_losses": test_losses,
        "train_losses": train_losses,
        "train_indices": train_indices,
        "test_indices": test_indices,
    },
    PTH_LOCATION)

In [28]:
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data["checkpoints"]
    checkpoint_epochs = cached_data["checkpoint_epochs"]
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    train_indices = cached_data["train_indices"]
    test_indices = cached_data["test_indices"]

## Show Model Training Statistics, Check that it groks!

In [29]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line


os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.



Collecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-oylf7xy9
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-oylf7xy9
  Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7
  Preparing metadata (setup.py) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.


In [30]:

line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)





In [31]:
# print out all parameters of model2
for name, param in model2.named_parameters():
    print(name, param.shape)

emb.weight torch.Size([114, 128])
transformer.layers.0.to_qkv.0.weight torch.Size([128, 128])
transformer.layers.0.to_qkv.1.weight torch.Size([128, 128])
transformer.layers.0.to_qkv.2.weight torch.Size([128, 128])
transformer.layers.0.to_out.weight torch.Size([128, 128])
transformer.ffn_layers.0.layer1.weight torch.Size([512, 128])
transformer.ffn_layers.0.layer2.weight torch.Size([128, 512])
to_logits.0.gamma torch.Size([128])
to_logits.1.weight torch.Size([113, 128])


In [32]:
print(model2.transformer.layers[0].to_qkv[2].weight.size())
print(model2.transformer.layers[0].to_out.weight.size())
print(model2.transformer.ffn_layers[0].layer1.weight.size())

torch.Size([128, 128])
torch.Size([128, 128])
torch.Size([512, 128])


In [33]:
W_V = model2.transformer.layers[0].to_qkv[2].weight
W_O = model2.transformer.layers[0].to_out.weight
W_mlp_in = model2.transformer.ffn_layers[0].layer1.weight
W_mlp_in

Parameter containing:
tensor([[ 3.5047e-01,  2.2517e-01,  2.1924e-01,  ...,  4.8396e-01,
         -3.4303e-02, -2.6806e-01],
        [-1.0415e-01, -1.7365e-04, -3.1212e-01,  ..., -6.9204e-01,
          3.3560e-01,  4.3170e-01],
        [ 1.3337e-01,  1.8390e-01,  1.9944e-01,  ...,  8.9012e-01,
          1.4895e-01, -1.0316e+00],
        ...,
        [ 1.0652e-01,  2.9674e-01, -4.0166e-01,  ..., -6.4110e-02,
         -9.4392e-02, -5.2126e-01],
        [-8.5568e-02,  3.7147e-03,  4.6724e-01,  ..., -8.6406e-01,
         -5.9624e-02, -1.1661e+00],
        [ 5.0394e-01, -1.4464e-01, -6.2496e-01,  ...,  5.0246e-02,
          5.7082e-03, -2.4015e-01]], device='cuda:0', requires_grad=True)

In [34]:
torch.sign(W_mlp_in) @ torch.sign(W_O) @ torch.sign(W_V)

tensor([[ 168.,   76.,   72.,  ..., -144., -112.,  428.],
        [  -8.,   24., -176.,  ..., -252., -364.,  -20.],
        [-184.,   84., -120.,  ..., -344.,  168.,  -20.],
        ...,
        [ -96.,  212.,  304.,  ...,  -80.,   80.,   68.],
        [ -12., -208.,  -20.,  ...,   44.,  428.,  304.],
        [   8.,  -96.,   32.,  ...,  276.,  100.,   -4.]], device='cuda:0',
       grad_fn=<MmBackward0>)

In [76]:
import plotly.express as px

# Convert the tensor to numpy and then to int for visualization
W_V_sign_np = (torch.sign(W_mlp_in)).detach().cpu().numpy().astype(int)

# Create a black and white color scale
colorscale = [[0, 'white'], [1, 'black']]

# Create the image
fig = px.imshow(W_V_sign_np, color_continuous_scale=colorscale, range_color=[-1,1])
# add a title
fig.update_layout(title="Binarized W_V visualized")

# Show the image
fig.show()

# Analysing the Model

## Standard Things to Try

In [36]:
original_logits= model2(dataset)
print(original_logits.numel())
cache = model2.cache
print(cache)

4328691
{'attn_pattern_BitAttention': tensor([[[[1.9188e-08, 1.9188e-08, 1.0000e+00],
          [1.9188e-08, 1.9188e-08, 1.0000e+00],
          [5.0000e-01, 5.0000e-01, 0.0000e+00]],

         [[5.0000e-01, 5.0000e-01, 0.0000e+00],
          [5.0000e-01, 5.0000e-01, 0.0000e+00],
          [5.0000e-01, 5.0000e-01, 0.0000e+00]],

         [[5.0000e-01, 5.0000e-01, 2.8834e-12],
          [5.0000e-01, 5.0000e-01, 2.8834e-12],
          [5.0000e-01, 5.0000e-01, 0.0000e+00]],

         [[5.0000e-01, 5.0000e-01, 8.6014e-33],
          [5.0000e-01, 5.0000e-01, 8.6014e-33],
          [0.0000e+00, 0.0000e+00, 1.0000e+00]]],


        [[[0.0000e+00, 1.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.0000e+00, 0.0000e+00]],

         [[1.0000e+00, 0.0000e+00, 0.0000e+00],
          [1.0000e+00, 0.0000e+00, 0.0000e+00],
          [1.0000e+00, 0.0000e+00, 0.0000e+00]],

         [[1.9171e-12, 1.0000e+00, 1.1056e-23],
          [1.0000e+00, 1.5776e-27, 3

Get key weight matrices:

In [37]:
W_E = model.embed.W_E[:-1]
print("W_E", W_E.shape)
W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur", W_neur.shape)
W_logit = model.blocks[0].mlp.W_out @ model.unembed.W_U
print("W_logit", W_logit.shape)

W_E torch.Size([113, 128])
W_neur torch.Size([4, 113, 512])
W_logit torch.Size([512, 113])


In [38]:
original_loss = loss_fn(original_logits, labels).item()
print("Original Loss:", original_loss)

Original Loss: 1.6028283887077162


### Looking at Activations

Helper variable:

In [39]:
neuron_acts = cache["post_activation_BitLinear"][:, -1, :]
neuron_pre_acts = cache["pre_activation_BitLinear"][:, -1, :]
print(neuron_acts.size())

torch.Size([12769, 512])


Get all shapes:

In [40]:
for param_name, param in cache.items():
    print(param_name, param.shape)

attn_pattern_BitAttention torch.Size([12769, 4, 3, 3])
pre_activation_BitLinear torch.Size([12769, 3, 512])
post_activation_BitLinear torch.Size([12769, 3, 512])


In [41]:
imshow(cache["attn_pattern_BitAttention"].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [42]:
dataset[:4]

tensor([[  0,   0, 113],
        [  0,   1, 113],
        [  0,   2, 113],
        [  0,   3, 113]], device='cuda:0')

In [43]:
cache["attn_pattern_BitAttention"].shape

torch.Size([12769, 4, 3, 3])

In [44]:
imshow(cache["attn_pattern_BitAttention"][:, 0, -1, 0].reshape(p, p), title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a")

In [45]:
imshow(
    einops.rearrange(cache["attn_pattern_BitAttention"][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p),
    title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a", facet_col=0)

Plotting neuron activations

In [46]:
imshow(
    einops.rearrange(neuron_pre_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

### Singular Value Decomposition

In [47]:
W_E.shape

torch.Size([113, 128])

In [48]:
U, S, Vh = torch.svd(W_E)
line(S, title="Singular Values")
imshow(U, title="Principal Components on the Input")

In [49]:
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
line(S, title="Singular Values Random")
imshow(U, title="Principal Components Random")

## Explaining Algorithm

### Analyse the Embedding - It's a Lookup Table!

In [50]:
U, S, Vh = torch.svd(W_E)
line(U[:, :15].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")





In [51]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+2):
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).cuda()
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)
imshow(fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names)


In [52]:
line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")









In [53]:
imshow(fourier_basis @ fourier_basis.T, title="All Fourier Vectors are Orthogonal")

### Analyse the Embedding

In [54]:
imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")

In [55]:
import plotly.graph_objects as go

# Compute the norm
norm_values = (fourier_basis @ W_E).norm(dim=-1)

# Convert the tensor to a numpy array for plotting
norm_values_np = norm_values.detach().cpu().numpy()

# Create the plot
fig = go.Figure(data=go.Scatter(y=norm_values_np))

# Set the title and labels
fig.update_layout(title='Norm of Tensor', xaxis_title='Index', yaxis_title='Norm')

# Show the plot
fig.show()

In [56]:
line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")

In [57]:
# key_freqs = [17, 25, 32, 47]
key_freq_indices = [9,10,23,24,59,60,65,66,67,68,79,80,97,98]
fourier_embed = fourier_basis @ W_E
key_fourier_embed = fourier_embed[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)
imshow(key_fourier_embed @ key_fourier_embed.T, title="Dot Product of embedding of key Fourier Terms")

key_fourier_embed torch.Size([14, 128])


### Key Frequencies

In [58]:
import neel_plotly as npx
key_cos = [num for num in key_freq_indices if num % 2 == 0]
npx.line(fourier_basis[key_cos], title="Cos of key freqs")





In [59]:
npx.line(fourier_basis[key_cos].mean(0), title="Constructive Interference")

## Analyse Neurons

In [60]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

In [61]:
imshow(
    einops.rearrange(neuron_acts[:, 0], "(a b) -> a b", a=p, b=p),
    title="First neuron act", xaxis="b", yaxis="a",)

In [64]:
imshow(fourier_basis @ neuron_acts[:, 0].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 0", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [65]:
imshow(fourier_basis @ neuron_acts[:, 5].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 5", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

### Neuron Clusters

In [67]:
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

fourier_neuron_acts torch.Size([512, 115, 115])


In [68]:
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()
for freq in range(0, p//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")

In [69]:
line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")