# Set Up


In [58]:
!nvidia-smi 
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
if "cd" not in globals():
    os.chdir("../")
    cd = True
print(os.getcwd())
save_dir = 'paper/figures'

Tue Jan 10 17:42:44 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  On   | 00000000:00:05.0 Off |                    0 |
| N/A   44C    P0    96W / 300W |   5160MiB / 32510MiB |     37%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
import wandb
import statsmodels.api as sm
import json
import pickle 
import copy

from tqdm.notebook import tqdm

# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import plotly.subplots

pio.renderers.default = "vscode"

# my own tooling
from utils.hook_points import HookPoint, HookedRootModule
from utils.plotting import *
from utils.groups import *
from utils.models import *
from utils.config import *
from utils.figures import *

In [60]:
if torch.cuda.is_available:
    print('Good to go!')
else:
    print('Things might be rather slow')

Good to go!


In [61]:
task_dir = "batch_experiments/MLP_S5_seed3"
#task_dir = "batch_experiments/T_S5_seed1"
#task_dir = "batch_experiments/MLP_C113_seed1"

seed, frac_train, layers, lr, group_param, weight_decay, num_epochs, group_type, architecture_type = load_cfg(task_dir)
group = group_type(group_param, init_all = True)
model = architecture_type(layers, group.order, seed).cuda()
model.load_state_dict(torch.load(f"{task_dir}/model.pt"), strict=False)
model.eval()
all_data, _, all_labels, _ = generate_train_test_data(group, frac_train = 1)
logits, activations = model.run_with_cache(group.all_data, return_cache_object=False)
activations['logits'] = logits
metric_obj = Metrics(group, training=False, track_metrics = True)
key_reps = metric_obj.determine_key_reps(model)

metrics_path = os.path.join(task_dir, 'metrics.csv')
summary_metrics_path = os.path.join(task_dir, 'summary_metrics.json')

# load the metrics
metrics = pd.read_csv(metrics_path)
summary_metrics = json.load(open(summary_metrics_path, 'r'))

reps_to_plot = list(group.non_trivial_irreps.keys())

Computing multiplication table...
... loading from file
Computing trace tensor cube for trivial representation
... loading from file
Computing trace tensor cube for sign representation
... loading from file
Computing trace tensor cube for standard representation
... loading from file
Computing trace tensor cube for standard_sign representation
... loading from file
Computing trace tensor cube for s5_5d_a representation
... loading from file
Computing trace tensor cube for s5_5d_b representation
... loading from file
Computing trace tensor cube for s5_6d representation
... loading from file


# Logit Attribution


In [62]:
# figure: a slice of true vs hypothetical logit cubes
logits = activations['logits']
logits = logits.reshape(group.order, group.order, group.order)

true_logits = logits[:, :, 0]/logits[:, :, 0].abs().max()
key_rep_logits = []
for key_rep in key_reps:
    key_rep_logits.append(group.irreps[key_rep].logit_trace_tensor_cube[:, :, 0]/group.irreps[key_rep].logit_trace_tensor_cube[:, :, 0].abs().max())
key_rep_logits = torch.stack(key_rep_logits, dim=0)

stack = torch.vstack([true_logits.unsqueeze(0), key_rep_logits])
print(stack.shape)
fig = px.imshow(to_numpy(stack), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, title=f'hidden activations', facet_col=0, labels={'x':'b', 'y':'a', 'facet_col': 'label'})
fig.layout.annotations[0]['text'] = 'true logit 0 over all inputs' 
for i in range(len(key_reps)):
    fig.layout.annotations[i+1]['text'] = f"{key_reps[i]} hypothesised logit 0"
fig.show()
fig.write_image(f"{save_dir}/logit_cubes.png", width=1000)


torch.Size([3, 120, 120])


