# Set Up


In [485]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
if "cd" not in globals():
    os.chdir("../")
    cd = True
print(os.getcwd())
save_dir = 'paper/figures'

Sat Jan 21 17:44:44 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  On   | 00000000:00:05.0 Off |                    0 |
| N/A   39C    P0    56W / 300W |   4179MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [486]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
import wandb
import statsmodels.api as sm
import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import plotly.subplots

pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *
from utils.config import *
from utils.figures import *

In [487]:
if torch.cuda.is_available:
    print('Good to go!')
else:
    print('Things might be rather slow')

Good to go!


In [488]:
import copy
pio.templates['grokking'] = copy.deepcopy(pio.templates['plotly'])
pio.templates.default = 'grokking'

pio.templates['grokking']['layout']['font']['family'] = 'Computer Modern'
pio.templates['grokking']['layout']['title'].update(dict(
    yref='paper',
    yanchor='bottom',
    y=1.,
    pad_b=10,
    xanchor='center',
    x=0.5,
    font_size=30,
))

pio.templates['grokking']['layout']['legend'].update(
    font_size = 20,
)

pio.templates['grokking']['layout']['margin'].update(
    l=80,
    r=0,
    t=0,
    b=60,
    pad=0,
)

axis_dict = dict(
    title_font_size=28,
    tickfont_size=20,
    title_standoff=1.,
)
coloraxis_dict = dict(colorbar_x=1.01, 
                colorbar_xanchor="left", 
                colorbar_xpad=0)
pio.templates['grokking']['layout']['xaxis'].update(axis_dict)
pio.templates['grokking']['layout']['yaxis'].update(axis_dict)
pio.templates['grokking']['layout']['coloraxis'].update(coloraxis_dict)

# make all figures a standard size
pio.templates['grokking']['layout']['width'] = 800
pio.templates['grokking']['layout']['height'] = 400

In [489]:
# produce a grokking plot of all the MLP S5 models
dirs = ['batch_experiments/S5_MLP_seed1', 'batch_experiments/S5_MLP_seed2', 'batch_experiments/S5_MLP_seed3', 'batch_experiments/S5_MLP_seed4']
data = []
for dir in dirs:
    # create a dict for this model
    model_dict = {}
    # grab metrics.csv
    metrics = pd.read_csv(f'{dir}/metrics.csv')
    # grab epochs, train_loss, test_loss
    epochs = metrics['epoch']
    train_loss = metrics['train_loss']
    test_loss = metrics['test_loss']
    train_acc = metrics['train_acc']
    test_acc = metrics['test_acc']
    # add to dict
    model_dict['epochs'] = epochs
    model_dict['train_loss'] = train_loss
    model_dict['test_loss'] = test_loss
    model_dict['train_acc'] = train_acc
    model_dict['test_acc'] = test_acc
    data.append(model_dict)

avg_data = {}
for key in data[0].keys():
    # take the median
    avg_data[key] = np.median([model[key] for model in data], axis=0)

fig = go.Figure()
fig.add_trace(go.Scatter(x=avg_data['epochs'], y=avg_data['train_loss'], name='Median Train Loss'))
fig.add_trace(go.Scatter(x=avg_data['epochs'], y=avg_data['test_loss'], name='Median Test Loss'))
# add faded traces, same thickness, but with lower opacity, blue for train, red for test
for model in data:
    fig.add_trace(go.Scatter(x=model['epochs'], y=model['train_loss'], name='train loss', line=dict(color='blue'), showlegend=False, opacity=0.2))
    fig.add_trace(go.Scatter(x=model['epochs'], y=model['test_loss'], name='test loss', line=dict(color='red'), showlegend=False, opacity=0.2))
fig.update_layout(
    xaxis_title='epoch',
    yaxis_title='loss',
)
fig.update_layout(
    legend= dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99
        )
)
# log y
fig.update_yaxes(type="log")
fig.show()
fig.write_image(f'{save_dir}/S5_MLP_grokking_loss.pdf')

