# Setup

In [34]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

Fri Dec 30 12:28:08 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 46%   37C    P8     7W / 105W |   1940MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
import sage

import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *
from utils.config import *

In [36]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Things might be rather slow')

Good to go!


# Interpretability




## Interpretability Set Up 

In [37]:
task_dir = "experiments/BilinearNet_S5_cached" #1L_MLP_sym_S5_cached_3"
seed, frac_train, layers, lr, group_param, weight_decay, num_epochs, group_type, architecture_type, metric_cfg, metric_obj = load_cfg(task_dir)
group = group_type(group_param, init_all = True)
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
model = architecture_type(layers, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"), strict=False)
model.eval()
logits, activations = model.run_with_cache(all_data, return_cache_object=False)
activations['logits'] = logits
metric_obj = metric_obj(group, training=False, track_metrics = True)


Computing multiplication table...
... loading from file
Computing trace tensor cube for trivial representation
... loading from file
Computing trace tensor cube for sign representation
... loading from file
Computing trace tensor cube for standard representation
... loading from file
Computing trace tensor cube for standard_sign representation
... loading from file
Computing trace tensor cube for s5_5d_a representation
... loading from file
Computing trace tensor cube for s5_5d_b representation
... loading from file
Computing trace tensor cube for s5_6d representation
... loading from file


In [38]:
print(activations['hidden'])

tensor([[0., -0., 0.,  ..., -0., 0., -0.],
        [0., -0., -0.,  ..., 0., 0., 0.],
        [-0., -0., 0.,  ..., -0., -0., 0.],
        ...,
        [-0., 0., -0.,  ..., -0., 0., 0.],
        [-0., -0., 0.,  ..., 0., 0., 0.],
        [0., -0., 0.,  ..., -0., -0., -0.]], device='cuda:0')


In [39]:
print(model.parameters)

<bound method Module.parameters of BilinearNet(
  (hidden): HookPoint()
)>


## Visualise the Weights


In [40]:
imshow(model.W_x)
imshow(model.W_y)
if model.__class__.__name__=="OneLayerMLP":
    imshow(model.W)
imshow(model.W_U)

## Losses on various subsets


In [41]:
# everything
all_indices = np.arange(group.order)
all_loss = loss_fn(logits, all_labels)
all_loss

tensor(6.1661e-06, device='cuda:0', grad_fn=<NllLossBackward0>)

In [42]:
# alternating group
alternating_indices = [i for i in range(group.order) if group.signature(i) == 1]
alternating_data = group.get_subset_of_data(alternating_indices).cuda()
alternating_data, alternating_labels = alternating_data[:, :2], alternating_data[:, 2]
alternating_logits = model(alternating_data)
alternating_loss = loss_fn(alternating_logits, alternating_labels).item()
alternating_loss

6.154123184387572e-06

In [43]:
# the alternating coset
alt_coset_indices = [i for i in range(group.order) if group.signature(i) == -1]
alt_coset_data = group.get_subset_of_data(alt_coset_indices).cuda()
alt_coset_data, alt_coset_labels = alt_coset_data[:, :2], alt_coset_data[:, 2]
alt_coset_logits = model(alternating_data)
alt_coset_loss = loss_fn(alternating_logits, alternating_labels).item()
alt_coset_loss

6.154123184387572e-06

In [44]:
# alternating, coset
data = group.get_subset_of_data(alternating_indices, alt_coset_indices).cuda()
data, labels = data[:, :2], data[:, 2]
logits = model(data)
loss = loss_fn(logits, labels).item()
loss


6.193130502651911e-06

In [45]:
# coset, alternating
data = group.get_subset_of_data(alt_coset_indices, alternating_indices ).cuda()
data, labels = data[:, :2], data[:, 2]
logits = model(data)
loss = loss_fn(logits, labels).item()
loss


6.169785137899453e-06

In [46]:
# line plot of losses on losses of all elements vs all others
losses=[]
for x in range(group.order):
    data = group.get_subset_of_data([x], all_indices).cuda()
    data, labels = data[:, :2], data[:, 2]
    logits = model(data)
    loss = loss_fn(logits, labels).item()
    losses.append(loss)
line(losses, log_y=True)


elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison



## Look at activation patterns

Let's first look at the relevant activations (the hidden and output layers) in the standard basis.

### Understanding Individual Neuron Activations

In [47]:
# first, reshape from (batch) to (group.order, group.order)
for key, value in activations.items():
    activations[key] = value.reshape(group.order, group.order, -1)

In [48]:
# elements
for x in range(group.order):
    print(f'{x}: {group.idx_to_perm(x)}, signature: {group.signature(x)}, order: {group.perm_order(x)}')

# conjugacy classes
#for x in group.conjugacy_classes:
    #print(x)

0: (2 4 3), signature: 1, order: 3
1: (4)(0 1 2), signature: 1, order: 3
2: (0 2 1 3 4), signature: 1, order: 5
3: (0 3)(1 4), signature: 1, order: 2
4: (0 4 3 2), signature: -1, order: 4
5: (4)(1 2 3), signature: 1, order: 3
6: (0 1 3 2 4), signature: 1, order: 5
7: (0 2)(1 4), signature: 1, order: 2
8: (0 3 4 2 1), signature: 1, order: 5
9: (0 4)(1 2 3), signature: -1, order: 6
10: (1 3 4), signature: 1, order: 3
11: (0 1 4 2 3), signature: 1, order: 5
12: (0 2 4 3 1), signature: 1, order: 5
13: (4)(0 3 2), signature: 1, order: 3
14: (0 4 1 3), signature: -1, order: 4
15: (1 4 2 3), signature: -1, order: 4
16: (0 1)(2 4 3), signature: -1, order: 6
17: (4)(0 2), signature: -1, order: 2
18: (0 3 4)(1 2), signature: -1, order: 6
19: (0 4 2 3 1), signature: 1, order: 5
20: (4), signature: 1, order: 1
21: (0 1 2 3 4), signature: 1, order: 5
22: (0 2 4 1 3), signature: 1, order: 5
23: (0 3 1 4 2), signature: 1, order: 5
24: (0 4), signature: -1, order: 2
25: (1 2)(3 4), signature: 1, order

In [49]:
hidden = activations['hidden']
#neurons = [3, 7, 8, 13, 15, 16, 19, 24, 33, 34, 36, 39] # checked 0-39 for blocky neurons
neurons = [24]
for neuron in neurons:
    imshow(hidden[:, :, neuron], title=f'hidden activations, neuron {neuron}', input1='position 1', input2='position 2')

In [50]:
logits = activations['logits']
for neuron in range(2):
    imshow(logits[:, :, neuron], title=f'logit {neuron}', input1='position 1', input2='position 2')

## Understanding Logit Computation

We hypothesise the network is using the following formula to compute logits

$$Tr(\rho(x)\rho(y)\rho(z^{-1}))=Tr(\rho(xyz^{-1}))$$

We can verify this by computing this (group.order, group.order, group.order) tensor cube, and comparing the cosine similarity of it to the actual logits produced. The group class has already done this. Note this trace has been centred, as logits don't care about an additive constant and we will compare the two. We also center the logits before comparing their cosine similarity.

In [51]:
centred_trace = group.irreps['standard'].logit_trace_tensor_cube
print(centred_trace.shape)

torch.Size([120, 120, 120])


In [52]:
# correct for logits not caring about an additive constant
logits = activations['logits']
print(logits.shape)
centred_logits = logits - logits.mean(-1, keepdim=True)

torch.Size([120, 120, 120])


### Visual Comparison
Let's first just look and see if they are visiually similar across two different slices

In [53]:
for neuron in range(2):
    imshow(centred_logits[:, :, neuron], title=f'real logit {neuron}', input1='position 1', input2='position 2')
    imshow(centred_trace[:, :, neuron], title=f'hypothesised logit {neuron}', input1='position 1', input2='position 2')

In [54]:
for neuron in range(2):
    imshow(centred_logits[neuron, :, :], title=f'real logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')
    imshow(centred_trace[neuron, :, :], title=f'hypothesised logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')

### Cosine similarity 

over all logits, mean over all inputs.

We find for S5 both the standard and natural representation yield good cosine similarity. This suggests the network really is performing the alleged algorithm. The product of the sign and standard representation is certainly not the algorithm employed

We find for S4 all three representations yield good similarity.

In [55]:
centred_logits = centred_logits.reshape(group.order*group.order,-1)
centred_trace = centred_trace.reshape(group.order*group.order,-1)
sims = F.cosine_similarity(centred_logits, centred_trace, dim=-1)
print(f'cosine similarity pre softmax {sims.mean()}')

centred_logits_softmax = F.softmax(centred_logits, dim=-1)
centred_trace_softmax = F.softmax(centred_trace, dim=-1)

softmax_sims = F.cosine_similarity(centred_logits_softmax, centred_trace_softmax, dim=-1).mean()
print(f'cosine similarity post softmax {softmax_sims.mean()}')

norms = torch.linalg.norm(centred_logits, dim=-1) 
scalar_projects = norms * sims 
unit_trace_logits = centred_trace / torch.linalg.norm(centred_trace, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(-1) / centred_logits.pow(2).sum(-1)
print(f'fraction of variance of logit explained by the trace logits {frac_explained.mean()}')

norms = torch.linalg.norm(centred_logits_softmax, dim=-1) 
scalar_projects = norms * softmax_sims 
unit_trace_logits = centred_trace_softmax / torch.linalg.norm(centred_trace_softmax, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(-1) / centred_logits_softmax.pow(2).sum(-1)
print(f'fraction of variance of softmax logits explained by the softmax trace logits {frac_explained.mean()}')


cosine similarity pre softmax 0.8569074273109436
cosine similarity post softmax 0.8944902420043945
fraction of variance of logit explained by the trace logits 0.7343247532844543
fraction of variance of softmax logits explained by the softmax trace logits 0.8001127243041992


### How different are the trace logits for different representations?

They are orthogonal.

Hypothesis: the representations learned are those that are maximally different. Each injective representation will give the correct answer. Having very different logits will produce outputs that destructively intefere, minimising loss.


In [56]:
centred_traces = {'sign': group.irreps['sign'].logit_trace_tensor_cube, 
                    'standard': group.irreps['standard'].logit_trace_tensor_cube,
                    'standard_sign': group.irreps['standard_sign'].logit_trace_tensor_cube}

if group.index == 4:
    centred_traces['s4_2d'] =  group.irreps['s4_2d'].logit_trace_tensor_cube

for key, value in centred_traces.items():
    centred_traces[key] = value.reshape(group.order*group.order,-1)

for key1, value1 in centred_traces.items():
    for key2, value2 in centred_traces.items():
        print(f'{key1} vs {key2}: {F.cosine_similarity(value1, value2, dim=-1).mean()}')

sign vs sign: 0.9999999403953552
sign vs standard: 0.0
sign vs standard_sign: 0.0
standard vs sign: 0.0
standard vs standard: 1.0
standard vs standard_sign: -5.83628834149863e-09
standard_sign vs sign: 0.0
standard_sign vs standard: -5.83628834149863e-09
standard_sign vs standard_sign: 1.0


### Theoretical optimal representations to use

Interesting. Now what loss would I get if I actually used each of these representations? Let's find out.

In [57]:
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    accuracy = (value.argmax(-1) == all_labels).float().mean()
    loss = loss_fn(value, all_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

sign: 4.221286773681641, 1.6666667461395264%
standard: 1.4973335266113281, 100.0%
standard_sign: 1.3868404626846313, 100.0%


In [58]:
def linspace(start, stop, step=1.):
  """
    Like np.linspace but uses step instead of num
    This is inclusive to stop, so if start=1, stop=3, step=0.5
    Output is: array([1., 1.5, 2., 2.5, 3.])
  """
  return np.linspace(start, stop, int((stop - start) / step + 1))

granularity = 0.1

if group.index ==5:
  best_weights = [0, 0, 0, 0]
  best_loss = 100
  for i in linspace(0, 1, granularity):
      for j in linspace(0, 1-i, granularity):
              k = 1 - i - j 
              hyp_logits = i*centred_traces['standard'] + j*centred_traces['sign'] + k*centred_traces['standard_sign']
              loss = loss_fn(hyp_logits, all_labels)
              if loss < best_loss:
                  best_loss = loss
                  best_weights = [i, j, k]
                  print(f'new best loss {best_loss} with weights {best_weights}')
  print(f'best combination: {best_loss}, {accuracy*100}%')
  print(f'best weights: {best_weights}')

if group.index==4:
  best_weights = [0.58, 0.27, 0, 0.19, 0.08]
  best_loss = 100
  for i in linspace(0, 1, granularity):
      for j in linspace(0, 1-i, granularity):
          for k in linspace (0, 1-i-j, granularity):
              l = 1 - i - j - k
              hyp_logits = i*centred_traces['standard'] + j*centred_traces['sign'] + k*centred_traces['standard_sign']+ l*centred_traces['s4_2d']
              loss = loss_fn(hyp_logits, all_labels)
              if loss < best_loss:
                  best_loss = loss
                  best_weights = [i, j, k, l]
                  print(f'new best loss {best_loss} with weights {best_weights}')
  print(f'best combination: {best_loss}')
  print(f'best weights: {best_weights}')

new best loss 1.3868404626846313 with weights [0.0, 0.0, 1.0]
new best loss 1.3438732624053955 with weights [0.1, 0.0, 0.9]
new best loss 1.309489369392395 with weights [0.2, 0.0, 0.8]
new best loss 1.283819556236267 with weights [0.30000000000000004, 0.0, 0.7]
new best loss 1.2674612998962402 with weights [0.4, 0.0, 0.6]
new best loss 1.261618733406067 with weights [0.5, 0.0, 0.5]
best combination: 1.261618733406067, 100.0%
best weights: [0.5, 0.0, 0.5]


Is this algorithm better on the alternating set? no


In [59]:
alternating_indices = [i for i in range(group.order) if group.signature(i) == 1]
alternating_data = group.get_subset_of_data(alternating_indices).cuda()
alternating_data, alternating_labels = alternating_data[:, :2], alternating_data[:, 2]

print(alternating_indices)
indices = []
for i in alternating_indices:
    for j in alternating_indices:
        indices.append(i*group.order+j)

[0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12, 13, 19, 20, 21, 22, 23, 25, 26, 27, 28, 30, 31, 32, 33, 39, 44, 49, 54, 55, 56, 57, 58, 64, 69, 74, 75, 76, 77, 78, 84, 89, 94, 95, 96, 97, 98, 100, 101, 102, 103, 105, 106, 107, 108, 110, 111, 112, 113, 119]


In [60]:
print('on alternating subgroup')
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    value = value[indices]
    accuracy = (value.argmax(-1) == alternating_labels).float().mean()
    loss = loss_fn(value, alternating_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

print('on whole group')
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    accuracy = (value.argmax(-1) == all_labels).float().mean()
    loss = loss_fn(value, all_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

on alternating subgroup
sign: 4.221277713775635, 1.6666667461395264%
standard: 1.4973357915878296, 100.0%
standard_sign: 1.3868407011032104, 100.0%
on whole group
sign: 4.221286773681641, 1.6666667461395264%
standard: 1.4973335266113281, 100.0%
standard_sign: 1.3868404626846313, 100.0%


## Understanding the embeddings

We consider the total embedding by multiplying out the linear layer W_x/y. This total embedding x_embed is a (group.order, hidden) tensor. Let's first look at the singular values of this matrix. We find there are 9 large ones = (4-1) x (4-1)

In [61]:
embed_dim = layers['embed_dim']
x_embed = model.x_embed
y_embed = model.y_embed

In [62]:
u_a,s_a,v_a = torch.linalg.svd(model.W_x) 

print('W_x')
print(s_a[:20])

u_x, s_x, v_x = torch.linalg.svd(x_embed)
print('x_embed')
print(s_x[:20])

u_b,s_b,v_b = torch.linalg.svd(model.W_y)
print('W_y')
print(s_b[:20])

u_y, s_y, v_y = torch.linalg.svd(y_embed)
print('y_embed')
print(s_y[:20])

if model.__class__.__name__ == "OneLayerMLP":
    u,s,v = torch.linalg.svd(model.W)
    print('W')
    print(s[:30])

u,s,v = torch.linalg.svd(model.W_U.T)
print('W_U')
print(s[:50])

W_x
tensor([21.3694, 20.3155, 19.8813, 19.5922, 19.4425, 19.3569, 19.2449, 19.1849,
        19.1296, 18.9362, 18.8428, 18.7884, 18.6878, 18.4996, 18.2576, 18.2140,
        15.8545,  2.4459,  1.0499,  0.8320], device='cuda:0',
       grad_fn=<SliceBackward0>)
x_embed
tensor([21.3694, 20.3155, 19.8813, 19.5922, 19.4425, 19.3569, 19.2449, 19.1849,
        19.1296, 18.9362, 18.8428, 18.7884, 18.6878, 18.4996, 18.2576, 18.2140,
        15.8545,  2.4459,  1.0499,  0.8320], device='cuda:0',
       grad_fn=<SliceBackward0>)
W_y
tensor([20.3510, 19.1228, 18.8870, 18.8539, 18.7505, 18.6588, 18.5699, 18.4841,
        18.4149, 18.3033, 18.2390, 17.9665, 17.9191, 17.8298, 17.6939, 17.6171,
        15.8500,  1.8426,  0.9977,  0.8369], device='cuda:0',
       grad_fn=<SliceBackward0>)
y_embed
tensor([20.3510, 19.1228, 18.8870, 18.8539, 18.7505, 18.6588, 18.5699, 18.4841,
        18.4149, 18.3033, 18.2390, 17.9665, 17.9191, 17.8298, 17.6939, 17.6171,
        15.8500,  1.8426,  0.9977,  0.8369], device

### Low rank approximations

What happens if we take the low rank approximation gained from the first 9 singular values?


In [63]:
k=9
#new_model = copy.deepcopy(model)
#new_model.W_x = torch.nn.Parameter(u_a[:, :k] @ torch.diag(s_a[:k]) @ v_a[:k, :])
#new_model.W_y = torch.nn.Parameter(u_b[:, :k] @ torch.diag(s_b[:k]) @ v_b[:k, :])
logits = model(all_data)
#new_logits = new_model(all_data)
#loss_new = loss_fn (new_logits, all_labels)
loss = loss_fn (logits, all_labels)

print(f'baselines loss {loss.item()}')
#print(f'loss with only the largest 9 singular values of embeddings {loss_new.item()}')

baselines loss 6.166075763758272e-06


The natural representation has a 1-dimensional invariant subspace spanned by the sum of all basis vectors. Note too in the standard representation, the trace of any element is identical to the trace in the natural representation, up to some additive constant, which we discount from logits anyway. 

We can also see this by looking at its singular values of the standard representations, of which there are $group.index^2$, but only $(group.index-1)^2+1$ non zero ones. I hypothesise this is the 9 directions corresponding to the natural representation, then one for the "distance from the plane" formed when we collapse the standard rep to the natural rep. I hypothesise this 1 additional direction is not used by the network.

In [64]:
k = (group.index-1)**2 + 1
natural_reps = group.other_reps['natural'].rep.reshape(group.order, group.index*group.index)
u, s, v = torch.linalg.svd(natural_reps)
print(s)

# remove junk singular values
natural_reps_mod = u[:, :k] @ torch.diag(s[:k]) # this is going to get plugged into a proj operator, so we can drop the v
natural_reps_mod = torch.linalg.qr(natural_reps_mod)[0]

tensor([1.0954e+01, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00,
        5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00,
        5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 5.4772e+00, 4.1506e-07,
        3.4037e-07, 3.2403e-07, 2.7186e-07, 1.7150e-07, 1.6669e-07, 1.1426e-07,
        4.7155e-08], device='cuda:0')


In the standard representation, all $(group.index-1)^2$ singular values are non zero.

In [65]:
standard_reps = group.irreps['standard'].rep.reshape(group.order, (group.index-1)*(group.index-1))
u,s,v = torch.linalg.svd(standard_reps)
print(s)

# orthogonalise these
##print(group.[].shape)

tensor([16.8572, 14.3396, 10.4183,  8.8623,  7.5388,  6.4389,  5.4772,  5.4772,
         5.4772,  5.4772,  4.6592,  3.9794,  3.3851,  2.8795,  2.0921,  1.7797],
       device='cuda:0')


same in the standard sign

In [66]:
standard_sign_reps = group.irreps['standard_sign'].rep.reshape(group.order, (group.index-1)*(group.index-1))
u,s,v = torch.linalg.svd(standard_sign_reps)
print(s)

#print(group.standard_sign_reps_orth.shape)


tensor([16.8572, 14.3396, 10.4183,  8.8623,  7.5388,  6.4389,  5.4772,  5.4772,
         5.4772,  5.4772,  4.6592,  3.9794,  3.3851,  2.8795,  2.0921,  1.7797],
       device='cuda:0')


How different are our representations? Are different representations orthogonal? Are individual ones orthonormal? 

Different reps are orthogonal. Individual ones are not orthonormal, and can not be made to be so easily. We use QR.

In [67]:
reps = {}
for key, value in group.irreps.items():
    reps[key] = value.rep

for key, value in reps.items():
    reps[key] = value.reshape(group.order, -1)

# if group.index == 4:
#     reps['s4_2d'] = group.s4_2d_reps

# loop over all pairs of reps, and compute their orthogonality
for key1, value1 in reps.items():
    for key2, value2 in reps.items():
        if key1 == key2:
            continue
        matrix = value1.T @ value2
        orthogonality = matrix.pow(2).mean()
        print(f'{key1} and {key2} have orthogonality {orthogonality}')

# orthonormality:
for key, value in reps.items():
    if value.shape[1] == 1:
        continue
    matrix = value.T@value
    print(matrix.shape)
    imshow(matrix, title = f'{key} orthonormality')

trivial and sign have orthogonality 0.0
trivial and standard have orthogonality 0.0
trivial and standard_sign have orthogonality 0.0
trivial and s5_5d_a have orthogonality 0.0
trivial and s5_5d_b have orthogonality 0.0
trivial and s5_6d have orthogonality 0.0
sign and trivial have orthogonality 0.0
sign and standard have orthogonality 0.0
sign and standard_sign have orthogonality 0.0
sign and s5_5d_a have orthogonality 0.0
sign and s5_5d_b have orthogonality 0.0
sign and s5_6d have orthogonality 0.0
standard and trivial have orthogonality 0.0
standard and sign have orthogonality 0.0
standard and standard_sign have orthogonality 0.0
standard and s5_5d_a have orthogonality 0.0
standard and s5_5d_b have orthogonality 0.0
standard and s5_6d have orthogonality 0.0
standard_sign and trivial have orthogonality 0.0
standard_sign and sign have orthogonality 0.0
standard_sign and standard have orthogonality 0.0
standard_sign and s5_5d_a have orthogonality 0.0
standard_sign and s5_5d_b have ortho

torch.Size([16, 16])


torch.Size([25, 25])


torch.Size([25, 25])


torch.Size([36, 36])


In [68]:
orth_reps={}
for key, value in group.irreps.items():
    orth_reps[key] = value.orth_rep

for key, value in orth_reps.items():
    if value.shape[1] == 1:
        continue
    matrix = value.T @ value
    print(matrix.shape)
    imshow(matrix, title=f'{key}_orth orthonormality')


torch.Size([16, 16])


torch.Size([16, 16])


torch.Size([25, 25])


torch.Size([25, 25])


torch.Size([36, 36])


Let's see if the embeddings are actually learning representations. We can generate a $(group.order, group.index^2)$ tensor encoding the natural representation, or a $(group.order, (group.index-1)^2)$ encoding the standard representation. The embeddings are $(group.order, embed\_dim)$ tensors, which are higher dimensional than the representations. We've already shown the embeddings are sparse in the singular basis, but let's now show the dimensions actually line up.

To do so, we can just take dot products of the embeddings with the orthonormal representations.

In [69]:
x_norm = x_embed.pow(2).sum()
y_norm = y_embed.pow(2).sum()
U_norm = model.W_U.pow(2).sum()

total_x = 0
total_y = 0
total_U = 0

for key, value in orth_reps.items():
    coefs_x = value.T @ x_embed
    coefs_y = value.T @ y_embed
    coefs_U = value.T @ model.W_U.T
    x_prop = coefs_x.pow(2).sum() / x_norm
    y_prop = coefs_y.pow(2).sum() / y_norm
    U_prop = coefs_U.pow(2).sum() / U_norm
    total_x += x_prop
    total_y += y_prop
    total_U += U_prop

    print(f'x embedding proportion in {key} directions: {x_prop}')
    print(f'y embedding proportion in {key} rep directions: {y_prop}')
    print(f'unembedding proportion in {key} rep directions: {U_prop}')

print(f'total x embedding proportion in all rep directions: {total_x}')
print(f'total y embedding proportion in all rep directions: {total_y}')
print(f'total unembedding proportion in all rep directions: {total_U}')


x embedding proportion in trivial directions: 0.001817768206819892
y embedding proportion in trivial rep directions: 0.0006342572742141783
unembedding proportion in trivial rep directions: 0.002767738653346896
x embedding proportion in sign directions: 0.04059810936450958
y embedding proportion in sign rep directions: 0.04383799433708191
unembedding proportion in sign rep directions: 0.039522770792245865
x embedding proportion in standard directions: 0.955693781375885
y embedding proportion in standard rep directions: 0.953105628490448
unembedding proportion in standard rep directions: 0.9532396793365479
x embedding proportion in standard_sign directions: 1.5877834825398017e-15
y embedding proportion in standard_sign rep directions: 1.5471183836541884e-15
unembedding proportion in standard_sign rep directions: 0.00029025148251093924
x embedding proportion in s5_5d_a directions: 2.2802457282935194e-15
y embedding proportion in s5_5d_a rep directions: 2.466628763831487e-15
unembedding pr

Is it using a combination of the natural and standard representations? Is it using some combination of the natural, standard and sign reprsentatations? 

What even is the combined rank of all the representations? How orthogonal are they? They are all precisely orthogonal, which is good.

S5 doesnt seem to use the product standard sign rep at all?

 ### Variance
 
Make sure that each of the representations actually uses all the information in equal parts from each of the dim^2 representation matrix elements. Do this by checking that the standard deviation of percentage contributions of each column is small.

We see that the orth x embedding is constant on columns and y embedding is constant on rows. This highly suggests its doing a matrix multiplication of y times x.

In [70]:
# method 2
rep = group.irreps['standard'].orth_rep.reshape(group.order, -1)
dims = rep.shape[1]

x_norm = x_embed.pow(2).sum()
y_norm = y_embed.pow(2).sum()
U_norm = model.W_U.pow(2).sum()

coefs_x = rep.T @ x_embed
coefs_y = rep.T @ y_embed
coefs_U = rep.T @ model.W_U.T

conts_x = coefs_x.pow(2).sum(-1) / x_norm
conts_y = coefs_y.pow(2).sum(-1) / y_norm
conts_U = coefs_U.pow(2).sum(-1) / U_norm

line(conts_x, xaxis = 'flattened rep direction', yaxis='proportion of x embedding')
line(conts_y, xaxis = 'flattened rep direction', yaxis='proportion of y embedding')
line(conts_U, xaxis = 'flattened rep direction', yaxis='proportion of unembedding')


## Understanding the hidden layer

Under the hypothesis that the network is embedding as representations, We should expect the hidden layer to contain terms like "rho(x)rho(y)".

Consider a model run on all the (group.order * group.order) inputs. The hidden layer is then a (group.order * group.order, hidden) tensor. 

Assuming some representation rho, we should expect to find a (group.order * group.order, rep_dim, rep_dim) = (group.order * group.order, rep_dim * rep_dim) tensor, with rep_dim << hidden. 

We can perform a similar projection trick here as we did on the embeddings.

We might want to center the hidden neurons

In [71]:
print(group.__class__.__name__)

SymmetricGroup


In [72]:
hidden = activations['hidden']
hidden = hidden.reshape(group.order*group.order, -1)
print(hidden.shape)
mean = hidden.mean(dim=-1, keepdim=True)
print(mean.shape)
hidden_centred = hidden - mean
u, s, v = torch.linalg.svd(hidden_centred)
#line(s)

logits = hidden @ model.W_U

loss = loss_fn(logits, all_labels)

centred_logits = hidden_centred @ model.W_U
centred_loss = loss_fn(centred_logits, all_labels)

logit_diff = logits - centred_logits

imshow(logit_diff[:20], title='logit difference')

print(f'Loss: {loss}')
print(f'Centred Hidden loss: {centred_loss}')

# i think this is close enough?

torch.Size([14400, 256])
torch.Size([14400, 1])


Loss: 6.166075763758272e-06
Centred Hidden loss: 6.755404683644883e-06


In [73]:
hidden_reps_xy = {}
hidden_reps_x = {}
hidden_reps_y = {}
for name, rep in reps.items():
    hidden_reps_xy[name] = rep[all_labels].reshape(group.order*group.order, -1)
    hidden_reps_x[name] = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
    hidden_reps_y[name] = rep[all_data[:, 1]].reshape(group.order*group.order, -1)

# loop over all pairs of reps, and compute their orthogonality
for key1, value1 in hidden_reps_xy.items():
    for key2, value2 in hidden_reps_xy.items():
        if key1 == key2:
            continue
        matrix = value1.T @ value2
        orthogonality = matrix.pow(2).sum()
        print(f'{key1} and {key2} have orthogonality {orthogonality}')

#make orthonormal

hidden_reps_xy_orth = {}
hidden_reps_x_orth = {}
hidden_reps_y_orth = {}


for rep in hidden_reps_xy.keys():
    hidden_reps_xy_orth[rep] = torch.qr(hidden_reps_xy[rep])[0]
    hidden_reps_x_orth[rep] = torch.qr(hidden_reps_x[rep])[0]
    hidden_reps_y_orth[rep] = torch.qr(hidden_reps_y[rep])[0]


trivial and sign have orthogonality 0.0
trivial and standard have orthogonality 0.0
trivial and standard_sign have orthogonality 0.0
trivial and s5_5d_a have orthogonality 0.0
trivial and s5_5d_b have orthogonality 0.0
trivial and s5_6d have orthogonality 0.0
sign and trivial have orthogonality 0.0
sign and standard have orthogonality 0.0
sign and standard_sign have orthogonality 0.0
sign and s5_5d_a have orthogonality 0.0
sign and s5_5d_b have orthogonality 0.0
sign and s5_6d have orthogonality 0.0
standard and trivial have orthogonality 0.0
standard and sign have orthogonality 0.0
standard and standard_sign have orthogonality 0.0
standard and s5_5d_a have orthogonality 0.0
standard and s5_5d_b have orthogonality 0.0
standard and s5_6d have orthogonality 0.0
standard_sign and trivial have orthogonality 0.0
standard_sign and sign have orthogonality 0.0
standard_sign and standard have orthogonality 0.0
standard_sign and s5_5d_a have orthogonality 0.0
standard_sign and s5_5d_b have ortho

In [74]:
hidden_norm = hidden.pow(2).sum()

total_x = 0
total_y = 0 
total_xy = 0

coefs_xs = {}
coefs_ys = {}
coefs_xys = {}

for rep in hidden_reps_xy.keys():

    coefs_x = hidden_reps_x_orth[rep].T @ hidden
    coefs_y = hidden_reps_y_orth[rep].T @ hidden
    coefs_xy = hidden_reps_xy_orth[rep].T @ hidden

    coefs_xs[rep] = coefs_x
    coefs_ys[rep] = coefs_y
    coefs_xys[rep] = coefs_xy
    
    x_prop = coefs_x.pow(2).sum() / hidden_norm
    y_prop = coefs_y.pow(2).sum() / hidden_norm
    xy_prop = coefs_xy.pow(2).sum() / hidden_norm

    print(f'{rep}, x: {x_prop}, y: {y_prop}, xy: {xy_prop}')
print(f'Total percent explained by reps: {total_x + total_y + total_xy}')

trivial, x: 6.792666340516007e-07, y: 6.792666340516007e-07, xy: 6.792666340516007e-07
sign, x: 1.297349143669635e-07, y: 1.7647413130816858e-07, xy: 0.11359807848930359
standard, x: 0.0005533711519092321, y: 0.0015573815908282995, xy: 0.21904142200946808
standard_sign, x: 3.697503690859138e-16, y: 1.492174744497591e-15, xy: 7.924138805351789e-16
s5_5d_a, x: 2.993583312790556e-16, y: 2.4466430214410218e-15, xy: 1.8767829714508516e-15
s5_5d_b, x: 3.545689446327158e-16, y: 1.927209172141505e-15, xy: 2.3628113236183164e-15
s5_6d, x: 1.0637469358698581e-06, y: 2.459198185533751e-06, xy: 2.2280653411144158e-07
Total percent explained by reps: 0


Claim: The output logits should not depend on the directions corresponding to x or y, but only those corresponding to xy. Let's test this. To do so, we first just ablate those directions and see what happens to loss. If we ablate the xy direction we should see loss tank.


In [75]:
logits = hidden @ model.W_U
loss = loss_fn(logits, all_labels)
print(f'baseline loss: {loss}')

for rep_name in hidden_reps_xy.keys():

    hidden_rep_x = hidden_reps_x_orth[rep_name]
    hidden_rep_y = hidden_reps_y_orth[rep_name]
    hidden_rep_xy = hidden_reps_xy_orth[rep_name]

    coefs_x = coefs_xs[rep_name]
    coefs_y = coefs_ys[rep_name]
    coefs_xy = coefs_xys[rep_name]

    hidden_x = hidden_rep_x @ coefs_x
    hidden_y = hidden_rep_y @ coefs_y
    hidden_xy = hidden_rep_xy @ coefs_xy

    hidden_ablated_x = hidden - hidden_x
    hidden_ablated_y = hidden - hidden_y
    hidden_ablated_xy = hidden - hidden_xy

    logits_x = hidden_ablated_x @ model.W_U
    logits_y = hidden_ablated_y @ model.W_U
    logits_xy = hidden_ablated_xy @ model.W_U

    loss_x = loss_fn(logits_x, all_labels)
    loss_y = loss_fn(logits_y, all_labels)
    loss_xy = loss_fn(logits_xy, all_labels)

    print(f'Ablating directions corresponding to {rep_name} rep loss, xy: {loss_xy}, x: {loss_x}, y: {loss_y}')
    

baseline loss: 6.166075763758272e-06
Ablating directions corresponding to trivial rep loss, xy: 6.167558240122162e-06, x: 6.167558240122162e-06, y: 6.167558240122162e-06
Ablating directions corresponding to sign rep loss, xy: 0.00044248905032873154, x: 6.166837465571007e-06, y: 6.166812909214059e-06
Ablating directions corresponding to standard rep loss, xy: 4.262816429138184, x: 6.1279788496904075e-06, y: 6.164229944261024e-06
Ablating directions corresponding to standard_sign rep loss, xy: 6.166083949210588e-06, x: 6.166108960314887e-06, y: 6.166042567201657e-06
Ablating directions corresponding to s5_5d_a rep loss, xy: 6.166216280689696e-06, x: 6.1661253312195186e-06, y: 6.166058938106289e-06
Ablating directions corresponding to s5_5d_b rep loss, xy: 6.166158527776133e-06, x: 6.166034381749341e-06, y: 6.166224466142012e-06
Ablating directions corresponding to s5_6d rep loss, xy: 6.163674697745591e-06, x: 6.166820639919024e-06, y: 6.166076218505623e-06


We show that the the spaces corresponding to x and y representations in the hidden layer are in the kernel of W_U.

In [76]:
# this shows the same as the above I think - by linearity
# print('Mean norms of unembedded representation vectors')
# for rep_name in hidden_reps_xy.keys():
#     hidden_rep_xy = hidden_reps_xy[rep_name]
#     hidden_rep_x = hidden_reps_xy[rep_name]
#     hidden_rep_y = hidden_reps_xy[rep_name]
    
#     print(model.W_U.shape)
#     print(hidden_rep_xy.shape)
#     P_xy = projection_matrix_general(hidden_reps_xy)
#     P_x = projection_matrix_general(hidden_reps_x) 
#     P_y = projection_matrix_general(hidden_reps_y)
#     proj_xy = P_xy @ hidden_centred
#     proj_x = P_x @ hidden_centred
#     proj_y = P_y @ hidden_centred

#     unembed_xy = proj_xy @ model.W_U
#     unembed_x = proj_x @ model.W_U
#     unembed_y = proj_y @ model.W_U

#     print(f'{name} rep, xy: {unembed_xy.pow(2).mean()}, x: {unembed_x.pow(2).mean()}, y: {unembed_y.pow(2).mean()}')


### Variance

Are some matrix elements of xy in the standard representation preferred over others in the hidden layer? Yes. I claim this is actually pretty deep.

Let $e_i$ be the standard basis matrices. Let $\mathbb{x} := \lambda_i x_i \mathbb{e}_i$. Let $\mathbb{y} := \delta_i y_i \mathbb{e}_i$ be such that the mean of $x_i$ and mean of $y_i$ is one over all inputs for fixed i, i.e. $\lambda_i,\delta_i$ encode an i-dependent scaling factor that we calcualate below. Then consider the matrix product $\mathbb{x}\mathbb{y} = \mathbb{z}$ = $z_i \mathbb{e}_i$. 

Then assuming 5d matrices, eg $z_0 = \lambda_0 \delta_0 x_0 y_0 + \lambda_1 \delta_5 x_1 y_5 + \lambda_2 \delta_{10} x_2 y_{10} + \lambda_3 \delta_15 x_3 y_{15} + \lambda_4 \delta_{20} x_4 y_{20}$.

This only makes sense dimensionally if $\delta_0 = \delta_5 = \delta_{10} = \delta_{15} = \delta_{20}$ and $\lambda_0 = \lambda_1 = \lambda_2 = \lambda_3 = \lambda_4$.



In [77]:
hidden_norm = hidden.pow(2).sum()

total_x = 0
total_y = 0 
total_xy = 0

coefs_xs = {}
coefs_ys = {}
coefs_xys = {}

for rep_name in hidden_reps_xy.keys():

    coefs_x = hidden_reps_x[rep_name].T @ hidden
    coefs_y = hidden_reps_y[rep_name].T @ hidden
    coefs_xy = hidden_reps_xy[rep_name].T @ hidden

    coefs_xs[rep_name] = coefs_x
    coefs_ys[rep_name] = coefs_y
    coefs_xys[rep_name] = coefs_xy
    
    x_prop = coefs_x.pow(2).sum() / hidden_norm
    y_prop = coefs_y.pow(2).sum() / hidden_norm
    xy_prop = coefs_xy.pow(2).sum() / hidden_norm

    x_conts = coefs_x.pow(2).sum(-1) / hidden_norm
    y_conts = coefs_y.pow(2).sum(-1) / hidden_norm
    xy_conts = coefs_xy.pow(2).sum(-1) / hidden_norm

    lines([x_conts, y_conts, xy_conts], title=f'Variance across {rep_name} hidden layer representations', labels=['x', 'y', 'xy'])

## Neuron Clustering

Before the ReLU, do the neurons cluster neatly into representations?

In [78]:
threshold = 2

x_embed = model.x_embed
y_embed = model.y_embed

x_embed_summed = x_embed.pow(2).sum(dim=0)
off_neurons_x = (x_embed_summed < threshold).nonzero().squeeze()

y_embed_summed = y_embed.pow(2).sum(dim=0)
off_neurons_y = (y_embed_summed < threshold).nonzero().squeeze()

assert (off_neurons_x == off_neurons_y).all()

off_neurons = off_neurons_x

print(f'Off neurons: {off_neurons}')

rep_neurons = {}
orth_reps_no_trivial = orth_reps.copy()
del orth_reps_no_trivial['trivial']

for rep_name, rep in orth_reps_no_trivial.items():
    coefs_x = rep.T @ x_embed
    coefs_y = rep.T @ y_embed
    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)

    x_neurons = (coefs_x_summed > threshold).nonzero().squeeze()
    y_neurons = (coefs_y_summed > threshold).nonzero().squeeze()
    assert (x_neurons == y_neurons).all()
    x_neurons = torch.tensor(x_neurons)
    if x_neurons.dim() == 0:
        x_neurons = x_neurons.unsqueeze(0)
    rep_neurons[rep_name] = x_neurons

print('Neurons corresponding to each representation')
print(rep_neurons)

all_neurons = torch.arange(model.W_U.shape[0])
unaccounted_neurons = set(all_neurons.tolist())
unaccounted_neurons -= set(off_neurons.tolist())
for rep_name, neurons in rep_neurons.items():
    unaccounted_neurons -= set(neurons.tolist())

print('Unaccounted neurons')
print(unaccounted_neurons)



Off neurons: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  16,  18,  19,  20,  22,  24,  26,  27,  29,  31,  32,  33,  34,
         35,  37,  38,  40,  41,  42,  44,  45,  47,  48,  49,  50,  51,  52,
         57,  60,  61,  62,  63,  65,  68,  69,  70,  71,  74,  75,  76,  77,
         78,  80,  81,  82,  84,  88,  89,  90,  92,  93,  94,  95,  96,  97,
         99, 100, 101, 102, 103, 104, 105, 106, 107, 109, 111, 112, 113, 114,
        115, 120, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 133, 134,
        135, 139, 140, 141, 142, 144, 147, 148, 150, 151, 152, 153, 154, 155,
        156, 157, 160, 161, 162, 163, 164, 165, 166, 167, 168, 170, 171, 173,
        174, 176, 177, 180, 182, 183, 184, 185, 186, 188, 189, 191, 192, 194,
        196, 197, 198, 199, 200, 201, 202, 203, 204, 206, 207, 208, 211, 212,
        213, 214, 216, 217, 218, 219, 220, 221, 222, 224, 225, 226, 227, 228,
        231, 232, 233, 234, 237, 238, 240, 241, 242


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Do the hidden layer neurons cluster neatly into representations? TODO

## Signature Representation Circuit Analysis

TLDR: the network memorises input signatures, mangles them together using the relu, and uses this to divide the output search space in two.


We know the S5 network uses the sign representation to some extent. Let's find this in the x and y embedding in the first instance. We'll later want to find this in the hidden layer too, and in the logits. 

We find in our run, (1L_MLP_sym_S5_cached_3), neurons 24, 66, 73, and 81 have high sign representation projection. Let's now print inspect these neurons. It's curious there are 4 of these.


In [79]:
signature_neurons = rep_neurons['sign']
print(signature_neurons)
sig_labels = [str(x) for x in signature_neurons.tolist()]

tensor([17], device='cuda:0')


The first row is the sign represntation. The next four rows are the x embeddings. The next image is the same but for y embeddings. Let a = 3.8. We learn that (for the same cached S5 model)

neuron 24 = $a ReLU (-sign(x) + sign(y)) > 0$ iff x odd and y even ie z odd

neuron 66 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 73 = $a ReLU (sign(x) - sign(y)) > 0$ iff x even and y odd ie z odd

neuron 81 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even



In 1L_MLP_cached_S4 we find

neuron 3 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 38 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 6 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even

neuron 29 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even


In [80]:
sigs = group.signatures.unsqueeze(-1)
xs = x_embed[:, signature_neurons]
stack = torch.hstack([sigs, xs]).T
imshow(stack, y=['sig'] + sig_labels, input2='input group element', title='Total x embeddings on select neurons')
ys = y_embed[:, signature_neurons]
stack = torch.hstack([sigs, ys]).T
imshow(stack, y=['sig'] + sig_labels, input2='input group element', title='Total y embeddings on select neurons')


How does this effect logits? We simply look at the relevant unembed rows. We expect rows 66 and 81 to correlate and 24 and 73 to anticorrelate with the (now output) signature. Thus the logits align with the trace formula.

In [81]:
sigs = group.signatures.unsqueeze(-1)
W_U_signatures = model.W_U[signature_neurons, :].T
stack = torch.hstack([sigs, W_U_signatures]).T

imshow(stack, y=['sig'] + sig_labels, input2='output group element', title='W_U on select neurons')

## Standard Representation Circuit Analysis

We know the S5 network uses the standard representation too. Let's find this in the x and y embedding, and in the hidden layer

### Extract representations from embeddings 

TLDR: this is pretty fucked due to the network doing a matrix multiply in an addition and relu

Our unfactored embedding matrix has size (group.order, hidden). The representations can be encoded in a (group.order, dim^2) tensor. I want the coefficients on this dim^2 subspace for some given input group element. This requires identifying the linear map hidden -> dim^2. Can I just take embedding.T @ rep.

Another approach is to use the singular value of the embedding as the coeffiecent, and the left singular vector as the indicator for whether that coefficient is used.

I'm likely overcomplicating this. I suspect all i need is a matrix multiply from an orthogonal set of vectors

In [82]:
def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P


#### Less janky method 4

Copying what I did below for the hidden layer. This doesn't really work. I think I should really be looking at the embedding only!

In [83]:
# rep = group.irreps['standard'].inverse_rep.reshape(group.order, -1)
# print(rep.shape)
# dims = rep.shape[1]
# dim = int(math.sqrt(dims))
# x_embed = model.W_y


# x_embed_to_reps = rep.T @ x_embed


# print(x_embed_to_reps.shape)
# imshow(x_embed_to_reps, title='Change of basis from total embedding to representation basis')

# x_embed_in_reps = x_embed @ x_embed_to_reps.T

# imshow(x_embed_in_reps[:20], title='x embedding in the representation basis')
# imshow(rep[:20], title='Theoretical learned representations')

Really we should project the unfactored x matrix onto the correct subspace first.


In [84]:
rep = group.irreps['standard'].rep.reshape(group.order, -1)
P = projection_matrix_general(rep)
proj_x = P @ x_embed

x_embed_to_reps = rep.T @ proj_x

imshow(x_embed_to_reps)

x_embed_in_reps = proj_x @ x_embed_to_reps.T
imshow(x_embed_in_reps[:20], title='Total projected x in the representation basis')
imshow(rep[:20], title='Theoretical learned representations')

In [87]:
rep = group.irreps['standard'].rep.reshape(group.order, -1)
P = projection_matrix_general(rep)
proj_y = P @ y_embed

y_embed_to_reps = rep.T @ proj_y

imshow(y_embed_to_reps)

y_embed_in_reps = proj_y @ y_embed_to_reps.T
imshow(y_embed_in_reps[:20], title='Total projected x in the representation basis')
imshow(rep[:20], title='Theoretical learned representations')

In [92]:
imshow(torch.vstack([x_embed_to_reps[0, :], y_embed_to_reps[0, :]]))


What if i do this one representation direction at a time? This is defo cheating


In [None]:
# rep = group.standard_reps.reshape(group.order, -1)
# dims = rep.shape[1]
# dim = int(math.sqrt(dims))
# x_embed = model.W_x 

# x_embed_in_reps = torch.zeros((group.order, dims)).cuda()
# for i in range(dims):
#     x_embed_to_reps = rep[:, i].unsqueeze(-1).T @ x_embed
#     #unfactored_x_in_reps = unfactored_x @ unfactored_x_to_reps.T
#     P = projection_matrix_general(rep[:, i].unsqueeze(-1))
#     proj_x = P @ x_embed  
#     x_embed_in_reps[:, i] = (proj_x @ x_embed_to_reps.T).squeeze()
# imshow(x_embed_in_reps[:10], title='Total projected x in the representation basis')
# imshow(rep[:10], title='Theoretical learned representations')



### Extract representations from the hidden layer

We first naively just print off the hidden layer in the representation basis. This is pretty noisy. This is because it has entangled the x and y and xy representations. 

In [None]:
rep = group.irreps['standard'].rep.reshape(group.order, -1)
dims = rep.shape[1]
dim = int(math.sqrt(dims))
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)

hidden_to_reps = hidden_reps_xy.T @ hidden

print(hidden_to_reps.shape)
imshow(hidden_to_reps, title='Change of basis from hidden basis to representation basis')

hidden_in_reps = hidden @ hidden_to_reps.T

print(hidden_in_reps.shape)

imshow(hidden_in_reps[:10], title='Hidden layer in the representation basis')


torch.Size([16, 128])


torch.Size([14400, 16])


To clean up this noise, we should first project the hidden layer onto the subspace of relevance, and then change the basis.

In [None]:
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
P = projection_matrix_general(hidden_reps_xy)
hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_xy = P @ hidden

hidden_to_reps_proj = hidden_reps_xy.T @ hidden_xy
hidden_in_rep = hidden_xy @ hidden_to_reps_proj.T

imshow(hidden_to_reps_proj)
imshow(hidden_in_rep[:10], title='Projected hidden layer in the standard representation basis', input2='representation basis', input1='input index in range group.order^2')
imshow(rep[all_labels].reshape(group.order*group.order, -1)[:10], title='Theoretical representations expected to be found in the hidden layer', input2='representation basis', input1='input index in range group.order^2')

Plausibly the network is actually computing (xy)^-1 in the hidden layer though.

In [None]:
rep = group.irreps['standard'].rep.reshape(group.order, -1)
hidden_reps_xy_inverse = rep[group.inverses[all_labels]].reshape(group.order*group.order, -1)

P = projection_matrix_general(hidden_reps_xy_inverse)
hidden = activations['hidden'].reshape(group.order*group.order, -1)

hidden_xy_inverse = P @ hidden

hidden_to_reps = hidden_reps_xy_inverse.T @ hidden_xy_inverse
imshow(hidden_to_reps, title='Change of basis from hidden basis to inverse representation basis', input2='representation basis', input1='hidden layer basis')

hidden_in_rep = hidden_xy_inverse[:10] @ hidden_to_reps.T
imshow(hidden_in_rep, title='Projected hidden layer in the inverse standard representation basis', input2='representation basis', input1='input index in range group.order^2')
imshow(rep[group.inverses[all_labels]].reshape(group.order*group.order, -1)[:10], title='Theoretical representations expected to be found in the hidden layer', input2='representation basis', input1='input index in range group.order^2')

### Logit Computation

Try and understand hidden -> logit computation. We expect the map to be v -> tr (v \rho(z^-1)).

We first note that the "hidden_to_reps" tensor shows that some neurons seem completely off in the hidden layer representation basis. I assert these are the ones in completely different representations. The signature neurons are notably absent. Others that are absent are the neurons which are just not used. This plot can be used for the signature representation too (need to change to line plot)

We can change the basis on both the hidden layer input and the logit outputs to the representation basis. We use (xy)^-1 representations for the hidden layer, and z representations for the output logits. Here, we see the form of the linear map is tr (v \rho(z)) = v x z^T, where the x is elementwise multiplication.

Note we can do the inverse too. We can change the basis on both the hidden layer input and the logit outputs to the representation basis. We use (xy) representations for the hidden layer, and z^-1 representations for the output logits. (below cell)

To see how we find both note that (a) and (b) describe the same linear map, just in different input and output bases.

a) $\rho(xy) \to tr \rho (xyz^{-1})$ 

b) $\rho((xy)^{-1}) \to tr \rho((xy)^{-1} z)$ i.e. $\rho^{-1} (xy) \to tr \rho^{-1} (xyz^{-1}) $



In [None]:
rep = group.irreps['standard'].rep.reshape(group.order, -1)
#rep = torch.randn_like(rep)
W_U = model.W_U
hidden_reps_xy = rep[group.inverses[all_labels]].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)
P = projection_matrix_general(hidden_reps_xy)
hidden_xy = P @ hidden
hidden_to_reps = hidden_reps_xy.T @ hidden_xy

print(hidden_to_reps.shape)
imshow(hidden_to_reps, title='Change of basis from the standard basis on the hidden layer to that of the representation space')

W_U_rep = hidden_to_reps @ W_U @ rep

print(W_U_rep.shape)
imshow(W_U_rep, title='Unembedding matrix in the representation space (for both input and output')



torch.Size([16, 128])


torch.Size([16, 16])


In [None]:
rep = group.irreps['standard'].rep.reshape(group.order, -1)
#rep = torch.randn_like(rep)
W_U = model.W_U
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)
P = projection_matrix_general(hidden_reps_xy)
hidden_xy = P @ hidden
hidden_to_reps = hidden_reps_xy.T @ hidden_xy

print(hidden_to_reps.shape)
imshow(hidden_to_reps, title='Change of basis from the standard basis on the hidden layer to that of the representation space')

W_U_rep = hidden_to_reps @ W_U @ rep[group.inverses]

print(W_U_rep.shape)
imshow(W_U_rep, title='Unembedding matrix in the representation space (for both input and output')



torch.Size([16, 128])


torch.Size([16, 16])


Curiously, the unembed matrix is diagonal if we don't introduce an inverse at any point and use the orthogonal rep. I don't think this means much.

In [None]:
rep = group.irreps['standard'].orth_rep.reshape(group.order, -1)
#rep = torch.randn_like(rep)
W_U = model.W_U
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)
P = projection_matrix_general(hidden_reps_xy)
hidden_xy = P @ hidden
hidden_to_reps = hidden_reps_xy.T @ hidden_xy

