# Setup

In [41]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

Tue Nov 29 14:39:44 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 46%   31C    P8     6W / 105W |   2677MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
import sage

import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *
from utils.config import *

In [43]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Things might be rather slow')

Good to go!


# Interpretability




## Interpretability Set Up 

In [44]:
task_dir = "1L_MLP_sym_S5_cached_3" #1L_MLP_sym_S5_cached_3"
seed, frac_train, width, lr, group_param, weight_decay, num_epochs, group_type, architecture_type, metric_cfg, metric_obj = load_cfg(task_dir)
group = group_type(group_param, init_all = True)
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
model = architecture_type(width, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"))
model.eval()
logits, activations = model.run_with_cache(all_data, return_cache_object=False)
activations['logits'] = logits
metric_obj = metric_obj(group, training=False, track_metrics = True)
metrics = metric_obj.get_metrics(model)



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Computing trace tensor cube for standard representation
Loading from file
Computing trace tensor cube for standard_sign representation
Loading from file
Computing trace tensor cube for sign representation
Loading from file
Computing trace tensor cube for trivial representation
Loading from file



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



## Visualise the Weights


In [45]:
imshow(model.W_x)
imshow(model.W_y)
imshow(model.W)
imshow(model.W_U)

## Losses on various subsets


In [46]:
all_indices = np.arange(group.order)

In [47]:
all_loss = loss_fn(logits, all_labels)
all_loss

tensor(8.0730e-07, device='cuda:0', grad_fn=<NllLossBackward0>)

In [48]:
alternating_indices = [i for i in range(group.order) if group.signature(i) == 1]
alternating_data = group.get_subset_of_data(alternating_indices).cuda()
alternating_data, alternating_labels = alternating_data[:, :2], alternating_data[:, 2]
alternating_logits = model(alternating_data)
alternating_loss = loss_fn(alternating_logits, alternating_labels).item()
alternating_loss

8.045964250413817e-07

In [49]:
alt_coset_indices = [i for i in range(group.order) if group.signature(i) == -1]
alt_coset_data = group.get_subset_of_data(alt_coset_indices).cuda()
alt_coset_data, alt_coset_labels = alt_coset_data[:, :2], alt_coset_data[:, 2]
alt_coset_logits = model(alternating_data)
alt_coset_loss = loss_fn(alternating_logits, alternating_labels).item()
alt_coset_loss

8.045964250413817e-07

In [50]:
#random_indices = np.random.choice(group.order, int(0.5*group.order*group.order))
#random_data = group.get_subset_of_data(random_indices).cuda()
#random_data, random_labels = random_data[:, :2], random_data[:, 2]
#random_logits = model(random_data)
#random_loss = loss_fn(random_logits, random_labels).item()
#random_loss


In [51]:
data = group.get_subset_of_data(alternating_indices, alt_coset_indices).cuda()
data, labels = data[:, :2], data[:, 2]
logits = model(data)
loss = loss_fn(logits, labels).item()
loss


8.157226147886831e-07

In [52]:
data = group.get_subset_of_data(alt_coset_indices, alternating_indices ).cuda()
data, labels = data[:, :2], data[:, 2]
logits = model(data)
loss = loss_fn(logits, labels).item()
loss


8.083382567747321e-07

In [53]:
losses=[]
for x in range(group.order):
    data = group.get_subset_of_data([x], all_indices).cuda()
    data, labels = data[:, :2], data[:, 2]
    logits = model(data)
    loss = loss_fn(logits, labels).item()
    losses.append(loss)
line(losses, log_y=True)


elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison



## Look at activation patterns

Let's first look at the relevant activations (the hidden and output layers) in the standard basis.

### Understanding Individual Neuron Activations

In [54]:
# first, reshape from (batch) to (group.order, group.order) so that we may easily fourier transform on each of the two input dimensions.

for key, value in activations.items():
    activations[key] = value.reshape(group.order, group.order, -1)

In [55]:
# elements
for x in range(group.order):
    print(f'{x}: {group.idx_to_perm(x)}, signature: {group.signature(x)}, order: {group.perm_order(x)}')

# conjugacy classes
#for x in group.conjugacy_classes:
    #print(x)

0: (2 4 3), signature: 1, order: 3
1: (4)(0 1 2), signature: 1, order: 3
2: (0 2 1 3 4), signature: 1, order: 5
3: (0 3)(1 4), signature: 1, order: 2
4: (0 4 3 2), signature: -1, order: 4
5: (4)(1 2 3), signature: 1, order: 3
6: (0 1 3 2 4), signature: 1, order: 5
7: (0 2)(1 4), signature: 1, order: 2
8: (0 3 4 2 1), signature: 1, order: 5
9: (0 4)(1 2 3), signature: -1, order: 6
10: (1 3 4), signature: 1, order: 3
11: (0 1 4 2 3), signature: 1, order: 5
12: (0 2 4 3 1), signature: 1, order: 5
13: (4)(0 3 2), signature: 1, order: 3
14: (0 4 1 3), signature: -1, order: 4
15: (1 4 2 3), signature: -1, order: 4
16: (0 1)(2 4 3), signature: -1, order: 6
17: (4)(0 2), signature: -1, order: 2
18: (0 3 4)(1 2), signature: -1, order: 6
19: (0 4 2 3 1), signature: 1, order: 5
20: (4), signature: 1, order: 1
21: (0 1 2 3 4), signature: 1, order: 5
22: (0 2 4 1 3), signature: 1, order: 5
23: (0 3 1 4 2), signature: 1, order: 5
24: (0 4), signature: -1, order: 2
25: (1 2)(3 4), signature: 1, order

### observations for cached S4

neuron 8: activates iff a even and b odd (!?)

neuron 24: activates iff both a and b odd



In [56]:
hidden = activations['hidden']
neurons = [3, 7, 8, 13, 15, 16, 19, 24, 33, 34, 36, 39] # checked 0-39 for blocky neurons
neurons = [24]
for neuron in neurons:
    imshow(hidden[:, :, neuron], title=f'hidden activations, neuron {neuron}', input1='position 1', input2='position 2')

In [57]:
logits = activations['logits']
for neuron in range(2):
    imshow(logits[:, :, neuron], title=f'logit {neuron}', input1='position 1', input2='position 2')

## Understanding Logit Computation

We hypotehsise the network is using the following formula to compute logits

$$Tr(\rho(x)\rho(y)\rho(z^{-1}))=Tr(\rho(xyz^{-1}))$$

We can verify this by computing this (group.order, group.order, group.order) tensor cube, and comparing the cosine similarity of it to the actual logits produced. The group class has already done this. Note this trace has been centred, as logits don't care about an additive constant and we will compare the two. We also center the logits before comparing their cosine similarity.

In [58]:
centred_trace = group.standard_trace_tensor_cubes
print(centred_trace.shape)

torch.Size([120, 120, 120])


In [59]:
# correct for logits not caring about an additive constant

logits = activations['logits']
print(logits.shape)
centred_logits = logits - logits.mean(-1, keepdim=True)

torch.Size([120, 120, 120])


### Visual Comparison
Let's first just look and see if they are visiually similar across two different slices

In [60]:
for neuron in range(2):
    imshow(centred_logits[:, :, neuron], title=f'real logit {neuron}', input1='position 1', input2='position 2')
    imshow(centred_trace[:, :, neuron], title=f'hypothesised logit {neuron}', input1='position 1', input2='position 2')

In [61]:
for neuron in range(2):
    imshow(centred_logits[neuron, :, :], title=f'real logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')
    imshow(centred_trace[neuron, :, :], title=f'hypothesised logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')

### Cosine similarity 

over all logits, mean over all inputs.

We find for S5 both the standard and natural representation yield good cosine similarity. This suggests the network really is performing the alleged algorithm. The product of the sign and standard representation is certainly not the algorithm employed

We find for S4 all three representations yield good similarity.

In [62]:
centred_logits = centred_logits.reshape(group.order*group.order,-1)
centred_trace = centred_trace.reshape(group.order*group.order,-1)
sims = F.cosine_similarity(centred_logits, centred_trace, dim=-1)
print(f'cosine similarity pre softmax {sims.mean()}')

centred_logits_softmax = F.softmax(centred_logits, dim=-1)
centred_trace_softmax = F.softmax(centred_trace, dim=-1)

softmax_sims = F.cosine_similarity(centred_logits_softmax, centred_trace_softmax, dim=-1).mean()
print(f'cosine similarity post softmax {softmax_sims.mean()}')

norms = torch.linalg.norm(centred_logits, dim=-1) 
scalar_projects = norms * sims 
unit_trace_logits = centred_trace / torch.linalg.norm(centred_trace, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(-1) / centred_logits.pow(2).sum(-1)
print(f'fraction of variance of logit explained by the trace logits {frac_explained.mean()}')

norms = torch.linalg.norm(centred_logits_softmax, dim=-1) 
scalar_projects = norms * softmax_sims 
unit_trace_logits = centred_trace_softmax / torch.linalg.norm(centred_trace_softmax, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(-1) / centred_logits_softmax.pow(2).sum(-1)
print(f'fraction of variance of softmax logits explained by the softmax trace logits {frac_explained.mean()}')


cosine similarity pre softmax 0.7574389576911926
cosine similarity post softmax 0.8944898247718811
fraction of variance of logit explained by the trace logits 0.5737529993057251
fraction of variance of softmax logits explained by the softmax trace logits 0.8001120686531067


### How different are the representations?


Hypothesis: the representations learned are those that are maximally different. Each injective representation will give the correct answer. Having very different logits will produce outputs that destructively intefere, minimising loss.


In [63]:
centred_traces = {'trivial': group.trivial_trace_tensor_cubes, 
                    'sign': group.sign_trace_tensor_cubes, 
                    'standard': group.standard_trace_tensor_cubes,
                    'standard_sign': group.standard_sign_trace_tensor_cubes}

if group.index == 4:
    centred_traces['s4_2d'] =  group.S4_2d_trace_tensor_cubes


for key, value in centred_traces.items():
    centred_traces[key] = value.reshape(group.order*group.order,-1)

for key1, value1 in centred_traces.items():
    for key2, value2 in centred_traces.items():
        print(f'{key1} vs {key2}: {F.cosine_similarity(value1, value2, dim=-1).mean()}')

trivial vs trivial: 0.0
trivial vs sign: 0.0
trivial vs standard: 0.0
trivial vs standard_sign: 0.0
sign vs trivial: 0.0
sign vs sign: 0.9999999403953552
sign vs standard: 0.0
sign vs standard_sign: 0.0
standard vs trivial: 0.0
standard vs sign: 0.0
standard vs standard: 1.0
standard vs standard_sign: -5.83628834149863e-09
standard_sign vs trivial: 0.0
standard_sign vs sign: 0.0
standard_sign vs standard: -5.83628834149863e-09
standard_sign vs standard_sign: 1.0


### Theoretical optimal representations to use

Interesting. Now what loss would I get if I actually used each of these representations? Let's find out.

In [64]:
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    accuracy = (value.argmax(-1) == all_labels).float().mean()
    loss = loss_fn(value, all_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

trivial: 4.787484645843506, 0.8333333730697632%
sign: 4.221286773681641, 1.6666667461395264%
standard: 1.4973335266113281, 100.0%
standard_sign: 1.3868404626846313, 100.0%


In [65]:
def linspace(start, stop, step=1.):
  """
    Like np.linspace but uses step instead of num
    This is inclusive to stop, so if start=1, stop=3, step=0.5
    Output is: array([1., 1.5, 2., 2.5, 3.])
  """
  return np.linspace(start, stop, int((stop - start) / step + 1))

granularity = 0.1

if group.index ==5:
  best_weights = [0, 0, 0, 0]
  best_loss = 100
  for i in linspace(0, 1, granularity):
      for j in linspace(0, 1-i, granularity):
          for k in linspace (0, 1-i-j, granularity):
              l = 1 - i - j - k
              hyp_logits = i*centred_traces['standard'] + j*centred_traces['sign'] + k*centred_traces['trivial'] + l*centred_traces['standard_sign']
              loss = loss_fn(hyp_logits, all_labels)
              if loss < best_loss:
                  best_loss = loss
                  best_weights = [i, j, k, l]
                  print(f'new best loss {best_loss} with weights {best_weights}')
  print(f'best combination: {best_loss}, {accuracy*100}%')
  print(f'best weights: {best_weights}')

if group.index==4:
  best_weights = [0.58, 0.27, 0, 0.19, 0.08]
  best_loss = 100
  for i in linspace(0, 1, granularity):
      for j in linspace(0, 1-i, granularity):
          for k in linspace (0, 1-i-j, granularity):
            for l in linspace (0, 1-i-j-k, granularity):
              m = 1 - i - j - k - l
              hyp_logits = i*centred_traces['standard'] + j*centred_traces['sign'] + k*centred_traces['trivial'] + l*centred_traces['standard_sign']+ m*centred_traces['s4_2d']
              loss = loss_fn(hyp_logits, all_labels)
              if loss < best_loss:
                  best_loss = loss
                  best_weights = [i, j, k, l, m]
                  print(f'new best loss {best_loss} with weights {best_weights}')
  print(f'best combination: {best_loss}')
  print(f'best weights: {best_weights}')

new best loss 1.3868404626846313 with weights [0.0, 0.0, 0.0, 1.0]
new best loss 1.3438732624053955 with weights [0.1, 0.0, 0.0, 0.9]
new best loss 1.309489369392395 with weights [0.2, 0.0, 0.0, 0.8]
new best loss 1.283819556236267 with weights [0.30000000000000004, 0.0, 0.0, 0.7]
new best loss 1.2674612998962402 with weights [0.4, 0.0, 0.0, 0.6]
new best loss 1.261618733406067 with weights [0.5, 0.0, 0.0, 0.5]
best combination: 1.261618733406067, 100.0%
best weights: [0.5, 0.0, 0.0, 0.5]


In [66]:
best_weights = [0.58, 0.27, 0, 0.19, 0.08]
i, j, k, l, m = best_weights
hyp_logits = i*centred_traces['standard'] + j*centred_traces['sign'] + k*centred_traces['trivial'] + l*centred_traces['standard_sign']#+ m*centred_traces['s4_2d']
loss = loss_fn(hyp_logits, all_labels)
loss



tensor(1.7810, device='cuda:0')

Is this algorithm better on the alternating set? no


In [67]:
alternating_indices = [i for i in range(group.order) if group.signature(i) == 1]
alternating_data = group.get_subset_of_data(alternating_indices).cuda()
alternating_data, alternating_labels = alternating_data[:, :2], alternating_data[:, 2]

print(alternating_indices)
indices = []
for i in alternating_indices:
    for j in alternating_indices:
        indices.append(i*group.order+j)

[0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12, 13, 19, 20, 21, 22, 23, 25, 26, 27, 28, 30, 31, 32, 33, 39, 44, 49, 54, 55, 56, 57, 58, 64, 69, 74, 75, 76, 77, 78, 84, 89, 94, 95, 96, 97, 98, 100, 101, 102, 103, 105, 106, 107, 108, 110, 111, 112, 113, 119]


In [68]:
print('on alternating subgroup')
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    value = value[indices]
    accuracy = (value.argmax(-1) == alternating_labels).float().mean()
    loss = loss_fn(value, alternating_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

print('on whole group')
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    accuracy = (value.argmax(-1) == all_labels).float().mean()
    loss = loss_fn(value, all_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

on alternating subgroup
trivial: 4.7874956130981445, 1.6666667461395264%
sign: 4.221277713775635, 1.6666667461395264%
standard: 1.4973357915878296, 100.0%
standard_sign: 1.3868407011032104, 100.0%
on whole group
trivial: 4.787484645843506, 0.8333333730697632%
sign: 4.221286773681641, 1.6666667461395264%
standard: 1.4973335266113281, 100.0%
standard_sign: 1.3868404626846313, 100.0%


## Understanding the embeddings

The embedding x_embed is a (group.order, embed_dim) tensor. Let's first look at the singular values of this matrix. We find there are 9 large ones = (4-1) x (4-1)

In [69]:
u_a,s_a,v_a = torch.linalg.svd(model.W_x) 
print(s_a[:20])

u_b,s_b,v_b = torch.linalg.svd(model.W_y)
print(s_b[:20])

u,s,v = torch.linalg.svd(model.W)
print(s[:30])

u,s,v = torch.linalg.svd(model.W_U.T)
print(s[:50])

tensor([11.3931, 11.2855, 11.2352, 11.2043,  9.2032,  9.1961,  9.1273,  8.9617,
         8.9002,  0.2676,  0.2544,  0.2452,  0.2425,  0.2351,  0.2077,  0.2019,
         0.2005,  0.1975,  0.1952,  0.1902], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([11.3378, 11.3012, 11.2593, 11.2254,  9.1977,  9.1513,  9.0929,  9.0123,
         8.9476,  0.2628,  0.2525,  0.2486,  0.2376,  0.2314,  0.2116,  0.2078,
         0.2038,  0.1971,  0.1934,  0.1863], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([11.4489, 11.3938, 11.3432, 11.3144, 11.2804, 11.1714, 11.1430, 11.0545,
         9.4669,  9.3767,  9.2421,  9.1046,  9.1037,  9.0957,  9.0443,  8.9440,
         8.8418,  8.5818,  0.4954,  0.3489,  0.2713,  0.2600,  0.2484,  0.2423,
         0.2199,  0.1974,  0.1885,  0.1798,  0.1667,  0.1510], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([12.0561,  9.5496,  9.4000,  9.3679,  9.3416,  9.3238,  9.2923,  9.2793,
         9.2679,  9.2671,  9.2528,  9.2283,  9.2100,  9.1

### Low rank approximations

What happens if we take the low rank approximation gained from the first 9 singular values?


In [70]:
k=9
new_model = copy.deepcopy(model)
new_model.W_x = torch.nn.Parameter(u_a[:, :k] @ torch.diag(s_a[:k]) @ v_a[:k, :])
new_model.W_y = torch.nn.Parameter(u_b[:, :k] @ torch.diag(s_b[:k]) @ v_b[:k, :])
logits = model(all_data)
new_logits = new_model(all_data)
loss_new = loss_fn (new_logits, all_labels)
loss = loss_fn (logits, all_labels)

print(f'baselines loss {loss.item()}')
print(f'loss with only the largest 9 singular values of embeddings {loss_new.item()}')

baselines loss 8.072951800386363e-07
loss with only the largest 9 singular values of embeddings 8.133301321322506e-07


The natural representation has a 1-dimensional invariant subspace spanned by the sum of all basis vectors. Note too in the standard representation, the trace of any element is identical to the trace in the natural representation, up to some additive constant, which we discount from logits anyway. 

We can also see this by looking at its singular values of the standard representations, of which there are $group.index^2$, but only $(group.index-1)^2+1$ non zero ones. I hypothesise this is the 9 directions corresponding to the natural representation, then one for the "distance from the plane" formed when we collapse the standard rep to the natural rep. I hypothesise this 1 additional direction is not used by the network??

In [71]:
k = (group.index-1)**2 + 1
natural_reps = group.natural_reps.reshape(group.order, group.index*group.index)
u, s, v = torch.linalg.svd(natural_reps)
print(s)

# remove junk singular values
natural_reps_mod = u[:, :k] @ torch.diag(s[:k]) # this is going to get plugged into a proj operator, so we can drop the v
natural_reps_mod = torch.linalg.qr(natural_reps_mod)[0]

tensor([1.0954e+01, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00,
        5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00,
        5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 4.1506e-07,
        3.4037e-07, 3.2403e-07, 2.7186e-07, 1.7150e-07, 1.6669e-07, 1.1426e-07,
        4.7155e-08], device='cuda:0')


In the standard representation, all $(group.index-1)^2$ singular values are non zero.

In [72]:
standard_reps = group.standard_reps.reshape(group.order, (group.index-1)*(group.index-1))
u,s,v = torch.linalg.svd(standard_reps)
print(s)

# orthogonalise these
print(group.standard_reps_orth.shape)

tensor([16.8572, 14.3396, 10.4183,  8.8623,  7.5388,  6.4389,  5.4772,  5.4772,
         5.4772,  5.4772,  4.6592,  3.9794,  3.3851,  2.8795,  2.0921,  1.7797],
       device='cuda:0')
torch.Size([120, 16])


In [73]:
standard_sign_reps = group.standard_sign_reps.reshape(group.order, (group.index-1)*(group.index-1))
u,s,v = torch.linalg.svd(standard_sign_reps)
print(s)

print(group.standard_sign_reps_orth.shape)


tensor([16.8572, 14.3396, 10.4183,  8.8623,  7.5388,  6.4389,  5.4772,  5.4772,
         5.4772,  5.4772,  4.6592,  3.9794,  3.3851,  2.8795,  2.0921,  1.7797],
       device='cuda:0')
torch.Size([120, 16])


Let's see if the embeddings are actually learning representations. We can generate a $(group.order, group.index^2)$ tensor encoding the natural representation, or a $(group.order, (group.index-1)^2)$ encoding the standard representation. The embeddings are $(group.order, embed\_dim)$ tensors, which are higher dimensional than the representations. We've already shown the embeddings are sparse in the singular basis, but let's now show the dimensions actually line up.

In [74]:
def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

P_natural = projection_matrix_general(natural_reps_mod)
assert(torch.allclose(P_natural, P_natural@P_natural, atol=1e-6))
assert(torch.allclose(P_natural, P_natural.T, atol=1e-6))

P_standard = projection_matrix_general(group.standard_reps_orth)
assert(torch.allclose(P_standard, P_standard@P_standard, atol=1e-6))
assert(torch.allclose(P_standard, P_standard.T, atol=1e-6))

P_standard_sign = projection_matrix_general(group.standard_sign_reps_orth)
assert(torch.allclose(P_standard_sign, P_standard_sign@P_standard_sign, atol=1e-6))
assert(torch.allclose(P_standard_sign, P_standard_sign.T, atol=1e-6))



proj_a_natural = P_natural @ model.W_x
proj_b_natural = P_natural @ model.W_y
proj_U_natural = P_natural @ model.W_U.T

proj_a_standard = P_standard @ model.W_x
proj_b_standard = P_standard @ model.W_y
proj_U_standard = P_standard @ model.W_U.T

proj_a_standard_sign = P_standard_sign @ model.W_x
proj_b_standard_sign = P_standard_sign @ model.W_y
proj_U_standard_sign = P_standard_sign @ model.W_U.T

print(f'W_x frobenius norm: {model.W_x.pow(2).sum()}')
print(f'Projection onto natural representation {proj_a_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_a_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_a_standard_sign.pow(2).sum()}')
print(f'W_y frobenius norm: {model.W_y.pow(2).sum()}')
print(f'Projection onto natural representation {proj_b_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_b_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_b_standard_sign.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto natural representation {proj_U_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_U_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_U_standard_sign.pow(2).sum()}')




W_x frobenius norm: 922.520751953125
Projection onto natural representation 836.4361572265625
Projection onto standard representation 836.43359375
Projection onto standard sign representation 0.05503452569246292
W_y frobenius norm: 922.8448486328125
Projection onto natural representation 836.846923828125
Projection onto standard representation 836.8443603515625
Projection onto standard sign representation 0.05195031315088272
W_U frobenius norm: 1544.8846435546875
Projection onto natural representation 1394.6727294921875
Projection onto standard representation 1376.19775390625
Projection onto standard sign representation 0.48662275075912476


Let's look at the matmul of the embedding and linear layer too.

In [75]:
embed_dim = model.W_x.shape[1]
unfactored_x = model.W_x @ model.W[:embed_dim, :]
unfactored_y =  model.W_y  @ model.W[embed_dim:, :]
proj_a_natural = P_natural @ unfactored_x
proj_b_natural = P_natural @ unfactored_y
proj_U_natural = P_natural @ model.W_U.T  

proj_a_standard = P_standard @ unfactored_x
proj_b_standard = P_standard @ unfactored_y
proj_U_standard = P_standard @ model.W_U.T 

proj_a_standard_sign = P_standard_sign @ unfactored_x
proj_b_standard_sign = P_standard_sign @ unfactored_y
proj_U_standard_sign = P_standard_sign @ model.W_U.T

print(f'W_x frobenius norm: {unfactored_x.pow(2).sum()}')
print(f'Projection onto natural representation {proj_a_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_a_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_a_standard_sign.pow(2).sum()}')
print(f'W_y frobenius norm: {unfactored_y.pow(2).sum()}')
print(f'Projection onto natural representation {proj_b_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_b_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_b_standard_sign.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto natural representation {proj_U_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_U_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_U_standard_sign.pow(2).sum()}')


W_x frobenius norm: 98638.125
Projection onto natural representation 91630.59375
Projection onto standard representation 91630.578125
Projection onto standard sign representation 1.8591350453789346e-05
W_y frobenius norm: 98640.140625
Projection onto natural representation 91632.328125
Projection onto standard representation 91632.3125
Projection onto standard sign representation 1.9870731193805113e-05
W_U frobenius norm: 1544.8846435546875
Projection onto natural representation 1394.6727294921875
Projection onto standard representation 1376.19775390625
Projection onto standard sign representation 0.48662275075912476


Is it using a combination of the natural and standard representations? Is it using some combination of the natural, standard and sign reprsentatations? 

What even is the combined rank of all the representations? How orthogonal are they? They are all precisely orthogonal, which is good.

S5 doesnt seem to use the product standard sign rep at all?

In [76]:
reps = [group.standard_reps_orth, group.standard_sign_reps_orth, group.trivial_reps_orth, group.sign_reps_orth]
for i in range(len(reps)):
    for j in range(i+1, len(reps)):
        rep1 = reps[i]
        rep2 = reps[j]
        print((rep1.T@rep2).pow(2).sum())

tensor(1.3128e-13, device='cuda:0')
tensor(6.1134e-13, device='cuda:0')
tensor(5.3533e-13, device='cuda:0')
tensor(5.3533e-13, device='cuda:0')
tensor(6.1134e-13, device='cuda:0')
tensor(0., device='cuda:0')


In [77]:
all_reps = torch.cat([group.standard_reps.reshape(group.order, -1), group.standard_sign_reps.reshape(group.order, -1)], dim=1)
print(all_reps.shape)

u,s,v = torch.linalg.svd(all_reps)
print(s)

torch.Size([120, 32])
tensor([16.8572, 16.8572, 14.3396, 14.3396, 10.4183, 10.4183,  8.8623,  8.8623,
         7.5388,  7.5388,  6.4389,  6.4389,  5.4772,  5.4772,  5.4772,  5.4772,
         5.4772,  5.4772,  5.4772,  5.4772,  4.6592,  4.6592,  3.9794,  3.9794,
         3.3851,  3.3851,  2.8796,  2.8795,  2.0921,  2.0921,  1.7797,  1.7797],
       device='cuda:0')


In [78]:
q = torch.linalg.qr(all_reps)[0]
P_all = projection_matrix_general(q)
assert(torch.allclose(P_all, P_all@P_all, atol=1e-6))
assert(torch.allclose(P_all, P_all.T, atol=1e-6))
proj_a_all = P_all @ model.W_x
proj_b_all = P_all @ model.W_y
proj_U_all = P_all @ model.W_U.T

print(f'W_x frobenius norm: {model.W_x.pow(2).sum()}')
print(f'Projection onto representation {proj_a_all.pow(2).sum()}')
print(f'W_y frobenius norm: {model.W_y.pow(2).sum()}')
print(f'Projection onto representation {proj_b_all.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto representation {proj_U_all.pow(2).sum()}')

u,s,v = torch.linalg.svd(all_reps)
P_all = projection_matrix_general(u[:, :18])
assert(torch.allclose(P_all, P_all@P_all, atol=1e-6))
assert(torch.allclose(P_all, P_all.T, atol=1e-6))
proj_a_all = P_all @ model.W_x
proj_b_all = P_all @ model.W_y
proj_U_all = P_all @ model.W_U.T

print(f'W_x frobenius norm: {model.W_x.pow(2).sum()}')
print(f'Projection onto representation {proj_a_all.pow(2).sum()}')
print(f'W_y frobenius norm: {model.W_y.pow(2).sum()}')
print(f'Projection onto representation {proj_b_all.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto representation {proj_U_all.pow(2).sum()}')



W_x frobenius norm: 922.520751953125
Projection onto representation 836.4886474609375
W_y frobenius norm: 922.8448486328125
Projection onto representation 836.8963012695312
W_U frobenius norm: 1544.8846435546875
Projection onto representation 1376.684326171875
W_x frobenius norm: 922.520751953125
Projection onto representation 590.4744873046875
W_y frobenius norm: 922.8448486328125
Projection onto representation 343.4208984375
W_U frobenius norm: 1544.8846435546875
Projection onto representation 771.437744140625


In [177]:
u, s, v = torch.linalg.svd(unfactored_x)
print(s)

tensor([1.2942e+02, 1.2764e+02, 1.2611e+02, 1.2509e+02, 8.4624e+01, 8.3712e+01,
        8.3519e+01, 8.1091e+01, 7.9507e+01, 2.6003e-02, 2.3604e-02, 1.9426e-02,
        1.1965e-02, 6.5083e-03, 3.1808e-03, 2.9643e-03, 2.6949e-03, 2.5449e-03,
        2.2842e-03, 2.0893e-03, 1.9290e-03, 1.7134e-03, 1.7075e-03, 1.6130e-03,
        1.5375e-03, 1.4819e-03, 1.4102e-03, 1.3262e-03, 1.3051e-03, 1.2371e-03,
        1.2120e-03, 1.1246e-03, 1.0774e-03, 1.0263e-03, 9.8877e-04, 9.3415e-04,
        8.7332e-04, 8.2646e-04, 8.0076e-04, 7.7537e-04, 7.4592e-04, 7.3031e-04,
        7.0730e-04, 6.6791e-04, 6.2660e-04, 6.2035e-04, 6.0258e-04, 5.6502e-04,
        5.5442e-04, 5.4421e-04, 5.1126e-04, 4.8851e-04, 4.6118e-04, 4.4546e-04,
        4.3492e-04, 4.1673e-04, 4.0156e-04, 3.7691e-04, 3.5005e-04, 3.3972e-04,
        3.2663e-04, 3.1963e-04, 3.1669e-04, 2.9690e-04, 2.8279e-04, 2.6633e-04,
        2.6012e-04, 2.5227e-04, 2.3989e-04, 2.2389e-04, 2.1151e-04, 1.9241e-04,
        1.9030e-04, 1.7695e-04, 1.7193e-

 ### Variance
 
Make sure that each of the representations actually uses all the information in equal parts from each of the dim^2 representation matrix elements. Do this by checking that the standard deviation of percentage contributions of each column is small.

In [79]:
# method 2
rep = group.standard_reps_orth.reshape(group.order, -1)
rep = group.natural_reps.reshape(group.order, -1)
dims = rep.shape[1]


total_norm_x = unfactored_x.pow(2).sum()
contsx = []
for i in range(dims):
    x = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(x)
    proj_x_standard = P @ unfactored_x
    proj_x_standard_square = proj_x_standard.pow(2)
    contsx.append((proj_x_standard_square.sum() / total_norm_x).item())
contsx = torch.tensor(contsx)
print(f'conts: {contsx}')
print(f'std: {contsx.std()}')
print(f'sum: {contsx.sum()}')
line(contsx)

total_norm_y = unfactored_y.pow(2).sum()
contsy = []
for i in range(dims):
    y = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(y)
    proj_y_standard = P @ unfactored_y
    proj_y_standard_square = proj_y_standard.pow(2)
    contsy.append((proj_y_standard_square.sum() / total_norm_y).item())
contsy = torch.tensor(contsy)
print(f'conts: {contsy}')
print(f'std: {contsy.std()}')
print(f'sum: {contsy.sum()}')
line(contsy)

conts = [] 
total_norm_U = model.W_U.pow(2).sum()
for i in range(dims):
    y = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(y)
    proj_U = P @ model.W_U.T 
    proj_U_square = proj_U.pow(2)
    conts.append((proj_U_square.sum() / total_norm_U).item())
conts = torch.tensor(conts)
print(f'conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)

conts: tensor([0.0607, 0.0400, 0.0105, 0.0833, 0.0357, 0.0631, 0.0387, 0.0109, 0.0863,
        0.0351, 0.0594, 0.0422, 0.0103, 0.0816, 0.0378, 0.0595, 0.0422, 0.0104,
        0.0814, 0.0386, 0.0618, 0.0399, 0.0108, 0.0845, 0.0365])
std: 0.024995392188429832
sum: 1.1611964702606201


conts: tensor([0.0604, 0.0619, 0.0622, 0.0595, 0.0606, 0.0413, 0.0396, 0.0396, 0.0411,
        0.0415, 0.0105, 0.0107, 0.0108, 0.0103, 0.0106, 0.0826, 0.0847, 0.0850,
        0.0818, 0.0829, 0.0375, 0.0357, 0.0363, 0.0365, 0.0378])
std: 0.024974582716822624
sum: 1.1611945629119873


conts: tensor([0.0469, 0.0469, 0.0473, 0.0489, 0.0481, 0.0465, 0.0460, 0.0457, 0.0474,
        0.0464, 0.0467, 0.0466, 0.0470, 0.0469, 0.0469, 0.0462, 0.0474, 0.0476,
        0.0463, 0.0460, 0.0463, 0.0479, 0.0462, 0.0483, 0.0471])
std: 0.0007871885318309069
sum: 1.1733055114746094


## Understanding the hidden layer

Under the hypothesis that the network is embedding as representations, We should expect the hidden layer to contain terms like "rho(x)rho(y)".

Consider a model run on all the (group.order * group.order) inputs. The hidden layer is then a (group.order * group.order, hidden) tensor. 

Assuming some representation rho, we should expect to find a (group.order * group.order, rep_dim, rep_dim) = (group.order * group.order, rep_dim * rep_dim) tensor, with rep_dim << hidden. 

We can perform a similar projection trick here as we did on the embeddings.

We center this hidden neurons. I'm actually not sure how important this is.

In [80]:
hidden = activations['hidden']
hidden = hidden.reshape(group.order*group.order, -1)
mean = hidden.mean(dim=-1, keepdim=True)
hidden_centred = hidden - mean
u, s, v = torch.linalg.svd(hidden_centred)
#line(s)

logits = hidden @ model.W_U
print(hidden.shape, logits.shape)

loss = loss_fn(logits, all_labels)

centred_logits = hidden_centred @ model.W_U
centred_loss = loss_fn(centred_logits, all_labels)

print(f'Loss: {loss}')
print(f'Centred Hidden loss: {centred_loss}')

# i think this is close enough?

torch.Size([14400, 128]) torch.Size([14400, 120])
Loss: 8.072951800386363e-07
Centred Hidden loss: 1.2235090025569662e-06


In [81]:
den = hidden_centred.pow(2).sum().item()
reps = {'trivial': group.trivial_reps, 
'sign': group.sign_reps, 
'standard': group.standard_reps, 
'standard_sign' : group.standard_sign_reps}

total = 0
for name, rep in reps.items():
    hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
    hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
    hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)
    hidden_reps_xy = torch.linalg.qr(hidden_reps_xy)[0]
    hidden_reps_x = torch.linalg.qr(hidden_reps_x)[0]
    hidden_reps_y = torch.linalg.qr(hidden_reps_y)[0]
    P_xy = projection_matrix_general(hidden_reps_xy)
    P_x = projection_matrix_general(hidden_reps_x) 
    P_y = projection_matrix_general(hidden_reps_y)
    proj_xy = P_xy @ hidden_centred
    u,s,v = torch.linalg.svd(proj_xy)
    proj_x = P_x @ hidden_centred
    proj_y = P_y @ hidden_centred
    num_xy = proj_xy.pow(2).sum()
    num_x = proj_x.pow(2).sum()
    num_y = proj_y.pow(2).sum()
    percent_x = num_x / den
    percent_y = num_y / den
    percent_xy = num_xy / den
    total += percent_x
    total += percent_y
    total += percent_xy
    
    print(f'{name}, x: {percent_x}, y: {percent_y}, xy: {percent_xy}')
print(f'Total percent explained by reps: {total}')

trivial, x: 0.11319603770971298, y: 0.11319603770971298, xy: 0.11319603770971298
sign, x: 0.02269669435918331, y: 0.022697651758790016, xy: 0.022696176543831825
standard, x: 0.3253755569458008, y: 0.3145125210285187, xy: 0.09184902906417847
standard_sign, x: 7.539092128361347e-10, y: 7.658595424508974e-10, xy: 2.1627967472515053e-10
Total percent explained by reps: 1.1394157409667969


In [183]:
torch.dot(hidden_reps_x[:, 0], hidden_reps_y[:, 1])

tensor(-2.2352e-08, device='cuda:0')

Claim: The output logits should not depend on the directions corresponding to x or y, but only those corresponding to xy. Let's test this. To do so, we first just ablate those directions and see what happens to loss. If we ablate the xy direction we should see loss tank.


In [82]:
logits = hidden @ model.W_U
loss = loss_fn(logits, all_labels)
print(f'baseline loss: {loss}')


for name, rep in reps.items():
    hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
    hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
    hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)
    hidden_reps_xy = torch.linalg.qr(hidden_reps_xy)[0]
    hidden_reps_x = torch.linalg.qr(hidden_reps_x)[0]
    hidden_reps_y = torch.linalg.qr(hidden_reps_y)[0]
    P_xy = projection_matrix_general(hidden_reps_xy)
    P_x = projection_matrix_general(hidden_reps_x) 
    P_y = projection_matrix_general(hidden_reps_y)
    proj_xy = P_xy @ hidden_centred
    hidden_centred_xy = hidden_centred - proj_xy
    proj_x = P_x @ hidden_centred
    hidden_centred_x = hidden_centred - proj_x
    proj_y = P_y @ hidden_centred
    hidden_centred_y = hidden_centred - proj_y

    logits_xy = hidden_centred_xy @ model.W_U
    loss_xy = loss_fn(logits_xy, all_labels)
    logits_x = hidden_centred_x @ model.W_U
    loss_x = loss_fn(logits_x, all_labels)
    logits_y = hidden_centred_y @ model.W_U
    loss_y = loss_fn(logits_y, all_labels)

    print(f'Ablating directions corresponding to {name} rep loss, xy: {loss_xy}, x: {loss_x}, y: {loss_y}')
    

baseline loss: 8.072951800386363e-07
Ablating directions corresponding to trivial rep loss, xy: 8.28512781936297e-07, x: 8.28512781936297e-07, y: 8.28512781936297e-07
Ablating directions corresponding to sign rep loss, xy: 0.0009347845334559679, x: 1.2203962569401483e-06, y: 1.219121372741938e-06
Ablating directions corresponding to standard rep loss, xy: 8.873103141784668, x: 1.1757592801586725e-06, y: 1.1799315871030558e-06
Ablating directions corresponding to standard_sign rep loss, xy: 1.2235090025569662e-06, x: 1.2234841051395051e-06, y: 1.2234759196871892e-06


We show that the the spaces corresponding to x and y representations in the hidden layer are in the kernel of W_U.

In [83]:
print('Mean norms of unembedded representation vectors')
for name, rep in reps.items():
    hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
    hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
    hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)
    
    P_xy = projection_matrix_general(hidden_reps_xy)
    P_x = projection_matrix_general(hidden_reps_x) 
    P_y = projection_matrix_general(hidden_reps_y)
    proj_xy = P_xy @ hidden_centred
    proj_x = P_x @ hidden_centred
    proj_y = P_y @ hidden_centred

    unembed_xy = proj_xy @ model.W_U
    unembed_x = proj_x @ model.W_U
    unembed_y = proj_y @ model.W_U

    print(f'{name} rep, xy: {unembed_xy.pow(2).mean()}, x: {unembed_x.pow(2).mean()}, y: {unembed_y.pow(2).mean()}')


Mean norms of unembedded representation vectors
trivial rep, xy: 3.7066051959991455, x: 3.7066051959991455, y: 3.7066051959991455
sign rep, xy: 17.681011199951172, x: 0.0025135576725006104, y: 0.0028997936751693487
standard rep, xy: 41.64058303833008, x: 0.19304227828979492, y: 0.10517318546772003
standard_sign rep, xy: 1.3805200038685683e-10, x: 9.253368737915935e-09, y: 1.1530609356213972e-08


### Variance

Are some matrix elements of xy in the standard representation preferred over others in the hidden layer? Yes. I claim this is actually pretty deep.

Let $e_i$ be the standard basis matrices. Let $\mathbb{x} := \lambda_i x_i \mathbb{e}_i$. Let $\mathbb{y} := \delta_i y_i \mathbb{e}_i$ be such that the mean of $x_i$ and mean of $y_i$ is one over all inputs for fixed i, i.e. $\lambda_i,\delta_i$ encode an i-dependent scaling factor that we calcualate below. Then consider the matrix product $\mathbb{x}\mathbb{y} = \mathbb{z}$ = $z_i \mathbb{e}_i$. 

Then assuming 5d matrices, eg $z_0 = \lambda_0 \delta_0 x_0 y_0 + \lambda_1 \delta_5 x_1 y_5 + \lambda_2 \delta_{10} x_2 y_{10} + \lambda_3 \delta_15 x_3 y_{15} + \lambda_4 \delta_{20} x_4 y_{20}$.

This only makes sense dimensionally if $\delta_0 = \delta_5 = \delta_{10} = \delta_{15} = \delta_{20}$ and $\lambda_0 = \lambda_1 = \lambda_2 = \lambda_3 = \lambda_4$.



In [84]:
rep = group.standard_reps_orth.reshape(group.order, -1)
print(rep)
dims = rep.shape[1]

hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)

#P_xy = projection_matrix_general(hidden_reps_xy)
#P_x = projection_matrix_general(hidden_reps_x) 
#P_y = projection_matrix_general(hidden_reps_y)
#proj_xy = P_xy @ hidden_centred
#proj_x = P_x @ hidden_centred
#
#proj_y = P_y @ hidden_centred

total_norm = hidden_centred.pow(2).sum()

conts = []
for i in range(dims):
    a = hidden_reps_x[:, i].unsqueeze(-1)
    P_x = projection_matrix_general(a)
    proj = P_x @ hidden_centred
    proj_square = proj.pow(2)
    conts.append((proj_square.sum() / total_norm).item())
conts = torch.tensor(conts)
print(f'x conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)

conts = []
for i in range(dims):
    a = hidden_reps_y[:, i].unsqueeze(-1)
    P_y = projection_matrix_general(a)
    proj = P_y @ hidden_centred
    proj_square = proj.pow(2)
    conts.append((proj_square.sum() / total_norm).item())
conts = torch.tensor(conts)
print(f'y conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)


conts = []
for i in range(dims):
    a = hidden_reps_xy[:, i].unsqueeze(-1)
    P_xy = projection_matrix_general(a)
    proj = P_xy @ hidden_centred
    proj_square = proj.pow(2)
    conts.append((proj_square.sum() / total_norm).item())
conts = torch.tensor(conts)
print(f'xy conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)

tensor([[-1.4434e-01,  1.1180e-01, -2.6202e-08,  ..., -5.8926e-02,
          8.3333e-02, -1.4434e-01],
        [ 0.0000e+00, -1.4907e-01,  1.0541e-01,  ..., -5.8926e-02,
          8.3333e-02,  1.4434e-01],
        [ 0.0000e+00,  0.0000e+00, -1.5811e-01,  ...,  4.8429e-08,
         -1.8626e-09,  2.2352e-08],
        ...,
        [ 0.0000e+00,  1.4907e-01, -1.0541e-01,  ..., -5.8926e-02,
          8.3333e-02, -1.4434e-01],
        [ 0.0000e+00,  0.0000e+00,  1.5811e-01,  ..., -5.8926e-02,
          8.3333e-02,  1.4434e-01],
        [ 1.4434e-01,  3.7268e-02,  5.2705e-02,  ...,  1.7678e-01,
          3.7253e-09,  5.5879e-09]], device='cuda:0')
x conts: tensor([0.0258, 0.0236, 0.0050, 0.0281, 0.0241, 0.0240, 0.0051, 0.0285, 0.0241,
        0.0240, 0.0052, 0.0290, 0.0259, 0.0206, 0.0050, 0.0275])
std: 0.009359377436339855
sum: 0.3253755569458008


y conts: tensor([0.0163, 0.0153, 0.0096, 0.0118, 0.0155, 0.0162, 0.0162, 0.0158, 0.0342,
        0.0381, 0.0340, 0.0354, 0.0109, 0.0123, 0.0173, 0.0156])
std: 0.009695964865386486
sum: 0.3145125210285187


xy conts: tensor([0.0058, 0.0058, 0.0058, 0.0057, 0.0056, 0.0058, 0.0058, 0.0056, 0.0057,
        0.0057, 0.0058, 0.0057, 0.0057, 0.0058, 0.0058, 0.0058])
std: 6.428975029848516e-05
sum: 0.09184902906417847


## Neuron Clustering

Before the ReLU, do the nurons cluster neatly into representations?

In [163]:
embed_dim = model.W_x.shape[1]
unfactored_x = model.W_x @ model.W[:embed_dim, :]
unfactored_x_summed = unfactored_x.pow(2).sum(dim=0)
off_neurons_x = (unfactored_x_summed < 1e-5).nonzero().squeeze()

unfactored_y = model.W_y @ model.W[embed_dim:, :]
unfactored_y_summed = unfactored_y.pow(2).sum(dim=0)
off_neurons_y = (unfactored_y_summed < 1e-5).nonzero().squeeze()


assert (off_neurons_x == off_neurons_y).all()

off_neurons = off_neurons_x

print(f'Off neurons: {off_neurons}')


P_sign = projection_matrix_general(group.sign_reps_orth)
P_standard = projection_matrix_general(group.standard_reps_orth)

proj_x_sign = P_sign @ unfactored_x
proj_x_sign_summed = proj_x_sign.pow(2).sum(dim=0)
proj_x_standard = P_standard @ unfactored_x
proj_x_standard_summed = proj_x_standard.pow(2).sum(dim=0)

sign_neurons_x = (proj_x_sign_summed > 1e-5).nonzero().squeeze()
standard_neurons_x = (proj_x_standard_summed > 1e-5).nonzero().squeeze()

proj_y_sign = P_sign @ unfactored_y
proj_y_sign_summed = proj_y_sign.pow(2).sum(dim=0)
proj_y_standard = P_standard @ unfactored_y
proj_y_standard_summed = proj_y_standard.pow(2).sum(dim=0)

sign_neurons_y = (proj_y_sign_summed > 1e-5).nonzero().squeeze()
standard_neurons_y = (proj_y_standard_summed > 1e-5).nonzero().squeeze()

assert (sign_neurons_x == sign_neurons_y).all()
assert (standard_neurons_x == standard_neurons_y).all()

sign_neurons = sign_neurons_x
standard_neurons = standard_neurons_x

print(f'Sign neurons: {sign_neurons}')
print(f'Standard neurons: {standard_neurons}')

all_neurons = torch.arange(model.W.shape[1])
unaccounted_neurons = set(all_neurons.tolist()) - set(off_neurons.tolist()) - set(sign_neurons.tolist()) - set(standard_neurons.tolist())

print(f'Unaccounted neurons: {unaccounted_neurons}')



Off neurons: tensor([  1,   8,  13,  14,  35,  37,  56,  64,  68,  77,  93,  95, 108, 114],
       device='cuda:0')
Sign neurons: tensor([ 24,  34,  58,  61,  66,  73,  79,  81,  85, 107], device='cuda:0')
Standard neurons: tensor([  0,   2,   3,   4,   5,   6,   7,   9,  10,  11,  12,  15,  16,  17,
         18,  19,  20,  21,  22,  23,  25,  26,  27,  28,  29,  30,  31,  32,
         33,  34,  36,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,
         49,  50,  51,  52,  53,  54,  55,  57,  59,  60,  62,  63,  65,  67,
         69,  70,  71,  72,  74,  75,  76,  78,  79,  80,  82,  83,  84,  86,
         87,  88,  89,  90,  91,  92,  94,  96,  97,  98,  99, 100, 101, 102,
        103, 104, 105, 106, 107, 109, 110, 111, 112, 113, 115, 116, 117, 118,
        119, 120, 121, 122, 123, 124, 125, 126, 127], device='cuda:0')
Unaccounted neurons: set()


## Signature Representation Circuit Analysis

TLDR: the network memorises input signatures, mangles them together using the relu, and uses this to divide the output search space in two.


We know the S5 network uses the sign representation to some extent. Let's find this in the x and y embedding in the first instance. We'll later want to find this in the hidden layer too, and in the logits. 

We find in our run, (1L_MLP_sym_S5_cached_3), neurons 24, 66, 73, and 81 have high sign representation projection. Let's now print inspect these neurons. It's curious there are 4 of these.


In [85]:
embed_dim = model.W_x.shape[1]
unfactored_x = model.W_x @ model.W[:embed_dim, :]
P_sign = projection_matrix_general(group.sign_reps_orth)
proj_x_sign = P_sign @ unfactored_x
proj_x_sign_summed = proj_x_sign.pow(2).sum(dim=0)
print(torch.topk(proj_x_sign_summed, 10))
signature_neurons_x = torch.topk(proj_x_sign_summed, 4).indices


torch.return_types.topk(
values=tensor([1.7570e+03, 1.7540e+03, 1.7494e+03, 1.7472e+03, 9.6021e-03, 5.0839e-03,
        3.1622e-03, 7.8594e-04, 5.9167e-05, 5.3620e-05], device='cuda:0',
       grad_fn=<TopkBackward0>),
indices=tensor([ 24,  66,  73,  81,  58,  61,  85,  34, 107,  79], device='cuda:0'))


In [86]:
embed_dim = model.W_y.shape[1]
unfactored_y = model.W_y @ model.W[embed_dim:, :]
P_sign = projection_matrix_general(group.sign_reps_orth)
proj_y_sign = P_sign @ unfactored_y
proj_y_sign_summed = proj_y_sign.pow(2).sum(dim=0)
print(torch.topk(proj_y_sign_summed, 10))
signature_neurons_y = torch.topk(proj_y_sign_summed, 4).indices


torch.return_types.topk(
values=tensor([1.7570e+03, 1.7541e+03, 1.7495e+03, 1.7472e+03, 9.6025e-03, 5.0842e-03,
        3.1622e-03, 7.8592e-04, 5.9173e-05, 5.3615e-05], device='cuda:0',
       grad_fn=<TopkBackward0>),
indices=tensor([ 24,  66,  73,  81,  58,  61,  85,  34, 107,  79], device='cuda:0'))


In [87]:
assert (signature_neurons_x == signature_neurons_y).all()
signature_neurons = signature_neurons_x


In [88]:
print(signature_neurons)
sig_labels = [str(x) for x in signature_neurons.tolist()]

tensor([24, 66, 73, 81], device='cuda:0')


The first row is the sign represntation. The next four rows are the x embeddings. The next image is the same but for y embeddings. Let a = 3.8. We learn that (for the same cached S5 model)

neuron 24 = $a ReLU (-sign(x) + sign(y)) > 0$ iff x odd and y even ie z odd

neuron 66 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 73 = $a ReLU (sign(x) - sign(y)) > 0$ iff x even and y odd ie z odd

neuron 81 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even



In 1L_MLP_cached_S4 we find

neuron 3 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 38 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 6 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even

neuron 29 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even


In [89]:
sigs = group.signatures.unsqueeze(-1)
xs = unfactored_x[:, signature_neurons]
stack = torch.hstack([sigs, xs]).T
imshow(stack, y=['sig'] + sig_labels, input2='input group element')
ys = unfactored_y[:, signature_neurons]
stack = torch.hstack([sigs, ys]).T
imshow(stack, y=['sig'] + sig_labels, input2='input group element')


How does this effect logits? We simply look at the relevant unembed rows. We expect rows 66 and 81 to correlate and 24 and 73 to anticorrelate with the (now output) signature. Thus the logits align with the trace formula.

In [90]:
sigs = group.signatures.unsqueeze(-1)
W_U_signatures = model.W_U[signature_neurons, :].T
#W_U_24 = model.W_U[24, :].unsqueeze(-1)
#W_U_66 = model.W_U[66, :].unsqueeze(-1)
#W_U_73 = model.W_U[73, :].unsqueeze(-1)
#W_U_81 = model.W_U[81, :].unsqueeze(-1)
#stack = torch.hstack([sigs, W_U_24, W_U_66, W_U_73, W_U_81]).T
stack = torch.hstack([sigs, W_U_signatures]).T

imshow(stack, y=['sig'] + sig_labels, input2='output group element')

## Standard Representation Circuit Analysis

We know the S5 network uses the standard representation too. Let's find this in the x and y embedding, and in the hidden layer

### Extract representations from embeddings.

Our unfactored embedding matrix has size (group.order, hidden). The representations can be encoded in a (group.order, dim^2) tensor. I want the coefficients on this dim^2 subspace for some given input group element. This requires identifying the linear map hidden -> dim^2. Can I just take embedding.T @ rep.

Let's try

Another approach is to use the singular value of the embedding as the coeffiecent, and the left singular vector as the indicator for whether that coefficient is used.

I'm likely overcomplicating this. I suspect all i need is a matrix multiply from an orthogonal set of vectors

### Janky method 3 (which i think is least janky)

In [91]:
rep = group.standard_reps.reshape(group.order, -1)
dims = rep.shape[1]
dim = int(math.sqrt(dims))
contsx = []

mats_x = []
mats_y = []



scalings = torch.ones(dims).cuda()
scalings[[7,4,9,12,15,6,14]] = -1


for idx in range(5):
    mat = []
    for i in range(dims):
        x = rep[:, i].unsqueeze(-1)
        P = projection_matrix_general(x)
        proj_x = P @ unfactored_x
        u, s, v = torch.svd(proj_x)
        mat.append(scalings[i]*u[idx, 0])
    mat = torch.tensor(mat)
    matx = mat.reshape(dim, dim)
    imshow(matx, title='network')
    imshow(rep[idx].reshape(dim, dim), title='real')
    mats_x.append(matx)

#for idx in range(3):
#    mat = []
#    for i in range(dims):
#        x = rep[:, i].unsqueeze(-1)
#        P = projection_matrix_general(x)
#        proj_y = P @ unfactored_y
#        u, s, v = torch.svd(proj_y)
#        mat.append(s[0]*u[idx, 0])
#    mat = torch.tensor(mat)
#    matx = mat.reshape(dim, dim)
#    imshow(matx, title='network')
#    imshow(rep[idx].reshape(dim, dim), title='real')
#    mats_y.append(matx)


### Look at the hidden layer

In [92]:
rep = group.standard_reps.reshape(group.order, -1)
dims = rep.shape[1]
dim = int(math.sqrt(dims))

hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)

hidden = activations['hidden'].reshape(group.order*group.order, -1)


# mats_xy = []
# for idx in range(2):
#     mat = []
#     for i in range(dims):
#         a = hidden_reps_xy[:, i].unsqueeze(-1)
#         P_xy = projection_matrix_general(a)
#         proj = P_xy @ hidden
#         u, s, v = torch.svd(proj)
#         mat.append(s[0]*u[idx, 0])
#     mat = torch.tensor(mat)
#     mat = mat.reshape(dim, dim)
#     imshow(mat, title='network hidden xy')
#     m_x = mats_x[all_data[idx, 0]]
#     m_y = mats_y[all_data[idx, 1]]
#     imshow(m_x @ m_y, title='network x @ y')
#     imshow(hidden_reps_xy[idx].reshape(dim, dim), title='real hidden xy')


In [93]:
# probably wrong

rep = group.standard_reps.reshape(group.order, -1)
W_U = model.W_U
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)

hidden_to_reps = hidden_reps_xy.T @ hidden

print(hidden_to_reps.shape)

hidden_in_reps = hidden @ hidden_to_reps.T

print(hidden_in_reps.shape)
imshow(hidden_in_reps[:50])


torch.Size([16, 128])
torch.Size([14400, 16])


### Logit Computation

Try and understand hidden -> logit computation. Should look at the standard representation basis here. We expect the map to be v -> tr (v \rho(z^-1)). Note that some neurons seem completely off in this basis. I assert these are the ones in completely different representations. The signature neurons are notably absent. Others that are absent are the neurons which are just not used.

The below plot can be produced as a line for the signature representation too!


In [179]:
rep = group.standard_reps.reshape(group.order, -1)
#rep = torch.randn_like(rep)
W_U = model.W_U
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)

hidden_to_reps = hidden_reps_xy.T @ hidden
hidden_to_reps_old = hidden_to_reps

print(hidden_to_reps.shape)
imshow(hidden_to_reps, title='Change of basis from the standard basis on the hidden layer to that of the representation space')

W_U_rep = hidden_to_reps @ W_U

print(W_U_rep.shape)
imshow(W_U_rep, title='Unembedding matrix in the representation space')

torch.Size([1, 128])


ValueError: px.imshow only accepts 2D single-channel, RGB or RGBA images. An image of shape (128,) was provided. Alternatively, 3- or 4-D single or multichannel datasets can be visualized using the `facet_col` or/and `animation_frame` arguments.

In [95]:
imshow(rep.T, title='Output representation space')

It looks like the network is not actually inverting z. We see the final linear map in the representation space as being v -> tr (vz). This, combined with the observation that the matrix product seems the wrong way around, suggests the actual algorithm being used is 

$ \argmax ( tr (\rho (y^-1) \rho (x^-1) \rho (z)) ) $

To check this, let's first verify the space of inverse representations is the same as the space of representations, by projecting the inverse representations onto the representation space. This isn't surprising.


In [96]:
group.compute_inverse_reps()

reps_orth = group.standard_reps_orth
inverse_reps = group.standard_inverse_reps.reshape(group.order, -1)

P = projection_matrix_general(reps_orth)

print(f'Norm of inverse rep: {torch.norm(inverse_reps)}')

proj = P @ inverse_reps

print(f'Norm of projection of inverse rep: {torch.norm(proj)}')

Norm of inverse rep: 30.983867645263672
Norm of projection of inverse rep: 30.983867645263672


Now reproduce above plot indicating how much of embeddings are explained by each of the inverse representation matrix elements, for x and y. We see now x is constant on rows and y is constant on columns. This suggests it is indeed learning the representations \rho(x^-1) and \rho(y^-1), by our definition of rho. Note this is really the same representation, and just an implementation detail. z needs to be constant across all elements, as tr(* z) uses all elements of z equally.

In [97]:
# method 2
rep = group.standard_inverse_reps.reshape(group.order, -1)
dims = rep.shape[1]

contsx = []
for i in range(dims):
    x = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(x)
    proj_x_standard = P @ unfactored_x
    proj_x_standard_square = proj_x_standard.pow(2)
    contsx.append((proj_x_standard_square.sum() / total_norm_x).item())
contsx = torch.tensor(contsx)
print(f'conts: {contsx}')
print(f'std: {contsx.std()}')
print(f'sum: {contsx.sum()}')
line(contsx)

contsy = []
for i in range(dims):
    y = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(y)
    proj_y_standard = P @ unfactored_y
    proj_y_standard_square = proj_y_standard.pow(2)
    contsy.append((proj_y_standard_square.sum() / total_norm_y).item())
contsy = torch.tensor(contsy)
print(f'conts: {contsy}')
print(f'std: {contsy.std()}')
print(f'sum: {contsy.sum()}')
line(contsy)

conts = [] 
total_norm_U = model.W_U.pow(2).sum()
for i in range(dims):
    y = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(y)
    proj_U = P @ model.W_U.T 
    proj_U_square = proj_U.pow(2)
    conts.append((proj_U_square.sum() / total_norm_U).item())
conts = torch.tensor(conts)
print(f'conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)

conts: tensor([0.0325, 0.0332, 0.0334, 0.0335, 0.0280, 0.0281, 0.0281, 0.0280, 0.0191,
        0.0192, 0.0191, 0.0193, 0.0730, 0.0742, 0.0740, 0.0750])
std: 0.0217499528080225
sum: 0.6176031231880188


conts: tensor([0.0765, 0.1139, 0.0771, 0.0458, 0.0780, 0.1143, 0.0764, 0.0447, 0.0761,
        0.1137, 0.0770, 0.0455, 0.0747, 0.1133, 0.0776, 0.0465])
std: 0.024946345016360283
sum: 1.2510803937911987


conts: tensor([0.0555, 0.0553, 0.0557, 0.0559, 0.0556, 0.0551, 0.0556, 0.0561, 0.0575,
        0.0558, 0.0558, 0.0560, 0.0581, 0.0559, 0.0560, 0.0566])
std: 0.0007709548226557672
sum: 0.8963699340820312


Now given we really are computing $\rho(y^{-1} x^{-1}) = \rho( (xy)^{-1} )$, let's recompute the change of basis for the hidden layer. 

We can either invert the inputs, or invert the representations. We elect to invert the rep.

In [98]:
inverse_rep = group.standard_inverse_reps.reshape(group.order, -1)

W_U = model.W_U

hidden_reps_y_inv_x_inv = inverse_rep[all_labels].reshape(group.order*group.order, -1)

hidden = activations['hidden'].reshape(group.order*group.order, -1)

hidden_to_reps = hidden_reps_y_inv_x_inv.T @ hidden

imshow(hidden_to_reps, title='Change of basis from the standard basis on the hidden layer to that of the inverse representation space')
imshow(hidden_to_reps_old, title='Change of basis from the standard basis on the hidden layer to that of the representation space')
W_U_rep = hidden_to_reps @ W_U

imshow(W_U_rep, title='Unembedding matrix in the representation space')

In [99]:
imshow(inverse_rep.T)

### Finding a more interpretable basis for the representation space

To interpret the hidden layer, which is the only basis dependent component of this circuit, we should try and find a nicer basis for it. it doesn't really matter whether we start with xy or (xy)^-1. The plot we produced above can be thought of as "how much does each neuron represent certain representation directions"


In [188]:
rep = group.standard_reps.reshape(group.order, -1)
#rep = torch.randn_like(rep)
W_U = model.W_U
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_to_reps = hidden_reps_xy.T @ hidden

imshow(hidden_to_reps)

In [184]:
for i in range(0):
    line(hidden_to_reps[:, i])

In [194]:
rep = group.standard_reps.reshape(group.order, -1)
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
P = projection_matrix_general(hidden_reps_xy)
hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_xy = P @ hidden
print(hidden_xy.shape)
print(hidden_to_reps.shape)
hidden_in_rep = hidden_xy @ hidden_to_reps.T

imshow(hidden_in_rep[:10])
imshow(rep[all_labels].reshape(group.order*group.order, -1)[:10])

torch.Size([14400, 128])
torch.Size([16, 128])
