# Mech Interp of Binarized Neural Networks


Question being explored: Recent papers have shown that binary and ternary transformer based networks with weights of {-1,1} or {-1,0,1} can achieve similar results to full precision networks. Are these networks simply simulating a full precision network or are they learning different and possibly more interpretable algorithms due to their discretized nature. 

Setup: A 1 layer transformer with all weights binarized except for the embed and unembed. The specific implementation is based off of the BitNet paper and code is in the BitNet folder. 

# Setup
(No need to read)

In [None]:
TRAIN_MODEL = False

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer-lens
    %pip install circuitsvis

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

In [None]:
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import os
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import plotly.graph_objects as go
import plotly.io as pio


In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Plotting helper functions:

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
# Define the location to save the model, using a relative path
PTH_LOCATION = "workspace/_scratch/grokking_demo.pth"

# Create the directory if it does not exist
os.makedirs(Path(PTH_LOCATION).parent, exist_ok=True)

# Model Training

## Config

In [None]:
p = 113
frac_train = 0.3

# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)

num_epochs = 10000
checkpoint_every = 100

DATA_SEED = 598

## Define Task
* Define modular addition
* Define the dataset & labels

Input format:
|a|b|=|

In [None]:
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(113), " -> (i j)", i=p, j=p)


In [None]:
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).cuda()
print(dataset[:5])
print(dataset.shape)

In [None]:
labels = (dataset[:, 0] + dataset[:, 1]) % p
print(labels.shape)
print(labels[:5])

Convert this to a train + test set - 30% in the training set

In [None]:
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]
print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

## Define Model

In [None]:

cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True,
    device="cuda",
    seed = 999,
)

In [None]:
from bitnet import BitNetTransformer

model2 = BitNetTransformer(dim=128, depth=1, heads=4, in_features=p+1, out_features=p, random_seed=999,mode='binary').to("cuda")

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [None]:
for name, param in model2.named_parameters():
    if "b_" in name:
        param.requires_grad = False


## Define Optimizer + Loss

