# Setup


In [1]:
!pip install transformer-lens
!pip install fancy_einsum
!pip install datasets
!pip install git+https://github.com/neelnanda-io/neel-plotly

Collecting transformer-lens
  Downloading transformer_lens-1.11.0-py3-none-any.whl (119 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.1/119.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer-lens)
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer-lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens)
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from transformer-lens)
  Downloa

In [15]:
import google.colab
IN_COLAB = True
TRAIN_MODEL = True

In [16]:
import plotly.io as pio
pio.renderers.default = "colab"
pio.templates['plotly'].layout.xaxis.title.font.size = 20
pio.templates['plotly'].layout.yaxis.title.font.size = 20
pio.templates['plotly'].layout.title.font.size = 30

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def add_lines(figure):
    figure.add_vline(memorization_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(circuit_formation_end_epoch, line_dash="dash", opacity=0.7)
    figure.add_vline(cleanup_end_epoch, line_dash="dash", opacity=0.7)
    return figure

In [17]:
import torch
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import copy

from IPython.display import HTML
#import wandb

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [18]:
#!wandb login --relogin

# Model Training

## Config

In [19]:
p = 113
frac_train = 0.3

# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)

num_epochs = 25000
checkpoint_every = 100

DATA_SEED = 598

## Define Task
(a^m + b^n) mod 113

In [20]:
m= 4
n= 4
#wandb_name = "m4_n4"

Input format:
|a|b|=|

In [21]:
a_vector = einops.repeat(torch.arange(p), "i -> (i j)", j=p)
b_vector = einops.repeat(torch.arange(p), "j -> (i j)", i=p)
equals_vector = einops.repeat(torch.tensor(113), " -> (i j)", i=p, j=p)

dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1)
labels = (dataset[:, 0]**m + dataset[:, 1]**n) % p
labels = labels.to('cuda')

Convert this to a train + test set - 30% in the training set

In [22]:
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]

train_data = train_data.to('cuda')
train_labels = train_labels.to('cuda')
test_data = test_data.to('cuda')
test_labels = test_labels.to('cuda')

print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)
print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

tensor([[ 21,  31, 113],
        [ 30,  98, 113],
        [ 47,  10, 113],
        [ 86,  21, 113],
        [ 99,  83, 113]], device='cuda:0')
tensor([93, 17, 58, 10, 12], device='cuda:0')
torch.Size([3830, 3])
tensor([[ 43,  40, 113],
        [ 31,  42, 113],
        [ 39,  63, 113],
        [ 35,  61, 113],
        [112, 102, 113]], device='cuda:0')
tensor([ 84, 100,  75,  49,  65], device='cuda:0')
torch.Size([8939, 3])


## Define Model

In [23]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True,
    device="cuda",
    n_devices=1,
    seed = 999,
)

model = HookedTransformer(cfg)

Disable the biases, as we don't need them for this task and it makes things easier to interpret.

In [24]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False

## Define Optimizer + Loss

In [25]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [26]:
def loss_fn(logits, labels):
    if len(logits.shape)==3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64).to('cuda:0')
    log_probs = logits.log_softmax(dim=-1).to('cuda:0')
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()
train_logits = model(train_data).to('cuda:0')
train_labels = train_labels.to('cuda:0')
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)
print("Uniform loss:")
print(np.log(p))