print(hidden_to_reps.shape)
imshow(hidden_to_reps, title='Change of basis from the standard basis on the hidden layer to that of the representation space')

W_U_rep = hidden_to_reps @ W_U @ rep

print(W_U_rep.shape)
imshow(W_U_rep, title='Unembedding matrix in the representation space (for both input and output')



torch.Size([16, 128])


torch.Size([16, 16])


can do something similar with the natural representation

In [None]:
rep = group.other_reps['natural'].rep.reshape(group.order, -1)
#rep = torch.randn_like(rep)
W_U = model.W_U
hidden_reps_xy = rep[group.inverses[all_labels]].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)
#P = projection_matrix_general(hidden_reps_xy)
#hidden_xy = P @ hidden
hidden_to_reps = hidden_reps_xy.T @ hidden

print(hidden_to_reps.shape)
imshow(hidden_to_reps, title='Change of basis from the standard basis on the hidden layer to that of the representation space')

W_U_rep = hidden_to_reps @ W_U @ rep

print(W_U_rep.shape)
imshow(W_U_rep, title='Unembedding matrix in the representation space (for both input and output')



torch.Size([25, 128])


torch.Size([25, 25])


Now reproduce above plot indicating how much of embeddings are explained by each of the inverse representation matrix elements, for x and y. We see now x is constant on rows and y is constant on columns. This suggests it is indeed learning the representations \rho(x^-1) and \rho(y^-1), by our definition of rho. z needs to be constant across all elements, as tr(* z) uses all elements of z equally.

