# Setup

In [5]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

Mon Nov 14 16:24:11 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 46%   33C    P8     7W / 105W |    825MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops

import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *


In [7]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Things might be rather slow')

Good to go!


# Interpretability




We interpret S_4 and S_5 in the natural and standard bases. 

## Interpretability Set Up 

In [8]:
task_dir = "1L_MLP_sym_S4_cached"
seed, frac_train, width, lr, group_param, weight_decay, num_epochs, group_type, architecture_type, metric = load_cfg(task_dir)
group = group_type(group_param, init_all = False)
model = architecture_type(width, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"))
model.eval()
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
logits, activations = model.run_with_cache(all_data, return_cache_object=False)
activations['logits'] = logits

## Visualise the Weights


In [9]:
imshow(model.W_E_a)
imshow(model.W_E_b)
imshow(model.W)
imshow(model.W_U)

## Understanding Activation Patterns

Let's first look at the relevant activations (the hidden and output layers) in the standard basis.

### Understanding Individual Neuron Activations

In [10]:
# first, reshape from (batch) to (group.order, group.order) so that we may easily fourier transform on each of the two input dimensions.

for key, value in activations.items():
    activations[key] = value.reshape(group.order, group.order, -1)

In [11]:
for x in range(group.order):
    print(f'{x}: {group.idx_to_perm(x)}, signature: {group.signature(x)}, order: {group.perm_order(x)}')

0: (3), signature: 1, order: 1
1: (0 1 2 3), signature: -1, order: 4
2: (0 2)(1 3), signature: 1, order: 2
3: (0 3), signature: -1, order: 2
4: (1 2 3), signature: 1, order: 3
5: (0 1 3 2), signature: -1, order: 4
6: (3)(0 2 1), signature: 1, order: 3
7: (0 3 1 2), signature: -1, order: 4
8: (1 3 2), signature: 1, order: 3
9: (3)(0 1), signature: -1, order: 2
10: (0 2 3), signature: 1, order: 3
11: (0 3 2 1), signature: -1, order: 4
12: (2 3), signature: -1, order: 2
13: (3)(0 1 2), signature: 1, order: 3
14: (0 2 1 3), signature: -1, order: 4
15: (0 3 2), signature: 1, order: 3
16: (3)(1 2), signature: -1, order: 2
17: (0 1 3), signature: 1, order: 3
18: (0 2 3 1), signature: -1, order: 4
19: (0 3)(1 2), signature: 1, order: 2
20: (1 3), signature: -1, order: 2
21: (0 1)(2 3), signature: 1, order: 2
22: (3)(0 2), signature: -1, order: 2
23: (0 3 1), signature: 1, order: 3


observations for cached S4

neuron 8: activates iff a even and b odd (!?)

neuron 24: activates iff both a and b odd

neuron 13: 


In [15]:
hidden = activations['hidden']
neurons = [3, 7, 8, 13, 15, 16, 19, 24, 33, 34, 36, 39] # checked 0-39 for blocky neurons
neurons = range(1, 10)
for neuron in neurons:
    imshow(hidden[:, :, neuron], title=f'hidden activations, neuron {neuron}', input1='position 1', input2='position 2')

In [13]:
logits = activations['logits']
for neuron in range(2):
    imshow(logits[:, :, neuron], title=f'logit {neuron}', input1='position 1', input2='position 2')

## Understanding Logit Computation

We hypotehsise the network is using the following formula to compute logits

$$Tr(\rho(x)\rho(y)\rho(z^{-1}))=Tr(\rho(xyz^{-1}))$$

We can verify this by computing this (group.order, group.order, group.order) tensor cube, and comparing the cosine similarity of it to the actual logits produced.

I'm actually not convinced cosine similarity of logits is a very good way of testing whether the algorithm used is the expected one.... 


In [14]:
group.compute_natural_rep()

def compute_trace_tensor_cube(group, all_data, rep):
  N = all_data.shape[0]
  t = torch.zeros((group.order*group.order, group.order), dtype=torch.float).cuda()
  for i in range(N):
    if i%1000 == 0:
      print(f'progress: {i} / {N}')
    x = all_data[i, 0]
    x_rep = rep(x.item())
    y = all_data[i, 1]
    y_rep = rep(y.item())
    temp = x_rep.mm(y_rep)
    for z_idx in range(group.order):
      z_rep = rep(z_idx)
      t[i, z_idx] = torch.trace(temp.mm(z_rep.T)) # transpose is inverse here
  return t.reshape(group.order, group.order, group.order)

rep = group.standard_rep
trace = compute_trace_tensor_cube(group, all_data, rep)

progress: 0 / 576


AttributeError: 'SymmetricGroup' object has no attribute 'standard_reps'

In [None]:
print(group.idx_to_perm(0))
print(group.standard_rep(0) @ group.standard_rep(0).inverse())

(2 4 3)
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]], device='cuda:0')


