# Setup

In [1]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

Mon Nov  7 12:40:30 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 46%   30C    P8     6W / 105W |    787MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops

import json

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *


In [3]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Training might be rather slow')

Good to go!


# Interpretability

Let's reverse engineer the 1 layer MLP trained on dihedral composition.

## Interpretability Set Up 

In [4]:
task_dir = "1L_MLP_dihedral" #_cached
seed, frac_train, width, lr, group_param, weight_decay, num_epochs, group_type, architecture_type = load_cfg(task_dir)
group = group_type(group_param)
model = architecture_type(width, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"))
model.eval()
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
logits, activations = model.run_with_cache(all_data, return_cache_object=False)
activations['logits'] = logits

## Interpretability Helper Functions

## Look at embeddings

The embedding is sparse in the Fourier basis, and throws away all Fourier components apart from a handful of frequencies (the number of frequencies and their values are arbitrary, and vary between training runs)

I call the frequencies with non-trivial norm the key frequencies (here: 8, 12, 21, 23).

In [5]:
group.compute_fourier_basis()
group.animate_fourier_basis()

In [6]:
# Embeddings
# W_E (position, embed_dim), fourier_basis (frequency, position)


W_E_a_rot, W_E_a_ref = torch.tensor_split(model.W_E_a, 2, dim=0)
W_E_b_rot, W_E_b_ref = torch.tensor_split(model.W_E_b, 2, dim=0)

W_E_a_rot_norm_embeddings = (W_E_a_rot.T @ group.fourier_basis.T).pow(2).sum(0)
W_E_a_ref_norm_embeddings = (W_E_a_ref.T @ group.fourier_basis.T).pow(2).sum(0)
W_E_a_norm_embeddings = W_E_a_rot_norm_embeddings + W_E_a_ref_norm_embeddings

W_E_b_rot_norm_embeddings = (W_E_b_rot.T @ group.fourier_basis.T).pow(2).sum(0)
W_E_b_ref_norm_embeddings = (W_E_b_ref.T @ group.fourier_basis.T).pow(2).sum(0)
W_E_b_norm_embeddings = W_E_b_rot_norm_embeddings + W_E_b_ref_norm_embeddings

line(W_E_a_norm_embeddings, 
      hover=group.fourier_basis_names,
      title='Norm of embedding coef of each Fourier Component in W_E_a',
      xaxis='Fourier Component',
      yaxis='Norm')

line(W_E_b_norm_embeddings, 
      hover=group.fourier_basis_names,
      title='Norm of embedding coef of each Fourier Component in W_E_b',
      xaxis='Fourier Component',
      yaxis='Norm')

lines([W_E_a_norm_embeddings, W_E_b_norm_embeddings], hover=group.fourier_basis_names, title=f'Norm of embedding coef of each Fourier Component in W_E_a/b', labels=['W_E_a', 'W_E_b'], xaxis='Fourier Component', yaxis='norm')

In [7]:
# Unembedding
# W_U (embed_dim, position), fourier_basis (frequency, position)

W_U_rot, W_U_ref = torch.tensor_split(model.W_U, 2, dim=1)

W_U_norm_embeddings = ((W_U_rot + W_U_ref) @ group.fourier_basis.T).pow(2).sum(0)

line(W_U_norm_embeddings, 
      hover=group.fourier_basis_names,
      title='Norm of embedding coef of each Fourier Component in W_U',
      xaxis='Fourier Component',
      yaxis='Norm')

## Understanding Activation Patterns

Let's first look at the relevant activations (the hidden and output layers) in the standard basis.

### Understanding Individual Neuron Activations

In [8]:
# first, reshape from (batch) to (group.order, group.order) so that we may easily fourier transform on each of the two input dimensions.

plot_neurons = [0,1,2]
activations_to_plot = ['hidden', 'logits']
for key, value in activations.items():
    activations[key] = value.reshape(group.order, group.order, -1)

for key, value in activations.items():
    if key in activations_to_plot:
        for neuron in plot_neurons:
            imshow(activations[key][:, :, neuron], title=f'{key} activations, neuron {neuron}', input1='position 1', input2='position 2')


### Understanding Individual Hidden Layer Neurons

Given the embeddings are fourier, and the activations are fourier - we should be able to understand the internal activations in the fourier basis. We see below that each neuron seems to activate for a linear combination of the following 9 terms for some frequency $w$: 

$1, \cos(wx), \sin(wx), \cos(wy), \sin(wy), \cos(wx)\cos(wy), \cos(wx)\sin(wy), \sin(wx)\cos(wy), \sin(wx)\sin(wy)$. 

In particular, the neurons cluster into different **representations**. 

Neuron 67 seems to activate on three different frequencies...

In [9]:
def untile2d(mat, p):
    mats = [
        mat[:p, :p],
        mat[:p, p:],
        mat[p:, :p],
        mat[p:, p:]
        ]
    return torch.stack(mats)

def tile2d(tensor):
    return torch.cat([torch.cat([tensor[0], tensor[1]], dim=1), torch.cat([tensor[2], tensor[3]], dim=1)], dim=0)

hidden_activations = activations['hidden']
hidden_activations = untile2d(hidden_activations, group.index)
fourier_activations = fft2d(hidden_activations, group.fourier_basis, stack=True)
region_to_eng = {
    0:  'rot, rot',
    1:  'rot, ref',
    2:  'ref, rot',
    3:  'ref, ref',
}
plot_neurons = [0,1,2,3,4,5,6,7,8]
for neuron in plot_neurons:
    for region in [1]:
        imshow_fourier(fourier_activations[region, :, :, neuron], group.fourier_basis_names, title=f'hidden activations, neuron {neuron}, region {region_to_eng[region]}')

### Understanding all hidden layer neurons

Let's sum up the activations of all neurons in the network, and look at the result. Note we first center the activation to have mean zero. From here on, everything is linear, as the unembedding is linear. Each neurons average value corresponds to a constant vector added to the logits, so the total effect of all neurons is adding a constant vector to the logits, which doesn't affect probabilites.

In [10]:
hidden_activations_centred = hidden_activations - einops.reduce(hidden_activations, 'region order1 order2 neuron -> 1 neuron', 'mean')
fourier_activations_centred = fft2d(hidden_activations_centred, group.fourier_basis, stack=True)
imshow_fourier(fourier_activations_centred.pow(2).mean([0, -1]), group.fourier_basis_names, title=f'Average Norm of 2d Fourier Coefficients of hidden activations (excl constant)')

## Neuron Clustering

Each neuron seems to activate for a particular frequency. Let's see how much of the neurons activation and variance this frequency can explain.



In [11]:
def extract_freq_2d(tensor, freq):
    # Takes in a pxpx... tensor, returns a 3x3x... tensor of the linear and quadratic terms of frequency freq
    index_1d = [0, 2*freq-1, 2*freq]
    # Some dumb manipulation to use fancy array indexing rules
    # Gets the rows and columns in index_1d
    return tensor[[[i]*3 for i in index_1d], [index_1d]*3]

neuron_freqs = []
neuron_frac_explained = []
for n in range(fourier_activations.shape[3]):
    best_frac_explained = -1
    best_freq = -1
    neuron_norm = fourier_activations[:, :, :, n].pow(2).sum()
    for freq in range(1, group.fourier_basis.shape[0]//2):
        frac_explained = 0
        for region in fourier_activations:
            frac_explained += (extract_freq_2d(region[:, :, n], freq).pow(2).sum() / neuron_norm).item()
        if frac_explained > best_frac_explained:
            best_frac_explained = frac_explained
            best_freq = freq
    neuron_freqs.append(best_freq)
    neuron_frac_explained.append(best_frac_explained)

scatter(x=neuron_freqs, 
        y=neuron_frac_explained, 
        labels={'x':'Neuron frequency', 
                'y':'Frac explained'},
        color_continuous_scale='Viridis')

neuron_freqs = np.array(neuron_freqs)
neuron_frac_explained = np.array(neuron_frac_explained)
key_freqs, neuron_freq_counts = np.unique(neuron_freqs, return_counts=True)


for i in range(len(key_freqs)):
    print(f'Cluster {i}: freq {key_freqs[i]}. {(neuron_freqs==key_freqs[i]).sum()} neurons')

Cluster 0: freq 2. 18 neurons
Cluster 1: freq 3. 16 neurons
Cluster 2: freq 9. 10 neurons
Cluster 3: freq 10. 15 neurons
Cluster 4: freq 11. 12 neurons
Cluster 5: freq 19. 15 neurons
Cluster 6: freq 21. 14 neurons
Cluster 7: freq 25. 28 neurons


### TODO: Validations

To validate that the neurons in the specific frequency clusters aren't doing anything with the other frequencies, we can set terms for any other frequency to zero. The resulting loss is close to baseline.

## Understanding Logit Computation

TLDR: The network uses $W_U$ to cancel out all 2D Fourier components other than the directions corresponding to $\cos(w(x+y)),\sin(w(x+y))$, and then multiplies these directions by $\cos(wz),\sin(wz)$ respectively and sums to get the output logits.

To see that it cancels out other directions, we can transform the neuron activations and logits to the 2D Fourier Basis, and look at the norm of the vector corresponding to each fourier component - **we see that the quadratic terms have *much* higher norm in the logits than neuron activations, and linear terms are close to zero.**

In [12]:
imshow_fourier(fourier_activations_centred.pow(2).mean([0, -1]), group.fourier_basis_names, title=f'Average Norm of 2d Fourier Coefficients of hidden activations (excl constant)')
fourier_logits = fft2d(untile2d(logits.reshape(group.order, group.order, -1),group.index), group.fourier_basis, stack=True)
imshow_fourier(fourier_logits.pow(2).mean([0, -1]), group.fourier_basis_names, title='Average Norm of Fourier Components of Logits')

To see it really is calculating $\cos(w(x+y))\cos(wz)+\sin(w(x+y))\sin(wz)$, we first recall the above plot showing norms of fourier components in the unembedding matrix is sparse. The unembedding matrix has shape (hidden, group.order). Let's redo that plot, but restrict to hidden rows corresponding to neurons of a fixed frequency. We should see that the frequencies match.

In [13]:
W_U_norm_embeddings_freq = []
for freq in key_freqs:
      W_U_rot, W_U_ref = torch.tensor_split(model.W_U, 2, dim=1)

      W_U_norm_embeddings_freq.append( ((W_U_rot[neuron_freqs==freq, :] + W_U_ref[neuron_freqs==freq, :]) @ group.fourier_basis.T).pow(2).sum(0) )

lines(W_U_norm_embeddings_freq, 
      hover=group.fourier_basis_names,
      title=f'Norm of embedding coef of each Fourier Component in W_U on frequency x neurons',
      labels=[str(x) for x in key_freqs],
      xaxis='Fourier Component',
      yaxis='Norm')

Finally we check what precisely the network is computing. It is doing four different things. Composition of 

1) rot rot
2) rot ref
3) ref rot
4) ref ref