In [None]:
# method 2
rep = group.irreps['standard'].rep.reshape(group.order, -1)
inverse_rep = rep[group.inverses]
rep = torch.qr(inverse_rep)[0]
dims = rep.shape[1]

x_embed = model.W_x @ model.W[:embed_dim, :]
y_embed = model.W_y @ model.W[embed_dim:, :]
x_norm = x_embed.pow(2).sum()
y_norm = y_embed.pow(2).sum()
U_norm = model.W_U.pow(2).sum()

coefs_x = rep.T @ x_embed
coefs_y = rep.T @ y_embed
coefs_U = rep.T @ model.W_U.T

conts_x = coefs_x.pow(2).sum(-1) / x_norm
conts_y = coefs_y.pow(2).sum(-1) / y_norm
conts_U = coefs_U.pow(2).sum(-1) / U_norm

line(conts_x, xaxis = 'flattened rep direction', yaxis='proportion of x embedding')
line(conts_y, xaxis = 'flattened rep direction', yaxis='proportion of y embedding')
line(conts_U, xaxis = 'flattened rep direction', yaxis='proportion of unembedding')



### Understanding the explicit calculation particular neurons are doing.

seems hard

In [None]:
rep = group.irreps['standard'].rep.reshape(group.order, -1)
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_to_reps = hidden_reps_xy.T @ hidden
imshow(hidden_to_reps)

