# Set Up


In [2]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
if "cd" not in globals():
    os.chdir("../")
    cd = True
print(os.getcwd())
save_dir = 'paper/figures'

Mon Jan 16 11:15:05 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  On   | 00000000:00:05.0 Off |                    0 |
| N/A   41C    P0    55W / 300W |   3055MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
import wandb
import statsmodels.api as sm
import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import plotly.subplots

pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *
from utils.config import *
from utils.figures import *

In [4]:
if torch.cuda.is_available:
    print('Good to go!')
else:
    print('Things might be rather slow')

Good to go!


In [5]:
task_dir = "paper/mainline-S5-old"

seed, frac_train, layers, lr, group_param, weight_decay, num_epochs, group_type, architecture_type = load_cfg(task_dir)
group = group_type(group_param, init_all = True)
model = architecture_type(layers, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"), strict=False)
model.eval()
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
logits, activations = model.run_with_cache(group.all_data, return_cache_object=False)
activations['logits'] = logits
metric_obj = Metrics(group, training=False, track_metrics = True)
key_reps = metric_obj.determine_key_reps(model)

metrics_path = os.path.join(task_dir, 'metrics.csv')
summary_metrics_path = os.path.join(task_dir, 'summary_metrics.json')

# load the metrics
metrics = pd.read_csv(metrics_path)
summary_metrics = json.load(open(summary_metrics_path, 'r'))

reps_to_plot = list(group.non_trivial_irreps.keys())

Computing multiplication table...
... loading from file
Computing trace tensor cube for trivial representation
... loading from file
Computing trace tensor cube for sign representation
... loading from file
Computing trace tensor cube for standard representation
... loading from file
Computing trace tensor cube for standard_sign representation
... loading from file
Computing trace tensor cube for s5_5d_a representation
... loading from file
Computing trace tensor cube for s5_5d_b representation
... loading from file
Computing trace tensor cube for s5_6d representation
... loading from file


# Logit Attribution


In [6]:
# figure: a slice of true vs hypothetical logit cubes
logits = activations['logits']
logits = logits.reshape(group.order, group.order, group.order)

true_logits = logits[:, :, 0]/logits[:, :, 0].abs().max()
key_rep_logits = []
for key_rep in key_reps:
    key_rep_logits.append(group.irreps[key_rep].logit_trace_tensor_cube[:, :, 0]/group.irreps[key_rep].logit_trace_tensor_cube[:, :, 0].abs().max())
key_rep_logits = torch.stack(key_rep_logits, dim=0)

stack = torch.vstack([true_logits.unsqueeze(0), key_rep_logits])
print(stack.shape)
fig = px.imshow(to_numpy(stack), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, title=f'logit 0', facet_col=0, labels={'x':'b', 'y':'a', 'facet_col': 'label'})
fig.layout.annotations[0]['text'] = 'true logit 0 over all inputs' 
for i in range(len(key_reps)):
    fig.layout.annotations[i+1]['text'] = f"{key_reps[i]} hypothesised logit 0"
fig.show()
fig.write_image(f"{save_dir}/logit_cubes.png", width=1000)


torch.Size([3, 120, 120])


In [7]:
# figure: evolution of cosine similarity
template = "logit_{}_rep_trace_similarity"
lines_from_template(metrics, template, reps_to_plot, title="cosine similarity of true logits and hypothesised logits", yaxis="cosine similarity", save=f"{save_dir}/logit_similarity.png", log_x=True)


Saving to paper/figures/logit_similarity.png


In [8]:
# evidence: end of training logit cosine similarity

template = "logit_{}_rep_trace_similarity"
print("cosine similarity of true logits and hypothesised logits at end of training")
for irrep in group.non_trivial_irreps:
    print(f"{irrep}: {summary_metrics[template.format(irrep)]}")


cosine similarity of true logits and hypothesised logits at end of training
sign: 0.3664345145225525
standard: 0.6768929958343506
standard_sign: 9.658833732828498e-09
s5_5d_a: -2.663000486791134e-09
s5_5d_b: 7.522247324232012e-05
s5_6d: 2.1354935597628355e-09


In [9]:
# percentage explained
percent = 0 
for rep in key_reps:
    percent += summary_metrics[template.format(rep)]**2
percent

0.5924583812429809

# Embeddings and Unembeddings

In [10]:
# figure: percent a, b, c embed by representation over course of training

template = "percent_x_embed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, title="Fraction of variance of left embedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_x_embed.png", log_x=True)
template = "percent_y_embed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, title="Fraction of variance of right embedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_y_embed.png", log_x=True)
template = "percent_unembed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, title="Fraction of variance of unembedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_unembed.png", log_x=True)

Saving to paper/figures/percent_x_embed.png
Saving to paper/figures/percent_y_embed.png
Saving to paper/figures/percent_unembed.png


# Hidden Layer Neurons

In [11]:
# figure: evolution of \rho(a), \rho(b), \rho(ab)
template = "percent_hidden_{}_rep"
lines_from_template(metrics, template, reps_to_plot, title="Fraction of variance of MLP neurons explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_hidden.png", log_x=True)

Saving to paper/figures/percent_hidden.png


In [12]:
# evidence: neuron clustering pre ReLU

threshold = 1

x_embed = model.x_embed
y_embed = model.y_embed

x_embed_summed = x_embed.pow(2).sum(dim=0)
off_neurons_x = (x_embed_summed < threshold).nonzero().squeeze()

y_embed_summed = y_embed.pow(2).sum(dim=0)
off_neurons_y = (y_embed_summed < threshold).nonzero().squeeze()

assert (off_neurons_x == off_neurons_y).all()

off_neurons = off_neurons_x

print(f'Off neurons: {len(off_neurons)}, {off_neurons}')

rep_neurons = {}

print('Neurons corresponding to each representation')
for rep_name in group.non_trivial_irreps:
    rep = group.irreps[rep_name].orth_rep
    coefs_x = rep.T @ x_embed
    coefs_y = rep.T @ y_embed
    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)

    x_neurons = (coefs_x_summed > threshold).nonzero().squeeze()
    y_neurons = (coefs_y_summed > threshold).nonzero().squeeze()
    assert (x_neurons == y_neurons).all()
    x_neurons = torch.tensor(x_neurons)
    if x_neurons.dim() == 0:
        x_neurons = x_neurons.unsqueeze(0)
    rep_neurons[rep_name] = x_neurons
    print(f'{rep_name}: {len(x_neurons)}, {x_neurons}')

