# Setup

In [29]:
!nvidia-smi 
import sys
sys.path.append('../')

Mon Oct 31 13:49:26 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 46%   32C    P8     6W / 105W |    871MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops

import json

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# EasyTransformer interpretability tooling
from utils.hook_points import HookPoint, HookedRootModule


In [31]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Training might be rather slow')

Good to go!


## Helper Functions 

### Plotting

In [32]:
# This is mostly a bunch of over-engineered mess to hack Plotly into producing 
# the pretty pictures I want, I recommend not reading too closely unless you 
# want Plotly hacking practice
def to_numpy(tensor, flat=False):
    if type(tensor)!=torch.Tensor:
        return tensor
    if flat:
        return tensor.flatten().detach().cpu().numpy()
    else:
        return tensor.detach().cpu().numpy()

def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', **kwargs):
    # if tensor.shape[0]==p*p:
    #     tensor = unflatten_first(tensor)
    tensor = torch.squeeze(tensor)
    px.imshow(to_numpy(tensor, flat=False), 
              labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name}, 
              **kwargs).show()

# Set default colour scheme
imshow_pos = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both 
# positive and negative values, where 0 is white)

imshow = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
# Presets a bunch of defaults to imshow to make it suitable for showing heatmaps 
# of activations with x axis being input 1 and y axis being input 2.

def imshow_fourier(tensor, fourier_basis_names, title='', animation_name='snapshot', facet_labels=[], **kwargs):
    # Set nice defaults for plotting functions in the 2D fourier basis
    # tensor is assumed to already be in the Fourier Basis
    tensor = torch.squeeze(tensor)
    fig=px.imshow(to_numpy(tensor),
            x=fourier_basis_names, 
            y=fourier_basis_names, 
            labels={'x':'x Component', 
                    'y':'y Component', 
                    'animation_frame':animation_name},
            title=title,
            color_continuous_midpoint=0., 
            color_continuous_scale='RdBu', 
            **kwargs)
    fig.update(data=[{'hovertemplate':"%{x}x * %{y}y<br>Value:%{z:.4f}"}])
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    fig.show()


inputs_heatmap = partial(imshow, xaxis='Input 1', yaxis='Input 2', color_continuous_scale='RdBu', color_continuous_midpoint=0.0)

def line(x, y=None, hover=None, xaxis='', yaxis='', save=False, **kwargs):
    if type(y)==torch.Tensor:
        y = to_numpy(y, flat=True)
    if type(x)==torch.Tensor:
        x=to_numpy(x, flat=True)
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    fig.show()
    if save:
        fig.write_image(save)

def scatter(x, y, **kwargs):
    px.scatter(x=to_numpy(x, flat=True), y=to_numpy(y, flat=True), **kwargs).show()

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, save=False, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()
    if save:
        fig.write_image(save)

def animate_lines(lines_list, snapshot_index = None, snapshot='snapshot', hover=None, xaxis='x', yaxis='y', **kwargs):
    if type(lines_list)==list:
        lines_list = torch.stack(lines_list, axis=0)
    lines_list = to_numpy(lines_list, flat=False)
    if snapshot_index is None:
        snapshot_index = np.arange(lines_list.shape[0])
    if hover is not None:
        hover = [i for j in range(len(snapshot_index)) for i in hover]
    rows=[]
    for i in range(lines_list.shape[0]):
        for j in range(lines_list.shape[1]):
            rows.append([lines_list[i][j], snapshot_index[i], j])
    df = pd.DataFrame(rows, columns=[yaxis, snapshot, xaxis])
    px.line(df, x=xaxis, y=yaxis, animation_frame=snapshot, range_y=[lines_list.min(), lines_list.max()], hover_name=hover,**kwargs).show()

### Fourier Transforms

In [33]:
def fft2d(mat, fourier_basis):
    # Converts a pxpx... tensor into the 2D Fourier basis.
    # Output has the same shape as the original
    shape = mat.shape
    p = fourier_basis.shape[0]
    fourier_mat = torch.einsum('xyz,fx,Fy->fFz', mat, fourier_basis, fourier_basis)
    return fourier_mat

## Define Groups

### Define Group Parent Class

