## Setup

In [None]:
TRAIN_MODEL = False

In [None]:
import plotly.io as pio

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import pandas as pd
import math

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
def imshow(image, title=None, xaxis=None, yaxis=None, x_points=None, y_points=None):
    px.imshow(
        utils.to_numpy(image),
        title=title,
        labels={'x': xaxis, 'y': yaxis},
        x=x_points,
        y=y_points,
        color_continuous_midpoint=0.,
        color_continuous_scale='RdBu',
        aspect='auto'
    ).show()

def line(tensor, renderer=None, xaxis='', yaxis='', **kwargs):
    px.line(
        utils.to_numpy(tensor),
        labels={"x":xaxis, "y":yaxis},
        **kwargs
    ).show(renderer)

def scatter(x, y, xaxis='', yaxis='', caxis='', renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y, x=x,
        labels={'x':xaxis, 'y':yaxis, 'color':caxis},
        **kwargs
    ).show(renderer)

In [None]:
def std_heatmap(z):
    heatmap = go.Heatmap(
        z=z,
        colorscale='RdBu',
        zmid=0.,
        showscale=False,
    )
    return heatmap

In [None]:
# Where the model's saved
PTH_LOCATION = 'modadd.pth'

## Model Training

### Config

In [None]:
p = 113
frac_train = 0.3

# Optimizer config
lr = 1e-3
wd = 1.
betas = (0.9, 0.98)

num_epochs = 25000
checkpoint_every = 100

DATA_SEED = 598

### Define Task

Define the dataset & labels

In [None]:
# Input format is |a|b|=|
a_vector = einops.repeat(torch.arange(p), 'i -> (i j)', j=p)
b_vector = einops.repeat(torch.arange(p), 'j -> (i j)', i=p)
equals_vector = einops.repeat(torch.tensor(113), ' -> (i j)', i=p, j=p)

print(a_vector, b_vector, equals_vector)

In [None]:
dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1)
print(dataset[:5])
print(dataset.shape)

In [None]:
labels = (dataset[:, 0] + dataset[:, 1]) % p
print(labels[:5])
print(labels.shape)

In [None]:
# Convert dataset into train + test set, 30% in training set
torch.manual_seed(DATA_SEED)
indices = torch.randperm(p*p)
cutoff = int(p*p*frac_train)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]

train_data = dataset[train_indices]
train_labels = labels[train_indices]
test_data = dataset[test_indices]
test_labels = labels[test_indices]

print(train_data[:5])
print(train_labels[:5])
print(train_data.shape)

print(test_data[:5])
print(test_labels[:5])
print(test_data.shape)

### Define Model

In [None]:
cfg = HookedTransformerConfig(
    n_layers=1,
    n_heads=4,
    d_model=128,
    d_head=32,
    d_mlp=512,
    act_fn='relu',
    normalization_type=None,
    d_vocab=p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True,
    device='cpu',
    seed=999,
)

In [None]:
model = HookedTransformer(cfg)

In [None]:
# Disable biases because they aren't needed and make model harder to interpret
for name, param in model.named_parameters():
    if 'b_' in name:
        param.requires_grad = False

### Define Optimizer + Loss

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=betas)