In [None]:
optimizer2  = torch.optim.AdamW(model2.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [None]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model2(train_data)
print(train_logits.shape)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model2(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)

In [None]:
print("Uniform loss:")
print(np.log(p))

## Actually Train

In [None]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if True:
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model2(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())

        optimizer2.step()
        optimizer2.zero_grad()

        with torch.inference_mode():
            test_logits = model2(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            # checkpoint_epochs.append(epoch)
            # model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

In [None]:
model2.state_dict

Above is the architecture of the model used. We see that all the linear layers in between the embed and unembed are BitLinear. Also note that the RMS norm is added before the unembeding, this was also found to be necessary for the model to actually train. 

In [None]:
model2.emb
W_E = model2.emb.weight
W_E = W_E[:-1]
print(W_E.shape)

In [None]:
cache = model2.cache
print(cache.keys())

## Show Model Training Statistics, Check that it groks!

In [None]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly.plot import line

In [None]:
fig = line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=False, 
     title="Training Curve for Modular Addition", line_labels=['train', 'test'], return_fig=True)


In [None]:
fig.show()
fig.write_image("workspace/_scratch/training_curve.pdf")


We see from the training curve that the model does indeed grok. The train is not able to get as low as the full model likely due to the lack of precision afforded by binarization.

In [None]:
# print out all parameters of model2
for name, param in model2.named_parameters():
    print(name, param.shape)

In [None]:
print(model2.transformer.layers[0].to_qkv[2].weight.size())
print(model2.transformer.layers[0].to_out.weight.size())
print(model2.transformer.ffn_layers[0].layer1.weight.size())

In [None]:
W_V = model2.transformer.layers[0].to_qkv[2].weight
W_K = model2.transformer.layers[0].to_qkv[1].weight
W_Q = model2.transformer.layers[0].to_qkv[0].weight
W_O = model2.transformer.layers[0].to_out.weight
W_mlp_in = model2.transformer.ffn_layers[0].layer1.weight
W_mlp_out = model2.transformer.ffn_layers[0].layer2.weight

In [None]:
torch.sign(W_mlp_in) @ torch.sign(W_O) @ torch.sign(W_V)

In [None]:
import plotly.express as px
import plotly.subplots as sp

W_V = model2.transformer.layers[0].to_qkv[2].weight
W_K = model2.transformer.layers[0].to_qkv[1].weight
W_Q = model2.transformer.layers[0].to_qkv[0].weight
W_O = model2.transformer.layers[0].to_out.weight
W_mlp_in = model2.transformer.ffn_layers[0].layer1.weight
W_mlp_out = model2.transformer.ffn_layers[0].layer2.weight

def visualize_weights(W, title):
    # Convert the tensor to numpy and then to int for visualization
    W_np = (torch.sign(W)).detach().cpu().numpy().astype(int)

    # Create the image
    fig = px.imshow(W_np, color_continuous_scale=["white", "black"], range_color=[-1,1])

    # add a title
    fig.update_layout(title=title)

    # Show the image
    return fig

# Visualize each weight matrix
visualize_weights(W_V, "Binarized W_V visualized").show()
visualize_weights(W_K, "Binarized W_K visualized").show()
visualize_weights(W_Q, "Binarized W_Q visualized").show()
visualize_weights(W_O, "Binarized W_O visualized").show()
visualize_weights(W_mlp_in, "Binarized W_mlp_in visualized").show()
visualize_weights(W_mlp_out, "Binarized W_mlp_out visualized").show()

We see that there does not exist any clear patterns to be discerned when we visualize one of the binarized weight matrices.

# Analysing the Model

## Standard Things to Try

In [None]:
original_logits= model2(dataset)
print(original_logits.numel())
cache = model2.cache
print(cache)

Get key weight matrices:

In [None]:
W_E = model2.emb.weight

In [None]:
original_loss = loss_fn(original_logits, labels).item()
print("Original Loss:", original_loss)

### Looking at Activations

Helper variable:

In [None]:
neuron_acts = cache["post_activation_BitLinear"][:, -1, :]
neuron_pre_acts = cache["pre_activation_BitLinear"][:, -1, :]
print(neuron_acts.size())

Get all shapes:

In [None]:
for param_name, param in cache.items():
    print(param_name, param.shape)

In [None]:
imshow(cache["attn_pattern_BitAttention"].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [None]:
dataset[:4]

In [None]:
cache["attn_pattern_BitAttention"].shape

In [None]:
import plotly.graph_objects as go
import plotly.io as pio

# Get the data
data = cache["attn_pattern_BitAttention"][:, 0, -1, 0].reshape(p, p)

# Create the figure
fig = go.Figure(data=go.Heatmap(z=data.cpu(), colorscale='Blues'))

# Set the title, labels, and size
fig.update_layout(title='Attention Score for Head 0', xaxis_title='b', yaxis_title='a', autosize=False, width=500, height=500)

# Show the figure
fig.show()

# Save the figure
pio.write_image(fig, 'workspace/_scratch/attention.pdf')


In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Create subplot structure
fig = make_subplots(rows=1, cols=4, subplot_titles=[f'Head {i}' for i in range(4)])

# Create individual plots
for i in range(4):
    img = einops.rearrange(cache["attn_pattern_BitAttention"][:, i, -1, 0], "(a b) -> a b", a=p, b=p).cpu()
    fig.add_trace(go.Heatmap(z=img, colorscale='Blues', showscale=False), row=1, col=i+1)

# Update layout
fig.update_layout(height=400, width=800, title_text="Attention for Each Head from a -> =")
fig.show()




Plotting neuron activations

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

### Singular Value Decomposition

In [None]:
W_E.shape
# take off the last row
W_E = W_E[:-1]
W_E.shape

In [None]:
U, S, Vh = torch.svd(W_E)
line(S, title="Singular Values")
imshow(U, title="Principal Components on the Input")

One difference this has compared to the full precision grokked model is that there seems to be more components.

In [None]:
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
line(S, title="Singular Values Random")
imshow(U, title="Principal Components Random")

## Explaining Algorithm

### Analyse the Embedding - It's a Lookup Table!

In [None]:
U, S, Vh = torch.svd(W_E)
line(U[:, :15].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")

In [None]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+2):
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).cuda()
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)
imshow(fourier_basis, xaxis="Input", yaxis="Component", y=fourier_basis_names)