In [34]:
class Group:
    """
    parent class for all groups
    """
    def __init__(self, index, order, fourier_order):
        self.index = index
        self.order = order 
        self.fourier_order = fourier_order  
        self.compute_multiplication_table()

    def compose(self, x, y):
        raise NotImplementedError

    def inverse(self, x):
        raise NotImplementedError
    
    def compute_multiplication_table(self):
        table = torch.zeros((self.order, self.order), dtype=torch.int64)
        for i in range(self.order):
            for j in range(self.order):
                table[i, j] = self.compose(i, j)
        self.multiplication_table = table
    
    def get_all_data(self, shuffle_seed=False):
        data=torch.zeros((self.order*self.order, 3), dtype=torch.int64)
        for i in range(self.order):
            for j in range(self.order):
                data[i*self.order+j, 0] = i
                data[i*self.order+j, 1] = j
                data[i*self.order+j, 2] = self.multiplication_table[i, j]
        if shuffle_seed:
            torch.manual_seed(shuffle_seed)
            shuffled_indices = torch.randperm(self.order*self.order)
            data = data[shuffled_indices]
        return data
    

    # Fourier basis for cylic-y groups. Should refactor into a different parent class eventually.
    def compute_fourier_basis(self):
        # compute a (frequency, position) tensor encoding the fourier basis
        fourier_basis = []
        fourier_basis.append(torch.ones(self.order)/np.sqrt(self.order))
        fourier_basis_names = ['Const']
        # Note that if p is even, we need to explicitly add a term for cos(kpi), ie 
        # alternating +1 and -1
        for i in range(1, self.fourier_order):
            fourier_basis.append(torch.cos(2*torch.pi*torch.arange(self.order)*i/self.order))
            fourier_basis.append(torch.sin(2*torch.pi*torch.arange(self.order)*i/self.order))
            fourier_basis[-2]/=fourier_basis[-2].norm()
            fourier_basis[-1]/=fourier_basis[-1].norm()
            fourier_basis_names.append(f'cos {i}')
            fourier_basis_names.append(f'sin {i}')

        self.fourier_basis = torch.stack(fourier_basis, dim=0).cuda()
        self.fourier_basis_names = fourier_basis_names  
    
    def animate_fourier_basis(self):
        animate_lines(self.fourier_basis, snapshot_index=self.fourier_basis_names, snapshot='Fourier Component', title='Graphs of Fourier Components (Use Slider)')
    


### Define Individual Groups

In [35]:

class CyclicGroup(Group):
    def __init__(self, index):
        super().__init__(index = index, order = index, fourier_order = index//2+1)        

    def compose(self, x, y):
        return (x+y)%self.order

    def inverse(self, x):
        return -x%self.order
        

class DihedralGroup(Group):
    """
    Dihedral group of order 2*index. First index elements are rotations, second are reflections.
    i.e. indexed as [e, r, r^2, ..., r^p, s, rs, r^2s, ..., r^ps]
    """
    def __init__(self, index):
        super().__init__(index = index, order = 2*index, fourier_order = index)        

    def idx_to_cpts(self, x):
        r = x % self.index
        # this could be rewritten in a single line
        if x >= self.index: 
            s = 1
        else: 
            s = 0
        return r, s

    def cpts_to_idx(self, r, s):
        return r + s*self.index

    def compose(self, x, y):
        x_r, x_s = self.idx_to_cpts(x)
        y_r, y_s = self.idx_to_cpts(y)
        if x_s == 0:
            z_r = (x_r + y_r) % self.index
            z_s = y_s
        elif x_s == 1:
            z_r = (x_r - y_r) % self.index
            z_s = (1 + y_s) % 2
        return self.cpts_to_idx(z_r, z_s)

#class SymettricGroup(Group):

## Define Model

In [36]:
class BilinearNet(HookedRootModule):
    """
    A completely linear network. W_E_a and W_E_b are embedding layers, whose outputs are elementwise multiplied. The result is unembedded by W_U.
    """
    def __init__(self, layers, n, seed=0):
        # embed_dim : dimension of the embedding
        # n : group order
        super().__init__()
        torch.manual_seed(seed)

        embed_dim = layers['embed_dim']
        
        # initialise parameters
        self.W_E_a = nn.Parameter(torch.randn(n, embed_dim)/np.sqrt(embed_dim))
        self.W_E_b = nn.Parameter(torch.randn(n, embed_dim)/np.sqrt(embed_dim))
        self.W_U = nn.Parameter(torch.randn(embed_dim, n)/np.sqrt(embed_dim))

        self.x_embed = HookPoint()
        self.y_embed = HookPoint()
        self.product = HookPoint()
        self.out = HookPoint()
        
        # We need to call the setup function of HookedRootModule to build an 
        # internal dictionary of modules and hooks, and to give each hook a name
        super().setup()

    def forward(self, data):
        x = data[:, 0] # (batch) 
        x_embed = self.x_embed(self.W_E_a[x]) # (batch, embed_dim)
        y = data[:, 1]
        y_embed = self.y_embed(self.W_E_b[y]) # (batch, embed_dim)
        product = self.product(x_embed * y_embed) # (batch, embed_dim)
        out = self.out(product @ self.W_U) # (batch, n)
        return out

class OneLayerMLP(HookedRootModule):
    def __init__(self, layers, n, seed=0):
        # embed_dim: dimension of the embedding
        # hidden : hidden dimension size
        # n : group order
        super().__init__()
        torch.manual_seed(seed)

        embed_dim = layers['embed_dim']
        hidden = layers['hidden_dim']

        # xavier initialise parameters
        self.W_E_a = nn.Parameter(torch.randn(n, embed_dim)/np.sqrt(embed_dim))
        self.W_E_b = nn.Parameter(torch.randn(n, embed_dim)/np.sqrt(embed_dim))
        self.W = nn.Parameter(torch.randn(2*embed_dim, hidden)/np.sqrt(2*embed_dim))
        self.relu = nn.ReLU()
        self.W_U = nn.Parameter(torch.randn(hidden, n)/np.sqrt(hidden))

        # hookpoints
        self.x_embed = HookPoint()
        self.y_embed = HookPoint()
        self.embed_stack = HookPoint()
        self.hidden = HookPoint()
        self.out = HookPoint()

        # We need to call the setup function of HookedRootModule to build an 
        # internal dictionary of modules and hooks, and to give each hook a name
        super().setup()

    def forward(self, data):
        x = data[:, 0] # (batch)
        x_embed = self.x_embed(self.W_E_a[x]) # (batch, embed_dim)
        y = data[:, 1] # (batch)
        y_embed = self.y_embed(self.W_E_b[y]) # (batch, embed_dim)
        embed_stack = self.embed_stack(torch.hstack((x_embed, y_embed))) # (batch, 2*embed_dim)
        hidden = self.hidden(self.relu(embed_stack @ self.W)) # (batch, hidden)
        out = self.out(hidden @ self.W_U) # (batch, n)
        return out

## Generate Data and Loss Functions 

In [37]:
def generate_train_test_data(group, frac_train, seed=False):
    data = group.get_all_data(seed).cuda()
    train_size = int(frac_train*data.shape[0])
    train = data[:train_size]
    test = data[train_size:]
    train_data = train[:, :2]
    train_labels = train[:, 2]
    test_data = test[:, :2]
    test_labels = test[:, 2]
    return train_data, test_data, train_labels, test_labels

def loss_fn(logits, labels):
    loss = F.cross_entropy(logits, labels)
    return loss

## Model Training


In [38]:
def load_cfg(task_dir):
    cfg_file = open(f"{task_dir}/cfg.json")
    cfg = json.load(cfg_file)

    seed = cfg['seed'] # TODO: don't set seed to 0, or generating train test data gets broken (not shuffled) - fix this
    frac_train = cfg['frac_train']
    layers = cfg['layers']
    lr = cfg['lr']
    group_param = cfg['group_parameter']
    weight_decay = cfg['weight_decay']
    num_epochs = cfg['num_epochs']
    group_type = eval(cfg['group'])
    architecture_type = eval(cfg['model'])
    return seed, frac_train, layers, lr, group_param, weight_decay, num_epochs, group_type, architecture_type

In [39]:
train = False

task_dir = "1L_MLP_dihedral"
seed, frac_train, layers, lr, group_param, weight_decay, num_epochs, group_type, architecture_type = load_cfg(task_dir)
group = group_type(group_param)

if train:

    train_data, test_data, train_labels, test_labels = generate_train_test_data(group, frac_train, seed)

    train_losses = []
    test_losses = []
    train_accs = []
    test_accs = []

    model = architecture_type(layers, group.order, seed)
    model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in tqdm(range(num_epochs)):
        train_logits = model(train_data)
        train_loss = loss_fn(train_logits, train_labels)
        train_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_losses.append(train_loss.item())
        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())
            train_acc = (train_logits.argmax(1)==train_labels).sum()/len(train_labels)
            test_acc = (test_logits.argmax(1)==test_labels).sum()/len(test_labels)
            train_accs.append(train_acc.item())
            test_accs.append(test_acc.item())
        if epoch%1000 == 0:
            print(f"Epoch:{epoch}, Train: L: {train_losses[-1]:.6f} A: {train_accs[-1]*100:.4f}, Test: L: {test_losses[-1]:.6f} A: {test_accs[-1]*100:.4f}%")
        #if epoch%10000 == 0 and epoch>0:
            #lines([train_losses, test_losses], log_y=True, labels=['train loss', 'test loss'])
            #lines([train_accs, test_accs], log_y=False, labels=['train acc', 'test acc'])

    # Save model
    torch.save(model.state_dict(), f"{task_dir}/model.pt")
    lines([train_losses, test_losses], log_y=True, labels=['train loss', 'test loss'], save=f"{task_dir}/loss.png")
    lines([train_accs, test_accs], log_y=False, labels=['train acc', 'test acc'], save=f"{task_dir}/acc.png")