print(rep_neurons)

all_neurons = torch.arange(model.W_U.shape[0])
unaccounted_neurons = set(all_neurons.tolist())
unaccounted_neurons -= set(off_neurons.tolist())
for rep_name, neurons in rep_neurons.items():
    unaccounted_neurons -= set(neurons.tolist())

print('Unaccounted neurons')
print(unaccounted_neurons)

Off neurons: 49, tensor([  0,   1,   5,   8,   9,  11,  12,  13,  14,  18,  19,  22,  25,  33,
         35,  44,  46,  47,  50,  51,  53,  55,  56,  61,  65,  68,  70,  72,
         76,  77,  78,  79,  82,  84,  85,  92,  93,  95,  99, 100, 104, 105,
        112, 114, 115, 116, 117, 119, 121], device='cuda:0')
Neurons corresponding to each representation
sign: 4, tensor([24, 66, 73, 81], device='cuda:0')
standard: 75, tensor([  2,   3,   4,   6,   7,  10,  15,  16,  17,  20,  21,  23,  26,  27,
         28,  29,  30,  31,  32,  34,  36,  37,  38,  39,  40,  41,  42,  43,
         45,  48,  49,  52,  54,  57,  58,  59,  60,  62,  63,  64,  67,  69,
         71,  74,  75,  80,  83,  86,  87,  88,  89,  90,  91,  94,  96,  97,
         98, 101, 102, 103, 106, 107, 108, 109, 110, 111, 113, 118, 120, 122,
        123, 124, 125, 126, 127], device='cuda:0')
standard_sign: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_a: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_b: 0, t


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [13]:
# evidence and table: neuron clustering in post hidden layer

threshold = 110

hidden = activations['hidden'].reshape(group.order**2, -1)

hidden_summed = hidden.pow(2).sum(dim=0)
off_neurons = (hidden_summed < threshold).nonzero().squeeze()