# do the same for accuracy
fig = go.Figure()
fig.add_trace(go.Scatter(x=avg_data['epochs'], y=avg_data['train_acc'], name='Median Train Accuracy'))
fig.add_trace(go.Scatter(x=avg_data['epochs'], y=avg_data['test_acc'], name='Median Test Accuracy'))
# add faded traces, same thickness, but with lower opacity, blue for train, red for test
for model in data:
    fig.add_trace(go.Scatter(x=model['epochs'], y=model['train_acc'], name='train accuracy', line=dict(color='blue'), showlegend=False, opacity=0.2))
    fig.add_trace(go.Scatter(x=model['epochs'], y=model['test_acc'], name='test accuracy', line=dict(color='red'), showlegend=False, opacity=0.2))
fig.update_layout(
    xaxis_title='epoch',
    yaxis_title='accuracy',
)
fig.update_layout(
    legend = dict(
            yanchor="bottom",
            y=0.01,
            xanchor="right",
            x=0.99
))
fig.show()
fig.write_image(f'{save_dir}/S5_MLP_grokking_acc.pdf')








In [490]:
task_dir = "paper/mainline-S5"

seed, frac_train, layers, lr, group_param, weight_decay, betas, num_epochs, group_type, architecture_type = load_cfg(task_dir)
group = group_type(group_param, init_all = True)
model = architecture_type(layers, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"), strict=False)
model.eval()
all_data, _, all_labels, _, _= generate_train_test_data(group, frac_train = 1)
logits, activations = model.run_with_cache(group.all_data, return_cache_object=False)
activations['logits'] = logits
#metric_obj = Metrics(group, training=False, track_metrics = True)
key_reps= []
with open(os.path.join(task_dir, 'key_reps_in_order.txt'), 'r') as f:
    for line in f:
        key_reps.append(line.strip())
metrics_path = os.path.join(task_dir, 'metrics.csv')
summary_metrics_path = os.path.join(task_dir, 'summary_metrics.json')

# load the metrics
metrics = pd.read_csv(metrics_path)
summary_metrics = json.load(open(summary_metrics_path, 'r'))

reps_to_plot = list(group.non_trivial_irreps.keys())

Computing multiplication table...
... loading from file
Computing trace tensor cube for sign representation
... loading from file
Computing trace tensor cube for standard representation
... loading from file
Computing trace tensor cube for standard_sign representation
... loading from file
Computing trace tensor cube for s5_5d_a representation
... loading from file
Computing trace tensor cube for s5_5d_b representation
... loading from file
Computing trace tensor cube for s5_6d representation
... loading from file


# Logit Attribution


In [525]:
# figure: a slice of true vs hypothetical logit cubes
logits = activations['logits']
logits = logits.reshape(group.order, group.order, group.order)

true_logits = logits[:, :, 0]/logits[:, :, 0].abs().max()
key_rep_logits = []
for key_rep in key_reps:
    key_rep_logits.append(group.irreps[key_rep].logit_trace_tensor_cube[:, :, 0]/group.irreps[key_rep].logit_trace_tensor_cube[:, :, 0].abs().max())
key_rep_logits = torch.stack(key_rep_logits, dim=0)

stack = torch.vstack([true_logits.unsqueeze(0), key_rep_logits])
print(stack.shape)
fig = px.imshow(to_numpy(stack), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, facet_col=0, labels={'x':'b', 'y':'a', 'facet_col': 'label'})
fig.layout.annotations[0]['text'] = 'observed logit' 
for i in range(len(key_reps)):
    fig.layout.annotations[i+1]['text'] = f"{key_reps[i]}"
for i in range(len(key_reps)+1):
    fig.layout.annotations[i]['font']['family'] = 'Computer Modern'
    fig.layout.annotations[i]['font']['size'] = 20
    fig.layout.annotations[i]['yshift'] = -20
fig.update_layout(
    height = 330
)
fig.show()
fig.write_image(f"{save_dir}/logit_cubes.pdf")




torch.Size([3, 120, 120])


In [492]:
# figure: evolution of cosine similarity
template = "logit_{}_rep_trace_similarity"
lines_from_template(metrics, template, reps_to_plot, yaxis="cosine similarity", save=f"{save_dir}/logit_similarity.pdf", log_x=True, legend_pos='tl')


Saving to paper/figures/logit_similarity.pdf


In [493]:
# evidence: end of training logit cosine similarity

template = "logit_{}_rep_trace_similarity"
print("cosine similarity of true logits and hypothesised logits at end of training")
for irrep in group.non_trivial_irreps:
    print(f"{irrep}: {summary_metrics[template.format(irrep)]}")


cosine similarity of true logits and hypothesised logits at end of training
sign: 0.5090430974960327
standard: 0.7674336433410645
standard_sign: -1.102307578548789e-09
s5_5d_a: -1.0477378964424133e-09
s5_5d_b: 0.001785235945135355
s5_6d: 5.187757778912783e-09


In [494]:
# percentage explained
print(summary_metrics['percent_logits_explained'])

0.848082423210144


# Embeddings and Unembeddings

In [495]:
# figure: percent a, b, c embed by representation over course of training

template = "percent_x_embed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, yaxis="fraction of variance", save=f"{save_dir}/percent_x_embed.pdf", log_x=True, legend_pos='tl')
template = "percent_y_embed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, yaxis="fraction of variance", save=f"{save_dir}/percent_y_embed.pdf", log_x=True, legend_pos='tl')
template = "percent_unembed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, yaxis="fraction of variance", save=f"{save_dir}/percent_unembed.pdf", log_x=True, legend_pos='tl')

Saving to paper/figures/percent_x_embed.pdf


Saving to paper/figures/percent_y_embed.pdf


Saving to paper/figures/percent_unembed.pdf


In [496]:
# table: percents explained
print("percent of variance explained by representation at end of training")
print("x, y, unembed")
total_x = 0
total_y = 0
total_unembed = 0
for rep in key_reps:
    print(f"{rep}: {100*summary_metrics['percent_x_embed_{}_rep'.format(rep)]}, {100*summary_metrics['percent_y_embed_{}_rep'.format(rep)]}, {100*summary_metrics['percent_unembed_{}_rep'.format(rep)]}")
    total_x += 100*summary_metrics['percent_x_embed_{}_rep'.format(rep)]
    total_y += 100*summary_metrics['percent_y_embed_{}_rep'.format(rep)]
    total_unembed += 100*summary_metrics['percent_unembed_{}_rep'.format(rep)]

# residuals
print(f"residuals: {100-total_x}, {100-total_y}, {100-total_unembed}")

percent of variance explained by representation at end of training
x, y, unembed
sign: 6.951265037059784, 6.951045244932175, 9.584786742925644
standard: 93.04873943328857, 93.04895401000977, 84.46281552314758
residuals: -4.470348358154297e-06, 7.450580596923828e-07, 5.952397733926773


# Hidden Layer Neurons

In [550]:
# figure: evolution of \rho(a), \rho(b), \rho(ab)
keys = ['percent_hidden_explained']
lines_from_keys(metrics, keys, yaxis="fraction of variance", labels=['Hidden FVE'], save=f"{save_dir}/percent_hidden.pdf", log_x=True, legend_pos='tl')

template = "percent_hidden_xy_{}_rep"
lines_from_template(metrics, template, reps_to_plot, yaxis="fraction of variance", save=f"{save_dir}/percent_hidden_xy.pdf", log_x=True, legend_pos='tl')

template = "percent_hidden_x_{}_rep"
lines_from_template(metrics, template, reps_to_plot, yaxis="fraction of variance", save=f"{save_dir}/percent_hidden_x.pdf", log_x=True, legend_pos='tl')

template = "percent_hidden_y_{}_rep"
lines_from_template(metrics, template, reps_to_plot, yaxis="fraction of variance", save=f"{save_dir}/percent_hidden_y.pdf", log_x=True, legend_pos='tl')

Saving to paper/figures/percent_hidden.pdf


Saving to paper/figures/percent_hidden_xy.pdf


Saving to paper/figures/percent_hidden_x.pdf


Saving to paper/figures/percent_hidden_y.pdf


In [498]:
# evidence: neuron clustering pre ReLU

threshold = 1

x_embed = model.x_embed
y_embed = model.y_embed

x_embed_summed = x_embed.pow(2).sum(dim=0)
off_neurons_x = (x_embed_summed < threshold).nonzero().squeeze()

y_embed_summed = y_embed.pow(2).sum(dim=0)
off_neurons_y = (y_embed_summed < threshold).nonzero().squeeze()

assert (off_neurons_x == off_neurons_y).all()

off_neurons = off_neurons_x

print(f'Off neurons: {len(off_neurons)}, {off_neurons}')

rep_neurons = {}

print('Neurons corresponding to each representation')
for rep_name in group.non_trivial_irreps:
    rep = group.irreps[rep_name].orth_rep
    coefs_x = rep.T @ x_embed
    coefs_y = rep.T @ y_embed
    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)

    x_neurons = (coefs_x_summed > threshold).nonzero().squeeze()
    y_neurons = (coefs_y_summed > threshold).nonzero().squeeze()
    assert (x_neurons == y_neurons).all()
    x_neurons = torch.tensor(x_neurons)
    if x_neurons.dim() == 0:
        x_neurons = x_neurons.unsqueeze(0)
    rep_neurons[rep_name] = x_neurons
    print(f'{rep_name}: {len(x_neurons)}, {x_neurons}')