In [63]:
# figure: evolution of cosine similarity
template = "logit_{}_rep_trace_similarity"
lines_from_template(metrics, template, reps_to_plot, title="cosine similarity of true logits and hypothesised logits", yaxis="cosine similarity", save=f"{save_dir}/logit_similarity.png", log_x=True)


Saving to paper/figures/logit_similarity.png


In [64]:
# evidence: end of training logit cosine similarity

template = "logit_{}_rep_trace_similarity"
print("cosine similarity of true logits and hypothesised logits at end of training")
for irrep in group.non_trivial_irreps:
    print(f"{irrep}: {summary_metrics[template.format(irrep)]}")


cosine similarity of true logits and hypothesised logits at end of training
sign: 0.3731147348880768
standard: 0.6733204126358032
standard_sign: 7.274529707501642e-06
s5_5d_a: -2.7835369564854773e-07
s5_5d_b: 0.0014667203649878502
s5_6d: 5.152600351721048e-06


In [87]:
# evidence: logits are weighted sum 
true_logits = model(all_data).reshape(group.order**2, group.order)
true_logits_centered = true_logits - true_logits.mean(dim=0, keepdim=True)
true_logits_flattened = true_logits_centered.reshape(-1)
key_rep_logits_flattened = {}
for rep in key_reps:
    rep_logits = group.irreps[rep].logit_trace_tensor_cube.reshape(-1)
    key_rep_logits_flattened[rep] = rep_logits

# fit the true logits to the key rep logits using OLS
X = torch.stack(list(key_rep_logits_flattened.values()), dim=1)
y = true_logits_flattened
X = torch.cat([X, torch.ones(X.shape[0], 1).cuda()], dim=1)
beta = torch.inverse(X.T @ X) @ X.T @ y
print('Betas: ', beta)
print(key_reps)

OLS_logits = torch.zeros_like(true_logits_flattened)
for i, logits in enumerate(key_rep_logits_flattened.values()):
    OLS_logits += beta[i]*logits
OLS_logits += beta[-1]
OLS_logits_norm = torch.norm(OLS_logits)
true_logits_norm = torch.norm(true_logits_flattened)

# calculate the R^2 of the OLS fit
r2 = F.cosine_similarity(true_logits_flattened, OLS_logits, dim=0).pow(2).sum()

print(f'Percent variance explained {r2}')

original_loss = loss_fn(activations['logits'].reshape(group.order**2, -1), all_labels)
OLS_loss = loss_fn(OLS_logits.reshape(group.order**2, -1), all_labels)

print('Original loss: ', original_loss)
print('OLS loss: ', OLS_loss)


# remove the OLS logits, and see what remains
residual_logits = true_logits_flattened - OLS_logits
residual_logits = residual_logits.reshape(group.order, group.order, group.order)
print('Residual loss: ',loss_fn(residual_logits.reshape(group.order**2, -1), all_labels))
imshow(residual_logits[:, :, 0])

Betas:  tensor([ 4.3076e+00,  7.7735e+00, -3.7882e-07], device='cuda:0',
       grad_fn=<MvBackward0>)
['sign', 'standard']
Percent variance explained 0.5920559763908386
Original loss:  tensor(2.0395e-06, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)
OLS loss:  tensor(1.8112e-09, device='cuda:0', dtype=torch.float64,
       grad_fn=<NegBackward0>)
Residual loss:  tensor(17.3197, device='cuda:0', dtype=torch.float64, grad_fn=<NegBackward0>)


In [67]:

print(F.cosine_similarity(key_rep_logits_flattened['sign'], true_logits_flattened-true_logits_flattened.mean(), dim=0))
print(F.cosine_similarity(key_rep_logits_flattened['sign'], true_logits_flattened, dim=0))


tensor(0.3729, device='cuda:0', grad_fn=<SumBackward1>)
tensor(0.3729, device='cuda:0', grad_fn=<SumBackward1>)