assert (off_neurons == off_neurons_x).all()

print(f'Off neurons: {off_neurons}')


fracs_explained_x = {}
fracs_explained_y = {}
fracs_explained_xy = {}
fracs_explained_trivial = {}

for rep_name in group.irreps.keys():
    group.irreps[rep_name].hidden_reps_x = group.irreps[rep_name].rep[all_data[:, 0]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_x_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_x)[0]
    group.irreps[rep_name].hidden_reps_y = group.irreps[rep_name].rep[all_data[:, 1]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_y_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_y)[0]
    group.irreps[rep_name].hidden_reps_xy = group.irreps[rep_name].rep[all_labels].reshape(group.order*group.order, -1)
    group.irreps[rep_name].hidden_reps_xy_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_xy)[0]

for rep_name in key_reps:
    rep_x = group.irreps[rep_name].hidden_reps_x_orth
    rep_y = group.irreps[rep_name].hidden_reps_y_orth
    rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    trivial = group.irreps['trivial'].hidden_reps_x_orth

    coefs_x = rep_x.T @ hidden
    coefs_y = rep_y.T @ hidden
    coefs_xy = rep_xy.T @ hidden

    coefs_trivial = trivial.T @ hidden

    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)
    coefs_xy_summed = coefs_xy.pow(2).sum(dim=0)
    coefs_trivial_summed = coefs_trivial.pow(2).sum(dim=0)


    neurons = rep_neurons[rep_name]

    frac_x = (coefs_x_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_y = (coefs_y_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_xy = (coefs_xy_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_trivial = (coefs_trivial_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())

    fracs_explained_x[rep_name] = frac_x
    fracs_explained_y[rep_name] = frac_y
    fracs_explained_xy[rep_name] = frac_xy
    fracs_explained_trivial[rep_name] = frac_trivial

print('Neurons corresponding to each representation')
for key in key_reps:
    print(f'frac variance explained in {key} x, y, xy: {fracs_explained_x[key], fracs_explained_y[key], fracs_explained_xy[key], fracs_explained_trivial[key]}')
    print(f'Sum of explained variance: {fracs_explained_x[key] + fracs_explained_y[key] + fracs_explained_xy[key] + fracs_explained_trivial[key]}')

Off neurons: tensor([  0,   1,   5,   8,   9,  11,  12,  13,  14,  18,  19,  22,  25,  33,
         35,  44,  46,  47,  50,  51,  53,  55,  56,  61,  65,  68,  70,  72,
         76,  77,  78,  79,  82,  84,  85,  92,  93,  95,  99, 100, 104, 105,
        112, 114, 115, 116, 117, 119, 121], device='cuda:0')
Neurons corresponding to each representation
frac variance explained in sign x, y, xy: (tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'))
Sum of explained variance: 0.9999998211860657
frac variance explained in standard x, y, xy: (tensor(0.2646, device='cuda:0'), tensor(0.2645, device='cuda:0'), tensor(0.0800, device='cuda:0'), tensor(0.3200, device='cuda:0'))
Sum of explained variance: 0.9289817810058594


In [14]:
# evidence: only \rho(ab) is important
hidden = activations['hidden'].reshape(group.order**2, -1)
loss = loss_fn(logits.reshape(group.order**2, -1), all_labels)
print(f'baseline loss: {loss}')

for rep_name in group.non_trivial_irreps:

    hidden_rep_x = group.irreps[rep_name].hidden_reps_x_orth
    hidden_rep_y = group.irreps[rep_name].hidden_reps_y_orth
    hidden_rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    coefs_x = hidden_rep_x.T @ hidden
    coefs_y = hidden_rep_y.T @ hidden
    coefs_xy = hidden_rep_xy.T @ hidden
    coefs_trivial = trivial.T @ hidden

    hidden_x = hidden_rep_x @ coefs_x
    hidden_y = hidden_rep_y @ coefs_y
    hidden_xy = hidden_rep_xy @ coefs_xy
    hidden_trivial = trivial @ coefs_trivial

    hidden_ablated_x = hidden - hidden_x
    hidden_ablated_y = hidden - hidden_y
    hidden_ablated_xy = hidden - hidden_xy
    hidden_ablated_trivial = hidden - hidden_trivial

    logits_x = hidden_ablated_x @ model.W_U
    logits_y = hidden_ablated_y @ model.W_U
    logits_xy = hidden_ablated_xy @ model.W_U
    logits_trivial = hidden_ablated_trivial @ model.W_U

    loss_x = loss_fn(logits_x, all_labels)
    loss_y = loss_fn(logits_y, all_labels)
    loss_xy = loss_fn(logits_xy, all_labels)
    loss_trivial = loss_fn(logits_trivial, all_labels)

    print(f'Ablating directions corresponding to {rep_name} rep loss, xy: {loss_xy}, x: {loss_x}, y: {loss_y}, trivial: {loss_trivial}')
    

baseline loss: 7.322111299253613e-07
Ablating directions corresponding to sign rep loss, xy: 0.0010408131037335398, x: 7.303637850563431e-07, y: 7.303876203454114e-07, trivial: 7.563847459256262e-07
Ablating directions corresponding to standard rep loss, xy: 17.520132485767775, x: 7.247668060390121e-07, y: 7.249903731815012e-07, trivial: 7.563847459256262e-07
Ablating directions corresponding to standard_sign rep loss, xy: 7.322113459849321e-07, x: 7.32210972366435e-07, y: 7.322112035616451e-07, trivial: 7.563847459256262e-07
Ablating directions corresponding to s5_5d_a rep loss, xy: 7.322110549515471e-07, x: 7.322107588338105e-07, y: 7.32210868970442e-07, trivial: 7.563847459256262e-07
Ablating directions corresponding to s5_5d_b rep loss, xy: 7.358036242541432e-07, x: 7.322188035788668e-07, y: 7.322256094681935e-07, trivial: 7.563847459256262e-07
Ablating directions corresponding to s5_6d rep loss, xy: 7.322113055899815e-07, x: 7.322110658635672e-07, y: 7.322108154068436e-07, trivial

In [33]:
# evidence: explicit extraction of \rho(ab). 

def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_to_reps_proj = {}
coefs = {}

for rep_name in key_reps:
    hidden_reps_xy = group.irreps[rep_name].hidden_reps_xy



    P = projection_matrix_general(hidden_reps_xy)
    hidden_xy = P @ hidden

    hidden_to_reps_proj[rep_name] = hidden_reps_xy.T @ hidden_xy

    imshow(hidden_to_reps_proj[rep_name], title=f'Change of basis from neuron basis to rho(ab) {rep_name} representation basis', input2='neuron basis', input1='representation basis', save=f'{save_dir}/hidden_to_{rep_name}_rep_change_of_basis.png')

    hidden_in_rep = hidden_xy @ hidden_to_reps_proj[rep_name].T

    theoretical_reps = hidden_reps_xy.reshape(group.order*group.order, -1)
    imshow(hidden_in_rep[:10], title=f'Projected hidden layer in the {rep_name} representation basis', input2='representation basis', input1='input index')
    imshow(theoretical_reps[:10], title=f'rho(ab) in {rep_name}', input2='representation basis', input1='input index')

    sim = F.cosine_similarity(hidden_in_rep.flatten(), theoretical_reps.flatten(), dim=0)
    print(f'Cosine similarity between hidden layer and theoretical representations: {sim}')

    # get the coef
    coef = (hidden_in_rep.norm() / theoretical_reps.norm())
    coefs[rep_name] = coef

Saving to paper/figures/hidden_to_sign_rep_change_of_basis.png


Cosine similarity between hidden layer and theoretical representations: 1.0000001192092896


Saving to paper/figures/hidden_to_standard_rep_change_of_basis.png


Cosine similarity between hidden layer and theoretical representations: 0.9999710917472839


# Logit Computation

In [16]:
# evidence: 
rep_name = 'standard'
W_U = model.W_U
rep = group.irreps[rep_name].rep.reshape(group.order, -1)
W_U_rep = hidden_to_reps_proj[rep_name] @ W_U @ rep [group.inverses]
print(W_U_rep.shape)
imshow(W_U_rep, title='Unembedding matrix in both input and output representation space', input2='input representation basis', input1='output representation basis', save=f'{save_dir}/unembedding_matrix_in_rep_basis.png')

real_linear_map = W_U_rep > 1e5
sim = F.cosine_similarity(W_U_rep.flatten(), real_linear_map.flatten(), dim=0)
print(f'Cosine similarity between unembedding matrix and real linear map: {sim}')

torch.Size([16, 16])


Saving to paper/figures/unembedding_matrix_in_rep_basis.png
Cosine similarity between unembedding matrix and real linear map: 0.9996892213821411


In [52]:
# do this orthogonally

def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_to_reps_proj_orth = {}
coefs = {}

for rep_name in key_reps:
    hidden_reps_xy = group.irreps[rep_name].hidden_reps_xy_orth


    P = projection_matrix_general(hidden_reps_xy)
    hidden_xy = P @ hidden

    hidden_to_reps_proj_orth[rep_name] = hidden_reps_xy.T @ hidden_xy

    imshow(hidden_to_reps_proj_orth[rep_name], title=f'Change of basis from neuron basis to rho(ab) {rep_name} representation basis', input2='neuron basis', input1='representation basis', save=f'{save_dir}/hidden_to_{rep_name}_orth_rep_change_of_basis.png')

    hidden_in_rep = hidden_xy @ hidden_to_reps_proj_orth[rep_name].T

    theoretical_reps = hidden_reps_xy.reshape(group.order*group.order, -1)
    imshow(hidden_in_rep[:10], title=f'Projected hidden layer in the {rep_name} representation basis', input2='representation basis', input1='input index')
    imshow(theoretical_reps[:10], title=f'rho(ab) in {rep_name}', input2='representation basis', input1='input index')

    sim = F.cosine_similarity(hidden_in_rep.flatten(), theoretical_reps.flatten(), dim=0)
    print(f'Cosine similarity between hidden layer and theoretical representations: {sim}')

    # get the coef
    coef = (hidden_in_rep.norm() / theoretical_reps.norm())
    coefs[rep_name] = coef

Saving to paper/figures/hidden_to_sign_orth_rep_change_of_basis.png


Cosine similarity between hidden layer and theoretical representations: 1.0000001192092896


Saving to paper/figures/hidden_to_standard_orth_rep_change_of_basis.png


Cosine similarity between hidden layer and theoretical representations: 0.9999791979789734


In [69]:
# percentage of W_U explained
W_U = model.W_U
for key, value in hidden_to_reps_proj_orth.items():
    rep_orth = group.irreps[key].orth_rep.reshape(group.order, -1)
    W_U_rep_neuron_basis = W_U[rep_neurons[key]]
    W_U_rep = W_U_rep_neuron_basis @ rep_orth[group.inverses]
    print(W_U_rep_neuron_basis.shape)
    print(f'Percentage of W_u explained in {key} representation: {(torch.norm(W_U_rep.flatten()) / torch.norm(W_U_rep_neuron_basis)).pow(2)}')


torch.Size([4, 120])
Percentage of W_u explained in sign representation: 0.9998068809509277
torch.Size([75, 120])
Percentage of W_u explained in standard representation: 0.9758121967315674


# Ablations

In [60]:
# MLP neurons
hidden = activations['hidden'].reshape(group.order**2, -1)

hidden_constructed = torch.zeros_like(hidden)
for key, value in hidden_to_reps_proj_orth.items():
    CoB_orth = value
    rep_orth = group.irreps[key].orth_rep
    hidden_reps_orth = group.irreps[key].hidden_reps_xy_orth
    hidden_rep = hidden_reps_orth @ CoB_orth
    hidden_rep = hidden_rep.reshape(group.order**2, -1)
    hidden_constructed += coefs[key] * hidden_rep
    

logits_constructed = hidden_constructed @ model.W_U
loss_constructed = loss_fn(logits_constructed, all_labels)
loss_base = loss_fn(hidden @ model.W_U, all_labels)
# percentage chagne
print(loss_base)
print(loss_constructed)


tensor(7.3221e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)
tensor(-0., device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)


In [76]:
# W_U

W_U = model.W_U
hidden = activations['hidden'].reshape(group.order**2, -1)
# restrict W_U to only output representation space
W_U_cont = torch.zeros_like(W_U)
for rep in key_reps:
    rep_orth = group.irreps[rep].orth_rep
    W_U_rep = W_U @ rep_orth[group.inverses]
    W_U_cont += W_U_rep @ rep_orth[group.inverses].T

W_U_null = W_U - W_U_cont
base_logits = model(all_data)
base_loss = loss_fn(base_logits, all_labels)
new_logits = hidden @ W_U_cont
null_logits = hidden @ W_U_null
new_loss = loss_fn(new_logits, all_labels)
null_loss = loss_fn(null_logits, all_labels)
print(base_loss)
print(new_loss)
print(null_loss)
#percent change
print((base_loss - new_loss)/base_loss)

tensor(7.3221e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)
tensor(7.0486e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)
tensor(4.8130, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(0.0374, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>)


In [None]:
# logits


# Full Circuit Analysis: Sign rep


In [19]:
# figure: blocky neurons

signature_neurons = rep_neurons['sign']
print(signature_neurons)
sig_labels = [str(x) for x in signature_neurons.tolist()]

hidden = activations['hidden'].reshape(group.order, group.order, -1)
fig = px.imshow(to_numpy(hidden[:, :, signature_neurons]), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, title=f'hidden activations', facet_col=2, labels={'x':'b', 'y':'a', 'facet_col': 'neuron'})
for i, neuron in enumerate(sig_labels):
    fig.layout.annotations[i]['text'] = f'neuron = {neuron}' 
fig.show()
fig.write_image(f'{save_dir}/blocky_sign_neurons.png')


tensor([24, 66, 73, 81], device='cuda:0')


In [20]:
# evidence: form of W_U on sign neurons

sigs = group.signatures.unsqueeze(-1)
W_U_signatures = model.W_U[signature_neurons, :].T
stack = torch.hstack([sigs, W_U_signatures]).T
imshow(stack, y=['sig'] + sig_labels, input2='output group element', title='W_U on select neurons')

# Progress Measures

In [21]:
# figure: total excluded loss
keys = ['total_embed_excluded_loss', 'total_embed_restricted_loss', 'test_loss', 'train_loss']
lines_from_keys(metrics, keys, title='Excluded Loss', labels=['Excluded Loss', 'Restricted Loss', 'Test Loss', 'Train Loss'], yaxis='Loss', save=f'{save_dir}/total_embed_excluded_and_restricted_loss.png', log_x=True, log_y = True)


Saving to paper/figures/total_embed_excluded_and_restricted_loss.png


In [22]:
# figure: total excluded loss
keys = ['total_logit_excluded_loss', 'total_logit_restricted_loss', 'test_loss', 'train_loss']
lines_from_keys(metrics, keys, title='Excluded Loss', labels=['Excluded Loss', 'Restricted Loss', 'Test Loss', 'Train Loss'], yaxis='Loss', save=f'{save_dir}/total_logit_excluded_and_restricted_loss.png', log_x=True, log_y = True)


Saving to paper/figures/total_logit_excluded_and_restricted_loss.png


In [23]:
# figure: excluded loss by rep
template = 'logit_excluded_loss_{}_rep'
lines_from_template(metrics, template, reps_to_plot,title='Excluded Loss by Representation', yaxis='Loss', save=f'{save_dir}/logit_excluded_loss_by_rep.png', log_y = True, log_x = True)

Saving to paper/figures/logit_excluded_loss_by_rep.png


In [24]:
template = 'logit_restricted_loss_{}_rep'
lines_from_template(metrics, template, reps_to_plot,title='Restricted Loss by Representation', yaxis='Loss', save=f'{save_dir}/logit_restricted_loss_by_rep.png', log_y = True, log_x = True)

Saving to paper/figures/logit_restricted_loss_by_rep.png


In [25]:
# figure: sum of square weights
keys = ['sum_of_squared_weights']
lines_from_keys(metrics, keys, title='Sum of Square Weights', labels=['Sum of Square Weights'], yaxis='Sum of Square Weights', save=f'{save_dir}/sum_of_square_weights.png', log_x=True, log_y=True)

Saving to paper/figures/sum_of_square_weights.png