It does this using a number of frequencies (we just show 1)

There are two types of output logit 
1) rotations
2) reflections

and two relevant directions the logit space corresponding to each frequency (see above)

1) cos(z) 
2) sin(z)





In [17]:
logits_rot, logits_ref = torch.tensor_split(logits, 2, dim=-1)

fourier_logits = logits_rot @ group.fourier_basis.T # (batch, position) @ (position, frequency) -> (batch, frequency)
fourier_logits = fourier_logits.reshape(group.order, group.order, group.index) 
untiled_fourier_logits = untile2d(fourier_logits, group.index)
print(untiled_fourier_logits.shape)


for idx, region in region_to_eng.items():
        
    for freq in key_freqs[-1:]:

        #now fourier transform input space too
        cos_logits_freq = fft2d(untiled_fourier_logits[idx, :, :, 2*freq-1].unsqueeze(dim=-1), group.fourier_basis)
        sin_logits_freq = fft2d(untiled_fourier_logits[idx, :, :, 2*freq].unsqueeze(dim=-1), group.fourier_basis)


        imshow_fourier(cos_logits_freq, group.fourier_basis_names, title=f'Component of logits in direction cos({freq}) on region {region}')
        imshow_fourier(sin_logits_freq, group.fourier_basis_names, title=f'Component of logits in direction sin({freq}) on region {region}')