In [None]:
line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")

In [None]:
imshow(fourier_basis @ fourier_basis.T, title="All Fourier Vectors are Orthogonal")

### Analyse the Embedding

In [None]:
imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")

Once again this is alot less "clean" compared to the full precision model but still seems to be fundamentally the same thing.

In [None]:
# Compute the norm
norm_values = (fourier_basis @ W_E).norm(dim=-1)

# Convert the tensor to a numpy array for plotting
norm_values_np = norm_values.detach().cpu().numpy()

# Create the plot
fig = go.Figure(data=go.Bar(y=norm_values_np))

# Set the title and labels
fig.update_layout(title='Norm of Tensor', xaxis_title='Index', yaxis_title='Norm')

# Show the plot
fig.show()

In [None]:
import numpy as np

# Compute the norm
norm_values = (fourier_basis @ W_E).norm(dim=-1)

# Convert the tensor to a numpy array for plotting
norm_values_np = norm_values.detach().cpu().numpy()

# Create indices for sin and cos
indices = np.arange(len(norm_values_np))
sin_indices = indices[indices % 2 == 0]
cos_indices = indices[indices % 2 == 1]

# Create the plot with alternating colors
fig = go.Figure(data=[
    go.Bar(x=sin_indices, y=norm_values_np[sin_indices], name='sin', marker_color='blue'),
    go.Bar(x=cos_indices, y=norm_values_np[cos_indices], name='cos', marker_color='red')
])