print(rep_neurons)

all_neurons = torch.arange(model.W_U.shape[0])
unaccounted_neurons = set(all_neurons.tolist())
unaccounted_neurons -= set(off_neurons.tolist())
for rep_name, neurons in rep_neurons.items():
    unaccounted_neurons -= set(neurons.tolist())

print('Unaccounted neurons')
print(unaccounted_neurons)

Off neurons: 2, tensor([ 48, 110], device='cuda:0')
Neurons corresponding to each representation
sign: 7, tensor([  2,   8,  17,  65, 111, 113, 120], device='cuda:0')
standard: 119, tensor([  0,   1,   3,   4,   5,   6,   7,   9,  10,  11,  12,  13,  14,  15,
         16,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,
         31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,
         45,  46,  47,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
         60,  61,  62,  63,  64,  66,  67,  68,  69,  70,  71,  72,  73,  74,
         75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,
         89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102,
        103, 104, 105, 106, 107, 108, 109, 112, 114, 115, 116, 117, 118, 119,
        121, 122, 123, 124, 125, 126, 127], device='cuda:0')
standard_sign: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_a: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_b: 


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [499]:
# evidence and table: neuron clustering in post hidden layer

threshold = 110

hidden = activations['hidden'].reshape(group.order**2, -1)
hidden = hidden - hidden.mean(dim=0, keepdim=True)

