# Set Up


In [17]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
if "cd" not in globals():
    os.chdir("../")
    cd = True
print(os.getcwd())
save_dir = 'paper/figures'

Tue Jan  3 15:14:37 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Quadro P...  On   | 00000000:00:05.0  On |                  N/A |
| 46%   38C    P8     7W / 105W |   1001MiB /  8119MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
import wandb

import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import plotly.subplots

pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *
from utils.config import *

In [19]:
if torch.cuda.is_available:
    print('Good to go!')
else:
    print('Things might be rather slow')

Good to go!


In [20]:
task_dir = "paper/mainline-S5" 
seed, frac_train, layers, lr, group_param, weight_decay, num_epochs, group_type, architecture_type, metric_cfg, metric_obj = load_cfg(task_dir)
group = group_type(group_param, init_all = True)
model = architecture_type(layers, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"), strict=False)
model.eval()
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
logits, activations = model.run_with_cache(group.all_data, return_cache_object=False)
activations['logits'] = logits
metric_obj = metric_obj(group, training=False, track_metrics = True)

key_reps = metric_cfg['key_reps']

Computing multiplication table...
... loading from file
Computing trace tensor cube for trivial representation
... loading from file
Computing trace tensor cube for sign representation
... loading from file
Computing trace tensor cube for standard representation
... loading from file
Computing trace tensor cube for standard_sign representation
... loading from file
Computing trace tensor cube for s5_5d_a representation
... loading from file
Computing trace tensor cube for s5_5d_b representation
... loading from file
Computing trace tensor cube for s5_6d representation
... loading from file


In [21]:
api = wandb.Api()
run = api.run("bilal-experiments/SymmetricGroupRepTheory/qjda3qwd") #mainline-S5-all-reps run

In [22]:
run.config

{'lr': 0.001,
 'seed': 1,
 'layers': {'embed_dim': 256, 'hidden_dim': 128},
 'frac_train': 0.5,
 'group_type': 'utils.groups.SymmetricGroup',
 'num_epochs': 60000,
 'group_param': 5,
 'weight_decay': 0.5,
 'architecture_type': 'utils.models.OneLayerMLP'}

In [23]:
run.summary

{'_wandb': {'runtime': 2305}, 'percent_x_embed_sign_rep': 0.07104285806417465, 'percent_std_hidden_s5_6d_rep': 1.6800820158241426e-15, 'percent_std_hidden_s5_5d_b_rep': 8.34931711324316e-07, 'percent_std_unembed_s5_5d_a_rep': 5.962136128800921e-06, 'percent_std_y_embed_s5_5d_a_rep': 5.839368485277552e-12, 'sum_of_squared_weights': 5232.31005859375, 'excluded_loss_trivial_rep': 8.072786386037478e-07, 'percent_y_embed_s5_6d_rep': 6.647793981962025e-10, 'excluded_loss_standard_rep': 4.101238250732422, 'percent_std_hidden_standard_rep': 4.9660935474094e-05, 'percent_std_hidden_standard_sign_rep': 2.183094677254971e-11, '_step': 59999, '_timestamp': 1672332879.7165425, 'percent_x_embed_s5_5d_b_rep': 3.8481540176604767e-10, 'excluded_loss_standard_sign_rep': 8.072868808994826e-07, 'restricted_loss': 8.064425287557242e-07, 'total_excluded_loss': 4.787490367889404, 'excluded_loss_s5_5d_a_rep': 8.070965122897178e-07, 'percent_std_hidden_sign_rep': 'NaN', 'percent_unembed_trivial_rep': 0.0119587

In [24]:
def get_history(keys):
    history = run.scan_history(keys=keys)
    out = []
    for key in keys:
        if key not in run.summary.keys():
            raise ValueError(f"Key {key} not a valid metric")
        out.append([])
    for row in history:
        for i, key in enumerate(keys):
            out[i].append(row[key])
    return out

def lines_from_keys(keys, title, yaxis, labels, save, **kwargs):
    data = get_history(keys)
    lines(data, title=title, xaxis="epoch", yaxis=yaxis, labels=labels, show=False, save=save, **kwargs)

def lines_from_template(template, title, yaxis, save, **kwargs):
    non_trivial_irreps_names = list(group.non_trivial_irreps.keys())
    keys = []
    for irrep in group.non_trivial_irreps:
        keys.append(template.format(irrep))
    lines_from_keys(keys, title, yaxis, non_trivial_irreps_names, save, **kwargs)  


# Logit Attribution


In [25]:
# figure: a slice of true vs hypothetical logit cubes
logits = activations['logits']
logits = logits.reshape(group.order, group.order, group.order)

true_logits = logits[:, :, 0]/logits[:, :, 0].max()
standard_logits = group.irreps['standard'].logit_trace_tensor_cube[:, :, 0] / group.irreps['standard'].logit_trace_tensor_cube[:, :, 0].max()
sign_logits = group.irreps['sign'].logit_trace_tensor_cube[:, :, 0] / group.irreps['sign'].logit_trace_tensor_cube[:, :, 0].max()

stack = torch.stack([true_logits, standard_logits, sign_logits], dim=0)
print(stack.shape)
fig = px.imshow(to_numpy(stack), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, title=f'hidden activations', facet_col=0, labels={'x':'b', 'y':'a', 'facet_col': 'label'})
fig.layout.annotations[0]['text'] = 'true logit 0 over all inputs' 
fig.layout.annotations[1]['text'] = 'standard hypothesised logit 0'
fig.layout.annotations[2]['text'] = 'sign hypothesised logit 0'
fig.show()
fig.write_image(f"{save_dir}/logit_cubes.png", width=1000)


torch.Size([3, 120, 120])


In [26]:
# figure: evolution of cosine similarity
template = "logit_{}_rep_trace_similarity"
lines_from_template(template, title="cosine similarity of true logits and hypothesised logits", yaxis="cosine similarity", save=f"{save_dir}/logit_similarity.png")


Saving to paper/figures/logit_similarity.png


In [27]:
# evidence: end of training logit cosine similarity

template = "logit_{}_rep_trace_similarity"
print("cosine similarity of true logits and hypothesised logits at end of training")
for irrep in group.non_trivial_irreps:
    print(f"{irrep}: {run.summary[template.format(irrep)]}")


cosine similarity of true logits and hypothesised logits at end of training
sign: 0.4936670660972595
standard: 0.7574389576911926
standard_sign: 1.463159549075499e-07
s5_5d_a: -7.306867360057367e-07
s5_5d_b: 0.00017816123727243394
s5_6d: 3.2632060538162477e-06


In [28]:
# evidence: logits are weighted sum 

true_logits_flattened = activations['logits'].reshape(-1)
key_rep_logits_flattened = {}
for rep in key_reps:
    rep_logits = group.irreps[rep].logit_trace_tensor_cube.reshape(-1)
    key_rep_logits_flattened[rep] = rep_logits

# fit the true logits to the key rep logits using OLS
X = torch.stack(list(key_rep_logits_flattened.values()), dim=1)
y = true_logits_flattened
X = torch.cat([X, torch.ones(X.shape[0], 1).cuda()], dim=1)
beta = torch.inverse(X.T @ X) @ X.T @ y
print(beta)
print(key_reps)

OLS_logits = torch.zeros_like(true_logits_flattened)
for i, logits in enumerate(key_rep_logits_flattened.values()):
    OLS_logits += beta[i]*logits
OLS_logits += beta[-1]
OLS_logits_norm = torch.norm(OLS_logits)
true_logits_norm = torch.norm(true_logits_flattened)

percent_var_explained = OLS_logits_norm / true_logits_norm

print(f'Percent varaince explained {percent_var_explained}')

original_loss = loss_fn(activations['logits'].reshape(group.order**2, -1), all_labels)
OLS_loss = loss_fn(OLS_logits.reshape(group.order**2, -1), all_labels)

print(original_loss, OLS_loss)

tensor([6.4520, 4.2045, 5.8817], device='cuda:0', grad_fn=<MvBackward0>)
['standard', 'sign']
Percent varaince explained 0.9355218410491943
tensor(8.0730e-07, device='cuda:0', grad_fn=<NllLossBackward0>) tensor(0., device='cuda:0', grad_fn=<NllLossBackward0>)


In [29]:

print(F.cosine_similarity(key_rep_logits_flattened['sign'], true_logits_flattened-true_logits_flattened.mean(), dim=0))
print(F.cosine_similarity(key_rep_logits_flattened['sign'], true_logits_flattened, dim=0))


tensor(0.4931, device='cuda:0', grad_fn=<SumBackward1>)
tensor(0.4059, device='cuda:0', grad_fn=<SumBackward1>)


In [30]:
print(F.cosine_similarity(key_rep_logits_flattened['standard'], true_logits_flattened-true_logits_flattened.mean(), dim=0))
print(F.cosine_similarity(key_rep_logits_flattened['standard'], true_logits_flattened, dim=0))

tensor(0.7567, device='cuda:0', grad_fn=<SumBackward1>)
tensor(0.6229, device='cuda:0', grad_fn=<SumBackward1>)


In [31]:
print(F.cosine_similarity(torch.ones(120**3).cuda(), true_logits_flattened-true_logits_flattened.mean(), dim=0))
print(F.cosine_similarity(torch.ones(120**3).cuda(), true_logits_flattened, dim=0))

tensor(1.6764e-08, device='cuda:0', grad_fn=<SumBackward1>)
tensor(0.5678, device='cuda:0', grad_fn=<SumBackward1>)


# Embeddings and Unembeddings

In [32]:
# figure: percent a, b, c embed by representation over course of training

template = "percent_x_embed_{}_rep"
lines_from_template(template, title="Fraction of variance of left embedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_x_embed.png")

template = "percent_y_embed_{}_rep"
lines_from_template(template, title="Fraction of variance of right embedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_y_embed.png")

template = "percent_unembed_{}_rep"
lines_from_template(template, title="Fraction of variance of unembedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_unembed.png")


Saving to paper/figures/percent_x_embed.png
Saving to paper/figures/percent_y_embed.png
Saving to paper/figures/percent_unembed.png


# Hidden Layer Neurons

In [33]:
# figure: evolution of \rho(a), \rho(b), \rho(ab)
template = "percent_hidden_{}_rep"
lines_from_template(template, title="Fraction of variance of MLP neurons explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_hidden_ab.png")



Saving to paper/figures/percent_hidden_ab.png


In [34]:
# evidence: neuron clustering pre ReLU

threshold = 1

x_embed = model.x_embed
y_embed = model.y_embed

x_embed_summed = x_embed.pow(2).sum(dim=0)
off_neurons_x = (x_embed_summed < threshold).nonzero().squeeze()

y_embed_summed = y_embed.pow(2).sum(dim=0)
off_neurons_y = (y_embed_summed < threshold).nonzero().squeeze()

assert (off_neurons_x == off_neurons_y).all()

off_neurons = off_neurons_x

print(f'Off neurons: {len(off_neurons)}, {off_neurons}')

rep_neurons = {}

print('Neurons corresponding to each representation')
for rep_name in group.non_trivial_irreps:
    rep = group.irreps[rep_name].orth_rep
    coefs_x = rep.T @ x_embed
    coefs_y = rep.T @ y_embed
    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)

    x_neurons = (coefs_x_summed > threshold).nonzero().squeeze()
    y_neurons = (coefs_y_summed > threshold).nonzero().squeeze()
    assert (x_neurons == y_neurons).all()
    x_neurons = torch.tensor(x_neurons)
    if x_neurons.dim() == 0:
        x_neurons = x_neurons.unsqueeze(0)
    rep_neurons[rep_name] = x_neurons
    print(f'{rep_name}: {len(x_neurons)}, {x_neurons}')

print(rep_neurons)

all_neurons = torch.arange(model.W_U.shape[0])
unaccounted_neurons = set(all_neurons.tolist())
unaccounted_neurons -= set(off_neurons.tolist())
for rep_name, neurons in rep_neurons.items():
    unaccounted_neurons -= set(neurons.tolist())

print('Unaccounted neurons')
print(unaccounted_neurons)

Off neurons: 29, tensor([  1,   8,  13,  14,  22,  25,  34,  35,  37,  43,  46,  56,  58,  61,
         64,  65,  68,  76,  77,  78,  79,  85,  93,  95, 100, 105, 107, 108,
        114], device='cuda:0')
Neurons corresponding to each representation
sign: 4, tensor([24, 66, 73, 81], device='cuda:0')
standard: 95, tensor([  0,   2,   3,   4,   5,   6,   7,   9,  10,  11,  12,  15,  16,  17,
         18,  19,  20,  21,  23,  26,  27,  28,  29,  30,  31,  32,  33,  36,
         38,  39,  40,  41,  42,  44,  45,  47,  48,  49,  50,  51,  52,  53,
         54,  55,  57,  59,  60,  62,  63,  67,  69,  70,  71,  72,  74,  75,
         80,  82,  83,  84,  86,  87,  88,  89,  90,  91,  92,  94,  96,  97,
         98,  99, 101, 102, 103, 104, 106, 109, 110, 111, 112, 113, 115, 116,
        117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127],
       device='cuda:0')
standard_sign: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_a: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [35]:
# evidence and table: neuron clustering in post hidden layer

threshold = 100

hidden = activations['hidden'].reshape(group.order**2, -1)

hidden_summed = hidden.pow(2).sum(dim=0)
off_neurons = (hidden_summed < threshold).nonzero().squeeze()

assert (off_neurons == off_neurons_x).all()

print(f'Off neurons: {off_neurons}')


fracs_explained_x = {}
fracs_explained_y = {}
fracs_explained_xy = {}
fracs_explained_trivial = {}

for rep_name in group.irreps.keys():
    group.irreps[rep_name].hidden_reps_x = group.irreps[rep_name].rep[all_data[:, 0]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_x_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_x)[0]
    group.irreps[rep_name].hidden_reps_y = group.irreps[rep_name].rep[all_data[:, 1]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_y_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_y)[0]
    group.irreps[rep_name].hidden_reps_xy = group.irreps[rep_name].rep[all_labels].reshape(group.order*group.order, -1)
    group.irreps[rep_name].hidden_reps_xy_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_xy)[0]