tensor(4.7362, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(4.7318, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
Uniform loss:
4.727387818712341


## Training with full batch

In [27]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
if TRAIN_MODEL:

    #wandb.init(project="che-de-moivre", name=wandb_name)

    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        train_losses.append(train_loss.item())

        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        if ((epoch+1)%checkpoint_every)==0:
            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")
            """
            wandb.log({
            "train_loss": train_loss.item(),
            "test_loss": test_loss.item(),
            "train_logit": train_logits,
            "test_logit": test_logits,
            })

            # weights
            for name, param in model.named_parameters():
                    wandb.log({f"weights_{name}": wandb.Histogram(param.data.cpu().numpy())})
            """

  0%|          | 0/25000 [00:00<?, ?it/s]

Epoch 99 Train Loss 0.09769160211850055 Test Loss 0.1879997152501541
Epoch 199 Train Loss 0.00406000303667101 Test Loss 0.01918378795302881
Epoch 299 Train Loss 0.001248733847137967 Test Loss 0.009900643663197498
Epoch 399 Train Loss 0.00039496988975130573 Test Loss 0.007062510671076046
Epoch 499 Train Loss 0.0001286241916246885 Test Loss 0.005821951513291045
Epoch 599 Train Loss 4.254310859837514e-05 Test Loss 0.0048554604098111085
Epoch 699 Train Loss 1.4272511328052307e-05 Test Loss 0.004061594150196693
Epoch 799 Train Loss 4.943566610062293e-06 Test Loss 0.003295708387666033
Epoch 899 Train Loss 1.7916612938821952e-06 Test Loss 0.0027047591947078173
Epoch 999 Train Loss 7.033562041651425e-07 Test Loss 0.002044240020345832
Epoch 1099 Train Loss 3.160014620573091e-07 Test Loss 0.0015713373384511466
Epoch 1199 Train Loss 1.7199437978534486e-07 Test Loss 0.0012317761695687835
Epoch 1299 Train Loss 1.1726094471721117e-07 Test Loss 0.0010421603932240647
Epoch 1399 Train Loss 9.7951768833

## Check if model groks

In [28]:
from neel_plotly.plot import line
line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=False, title="Training Curve for Modular Addition", line_labels=['train loss', 'test loss'], toggle_x=True, toggle_y=True)

# Analysing the Model

Get key weight matrices:

In [29]:
W_E = model.embed.W_E[:-1]
print("W_E", W_E.shape)
W_neur = W_E @ model.blocks[0].attn.W_V @ model.blocks[0].attn.W_O @ model.blocks[0].mlp.W_in
print("W_neur", W_neur.shape)
W_logit = model.blocks[0].mlp.W_out.to('cuda:0') @ model.unembed.W_U.to('cuda:0')
print("W_logit", W_logit.shape)

W_E torch.Size([113, 128])
W_neur torch.Size([4, 113, 512])
W_logit torch.Size([512, 113])


### Looking at Activations

Helper variable:

In [30]:
original_logits, cache = model.run_with_cache(dataset)
pattern_a = cache["pattern", 0, "attn"][:, :, -1, 0]
pattern_b = cache["pattern", 0, "attn"][:, :, -1, 1]
neuron_acts = cache["post", 0, "mlp"][:, -1, :]
neuron_pre_acts = cache["pre", 0, "mlp"][:, -1, :]

Model architecture:

In [31]:
for param_name, param in cache.items():
    print(param_name, param.shape)

hook_embed torch.Size([12769, 3, 128])
hook_pos_embed torch.Size([12769, 3, 128])
blocks.0.hook_resid_pre torch.Size([12769, 3, 128])
blocks.0.attn.hook_q torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_k torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_v torch.Size([12769, 3, 4, 32])
blocks.0.attn.hook_attn_scores torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_pattern torch.Size([12769, 4, 3, 3])
blocks.0.attn.hook_z torch.Size([12769, 3, 4, 32])
blocks.0.hook_attn_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_mid torch.Size([12769, 3, 128])
blocks.0.mlp.hook_pre torch.Size([12769, 3, 512])
blocks.0.mlp.hook_post torch.Size([12769, 3, 512])
blocks.0.hook_mlp_out torch.Size([12769, 3, 128])
blocks.0.hook_resid_post torch.Size([12769, 3, 128])


In [32]:
imshow(cache["pattern", 0].mean(dim=0)[:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [33]:
imshow(cache["pattern", 0][5][:, -1, :], title="Average Attention Pattern per Head", xaxis="Source", yaxis="Head", x=['a', 'b', '='])

In [34]:
imshow(
    einops.rearrange(cache["pattern", 0][:, :, -1, 0], "(a b) head -> head a b", a=p, b=p),
    title="Attention for Head 0 from a -> =", xaxis="b", yaxis="a", facet_col=0)

In [35]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron activations", xaxis="b", yaxis="a", facet_col=0)

### Singular Value Decomposition

In [36]:
U, S, Vh = torch.svd(W_E)
line(S, title="Singular Values")
imshow(U, title="Principal Components on the Input")

In [37]:
# Control - random Gaussian matrix
U, S, Vh = torch.svd(torch.randn_like(W_E))
line(S, title="Singular Values Random")
imshow(U, title="Principal Components Random")

## Explaining Algorithm

### Analyse the Embedding

In [38]:
U, S, Vh = torch.svd(W_E)
line(U[:, :8].T, title="Principal Components of the embedding", xaxis="Input Vocabulary")

In [39]:
fourier_basis = []
fourier_basis_names = []
fourier_basis.append(torch.ones(p))
fourier_basis_names.append("Constant")
for freq in range(1, p//2+1):
    fourier_basis.append(torch.sin(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Sin {freq}")
    fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))
    fourier_basis_names.append(f"Cos {freq}")
fourier_basis = torch.stack(fourier_basis, dim=0).cuda()
fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)

line(fourier_basis[:8], xaxis="Input", line_labels=fourier_basis_names[:8], title="First 8 Fourier Components")
line(fourier_basis[25:29], xaxis="Input", line_labels=fourier_basis_names[25:29], title="Middle Fourier Components")

In [40]:
imshow(fourier_basis @ fourier_basis.T, title="Verify All Fourier Vectors are Orthogonal")

### Analyse the Embedding

In [41]:
imshow(fourier_basis @ W_E, yaxis="Fourier Component", xaxis="Residual Stream", y=fourier_basis_names, title="Embedding in Fourier Basis")

In [42]:
line((fourier_basis @ W_E).norm(dim=-1), xaxis="Fourier Component", x=fourier_basis_names, title="Norms of Embedding in Fourier Basis")

 Key Frequencies: Models might not utilize the same frequencies even with exact settings. Adjust this according to the plot above for a meaningful ablation (progress measure evaluation).  

In [43]:
key_freqs = list(range(1, 57))

## Analyse Neurons

In [44]:
imshow(
    einops.rearrange(neuron_acts[:, :5], "(a b) neuron -> neuron a b", a=p, b=p),
    title="First 5 neuron acts", xaxis="b", yaxis="a", facet_col=0)

### Neuron Clusters

In [45]:
fourier_neuron_acts = fourier_basis @ einops.rearrange(neuron_acts, "(a b) neuron -> neuron a b", a=p, b=p) @ fourier_basis.T
# Center these by removing the mean
fourier_neuron_acts[:, 0, 0] = 0.
print("fourier_neuron_acts", fourier_neuron_acts.shape)

fourier_neuron_acts torch.Size([512, 113, 113])


In [46]:
neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()
for freq in range(0, p//2):
    for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:
        for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:
            neuron_freq_norm[freq] += fourier_neuron_acts[:, x, y]**2
neuron_freq_norm = neuron_freq_norm / fourier_neuron_acts.pow(2).sum(dim=[-1, -2])[None, :]
imshow(neuron_freq_norm, xaxis="Neuron", yaxis="Freq", y=torch.arange(1, p//2+1), title="Neuron Frac Explained by Freq")

In [47]:
line(neuron_freq_norm.max(dim=0).values.sort().values, xaxis="Neuron", title="Max Neuron Frac Explained over Freqs")

# Progress Measures

## Setup Code

Code to run a metric over every checkpoint

In [48]:
metric_cache = {}

def get_metrics(model, metric_cache, metric_fn, name, reset=False):
    if reset or (name not in metric_cache) or (len(metric_cache[name]) == 0):
        metric_cache[name] = []

        for c, sd in enumerate(tqdm.tqdm((model_checkpoints))):
            model.reset_hooks()
            model.load_state_dict(sd)
            out = metric_fn(model)

            if type(out) == torch.Tensor:
                out = utils.to_numpy(out)
            metric_cache[name].append(out)
        model.load_state_dict(model_checkpoints[-1])
        try:
            metric_cache[name] = torch.tensor(metric_cache[name])
        except:
            metric_cache[name] = torch.tensor(np.array(metric_cache[name]))

### Loss Curves

Adjust vertical lines accordingly

In [49]:
memorization_end_epoch = 1500
circuit_formation_end_epoch = 13300
cleanup_end_epoch = 16600

In [50]:
fig = line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=False, title="Training Curve for Modular Addition", line_labels=['train', 'test'], toggle_x=True, toggle_y=True, return_fig=True)
add_lines(fig)

### Logit Periodicity

In [51]:
all_logits = original_logits[:, -1, :]
print(all_logits.shape)
all_logits = einops.rearrange(all_logits, "(a b) c -> a b c", a=p, b=p)
print(all_logits.shape)

torch.Size([12769, 113])
torch.Size([113, 113, 113])


#### Getting predicted logits by each key frequencies

In [52]:
coses = {}
for freq in key_freqs:
    #print("Freq:", freq)
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()
    cube_predicted_logits /= cube_predicted_logits.norm()
    coses[freq] = cube_predicted_logits

#### Cosine Similarity: A metric evaluating how well predicted logits are explained by each key frequencies

In [53]:
approximated_logits = torch.zeros_like(all_logits)
for freq in key_freqs:
    print("Freq:", freq)
    coeff = (all_logits.to('cuda:0') * coses[freq].to('cuda:0')).sum()
    print("Coeff:", coeff)
    cosine_sim = coeff.to('cuda:0') / all_logits.to('cuda:0').norm()
    print("Cosine Sim:", cosine_sim)
    approximated_logits = approximated_logits.to('cuda:0')
    approximated_logits += coeff.to('cuda:0') * coses[freq]
residual = all_logits.to('cuda:0') - approximated_logits.to('cuda:0')
print("Residual size:", residual.norm())
print("Residual fraction of norm:", residual.to('cuda:0').norm()/all_logits.to('cuda:0').norm())

Freq: 1
Coeff: tensor(40.9939, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.0014, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 2
Coeff: tensor(-25.5154, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(-0.0009, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 3
Coeff: tensor(4.1757, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.0001, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 4
Coeff: tensor(-5.8618, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(-0.0002, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 5
Coeff: tensor(-1.0132, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(-3.5402e-05, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 6
Coeff: tensor(-2.8688, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(-0.0001, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 7
Coeff: tensor(26.8485, device='cuda:0', grad_fn=<SumBackward0>)
Cosine Sim: tensor(0.0009, device='cuda:0', grad_fn=<DivBackward0>)
Freq: 8
Coef

#### Cosine Similarity on a random vector

In [54]:
random_logit_cube = torch.randn_like(all_logits)
print((all_logits * random_logit_cube).sum() / random_logit_cube.norm()/all_logits.norm())

tensor(0.0014, device='cuda:0', grad_fn=<DivBackward0>)


#### Look During Training

In [55]:
cos_cube = []
for freq in range(1, p//2 + 1):
    a = torch.arange(p)[:, None, None]
    b = torch.arange(p)[None, :, None]
    c = torch.arange(p)[None, None, :]
    cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()
    cube_predicted_logits /= cube_predicted_logits.norm()
    cos_cube.append(cube_predicted_logits)
cos_cube = torch.stack(cos_cube, dim=0)
print(cos_cube.shape)

torch.Size([56, 113, 113, 113])


In [56]:
def get_cos_coeffs(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals


get_metrics(model, metric_cache, get_cos_coeffs, "cos_coeffs")
print(metric_cache["cos_coeffs"].shape)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250, 56])



Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:261.)



In [57]:
fig = line(metric_cache["cos_coeffs"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Coefficients with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Coefficient", return_fig=True)
add_lines(fig)

In [58]:
def get_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    return vals / logits.norm()

Run this  cell again if graph not shown

In [68]:
get_metrics(model, metric_cache, get_cos_sim, "cos_sim")
fig = line(metric_cache["cos_sim"].T, line_labels=[f"Freq {i}" for i in range(1, p//2+1)], title="Cosine Sim with Predicted Logits", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

In [60]:
def get_residual_cos_sim(model):
    logits = model(dataset)[:, -1]
    logits = einops.rearrange(logits, "(a b) c -> a b c", a=p, b=p)
    vals = (cos_cube * logits[None, :, :, :]).sum([-3, -2, -1])
    residual = logits - (vals[:, None, None, None] * cos_cube).sum(dim=0)
    return residual.norm() / logits.norm()

Run this cell again if graph not shown

In [69]:
get_metrics(model, metric_cache, get_residual_cos_sim, "residual_cos_sim")
fig = line([metric_cache["cos_sim"][:, i] for i in range(p//2)]+[metric_cache["residual_cos_sim"]], line_labels=[f"Freq {i}" for i in range(1, p//2+1)]+["residual"], title="Cosine Sim with Predicted Logits + Residual", xaxis="Epoch", x=checkpoint_epochs, yaxis="Cosine Sim", return_fig=True)
add_lines(fig)

## Fourier Norms of neuron activations

In [62]:
neuron_acts_square = einops.rearrange(neuron_acts, "(a b) neur -> a b neur", a=p, b=p).clone()
# Center it
neuron_acts_square -= einops.reduce(neuron_acts_square, "a b neur -> 1 1 neur", "mean")
neuron_acts_square_fourier = einsum("a b neur, fa a, fb b -> fa fb neur", neuron_acts_square, fourier_basis, fourier_basis)
imshow(neuron_acts_square_fourier.norm(dim=-1), xaxis="Fourier Component b", yaxis="Fourier Component a", title="Norms of neuron activations by Fourier Component", x=fourier_basis_names, y=fourier_basis_names)

In [63]:
imshow(
    neuron_acts_square_fourier.norm(dim=-1)[1:, 1:],
    xaxis="Fourier Component b",
    yaxis="Fourier Component a",
    title="Norms of neuron activations by Fourier Component Excluding Constant",
    x=fourier_basis_names[1:],
    y=fourier_basis_names[1:]
)

## Excluded Loss

In [64]:
approx_neuron_acts = torch.zeros_like(neuron_acts)
# approx_neuron_acts += neuron_acts.mean(dim=0)
a = torch.arange(p)[:, None]
b = torch.arange(p)[None, :]
for freq in key_freqs:
    cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()
    cos_apb_vec /= cos_apb_vec.norm()
    cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
    sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()
    sin_apb_vec /= sin_apb_vec.norm()
    sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
    approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
excluded_neuron_acts = neuron_acts - approx_neuron_acts
excluded_logits = excluded_neuron_acts @ W_logit
print(loss_fn(excluded_logits[train_indices], train_labels))

tensor(0.0004, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)


We found empirically that for m=n where n is a fixed positive even integer, the models always learn to only utilize cosine frequencies, even with different hyperparameters. Uncomment the sine part if you're interested in ablating sine frequencies too

In [65]:
def get_excluded_loss(model):
    logits, cache = model.run_with_cache(dataset)
    logits = logits[:, -1, :]
    neuron_acts = cache["post", 0, "mlp"][:, -1, :]
    approx_neuron_acts = torch.zeros_like(neuron_acts)
    # approx_neuron_acts += neuron_acts.mean(dim=0)
    a = torch.arange(p)[:, None]
    b = torch.arange(p)[None, :]
    for freq in key_freqs:
        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec

        cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()
        cos_apb_vec /= cos_apb_vec.norm()
        cos_apb_vec = einops.rearrange(cos_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec
        """
        sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()
        sin_apb_vec /= sin_apb_vec.norm()
        sin_apb_vec = einops.rearrange(sin_apb_vec, "a b -> (a b) 1")
        approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec
        """
    excluded_neuron_acts = neuron_acts - approx_neuron_acts
    residual_stream_final = excluded_neuron_acts @ model.blocks[0].mlp.W_out + cache["resid_mid", 0][:, -1, :]
    excluded_logits = residual_stream_final @ model.unembed.W_U
    excluded_loss = loss_fn(excluded_logits[train_indices], train_labels)
    """
    wandb.log({
            "excluded_logits":excluded_logits,
            "excluded_loss": excluded_loss
            })
    """
    return excluded_loss
get_excluded_loss(model)

tensor(0.0074, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)

In [66]:
get_metrics(model, metric_cache, get_excluded_loss, "excluded_loss", reset=True)
print(metric_cache["excluded_loss"].shape)

  0%|          | 0/250 [00:00<?, ?it/s]

torch.Size([250])


Excluded Loss : Train Loss with ablation

In [70]:
fig = line([train_losses[::100], test_losses[::100], metric_cache["excluded_loss"]], x=np.arange(0, len(train_losses), 100), xaxis="Epoch", yaxis="Loss", log_y=True, title="Excluded Loss Curve", line_labels=['train', 'test', "excluded_loss"], toggle_x=True, toggle_y=True, return_fig=True)

add_lines(fig)