# Setup

In [186]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

Tue Nov  8 17:24:21 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 46%   32C    P8     6W / 105W |   1013MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [202]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops

import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *


In [188]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Things might be rather slow')

Good to go!


# Interpretability




## Interpretability Set Up 

In [189]:
task_dir = "1L_MLP_sym_S4_cached"
seed, frac_train, width, lr, group_param, weight_decay, num_epochs, group_type, architecture_type, metrics = load_cfg(task_dir)
group = group_type(group_param)
model = architecture_type(width, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"))
model.eval()
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
logits, activations = model.run_with_cache(all_data, return_cache_object=False)
activations['logits'] = logits

## Visualise the Weights


In [190]:
imshow(model.W_E_a)
imshow(model.W_E_b)
imshow(model.W)
imshow(model.W_U)

## Understanding Activation Patterns

Let's first look at the relevant activations (the hidden and output layers) in the standard basis.

### Understanding Individual Neuron Activations

In [191]:
# first, reshape from (batch) to (group.order, group.order) so that we may easily fourier transform on each of the two input dimensions.

for key, value in activations.items():
    activations[key] = value.reshape(group.order, group.order, -1)

In [192]:
for x in range(group.order):
    print(f'{x}: {group.idx_to_perm(x)}, signature: {group.signature(x)}, order: {group.perm_order(x)}')



0: (3), signature: 1, order: 1
1: (0 1 2 3), signature: -1, order: 4
2: (0 2)(1 3), signature: 1, order: 2
3: (0 3), signature: -1, order: 2
4: (1 2 3), signature: 1, order: 3
5: (0 1 3 2), signature: -1, order: 4
6: (3)(0 2 1), signature: 1, order: 3
7: (0 3 1 2), signature: -1, order: 4
8: (1 3 2), signature: 1, order: 3
9: (3)(0 1), signature: -1, order: 2
10: (0 2 3), signature: 1, order: 3
11: (0 3 2 1), signature: -1, order: 4
12: (2 3), signature: -1, order: 2
13: (3)(0 1 2), signature: 1, order: 3
14: (0 2 1 3), signature: -1, order: 4
15: (0 3 2), signature: 1, order: 3
16: (3)(1 2), signature: -1, order: 2
17: (0 1 3), signature: 1, order: 3
18: (0 2 3 1), signature: -1, order: 4
19: (0 3)(1 2), signature: 1, order: 2
20: (1 3), signature: -1, order: 2
21: (0 1)(2 3), signature: 1, order: 2
22: (3)(0 2), signature: -1, order: 2
23: (0 3 1), signature: 1, order: 3


observations for cached S4

neuron 8: activates iff a even and b odd (!?)
neuron 24: activates iff both a and b odd
neuron 13: 


In [193]:
hidden = activations['hidden']
neurons = [3, 7, 8, 13, 15, 16, 19, 24, 33, 34, 36, 39] # checked 0-39 for blocky neurons
for neuron in neurons:
    imshow(hidden[:, :, neuron], title=f'hidden activations, neuron {neuron}', input1='position 1', input2='position 2')

In [194]:
logits = activations['logits']
for neuron in range(10):
    imshow(logits[:, :, neuron], title=f'logit {neuron}', input1='position 1', input2='position 2')

## Understanding Logit Computation

We hypotehsise the network is using the following formula to compute logits

$$Tr(\rho(x)\rho(y)\rho(z^{-1}))=Tr(\rho(xyz^{-1}))$$

We can verify this by computing this (group.order, group.order, group.order) tensor cube, and comparing the cosine similarity of it to the actual logits produced.


In [195]:
group.compute_natural_rep()
def compute_trace_tensor_cube(group, all_data):
  N = all_data.shape[0]
  t = torch.zeros((group.order*group.order, group.order), dtype=torch.float).cuda()
  for i in range(N):
    if i%1000 == 0:
      print(f'progress: {i} / {N}')
    x = all_data[i, 0]
    x_rep = group.natural_rep(x.item())
    y = all_data[i, 1]
    y_rep = group.natural_rep(y.item())
    temp = x_rep.mm(y_rep)
    for z_idx in range(group.order):
      z_rep = group.natural_rep(z_idx)
      t[i, z_idx] = torch.trace(temp.mm(z_rep.T)) # transpose is inverse here
  return t.reshape(group.order, group.order, group.order)

trace = compute_trace_tensor_cube(group, all_data)

progress: 0 / 576


correct for logits not caring about an additive constant

In [196]:
logits = activations['logits']
centred_logits = logits - logits.mean(-1)
centred_trace = trace - trace.mean(-1)

Let's first just look and see if they are visiually similar across two different slices

In [197]:
for neuron in range(10):
    imshow(centred_logits[:, :, neuron], title=f'real logit {neuron}', input1='position 1', input2='position 2')
    imshow(centred_trace[:, :, neuron], title=f'hypothesised logit {neuron}', input1='position 1', input2='position 2')

In [198]:
for neuron in range(10):
    imshow(centred_logits[neuron, :, :], title=f'real logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')
    imshow(centred_trace[neuron, :, :], title=f'hypothesised logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')

Cosine similarity over all logits, mean over all inputs.

In [199]:
centred_logits = centred_logits.reshape(group.order*group.order,-1)
centred_trace = centred_trace.reshape(group.order*group.order,-1)
sims = F.cosine_similarity(centred_logits, centred_trace, dim=-1)
print(f'cosine similarity pre softmax {sims.mean()}')

centred_logits_softmax = F.softmax(centred_logits, dim=-1)
centred_trace_softmax = F.softmax(centred_trace, dim=-1)

softmax_sims = F.cosine_similarity(centred_logits_softmax, centred_trace_softmax, dim=-1).mean()
print(f'cosine similarity post softmax {softmax_sims.mean()}')

norms = torch.linalg.norm(centred_logits, dim=-1) 
scalar_projects = norms * sims 
unit_trace_logits = centred_trace / torch.linalg.norm(centred_trace, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(0) / centred_logits.pow(2).sum(0)
print(f'fraction of variance of logit explained by the trace logits {frac_explained.mean()}')





cosine similarity pre softmax 0.5828865170478821
cosine similarity post softmax 0.9396134614944458
fraction of variance of logit explained by the trace logits 0.3483385741710663


# Understanding the embeddings

Let's see if the embeddings are learning representations. We can generate a (group.order, group.index, group.index) tensor encoding the natural representation. We can think of this as a (group.order, group.index * group.index) tensor. 

The embedding x_embed is a (group.order, embed_dim) tensor. Let's first look at the singular values of this matrix. We find there are 9 large ones = (4-1) x (4-1)

In [200]:
u_a,s_a,v_a = torch.linalg.svd(model.W_E_a) 
print(s_a)

u_b,s_b,v_b = torch.linalg.svd(model.W_E_b)

tensor([6.2584, 6.2100, 6.1579, 4.6228, 4.5293, 4.3990, 4.2643, 4.1072, 3.9995,
        0.3696, 0.3045, 0.2542, 0.2231, 0.2133, 0.1835, 0.1604, 0.1469, 0.1421,
        0.1348, 0.1290, 0.1249, 0.1047, 0.1020, 0.0969], device='cuda:0',
       grad_fn=<LinalgSvdBackward0>)


What happens if we take the low rank approximation gained from the first 9 singular values?


In [204]:
k=9
new_model = copy.deepcopy(model)
new_model.W_E_a = torch.nn.Parameter(u_a[:, :k] @ torch.diag(s_a[:k]) @ v_a[:k, :])
new_model.W_E_b = torch.nn.Parameter(u_b[:, :k] @ torch.diag(s_b[:k]) @ v_b[:k, :])
logits = model(all_data)
new_logits = new_model(all_data)
loss_new = loss_fn (new_logits, all_labels)
loss = loss_fn (logits, all_labels)

print(f'baselines loss {loss.item()}')
print(f'loss with only the largest 9 singular values of embeddings {loss_new.item()}')

0.0001891496212920174 0.00018685602117329836


So far we've been working with the standard reperesentation. This has a 1-dimensional invariant subspace spanned by the sum of all basis vectors. We can see this by looking at its singular values, of which there are $group.index^2$, but only $(group.index-1)^2$ relevant ones.

In [208]:
natural_reps = group.natural_reps.reshape(group.order, group.index*group.index)
u, s, v = torch.linalg.svd(natural_reps)
print(s)

tensor([4.8990e+00, 2.8284e+00, 2.8284e+00, 2.8284e+00, 2.8284e+00, 2.8284e+00,
        2.8284e+00, 2.8284e+00, 2.8284e+00, 2.8284e+00, 1.4788e-07, 1.4178e-07,
        8.9138e-08, 8.3209e-08, 5.7613e-08, 5.0008e-08], device='cuda:0')