# Set the title and labels
fig.update_layout(
    title='Fourier Components of Embedding Matrix',
    xaxis_title='Frequency',
    yaxis_title='Norm',
    barmode='group',
    xaxis=dict(
        tickmode='array',
        tickvals=list(range(0, len(norm_values_np), 20)),  # Show only every 10th label
        ticktext=list(range(0, len(norm_values_np) // 2, 10))  # Show only every 10th label
    )
)

# Show the plot
fig.show()
pio.write_image(fig, 'workspace/_scratch/fourier_emb.pdf')

In [None]:
line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", yaxis="Norm of Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")

In [None]:
# key_freqs = [17, 25, 32, 47]
fourier_embed = fourier_basis @ W_E
# key freq indices are those for which the fourier_embed are higher than 0.1
key_freq_indices = (fourier_embed.norm(dim=-1) > 0.1).nonzero().squeeze().tolist()
print(key_freq_indices)
key_fourier_embed = fourier_embed[key_freq_indices]
print("key_fourier_embed", key_fourier_embed.shape)
imshow(key_fourier_embed @ key_fourier_embed.T, title="Dot Product of embedding of key Fourier Terms")

One difference this graph shows compared to the full precision model is that the terms are not as orthogonal. I hypothesize that this is due to the lack of precision of binarization.

### Key Frequencies

In [None]:
import neel_plotly as npx
key_cos = [num for num in key_freq_indices if num % 2 == 0]
npx.line(fourier_basis[key_cos], title="Cos of key freqs")

In [None]:
npx.line(fourier_basis[key_cos].mean(0), title="Constructive Interference")

## Analyse Neurons

In [None]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

In [None]:
# Get the data
data = einops.rearrange(neuron_acts[:, 0], "(a b) -> a b", a=p, b=p)

# Create the figure
fig = go.Figure(data=go.Heatmap(z=data.cpu(), colorscale='Blues'))

# Set the title, labels, and size
fig.update_layout(title='First neuron act', xaxis_title='b', yaxis_title='a', autosize=False, width=500, height=500)

# Show the figure
fig.show()

pio.write_image(fig, 'workspace/_scratch/activation.pdf')

In [None]:
imshow(fourier_basis @ neuron_acts[:, 0].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 0", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

In [None]:
imshow(fourier_basis @ neuron_acts[:, 4].reshape(p, p) @ fourier_basis.T, title="2D Fourier Transformer of neuron 4", xaxis="b", yaxis="a", x=fourier_basis_names, y=fourier_basis_names)

### Neuron Clusters

In [None]:
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean - doesn't matter!
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

In [None]:
neuron_freq_norm = torch.zeros(p//2, cfg.d_mlp).cuda()
for freq in range(0, p//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")

In [None]:
import plotly.graph_objects as go

# Calculate the maximum values along the 0th dimension
max_values = neuron_freq_norm.max(0).values

# Create an array of indices
indices = list(range(len(max_values)))

# Create the bar graph
fig = go.Figure(data=[go.Bar(x=indices, y=max_values.cpu().numpy())])

# Add labels and title
fig.update_layout(title='Max Values of neuron_freq_norm', xaxis_title='Index', yaxis_title='Max Value')

# Display the plot
fig.show()

# Create the histogram
fig = go.Figure(data=go.Histogram(x=max_values.cpu(), nbinsx=10, histnorm=''))

# Set the bin size
fig.update_traces(xbins=dict(start=0, end=1, size=0.05))

# Set the title and labels
fig.update_layout(title='FVE by degree 2 polynomials (Binary)', xaxis_title='Fraction of Variance Explained', yaxis_title='Number of Neurons')

# Show the figure
fig.show()
pio.write_image(fig, 'workspace/_scratch/fve.pdf')

In [None]:
line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")

In [None]:
neuron_freq_norm.shape

In [None]:
import plotly.graph_objects as go

# Get the data
neuron_acts_square = einops.rearrange(neuron_acts, "(a b) neur -> a b neur", a=p, b=p).clone()
neuron_acts_square -= einops.reduce(neuron_acts_square, "a b neur -> 1 1 neur", "mean")
neuron_acts_square_fourier = einsum("a b neur, fa a, fb b -> fa fb neur", neuron_acts_square, fourier_basis, fourier_basis)

# Create the data for the heatmap
data = neuron_acts_square_fourier.norm(dim=-1)

# Create the figure
fig = go.Figure(data=go.Heatmap(z=data.cpu(), colorscale='Blues', x=fourier_basis_names, y=fourier_basis_names))

# Set the title and labels
fig.update_layout(title={
        'text': "Norms of neuron activations<br>by Fourier Component",
        'y':0.9,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top',
        'font': dict(
            size=15
        )
    }, xaxis_title='Fourier Component b', yaxis_title='Fourier Component a',autosize=False, width=500, height=500)

# Show the figure
fig.show()

pio.write_image(fig, 'workspace/_scratch/neuron_acts_fourier.pdf')

### Summary of Results

Overall, all of the graphs I've generated seem to align with the graphs of the full precision models from the original reverse engineering modular addition code from https://youtu.be/o0FppeD_xXQ?si=ObA2aISAUQI_H2GC

While this investigation is not exactly thorough, it is pretty clear at least that there is no evidence to support my initial hypothesis that binarized transformers can learn a more discretized and more interpretable representation. Instead, all of evidence seems to suggest instead that the binarized setup is instead learning an algorithm which is mostly the same as the one being learned by the full precision model.

From this preliminary investigation, I further hypothesize that this result is due to the fact that the start and end with the embed and unembed layers are not binarized so they are still free to learn the fourier transform which are the most important parts to this. Furthermore, I think that in general, binarized networks will end up learning approximations of full precision networks. This is because the optimization techniques used such as the straight-through estimator of the gradient used by BitNet aim to treat the binarization mechanism as a continuous function to be optimized over.