In [68]:
print(F.cosine_similarity(key_rep_logits_flattened['standard'], true_logits_flattened-true_logits_flattened.mean(), dim=0))
print(F.cosine_similarity(key_rep_logits_flattened['standard'], true_logits_flattened, dim=0))

tensor(0.6730, device='cuda:0', grad_fn=<SumBackward1>)
tensor(0.6730, device='cuda:0', grad_fn=<SumBackward1>)


In [69]:
print(F.cosine_similarity(torch.ones(120**3).cuda(), true_logits_flattened-true_logits_flattened.mean(), dim=0))
print(F.cosine_similarity(torch.ones(120**3).cuda(), true_logits_flattened, dim=0))

tensor(-1.1874e-08, device='cuda:0', grad_fn=<SumBackward1>)
tensor(-3.2873e-08, device='cuda:0', grad_fn=<SumBackward1>)


# Embeddings and Unembeddings

In [70]:
# figure: percent a, b, c embed by representation over course of training

template = "percent_x_embed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, title="Fraction of variance of left embedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_x_embed.png", log_x=True)
template = "percent_y_embed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, title="Fraction of variance of right embedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_y_embed.png", log_x=True)
template = "percent_unembed_{}_rep"
lines_from_template(metrics, template, reps_to_plot, title="Fraction of variance of unembedding explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_unembed.png", log_x=True)

Saving to paper/figures/percent_x_embed.png
Saving to paper/figures/percent_y_embed.png
Saving to paper/figures/percent_unembed.png


# Hidden Layer Neurons

In [71]:
# figure: evolution of \rho(a), \rho(b), \rho(ab)
template = "percent_hidden_{}_rep"
lines_from_template(metrics, template, reps_to_plot, title="Fraction of variance of MLP neurons explained by representation", yaxis="fraction of variance", save=f"{save_dir}/percent_hidden.png", log_x=True)

Saving to paper/figures/percent_hidden.png


In [72]:
# evidence: neuron clustering pre ReLU

threshold = 1

x_embed = model.x_embed
y_embed = model.y_embed

x_embed_summed = x_embed.pow(2).sum(dim=0)
off_neurons_x = (x_embed_summed < threshold).nonzero().squeeze()

y_embed_summed = y_embed.pow(2).sum(dim=0)
off_neurons_y = (y_embed_summed < threshold).nonzero().squeeze()

assert (off_neurons_x == off_neurons_y).all()

off_neurons = off_neurons_x

print(f'Off neurons: {len(off_neurons)}, {off_neurons}')

rep_neurons = {}

print('Neurons corresponding to each representation')
for rep_name in group.non_trivial_irreps:
    rep = group.irreps[rep_name].orth_rep
    coefs_x = rep.T @ x_embed
    coefs_y = rep.T @ y_embed
    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)

    x_neurons = (coefs_x_summed > threshold).nonzero().squeeze()
    y_neurons = (coefs_y_summed > threshold).nonzero().squeeze()
    assert (x_neurons == y_neurons).all()
    x_neurons = torch.tensor(x_neurons)
    if x_neurons.dim() == 0:
        x_neurons = x_neurons.unsqueeze(0)
    rep_neurons[rep_name] = x_neurons
    print(f'{rep_name}: {len(x_neurons)}, {x_neurons}')

print(rep_neurons)

all_neurons = torch.arange(model.W_U.shape[0])
unaccounted_neurons = set(all_neurons.tolist())
unaccounted_neurons -= set(off_neurons.tolist())
for rep_name, neurons in rep_neurons.items():
    unaccounted_neurons -= set(neurons.tolist())

print('Unaccounted neurons')
print(unaccounted_neurons)

Off neurons: 17, tensor([  6,  10,  12,  15,  30,  57,  60,  72,  73,  76,  88,  97,  99, 100,
        101, 119, 127], device='cuda:0')
Neurons corresponding to each representation
sign: 13, tensor([  3,   4,  11,  24,  43,  48,  50,  77,  89,  95,  96, 103, 124],
       device='cuda:0')