correct for logits not caring about an additive constant

In [None]:
logits = activations['logits']
centred_logits = logits - logits.mean(-1)
centred_trace = trace - trace.mean(-1)

torch.Size([120, 120, 120])


Let's first just look and see if they are visiually similar across two different slices

In [None]:
for neuron in range(2):
    imshow(centred_logits[:, :, neuron], title=f'real logit {neuron}', input1='position 1', input2='position 2')
    imshow(centred_trace[:, :, neuron], title=f'hypothesised logit {neuron}', input1='position 1', input2='position 2')

In [None]:
for neuron in range(2):
    imshow(centred_logits[neuron, :, :], title=f'real logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')
    imshow(centred_trace[neuron, :, :], title=f'hypothesised logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')

Cosine similarity over all logits, mean over all inputs.

We find for S5 both the standard and natural representation yield good cosine similarity. This suggests the network really is performing the alleged algorithm. The product of the sign and standard representation is certainly not the algorithm employed

We find for S4 all three representations yield good similarity.

In [None]:
centred_logits = centred_logits.reshape(group.order*group.order,-1)
centred_trace = centred_trace.reshape(group.order*group.order,-1)
sims = F.cosine_similarity(centred_logits, centred_trace, dim=-1)
print(f'cosine similarity pre softmax {sims.mean()}')

centred_logits_softmax = F.softmax(centred_logits, dim=-1)
centred_trace_softmax = F.softmax(centred_trace, dim=-1)

softmax_sims = F.cosine_similarity(centred_logits_softmax, centred_trace_softmax, dim=-1).mean()
print(f'cosine similarity post softmax {softmax_sims.mean()}')