# Interpretability

Let's reverse engineer the 1 layer MLP trained on modular addition

## Interpretability Set Up 

In [51]:
task_dir = "1L_MLP_dihedral"
seed, frac_train, width, lr, group_param, weight_decay, num_epochs, group_type, architecture_type = load_cfg(task_dir)
group = group_type(group_param)
model = architecture_type(width, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"))
model.eval()
all_data, all_labels, _, _ = generate_train_test_data(group, frac_train = 1)
logits, activations = model.run_with_cache(all_data, return_cache_object=False)

## Interpretability Helper Functions

## Look at embeddings

The embedding is sparse in the Fourier basis, and throws away all Fourier components apart from a handful of frequencies (the number of frequencies and their values are arbitrary, and vary between training runs)

I call the frequencies with non-trivial norm the key frequencies (here: 2, 24, 48, 53, 64). I don't know whats going on with 17 and 33.

In [52]:
group.compute_fourier_basis()
group.animate_fourier_basis()

In [53]:
# Embeddings
# W_E (position, embed_dim), fourier_basis (frequency, position)

W_E_a_norm_embeddings = (model.W_E_a.T @ group.fourier_basis.T).pow(2).sum(0)
W_E_b_norm_embeddings = (model.W_E_b.T @ group.fourier_basis.T).pow(2).sum(0)

line(W_E_a_norm_embeddings, 
      hover=group.fourier_basis_names,
      title='Norm of embedding coef of each Fourier Component in W_E_a',
      xaxis='Fourier Component',
      yaxis='Norm')

line(W_E_b_norm_embeddings, 
    hover=group.fourier_basis_names,
    title='Norm of embedding coef of each Fourier Component in W_E_b',
    xaxis='Fourier Component',
    yaxis='Norm')

lines([W_E_a_norm_embeddings, W_E_b_norm_embeddings], hover=group.fourier_basis_names, title=f'Norm of embedding coef of each Fourier Component in W_E_a/b', labels=['W_E_a', 'W_E_b'], xaxis='Fourier Component', yaxis='norm')

In [54]:
# Unembedding - here we are taking the 
# W_U (embed_dim, position), fourier_basis (frequency, position)

W_U_norm_embeddings = (model.W_U @ group.fourier_basis.T).pow(2).sum(0)

line(W_U_norm_embeddings, 
      hover=group.fourier_basis_names,
      title='Norm of embedding coef of each Fourier Component in W_U',
      xaxis='Fourier Component',
      yaxis='Norm')

## Understanding Activation Patterns

Let's first look at the relevant activations (the hidden and output layers) in the standard basis.

### Understanding Individual Neuron Activations

In [55]:
# first, reshape from (batch) to (group.order, group.order) so that we may easily fourier transform on each of the two input dimensions.

plot_neurons = [0,35,67]
activations_to_plot = ['hidden', 'out']
for key, value in activations.items():
    activations[key] = value.reshape(group.order, group.order, -1)

for key, value in activations.items():
    if key in activations_to_plot:
        for neuron in plot_neurons:
            imshow(activations[key][:, :, neuron], title=f'{key} activations, neuron {neuron}', xaxis='position 1', yaxis='position 2')


### Understanding Individual Hidden Layer Neurons

Given the embeddings are fourier, and the activations are fourier - we should be able to understand the internal activations in the fourier basis. We see below that each neuron seems to activate for a linear combination of the following 9 terms for some frequency $w$: 

$1, \cos(wx), \sin(wx), \cos(wy), \sin(wy), \cos(wx)\cos(wy), \cos(wx)\sin(wy), \sin(wx)\cos(wy), \sin(wx)\sin(wy)$. 

In particular, the neurons cluster into different **representations**. 

Neuron 67 seems to activate on three different frequencies...

In [56]:
hidden_activations = activations['hidden']
fourier_activations = fft2d(hidden_activations, group.fourier_basis)

for neuron in plot_neurons:
    imshow_fourier(fourier_activations[:, :, neuron], group.fourier_basis_names, title=f'hidden activations, neuron {neuron}')

### Understanding all hidden layer neurons

Let's sum up the activations of all neurons in the network, and look at the result. Note we first center the activation to have mean zero. From here on, everything is linear, as the unembedding is linear. Each neurons average value corresponds to a constant vector added to the logits, so the total effect of all neurons is adding a constant vector to the logits, which doesn't affect probabilites.

In [57]:
hidden_activations_centred = hidden_activations - einops.reduce(hidden_activations, 'order1 order2 neuron -> 1 neuron', 'mean')
fourier_activations_centred = fft2d(hidden_activations_centred, group.fourier_basis)
imshow_fourier(fourier_activations_centred.pow(2).mean(2), group.fourier_basis_names, title=f'Average Norm of 2d Fourier Coefficients of hidden activations (excl constant)')

## Neuron Clustering

Each neuron seems to activate for a particular frequency. Let's see how much of the neurons activation and variance this frequency can explain.

We then the fraction of the variance explained.



In [58]:
def extract_freq_2d(tensor, freq):
    # Takes in a pxpx... tensor, returns a 3x3x... tensor of the linear and quadratic terms of frequency freq
    index_1d = [0, 2*freq-1, 2*freq]
    # Some dumb manipulation to use fancy array indexing rules
    # Gets the rows and columns in index_1d
    return tensor[[[i]*3 for i in index_1d], [index_1d]*3]

neuron_freqs = []
neuron_frac_explained = []
for n in range(fourier_activations.shape[2]):
    best_frac_explained = -1
    best_freq = -1
    for freq in range(1, group.fourier_basis.shape[0]//2):
        frac_explained = (extract_freq_2d(fourier_activations[:, :, n], freq).pow(2).sum() / fourier_activations[:, :, n].pow(2).sum()).item()
        if frac_explained > best_frac_explained:
            best_frac_explained = frac_explained
            best_freq = freq
    neuron_freqs.append(best_freq)
    neuron_frac_explained.append(best_frac_explained)

scatter(x=neuron_freqs, 
        y=neuron_frac_explained, 
        labels={'x':'Neuron frequency', 
                'y':'Frac explained'},
        color_continuous_scale='Viridis')

neuron_freqs = np.array(neuron_freqs)
neuron_frac_explained = np.array(neuron_frac_explained)
key_freqs, neuron_freq_counts = np.unique(neuron_freqs, return_counts=True)


for i in range(len(key_freqs)):
    print(f'Cluster {i}: freq {key_freqs[i]}. {(neuron_freqs==key_freqs[i]).sum()} neurons')

Cluster 0: freq 1. 4 neurons
Cluster 1: freq 4. 4 neurons
Cluster 2: freq 10. 9 neurons
Cluster 3: freq 11. 1 neurons
Cluster 4: freq 12. 1 neurons
Cluster 5: freq 13. 2 neurons
Cluster 6: freq 15. 1 neurons
Cluster 7: freq 16. 8 neurons
Cluster 8: freq 17. 7 neurons
Cluster 9: freq 20. 3 neurons
Cluster 10: freq 23. 5 neurons
Cluster 11: freq 24. 7 neurons
Cluster 12: freq 25. 6 neurons
Cluster 13: freq 26. 8 neurons
Cluster 14: freq 31. 3 neurons
Cluster 15: freq 33. 3 neurons
Cluster 16: freq 42. 5 neurons
Cluster 17: freq 44. 8 neurons
Cluster 18: freq 45. 7 neurons
Cluster 19: freq 46. 8 neurons
Cluster 20: freq 48. 9 neurons
Cluster 21: freq 50. 9 neurons
Cluster 22: freq 51. 10 neurons


### TODO: Validations

To validate that the neurons in the specific frequency clusters aren't doing anything with the other frequencies, we can set terms for any other frequency to zero. The resulting loss is close to baseline.

## Understanding Logit Computation

TLDR: The network uses $W_U$ to cancel out all 2D Fourier components other than the directions corresponding to $\cos(w(x+y)),\sin(w(x+y))$, and then multiplies these directions by $\cos(wz),\sin(wz)$ respectively and sums to get the output logits.

To see that it cancels out other directions, we can transform the neuron activations and logits to the 2D Fourier Basis, and look at the norm of the vector corresponding to each fourier component - **we see that the quadratic terms have *much* higher norm in the logits than neuron activations, and linear terms are close to zero.**

In [59]:
imshow_fourier(fourier_activations_centred.pow(2).mean(2), group.fourier_basis_names, title=f'Average Norm of 2d Fourier Coefficients of hidden activations (excl constant)')
fourier_logits = fft2d(logits.reshape(group.order, group.order, -1), group.fourier_basis)
imshow_fourier(fourier_logits.pow(2).mean(2), group.fourier_basis_names, title='Average Norm of Fourier Components of Logits')

To see it really is calculating $\cos(w(x+y))\cos(wz)+\sin(w(x+y))\sin(wz)$, we first recall the above plot showing norms of fourier components in the unembedding matrix is sparse. The unembedding matrix has shape (hidden, group.order). Let's redo that plot, but restrict to hidden rows corresponding to neurons of a fixed frequency. We should see that the frequencies match.

In [60]:
W_U_norm_embeddings_freq = []
for freq in key_freqs:
      W_U_norm_embeddings_freq.append((model.W_U[neuron_freqs==freq, :] @ group.fourier_basis.T).pow(2).sum(0))

lines(W_U_norm_embeddings_freq, 
      hover=group.fourier_basis_names,
      title=f'Norm of embedding coef of each Fourier Component in W_U on frequency x neurons',
      labels=[str(x) for x in key_freqs],
      xaxis='Fourier Component',
      yaxis='Norm')

Finally, we should check that the direction of $W_U$ corresponding to $\cos(wz)$ extracts the $\cos(w(x+y))$ component of the neuron activations, and likewise for sin. We aim to show the network is computing $\cos(w(x+y-z)) = \cos(w(x+y))\cos(wz) + \sin(w(x+y))\sin(wz)$

So, when plotted the direction corresponding to $cos(wz)$ in the 2D fourier basis, we should see a coeffiecent of $+C$ on $\cos(wx)\cos(wy)$, and $-C$ on $\sin(wx)\sin(wy)$.
And, when plotting the direction corresponding to $\sin(wz)$ in the 2D fourier basis, we should see a coeffiecent $+S$ on $\cos(wx)\sin(wy)$, and $S$ on $\sin(wx)\cos(wy)$.

TODO: Take a fourier logit frequency, and plot as a function of x and y the corresponding components. Might help see what the fucky frequencies are doing. They do have small norm compared to the actual frequencies though.

In [61]:
fourier_unembed = model.W_U @ group.fourier_basis.T # (hidden, position) @ (position, frequency) -> (hidden, frequency)
fourier_logits = hidden_activations @ fourier_unembed # (batch, hidden) @ (hidden, frequency) -> (batch, frequency)
fourier_logits = fourier_logits.reshape(group.order, group.order, group.order) 

for freq in key_freqs[:2]:
    imshow_fourier(fft2d(fourier_logits[:, :, 2*freq-1].unsqueeze(dim=2), group.fourier_basis), group.fourier_basis_names, title=f'Component of logits in direction cos({freq})')
    imshow_fourier(fft2d(fourier_logits[:, :, 2*freq].unsqueeze(dim=2), group.fourier_basis), group.fourier_basis_names, title=f'Component of logits in direction sin({freq})')


RuntimeError: shape '[106, 106, 106]' is invalid for input of size 1179780


## Look at cosine similarity to representations

### Over the whole model

### Over individual neurons