standard: 98, tensor([  0,   1,   2,   5,   7,   8,   9,  13,  14,  16,  17,  18,  19,  20,
         21,  22,  23,  25,  26,  27,  28,  29,  31,  32,  33,  34,  35,  36,
         37,  38,  39,  40,  41,  42,  44,  45,  46,  47,  49,  51,  52,  53,
         54,  55,  56,  58,  59,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  74,  75,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,
         90,  91,  92,  93,  94,  98, 102, 104, 105, 106, 107, 108, 109, 110,
        111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 125, 126],
       device='cuda:0')
standard_sign: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_a: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [73]:
# evidence and table: neuron clustering in post hidden layer

threshold = 100

hidden = activations['hidden'].reshape(group.order**2, -1)

hidden_summed = hidden.pow(2).sum(dim=0)
off_neurons = (hidden_summed < threshold).nonzero().squeeze()

assert (off_neurons == off_neurons_x).all()

print(f'Off neurons: {off_neurons}')


fracs_explained_x = {}
fracs_explained_y = {}
fracs_explained_xy = {}
fracs_explained_trivial = {}

for rep_name in group.irreps.keys():
    group.irreps[rep_name].hidden_reps_x = group.irreps[rep_name].rep[all_data[:, 0]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_x_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_x)[0]
    group.irreps[rep_name].hidden_reps_y = group.irreps[rep_name].rep[all_data[:, 1]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_y_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_y)[0]
    group.irreps[rep_name].hidden_reps_xy = group.irreps[rep_name].rep[all_labels].reshape(group.order*group.order, -1)
    group.irreps[rep_name].hidden_reps_xy_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_xy)[0]