for rep_name in key_reps:
    rep_x = group.irreps[rep_name].hidden_reps_x_orth
    rep_y = group.irreps[rep_name].hidden_reps_y_orth
    rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    trivial = group.irreps['trivial'].hidden_reps_x_orth

    coefs_x = rep_x.T @ hidden
    coefs_y = rep_y.T @ hidden
    coefs_xy = rep_xy.T @ hidden

    coefs_trivial = trivial.T @ hidden

    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)
    coefs_xy_summed = coefs_xy.pow(2).sum(dim=0)
    coefs_trivial_summed = coefs_trivial.pow(2).sum(dim=0)


    neurons = rep_neurons[rep_name]

    frac_x = (coefs_x_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_y = (coefs_y_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_xy = (coefs_xy_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_trivial = (coefs_trivial_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())

    fracs_explained_x[rep_name] = frac_x
    fracs_explained_y[rep_name] = frac_y
    fracs_explained_xy[rep_name] = frac_xy
    fracs_explained_trivial[rep_name] = frac_trivial

print('Neurons corresponding to each representation')
for key in key_reps:
    print(f'frac variance explained in {key} x, y, xy: {fracs_explained_x[key], fracs_explained_y[key], fracs_explained_xy[key], fracs_explained_trivial[key]}')
    print(f'Sum of explained variance: {fracs_explained_x[key] + fracs_explained_y[key] + fracs_explained_xy[key] + fracs_explained_trivial[key]}')

Off neurons: tensor([  1,   8,  13,  14,  22,  25,  34,  35,  37,  43,  46,  56,  58,  61,
         64,  65,  68,  76,  77,  78,  79,  85,  93,  95, 100, 105, 107, 108,
        114], device='cuda:0')
Neurons corresponding to each representation
frac variance explained in standard x, y, xy: (tensor(0.2743, device='cuda:0'), tensor(0.2650, device='cuda:0'), tensor(0.0774, device='cuda:0'), tensor(0.3100, device='cuda:0'))
Sum of explained variance: 0.9267013072967529
frac variance explained in sign x, y, xy: (tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'))
Sum of explained variance: 0.9999999403953552


In [36]:
# evidence: only \rho(ab) is important
hidden = activations['hidden'].reshape(group.order**2, -1)
loss = loss_fn(logits.reshape(group.order**2, -1), all_labels)
print(f'baseline loss: {loss}')

for rep_name in group.non_trivial_irreps:

    hidden_rep_x = group.irreps[rep_name].hidden_reps_x_orth
    hidden_rep_y = group.irreps[rep_name].hidden_reps_y_orth
    hidden_rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    coefs_x = hidden_rep_x.T @ hidden
    coefs_y = hidden_rep_y.T @ hidden
    coefs_xy = hidden_rep_xy.T @ hidden

    hidden_x = hidden_rep_x @ coefs_x
    hidden_y = hidden_rep_y @ coefs_y
    hidden_xy = hidden_rep_xy @ coefs_xy

    hidden_ablated_x = hidden - hidden_x
    hidden_ablated_y = hidden - hidden_y
    hidden_ablated_xy = hidden - hidden_xy

    logits_x = hidden_ablated_x @ model.W_U
    logits_y = hidden_ablated_y @ model.W_U
    logits_xy = hidden_ablated_xy @ model.W_U

    loss_x = loss_fn(logits_x, all_labels)
    loss_y = loss_fn(logits_y, all_labels)
    loss_xy = loss_fn(logits_xy, all_labels)

    print(f'Ablating directions corresponding to {rep_name} rep loss, xy: {loss_xy}, x: {loss_x}, y: {loss_y}')
    

baseline loss: 4.221286773681641
Ablating directions corresponding to sign rep loss, xy: 0.0006927842041477561, x: 8.058216280915076e-07, y: 8.053580131672788e-07
Ablating directions corresponding to standard rep loss, xy: 8.736567497253418, x: 7.778901931487781e-07, y: 7.815492608642671e-07
Ablating directions corresponding to standard_sign rep loss, xy: 8.072868808994826e-07, x: 8.072868808994826e-07, y: 8.073200206126785e-07
Ablating directions corresponding to s5_5d_a rep loss, xy: 8.072951800386363e-07, x: 8.074773063526663e-07, y: 8.073779440564977e-07
Ablating directions corresponding to s5_5d_b rep loss, xy: 8.140337968143285e-07, x: 8.075683695096814e-07, y: 8.079657050075184e-07
Ablating directions corresponding to s5_6d rep loss, xy: 8.072951800386363e-07, x: 8.074027846305398e-07, y: 8.074524657786242e-07


In [37]:
# evidence: explicit extraction of \rho(ab). 

def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

hidden = activations['hidden'].reshape(group.order*group.order, -1)


for rep_name in key_reps:
    hidden_reps_xy = group.irreps[rep_name].hidden_reps_xy



    P = projection_matrix_general(hidden_reps_xy)
    hidden_xy = P @ hidden

    hidden_to_reps_proj = hidden_reps_xy.T @ hidden_xy

    imshow(hidden_to_reps_proj, title=f'Change of basis from neuron basis to rho(ab) {rep_name} representation basis', input2='neuron basis', input1='representation basis', save=f'{save_dir}/hidden_to_{rep_name}_rep_change_of_basis.png')

    hidden_in_rep = hidden_xy @ hidden_to_reps_proj.T

    theoretical_reps = hidden_reps_xy.reshape(group.order*group.order, -1)
    imshow(hidden_in_rep[:10], title=f'Projected hidden layer in the {rep_name} representation basis', input2='representation basis', input1='input index')
    imshow(theoretical_reps[:10], title=f'rho(ab) in {rep_name}', input2='representation basis', input1='input index')

    sim = F.cosine_similarity(hidden_in_rep.flatten(), theoretical_reps.flatten(), dim=0)
    print(f'Cosine similarity between hidden layer and theoretical representations: {sim}')

Saving to paper/figures/hidden_to_standard_rep_change_of_basis.png


Cosine similarity between hidden layer and theoretical representations: 0.9997438192367554


Saving to paper/figures/hidden_to_sign_rep_change_of_basis.png


Cosine similarity between hidden layer and theoretical representations: 1.0


# Logit Computation

In [38]:
# evidence: 
W_U = model.W_U
rep = group.irreps['standard'].rep.reshape(group.order, -1)
W_U_rep = hidden_to_reps_proj @ W_U @ rep [group.inverses]
print(W_U_rep.shape)
imshow(W_U_rep, title='Unembedding matrix in both input and output representation space', input2='input representation basis', input1='output representation basis', save=f'{save_dir}/unembedding_matrix_in_rep_basis.png')

real_linear_map = W_U_rep > 1e5
sim = F.cosine_similarity(W_U_rep.flatten(), real_linear_map.flatten(), dim=0)
print(f'Cosine similarity between unembedding matrix and real linear map: {sim}')

torch.Size([1, 16])


Saving to paper/figures/unembedding_matrix_in_rep_basis.png
Cosine similarity between unembedding matrix and real linear map: 0.0


# Ablations

# Full Circuit Analysis: Sign rep


In [39]:
# figure: blocky neurons

signature_neurons = rep_neurons['sign']
print(signature_neurons)
sig_labels = [str(x) for x in signature_neurons.tolist()]

hidden = activations['hidden'].reshape(group.order, group.order, -1)
fig = px.imshow(to_numpy(hidden[:, :, signature_neurons]), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, title=f'hidden activations', facet_col=2, labels={'x':'b', 'y':'a', 'facet_col': 'neuron'})
for i, neuron in enumerate(sig_labels):
    fig.layout.annotations[i]['text'] = f'neuron = {neuron}' 
fig.show()
fig.write_image(f'{save_dir}/blocky_sign_neurons.png')


tensor([24, 66, 73, 81], device='cuda:0')


In [40]:
# evidence: form of W_U on sign neurons

sigs = group.signatures.unsqueeze(-1)
W_U_signatures = model.W_U[signature_neurons, :].T
stack = torch.hstack([sigs, W_U_signatures]).T
imshow(stack, y=['sig'] + sig_labels, input2='output group element', title='W_U on select neurons')

# Progress Measures

In [41]:
# figure: total excluded loss
keys = ['total_excluded_loss', 'test_loss', 'train_loss']
lines_from_keys(keys, title='Excluded Loss', labels=['Excluded Loss', 'Test Loss', 'Train Loss'], yaxis='Loss', save=f'{save_dir}/total_excluded_loss.png', log_y = True)

Saving to paper/figures/total_excluded_loss.png


In [42]:
# figure: excluded loss by rep
template = 'excluded_loss_{}_rep'
lines_from_template(template, title='Excluded Loss by Representation', yaxis='Loss', save=f'{save_dir}/excluded_loss_by_rep.png', log_y = True)

Saving to paper/figures/excluded_loss_by_rep.png


In [43]:
# figure: restricted loss
keys = ['restricted_loss', 'test_loss', 'train_loss']
lines_from_keys(keys, title='Restricted Loss', labels=['Restricted Loss', 'Test Loss', 'Train Loss'], yaxis='Loss', save=f'{save_dir}/restricted_loss.png', log_y=True)

Saving to paper/figures/restricted_loss.png


In [44]:
# figure: sum of square weights
keys = ['sum_of_squared_weights']
lines_from_keys(keys, title='Sum of Square Weights', labels=['Sum of Square Weights'], yaxis='Sum of Square Weights', save=f'{save_dir}/sum_of_square_weights.png', log_y=True)

Saving to paper/figures/sum_of_square_weights.png