torch.Size([4, 53, 53, 53])


This is weird. I would have expected the logits in regions where parity disallows the result to be identically zero. Let's ablate and see what happens to loss.

In [15]:
logits_rot, logits_ref = torch.tensor_split(logits, 2, dim=-1)

fourier_logits_rot = logits_rot @ group.fourier_basis.T # (batch, position) @ (position, frequency) -> (batch, frequency)
fourier_logits_ref = logits_ref @ group.fourier_basis.T # (batch, position) @ (position, frequency) -> (batch, frequency)

fourier_logits_rot = fourier_logits_rot.reshape(group.order, group.order, group.index)
fourier_logits_ref = fourier_logits_ref.reshape(group.order, group.order, group.index) 
untiled_fourier_logits_rot = untile2d(fourier_logits_rot, group.index)
untiled_fourier_logits_ref = untile2d(fourier_logits_ref, group.index)

untiled_fourier_logits_rot_new = untiled_fourier_logits_rot.clone()
untiled_fourier_logits_ref_new = untiled_fourier_logits_ref.clone()

for idx, region in region_to_eng.items():
        
    for freq in key_freqs:

        
        cos_logits_freq_rot = fft2d(untiled_fourier_logits_rot[idx, :, :, 2*freq-1].unsqueeze(dim=-1), group.fourier_basis)
        sin_logits_freq_rot = fft2d(untiled_fourier_logits_rot[idx, :, :, 2*freq].unsqueeze(dim=-1), group.fourier_basis)

        cos_logits_freq_ref = fft2d(untiled_fourier_logits_ref[idx, :, :, 2*freq-1].unsqueeze(dim=-1), group.fourier_basis)
        sin_logits_freq_ref = fft2d(untiled_fourier_logits_ref[idx, :, :, 2*freq].unsqueeze(dim=-1), group.fourier_basis)

        if idx in [1,2]: #if logit ref, rotations should be 0
            untiled_fourier_logits_rot_new[idx, :, :, 2*freq-1] -= fft2d(cos_logits_freq_rot, group.fourier_basis, inverse=True)
            untiled_fourier_logits_rot_new[idx, :, :, 2*freq] -= fft2d(sin_logits_freq_rot, group.fourier_basis, inverse=True)
        
        if idx in [0,3]: #if logit rot, reflections should be 0
            untiled_fourier_logits_ref_new[idx, :, :, 2*freq-1] -= fft2d(cos_logits_freq_ref, group.fourier_basis, inverse=True)
            untiled_fourier_logits_ref_new[idx, :, :, 2*freq] -= fft2d(sin_logits_freq_ref, group.fourier_basis, inverse=True)


fourier_logits_rot_new = tile2d(untiled_fourier_logits_rot_new)
fourier_logits_rot_new = fourier_logits_rot_new.reshape(group.order*group.order, group.index)
logits_rot_new = fourier_logits_rot_new @ group.fourier_basis

fourier_logits_ref_new = tile2d(untiled_fourier_logits_ref_new)
fourier_logits_ref_new = fourier_logits_ref_new.reshape(group.order*group.order, group.index)
logits_ref_new = fourier_logits_ref_new @ group.fourier_basis


logits_new = torch.cat([logits_rot_new, logits_ref_new], dim=-1)

In [16]:
baseline_loss = loss_fn(logits, all_labels)
print(baseline_loss)
new_loss = loss_fn(logits_new, all_labels)
print(new_loss)

tensor(7.7648e-06, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(3.5912e-06, device='cuda:0', grad_fn=<NllLossBackward0>)