norms = torch.linalg.norm(centred_logits, dim=-1) 
scalar_projects = norms * sims 
unit_trace_logits = centred_trace / torch.linalg.norm(centred_trace, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(-1) / centred_logits.pow(2).sum(-1)
print(f'fraction of variance of logit explained by the trace logits {frac_explained.mean()}')

norms = torch.linalg.norm(centred_logits_softmax, dim=-1) 
scalar_projects = norms * softmax_sims 
unit_trace_logits = centred_trace_softmax / torch.linalg.norm(centred_trace_softmax, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(-1) / centred_logits_softmax.pow(2).sum(-1)
print(f'fraction of variance of softmax logits explained by the softmax trace logits {frac_explained.mean()}')


cosine similarity pre softmax 0.5036443471908569
cosine similarity post softmax 0.854880690574646
fraction of variance of logit explained by the trace logits 0.2661466896533966
fraction of variance of softmax logits explained by the softmax trace logits 0.7308210730552673


# Understanding the embeddings

The embedding x_embed is a (group.order, embed_dim) tensor. Let's first look at the singular values of this matrix. We find there are 9 large ones = (4-1) x (4-1)

In [None]:
u_a,s_a,v_a = torch.linalg.svd(model.W_E_a) 
print(s_a[:20])

u_b,s_b,v_b = torch.linalg.svd(model.W_E_b)
print(s_b[:20])

u,s,v = torch.linalg.svd(model.W)
print(s[:30])

u,s,v = torch.linalg.svd(model.W_U.T)
print(s[:50])

tensor([10.8622, 10.5386, 10.4976, 10.4262, 10.0723, 10.0522,  9.9892,  9.8150,
         9.0828,  0.3851,  0.3816,  0.3684,  0.3559,  0.3420,  0.3159,  0.2994,
         0.2862,  0.2850,  0.2769,  0.2744], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([10.7975, 10.6226, 10.6078, 10.4834, 10.0891,  9.9713,  9.9075,  9.7354,
         9.0852,  0.4102,  0.4003,  0.3783,  0.3541,  0.3482,  0.3080,  0.3023,
         0.2956,  0.2875,  0.2822,  0.2727], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([11.1874, 10.9568, 10.8445, 10.7239, 10.6454, 10.5712, 10.4541, 10.3307,
        10.1480, 10.0513,  9.9480,  9.8286,  9.6388,  9.5681,  9.4593,  9.1160,
         8.9759,  8.9681,  0.6175,  0.5783,  0.3598,  0.3107,  0.2823,  0.2542,
         0.2273,  0.2163,  0.2006,  0.1659,  0.1411,  0.1307], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([11.5073,  8.9923,  8.9157,  8.8915,  8.8643,  8.8532,  8.8332,  8.8093,
         8.7679,  8.7589,  8.7351,  8.6982,  8.6922,  8.6

What happens if we take the low rank approximation gained from the first 9 singular values?


In [None]:
k=9
new_model = copy.deepcopy(model)
new_model.W_E_a = torch.nn.Parameter(u_a[:, :k] @ torch.diag(s_a[:k]) @ v_a[:k, :])
new_model.W_E_b = torch.nn.Parameter(u_b[:, :k] @ torch.diag(s_b[:k]) @ v_b[:k, :])
logits = model(all_data)
new_logits = new_model(all_data)
loss_new = loss_fn (new_logits, all_labels)
loss = loss_fn (logits, all_labels)

print(f'baselines loss {loss.item()}')
print(f'loss with only the largest 9 singular values of embeddings {loss_new.item()}')

baselines loss 1.2673684750552638e-06
loss with only the largest 9 singular values of embeddings 1.272823965337011e-06


So far we've been working with the natural reperesentation. This has a 1-dimensional invariant subspace spanned by the sum of all basis vectors. Note too in the standard representation, the trace of any element is identical to the trace in the natural representation, up to some additive constant, which we discount from logits anyway. 

We can also see this by looking at its singular values of the standard representations, of which there are $group.index^2$, but only $(group.index-1)^2+1$ non zero ones. I hypothesise this is the 9 directions corresponding to the natural representation, then one for the "distance from the plane" formed when we collapse the standard rep to the natural rep. I hypothesise this 1 additional direction is not used by the network??

In [None]:
k = (group.index-1)**2 + 1
natural_reps = group.natural_reps.reshape(group.order, group.index*group.index)
u, s, v = torch.linalg.svd(natural_reps)
print(s)

# remove junk singular values
natural_reps_mod = u[:, :k] @ torch.diag(s[:k]) # this is going to get plugged into a proj operator, so we can drop the v
natural_reps_mod = torch.linalg.qr(natural_reps_mod)[0]

tensor([1.0954e+01, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00,
        5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00,
        5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 4.1506e-07,
        3.4037e-07, 3.2403e-07, 2.7186e-07, 1.7150e-07, 1.6669e-07, 1.1426e-07,
        4.7155e-08], device='cuda:0')


In the standard representatin, all $(group.index-1)^2$ are non zero.

In [None]:
standard_reps = group.standard_reps.reshape(group.order, (group.index-1)*(group.index-1))
u,s,v = torch.linalg.svd(standard_reps)
print(s)
standard_reps_mod = torch.linalg.qr(standard_reps)[0]

tensor([16.8572, 14.3396, 10.4183,  8.8623,  7.5388,  6.4389,  5.4772,  5.4772,
         5.4772,  5.4772,  4.6592,  3.9794,  3.3851,  2.8795,  2.0921,  1.7797],
       device='cuda:0')


In [None]:
product_standard_sign_reps = group.product_standard_sign_reps.reshape(group.order, (group.index-1)*(group.index-1))
u,s,v = torch.linalg.svd(product_standard_sign_reps)
print(s)
product_standard_sign_reps_mod = torch.linalg.qr(product_standard_sign_reps)[0]

tensor([16.8572, 14.3396, 10.4183,  8.8623,  7.5388,  6.4389,  5.4772,  5.4772,
         5.4772,  5.4772,  4.6592,  3.9794,  3.3851,  2.8795,  2.0921,  1.7797],
       device='cuda:0')


Let's see if the embeddings are actually learning representations. We can generate a $(group.order, group.index^2)$ tensor encoding the natural representation, or a $(group.order, (group.index-1)^2)$ encoding the standard representation. The embeddings are $(group.order, embed\_dim)$ tensors, which are higher dimensional than the representations. We've already shown the embeddings are sparse in the singular basis, but let's now show the dimensions actually line up.

In [None]:
def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

P_natural = projection_matrix_general(natural_reps_mod)
assert(torch.allclose(P_natural, P_natural@P_natural, atol=1e-6))
assert(torch.allclose(P_natural, P_natural.T, atol=1e-6))

P_standard = projection_matrix_general(standard_reps_mod)
assert(torch.allclose(P_standard, P_standard@P_standard, atol=1e-6))
assert(torch.allclose(P_standard, P_standard.T, atol=1e-6))

P_standard_sign = projection_matrix_general(product_standard_sign_reps_mod)
assert(torch.allclose(P_standard_sign, P_standard_sign@P_standard_sign, atol=1e-6))
assert(torch.allclose(P_standard_sign, P_standard_sign.T, atol=1e-6))



proj_a_natural = P_natural @ model.W_E_a
proj_b_natural = P_natural @ model.W_E_b
proj_U_natural = P_natural @ model.W_U.T

proj_a_standard = P_standard @ model.W_E_a
proj_b_standard = P_standard @ model.W_E_b
proj_U_standard = P_standard @ model.W_U.T

proj_a_standard_sign = P_standard_sign @ model.W_E_a
proj_b_standard_sign = P_standard_sign @ model.W_E_b
proj_U_standard_sign = P_standard_sign @ model.W_U.T

print(f'W_E_a frobenius norm: {model.W_E_a.pow(2).sum()}')
print(f'Projection onto natural representation {proj_a_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_a_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_a_standard_sign.pow(2).sum()}')
print(f'W_E_b frobenius norm: {model.W_E_b.pow(2).sum()}')
print(f'Projection onto natural representation {proj_b_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_b_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_b_standard_sign.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto natural representation {proj_U_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_U_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_U_standard_sign.pow(2).sum()}')


W_E_a frobenius norm: 932.2500610351562
Projection onto natural representation 846.6348876953125
Projection onto standard representation 846.6321411132812
Projection onto standard sign representation 0.1008387953042984
W_E_b frobenius norm: 931.7891845703125
Projection onto natural representation 846.0723876953125
Projection onto standard representation 846.06982421875
Projection onto standard sign representation 0.105100616812706
W_U frobenius norm: 1413.734130859375
Projection onto natural representation 1274.8861083984375
Projection onto standard representation 1229.9315185546875
Projection onto standard sign representation 0.45307162404060364


It's slightly sus that these are exactly half the norm of the actual embedding matrices. Is it using a combination of the natural and standard representations? No.

Is it using some combination of the natural, standard and sign reprsentatations?

Seems like it might be in S4, but in S5 it is not at all. Is this to do with parity?



In [None]:
all_reps = torch.cat([standard_reps_mod, natural_reps_mod], dim=1)
print(all_reps.shape)
q = torch.linalg.qr(all_reps)[0]
P_all = projection_matrix_general(q)
assert(torch.allclose(P_all, P_all@P_all, atol=1e-6))
assert(torch.allclose(P_all, P_all.T, atol=1e-6))
proj_a_all = P_all @ model.W_E_a
proj_b_all = P_all @ model.W_E_b
proj_U_all = P_all @ model.W_U.T

print(f'W_E_a frobenius norm: {model.W_E_a.pow(2).sum()}')
print(f'Projection onto representation {proj_a_all.pow(2).sum()}')
print(f'W_E_b frobenius norm: {model.W_E_b.pow(2).sum()}')
print(f'Projection onto representation {proj_b_all.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto representation {proj_U_all.pow(2).sum()}')


torch.Size([120, 33])
W_E_a frobenius norm: 932.2500610351562
Projection onto representation 856.630615234375
W_E_b frobenius norm: 931.7891845703125
Projection onto representation 856.084228515625
W_U frobenius norm: 1413.734130859375
Projection onto representation 1291.10400390625


In [None]:
# how different are the two models

proj_natural_on_standard = P_natural @ standard_reps_mod
print(f'Projection of natural onto standard representation {proj_natural_on_standard.pow(2).sum()}')
print(f'Standard representation frobenius norm: {standard_reps_mod.pow(2).sum()}')

proj_standard_on_natural = P_standard @ natural_reps_mod
print(f'Projection of standard on natural representation {proj_standard_on_natural.pow(2).sum()}')
print(f'Natural representation frobenius norm: {natural_reps_mod.pow(2).sum()}')

Projection of natural onto standard representation 16.0
Standard representation frobenius norm: 16.0
Projection of standard on natural representation 15.999999046325684
Natural representation frobenius norm: 17.0