hidden_summed = hidden.pow(2).sum(dim=0)
off_neurons = (hidden_summed < threshold).nonzero().squeeze()

assert (off_neurons == off_neurons_x).all()

print(f'Off neurons: {off_neurons}')


fracs_explained_x = {}
fracs_explained_y = {}
fracs_explained_xy = {}
fracs_explained_trivial = {}

for rep_name in group.irreps.keys():
    group.irreps[rep_name].hidden_reps_x = group.irreps[rep_name].rep[all_data[:, 0]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_x_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_x)[0]
    group.irreps[rep_name].hidden_reps_y = group.irreps[rep_name].rep[all_data[:, 1]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_y_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_y)[0]
    group.irreps[rep_name].hidden_reps_xy = group.irreps[rep_name].rep[all_labels].reshape(group.order*group.order, -1)
    group.irreps[rep_name].hidden_reps_xy_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_xy)[0]

for rep_name in key_reps:
    rep_x = group.irreps[rep_name].hidden_reps_x_orth
    rep_y = group.irreps[rep_name].hidden_reps_y_orth
    rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    trivial = group.irreps['trivial'].hidden_reps_x_orth

    coefs_x = rep_x.T @ hidden
    coefs_y = rep_y.T @ hidden
    coefs_xy = rep_xy.T @ hidden

    coefs_trivial = trivial.T @ hidden

    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)
    coefs_xy_summed = coefs_xy.pow(2).sum(dim=0)
    coefs_trivial_summed = coefs_trivial.pow(2).sum(dim=0)


    neurons = rep_neurons[rep_name]

    frac_x = (coefs_x_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_y = (coefs_y_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_xy = (coefs_xy_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_trivial = (coefs_trivial_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())

    fracs_explained_x[rep_name] = frac_x
    fracs_explained_y[rep_name] = frac_y
    fracs_explained_xy[rep_name] = frac_xy
    fracs_explained_trivial[rep_name] = frac_trivial

print('Neurons corresponding to each representation')
for key in key_reps:
    print(f'frac variance explained in {key} x, y, xy: {fracs_explained_x[key], fracs_explained_y[key], fracs_explained_xy[key], fracs_explained_trivial[key]}')
    print(f'Sum of explained variance: {fracs_explained_x[key] + fracs_explained_y[key] + fracs_explained_xy[key] + fracs_explained_trivial[key]}')

Off neurons: tensor([ 48, 110], device='cuda:0')
Neurons corresponding to each representation
frac variance explained in sign x, y, xy: (tensor(0.3333, device='cuda:0'), tensor(0.3333, device='cuda:0'), tensor(0.3333, device='cuda:0'), tensor(2.0314e-15, device='cuda:0'))
Sum of explained variance: 0.9999998807907104
frac variance explained in standard x, y, xy: (tensor(0.3959, device='cuda:0'), tensor(0.3708, device='cuda:0'), tensor(0.1131, device='cuda:0'), tensor(2.6660e-15, device='cuda:0'))
Sum of explained variance: 0.8797791004180908


In [500]:
# evidence: only \rho(ab) is important
hidden = activations['hidden'].reshape(group.order**2, -1)
loss = loss_fn(logits.reshape(group.order**2, -1), all_labels).item()
print(f'baseline loss: {loss}')

for rep_name in group.non_trivial_irreps:

    hidden_rep_x = group.irreps[rep_name].hidden_reps_x_orth
    hidden_rep_y = group.irreps[rep_name].hidden_reps_y_orth
    hidden_rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    coefs_x = hidden_rep_x.T @ hidden
    coefs_y = hidden_rep_y.T @ hidden
    coefs_xy = hidden_rep_xy.T @ hidden
    coefs_trivial = trivial.T @ hidden

    hidden_x = hidden_rep_x @ coefs_x
    hidden_y = hidden_rep_y @ coefs_y
    hidden_xy = hidden_rep_xy @ coefs_xy
    hidden_trivial = trivial @ coefs_trivial

    hidden_ablated_x = hidden - hidden_x
    hidden_ablated_y = hidden - hidden_y
    hidden_ablated_xy = hidden - hidden_xy
    hidden_ablated_trivial = hidden - hidden_trivial

    logits_x = hidden_ablated_x @ model.W_U
    logits_y = hidden_ablated_y @ model.W_U
    logits_xy = hidden_ablated_xy @ model.W_U
    logits_trivial = hidden_ablated_trivial @ model.W_U

    loss_x = loss_fn(logits_x, all_labels).item()
    loss_y = loss_fn(logits_y, all_labels).item()
    loss_xy = loss_fn(logits_xy, all_labels).item()
    loss_trivial = loss_fn(logits_trivial, all_labels).item()

    print(f'Ablating directions corresponding to {rep_name} rep loss, xy: {loss_xy}, x: {loss_x}, y: {loss_y}, trivial: {loss_trivial}')
    # percentage change in loss
    print(f'Relative change in loss: {(loss_xy - loss)/loss, (loss_x - loss)/loss, (loss_y - loss)/loss, (loss_trivial - loss)/loss}')

baseline loss: 2.3838242884590774e-06
Ablating directions corresponding to sign rep loss, xy: 0.0009095180870913975, x: 2.372950527269995e-06, y: 2.3709135243412416e-06, trivial: 2.419312382417704e-06
Relative change in loss: (380.5373857438616, -0.004561477639826964, -0.00541598815833085, 0.014887042694562933)
Ablating directions corresponding to standard rep loss, xy: 7.5507839045127305, x: 2.209098210702922e-06, y: 2.2082870514831203e-06, trivial: 2.419312382417704e-06
Relative change in loss: (3167507.5873856987, -0.07329654228378542, -0.07363681871427938, 0.014887042694562933)
Ablating directions corresponding to standard_sign rep loss, xy: 2.383824176064554e-06, x: 2.3838226086524663e-06, y: 2.3838241945102326e-06, trivial: 2.419312382417704e-06
Relative change in loss: (-4.714882884186531e-08, -7.04668804341278e-07, -3.941097726865407e-08, 0.014887042694562933)
Ablating directions corresponding to s5_5d_a rep loss, xy: 2.383824804402695e-06, x: 2.383821273166266e-06, y: 2.383814

In [542]:
# evidence: explicit extraction of \rho(ab). 

def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_to_reps_proj = {}
coefs = {}

for rep_name in key_reps:
    hidden_reps_xy = group.irreps[rep_name].hidden_reps_xy



    P = projection_matrix_general(hidden_reps_xy)
    hidden_xy = P @ hidden

    hidden_to_reps_proj[rep_name] = hidden_reps_xy.T @ hidden_xy
    #hidden_to_reps_proj[rep_name] = hidden_to_reps_proj[rep_name] / hidden_to_reps_proj[rep_name].norm(dim=1, keepdim=True)
    if rep_name == 'standard':
        plot = hidden_to_reps_proj[rep_name].detach().cpu().numpy()
        fig = px.imshow(plot, color_continuous_scale='RdBu', color_continuous_midpoint=0.0, labels={'x':'neuron basis', 'y':'rep basis'})
        fig.update_layout(
            height = 180
        )
        fig.show()
        fig.write_image(f'{save_dir}/hidden_to_{rep_name}_rep_change_of_basis.pdf')

    hidden_in_rep = hidden_xy @ hidden_to_reps_proj[rep_name].T

    theoretical_reps = hidden_reps_xy.reshape(group.order*group.order, -1)
    #imshow(hidden_in_rep[:10], title=f'Projected hidden layer in the {rep_name} representation basis', input2='representation basis', input1='input index')
    #imshow(theoretical_reps[:10], title=f'rho(ab) in {rep_name}', input2='representation basis', input1='input index')
    

    hidden_in_rep_norm = hidden_in_rep.flatten() / hidden_in_rep.flatten().norm()
    theoretical_reps_norm = theoretical_reps.flatten() / theoretical_reps.flatten().norm()

    # MSE loss between hidden_in_rep and theoretical_reps
    sim = F.mse_loss(hidden_in_rep_norm, theoretical_reps_norm)
    #sim = F.cosine_similarity(hidden_in_rep.flatten(), theoretical_reps.flatten(), dim=0)
    print(f'MSE Loss between hidden layer and theoretical representations: {sim}')

    # get the coef
    coef = (hidden_in_rep.norm() / theoretical_reps.norm())
    coefs[rep_name] = coef

MSE Loss between hidden layer and theoretical representations: 0.0


MSE Loss between hidden layer and theoretical representations: 4.4785686270643055e-09


# Logit Computation

In [502]:
# evidence: 
rep_name = 'standard'
W_U = model.W_U
rep = group.irreps[rep_name].rep.reshape(group.order, -1)
#rep = rep / rep.norm(dim=0, keepdim=True)
W_U_rep = hidden_to_reps_proj[rep_name] @ W_U @ rep [group.inverses]
print(W_U_rep.shape)
imshow(W_U_rep, title='', input2='input rep basis', input1='output rep basis', save=f'{save_dir}/unembedding_matrix_in_rep_basis.pdf')

real_linear_map = (W_U_rep > 1e5).float()
real_linear_map_norm = real_linear_map.flatten() / real_linear_map.flatten().norm()
W_U_rep_norm = W_U_rep.flatten() / W_U_rep.flatten().norm()
sim = F.mse_loss(real_linear_map_norm, W_U_rep_norm)
#sim = F.cosine_similarity(W_U_rep.flatten(), real_linear_map.flatten(), dim=0)
print(f'Cosine similarity between unembedding matrix and real linear map: {sim}')

torch.Size([16, 16])


Saving to paper/figures/unembedding_matrix_in_rep_basis.pdf
Cosine similarity between unembedding matrix and real linear map: 1.1119418559246697e-05


In [503]:
# do this orthogonally

def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_to_reps_proj_orth = {}
coefs = {}

for rep_name in key_reps:
    hidden_reps_xy = group.irreps[rep_name].hidden_reps_xy_orth


    P = projection_matrix_general(hidden_reps_xy)
    hidden_xy = P @ hidden

    hidden_to_reps_proj_orth[rep_name] = hidden_reps_xy.T @ hidden_xy
    hidden_to_reps_proj_orth[rep_name] = hidden_to_reps_proj_orth[rep_name] / hidden_to_reps_proj_orth[rep_name].norm(dim=1, keepdim=True)

    #imshow(hidden_to_reps_proj_orth[rep_name], title=f'Change of basis from neuron basis to rho(ab) {rep_name} representation basis', input2='neuron basis', input1='representation basis', save=f'{save_dir}/hidden_to_{rep_name}_orth_rep_change_of_basis.pdf')

    hidden_in_rep = hidden_xy @ hidden_to_reps_proj_orth[rep_name].T

    theoretical_reps = hidden_reps_xy.reshape(group.order*group.order, -1)
    #imshow(hidden_in_rep[:10], title=f'Projected hidden layer in the {rep_name} representation basis', input2='representation basis', input1='input index')
    #imshow(theoretical_reps[:10], title=f'rho(ab) in {rep_name}', input2='representation basis', input1='input index')

    sim = F.cosine_similarity(hidden_in_rep.flatten(), theoretical_reps.flatten(), dim=0)
    print(f'Cosine similarity between hidden layer and theoretical representations: {sim}')

    # get the coef
    coef = (hidden_in_rep.norm() / theoretical_reps.norm())
    coefs[rep_name] = coef

Cosine similarity between hidden layer and theoretical representations: 1.0
Cosine similarity between hidden layer and theoretical representations: 0.9996493458747864


In [504]:
# percentage of W_U explained
W_U = model.W_U
for key, value in hidden_to_reps_proj_orth.items():
    rep_orth = group.irreps[key].orth_rep.reshape(group.order, -1)
    W_U_rep_neuron_basis = W_U[rep_neurons[key]]
    W_U_rep = W_U_rep_neuron_basis @ rep_orth[group.inverses]
    print(W_U_rep_neuron_basis.shape)
    print(f'Percentage of W_u explained in {key} representation: {(torch.norm(W_U_rep.flatten()) / torch.norm(W_U_rep_neuron_basis)).pow(2)}')


torch.Size([7, 120])
Percentage of W_u explained in sign representation: 0.9993630647659302
torch.Size([119, 120])
Percentage of W_u explained in standard representation: 0.9340784549713135


# Ablations

In [505]:
# MLP neurons
hidden = activations['hidden'].reshape(group.order**2, -1)

hidden_constructed = torch.zeros_like(hidden)
for key, value in hidden_to_reps_proj_orth.items():
    CoB_orth = value
    # normalize this CoB matrix
    #CoB_orth = CoB_orth / torch.norm(CoB_orth)
    rep_orth = group.irreps[key].orth_rep
    hidden_reps_orth = group.irreps[key].hidden_reps_xy_orth
    hidden_reps_orth 
    hidden_rep = hidden_reps_orth @ CoB_orth
    x = CoB_orth @ CoB_orth.T
    hidden_rep = hidden_rep.reshape(group.order**2, -1)
    hidden_constructed += coefs[key] * hidden_rep
    

logits_constructed = hidden_constructed @ model.W_U
loss_constructed = loss_fn(logits_constructed, all_labels)
# imshow(logits_constructed[:10])
# imshow((hidden @ model.W_U)[:10])
loss_base = loss_fn(hidden @ model.W_U, all_labels)
# percentage change
print(f'Percentage change in loss: {(loss_constructed - loss_base) / loss_base}')
print(loss_constructed)


Percentage change in loss: -0.7061693842141639
tensor(7.0044e-07, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)


In [506]:
# W_U

W_U = model.W_U
hidden = activations['hidden'].reshape(group.order**2, -1)
# restrict W_U to only output representation space
W_U_cont = torch.zeros_like(W_U)
for rep in key_reps:
    rep_orth = group.irreps[rep].orth_rep
    W_U_rep = W_U @ rep_orth[group.inverses]
    W_U_cont += W_U_rep @ rep_orth[group.inverses].T

W_U_null = W_U - W_U_cont
base_logits = model(all_data)
base_loss = loss_fn(base_logits, all_labels)
new_logits = hidden @ W_U_cont
null_logits = hidden @ W_U_null
new_loss = loss_fn(new_logits, all_labels)
null_loss = loss_fn(null_logits, all_labels)
print(base_loss)
print(new_loss)
print(null_loss)
#percent change
print((base_loss - new_loss)/base_loss)

tensor(2.3838e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)
tensor(2.0779e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)
tensor(4.7955, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(0.1283, device='cuda:0', dtype=torch.float64, grad_fn=<DivBackward0>)


In [507]:
# logits
template = 'logit_excluded_loss_{}_rep'
for rep in key_reps:
    print(f'Excluding {rep}: {summary_metrics[template.format(rep)]}')

print(f'Excluding all: {summary_metrics["total_logit_excluded_loss"]}')

Excluding sign: 0.0006027620128980678
Excluding standard: 7.279840630335901
Excluding all: 7.601493474399844


In [508]:
# logits: ablating other directions improves performance...

# Full Circuit Analysis: Sign rep


In [545]:
# figure: blocky neurons

signature_neurons = rep_neurons['sign']
print(signature_neurons)
sig_labels = [str(x) for x in signature_neurons.tolist()]

hidden = activations['hidden'].reshape(group.order, group.order, -1)
fig = px.imshow(to_numpy(hidden[:, :, signature_neurons]), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, facet_col=2, labels={'x':'b', 'y':'a', 'facet_col': 'neuron'})
for i, neuron in enumerate(sig_labels):
    fig.layout.annotations[i]['text'] = f'neuron = {neuron}'
    fig.layout.annotations[i]['yshift'] = -25
fig.update_xaxes(showticklabels=False, title_standoff=10)
fig.update_yaxes(showticklabels=False)
fig.update_layout(
    height = 200
)
fig.update_layout(
    margin=dict(l=40, r=50, t=0, b=0),
)
fig.show()
fig.write_image(f'{save_dir}/blocky_sign_neurons.pdf')


tensor([  2,   8,  17,  65, 111, 113, 120], device='cuda:0')


In [510]:
# identify form of neurons
sigs = group.signatures.unsqueeze(-1)
xs = model.x_embed[:, signature_neurons]
stack = torch.hstack([sigs, xs]).T
#imshow(stack, y=['sig'] + sig_labels, input2='input group element', title='Total x embeddings on select neurons')
ys = y_embed[:, signature_neurons]
stack = torch.hstack([sigs, ys]).T
#imshow(stack, y=['sig'] + sig_labels, input2='input group element', title='Total y embeddings on select neurons')

In [511]:
# evidence: form of W_U on sign neurons

sigs = group.signatures.unsqueeze(-1)
W_U_signatures = model.W_U[signature_neurons, :].T
stack = torch.hstack([sigs, W_U_signatures]).T
#imshow(stack, y=['sig'] + sig_labels, input2='output group element', title='W_U on select neurons')

# Progress Measures

In [512]:
vlines = [2200, 80000, 100000]
trim = 130000

In [513]:
# figure: total excluded loss
keys = ['total_hidden_excluded_loss', 'total_hidden_restricted_loss', 'test_loss', 'train_loss']
lines_from_keys(metrics, keys, labels=['Excluded Loss', 'Restricted Loss', 'Test Loss', 'Train Loss'], yaxis='Loss', save=f'{save_dir}/total_logit_excluded_and_restricted_loss.pdf', log_y = True, log_x=False, legend_pos='tr', trim=trim, vlines=vlines)


Saving to paper/figures/total_logit_excluded_and_restricted_loss.pdf


In [514]:
# figure: excluded loss by rep
template = 'hidden_excluded_loss_{}_rep'
lines_from_template(metrics, template, reps_to_plot, yaxis='Loss', save=f'{save_dir}/logit_excluded_loss_by_rep.pdf', log_y = True, log_x=True, legend_pos='bl', trim=trim, vlines=vlines)

Saving to paper/figures/logit_excluded_loss_by_rep.pdf


In [515]:
template = 'hidden_restricted_loss_{}_rep'
lines_from_template(metrics, template, reps_to_plot, yaxis='Loss', save=f'{save_dir}/logit_restricted_loss_by_rep.pdf', log_x=True, log_y = True, legend_pos='bl', trim=trim, vlines=vlines)

Saving to paper/figures/logit_restricted_loss_by_rep.pdf


In [516]:
# figure: sum of square weights
keys = ['sum_of_squared_weights']
lines_from_keys(metrics, keys, labels=['Sum of Square Weights'], save=f'{save_dir}/sum_of_square_weights.pdf', yaxis='sum of squared weights', log_y=True, log_x=True, trim=trim, vlines=vlines)

Saving to paper/figures/sum_of_square_weights.pdf


In [517]:
keys = ['test_loss_restricted_loss_ratio']
lines_from_keys(metrics, keys, labels=['Test Loss to Restricted Loss Ratio'], yaxis='ratio', save=f'{save_dir}/test_loss_restricted_loss_ratio.pdf', log_y=True, log_x=True, trim=trim, vlines=vlines)

Saving to paper/figures/test_loss_restricted_loss_ratio.pdf