for rep_name in key_reps:
    rep_x = group.irreps[rep_name].hidden_reps_x_orth
    rep_y = group.irreps[rep_name].hidden_reps_y_orth
    rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    trivial = group.irreps['trivial'].hidden_reps_x_orth

    coefs_x = rep_x.T @ hidden
    coefs_y = rep_y.T @ hidden
    coefs_xy = rep_xy.T @ hidden

    coefs_trivial = trivial.T @ hidden

    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)
    coefs_xy_summed = coefs_xy.pow(2).sum(dim=0)
    coefs_trivial_summed = coefs_trivial.pow(2).sum(dim=0)


    neurons = rep_neurons[rep_name]

    frac_x = (coefs_x_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_y = (coefs_y_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_xy = (coefs_xy_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_trivial = (coefs_trivial_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())

    fracs_explained_x[rep_name] = frac_x
    fracs_explained_y[rep_name] = frac_y
    fracs_explained_xy[rep_name] = frac_xy
    fracs_explained_trivial[rep_name] = frac_trivial

print('Neurons corresponding to each representation')
for key in key_reps:
    print(f'frac variance explained in {key} x, y, xy: {fracs_explained_x[key], fracs_explained_y[key], fracs_explained_xy[key], fracs_explained_trivial[key]}')
    print(f'Sum of explained variance: {fracs_explained_x[key] + fracs_explained_y[key] + fracs_explained_xy[key] + fracs_explained_trivial[key]}')

Off neurons: tensor([  6,  10,  12,  15,  30,  57,  60,  72,  73,  76,  88,  97,  99, 100,
        101, 119, 127], device='cuda:0')
Neurons corresponding to each representation
frac variance explained in sign x, y, xy: (tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'), tensor(0.2500, device='cuda:0'))
Sum of explained variance: 0.9999998807907104
frac variance explained in standard x, y, xy: (tensor(0.2644, device='cuda:0'), tensor(0.2645, device='cuda:0'), tensor(0.0795, device='cuda:0'), tensor(0.3200, device='cuda:0'))
Sum of explained variance: 0.9284014701843262


In [74]:
# evidence: only \rho(ab) is important
hidden = activations['hidden'].reshape(group.order**2, -1)
loss = loss_fn(logits.reshape(group.order**2, -1), all_labels)
print(f'baseline loss: {loss}')

for rep_name in group.non_trivial_irreps:

    hidden_rep_x = group.irreps[rep_name].hidden_reps_x_orth
    hidden_rep_y = group.irreps[rep_name].hidden_reps_y_orth
    hidden_rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    coefs_x = hidden_rep_x.T @ hidden
    coefs_y = hidden_rep_y.T @ hidden
    coefs_xy = hidden_rep_xy.T @ hidden

    hidden_x = hidden_rep_x @ coefs_x
    hidden_y = hidden_rep_y @ coefs_y
    hidden_xy = hidden_rep_xy @ coefs_xy

    hidden_ablated_x = hidden - hidden_x
    hidden_ablated_y = hidden - hidden_y
    hidden_ablated_xy = hidden - hidden_xy

    logits_x = hidden_ablated_x @ model.W_U
    logits_y = hidden_ablated_y @ model.W_U
    logits_xy = hidden_ablated_xy @ model.W_U

    loss_x = loss_fn(logits_x, all_labels)
    loss_y = loss_fn(logits_y, all_labels)
    loss_xy = loss_fn(logits_xy, all_labels)

    print(f'Ablating directions corresponding to {rep_name} rep loss, xy: {loss_xy}, x: {loss_x}, y: {loss_y}')
    

baseline loss: 1.4973364201770554
Ablating directions corresponding to sign rep loss, xy: 0.0017656893313159389, x: 2.032480954875985e-06, y: 2.033004011839782e-06
Ablating directions corresponding to standard rep loss, xy: 16.5026584172692, x: 2.0316150472263943e-06, y: 2.0344320892097933e-06
Ablating directions corresponding to standard_sign rep loss, xy: 2.0394702808804127e-06, x: 2.0394699440704904e-06, y: 2.039469258614872e-06
Ablating directions corresponding to s5_5d_a rep loss, xy: 2.0394699319726187e-06, x: 2.0394691612360735e-06, y: 2.039468071908322e-06
Ablating directions corresponding to s5_5d_b rep loss, xy: 2.228731111385631e-06, x: 2.0407866583878264e-06, y: 2.041453179068509e-06
Ablating directions corresponding to s5_6d rep loss, xy: 2.0394700848458166e-06, x: 2.039469538647029e-06, y: 2.0394684064208865e-06


In [75]:
# evidence: explicit extraction of \rho(ab). 

def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_to_reps_proj = {}

for rep_name in key_reps:
    hidden_reps_xy = group.irreps[rep_name].hidden_reps_xy



    P = projection_matrix_general(hidden_reps_xy)
    hidden_xy = P @ hidden

    hidden_to_reps_proj[rep_name] = hidden_reps_xy.T @ hidden_xy

    imshow(hidden_to_reps_proj[rep_name], title=f'Change of basis from neuron basis to rho(ab) {rep_name} representation basis', input2='neuron basis', input1='representation basis', save=f'{save_dir}/hidden_to_{rep_name}_rep_change_of_basis.png')

    hidden_in_rep = hidden_xy @ hidden_to_reps_proj[rep_name].T

    theoretical_reps = hidden_reps_xy.reshape(group.order*group.order, -1)
    imshow(hidden_in_rep[:10], title=f'Projected hidden layer in the {rep_name} representation basis', input2='representation basis', input1='input index')
    imshow(theoretical_reps[:10], title=f'rho(ab) in {rep_name}', input2='representation basis', input1='input index')

    sim = F.cosine_similarity(hidden_in_rep.flatten(), theoretical_reps.flatten(), dim=0)
    print(f'Cosine similarity between hidden layer and theoretical representations: {sim}')

Saving to paper/figures/hidden_to_sign_rep_change_of_basis.png


Cosine similarity between hidden layer and theoretical representations: 1.0


Saving to paper/figures/hidden_to_standard_rep_change_of_basis.png


Cosine similarity between hidden layer and theoretical representations: 0.9998586177825928


# Logit Computation

In [76]:
# evidence: 
rep_name = 'standard'
W_U = model.W_U
rep = group.irreps[rep_name].rep.reshape(group.order, -1)
W_U_rep = hidden_to_reps_proj[rep_name] @ W_U @ rep [group.inverses]
print(W_U_rep.shape)
imshow(W_U_rep, title='Unembedding matrix in both input and output representation space', input2='input representation basis', input1='output representation basis', save=f'{save_dir}/unembedding_matrix_in_rep_basis.png')

real_linear_map = W_U_rep > 1e5
sim = F.cosine_similarity(W_U_rep.flatten(), real_linear_map.flatten(), dim=0)
print(f'Cosine similarity between unembedding matrix and real linear map: {sim}')

torch.Size([16, 16])


Saving to paper/figures/unembedding_matrix_in_rep_basis.png
Cosine similarity between unembedding matrix and real linear map: 0.9966269731521606


# Ablations

# Full Circuit Analysis: Sign rep


In [77]:
# figure: blocky neurons

signature_neurons = rep_neurons['sign']
print(signature_neurons)
sig_labels = [str(x) for x in signature_neurons.tolist()]

hidden = activations['hidden'].reshape(group.order, group.order, -1)
fig = px.imshow(to_numpy(hidden[:, :, signature_neurons]), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, title=f'hidden activations', facet_col=2, labels={'x':'b', 'y':'a', 'facet_col': 'neuron'})
for i, neuron in enumerate(sig_labels):
    fig.layout.annotations[i]['text'] = f'neuron = {neuron}' 
fig.show()
fig.write_image(f'{save_dir}/blocky_sign_neurons.png')


tensor([  3,   4,  11,  24,  43,  48,  50,  77,  89,  95,  96, 103, 124],
       device='cuda:0')


In [78]:
# evidence: form of W_U on sign neurons

sigs = group.signatures.unsqueeze(-1)
W_U_signatures = model.W_U[signature_neurons, :].T
stack = torch.hstack([sigs, W_U_signatures]).T
imshow(stack, y=['sig'] + sig_labels, input2='output group element', title='W_U on select neurons')

# Progress Measures

In [79]:
# figure: total excluded loss
keys = ['total_excluded_loss', 'test_loss', 'train_loss']
lines_from_keys(metrics, keys, title='Excluded Loss', labels=['Excluded Loss', 'Test Loss', 'Train Loss'], yaxis='Loss', save=f'{save_dir}/total_excluded_loss.png', log_y = True)


Saving to paper/figures/total_excluded_loss.png


In [80]:
# figure: excluded loss by rep
template = 'excluded_loss_{}_rep'
lines_from_template(metrics, template, reps_to_plot,title='Excluded Loss by Representation', yaxis='Loss', save=f'{save_dir}/excluded_loss_by_rep.png', log_y = True, log_x = True)

Saving to paper/figures/excluded_loss_by_rep.png


In [81]:
# figure: restricted loss
keys = ['restricted_loss', 'test_loss', 'train_loss']
lines_from_keys(metrics, keys, title='Restricted Loss', labels=['Restricted Loss', 'Test Loss', 'Train Loss'], yaxis='Loss', save=f'{save_dir}/restricted_loss.png', log_y=True)

Saving to paper/figures/restricted_loss.png


In [82]:
# figure: sum of square weights
keys = ['sum_of_squared_weights']
lines_from_keys(metrics, keys, title='Sum of Square Weights', labels=['Sum of Square Weights'], yaxis='Sum of Square Weights', save=f'{save_dir}/sum_of_square_weights.png', log_y=True)

Saving to paper/figures/sum_of_square_weights.png