In [None]:
def loss_fn(logits, labels):
    if len(logits.shape) == 3:
        logits = logits[:, -1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()

train_logits = model(train_data)
train_loss = loss_fn(train_logits, train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits, test_labels)
print(test_loss)

In [None]:
# Check what loss would be if guess was uniformly random
print("Uniform loss:")
print(np.log(p))

### Actually Train

In [None]:
# We'll train the model with the whole batch at once, not SGD
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []

if TRAIN_MODEL:
    for epoch in tqdm.tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        if (epoch + 1) % checkpoint_every == 0:
            train_losses.append(train_loss.item())
            with torch.inference_mode():
                test_logits = model(test_data)
                test_loss = loss_fn(test_logits, test_labels)
                test_losses.append(test_loss.item())

            checkpoint_epochs.append(epoch)
            model_checkpoints.append(copy.deepcopy(model.state_dict()))
            print(f"Epoch {epoch} Train Loss {train_loss.item()} Test Loss {test_loss.item()}")

In [None]:
if TRAIN_MODEL:
    torch.save(
        {
            'model': model.state_dict(),
            'config': model.cfg,
            'checkpoints': model_checkpoints,
            'checkpoint_epochs': checkpoint_epochs,
            'test_losses': test_losses,
            'train_losses': train_losses,
            'train_indices': train_indices,
            'test_indices': test_indices
        },
        PTH_LOCATION
    )

In [None]:
if not TRAIN_MODEL:
    cached_data = torch.load(PTH_LOCATION)
    model.load_state_dict(cached_data['model'])
    model_checkpoints = cached_data['checkpoints']
    checkpoint_epochs = cached_data['checkpoint_epochs']
    test_losses = cached_data['test_losses']
    train_losses = cached_data['train_losses']
    train_indices = cached_data["train_indices"]
    test_indices = cached_data["test_indices"]

In [None]:
train_loss_trace = go.Scatter(
    x=np.arange(len(train_losses)),
    y=train_losses,
    mode='lines', name='Train Loss'
)
test_loss_trace = go.Scatter(
    x=np.arange(len(test_losses)),
    y=test_losses,
    mode='lines', name='Test Loss'
)
fig = go.Figure()
fig.add_trace(train_loss_trace)
fig.add_trace(test_loss_trace)
fig.show()

In [None]:
loss_fn(model(train_data), train_labels)

## Analysing the Model

### Setting up for analysis

In [None]:
# Get logits for all inputs and cache
original_logits, cache = model.run_with_cache(dataset)
params = model.state_dict()

You'll often be confused about what each hook is. To test, you can always print the hooks, multiply them by some parameters and compare them with other hooks.

In [None]:
# Listing all cache items
for name, param in cache.items():
    print(name, param.shape, sep=': ')

In [None]:
# Listing model layers & parameters
for param_name in params.keys():
    print(param_name)

### Analyzing the embedding matrix W_E

In [None]:
W_E = params['embed.W_E'][:-1].detach()
imshow(W_E)

This is confusing. We can perform SVD on this matrix to get information.

In [None]:
U, S, V = torch.svd(W_E)

In [None]:
imshow(U)
imshow(U * S)
line(S)

Only a few vectors here really matter. I think if we limit ourselves up to (and including) row 11 we can get most of it. To check if this is indeed the case, we can compare the 2.

In [None]:
U_ = U.clone()
U_[:, 12:] = 0.
# Check if there's any difference between U and only U's 15 first rows
((U-U_) @ torch.diag(S) @ V.T).abs().mean()

There is some information loss, but not too much. The biggest difference is one order of magnitude smaller than the largest values of W_E. There might be an impact, but for now let's analyze just the first 12 columns.

In [None]:
line(U[:, :12])

We can apply the Fast Fourier Transform to find the waves represented by each column.

In [None]:
U_freq = np.fft.fft(U[:, :12], axis=0)

imshow(np.real(U_freq), title="real freq", xaxis='column')
imshow(np.imag(U_freq), title='imag freq', xaxis='column')

I got GPT-4 to write some code that finds the wave that fits the data. I don't really understand what it's doing, but it works great!

I should really learn how the FFT, discrete and normal Fourier transforms work. They seem super useful.

In [None]:
def analyze_and_plot(data, num=1000):
    # Perform the Fast Fourier Transform (FFT) on the input data
    fft_data = np.fft.fft(data)

    # Calculate the amplitudes (absolute values) of the FFT result
    amplitudes = np.abs(fft_data)

    # Find the index of the dominant frequency
    dominant_frequency_index = np.argmax(amplitudes[1:]) + 1
    if dominant_frequency_index > len(data) / 2:  # Handle aliasing for real-valued input signal
        dominant_frequency_index = len(data) - dominant_frequency_index

    # Calculate the dominant frequency in cycles per data point and its amplitude
    dominant_frequency = dominant_frequency_index / len(data)
    dominant_amplitude = amplitudes[dominant_frequency_index] / (len(data) / 2)  # Correct for symmetric FFT of real-valued signal

    # Calculate the phase shift of the dominant frequency component
    phase_shift = np.angle(fft_data[dominant_frequency_index])

    # Create sequences of evenly spaced values for the original and the reconstructed data
    x_values_data = np.linspace(0, len(data) - 1, len(data))
    x_values_reconstructed = np.linspace(0, len(data) - 1, num)  # Use num to control the resolution of the reconstructed function

    # Reconstruct the function with the dominant frequency
    reconstructed_function = dominant_amplitude * np.cos(2 * np.pi * dominant_frequency * x_values_reconstructed + phase_shift)

    # Create an empty plotly graph object
    fig = go.Figure()

    # Add the original data and the reconstructed function to the plot
    fig.add_trace(go.Scatter(x=x_values_data, y=data, mode='lines', name='Original data'))
    fig.add_trace(go.Scatter(x=x_values_reconstructed, y=reconstructed_function, mode='lines', name='Reconstructed function'))

    # Create dictionary with data from function

    return fig, f"f(x) = {dominant_amplitude} * cos(2 * pi * {dominant_frequency} * x + {phase_shift})"

In [None]:
fig = make_subplots(rows=2, cols=6)
for col in range(6):
    for row in range(2):
        # Suppose that the analyze_and_plot function returns a figure with a single trace
        figure = analyze_and_plot(U[:, 2 * col + row])[0]
        for trace in figure['data']:
            fig.add_trace(
                trace,
                row=row + 1, col=col + 1
            )
fig.update_layout(height=600)
fig.show()

All pairs (0-1 up to 10-11) represent pairs of sine-cosine, where one wave is shifted +pi/2 to the right (or left). I'll probably have to save the frequencies, amplitudes and phase-shifts, because I believe I'll need them for analysis later.

Nonetheless, the analysis of W_E is basically done. Let's move on to other layers!

### Analyzing the self-attention layer

Neel had a super good idea: plot the attention patterns in a 113 x 113 gird. Keep an eye out for cool visualizations like this, and always try to find innovatives ways to visualize data -- that can save you a lot of time.

In [None]:
attn_pattern = cache['pattern', 0]
print(attn_pattern.shape)

In [None]:
# Rearrange attn pattern into a 113 x 113 grid
total_square_attn = einops.rearrange(attn_pattern, '(i j) h a b -> h i j a b', i=p, j=p)
print(total_square_attn.shape)

In [None]:
# We only care about the attn of Q: '=', K: 'a' and 'b'
square_attn = total_square_attn[:, :, :, -1, :-1]
print(square_attn.shape)

fig_square_attn = make_subplots(
    # 'a' in row 1, 'b' in row 2, cols correspond to attn heads
    rows=2, cols=4,
    subplot_titles=[
        f"Head {col} Input {'a' if row%2==0 else 'b'}"
        for row in range(2) for col in range(4)
    ],
    vertical_spacing=0.1,
    horizontal_spacing=0.05
)

for row in range(2):
    for col in range(4):
        fig_square_attn.add_trace(
            std_heatmap(square_attn[col, :, :, row]),
            row=row+1, col=col+1,
        )
fig_square_attn.update_layout(height=800)
fig_square_attn.show()

This is suuuuuuuper periodic. Let's take the 2-D FFT to check if and how this can be reduced into waves, then ask GPT-4 to help us out.

In [None]:
# Take 2-D FFT over the axis of length 113 (1 and 2)
fft_square_attn = np.fft.fft2(square_attn, axes=(1, 2))
print(fft_square_attn.shape)

In [None]:
fig_fft_square_attn = make_subplots(
    # One row for each (head-input) pair
    # Cols are graph--real--imag, graph is the attn patterns, real + imag are from the FFT
    rows=8, cols=3,
    subplot_titles=[
        f"Head {row//2} Input {'a' if row%2==0 else 'b'} {'graph' if col%3==0 else ('real' if col%3==1 else 'imag')}"
        for row in range(8) for col in range(3)
    ],
    vertical_spacing=0.02,
    horizontal_spacing=0.05
)

for row in range(8):
    for col in range(3):
        # Show either graph, real of imag part of FFT
        head = row // 2
        a_or_b = row % 2
        # Shift the FFT graph to the center
        centered_fft = np.fft.fftshift(fft_square_attn[head, :, :, a_or_b])

        match col:
            case 0:
                # Show graph
                z = square_attn[head, :, :, a_or_b]
            case 1:
                # Show centered real part of FFT
                z = np.real(centered_fft)
            case 2:
                # Show centered imag part of FFT
                z = np.imag(centered_fft)
        
        fig_fft_square_attn.add_trace(
            std_heatmap(z),
            row=row+1, col=col+1,
        )
fig_fft_square_attn.update_layout(height=3200)
fig_fft_square_attn.show()

The pairs 0-1 and 2-3 look very similar. I'm trying to understand what they're doing. They're representing waves, and depending on the inputs it'll give more importance to 'a' or 'b'. How does this work, and what does this mean? 

1. One graph is the transpose of the other. The following example will make this clear: Our inputs are (a, b). Attention is calculated irrespective of the other values, and since 'n' in pos 'a' is (practically) the same as 'n' in pos 'b' (I did this analysis somewhere), picking (b, a) will reverse the attention.

2. (3, 1) has attention (0.2, 0.8). Why is this? Why are we getting 0.2 of the Value of 3 and 0.8 of the Value of 1? What does this mean, and what is happening here?

We can see that only a few frequencies matter. Let's create a function to reduce the attention heads' patterns to the bare minimum (only the most important frequencies), and then check if they do indeed approximate the real attention patterns well.

In [None]:
def reduce_matrix_fft(matrix, n_top=1):
    '''
    In: wavy real square matrix
    Out: matrix reduced to its n_top greatest frequencies on the main axes
    Maybe I'll also make it return the waves' data (not for now)
    '''
    fft_matrix = np.fft.fft2(matrix)
    
    # Sorted (ascending order) list of indexes of greatest absolute value
    sort_id_rows = np.abs(fft_matrix[:, 0]).argsort()
    sort_id_cols = np.abs(fft_matrix[0, :]).argsort()

    # Put 1 + 2*n_top greatest values of each row & col in a zeros vector
    reduced_fft_matrix = np.zeros_like(matrix, dtype=np.complex128)
    reduced_fft_matrix[:, 0][sort_id_rows[-2*n_top-1:]] = fft_matrix[:, 0][sort_id_rows[-2*n_top-1:]]
    reduced_fft_matrix[0, :][sort_id_cols[-2*n_top-1:]] = fft_matrix[0, :][sort_id_cols[-2*n_top-1:]]

    # Apply inverse FFT
    reduced_matrix = np.real(np.fft.ifft2(reduced_fft_matrix))

    return torch.from_numpy(reduced_matrix)

In [None]:
# Check if approximation is good, and the effect of 'n_top' on the approximation
for head in range(4):
    errors = []
    for n_top in range(114):
    # Take attention pattern of head 'head', 'a' if a_or_b = 0, 'b' if = 1
        head_attn = square_attn[head, :, :, 0]
        error = torch.abs(reduce_matrix_fft(head_attn, n_top=n_top) - head_attn).mean()
        errors.append(100 * error / head_attn.mean())
    line(errors)

**MORE IDEAS TO TRY**: 

1. Look at the relation between equivalent sums in the attention heads and after the QK - WV circuit. Hopefully equivalent sums will have equivalent vectors there, which will make things easier to interpret. Be sure to look *before* the residual stream.

2. Find similar relations in the attention patterns themselves? Maybe?

### Investigating the circuit from W_E to the attention patterns

For 'a' and 'b', what matters is what's at the end of **W_E -> W_pos -> W_K**

For '=', what matters is what's at the end of **W_E -> W_pos -> W_Q**

Our preliminary analysis was very important to understand what's going on, and I believe that the tools we've used before will be of use here too. Now I'll try to analyze what comes right before the attention is calculated. For starters, I want to plot all inputs in **pos 0** (a), all inputs in **pos 1** (b), and the final form of '=' after passing through all the intermediate layers.

Let's first make the way to our destination by hand, so that I get some practice with batch multiplication

In [None]:
# One-hot encode the dataset
dataset_one_hot = F.one_hot(dataset).float() # The dtype has to be float
print(dataset_one_hot.shape, params['embed.W_E'].shape)

# Pass dataset through embedding matrix
dataset_embed = einops.einsum(dataset_one_hot, params['embed.W_E'], 'batch pos d_in, d_in d_model -> batch pos d_model')

# Check if it's the same as hook after W_E
print((dataset_embed == cache['embed']).float().mean())

# Add positional embedding
dataset_pos = dataset_embed + params['pos_embed.W_pos']

# Check if it's the same as the residual stream before attn
print((dataset_pos == cache['resid_pre', 0]).float().mean())

# Get Keys and Queries
dataset_K = einops.einsum(dataset_pos, params['blocks.0.attn.W_K'], 'batch pos d_model, head d_model d_head -> batch pos head d_head')
dataset_Q = einops.einsum(dataset_pos, params['blocks.0.attn.W_Q'], 'batch pos d_model, head d_model d_head -> batch pos head d_head')

# Check if it's the same as hook_k and hook_q
print((dataset_K == cache['k', 0]).float().mean())
print((dataset_Q == cache['q', 0]).float().mean())

Now I want to display the things we're interested in: Queries of '=' for each head keys for 'a' and 'b' for all numbers 0-112

In [None]:
# We can pick just the first element of the batch ([0, 0, 113]) because '=' is the same in all of them
q_equals = cache['q', 0][1, -1, :, :]

imshow(q_equals, yaxis='heads', title="Query of '=' for each head")

In [None]:
# Print difference between keys of 'a' and 'b'
for head in range(4):
    k = cache['k', 0]
    # Get keys of 'a' and 'b'
    k_a = k[::113, 0, head, :]
    k_b = k[:113, 1, head, :]

    # Compute maximum absolute difference
    error = np.abs(k_a - k_b).max()

    # Print percentage error in relation to k_a's mean
    print(100 * error / np.abs(k_a).mean())

As we can see, the difference between the Keys of 'a' and 'b' is never greater than 1% (which was expected, since addition is commutative). Thus, we may restrict our analysis to position 'a', since 'b''s analysis will be practically identical.

In [None]:
fig_keys_graph_and_fft = make_subplots(
    # One row for each (head-input) pair
    # Cols are graph--real--imag, graph is the Key of 'a', real + imag are from the FFT
    rows=4, cols=3,
    subplot_titles=[
        f"Head {head}, 'a' {'graph' if col%3==0 else ('real' if col%3==1 else 'imag')}"
        for head in range(4) for col in range(3)
    ],
    vertical_spacing=0.05,
    horizontal_spacing=0.05,
)

for head in range(4):
    for col in range(3):
        # Show either graph of Keys of 'a', real of imag part of FFT
        graph = cache['k', 0][::113, 0, head, :]

        match col:
            case 0:
                # Show graph
                z = graph
            case 1:
                # Show centered real part of FFT
                z = np.real(np.fft.fftshift(np.fft.fft(graph, axis=0)))
            case 2:
                # Show centered imag part of FFT
                z = np.imag(np.fft.fftshift(np.fft.fft(graph, axis=0)))
        
        fig_keys_graph_and_fft.add_trace(
            std_heatmap(z),
            row=head+1, col=col+1
        )
fig_keys_graph_and_fft.update_layout(height=1600)
fig_keys_graph_and_fft.show()

IMPORTANT IDEAS:

1. The graphs above are **extremely** periodic, and the frequencies are the same throughout all the head dimensions. This means that, **very likely**, all of the 113 arrays are just one specific array multiplied by a scalar (which is a function of many frequencies). Since what we do is take the dot product between this and '=''s Query, the resulting dot product will be a constant (fixed for all 113) multiplied by that scalar determined by the frequencies.

2. Maybe one reason why the attention patterns weren't as periodic as the graphs above was because of the softmax. To check if I'm right, perform FFT on the attention pattern *before* softmax.

3. You might be asking yourself, future Nicolas, why I didn't perform the 2-D FFT, since it's **obviously** periodic in both dimensions. I asked myself the same question, and apparently the FFT disagrees.

We can check both of these hypothesis by performing the dot product between Queries and Keys, and then applying the FFT again. If the frequencies are similar, than (I believe) I'll be proven correct. Either way, what we *really* need to analyze is that dot product, so let's get started rn! Since we've shown elsewhere that 'a' and 'b' are basically identical (any difference between them is *at least* smaller than 10^-3), we'll restric our analysis to position 'a'.

In [None]:
# Get Query of '=', Key of 'a'
query_equals = cache['q', 0][0, -1, :, :] # again, we can pick the first one because all are equal
keys_a = cache['k', 0][::113, 0, :, :]

print("keys_a:", keys_a.shape)
print("query_equals:", query_equals.shape)

# Calculate the unnormalized attention of 'a'
attn_a = einops.einsum(keys_a, query_equals, 'num head d_head, head d_head -> num head')

# Show absolute value of FFT of unnormalized attention of 'a'
imshow(np.abs(np.fft.fft(attn_a, axis=0)), xaxis='heads', yaxis='nums')


# We want to check if this has the same frequencies as 'a''s keys
# Concatenate different heads into 1 dimension so that it has same shape as unnorm attn
concat_keys_a = einops.rearrange(keys_a, 'nums head d_head -> nums (head d_head)')

# Show absolute value of FFT of 'a''s Keys, check if they look the same
imshow(np.abs(np.fft.fft(concat_keys_a, axis=0)), xaxis='heads', yaxis='nums')

imshow(attn_a)

I WAS CORRECT FUCK YEAH

This means now the only things I have to analyze are:

1. The attention pattern before softmax
2. How different numbers add to '=''s residual stream
3. What the MLP does

**IDEAS**: 

1. You haven't looked at the Values matrices at all, they might contain some useful insight

2. PLOT 'attn_a' AS LINES AND DO CURVE-FITTING YOU MORON

**HYPOTHESIS**: Look at how clean these lines are. Super fucking clean if I may say so. Maybe what softmax is doing is 'corrupt' the pretty lines, that's why we don't get frequencies as nice as the ones here after the softmax. I believe this is true. Actually no, look below. One of the problems is with the Values, they don't have frequencies as clean as these.

In [None]:
# Plot attn_a and 
fig_attn_a = make_subplots(
    rows=4, cols=1,
    subplot_titles=[f'head {head}' for head in range(4)],
    vertical_spacing=0.05
)

for head in range(4):
    figure, data = analyze_and_plot(attn_a[:, head], num=113)
    print(data)
    for trace in figure['data']:
        fig_attn_a.add_trace(
            trace,
            row=head + 1, col=1
        )
fig_attn_a.update_layout(height=800)
fig_attn_a.show()

Uhhh, this is weird. WHY DOES JUST ONE FREQUENCY FIT THE WAVE SO WELL?? I'll get GPT-4 to write code that gets the 4 main frequencies (when I get it back ;-;), but this is already surprisingly accurate. What if we try to understand what's going on by using just these simplifications? Maybe we'll actually get pretty far, and we won't need GPT-4's help.

**IDEA**: With the cosine wave on hand, calculate and plot the attention graph from the pure waves and compare it to the original. If it's close enough, celebrate! You can also go through the whole process and get the outputs based on these. If they're decently high, you'll be basically done!

In [None]:
# The indexes are for the heads 0-4. Maybe I should cut some digits
head_cos_wave = {
    'dominant_amplitude': [6.329031410795562, 8.48166410569167, 6.196583114665829, 7.722471474643398],
    'dominant_frequency': [0.36283185840707965, 0.36283185840707965, 0.07079646017699115, 0.07079646017699115],
    'phase_shift': [2.3973650351858686, -0.8454673633717475, 2.9620057309991714, -0.1647056964663721],
}

# Waves for unnormalized attention of 'a' and 'b'
x_0_to_112 = torch.linspace(start=0, end=112, steps=113)
cos_x = torch.stack([
    (head_cos_wave['dominant_amplitude'][head] * torch.cos(2*torch.pi*x_0_to_112*head_cos_wave['dominant_frequency'][head] + head_cos_wave['phase_shift'][head]))
    for head in range(4)]
)

# Display waves
line(einops.rearrange(cos_x, 'a b -> b a'))

In [None]:
# Reconstruct attention scores from the waves
reconstructed_attn_scores = torch.tensor([[[
    [cos_x[head][a] / math.sqrt(32), cos_x[head][b] / math.sqrt(32)] # Divide by d_head for normalization
    for b in range(113)]
    for a in range(113)]
    for head in range(4)]
)

In [None]:
# Calculate attention pattern
reconstructed_attn_pattern = reconstructed_attn_scores.softmax(-1)

# Compare reconstructed and original attn patterns
imshow(reconstructed_attn_pattern[3, :, :, 1])
imshow(square_attn[3, :, :, 1])

This looks SOOO good. They're so similar!!! I'll pass it through the rest of the neural net to see if I get a good result. If so, I'm basically done!

In [None]:
# Rearrange it into the original (reduced) shape
reshaped_reconstructed_attn_pattern = einops.rearrange(reconstructed_attn_pattern, 'head a b a_b -> (a b) a_b head 1')

# Get values for a and b (12769, 2, 4, 32)
v_a_and_b = cache['v', 0][:, :-1]

# Get z (12769, 4, 32)
reconstructed_z = (reshaped_reconstructed_attn_pattern * v_a_and_b).sum(1)
print(reconstructed_z.shape)

# Get attn_out (12769, d_model)
print(params['blocks.0.attn.W_O'].shape)
reconstructed_attn_out = einops.einsum(reconstructed_z, params['blocks.0.attn.W_O'], 'b h d_head, h d_head d_model -> b d_model')

# Get resid_mid (12796, 128)
reconstructed_resid_mid = cache['resid_pre', 0][:, -1] + reconstructed_attn_out

# Get logits
reconstructed_logits = pass_mlp(reconstructed_resid_mid)

# Get probs
reconstructed_probs = reconstructed_logits.softmax(-1)

# Store all vectors in equivalence classes (same class if same sum)
sorted_probs = [[] for _ in range(p)]

for a in range(p):
    for b in range(p):
        sum = (a + b) % p
        index = a * p + b
        # Put vector of sum 'sum' in index class 'sum'
        sorted_probs[sum].append(utils.to_numpy(reconstructed_probs[index]))

# Transform into ndarray with shape (113, 113, 128) -> (class/sum, num, d_model)
# 'num' is just a number from 0 to 112, not necessarily in order because the for loop above doesn't track that
sorted_probs = np.array(sorted_probs)

# Stack them for improved visualization
stack_probs = einops.rearrange(sorted_probs, 'sum num d_model -> (sum num) d_model')

imshow(stack_probs)

**IT'S FUCKING DONE AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA**

### Analyzing the model after the attention layer

**STOP DISPLAYING HUGE GRAPHS, IT LAGS THE FUCK OUT OF VS CODE**

**IDEA I ONLY HAD TOO LATE**: Look at the residual stream for '=' after the attention layer: I expect equivalent sums (0 + 3 and 2 + 1) to have very similar vectors. If they dont, then it'll definitely be the job of the MLP to make sure they're similar. Since the loss is very small, *somewhere* after the attention equivalent sums **must** have very similar '='s vectors.

Let's start by looking at the residual stream after the MLP. There the model should already have 'sorted' equivalent sums in equivalent vectors

In [None]:
# Get residual stream of '=' after MLP
resid_post = cache['resid_post', 0][:, -1]

# Store all vectors in equivalence classes (same class if same sum)
ordered_resid_post = [[] for _ in range(p)]

for a in range(p):
    for b in range(p):
        sum = (a + b) % p
        index = a * p + b
        # Put vector of sum 'sum' in index class 'sum'
        ordered_resid_post[sum].append(utils.to_numpy(resid_post[index]))

# Transform into ndarray with shape (113, 113, 128) -> (class/sum, num, d_model)
# 'num' is just a number from 0 to 112, not necessarily in order because the for loop above doesn't track that
ordered_resid_post = np.array(ordered_resid_post)

# Stack them for improved visualization
stack_orp = einops.rearrange(ordered_resid_post, 'sum num d_model -> (sum num) d_model')
imshow(stack_orp[24*128 : 27*128], title="Stacked classes 24-26 of resid_post") # Only look at a few classes because visualizing everything lags my pc

# Also concatenate them because I'll need it for later
cat_orp = einops.rearrange(ordered_resid_post, 'sum num d_model -> num (sum d_model)')

This is exactly what I wanted to see: vectors which represent the same sum are, *indeed*, similar. Unfortunately, this doesn't seem to be the case for the vectors before the MLP (look below), which will be a pain in the ass. Hopefully they have a cool representation.

Below we're doing the same thing as above, but with the residual stream before the MLP

In [None]:
# Get residual stream of '=' before MLP
resid_mid = cache['resid_mid', 0][:, -1]

# Store all p * p vectors in equivalence classes (same class if same sum)
ordered_resid_mid = [[] for _ in range(p)]

for a in range(p):
    for b in range(p):
        sum = (a + b) % p
        index = a * p + b
        # Put vector of sum 'sum' in index class 'sum'
        ordered_resid_mid[sum].append(utils.to_numpy(resid_mid[index]))

# Transform into ndarray with shape (113, 113, 128) -> (class/sum, num, d_model)
# 'num' is just a number from 0 to 112, not necessarily in order because the for loop above doesn't track that
ordered_resid_mid = np.array(ordered_resid_mid)

# Concatenate them for (hopefully) better visualization
# Here concatenating is better because there's no apparent relation between vectors of the same class
cat_orm = einops.rearrange(ordered_resid_mid, 'sum num d_model -> num (sum d_model)')
imshow(cat_orm[:, 24*128 : 27*128], title="Concatenated classes 24-26 of resid_mid")

This is a random mess AAAAAAAAAAAAAAAAAAA. What we wanted was to see homogenous columns, which would imply that vectors from the same class are similar. Well, tbh it's not *that* bad, you can see the wavy patterns pretty well, FFT will do a good job with this.

**IDEA**: Maybe what the MLP does is extract a certain subspace from the vectors. If so, FFT should make 'resid_mid' interpretable. Hopefully. Tonight, I'll pray.

In [None]:
# Perform FFT on concatenated Ordered Residual Mid to hopefully extract some info
fft_cat_orm = np.fft.fft(cat_orm[:, 24*128 : 27*128], axis=0) # Only doing with these sections because lag

# Display abs of FFT of concatenated ORM
imshow(np.abs(fft_cat_orm), title="FFT of classes 24-26 of resid_mid")

WTF, I didn't expect this at all. Why can they all be represented with the same frequencies? This is extremely surprising.

**IMPORTANT REALIZATION** *A priori*, there's no reason for them to be ordered in any way, because the order in which I picked members from each equivalence class wasn't orderly. It wasn't random either, but I presume (until I think better of it) that the ordering shouldn't matter. Therefore, I should try to find another way to extract information from vectors in the same equivalence class that doesn't use the FFT (at least not at first). Some possible ideas are:

1. Pass it through the first layer of the MLP to see if the data gets more orderly, and how that happens.
2. Other ideas

**To recapitulate**: each 128-long block in the x axis represents one class. For some reason, all classes can be represented using the same frequencies (or at least it seems so). This is very surprising and unexpected, because *a priori* there's no reason for different equivalence classes to be represented by the same frequencies (or to have a similar wave-structure). Think of why this could be the case:

1. The attention heads represent stuff as waves so something something all waves are the same what changes is where they are something something.

Let's check how many frequencies we need to make an omelete:

In [None]:
# Get abs max of each row of fft_cat_orm
max_fft_cat_orm = np.abs(fft_cat_orm).max(axis=1)

for i in range(p //2): # Only the lower-half matters, the other part is reflected
    if max_fft_cat_orm[i] > 10: # The biggest below 10 is less than 7, can be ignored
        print(i, max_fft_cat_orm[i])

9 different frequencies. This feels tractable. It'd be cool if we managed to reverse-engineer what the MLP is doing by finding the differences in frequencies.

Speaking of which, I haven't used the FFT on 'resid_post'. Maybe it'll give some insight. However, I need to sleep now. FUCK SLEEPING WE MUST DO SCIENCE!

In [None]:
# Take FTT of 'resid_post' wrt columns (which look very similar already)
fft_cat_orp = np.fft.fft(cat_orp, axis=0)

# Get abs max of each row of fft_cat_orp
max_fft_cat_orp = np.abs(fft_cat_orp).max(axis=1)

# Check if there are any big ones
for i in range(p): # Only the lower-half matters, the other part is reflected
    if max_fft_cat_orp[i] > 100:
        print(i, max_fft_cat_orp[i])

imshow(np.abs(fft_cat_orp[:, 24*128 : 27*128]), title="FFT of classes 24-26 of resid_post") # My computer can't handle this shit anymore

This is very good news, they can be explained with basically just one frequency (all others are at least one order of magnitude smaller). Let's try reducing the "Ordered Residual Post" to just its highest frequency

In [None]:
# Reduce fft_cat_orp to one its greatest frequency
red_fft_cat_orp = np.zeros_like(fft_cat_orp, dtype=np.complex128)
red_fft_cat_orp[0] = fft_cat_orp[0]

# Perform the inverse fourier transform to get reduced cat_orp
red_cat_orp = np.real(np.fft.ifft(red_fft_cat_orp, axis=0)) # The imaginary part will be negligeble because the original was real

# Display both original
imshow(red_cat_orp[:, 24*128 : 27*128], title='Reduced Concatenated Ordered Residual Post')
imshow(cat_orp[:, 24*128 : 27*128], title='Concatenated Ordered Residual Post')

One is very reduced, but they're basically the same. To see if they're both doing the same job, we can plot the probability distributions you get from the output's logits. if the distributions are the same, we'll know that the reduction doesn't lose significant information.

In [None]:
# Create stacked version of original and reduced ORP
red_stack_orp = einops.rearrange(red_cat_orp, 'num (sum d_model) -> (sum num) d_model', sum=p, num=p)[::p] # Each 113 block has the same output
stack_orp = einops.rearrange(ordered_resid_post, 'sum num d_out -> (sum num) d_out')[::p] # Idem

# Pass them through W_U to get logits
logits = torch.from_numpy(stack_orp) @ params['unembed.W_U']
red_logits = torch.from_numpy(red_stack_orp).float() @ params['unembed.W_U']

# Take the softmax to get probability distribution
probs = logits.softmax(dim=-1)
red_probs = red_logits.softmax(dim=-1)

# Display both to compare probability distributions
imshow(probs, title="Original Probability Distribution")
imshow(red_probs, title="Reduced Probability Distribution")

YES YESSSSSSSSSSSSSSSSSSSSSSSSS THIS IS ALL THAT FUCKING MATTERS YEAAAAAAAAAAAAAAAAAAH FUCK EVERYTHING FUCK EVERYONE ONLY ONE FREQUENCY FUCKING MATTERS!!!!

**RESULT**: Yes, onlt that frequency matters. It impressed me how it's not even being rounded, all of the probabilities are 1 (or .9999 lol). This means that we only need to learn how to get from **resid_mid** to the reduced **resid_post**. Hopefully inspecting the MLP will tell us how it's done, so that we can reverse-engineer the process and find out what the vectors in **resid_mid** are.

### Analyzing 'resid_mid'

We've found out that the MLP takes the **resid_mid** and returns an array where inputs with an equivalent sum have very similar vectors. However, things are not so clear in the residual stream right after the attention later (and before the MLP), so our job is to figure out how to sort the vectors in **resid_mid** into their respective usms/classes without the MLP.

If we manage to get a good understanding of how the different classes can be sorted (based on FFT, SVD), we'll be able to combine this with our knowledge of the self-attention layer to finish the interpretation of the model! 

In [None]:
# I've just realized that 'sorted' is better than 'ordered'
sorted_resid_mid = ordered_resid_mid # It's a np array
print(sorted_resid_mid.shape)

Let me first check if the sorting algorithm I wrote is working. To do that I'll use the same algorithm on the output, which will make it very evident whether I made a mistake or not

In [None]:
# I'm going to check if the sorting algorithm is working correctly by using it on the output

# Get output probabilities for dataset
unsorted_out = model(dataset)[:, -1, :].softmax(dim=-1) # Only prediction for '=' matters

# Store all vectors in equivalence classes (same class if same sum)
sorted_logits = [[] for _ in range(p)]

for a in range(p):
    for b in range(p):
        sum = (a + b) % p
        index = a * p + b
        # Put vector of sum 'sum' in index class 'sum'
        sorted_logits[sum].append(utils.to_numpy(unsorted_out[index]))

# Make it into an ndarray so we can use numpy funcionalities
sorted_logits = np.array(sorted_logits) # Has shape (p, p, p) -> (class, n, prob)

# We expect an ordered list from 0 to 112
print(sorted_logits[:, 0].argmax(axis=-1))

In [None]:
# Get SVD of each class of resid_mid
SVD_rm = np.linalg.svd(sorted_resid_mid) # Returns tuple (U, S, V) will all of them

# Get individual components
U_rm = SVD_rm[0] # (113, 113, 113)
S_rm = SVD_rm[1] # (113, 113)
V_rm = SVD_rm[2] # (113, 128, 128)

# Display maximum and mean 'importance' of ith index of S
df = pd.DataFrame({'max': S_rm.max(axis=0), 'min': S_rm.min(axis=0), 'mean': S_rm.mean(axis=0)})
px.line(df)

The 3 graphs are very similar, which implies that the data is somewhat similar too. This is very good news! It means that each sum/class has a similar representation, which makes things more interpretable.

We can see a very sharp drop from 4 to 5, so it seems like a good threshold. If we cannot get what we want from just the first 5 columns we'll use more. (no need, the first 5 carry basically all info)

In [None]:
# Display most significant columns upto up_to
up_to = 5
# Pick from 'n' different classes
n_classes = 8

indexes = np.random.randint(size=n_classes, low=0, high=p)

fig_SVD_rm = make_subplots(
    rows=3, cols=n_classes,
    horizontal_spacing=0.02,
    vertical_spacing=0.05,
    row_titles=['U', 'U_fft', 'V'],
    column_titles=[str(i) for i in indexes]
)

for col in range(n_classes):
    # Pick a random class
    n = indexes[col]

    # Show first 5 rows of U * S
    fig_SVD_rm.add_trace(
        go.Heatmap(
            z=U_rm[n][:, :up_to] * S_rm[n][:up_to],
            colorscale='RdBu',
            zmid=0.,
            showscale=False,
        ), row=1, col=col+1
    )
    # Show first 5 rows of abs FFT of U * S
    fig_SVD_rm.add_trace(
        go.Heatmap(
            z=np.abs(np.fft.fft((U_rm[n][:, :up_to] * S_rm[n][:up_to]), axis=0)),
            colorscale='RdBu',
            zmid=0.,
            showscale=False,
        ), row=2, col=col+1
    )
    # Show first 5 rows of S * V
    fig_SVD_rm.add_trace(
        go.Heatmap(
            z=V_rm[n][:, :up_to] * S_rm[n][:up_to],
            colorscale='RdBu',
            zmid=0.,
            showscale=False
        ), row=3, col=col+1
    )
fig_SVD_rm.update_layout(height=800)
fig_SVD_rm.show()

**OBSERVATION**: They all use (practically) **the same** frequency. This is great news!!! Or is it? I already knew, from the previous analysis, that they all shared similar frequencies. Nonetheless, this does give me some new information. 

**IDEA**: Reduce the FFT to only the positions that repeatedly appear, make a reduced version of the arrays and pass them all through the MLP. If it results in correct probabilities we may have figured it out!

In [None]:
def pass_mlp(resid_mid):
    '''
    Pass batch through MLP and W_U, used to go from resid_mid to logits
    '''
    # turn resid_mid into tensor
    resid_mid = torch.tensor(resid_mid, dtype=torch.float32)
    # Pass through first linear layer
    mlp_in = resid_mid @ params['blocks.0.mlp.W_in'] + params['blocks.0.mlp.b_in']
    # Pass through ReLU
    mlp_relu = torch.relu(mlp_in)
    # Pass through second linear layer
    mlp_out = mlp_relu @ params['blocks.0.mlp.W_out'] + params['blocks.0.mlp.b_out']
    # Add to residual stream
    resid_out =  resid_mid + mlp_out
    # Return logits
    logits = resid_out @ params['unembed.W_U'] + params['unembed.b_U']
    return logits

In [None]:
def reduce_fft(U):
    '''
    Receives U (batch, 113, 113) and returns a reduced version of U with only first 5 cols nonzero and simplified
    Take the FFT and make all frequencies but one on each col zero (indexes below)
    '''
    U_fft = np.fft.fft(U, axis=1)

    # Initialize array that will be reduced FFT, copy DC
    U_fft_red = np.zeros_like(U_fft)
    U_fft_red[:, 0, :5] = U_fft[:, 0, :5]

    # Define which frequencies will be copied from each col
    freq_index = [41, 8, 8, 19, 54]

    for col in range(5):
        # Copy main frequency into respective col
        U_fft_red[:, freq_index[col], col] = U_fft[:, freq_index[col], col]
        # Also copy the mirrored ones
        U_fft_red[:, 113 - freq_index[col], col] = U_fft[:, 113 - freq_index[col], col]

    # Take the inverse FFT
    U_red = np.real(np.fft.ifft(U_fft_red, axis=1))
    return U_red

In [None]:
# Get reduced U_rm
U_rm_red = reduce_fft(U_rm)

# Transform S_rm into diagonal matrix of shape (113, 113, 128
S_rm_diag = [np.hstack([np.diag(S_rm[i]), np.zeros((113, 128-113))]) for i in range(p)]
S_rm_diag = np.array(S_rm_diag)

# Get reduced sorted_resid_mid
red_srm = U_rm_red @ S_rm_diag @ V_rm # (113, 113, 128)

# Stack red_srm for better visualization
stack_red_srm = einops.rearrange(red_srm, 'sum num d_model -> (sum num) d_model')

# Pass through MLP and get logits
imshow(pass_mlp(stack_red_srm[::p]).softmax(dim=-1)) # I'm scared of large plots

In [None]:
pass_mlp(red_srm).softmax(dim=-1).mean(axis=-2).max(axis=-1)[0].mean()

***NICE***

If you zoom-in you can see that almost all inputs return the correct probabilities. Only a few thin strips give the wrong prediction. This means that our dimensionality-reduction **WORKED!**, and interpreting the model will become much easier. Now we're really close to writing down the actual frequencies, I can't believe we're doing so well **FUCK YEAH SCIENCE**.

We've analyzed 'resid_mid', but I wonder what info we can get by analyzing the encoding of '=' and what gets added to it by each head

**IDEA**: Since the encoding of '=' is the same for all inputs, vectors coming from the attention layer (before being added to the encoding of '=') and resid_mid should have all the same properties other than their means. We can compare red_resid_mid with resid_mid, both subtracted from the encoding of '=', to check if my hypothesis holds.

In [None]:
# Defining/mentioning variables we'll use

# Get original and reduced sorted_resid_mid (113, 113, 128) -> (group, num, d_model)
orig_sorted_resid_mid = torch.tensor(sorted_resid_mid) # Easier to deal with tensor for now
red_sorted_resid_mid = torch.tensor(red_srm) # Renaming because why not

# Get vector of '=' on resid_pre (128,)
resid_pre_equals = cache['resid_pre', 0][0, -1] # All are equal, so we can pick just the first