consider neuron 0

In [None]:
hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
#hidden_reps_xy = torch.qr(hidden_reps_xy)[0]
hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
#hidden_reps_x = torch.qr(hidden_reps_x)[0]
hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)
#hidden_reps_y = torch.qr(hidden_reps_y)[0]

embed_stack = activations['embed_stack'].reshape(group.order*group.order, -1)
x_embeds = embed_stack[:, :256]
y_embeds = embed_stack[:, 256:]
x_embeds = x_embeds @ model.W[:256, :]
y_embeds = y_embeds @ model.W[256:, :]

print(x_embeds.shape)
print(y_embeds.shape)

print(hidden_reps_xy.shape)

x_embeds_to_reps = hidden_reps_x.T @ x_embeds
y_embeds_to_reps = hidden_reps_y.T @ y_embeds

#x_embeds = hidden_reps_x @ x_embeds_to_reps
#y_embeds = hidden_reps_y @ y_embeds_to_reps

imshow(x_embeds_to_reps)

imshow(y_embeds_to_reps)

imshow(hidden_to_reps)

s = x_embeds + y_embeds
relu = torch.nn.ReLU()
s_to_reps = hidden_reps_xy.T @ relu(s)

#imshow(s_to_reps)


torch.Size([14400, 128])
torch.Size([14400, 128])
torch.Size([14400, 16])


neuron 0 x embed = $-4758v_{12} + 3680v_{13} + 6853v_{14} + 10600v_{15} \approx  900 (-4v_{12} + 3v_{13} + 7.5v_{14} + 12v_{15})$


neuron 0 post relu = $-3730v_{15} - 2800v_{14} - 1865 v_{13} - 932v_{12} \approx - 900 ( v_{12}  + 2v_{13} + 3v_{14} + 4v_{15})$
