# Setup

In [222]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

Fri Nov 25 17:48:05 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 47%   40C    P8     7W / 105W |   5849MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Imports

In [223]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
import sage

import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *
from utils.config import *

In [224]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Things might be rather slow')

Good to go!


# Interpretability




## Interpretability Set Up 

In [225]:
task_dir = "1L_MLP_sym_S4" #1L_MLP_sym_S5_cached_3"
seed, frac_train, width, lr, group_param, weight_decay, num_epochs, group_type, architecture_type, metric_cfg, metric_obj = load_cfg(task_dir)
group = group_type(group_param, init_all = True)
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
model = architecture_type(width, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"))
model.eval()
logits, activations = model.run_with_cache(all_data, return_cache_object=False)
activations['logits'] = logits
metric_obj = metric_obj(group, training=False, track_metrics = True)
metrics = metric_obj.get_metrics(model)


Computing trace tensor cube for standard representation
Loading from file
Computing trace tensor cube for standard_sign representation
Loading from file
Computing trace tensor cube for sign representation
Loading from file
Computing trace tensor cube for trivial representation
Loading from file
Computing trace tensor cube for s4_2d representation
Loading from file



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



## Visualise the Weights


In [226]:
imshow(model.W_x)
imshow(model.W_y)
imshow(model.W)
imshow(model.W_U)

## Losses on various subsets


In [227]:
all_indices = np.arange(group.order)

In [228]:
all_loss = loss_fn(logits, all_labels)
all_loss

tensor(9.8294e-05, device='cuda:0', grad_fn=<NllLossBackward0>)

In [229]:
alternating_indices = [i for i in range(group.order) if group.signature(i) == 1]
alternating_data = group.get_subset_of_data(alternating_indices).cuda()
alternating_data, alternating_labels = alternating_data[:, :2], alternating_data[:, 2]
alternating_logits = model(alternating_data)
alternating_loss = loss_fn(alternating_logits, alternating_labels).item()
alternating_loss

0.00010362412285758182

In [230]:
alt_coset_indices = [i for i in range(group.order) if group.signature(i) == -1]
alt_coset_data = group.get_subset_of_data(alt_coset_indices).cuda()
alt_coset_data, alt_coset_labels = alt_coset_data[:, :2], alt_coset_data[:, 2]
alt_coset_logits = model(alternating_data)
alt_coset_loss = loss_fn(alternating_logits, alternating_labels).item()
alt_coset_loss

0.00010362412285758182

In [231]:
#random_indices = np.random.choice(group.order, int(0.5*group.order*group.order))
#random_data = group.get_subset_of_data(random_indices).cuda()
#random_data, random_labels = random_data[:, :2], random_data[:, 2]
#random_logits = model(random_data)
#random_loss = loss_fn(random_logits, random_labels).item()
#random_loss


In [232]:
data = group.get_subset_of_data(alternating_indices, alt_coset_indices).cuda()
data, labels = data[:, :2], data[:, 2]
logits = model(data)
loss = loss_fn(logits, labels).item()
loss


4.8020650865510106e-05

In [233]:
data = group.get_subset_of_data(alt_coset_indices, alternating_indices ).cuda()
data, labels = data[:, :2], data[:, 2]
logits = model(data)
loss = loss_fn(logits, labels).item()
loss


0.00018165486108046025

In [234]:
losses=[]
for x in range(group.order):
    data = group.get_subset_of_data([x], all_indices).cuda()
    data, labels = data[:, :2], data[:, 2]
    logits = model(data)
    loss = loss_fn(logits, labels).item()
    losses.append(loss)
line(losses, log_y=True)


elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison



## Look at activation patterns

Let's first look at the relevant activations (the hidden and output layers) in the standard basis.

### Understanding Individual Neuron Activations

In [235]:
# first, reshape from (batch) to (group.order, group.order) so that we may easily fourier transform on each of the two input dimensions.

for key, value in activations.items():
    activations[key] = value.reshape(group.order, group.order, -1)

In [236]:
# elements
for x in range(group.order):
    print(f'{x}: {group.idx_to_perm(x)}, signature: {group.signature(x)}, order: {group.perm_order(x)}')

# conjugacy classes
#for x in group.conjugacy_classes:
    #print(x)

0: (3), signature: 1, order: 1
1: (0 1 2 3), signature: -1, order: 4
2: (0 2)(1 3), signature: 1, order: 2
3: (0 3), signature: -1, order: 2
4: (1 2 3), signature: 1, order: 3
5: (0 1 3 2), signature: -1, order: 4
6: (3)(0 2 1), signature: 1, order: 3
7: (0 3 1 2), signature: -1, order: 4
8: (1 3 2), signature: 1, order: 3
9: (3)(0 1), signature: -1, order: 2
10: (0 2 3), signature: 1, order: 3
11: (0 3 2 1), signature: -1, order: 4
12: (2 3), signature: -1, order: 2
13: (3)(0 1 2), signature: 1, order: 3
14: (0 2 1 3), signature: -1, order: 4
15: (0 3 2), signature: 1, order: 3
16: (3)(1 2), signature: -1, order: 2
17: (0 1 3), signature: 1, order: 3
18: (0 2 3 1), signature: -1, order: 4
19: (0 3)(1 2), signature: 1, order: 2
20: (1 3), signature: -1, order: 2
21: (0 1)(2 3), signature: 1, order: 2
22: (3)(0 2), signature: -1, order: 2
23: (0 3 1), signature: 1, order: 3


### observations for cached S4

neuron 8: activates iff a even and b odd (!?)

neuron 24: activates iff both a and b odd



In [237]:
hidden = activations['hidden']
neurons = [3, 7, 8, 13, 15, 16, 19, 24, 33, 34, 36, 39] # checked 0-39 for blocky neurons
neurons = [24]
for neuron in neurons:
    imshow(hidden[:, :, neuron], title=f'hidden activations, neuron {neuron}', input1='position 1', input2='position 2')

## Find interesting neurons

In [238]:
def find_interesting_neuron(neuron):
    # check if position 1 activates only on certain signs
    pass
    


    # check if position 2 activates only on certain orders/signs

for idx in range(hidden.shape[-1]):
    neuron = hidden[:, :, idx]
    print(f'Analysis of neuron {idx}')
    find_interesting_neuron(neuron)

Analysis of neuron 0
Analysis of neuron 1
Analysis of neuron 2
Analysis of neuron 3
Analysis of neuron 4
Analysis of neuron 5
Analysis of neuron 6
Analysis of neuron 7
Analysis of neuron 8
Analysis of neuron 9
Analysis of neuron 10
Analysis of neuron 11
Analysis of neuron 12
Analysis of neuron 13
Analysis of neuron 14
Analysis of neuron 15
Analysis of neuron 16
Analysis of neuron 17
Analysis of neuron 18
Analysis of neuron 19
Analysis of neuron 20
Analysis of neuron 21
Analysis of neuron 22
Analysis of neuron 23
Analysis of neuron 24
Analysis of neuron 25
Analysis of neuron 26
Analysis of neuron 27
Analysis of neuron 28
Analysis of neuron 29
Analysis of neuron 30
Analysis of neuron 31
Analysis of neuron 32
Analysis of neuron 33
Analysis of neuron 34
Analysis of neuron 35
Analysis of neuron 36
Analysis of neuron 37
Analysis of neuron 38
Analysis of neuron 39
Analysis of neuron 40
Analysis of neuron 41
Analysis of neuron 42
Analysis of neuron 43
Analysis of neuron 44
Analysis of neuron 4

In [239]:
logits = activations['logits']
for neuron in range(2):
    imshow(logits[:, :, neuron], title=f'logit {neuron}', input1='position 1', input2='position 2')

## Understanding Logit Computation

We hypotehsise the network is using the following formula to compute logits

$$Tr(\rho(x)\rho(y)\rho(z^{-1}))=Tr(\rho(xyz^{-1}))$$

We can verify this by computing this (group.order, group.order, group.order) tensor cube, and comparing the cosine similarity of it to the actual logits produced. The group class has already done this. Note this trace has been centred, as logits don't care about an additive constant and we will compare the two. We also center the logits before comparing their cosine similarity.

In [240]:
centred_trace = group.standard_trace_tensor_cubes
print(centred_trace.shape)

torch.Size([24, 24, 24])


In [241]:
# correct for logits not caring about an additive constant

logits = activations['logits']
print(logits.shape)
centred_logits = logits - logits.mean(-1, keepdim=True)

torch.Size([24, 24, 24])


### Visual Comparison
Let's first just look and see if they are visiually similar across two different slices

In [242]:
for neuron in range(2):
    imshow(centred_logits[:, :, neuron], title=f'real logit {neuron}', input1='position 1', input2='position 2')
    imshow(centred_trace[:, :, neuron], title=f'hypothesised logit {neuron}', input1='position 1', input2='position 2')

In [243]:
for neuron in range(2):
    imshow(centred_logits[neuron, :, :], title=f'real logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')
    imshow(centred_trace[neuron, :, :], title=f'hypothesised logits with position 1 fixed as {neuron}', input1='position 2', input2='logit')

### Cosine similarity 

over all logits, mean over all inputs.

We find for S5 both the standard and natural representation yield good cosine similarity. This suggests the network really is performing the alleged algorithm. The product of the sign and standard representation is certainly not the algorithm employed

We find for S4 all three representations yield good similarity.

In [244]:
centred_logits = centred_logits.reshape(group.order*group.order,-1)
centred_trace = centred_trace.reshape(group.order*group.order,-1)
sims = F.cosine_similarity(centred_logits, centred_trace, dim=-1)
print(f'cosine similarity pre softmax {sims.mean()}')

centred_logits_softmax = F.softmax(centred_logits, dim=-1)
centred_trace_softmax = F.softmax(centred_trace, dim=-1)

softmax_sims = F.cosine_similarity(centred_logits_softmax, centred_trace_softmax, dim=-1).mean()
print(f'cosine similarity post softmax {softmax_sims.mean()}')

norms = torch.linalg.norm(centred_logits, dim=-1) 
scalar_projects = norms * sims 
unit_trace_logits = centred_trace / torch.linalg.norm(centred_trace, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(-1) / centred_logits.pow(2).sum(-1)
print(f'fraction of variance of logit explained by the trace logits {frac_explained.mean()}')

norms = torch.linalg.norm(centred_logits_softmax, dim=-1) 
scalar_projects = norms * softmax_sims 
unit_trace_logits = centred_trace_softmax / torch.linalg.norm(centred_trace_softmax, dim=-1, keepdim=True)
projects = unit_trace_logits * scalar_projects.unsqueeze(-1)
frac_explained = projects.pow(2).sum(-1) / centred_logits_softmax.pow(2).sum(-1)
print(f'fraction of variance of softmax logits explained by the softmax trace logits {frac_explained.mean()}')


cosine similarity pre softmax 0.6018916964530945
cosine similarity post softmax 0.93958580493927
fraction of variance of logit explained by the trace logits 0.3640381395816803
fraction of variance of softmax logits explained by the softmax trace logits 0.8828215003013611


### How different are the representations?


Hypothesis: the representations learned are those that are maximally different. Each injective representation will give the correct answer. Having very different logits will produce outputs that destructively intefere, minimising loss.


In [245]:
centred_traces = {'trivial': group.trivial_trace_tensor_cubes, 
                    'sign': group.sign_trace_tensor_cubes, 
                    'standard': group.standard_trace_tensor_cubes,
                    'standard_sign': group.standard_sign_trace_tensor_cubes}

if group.index == 4:
    centred_traces['s4_2d'] =  group.S4_2d_trace_tensor_cubes


for key, value in centred_traces.items():
    centred_traces[key] = value.reshape(group.order*group.order,-1)

for key1, value1 in centred_traces.items():
    for key2, value2 in centred_traces.items():
        print(f'{key1} vs {key2}: {F.cosine_similarity(value1, value2, dim=-1).mean()}')

trivial vs trivial: 0.0
trivial vs sign: 0.0
trivial vs standard: 0.0
trivial vs standard_sign: 0.0
trivial vs s4_2d: 0.0
sign vs trivial: 0.0
sign vs sign: 0.9999998807907104
sign vs standard: -4.3461718668424965e-09
sign vs standard_sign: -2.4835269396561444e-09
sign vs s4_2d: 0.0
standard vs trivial: 0.0
standard vs sign: -4.3461718668424965e-09
standard vs standard: 0.9999998211860657
standard vs standard_sign: -2.4835268064293814e-08
standard vs s4_2d: -3.725290298461914e-09
standard_sign vs trivial: 0.0
standard_sign vs sign: -2.4835269396561444e-09
standard_sign vs standard: -2.4835268064293814e-08
standard_sign vs standard_sign: 0.9999998211860657
standard_sign vs s4_2d: -3.725290298461914e-09
s4_2d vs trivial: 0.0
s4_2d vs sign: 0.0
s4_2d vs standard: -3.725290298461914e-09
s4_2d vs standard_sign: -3.725290298461914e-09
s4_2d vs s4_2d: 0.9999998211860657


### Theoretical optimal representations to use

Interesting. Now what loss would I get if I actually used each of these representations? Let's find out.

In [246]:
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    accuracy = (value.argmax(-1) == all_labels).float().mean()
    loss = loss_fn(value, all_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

trivial: 3.178053617477417, 4.1666669845581055%
sign: 2.6118361949920654, 8.333333969116211%
standard: 0.8650597929954529, 100.0%
standard_sign: 0.8650598526000977, 100.0%
s4_2d: 3.295471429824829, 12.5%


In [247]:
def linspace(start, stop, step=1.):
  """
    Like np.linspace but uses step instead of num
    This is inclusive to stop, so if start=1, stop=3, step=0.5
    Output is: array([1., 1.5, 2., 2.5, 3.])
  """
  return np.linspace(start, stop, int((stop - start) / step + 1))

granularity = 0.1

if group.index ==5:
  best_weights = [0, 0, 0, 0]
  best_loss = 100
  for i in linspace(0, 1, granularity):
      for j in linspace(0, 1-i, granularity):
          for k in linspace (0, 1-i-j, granularity):
              l = 1 - i - j - k
              hyp_logits = i*centred_traces['standard'] + j*centred_traces['sign'] + k*centred_traces['trivial'] + l*centred_traces['standard_sign']
              loss = loss_fn(hyp_logits, all_labels)
              if loss < best_loss:
                  best_loss = loss
                  best_weights = [i, j, k, l]
                  print(f'new best loss {best_loss} with weights {best_weights}')
  print(f'best combination: {best_loss}, {accuracy*100}%')
  print(f'best weights: {best_weights}')

if group.index==4:
  best_weights = [0.58, 0.27, 0, 0.19, 0.08]
  best_loss = 100
  for i in linspace(0, 1, granularity):
      for j in linspace(0, 1-i, granularity):
          for k in linspace (0, 1-i-j, granularity):
            for l in linspace (0, 1-i-j-k, granularity):
              m = 1 - i - j - k - l
              hyp_logits = i*centred_traces['standard'] + j*centred_traces['sign'] + k*centred_traces['trivial'] + l*centred_traces['standard_sign']+ m*centred_traces['s4_2d']
              loss = loss_fn(hyp_logits, all_labels)
              if loss < best_loss:
                  best_loss = loss
                  best_weights = [i, j, k, l, m]
                  print(f'new best loss {best_loss} with weights {best_weights}')
  print(f'best combination: {best_loss}')
  print(f'best weights: {best_weights}')

new best loss 3.295471429824829 with weights [0.0, 0.0, 0.0, 0.0, 1.0]
new best loss 2.9318184852600098 with weights [0.0, 0.0, 0.0, 0.1, 0.9]
new best loss 2.59440279006958 with weights [0.0, 0.0, 0.0, 0.2, 0.8]
new best loss 2.284221649169922 with weights [0.0, 0.0, 0.0, 0.30000000000000004, 0.7]
new best loss 2.001570701599121 with weights [0.0, 0.0, 0.0, 0.4, 0.6]
new best loss 1.7463009357452393 with weights [0.0, 0.0, 0.0, 0.5, 0.5]
new best loss 1.5180823802947998 with weights [0.0, 0.0, 0.0, 0.6000000000000001, 0.3999999999999999]
new best loss 1.3165165185928345 with weights [0.0, 0.0, 0.0, 0.7000000000000001, 0.29999999999999993]
new best loss 1.1410915851593018 with weights [0.0, 0.0, 0.0, 0.8, 0.19999999999999996]
new best loss 0.9910213947296143 with weights [0.0, 0.0, 0.0, 0.9, 0.09999999999999998]
new best loss 0.8650598526000977 with weights [0.0, 0.0, 0.0, 1.0, 0.0]
new best loss 0.8119459748268127 with weights [0.1, 0.0, 0.0, 0.9, 0.0]
new best loss 0.770799458026886 

In [248]:
best_weights = [0.58, 0.27, 0, 0.19, 0.08]
i, j, k, l, m = best_weights
hyp_logits = i*centred_traces['standard'] + j*centred_traces['sign'] + k*centred_traces['trivial'] + l*centred_traces['standard_sign']#+ m*centred_traces['s4_2d']
loss = loss_fn(hyp_logits, all_labels)
loss



tensor(0.9857, device='cuda:0')

Is this algorithm better on the alternating set? no


In [249]:
alternating_indices = [i for i in range(group.order) if group.signature(i) == 1]
alternating_data = group.get_subset_of_data(alternating_indices).cuda()
alternating_data, alternating_labels = alternating_data[:, :2], alternating_data[:, 2]

print(alternating_indices)
indices = []
for i in alternating_indices:
    for j in alternating_indices:
        indices.append(i*group.order+j)



[0, 2, 4, 6, 8, 10, 13, 15, 17, 19, 21, 23]


In [250]:
print('on alternating subgroup')
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    value = value[indices]
    accuracy = (value.argmax(-1) == alternating_labels).float().mean()
    loss = loss_fn(value, alternating_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

print('on whole group')
for key, value in centred_traces.items():
    value = value.reshape(group.order * group.order, -1)
    accuracy = (value.argmax(-1) == all_labels).float().mean()
    loss = loss_fn(value, all_labels)
    print(f'{key}: {loss}, {accuracy*100}%')

on alternating subgroup
trivial: 3.1780543327331543, 8.333333969116211%
sign: 2.6118338108062744, 8.333333969116211%
standard: 0.8650605082511902, 100.0%
standard_sign: 0.865060567855835, 100.0%
s4_2d: 1.7954720258712769, 25.0%
on whole group
trivial: 3.178053617477417, 4.1666669845581055%
sign: 2.6118361949920654, 8.333333969116211%
standard: 0.8650597929954529, 100.0%
standard_sign: 0.8650598526000977, 100.0%
s4_2d: 3.295471429824829, 12.5%


## Understanding the embeddings

The embedding x_embed is a (group.order, embed_dim) tensor. Let's first look at the singular values of this matrix. We find there are 9 large ones = (4-1) x (4-1)

In [251]:
u_a,s_a,v_a = torch.linalg.svd(model.W_x) 
print(s_a[:20])

u_b,s_b,v_b = torch.linalg.svd(model.W_y)
print(s_b[:20])

u,s,v = torch.linalg.svd(model.W)
print(s[:30])

u,s,v = torch.linalg.svd(model.W_U.T)
print(s[:50])

tensor([6.3954, 6.2867, 5.9924, 4.9930, 4.4637, 4.3878, 4.1225, 3.7712, 1.4182,
        1.1117, 1.0283, 0.9808, 0.9367, 0.8480, 0.7270, 0.7041, 0.6658, 0.6271,
        0.6043, 0.4699], device='cuda:0', grad_fn=<SliceBackward0>)
tensor([6.3021, 6.2545, 6.1591, 4.7668, 4.5945, 4.4863, 3.9569, 3.8314, 3.4648,
        1.2476, 1.1810, 1.0931, 1.0247, 1.0040, 0.8466, 0.6357, 0.5969, 0.5787,
        0.5099, 0.4586], device='cuda:0', grad_fn=<SliceBackward0>)
tensor([7.9638, 7.7322, 7.6229, 7.5385, 7.3306, 7.0248, 6.9311, 6.3965, 6.1976,
        5.7654, 5.6431, 5.3133, 5.1749, 5.0807, 4.9703, 3.2457, 1.5680, 0.6871,
        0.6107, 0.5369, 0.4583, 0.4246, 0.3802, 0.3679, 0.3462, 0.3192, 0.2998,
        0.2895, 0.2602, 0.2531], device='cuda:0', grad_fn=<SliceBackward0>)
tensor([4.7549, 4.3455, 3.7934, 3.6077, 3.5033, 3.4027, 3.3206, 3.2817, 3.2567,
        3.2101, 3.1653, 3.0539, 2.9524, 2.9380, 2.8199, 2.6941, 2.5531, 2.5214,
        2.3084, 2.0654, 1.4510, 1.1244, 0.3999, 0.3385], device='cud

### Low rank approximations

What happens if we take the low rank approximation gained from the first 9 singular values?


In [252]:
k=9
new_model = copy.deepcopy(model)
new_model.W_x = torch.nn.Parameter(u_a[:, :k] @ torch.diag(s_a[:k]) @ v_a[:k, :])
new_model.W_y = torch.nn.Parameter(u_b[:, :k] @ torch.diag(s_b[:k]) @ v_b[:k, :])
logits = model(all_data)
new_logits = new_model(all_data)
loss_new = loss_fn (new_logits, all_labels)
loss = loss_fn (logits, all_labels)

print(f'baselines loss {loss.item()}')
print(f'loss with only the largest 9 singular values of embeddings {loss_new.item()}')

baselines loss 9.829446935327724e-05
loss with only the largest 9 singular values of embeddings 9.274612239096314e-05


The natural representation has a 1-dimensional invariant subspace spanned by the sum of all basis vectors. Note too in the standard representation, the trace of any element is identical to the trace in the natural representation, up to some additive constant, which we discount from logits anyway. 

We can also see this by looking at its singular values of the standard representations, of which there are $group.index^2$, but only $(group.index-1)^2+1$ non zero ones. I hypothesise this is the 9 directions corresponding to the natural representation, then one for the "distance from the plane" formed when we collapse the standard rep to the natural rep. I hypothesise this 1 additional direction is not used by the network??

In [253]:
k = (group.index-1)**2 + 1
natural_reps = group.natural_reps.reshape(group.order, group.index*group.index)
u, s, v = torch.linalg.svd(natural_reps)
print(s)

# remove junk singular values
natural_reps_mod = u[:, :k] @ torch.diag(s[:k]) # this is going to get plugged into a proj operator, so we can drop the v
natural_reps_mod = torch.linalg.qr(natural_reps_mod)[0]

tensor([4.8990e+00, 2.8284e+00, 2.8284e+00, 2.8284e+00, 2.8284e+00, 2.8284e+00,
        2.8284e+00, 2.8284e+00, 2.8284e+00, 2.8284e+00, 1.4788e-07, 1.4178e-07,
        8.9138e-08, 8.3209e-08, 5.7613e-08, 5.0008e-08], device='cuda:0')


In the standard representation, all $(group.index-1)^2$ singular values are non zero.

In [254]:
standard_reps = group.standard_reps.reshape(group.order, (group.index-1)*(group.index-1))
u,s,v = torch.linalg.svd(standard_reps)
print(s)

# orthogonalise these
print(group.standard_reps_orth.shape)

tensor([6.8284, 5.2263, 3.6955, 2.8284, 2.8284, 2.8284, 2.1648, 1.5307, 1.1716],
       device='cuda:0')
torch.Size([24, 9])


In [255]:
standard_sign_reps = group.standard_sign_reps.reshape(group.order, (group.index-1)*(group.index-1))
u,s,v = torch.linalg.svd(standard_sign_reps)
print(s)

print(group.standard_sign_reps_orth.shape)


tensor([6.8284, 5.2263, 3.6955, 2.8284, 2.8284, 2.8284, 2.1648, 1.5307, 1.1716],
       device='cuda:0')
torch.Size([24, 9])


Let's see if the embeddings are actually learning representations. We can generate a $(group.order, group.index^2)$ tensor encoding the natural representation, or a $(group.order, (group.index-1)^2)$ encoding the standard representation. The embeddings are $(group.order, embed\_dim)$ tensors, which are higher dimensional than the representations. We've already shown the embeddings are sparse in the singular basis, but let's now show the dimensions actually line up.

In [256]:
def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

P_natural = projection_matrix_general(natural_reps_mod)
assert(torch.allclose(P_natural, P_natural@P_natural, atol=1e-6))
assert(torch.allclose(P_natural, P_natural.T, atol=1e-6))

P_standard = projection_matrix_general(group.standard_reps_orth)
assert(torch.allclose(P_standard, P_standard@P_standard, atol=1e-6))
assert(torch.allclose(P_standard, P_standard.T, atol=1e-6))

P_standard_sign = projection_matrix_general(group.standard_sign_reps_orth)
assert(torch.allclose(P_standard_sign, P_standard_sign@P_standard_sign, atol=1e-6))
assert(torch.allclose(P_standard_sign, P_standard_sign.T, atol=1e-6))



proj_a_natural = P_natural @ model.W_x
proj_b_natural = P_natural @ model.W_y
proj_U_natural = P_natural @ model.W_U.T

proj_a_standard = P_standard @ model.W_x
proj_b_standard = P_standard @ model.W_y
proj_U_standard = P_standard @ model.W_U.T

proj_a_standard_sign = P_standard_sign @ model.W_x
proj_b_standard_sign = P_standard_sign @ model.W_y
proj_U_standard_sign = P_standard_sign @ model.W_U.T

print(f'W_x frobenius norm: {model.W_x.pow(2).sum()}')
print(f'Projection onto natural representation {proj_a_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_a_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_a_standard_sign.pow(2).sum()}')
print(f'W_y frobenius norm: {model.W_y.pow(2).sum()}')
print(f'Projection onto natural representation {proj_b_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_b_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_b_standard_sign.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto natural representation {proj_U_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_U_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_U_standard_sign.pow(2).sum()}')




W_x frobenius norm: 221.51937866210938
Projection onto natural representation 121.28018188476562
Projection onto standard representation 120.8953857421875
Projection onto standard sign representation 37.8831787109375
W_y frobenius norm: 232.13304138183594
Projection onto natural representation 121.18186950683594
Projection onto standard representation 121.0242919921875
Projection onto standard sign representation 46.543434143066406
W_U frobenius norm: 213.47396850585938
Projection onto natural representation 107.65229797363281
Projection onto standard representation 88.83145141601562
Projection onto standard sign representation 38.703643798828125


Let's look at the matmul of the embedding and linear layer too.

In [257]:
embed_dim = model.W_x.shape[1]
unfactored_x = model.W_x @ model.W[:embed_dim, :]
unfactored_y =  model.W_y  @ model.W[embed_dim:, :]
proj_a_natural = P_natural @ unfactored_x
proj_b_natural = P_natural @ unfactored_y
proj_U_natural = P_natural @ model.W_U.T  

proj_a_standard = P_standard @ unfactored_x
proj_b_standard = P_standard @ unfactored_y
proj_U_standard = P_standard @ model.W_U.T 

proj_a_standard_sign = P_standard_sign @ unfactored_x
proj_b_standard_sign = P_standard_sign @ unfactored_y
proj_U_standard_sign = P_standard_sign @ model.W_U.T

print(f'W_x frobenius norm: {unfactored_x.pow(2).sum()}')
print(f'Projection onto natural representation {proj_a_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_a_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_a_standard_sign.pow(2).sum()}')
print(f'W_y frobenius norm: {unfactored_y.pow(2).sum()}')
print(f'Projection onto natural representation {proj_b_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_b_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_b_standard_sign.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto natural representation {proj_U_natural.pow(2).sum()}')
print(f'Projection onto standard representation {proj_U_standard.pow(2).sum()}')
print(f'Projection onto standard sign representation {proj_U_standard_sign.pow(2).sum()}')


W_x frobenius norm: 9516.873046875
Projection onto natural representation 6522.3798828125
Projection onto standard representation 6522.3095703125
Projection onto standard sign representation 1051.346923828125
W_y frobenius norm: 9470.5068359375
Projection onto natural representation 6506.85693359375
Projection onto standard representation 6506.4248046875
Projection onto standard sign representation 1011.0576171875
W_U frobenius norm: 213.47396850585938
Projection onto natural representation 107.65229797363281
Projection onto standard representation 88.83145141601562
Projection onto standard sign representation 38.703643798828125


Is it using a combination of the natural and standard representations? Is it using some combination of the natural, standard and sign reprsentatations? 

What even is the combined rank of all the representations? How orthogonal are they? They are all precisely orthogonal, which is good.

S5 doesnt seem to use the product standard sign rep at all?

In [258]:
reps = [group.standard_reps_orth, group.standard_sign_reps_orth, group.trivial_reps_orth, group.sign_reps_orth]
for i in range(len(reps)):
    for j in range(i+1, len(reps)):
        rep1 = reps[i]
        rep2 = reps[j]
        print((rep1.T@rep2).pow(2).sum())

tensor(2.5437e-13, device='cuda:0')
tensor(2.2728e-13, device='cuda:0')
tensor(9.8269e-14, device='cuda:0')
tensor(9.8269e-14, device='cuda:0')
tensor(2.2728e-13, device='cuda:0')
tensor(0., device='cuda:0')


In [259]:
all_reps = torch.cat([group.standard_reps.reshape(group.order, -1), group.standard_sign_reps.reshape(group.order, -1)], dim=1)
print(all_reps.shape)

u,s,v = torch.linalg.svd(all_reps)
print(s)

torch.Size([24, 18])
tensor([6.8284, 6.8284, 5.2263, 5.2263, 3.6955, 3.6955, 2.8284, 2.8284, 2.8284,
        2.8284, 2.8284, 2.8284, 2.1648, 2.1648, 1.5307, 1.5307, 1.1716, 1.1716],
       device='cuda:0')


In [260]:
q = torch.linalg.qr(all_reps)[0]
P_all = projection_matrix_general(q)
assert(torch.allclose(P_all, P_all@P_all, atol=1e-6))
assert(torch.allclose(P_all, P_all.T, atol=1e-6))
proj_a_all = P_all @ model.W_x
proj_b_all = P_all @ model.W_y
proj_U_all = P_all @ model.W_U.T

print(f'W_x frobenius norm: {model.W_x.pow(2).sum()}')
print(f'Projection onto representation {proj_a_all.pow(2).sum()}')
print(f'W_y frobenius norm: {model.W_y.pow(2).sum()}')
print(f'Projection onto representation {proj_b_all.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto representation {proj_U_all.pow(2).sum()}')

u,s,v = torch.linalg.svd(all_reps)
P_all = projection_matrix_general(u[:, :18])
assert(torch.allclose(P_all, P_all@P_all, atol=1e-6))
assert(torch.allclose(P_all, P_all.T, atol=1e-6))
proj_a_all = P_all @ model.W_x
proj_b_all = P_all @ model.W_y
proj_U_all = P_all @ model.W_U.T

print(f'W_x frobenius norm: {model.W_x.pow(2).sum()}')
print(f'Projection onto representation {proj_a_all.pow(2).sum()}')
print(f'W_y frobenius norm: {model.W_y.pow(2).sum()}')
print(f'Projection onto representation {proj_b_all.pow(2).sum()}')
print(f'W_U frobenius norm: {model.W_U.pow(2).sum()}')
print(f'Projection onto representation {proj_U_all.pow(2).sum()}')



W_x frobenius norm: 221.51937866210938
Projection onto representation 158.77854919433594
W_y frobenius norm: 232.13304138183594
Projection onto representation 167.56771850585938
W_U frobenius norm: 213.47396850585938
Projection onto representation 127.53509521484375
W_x frobenius norm: 221.51937866210938
Projection onto representation 158.778564453125
W_y frobenius norm: 232.13304138183594
Projection onto representation 167.56773376464844
W_U frobenius norm: 213.47396850585938
Projection onto representation 127.53509521484375


 ### Variance
 
Make sure that each of the representations actually uses all the information in equal parts from each of the dim^2 representation matrix elements. Do this by checking that the standard deviation of percentage contributions of each column is small.

In [261]:
# method 1 - probably wrong? I think this measures something like "std of neurons in representing the given representation"

embed_dim = model.W_x.shape[1]
unfactored_x = model.W_x @ model.W[:embed_dim, :]
unfactored_y = model.W_y  @ model.W[embed_dim:, :]

total_norm_x = unfactored_x.pow(2).sum()
proj_x_standard = P_standard @ unfactored_x
proj_x_standard_square = proj_x_standard.pow(2)
percents1 = proj_x_standard_square.sum(1) / total_norm_x

total_norm_y = unfactored_y.pow(2).sum()
proj_y_standard = P_standard @ unfactored_y
proj_y_standard_square = proj_y_standard.pow(2)
percents2 = proj_y_standard_square.sum(1) / total_norm_y

#print(torch.std(torch.cat((percents1, percents2))))




In [262]:
# method 2
rep = group.standard_reps_orth.reshape(group.order, -1)
dims = rep.shape[1]

contsx = []
for i in range(dims):
    x = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(x)
    proj_x_standard = P @ unfactored_x
    proj_x_standard_square = proj_x_standard.pow(2)
    contsx.append((proj_x_standard_square.sum() / total_norm_x).item())
contsx = torch.tensor(contsx)
print(f'conts: {contsx}')
print(f'std: {contsx.std()}')
print(f'sum: {contsx.sum()}')
line(contsx)

contsy = []
for i in range(dims):
    y = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(y)
    proj_y_standard = P @ unfactored_y
    proj_y_standard_square = proj_y_standard.pow(2)
    contsy.append((proj_y_standard_square.sum() / total_norm_y).item())
contsy = torch.tensor(contsy)
print(f'conts: {contsy}')
print(f'std: {contsy.std()}')
print(f'sum: {contsy.sum()}')
line(contsy)

conts = [] 
total_norm_U = model.W_U.pow(2).sum()
for i in range(dims):
    y = rep[:, i].unsqueeze(-1)
    P = projection_matrix_general(y)
    proj_U = P @ model.W_U.T 
    proj_U_square = proj_U.pow(2)
    conts.append((proj_U_square.sum() / total_norm_U).item())
conts = torch.tensor(conts)
print(f'conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)

conts: tensor([0.1669, 0.0635, 0.0364, 0.1408, 0.0540, 0.0301, 0.1204, 0.0470, 0.0262])
std: 0.05252499878406525
sum: 0.6853418350219727


conts: tensor([0.0310, 0.0283, 0.0310, 0.0556, 0.0508, 0.0555, 0.1487, 0.1356, 0.1505])
std: 0.0526602640748024
sum: 0.6870197057723999


conts: tensor([0.0489, 0.0548, 0.0513, 0.0467, 0.0427, 0.0480, 0.0445, 0.0384, 0.0408])
std: 0.005207408219575882
sum: 0.4161231517791748


## Understanding the hidden layer

Under the hypothesis that the network is embedding as representations, We should expect the hidden layer to contain terms like "rho(x)rho(y)".

Consider a model run on all the (group.order * group.order) inputs. The hidden layer is then a (group.order * group.order, hidden) tensor. 

Assuming some representation rho, we should expect to find a (group.order * group.order, rep_dim, rep_dim) = (group.order * group.order, rep_dim * rep_dim) tensor, with rep_dim << hidden. 

We can perform a similar projection trick here as we did on the embeddings.

We center this hidden neurons. I'm actually not sure how important this is.

In [263]:
hidden = activations['hidden']
hidden = hidden.reshape(group.order*group.order, -1)
mean = hidden.mean(dim=-1, keepdim=True)
hidden_centred = hidden - mean
u, s, v = torch.linalg.svd(hidden_centred)
#line(s)

logits = hidden @ model.W_U
print(hidden.shape, logits.shape)

loss = loss_fn(logits, all_labels)

centred_logits = hidden_centred @ model.W_U
centred_loss = loss_fn(centred_logits, all_labels)

print(f'Loss: {loss}')
print(f'Centred Hidden loss: {centred_loss}')

# i think this is close enough?

torch.Size([576, 64]) torch.Size([576, 24])
Loss: 9.829446935327724e-05
Centred Hidden loss: 0.00013269681949168444


In [264]:
den = hidden_centred.pow(2).sum().item()
reps = {'trivial': group.trivial_reps, 
'sign': group.sign_reps, 
'standard': group.standard_reps, 
'standard_sign' : group.standard_sign_reps}

total = 0
for name, rep in reps.items():
    hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
    hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
    hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)
    hidden_reps_xy = torch.linalg.qr(hidden_reps_xy)[0]
    hidden_reps_x = torch.linalg.qr(hidden_reps_x)[0]
    hidden_reps_y = torch.linalg.qr(hidden_reps_y)[0]
    P_xy = projection_matrix_general(hidden_reps_xy)
    P_x = projection_matrix_general(hidden_reps_x) 
    P_y = projection_matrix_general(hidden_reps_y)
    proj_xy = P_xy @ hidden_centred
    u,s,v = torch.linalg.svd(proj_xy)
    proj_x = P_x @ hidden_centred
    proj_y = P_y @ hidden_centred
    num_xy = proj_xy.pow(2).sum()
    num_x = proj_x.pow(2).sum()
    num_y = proj_y.pow(2).sum()
    percent_x = num_x / den
    percent_y = num_y / den
    percent_xy = num_xy / den
    total += percent_x
    total += percent_y
    total += percent_xy
    
    print(f'{name}, x: {percent_x}, y: {percent_y}, xy: {percent_xy}')
print(f'Total percent explained by reps: {total}')

trivial, x: 0.025384150445461273, y: 0.025384150445461273, xy: 0.025384150445461273
sign, x: 0.020276417955756187, y: 0.020495379343628883, xy: 0.01825098879635334
standard, x: 0.2577059268951416, y: 0.2551901638507843, xy: 0.09756028652191162
standard_sign, x: 0.03931063786149025, y: 0.0377165749669075, xy: 0.011246347799897194
Total percent explained by reps: 0.8339051008224487


Claim: The output logits should not depend on the directions corresponding to x or y, but only those corresponding to xy. Let's test this. To do so, we first just ablate those directions and see what happens to loss. If we ablate the xy direction we should see loss tank.


In [265]:
logits = hidden @ model.W_U
loss = loss_fn(logits, all_labels)
print(f'baseline loss: {loss}')


for name, rep in reps.items():
    hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
    hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
    hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)
    hidden_reps_xy = torch.linalg.qr(hidden_reps_xy)[0]
    hidden_reps_x = torch.linalg.qr(hidden_reps_x)[0]
    hidden_reps_y = torch.linalg.qr(hidden_reps_y)[0]
    P_xy = projection_matrix_general(hidden_reps_xy)
    P_x = projection_matrix_general(hidden_reps_x) 
    P_y = projection_matrix_general(hidden_reps_y)
    proj_xy = P_xy @ hidden_centred
    hidden_centred_xy = hidden_centred - proj_xy
    proj_x = P_x @ hidden_centred
    hidden_centred_x = hidden_centred - proj_x
    proj_y = P_y @ hidden_centred
    hidden_centred_y = hidden_centred - proj_y

    logits_xy = hidden_centred_xy @ model.W_U
    loss_xy = loss_fn(logits_xy, all_labels)
    logits_x = hidden_centred_x @ model.W_U
    loss_x = loss_fn(logits_x, all_labels)
    logits_y = hidden_centred_y @ model.W_U
    loss_y = loss_fn(logits_y, all_labels)

    print(f'Ablating directions corresponding to {name} rep loss, xy: {loss_xy}, x: {loss_x}, y: {loss_y}')
    

baseline loss: 9.829446935327724e-05
Ablating directions corresponding to trivial rep loss, xy: 8.61176522448659e-05, x: 8.61176522448659e-05, y: 8.61176522448659e-05
Ablating directions corresponding to sign rep loss, xy: 0.0017227226635441184, x: 0.00011770277342293411, y: 0.00012524216435849667
Ablating directions corresponding to standard rep loss, xy: 1.2313432693481445, x: 6.505138298962265e-05, y: 7.671516505070031e-05
Ablating directions corresponding to standard_sign rep loss, xy: 0.0018795751966536045, x: 0.00012790279288310558, y: 0.00010995608317898586


We show that the the spaces corresponding to x and y representations in the hidden layer are in the kernel of W_U.

In [266]:
print('Mean norms of unembedded representation vectors')
for name, rep in reps.items():
    hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
    hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
    hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)
    
    P_xy = projection_matrix_general(hidden_reps_xy)
    P_x = projection_matrix_general(hidden_reps_x) 
    P_y = projection_matrix_general(hidden_reps_y)
    proj_xy = P_xy @ hidden_centred
    proj_x = P_x @ hidden_centred
    proj_y = P_y @ hidden_centred

    unembed_xy = proj_xy @ model.W_U
    unembed_x = proj_x @ model.W_U
    unembed_y = proj_y @ model.W_U

    print(f'{name} rep, xy: {unembed_xy.pow(2).mean()}, x: {unembed_x.pow(2).mean()}, y: {unembed_y.pow(2).mean()}')


Mean norms of unembedded representation vectors
trivial rep, xy: 1.9402427673339844, x: 1.9402427673339844, y: 1.9402427673339844
sign rep, xy: 4.104244232177734, x: 0.030613139271736145, y: 0.03885919973254204
standard rep, xy: 11.336535453796387, x: 0.37504273653030396, y: 0.41304516792297363
standard_sign rep, xy: 0.9656031727790833, x: 0.09219139814376831, y: 0.11231154203414917


### Variance

Are some matrix elements of xy in the standard representation preferred over others in the hidden layer? Yes. I claim this is actually pretty deep.

Let $e_i$ be the standard basis matrices. Let $\mathbb{x} := \lambda_i x_i \mathbb{e}_i$. Let $\mathbb{y} := \delta_i y_i \mathbb{e}_i$ be such that the mean of $x_i$ and mean of $y_i$ is one over all inputs for fixed i, i.e. $\lambda_i,\delta_i$ encode an i-dependent scaling factor that we calcualate below. Then consider the matrix product $\mathbb{x}\mathbb{y} = \mathbb{z}$ = $z_i \mathbb{e}_i$. 

Then assuming 5d matrices, eg $z_0 = \lambda_0 \delta_0 x_0 y_0 + \lambda_1 \delta_5 x_1 y_5 + \lambda_2 \delta_{10} x_2 y_{10} + \lambda_3 \delta_15 x_3 y_{15} + \lambda_4 \delta_{20} x_4 y_{20}$.

This only makes sense dimensionally if $\delta_0 = \delta_5 = \delta_{10} = \delta_{15} = \delta_{20}$ and $\lambda_0 = \lambda_1 = \lambda_2 = \lambda_3 = \lambda_4$.



In [267]:
rep = group.standard_reps_orth.reshape(group.order, -1)
print(rep)
dims = rep.shape[1]

hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)
hidden_reps_x = rep[all_data[:, 0]].reshape(group.order*group.order, -1)
hidden_reps_y = rep[all_data[:, 1]].reshape(group.order*group.order, -1)

#P_xy = projection_matrix_general(hidden_reps_xy)
#P_x = projection_matrix_general(hidden_reps_x) 
#P_y = projection_matrix_general(hidden_reps_y)
#proj_xy = P_xy @ hidden_centred
#proj_x = P_x @ hidden_centred
#
#proj_y = P_y @ hidden_centred

total_norm = hidden_centred.pow(2).sum()

conts = []
for i in range(dims):
    a = hidden_reps_x[:, i].unsqueeze(-1)
    P_x = projection_matrix_general(a)
    proj = P_x @ hidden_centred
    proj_square = proj.pow(2)
    conts.append((proj_square.sum() / total_norm).item())
conts = torch.tensor(conts)
print(f'x conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)

conts = []
for i in range(dims):
    a = hidden_reps_y[:, i].unsqueeze(-1)
    P_y = projection_matrix_general(a)
    proj = P_y @ hidden_centred
    proj_square = proj.pow(2)
    conts.append((proj_square.sum() / total_norm).item())
conts = torch.tensor(conts)
print(f'y conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)


conts = []
for i in range(dims):
    a = hidden_reps_xy[:, i].unsqueeze(-1)
    P_xy = projection_matrix_general(a)
    proj = P_xy @ hidden_centred
    proj_square = proj.pow(2)
    conts.append((proj_square.sum() / total_norm).item())
conts = torch.tensor(conts)
print(f'xy conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)

tensor([[-2.8868e-01,  2.0412e-01,  1.1334e-08,  1.6667e-01,  2.3570e-01,
         -2.0412e-01,  1.1785e-01,  1.6667e-01,  2.8868e-01],
        [ 0.0000e+00, -3.0619e-01,  1.7678e-01, -2.9802e-08,  1.7678e-01,
          3.0619e-01, -3.5355e-01, -9.6858e-08,  9.3132e-09],
        [ 0.0000e+00,  0.0000e+00, -3.5355e-01, -3.3333e-01, -1.1785e-01,
         -5.9605e-08,  1.1785e-01, -3.3333e-01, -2.9802e-08],
        [ 0.0000e+00,  3.0619e-01,  1.7678e-01,  0.0000e+00,  1.7678e-01,
         -3.0619e-01, -3.5355e-01,  3.3528e-08, -1.7619e-07],
        [-2.8868e-01, -1.0206e-01,  1.7678e-01,  1.6667e-01,  5.8926e-02,
          3.0619e-01,  1.1785e-01, -3.3333e-01, -2.2352e-08],
        [ 0.0000e+00, -3.0619e-01, -1.7678e-01, -3.3333e-01,  5.8926e-02,
         -1.0206e-01,  1.1785e-01,  1.6667e-01, -2.8868e-01],
        [ 2.8868e-01,  1.0206e-01, -1.7678e-01,  1.6667e-01, -2.9463e-01,
          1.0206e-01,  1.1785e-01,  1.6667e-01,  2.8868e-01],
        [ 0.0000e+00,  0.0000e+00,  3.5355e-01, 

y conts: tensor([0.0154, 0.0159, 0.0158, 0.0218, 0.0213, 0.0221, 0.0503, 0.0422, 0.0503])
std: 0.014855880290269852
sum: 0.2551901340484619


xy conts: tensor([0.0126, 0.0127, 0.0126, 0.0109, 0.0095, 0.0118, 0.0101, 0.0082, 0.0092])
std: 0.0016876683803275228
sum: 0.09756031632423401


## Signature Representation Circuit Analysis

TLDR: the network memorises input signatures, mangles them together using the relu, and uses this to divide the output search space in two.


We know the S5 network uses the sign representation to some extent. Let's find this in the x and y embedding in the first instance. We'll later want to find this in the hidden layer too, and in the logits. 

We find in our run, (1L_MLP_sym_S5_cached_3), neurons 24, 66, 73, and 81 have high sign representation projection. Let's now print inspect these neurons. It's curious there are 4 of these.


In [268]:
embed_dim = model.W_x.shape[1]
unfactored_x = model.W_x @ model.W[:embed_dim, :]
P_sign = projection_matrix_general(group.sign_reps_orth)
proj_x_sign = P_sign @ unfactored_x
proj_x_sign_summed = proj_x_sign.pow(2).sum(dim=0)
print(torch.topk(proj_x_sign_summed, 10))
signature_neurons_x = torch.topk(proj_x_sign_summed, 4).indices


torch.return_types.topk(
values=tensor([135.3074, 135.3053, 127.9919, 127.9802,  17.7410,   0.6434,   0.4606,
          0.3291,   0.2982,   0.2614], device='cuda:0',
       grad_fn=<TopkBackward0>),
indices=tensor([ 3, 38,  6, 29, 21, 47, 17, 43, 36, 28], device='cuda:0'))


In [269]:
embed_dim = model.W_y.shape[1]
unfactored_y = model.W_y @ model.W[embed_dim:, :]
P_sign = projection_matrix_general(group.sign_reps_orth)
proj_y_sign = P_sign @ unfactored_y
proj_y_sign_summed = proj_y_sign.pow(2).sum(dim=0)
print(torch.topk(proj_y_sign_summed, 10))
signature_neurons_y = torch.topk(proj_y_sign_summed, 4).indices


torch.return_types.topk(
values=tensor([1.3437e+02, 1.3434e+02, 1.3005e+02, 1.3002e+02, 2.6863e+01, 3.4359e-01,
        1.5912e-01, 1.4563e-01, 1.4528e-01, 1.2325e-01], device='cuda:0',
       grad_fn=<TopkBackward0>),
indices=tensor([ 3, 38,  6, 29, 21, 36, 37, 17, 22, 43], device='cuda:0'))


In [270]:
assert (signature_neurons_x == signature_neurons_y).all()
signature_neurons = signature_neurons_x


In [271]:
print(signature_neurons)
sig_labels = [str(x) for x in signature_neurons.tolist()]

tensor([ 3, 38,  6, 29], device='cuda:0')


The first row is the sign represntation. The next four rows are the x embeddings. The next image is the same but for y embeddings. Let a = 3.8. We learn that (for the same cached S5 model)

neuron 24 = $a ReLU (-sign(x) + sign(y)) > 0$ iff x odd and y even ie z odd

neuron 66 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 73 = $a ReLU (sign(x) - sign(y)) > 0$ iff x even and y odd ie z odd

neuron 81 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even



In 1L_MLP_cached_S4 we find

neuron 3 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 38 = $a ReLU (-sign(x) - sign(y)) > 0$ iff x odd and y odd ie z even

neuron 6 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even

neuron 29 = $a ReLU (sign(x) + sign(y)) > 0$ iff x even and y even ie z even


In [272]:
sigs = group.signatures.unsqueeze(-1)
xs = unfactored_x[:, signature_neurons]
stack = torch.hstack([sigs, xs]).T
imshow(stack, y=['sig'] + sig_labels, input2='input group element')
ys = unfactored_y[:, signature_neurons]
stack = torch.hstack([sigs, ys]).T
imshow(stack, y=['sig'] + sig_labels, input2='input group element')


How does this effect logits? We simply look at the relevant unembed rows. We expect rows 66 and 81 to correlate and 24 and 73 to anticorrelate with the (now output) signature. Thus the logits align with the trace formula.

In [273]:
sigs = group.signatures.unsqueeze(-1)
W_U_signatures = model.W_U[signature_neurons, :].T
#W_U_24 = model.W_U[24, :].unsqueeze(-1)
#W_U_66 = model.W_U[66, :].unsqueeze(-1)
#W_U_73 = model.W_U[73, :].unsqueeze(-1)
#W_U_81 = model.W_U[81, :].unsqueeze(-1)
#stack = torch.hstack([sigs, W_U_24, W_U_66, W_U_73, W_U_81]).T
stack = torch.hstack([sigs, W_U_signatures]).T

imshow(stack, y=['sig'] + sig_labels, input2='output group element')

## Standard Representation Circuit Analysis

We know the S5 network uses the sign representation to some extent. Let's find this in the x and y embedding in the first instance. I want to give the network some input and extract the representation it uses. Then we can see if the hidden layer actually has the matrix product of these things? seems tough....

Our unfactored embedding matrix has size (group.order, hidden). The representations can be encoded in a (group.order, dim^2) tensor. I want the coefficients on this dim^2 subspace for some given input group element. This requires identifying the linear map hidden -> dim^2. Can I just take embedding.T @ rep.

Let's try

Another approach is to use the singular value of the embedding as the coeffiecent, and the left singular vector as the indicator for whether that coefficient is used.

I'm likely overcomplicating this. I suspect all i need is a matrix multiply from an orthogonal set of vectors

### Janky Method 1

In [274]:
rep = group.standard_reps.reshape(group.order, -1)
dims = rep.shape[1]

for idx in range(2):

    print(group.idx_to_perm(idx))
    matrix = []

    print(unfactored_y.shape)

    m = unfactored_y.T @ rep

    print(m.shape)

    out = m.T @ unfactored_y[idx]

    print(out)

    out = out.reshape(4,4)

    imshow(out-out.mean(-1, keepdim=True), title = 'network')
    imshow(group.standard_reps[idx], title = 'real')


(3)
torch.Size([24, 64])
torch.Size([64, 9])
tensor([787.9366, 972.0196, 765.3860, 535.2859, 631.3893, 493.1357, 800.3531,
        982.7974, 779.4029], device='cuda:0', grad_fn=<MvBackward0>)


RuntimeError: shape '[4, 4]' is invalid for input of size 9

### Janky method 2


In [None]:
rep = group.standard_reps.reshape(group.order, -1)
dims = rep.shape[1]
contsx = []

for idx in range(2):
    mat = []
    for i in range(dims):
        x = rep[:, i].unsqueeze(-1)
        P = projection_matrix_general(x)
        proj_x = P @ unfactored_x
        mat.append(proj_x[idx].pow(2).sum())
    mat = torch.tensor(mat)
    matx = mat.reshape(4,4)
    imshow(mat, title='network')
    imshow(rep[idx].reshape(4,4), title='real')

for idx in range(1):
    mat = []
    for i in range(dims):
        x = rep[:, i].unsqueeze(-1)
        P = projection_matrix_general(x)
        proj_y = P @ unfactored_y
        mat.append(proj_y[idx].pow(2).sum())
    mat = torch.tensor(mat)
    maty = mat.reshape(4,4)
    imshow(mat, title='network')
    imshow(rep[idx].reshape(4,4), title='real')

matz = matx @ maty

imshow(matz, title='embed product')

conts = []
for i in range(dims):
    a = hidden_reps_xy[:, i].unsqueeze(-1)
    P_xy = projection_matrix_general(a)
    proj = P_xy @ hidden_centred
    proj_square = proj.pow(2)
    conts.append((proj_square.pow(2).sum()))
conts = torch.tensor(conts)
print(f'xy conts: {conts}')
print(f'std: {conts.std()}')
print(f'sum: {conts.sum()}')
line(conts)

ValueError: px.imshow only accepts 2D single-channel, RGB or RGBA images. An image of shape (16,) was provided. Alternatively, 3- or 4-D single or multichannel datasets can be visualized using the `facet_col` or/and `animation_frame` arguments.

### Janky method 3 (which i think is least janky)

In [None]:
rep = group.standard_reps.reshape(group.order, -1)
dims = rep.shape[1]
dim = int(math.sqrt(dims))
contsx = []

mats_x = []
mats_y = []

for idx in range(3):
    mat = []
    for i in range(dims):
        x = rep[:, i].unsqueeze(-1)
        P = projection_matrix_general(x)
        proj_x = P @ unfactored_x
        u, s, v = torch.svd(proj_x)
        mat.append(s[0]*u[idx, 0])
    mat = torch.tensor(mat)
    matx = mat.reshape(dim, dim)
    imshow(matx, title='network')
    imshow(rep[idx].reshape(dim, dim), title='real')
    mats_x.append(matx)

for idx in range(3):
    mat = []
    for i in range(dims):
        x = rep[:, i].unsqueeze(-1)
        P = projection_matrix_general(x)
        proj_y = P @ unfactored_y
        u, s, v = torch.svd(proj_y)
        mat.append(s[0]*u[idx, 0])
    mat = torch.tensor(mat)
    matx = mat.reshape(dim, dim)
    imshow(matx, title='network')
    imshow(rep[idx].reshape(dim, dim), title='real')
    mats_y.append(matx)


In [None]:
rep = group.standard_reps.reshape(group.order, -1)
dims = rep.shape[1]
dim = int(math.sqrt(dims))

hidden_reps_xy = rep[all_labels].reshape(group.order*group.order, -1)

hidden = activations['hidden'].reshape(group.order*group.order, -1)


mats_xy = []
for idx in range(5):
    mat = []
    for i in range(dims):
        a = hidden_reps_xy[:, i].unsqueeze(-1)
        P_xy = projection_matrix_general(a)
        proj = P_xy @ hidden
        u, s, v = torch.svd(proj)
        mat.append(s[0]*u[idx, 0])
    mat = torch.tensor(mat)
    mat = mat.reshape(dim, dim)
    imshow(mat, title='network hidden xy')
    m_x = mats_x[all_data[idx, 0]]
    m_y = mats_y[all_data[idx, 1]]
    imshow(m_x @ m_y, title='network x @ y')
    imshow(hidden_reps_xy[idx].reshape(dim, dim), title='real hidden xy')


IndexError: list index out of range