In this notebook, we will train a network to perform group composition on the abstract group S5, and then reverse engineer the algorithm learned. This notebook should be fairly self contained, and uses similar code to that in the paper, though slightly simplified. Full code for the paper is available at ...

# Setup


In [136]:
# Imports
import os
if "cd" not in globals():
    os.chdir("../")
    cd = True
import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
from tqdm import tqdm
from transformer_lens.hook_points import HookPoint, HookedRootModule
from sympy.combinatorics.named_groups import SymmetricGroup as SymPySymmetricGroup
from sympy.combinatorics import Permutation, PermutationGroup
from utils.plotting import *

In [137]:
SEED = 2

Let's first create a class to contain all the information about our group. We include methods 
* to convert between permutations and their indices that we use to one-hot-encode them. 
* to compose permutations, based on their index
* to calculate the order of a permutation
* to calculate the signature of a permutation
* to cache all signatures of the group elements
* to calculate the inverse of a permutation
* to cache all inverses of group elements
* to calculate the full multiplication table of the group
* to return all pairs of permutations, and their product

In [138]:
class SymmetricGroup():
    """
    Class for the symmetric group of order index.
    """
    def __init__(self, index, init_all=False):
        """
        Initialise the group. Optionally calculate all other tensors needed to track metrics.

        Args:
            index (int): Index of group in family of symmetric groups.
            init_all (bool, optional): If false, only calculate what is required to train. If true, calculate all other tensors needed to track metrics. Defaults to False.
        """

        # build class on top of sympy
        self.G = SymPySymmetricGroup(index)
        self.index = index
        self.order = math.factorial(index)
        self.acronym = 'S'
        self.multiplication_table = self.compute_multiplication_table()
        
        # hacky method to find the index of the identity element 
        self.identity = [i for i in range(self.order) if self.idx_to_perm(i).order() == 1][0]
        
        self.inverses = self.compute_inverses()

        # compute the signatures of the elements
        self.signatures = self.compute_signatures()

        self.all_data, _ = self.get_all_data()
        self.all_data = self.all_data[:, :2]


    def idx_to_perm(self, x):
        """
        Convert an index to a permutation.

        Args:
            x (int): index of element in group

        Returns:
            Permutation: permutation object from sympy
        """
        return self.G._elements[x]

    def perm_to_idx(self, perm):
        """
        Converts a permutation to an index.

        Args:
            perm (Permutation): permutation object from sympy

        Returns:
            int: index of element in group
        """
        return self.G._elements.index(perm)

    def compose(self, x, y):
        """
        Compose elements of the group by converting to permutations, composing, and converting back.

        Args:
            x (int): left index
            y (int): right index

        Returns:
            int: index of composition
        """
        return self.perm_to_idx(self.idx_to_perm(x) * self.idx_to_perm(y))

    def perm_order(self, x):
        """
        Gets the order of a permutation.

        Args:
            x (int): index of element

        Returns:
            int: order of permutation
        """
        return self.idx_to_perm(x).order()

    def signature(self, x):
        """
        Gets the signature of a permutation.

        Args:
            x (int): index of element

        Returns:
            int: Integer \in {0, 1} representing the signature of the permutation.
        """
        return self.idx_to_perm(x).signature()
    
    def compute_signatures(self):
        """
        Compute and store the signature of each element in the group.

        Returns:
            torch.tensor: tensor of signatures
        """
        signatures = torch.tensor([self.signature(i) for i in range(self.order)]).cuda()
        return signatures
    
    def inverse(self, x):
        """
        Compute the inverse of an element of the group.

        Args:
            x (int): Index of element to inverse

        Returns:
            int: Index of inverse element
        """
        return (self.multiplication_table[x, :] == self.identity).nonzero().item()

    def compute_inverses(self):
        inverses = torch.zeros(self.order, dtype=torch.int64)
        for i in range(self.order):
            inverses[i] = self.inverse(i)
        return inverses
    
    def compute_multiplication_table(self):
        """
        Compute the multiplication table of the group. Caches/loads from file if possible.
        """
        print('Computing multiplication table...')
        table = torch.zeros((self.order, self.order), dtype=torch.int64).cuda()
        for i in tqdm(range(self.order)):
            for j in range(self.order):
                table[i, j] = self.compose(i, j)
        return table

    def get_all_data(self, shuffle_seed=False):
        """
        Get's all data and labels for the pairwise composition task.

        Args:
            shuffle_seed (bool, optional): Shuffle data for training. Defaults to False.

        Returns:
            torch.tensor: Tensor of shape (order*order, 3) where each row is (x, y, x*y).
        """
        data=torch.zeros((self.order*self.order, 3), dtype=torch.int64)
        shuffled_indices = None
        for i in range(self.order):
            for j in range(self.order):
                data[i*self.order+j, 0] = i
                data[i*self.order+j, 1] = j
                data[i*self.order+j, 2] = self.multiplication_table[i, j]
        if shuffle_seed:
            torch.manual_seed(shuffle_seed) 
            shuffled_indices = torch.randperm(self.order*self.order)
            data = data[shuffled_indices]
        return data.cuda(), shuffled_indices

In [139]:
group = SymmetricGroup(5)

Computing multiplication table...


100%|██████████| 120/120 [00:10<00:00, 11.46it/s]


The model architecture is an MLP as described in the paper. We subclass the HookedRootModule class from the transformer_lens library to easily cache activations on the forward pass.

In [140]:
model_cfg = {    
    "layers": {
        "embed_dim": 256,
        "hidden_dim": 128
    }
}

class OneLayerMLP(HookedRootModule):
    """ 
    A one layer MLP. W_x and W_y are embedding layers, whose outputs are concatenated and fed into a hidden layer. The result is unembedded by W_U.
    """
    def __init__(self, layers, n, seed=0):
        # embed_dim: dimension of the embedding
        # hidden : hidden dimension size
        # n : group order
        super().__init__()
        torch.manual_seed(seed)

        self.embed_dim = layers['embed_dim']
        hidden = layers['hidden_dim']

        # xavier initialise parameters
        self.W_x = nn.Parameter(torch.randn(n, self.embed_dim)/np.sqrt(self.embed_dim))
        self.W_y = nn.Parameter(torch.randn(n, self.embed_dim)/np.sqrt(self.embed_dim))
        self.W = nn.Parameter(torch.randn(2*self.embed_dim, hidden)/np.sqrt(2*self.embed_dim))
        self.relu = nn.ReLU()
        self.W_U = nn.Parameter(torch.randn(hidden, n)/np.sqrt(hidden))

        # hookpoints
        self.embed_stack = HookPoint()
        self.hidden = HookPoint()

        # We need to call the setup function of HookedRootModule to build an 
        # internal dictionary of modules and hooks, and to give each hook a name
        super().setup()

    def forward(self, data):
        x = data[:, 0] # (batch)
        half_x_embed = self.W_x[x] # (batch, embed_dim)
        y = data[:, 1] # (batch)
        half_y_embed = self.W_y[y] # (batch, embed_dim)
        embed_stack = self.embed_stack(torch.hstack((half_x_embed, half_y_embed))) # (batch, 2*embed_dim)
        hidden = self.hidden(self.relu(embed_stack @ self.W)) # (batch, hidden)
        out = hidden @ self.W_U # (batch, n)
        return out

def loss_fn(logits, labels):
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()

def get_accuracy(logits, labels):
    """
    Compute accuracy of model.

    Args:
        logits (torch.tensor): (batch, group.order) tensor of logits
        labels (torch.tensor): (batch) tensor of labels

    Returns:
        float: accuracy
    """
    return ((logits.argmax(1)==labels).sum()/len(labels)).item()

model = OneLayerMLP(model_cfg['layers'], group.order).cuda()

To generate our train/test split, we just take a fixed proportion of the entire multiplication table of the group at random. We'll use a ratio of 0.4 here.

In [141]:
def generate_train_test_data(group, frac_train, seed=False):
    """
    Generate train and test data from a group's all data.
    """
    data, shuffled_indices = group.get_all_data(seed)
    train_size = int(frac_train*data.shape[0])
    train = data[:train_size]
    test = data[train_size:]
    train_data = train[:, :2]
    train_labels = train[:, 2]
    test_data = test[:, :2]
    test_labels = test[:, 2]
    return train_data, test_data, train_labels, test_labels, shuffled_indices



We'll checkpoint our models every 500 epochs, to later be able to analyse the model throughout training.

In [142]:
def save_checkpoint(model, epoch, task_dir, final=False):
    """
    Save model checkpoint.
    """
    path = f'{task_dir}/checkpoints/epoch_{epoch}.pt'
    if final:
        path = f'{task_dir}/model.pt'
    torch.save(model.state_dict(), path)

def load_checkpoint(model, task_dir, epoch=None, final=False):
    """
    Load model checkpoint.
    """
    path = f'{task_dir}/checkpoints/epoch_{epoch}.pt'
    if final:
        path = f'temp/model.pt'
    model.load_state_dict(torch.load(path), strict=False)
    return model

# Training

In [143]:
training_cfg = {
    "lr": 0.001,
    "betas": (0.9, 0.98),
    "weight_decay": 1,
    "num_epochs": 75000,
    "frac_train": 0.4,
}
train_data, test_data, train_labels, test_labels, shuffled_indices = generate_train_test_data(group, training_cfg["frac_train"], seed=SEED)
train_indices = shuffled_indices[:len(train_data)]
optimizer = torch.optim.AdamW(model.parameters(), lr=training_cfg["lr"], betas=training_cfg["betas"], weight_decay=training_cfg["weight_decay"])

# remove temp directory if it exists
if os.path.exists('temp'):
    os.system('rm -rf temp')

# create a temp directory
task_dir = f'temp'
if not os.path.exists(task_dir):
    os.makedirs(task_dir)

# create a temp directory for checkpoints
checkpoint_dir = f'temp/checkpoints'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

In [144]:
train_losses = []
test_losses = []
train_accs = []
test_accs = []

checkpoint_every = 500

for epoch in tqdm(range(training_cfg["num_epochs"])):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % checkpoint_every == 0:
        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            train_acc = get_accuracy(train_logits, train_labels)
            test_acc = get_accuracy(test_logits, test_labels)
            print(f"Epoch {epoch}: Train Loss: {train_loss}, Test Loss: {test_loss}, Train Accuracy: {train_acc}, Test Accuracy: {test_acc}")
        train_losses.append(train_loss.item())
        test_losses.append(test_loss.item())
        train_accs.append(train_acc)
        test_accs.append(test_acc)

        # save checkpoint
        save_checkpoint(model, epoch, task_dir)





  0%|          | 55/75000 [00:00<04:37, 270.09it/s]

Epoch 0: Train Loss: 4.788640449630546, Test Loss: 4.789846160405907, Train Accuracy: 0.008854166604578495, Test Accuracy: 0.007986110635101795


  1%|          | 545/75000 [00:01<03:46, 328.43it/s]

Epoch 500: Train Loss: 0.09221563697684555, Test Loss: 19.694246244583418, Train Accuracy: 1.0, Test Accuracy: 0.00011574073869269341


  1%|▏         | 1055/75000 [00:03<03:43, 330.48it/s]

Epoch 1000: Train Loss: 0.001450883571250554, Test Loss: 31.22518665245293, Train Accuracy: 1.0, Test Accuracy: 0.00023148147738538682


  2%|▏         | 1541/75000 [00:04<03:36, 339.82it/s]

Epoch 1500: Train Loss: 4.246526148096426e-05, Test Loss: 40.78074727001797, Train Accuracy: 1.0, Test Accuracy: 0.00069444440305233


  3%|▎         | 2061/75000 [00:06<03:37, 335.69it/s]

Epoch 2000: Train Loss: 1.0533933492527978e-05, Test Loss: 42.80536998408742, Train Accuracy: 1.0, Test Accuracy: 0.00046296295477077365


  3%|▎         | 2540/75000 [00:07<03:38, 331.02it/s]

Epoch 2500: Train Loss: 9.588212874495189e-06, Test Loss: 41.24490995157959, Train Accuracy: 1.0, Test Accuracy: 0.00069444440305233


  4%|▍         | 3050/75000 [00:09<03:35, 334.00it/s]

Epoch 3000: Train Loss: 9.378379651280005e-06, Test Loss: 39.977309982659754, Train Accuracy: 1.0, Test Accuracy: 0.00046296295477077365


  5%|▍         | 3562/75000 [00:10<03:32, 335.75it/s]

Epoch 3500: Train Loss: 9.245461690815503e-06, Test Loss: 39.065480015176306, Train Accuracy: 1.0, Test Accuracy: 0.00046296295477077365


  5%|▌         | 4042/75000 [00:12<03:31, 335.59it/s]

Epoch 4000: Train Loss: 9.117594372052946e-06, Test Loss: 38.30634850968179, Train Accuracy: 1.0, Test Accuracy: 0.0005787037080153823


  6%|▌         | 4555/75000 [00:13<03:29, 336.06it/s]

Epoch 4500: Train Loss: 9.033575412072167e-06, Test Loss: 37.66636245723774, Train Accuracy: 1.0, Test Accuracy: 0.00069444440305233


  7%|▋         | 5033/75000 [00:15<03:37, 321.58it/s]

Epoch 5000: Train Loss: 8.960484987782406e-06, Test Loss: 37.124528740108, Train Accuracy: 1.0, Test Accuracy: 0.0009259259095415473


  7%|▋         | 5562/75000 [00:16<03:32, 326.33it/s]

Epoch 5500: Train Loss: 8.88780706418817e-06, Test Loss: 36.68442777705224, Train Accuracy: 1.0, Test Accuracy: 0.0009259259095415473


  8%|▊         | 6063/75000 [00:18<03:31, 325.90it/s]

Epoch 6000: Train Loss: 8.827659307100988e-06, Test Loss: 36.281497335379505, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


  9%|▊         | 6534/75000 [00:19<03:31, 324.08it/s]

Epoch 6500: Train Loss: 8.766795108312168e-06, Test Loss: 35.92149788267344, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


  9%|▉         | 7036/75000 [00:21<03:26, 328.62it/s]

Epoch 7000: Train Loss: 8.70586018842263e-06, Test Loss: 35.59721878720156, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 10%|█         | 7544/75000 [00:22<03:25, 327.72it/s]

Epoch 7500: Train Loss: 8.653125641041196e-06, Test Loss: 35.313005839940026, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 11%|█         | 8059/75000 [00:24<03:19, 335.44it/s]

Epoch 8000: Train Loss: 8.605686067864457e-06, Test Loss: 35.03278890211221, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 11%|█▏        | 8535/75000 [00:25<03:21, 329.90it/s]

Epoch 8500: Train Loss: 8.551810847800387e-06, Test Loss: 34.739323462732365, Train Accuracy: 1.0, Test Accuracy: 0.0012731481110677123


 12%|█▏        | 9041/75000 [00:27<03:17, 333.16it/s]

Epoch 9000: Train Loss: 8.50658259261561e-06, Test Loss: 34.49137157863076, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 13%|█▎        | 9545/75000 [00:28<03:22, 323.61it/s]

Epoch 9500: Train Loss: 8.461515977252199e-06, Test Loss: 34.23239148948205, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 13%|█▎        | 10049/75000 [00:30<03:19, 325.63it/s]

Epoch 10000: Train Loss: 8.41818154212479e-06, Test Loss: 33.9824705600459, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 14%|█▍        | 10544/75000 [00:31<03:20, 321.84it/s]

Epoch 10500: Train Loss: 8.381760424029346e-06, Test Loss: 33.76065107334918, Train Accuracy: 1.0, Test Accuracy: 0.0012731481110677123


 15%|█▍        | 11036/75000 [00:33<03:20, 318.44it/s]

Epoch 11000: Train Loss: 8.342107687623325e-06, Test Loss: 33.53585156533459, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 15%|█▌        | 11565/75000 [00:35<03:17, 320.92it/s]

Epoch 11500: Train Loss: 8.3097532155924e-06, Test Loss: 33.32277491243013, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 16%|█▌        | 12034/75000 [00:36<03:11, 328.43it/s]

Epoch 12000: Train Loss: 8.271175484665112e-06, Test Loss: 33.10997628470564, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 17%|█▋        | 12537/75000 [00:37<03:09, 328.97it/s]

Epoch 12500: Train Loss: 8.235231204601857e-06, Test Loss: 32.90317704753187, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 17%|█▋        | 13047/75000 [00:39<03:07, 330.72it/s]

Epoch 13000: Train Loss: 8.200356792759745e-06, Test Loss: 32.710834733718066, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 18%|█▊        | 13553/75000 [00:41<03:07, 327.41it/s]

Epoch 13500: Train Loss: 8.167303766550059e-06, Test Loss: 32.51174707043542, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 19%|█▊        | 14054/75000 [00:42<03:05, 328.57it/s]

Epoch 14000: Train Loss: 8.135672509914182e-06, Test Loss: 32.32238594065248, Train Accuracy: 1.0, Test Accuracy: 0.00138888880610466


 19%|█▉        | 14551/75000 [00:44<03:05, 326.19it/s]

Epoch 14500: Train Loss: 8.102903941423043e-06, Test Loss: 32.115956830488095, Train Accuracy: 1.0, Test Accuracy: 0.0015046296175569296


 20%|██        | 15052/75000 [00:45<03:02, 327.75it/s]

Epoch 15000: Train Loss: 8.065833351070415e-06, Test Loss: 31.890347544636757, Train Accuracy: 1.0, Test Accuracy: 0.00138888880610466


 21%|██        | 15563/75000 [00:47<02:58, 333.08it/s]

Epoch 15500: Train Loss: 8.026087279497744e-06, Test Loss: 31.649101277711008, Train Accuracy: 1.0, Test Accuracy: 0.0015046296175569296


 21%|██▏       | 16039/75000 [00:48<02:58, 330.08it/s]

Epoch 16000: Train Loss: 7.990219937112013e-06, Test Loss: 31.412466731245452, Train Accuracy: 1.0, Test Accuracy: 0.0015046296175569296


 22%|██▏       | 16549/75000 [00:50<02:57, 329.77it/s]

Epoch 16500: Train Loss: 7.946994950190117e-06, Test Loss: 31.1731525426386, Train Accuracy: 1.0, Test Accuracy: 0.0012731481110677123


 23%|██▎       | 17059/75000 [00:51<02:55, 329.84it/s]

Epoch 17000: Train Loss: 7.912758865476102e-06, Test Loss: 30.9362042017261, Train Accuracy: 1.0, Test Accuracy: 0.0012731481110677123


 23%|██▎       | 17535/75000 [00:53<02:52, 332.91it/s]

Epoch 17500: Train Loss: 7.87001976760949e-06, Test Loss: 30.692012864629923, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 24%|██▍       | 18045/75000 [00:54<02:51, 331.81it/s]

Epoch 18000: Train Loss: 7.843372741066842e-06, Test Loss: 30.466719066983185, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 25%|██▍       | 18555/75000 [00:56<02:50, 331.92it/s]

Epoch 18500: Train Loss: 7.813346975554372e-06, Test Loss: 30.263493982474394, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 25%|██▌       | 19065/75000 [00:57<02:48, 332.58it/s]

Epoch 19000: Train Loss: 7.786767561718433e-06, Test Loss: 30.054695287093264, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 26%|██▌       | 19541/75000 [00:59<02:47, 331.86it/s]

Epoch 19500: Train Loss: 7.762824314763146e-06, Test Loss: 29.86086674630068, Train Accuracy: 1.0, Test Accuracy: 0.00138888880610466


 27%|██▋       | 20051/75000 [01:00<02:45, 332.19it/s]

Epoch 20000: Train Loss: 7.733572345196399e-06, Test Loss: 29.683394518470973, Train Accuracy: 1.0, Test Accuracy: 0.0012731481110677123


 27%|██▋       | 20561/75000 [01:02<02:44, 331.90it/s]

Epoch 20500: Train Loss: 7.707896399875263e-06, Test Loss: 29.514493396346417, Train Accuracy: 1.0, Test Accuracy: 0.0012731481110677123


 28%|██▊       | 21037/75000 [01:03<02:42, 331.80it/s]

Epoch 21000: Train Loss: 7.679609528878402e-06, Test Loss: 29.34504568241238, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 29%|██▊       | 21547/75000 [01:05<02:40, 332.56it/s]

Epoch 21500: Train Loss: 7.6579333982688e-06, Test Loss: 29.1617596875596, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 29%|██▉       | 22057/75000 [01:06<02:39, 332.47it/s]

Epoch 22000: Train Loss: 7.632726591409446e-06, Test Loss: 28.981383395699144, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 30%|███       | 22537/75000 [01:08<02:35, 337.49it/s]

Epoch 22500: Train Loss: 7.609188975741762e-06, Test Loss: 28.81239064009929, Train Accuracy: 1.0, Test Accuracy: 0.0012731481110677123


 31%|███       | 23053/75000 [01:09<02:37, 330.65it/s]

Epoch 23000: Train Loss: 7.5839525819489534e-06, Test Loss: 28.632542509547275, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 31%|███▏      | 23563/75000 [01:11<02:34, 332.03it/s]

Epoch 23500: Train Loss: 7.558610574122402e-06, Test Loss: 28.442976767186426, Train Accuracy: 1.0, Test Accuracy: 0.0016203703125938773


 32%|███▏      | 24039/75000 [01:12<02:34, 330.17it/s]

Epoch 24000: Train Loss: 7.539640716760936e-06, Test Loss: 28.2618926163763, Train Accuracy: 1.0, Test Accuracy: 0.0016203703125938773


 33%|███▎      | 24549/75000 [01:14<02:32, 331.54it/s]

Epoch 24500: Train Loss: 7.520757180362572e-06, Test Loss: 28.103247645970754, Train Accuracy: 1.0, Test Accuracy: 0.0015046296175569296


 33%|███▎      | 25059/75000 [01:15<02:30, 331.98it/s]

Epoch 25000: Train Loss: 7.5037695811678605e-06, Test Loss: 27.97218658586363, Train Accuracy: 1.0, Test Accuracy: 0.001967592630535364


 34%|███▍      | 25535/75000 [01:17<02:29, 330.74it/s]

Epoch 25500: Train Loss: 7.472600418367857e-06, Test Loss: 27.80970157428469, Train Accuracy: 1.0, Test Accuracy: 0.002430555410683155


 35%|███▍      | 26043/75000 [01:18<02:28, 330.65it/s]

Epoch 26000: Train Loss: 7.445270258515703e-06, Test Loss: 27.61963518039626, Train Accuracy: 1.0, Test Accuracy: 0.0025462962221354246


 35%|███▌      | 26552/75000 [01:20<02:27, 327.47it/s]

Epoch 26500: Train Loss: 7.4162403797060065e-06, Test Loss: 27.36292946594753, Train Accuracy: 1.0, Test Accuracy: 0.002662037033587694


 36%|███▌      | 27060/75000 [01:21<02:24, 332.71it/s]

Epoch 27000: Train Loss: 7.382393125900909e-06, Test Loss: 27.091228031967475, Train Accuracy: 1.0, Test Accuracy: 0.003009259235113859


 37%|███▋      | 27536/75000 [01:23<02:23, 331.53it/s]

Epoch 27500: Train Loss: 7.328589038126993e-06, Test Loss: 26.813108132164587, Train Accuracy: 1.0, Test Accuracy: 0.0028935184236615896


 37%|███▋      | 28046/75000 [01:24<02:22, 329.09it/s]

Epoch 28000: Train Loss: 7.253160147275028e-06, Test Loss: 26.340089210303564, Train Accuracy: 1.0, Test Accuracy: 0.0028935184236615896


 38%|███▊      | 28556/75000 [01:26<02:20, 330.35it/s]

Epoch 28500: Train Loss: 6.866271723380674e-06, Test Loss: 24.727168960491685, Train Accuracy: 1.0, Test Accuracy: 0.003935185261070728


 39%|███▉      | 29066/75000 [01:27<02:17, 334.65it/s]

Epoch 29000: Train Loss: 6.778686261361496e-06, Test Loss: 23.38908695137215, Train Accuracy: 1.0, Test Accuracy: 0.005092592444270849


 39%|███▉      | 29548/75000 [01:29<02:30, 301.16it/s]

Epoch 29500: Train Loss: 6.753517312819167e-06, Test Loss: 22.87147307175749, Train Accuracy: 1.0, Test Accuracy: 0.006134259048849344


 40%|████      | 30044/75000 [01:30<02:28, 301.87it/s]

Epoch 30000: Train Loss: 6.703964682782795e-06, Test Loss: 22.53354464050988, Train Accuracy: 1.0, Test Accuracy: 0.007175925653427839


 41%|████      | 30540/75000 [01:32<02:25, 305.11it/s]

Epoch 30500: Train Loss: 6.624811643447796e-06, Test Loss: 22.17765734690029, Train Accuracy: 1.0, Test Accuracy: 0.006481481250375509


 41%|████▏     | 31048/75000 [01:34<02:11, 333.58it/s]

Epoch 31000: Train Loss: 6.609254359843747e-06, Test Loss: 21.880568016570763, Train Accuracy: 1.0, Test Accuracy: 0.007060185074806213


 42%|████▏     | 31558/75000 [01:35<02:09, 336.06it/s]

Epoch 31500: Train Loss: 6.568767771482509e-06, Test Loss: 21.63429638300545, Train Accuracy: 1.0, Test Accuracy: 0.007060185074806213


 43%|████▎     | 32034/75000 [01:36<02:09, 332.44it/s]

Epoch 32000: Train Loss: 6.538211501695077e-06, Test Loss: 21.419446053111486, Train Accuracy: 1.0, Test Accuracy: 0.007175925653427839


 43%|████▎     | 32544/75000 [01:38<02:06, 334.35it/s]

Epoch 32500: Train Loss: 6.506162771517954e-06, Test Loss: 21.21326760484711, Train Accuracy: 1.0, Test Accuracy: 0.0076388888992369175


 44%|████▍     | 33061/75000 [01:40<02:17, 304.85it/s]

Epoch 33000: Train Loss: 6.472669201983297e-06, Test Loss: 20.98683300858886, Train Accuracy: 1.0, Test Accuracy: 0.007986110635101795


 45%|████▍     | 33552/75000 [01:41<02:09, 319.51it/s]

Epoch 33500: Train Loss: 6.445430908914143e-06, Test Loss: 20.788384640324466, Train Accuracy: 1.0, Test Accuracy: 0.00902777723968029


 45%|████▌     | 34044/75000 [01:43<02:06, 325.02it/s]

Epoch 34000: Train Loss: 6.42068273578108e-06, Test Loss: 20.60694035701498, Train Accuracy: 1.0, Test Accuracy: 0.009374999441206455


 46%|████▌     | 34548/75000 [01:44<02:02, 328.89it/s]

Epoch 34500: Train Loss: 6.3931798190469635e-06, Test Loss: 20.43196017378151, Train Accuracy: 1.0, Test Accuracy: 0.009490740485489368


 47%|████▋     | 35058/75000 [01:46<02:00, 330.57it/s]

Epoch 35000: Train Loss: 6.369967330538463e-06, Test Loss: 20.26740408285561, Train Accuracy: 1.0, Test Accuracy: 0.010185184888541698


 47%|████▋     | 35534/75000 [01:47<01:59, 330.83it/s]

Epoch 35500: Train Loss: 6.343500452344798e-06, Test Loss: 20.08430434209972, Train Accuracy: 1.0, Test Accuracy: 0.010648148134350777


 48%|████▊     | 36044/75000 [01:49<01:57, 330.76it/s]

Epoch 36000: Train Loss: 6.313479815200927e-06, Test Loss: 19.879143133593793, Train Accuracy: 1.0, Test Accuracy: 0.011689814738929272


 49%|████▊     | 36554/75000 [01:50<01:56, 329.43it/s]

Epoch 36500: Train Loss: 6.272312754805183e-06, Test Loss: 19.646603377999355, Train Accuracy: 1.0, Test Accuracy: 0.011574073694646358


 49%|████▉     | 37055/75000 [01:52<01:55, 328.37it/s]

Epoch 37000: Train Loss: 6.234207488279559e-06, Test Loss: 19.358881351174368, Train Accuracy: 1.0, Test Accuracy: 0.01215277798473835


 50%|█████     | 37564/75000 [01:53<01:53, 328.68it/s]

Epoch 37500: Train Loss: 6.20565459422312e-06, Test Loss: 19.08614042866305, Train Accuracy: 1.0, Test Accuracy: 0.012384259141981602


 51%|█████     | 38047/75000 [01:55<01:48, 339.05it/s]

Epoch 38000: Train Loss: 6.172658027457068e-06, Test Loss: 18.820226449169628, Train Accuracy: 1.0, Test Accuracy: 0.01215277798473835


 51%|█████▏    | 38559/75000 [01:56<01:49, 332.74it/s]

Epoch 38500: Train Loss: 6.144682063424371e-06, Test Loss: 18.55695150350506, Train Accuracy: 1.0, Test Accuracy: 0.011689814738929272


 52%|█████▏    | 39035/75000 [01:58<01:47, 334.01it/s]

Epoch 39000: Train Loss: 6.113745422392649e-06, Test Loss: 18.304584442704417, Train Accuracy: 1.0, Test Accuracy: 0.012500000186264515


 53%|█████▎    | 39545/75000 [01:59<01:46, 332.53it/s]

Epoch 39500: Train Loss: 6.085818153365357e-06, Test Loss: 18.06142377887689, Train Accuracy: 1.0, Test Accuracy: 0.013078703545033932


 53%|█████▎    | 40055/75000 [02:01<01:44, 333.44it/s]

Epoch 40000: Train Loss: 6.056025373502686e-06, Test Loss: 17.829583619040406, Train Accuracy: 1.0, Test Accuracy: 0.013888888992369175


 54%|█████▍    | 40565/75000 [02:02<01:43, 331.64it/s]

Epoch 40500: Train Loss: 6.029623147600498e-06, Test Loss: 17.61176409454993, Train Accuracy: 1.0, Test Accuracy: 0.015046295709908009


 55%|█████▍    | 41041/75000 [02:04<01:42, 331.63it/s]

Epoch 41000: Train Loss: 6.0053229759146625e-06, Test Loss: 17.407382882643347, Train Accuracy: 1.0, Test Accuracy: 0.015856482088565826


 55%|█████▌    | 41551/75000 [02:05<01:40, 332.54it/s]

Epoch 41500: Train Loss: 5.983003571144594e-06, Test Loss: 17.21130604893663, Train Accuracy: 1.0, Test Accuracy: 0.016087962314486504


 56%|█████▌    | 42061/75000 [02:07<01:38, 334.45it/s]

Epoch 42000: Train Loss: 5.961012355919118e-06, Test Loss: 17.0355843945154, Train Accuracy: 1.0, Test Accuracy: 0.01666666567325592


 57%|█████▋    | 42537/75000 [02:08<01:37, 333.95it/s]

Epoch 42500: Train Loss: 5.941329721688369e-06, Test Loss: 16.876215237952806, Train Accuracy: 1.0, Test Accuracy: 0.016898147761821747


 57%|█████▋    | 43047/75000 [02:10<01:37, 329.01it/s]

Epoch 43000: Train Loss: 5.918655172883252e-06, Test Loss: 16.712560947550067, Train Accuracy: 1.0, Test Accuracy: 0.017129629850387573


 58%|█████▊    | 43557/75000 [02:11<01:33, 334.59it/s]

Epoch 43500: Train Loss: 5.896129222321611e-06, Test Loss: 16.536363207333466, Train Accuracy: 1.0, Test Accuracy: 0.017824074253439903


 59%|█████▉    | 44067/75000 [02:13<01:32, 334.35it/s]

Epoch 44000: Train Loss: 5.872568202247375e-06, Test Loss: 16.35011642162886, Train Accuracy: 1.0, Test Accuracy: 0.018981480970978737


 59%|█████▉    | 44543/75000 [02:14<01:31, 333.83it/s]

Epoch 44500: Train Loss: 5.8501600175193655e-06, Test Loss: 16.16388637715165, Train Accuracy: 1.0, Test Accuracy: 0.020254628732800484


 60%|██████    | 45052/75000 [02:16<01:31, 328.43it/s]

Epoch 45000: Train Loss: 5.829058720800806e-06, Test Loss: 15.96083811232319, Train Accuracy: 1.0, Test Accuracy: 0.02152777649462223


 61%|██████    | 45560/75000 [02:17<01:29, 328.09it/s]

Epoch 45500: Train Loss: 5.8023488736319265e-06, Test Loss: 15.731489204111469, Train Accuracy: 1.0, Test Accuracy: 0.02291666716337204


 61%|██████▏   | 46035/75000 [02:19<01:28, 327.95it/s]

Epoch 46000: Train Loss: 5.778417176731881e-06, Test Loss: 15.498368319486792, Train Accuracy: 1.0, Test Accuracy: 0.02395833283662796


 62%|██████▏   | 46545/75000 [02:20<01:26, 328.90it/s]

Epoch 46500: Train Loss: 5.7485997991579845e-06, Test Loss: 15.25763803906338, Train Accuracy: 1.0, Test Accuracy: 0.02569444477558136


 63%|██████▎   | 47052/75000 [02:22<01:25, 327.67it/s]

Epoch 47000: Train Loss: 5.71902644418635e-06, Test Loss: 14.994327447629551, Train Accuracy: 1.0, Test Accuracy: 0.027314813807606697


 63%|██████▎   | 47560/75000 [02:23<01:23, 329.15it/s]

Epoch 47500: Train Loss: 5.682111042755983e-06, Test Loss: 14.685775060844042, Train Accuracy: 1.0, Test Accuracy: 0.02916666679084301


 64%|██████▍   | 48036/75000 [02:25<01:22, 328.80it/s]

Epoch 48000: Train Loss: 5.649667611003167e-06, Test Loss: 14.348951723577779, Train Accuracy: 1.0, Test Accuracy: 0.03275462985038757


 65%|██████▍   | 48545/75000 [02:26<01:19, 332.64it/s]

Epoch 48500: Train Loss: 5.608082605487589e-06, Test Loss: 13.962185453268104, Train Accuracy: 1.0, Test Accuracy: 0.03634259104728699


 65%|██████▌   | 49055/75000 [02:28<01:18, 329.94it/s]

Epoch 49000: Train Loss: 5.562307738428344e-06, Test Loss: 13.519015742211737, Train Accuracy: 1.0, Test Accuracy: 0.03807870298624039


 66%|██████▌   | 49565/75000 [02:29<01:16, 330.52it/s]

Epoch 49500: Train Loss: 5.502751018378544e-06, Test Loss: 13.025565863303097, Train Accuracy: 1.0, Test Accuracy: 0.04328703507781029


 67%|██████▋   | 50041/75000 [02:31<01:15, 330.37it/s]

Epoch 50000: Train Loss: 5.427335541992286e-06, Test Loss: 12.398843871511401, Train Accuracy: 1.0, Test Accuracy: 0.04884259030222893


 67%|██████▋   | 50551/75000 [02:32<01:13, 330.85it/s]

Epoch 50500: Train Loss: 5.3553083263470075e-06, Test Loss: 11.72297639207231, Train Accuracy: 1.0, Test Accuracy: 0.056365739554166794


 68%|██████▊   | 51050/75000 [02:34<01:13, 327.15it/s]

Epoch 51000: Train Loss: 5.271993972304082e-06, Test Loss: 10.995280503199856, Train Accuracy: 1.0, Test Accuracy: 0.0667824074625969


 69%|██████▊   | 51559/75000 [02:35<01:11, 329.78it/s]

Epoch 51500: Train Loss: 5.214591228689342e-06, Test Loss: 10.304557355796362, Train Accuracy: 1.0, Test Accuracy: 0.08136574178934097


 69%|██████▉   | 52035/75000 [02:37<01:09, 330.50it/s]

Epoch 52000: Train Loss: 5.1546502519967825e-06, Test Loss: 9.673002143439602, Train Accuracy: 1.0, Test Accuracy: 0.09236110746860504


 70%|███████   | 52545/75000 [02:38<01:08, 329.82it/s]

Epoch 52500: Train Loss: 5.083351074943411e-06, Test Loss: 8.999046109956504, Train Accuracy: 1.0, Test Accuracy: 0.11064814776182175


 71%|███████   | 53055/75000 [02:40<01:06, 330.37it/s]

Epoch 53000: Train Loss: 5.009791931757856e-06, Test Loss: 8.313636907846908, Train Accuracy: 1.0, Test Accuracy: 0.13043981790542603


 71%|███████▏  | 53565/75000 [02:41<01:05, 328.77it/s]

Epoch 53500: Train Loss: 4.901697692593372e-06, Test Loss: 7.550097084414532, Train Accuracy: 1.0, Test Accuracy: 0.15324074029922485


 72%|███████▏  | 54040/75000 [02:43<01:04, 326.70it/s]

Epoch 54000: Train Loss: 4.793080132329119e-06, Test Loss: 6.672386415046485, Train Accuracy: 1.0, Test Accuracy: 0.1866898089647293


 73%|███████▎  | 54549/75000 [02:44<01:02, 328.25it/s]

Epoch 54500: Train Loss: 4.704302810143485e-06, Test Loss: 5.868192245731048, Train Accuracy: 1.0, Test Accuracy: 0.2222222238779068


 73%|███████▎  | 55059/75000 [02:46<01:00, 328.84it/s]

Epoch 55000: Train Loss: 4.573785034272513e-06, Test Loss: 5.021286351653474, Train Accuracy: 1.0, Test Accuracy: 0.2674768567085266


 74%|███████▍  | 55565/75000 [02:48<00:58, 331.23it/s]

Epoch 55500: Train Loss: 4.344789847361807e-06, Test Loss: 3.9317428781977015, Train Accuracy: 1.0, Test Accuracy: 0.34143519401550293


 75%|███████▍  | 56041/75000 [02:49<00:57, 328.56it/s]

Epoch 56000: Train Loss: 4.069025077045527e-06, Test Loss: 2.6533171806389544, Train Accuracy: 1.0, Test Accuracy: 0.46655091643333435


 75%|███████▌  | 56551/75000 [02:50<00:55, 331.65it/s]

Epoch 56500: Train Loss: 3.756941316754946e-06, Test Loss: 1.4274187227466157, Train Accuracy: 1.0, Test Accuracy: 0.6435185074806213


 76%|███████▌  | 57062/75000 [02:52<00:53, 336.48it/s]

Epoch 57000: Train Loss: 3.5212546965505825e-06, Test Loss: 0.6090847056514995, Train Accuracy: 1.0, Test Accuracy: 0.8148148059844971


 77%|███████▋  | 57544/75000 [02:53<00:51, 337.94it/s]

Epoch 57500: Train Loss: 3.138368893575349e-06, Test Loss: 0.1445243486648616, Train Accuracy: 1.0, Test Accuracy: 0.9521990418434143


 77%|███████▋  | 58056/75000 [02:55<00:51, 331.45it/s]

Epoch 58000: Train Loss: 2.872058941821309e-06, Test Loss: 0.013671531698568185, Train Accuracy: 1.0, Test Accuracy: 0.9967592358589172


 78%|███████▊  | 58566/75000 [02:56<00:49, 331.85it/s]

Epoch 58500: Train Loss: 2.6782472323079438e-06, Test Loss: 0.0009168773949502967, Train Accuracy: 1.0, Test Accuracy: 1.0


 79%|███████▊  | 59042/75000 [02:58<00:48, 331.74it/s]

Epoch 59000: Train Loss: 2.4729377103018777e-06, Test Loss: 9.8915187475948e-05, Train Accuracy: 1.0, Test Accuracy: 1.0


 79%|███████▉  | 59552/75000 [02:59<00:46, 333.12it/s]

Epoch 59500: Train Loss: 2.3553346298043783e-06, Test Loss: 2.687931886452371e-05, Train Accuracy: 1.0, Test Accuracy: 1.0


 80%|████████  | 60062/75000 [03:01<00:44, 333.22it/s]

Epoch 60000: Train Loss: 2.2618374419795676e-06, Test Loss: 1.298555439462764e-05, Train Accuracy: 1.0, Test Accuracy: 1.0


 81%|████████  | 60538/75000 [03:02<00:43, 331.76it/s]

Epoch 60500: Train Loss: 2.181857972790403e-06, Test Loss: 8.589235791436772e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 81%|████████▏ | 61046/75000 [03:04<00:42, 327.16it/s]

Epoch 61000: Train Loss: 2.1084638522275685e-06, Test Loss: 6.704509864876695e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 82%|████████▏ | 61560/75000 [03:05<00:40, 335.45it/s]

Epoch 61500: Train Loss: 2.0340739308289346e-06, Test Loss: 5.51413766081908e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 83%|████████▎ | 62034/75000 [03:07<00:39, 326.93it/s]

Epoch 62000: Train Loss: 1.9650580589948158e-06, Test Loss: 4.6891417668577825e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 83%|████████▎ | 62538/75000 [03:08<00:38, 327.74it/s]

Epoch 62500: Train Loss: 1.909919170712326e-06, Test Loss: 4.165876455975242e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 84%|████████▍ | 63045/75000 [03:10<00:36, 329.25it/s]

Epoch 63000: Train Loss: 1.8488617386053036e-06, Test Loss: 3.744434844542107e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 85%|████████▍ | 63555/75000 [03:12<00:34, 330.84it/s]

Epoch 63500: Train Loss: 1.7886234955720387e-06, Test Loss: 3.4325214065642304e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 85%|████████▌ | 64066/75000 [03:13<00:32, 334.91it/s]

Epoch 64000: Train Loss: 1.7494687011550233e-06, Test Loss: 3.26887363146542e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 86%|████████▌ | 64544/75000 [03:15<00:34, 302.95it/s]

Epoch 64500: Train Loss: 1.7216043529608916e-06, Test Loss: 3.204912657967898e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 87%|████████▋ | 65036/75000 [03:16<00:30, 326.18it/s]

Epoch 65000: Train Loss: 1.706215560263108e-06, Test Loss: 3.173177458927927e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 87%|████████▋ | 65534/75000 [03:18<00:29, 326.29it/s]

Epoch 65500: Train Loss: 1.692466582949674e-06, Test Loss: 3.133625392339933e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 88%|████████▊ | 66065/75000 [03:19<00:27, 326.56it/s]

Epoch 66000: Train Loss: 1.6770921792472609e-06, Test Loss: 3.0675506902054027e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 89%|████████▊ | 66539/75000 [03:21<00:25, 331.88it/s]

Epoch 66500: Train Loss: 1.6647249756787463e-06, Test Loss: 3.0290293763099327e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 89%|████████▉ | 67052/75000 [03:22<00:23, 333.02it/s]

Epoch 67000: Train Loss: 1.650439941519867e-06, Test Loss: 2.997267465970018e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 90%|█████████ | 67562/75000 [03:24<00:22, 331.36it/s]

Epoch 67500: Train Loss: 1.6356889859205959e-06, Test Loss: 2.967215449236484e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 91%|█████████ | 68038/75000 [03:25<00:21, 330.94it/s]

Epoch 68000: Train Loss: 1.6241817105707107e-06, Test Loss: 2.9454624960086443e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 91%|█████████▏| 68546/75000 [03:27<00:19, 327.57it/s]

Epoch 68500: Train Loss: 1.6145126082200447e-06, Test Loss: 2.9298268551339112e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 92%|█████████▏| 69056/75000 [03:28<00:17, 331.39it/s]

Epoch 69000: Train Loss: 1.6091485718025335e-06, Test Loss: 2.9233674794380216e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 93%|█████████▎| 69564/75000 [03:30<00:16, 328.80it/s]

Epoch 69500: Train Loss: 1.6027954647796705e-06, Test Loss: 2.9161515079649815e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 93%|█████████▎| 70042/75000 [03:31<00:14, 336.31it/s]

Epoch 70000: Train Loss: 1.5987439150519777e-06, Test Loss: 2.9131157647091446e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 94%|█████████▍| 70552/75000 [03:33<00:13, 331.43it/s]

Epoch 70500: Train Loss: 1.5940250762845288e-06, Test Loss: 2.9082164326665877e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 95%|█████████▍| 71062/75000 [03:34<00:11, 329.31it/s]

Epoch 71000: Train Loss: 1.5908218761586038e-06, Test Loss: 2.904543327131442e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 95%|█████████▌| 71537/75000 [03:36<00:10, 331.79it/s]

Epoch 71500: Train Loss: 1.5880382052228119e-06, Test Loss: 2.901215212829023e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 96%|█████████▌| 72047/75000 [03:37<00:08, 331.40it/s]

Epoch 72000: Train Loss: 1.5863402038330937e-06, Test Loss: 2.9010685396515356e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 97%|█████████▋| 72558/75000 [03:39<00:07, 336.61it/s]

Epoch 72500: Train Loss: 1.5843413327071396e-06, Test Loss: 2.900596816474384e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 97%|█████████▋| 73041/75000 [03:40<00:05, 337.39it/s]

Epoch 73000: Train Loss: 1.581574648014948e-06, Test Loss: 2.8993523342419734e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 98%|█████████▊| 73551/75000 [03:42<00:04, 331.85it/s]

Epoch 73500: Train Loss: 1.5802146834775628e-06, Test Loss: 2.8988171902115067e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 99%|█████████▊| 74061/75000 [03:43<00:02, 330.49it/s]

Epoch 74000: Train Loss: 1.5809986198968154e-06, Test Loss: 2.902854505176685e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 99%|█████████▉| 74537/75000 [03:45<00:01, 330.25it/s]

Epoch 74500: Train Loss: 1.5799228038793309e-06, Test Loss: 2.9033055499829177e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


100%|██████████| 75000/75000 [03:46<00:00, 331.03it/s]


Let's plot loss and accuracy. We see the model groks! It initially memorises the training data, and much later on (around epoch 70k) suddenly generalises to the test set.

In [145]:
x = list(range(0, training_cfg["num_epochs"], checkpoint_every))
lines([train_losses, test_losses], x=x, labels=["Train Loss", "Test Loss"], xaxis="Epoch", yaxis="Loss", title="Loss vs Epoch", log_y=True)
lines([train_accs, test_accs], x=x, labels=["Train Accuracy", "Test Accuracy"], xaxis="Epoch", yaxis="Accuracy", title="Accuracy vs Epoch")

# Interpretability

Now let's try to interpret the model. To do so, we need to know some things about representation theory. Let's first create a class to hold information about representations. It will share some attributes with the group class. We will subclass from this class to create a class for each representation.

What information do we need, and how do we store it? 

Let $n$ be the order of group $G$. Each representation is a set $n$ $d \times d$ matrices. We store this in our representation objects as tensors of shape $n \times d^2$. 

We also compute what we call an 'orthogonal representation'. I'll explain why this is necessary later on.

Finally we compute the characters $\chi(abc^{-1})$ for each pair of inputs $a,b$ and output $c$. We store this in a tensor of shape $n \times n \times n$. This character tensor is used later to compute logit similarity.

## Setup: Defining Representations

In [146]:
class Representation():
    """
    Base class for all representations.
    """

    def __init__(self, compute_rep_params, index, order, multiplication_table, inverses, all_data, group_acronym, irrep=True):
        """
        Initialise the symmetric group representation.

        Args:
            compute_rep_params (tuple): representation specific parameters required for computing the representation
            index (int): group index in family
            order (int): order of the group
            multiplication_table (torch.tensor): square (group.order, group.order) tensor of group multiplication table 
            inverses (torch.tensor): vector of group inverses
            all_data ()
            irrep (Boolean)
        """

        self.index = index
        self.order = order
        self.multiplication_table = multiplication_table
        self.inverses = inverses
        self.all_data = all_data
        self.group_acronym = group_acronym

        # TODO: this is needed to get the dimension of generated representations - think up a better way of doing this
        self.compute_rep_params = compute_rep_params

        self.dim = self.get_rep_dim()

        self.rep = self.compute_rep(*compute_rep_params)

        if irrep:
            self.orth_rep = self.compute_orth_rep(self.rep)
            if self.friendly_name != 'trivial':
                self.logit_trace_tensor_cube = self.compute_logit_trace_tensor_cube()


    def get_rep_dim(self):
        return NotImplementedError

    def compute_rep(self):
        """
        Compute the representation. Must be implemented by child class.

        Raises:
            NotImplementedError
        """
        raise NotImplementedError
    
    def compute_orth_rep(self, rep):
        """
        Use QR decomposition to orthogonalise the representation but retain the subspace spanned by the columns. 

        Args:
            rep (torch.tensor): (group.order, dim^2) tensor of representation

        Returns:
            torch.tensor: (group.order, dim^2) tensor with orthonormal columns en
        """
        
        orth_rep = rep.reshape(self.order, self.dim * self.dim)
        orth_rep = torch.linalg.qr(orth_rep)[0]
        return orth_rep


    def compute_logit_trace_tensor_cube(self):
        """
        Under the hypothesis, the network computes tr(\rho(x)\rho(y)\rho(z^-1)) for some representation \rho.
        This function computes this trace tensor cube for a given representation
        
        Returns:
            torch.tensor: (group.order^3) trace tensor cube
        """
        print(f'Computing trace tensor cube for {self.friendly_name} representation')
        filename = f'utils/cache/{self.group_acronym}{self.index}/{self.group_acronym}{self.index}_{self.friendly_name}_trace_tensor_cube.pt'
        if os.path.exists(filename):
            print('... loading from file')
            t = torch.load(filename)
            return t 
        N = self.all_data.shape[0]
        t = torch.zeros((self.order*self.order, self.order), dtype=torch.float).cuda()
        for i in tqdm(range(N)):
            x = self.all_data[i, 0]
            y = self.all_data[i, 1]
            xy = self.multiplication_table[x, y]
            for z_idx in range(self.order):
                xyz = self.multiplication_table[xy, self.inverses[z_idx]]
                t[i, z_idx] = torch.trace(self.rep[xyz])
        t = t.reshape(self.order, self.order, self.order)
        f = open(filename, 'wb')
        torch.save(t, f)
        return t 

Now we construct all the irreducible representations of S5. One can find a full list [here](https://groupprops.subwiki.org/wiki/Linear_representation_theory_of_symmetric_group:S5). Feel free to skip over this construction if you're not interested in the details.

First, the boring trivial reprsentation. This is one dimensional - that is $d=1$, and all the elements map to the matrix $[1]$. We pass information needed to compute and initialise the representation from the child class to the parent class via `compute_rep_params` and `init_rep_params`.

### Constructing Irreps of S5


In [147]:
class TrivialRepresentation(Representation):
    """
    The trivial representation of the symmetric group.
    """
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the trivial representation. 

        Args:
            compute_rep_params (list): idx_to_perm function required to compute the representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'trivial'
        super().__init__(compute_rep_params, **init_rep_params, irrep=True)
    
    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return 1

    def compute_rep(self):
        """
        Compute the trivial representation.

        Args:
            idx_to_perm (function): function to convert an index to a permutation

        Returns:
            torch.tensor: (group.order, 1) tensor of representation
        """
        return torch.ones((self.order, 1, 1), dtype=torch.float).cuda()
    

# we need some data from the group to initialise representations
rep_params = {
            'index': group.index,
            'order': group.order,
            'multiplication_table': group.multiplication_table,
            'inverses': group.inverses,
            'all_data': group.all_data,
            'group_acronym': group.acronym
        }

trivial_rep = TrivialRepresentation(compute_rep_params=[], init_rep_params=rep_params)

Next, the sign representation. This is also 1 dimensional, but the elements map to $[1]$ or $[-1]$ depending on whether the permutation is even or odd.

In [148]:
class SignRepresentation(Representation):
    """
    Initialise the sign representation of the symmetric group.

    """
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the sign representation. 

        Args:
            compute_rep_params (list): list consisting of the signatures object required to compute the representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'sign'
        super().__init__(compute_rep_params, **init_rep_params)

    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return 1

    def compute_rep(self, signatures):
        """
        Compute the sign representation from the signatures.

        Args:
            signatures (torch.tensor): (group.order, 1) tensor of signatures

        Returns:
            torch.tensor: (group.order, 1, 1) tensor of sign representations
        """
        rep = torch.zeros(self.order, 1, 1).cuda()
        rep[:, 0, 0] = signatures
        return rep

sign_rep = SignRepresentation([group.signatures], rep_params)


Computing trace tensor cube for sign representation
... loading from file


Next, we construct the natural representation. This isn't actually an irreducible representation, but is easy to construct. We can later decompose it into a direct sum of the trivial representation and the `standard` representation. The natural representation itself is $d=5$, with group elements mapping naturally to the 5x5 [permutation matrices](https://en.wikipedia.org/wiki/Permutation_matrix) (with a 1 in each row and column, and 0s elsewhere).

In [149]:
class NaturalRepresentation(Representation):
    """
    Compute the natural representation of the symmetric group.
    """
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the natural representation. 

        Args:
            compute_rep_params (list): idx_to_perm function required to compute the representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'natural'
        super().__init__(compute_rep_params, **init_rep_params, irrep=False)
    
    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return self.index

    def compute_rep(self, idx_to_perm):
        """
        Compute the natural representation by directly computing permutation matrices

        Args:
            idx_to_perm (function): Function that takes an index and returns the corresponding permutation object

        Returns:
            torch.tensor: (group.order, group.index, group.index) tensor of permutation matrices for each group element
        """
        idx = list(np.linspace(0, self.index-1, self.index))
        rep = torch.zeros(self.order, self.index, self.index).cuda()
        for x in range(self.order):
            rep[x, idx, idx_to_perm(x)(idx)] = 1
        return rep

natural_rep = NaturalRepresentation([group.idx_to_perm], rep_params)


an integer is required (got type numpy.float64).  Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.



In [150]:
class StandardRepresentation(Representation):
    """
    Generate the standard representation of the symmetric group.

    """
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the standard representation. 

        Args:
            compute_rep_params (list): list containing natural_reps object necessary to calculate the standard representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'standard'
        super().__init__(compute_rep_params, **init_rep_params)

    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return self.index-1
    
    def compute_rep(self, natural_reps):
        """
        Compute the standard representation from the natural representation.

        Args:
            natural_reps (torch.tensor): (group.order, group.index, group.index) tensor of natural representations (permutation matrices)

        Returns:
            torch.tensor: (group.order, group.index-1, group.index-1) tensor of standard representations
        """
        rep = []
        basis_transform = torch.zeros(self.index, self.index).cuda()
        for i in range(self.index-1):
            basis_transform[i, i] = 1
            basis_transform[i, i+1] = -1
        basis_transform[self.index-1, self.index-1] = 1 #to make the transform non singular
        for x in natural_reps:
            temp = basis_transform @ x @ basis_transform.inverse()
            rep.append(temp[:self.index-1, :self.index-1])
        rep = torch.stack(rep, dim=0).cuda()
        return rep  

standard_rep = StandardRepresentation([natural_rep.rep], rep_params)


Computing trace tensor cube for standard representation
... loading from file


In [151]:
class StandardSignRepresentation(Representation):
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the tensor product of the standard and sign representation. 

        Args:
            compute_rep_params (list): list consisting of tensor of standard representation and signatures, required to compute the representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'standard_sign'
        super().__init__(compute_rep_params, **init_rep_params)

    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return self.index-1

    def compute_rep(self, standard_reps, signatures):
        """
        Compute the tensor product of the standard and sign representation.

        Args:
            standard_reps (torch.tensor): (group.order, group.index-1, group.index-1) tensor of standard representations
            signatures (torch.tensor): (group.order, 1) tensor of signatures

        Returns:
            torch.tensor: (group.order, group.index-1, group.index-1) tensor of standard_sign representations
        """
        rep = []
        for i in range(standard_reps.shape[0]):
            rep.append(signatures[i]*standard_reps[i])
        rep = torch.stack(rep, dim=0).cuda()
        return rep
    
standard_sign_rep = StandardSignRepresentation([standard_rep.rep, sign_rep.rep], rep_params)      

Computing trace tensor cube for standard_sign representation
... loading from file


In [152]:
class SymmetricRepresentationFromGenerators(Representation):
    def __init__(self, compute_rep_params, init_rep_params, name):
        """
        Initialise a representation of the symmetric group from a set of generators.

        Args:
            compute_rep_params (list): list consisting of the generators of the representation, sympy group object, and a function that maps indices to permutations
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
            name (str): name of the representation
        """
        self.friendly_name = name
        super().__init__(compute_rep_params, **init_rep_params)

    # TODO: make this less hacky
    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return list(self.compute_rep_params[0].values())[0].shape[0] # hacky way to get the dimension of the representation

    def compute_rep(self, generators, G, idx_to_perm):
        """
        Compute the representation from the generators.

        Args:
            generators (dict): dictionary of generators of the group along with their representations
            G (sympy group object): group object
            idx_to_perm (function): function that maps indices to permutations

        Returns:
            torch.tensor: (group.order, dim, dim) tensor of arbitrary representations
        """
        rep = torch.zeros(self.order, self.dim, self.dim).cuda()
        for i in range(self.order):
            generator_product = G.generator_product(idx_to_perm(i), original=True)
            result = torch.eye(self.dim).float()
            for g in generator_product:
                result = result @ generators[g]
            rep[i] = result
        return rep.cuda()


s5_5d_a_generators = {}
s5_5d_a_generators[Permutation(0, 1, 2, 3, 4)] = torch.tensor([
    [ 1, -1, -1,  1,  0],
    [ 0, -1, -1,  0,  1],
    [ 1, -1,  0,  0,  0],
    [ 0, -1,  0,  0,  0],
    [ 1, -1, -1,  0,  0]
]).float()
s5_5d_a_generators[Permutation(4, 3, 2, 1, 0)] = s5_5d_a_generators[Permutation(0, 1, 2, 3, 4)].inverse()
s5_5d_a_generators[Permutation(4)(0,1)] = torch.tensor([
    [ 0,  0, -1,  0,  0],
    [ 0,  0,  0, -1,  0],
    [-1,  0,  0,  0,  0],
    [ 0, -1,  0,  0,  0],
    [ 0,  0,  0,  0, -1]
]).float()
s5_5d_a_rep = SymmetricRepresentationFromGenerators([s5_5d_a_generators, group.G, group.idx_to_perm], rep_params, 's5_5d_a')

# 3,2 specht
s5_5d_b_generators = {}
s5_5d_b_generators[Permutation(0, 1, 2, 3, 4)] = torch.tensor([
    [-1,  1, -1,  0,  0],
    [ 0,  0,  0,  1, -1],
    [ 0,  0,  1,  0, -1],
    [ 1,  0,  0,  0,  0],
    [ 1,  0,  1,  0,  0]
]).float()
s5_5d_b_generators[Permutation(4, 3, 2, 1, 0)] = s5_5d_b_generators[Permutation(0, 1, 2, 3, 4)].inverse()
s5_5d_b_generators[Permutation(4)(0,1)] = torch.tensor([
    [ 1,  0,  0,  0,  0],
    [ 0,  0,  0, -1,  0],
    [-1,  0,  0,  0, -1],
    [ 0, -1,  0,  0,  0],
    [-1,  0, -1,  0,  0]
]).float()
s5_5d_b_rep = SymmetricRepresentationFromGenerators([s5_5d_b_generators, group.G, group.idx_to_perm], rep_params, 's5_5d_b')

# 3,1,1 specht
s5_6d_generators = {}
s5_6d_generators[Permutation(0, 1, 2, 3, 4)] = torch.tensor([
    [ 1, -1,  1,  0,  0,  0],
    [ 1,  0,  0, -1,  1,  0],
    [ 0,  1,  0, -1,  0,  1],
    [ 1,  0,  0,  0,  0,  0],
    [ 0,  1,  0,  0,  0,  0],
    [ 0,  0,  0,  1,  0,  0]
]).float()
s5_6d_generators[Permutation(4, 3, 2, 1, 0)] = s5_6d_generators[Permutation(0, 1, 2, 3, 4)].inverse()
s5_6d_generators[Permutation(4)(0,1)] = torch.tensor([
    [-1,  0,  0,  0,  0,  0],
    [ 0,  0,  0, -1,  0,  0],
    [ 0,  0,  0,  0, -1,  0],
    [ 0, -1,  0,  0,  0,  0],
    [ 0,  0, -1,  0,  0,  0],
    [ 0,  0,  0,  0,  0,  1]
]).float()
s5_6d_rep = SymmetricRepresentationFromGenerators([s5_6d_generators, group.G, group.idx_to_perm], rep_params, 's5_6d')


Computing trace tensor cube for s5_5d_a representation
... loading from file
Computing trace tensor cube for s5_5d_b representation
... loading from file
Computing trace tensor cube for s5_6d representation
... loading from file


In [153]:
group.irreps = {
    'trivial': trivial_rep,
    'sign': sign_rep,
    'standard': standard_rep,
    'standard_sign': standard_sign_rep,
    's5_5d_a': s5_5d_a_rep,
    's5_5d_b': s5_5d_b_rep,
    's5_6d': s5_6d_rep
}

group.non_trivial_irreps = group.irreps.copy()
del group.non_trivial_irreps['trivial']

Now we have all our representations loaded, let's inspect them...

## Reverse engineering the final model

In [154]:
all_data, _ = group.get_all_data(False)
all_data, all_labels = all_data[:, :2], all_data[:, 2]
with torch.inference_mode():
    all_logits, activations = model.run_with_cache(all_data)

    

In [155]:
print(all_labels.shape)

torch.Size([14400])


### Logit Similarity


In [156]:
def logit_trace_similarity(logits, trace_cube):
    """
    Compute cosine similarity between true logits and logits computed via tr(\rho(x)\rho(y)\rho(z^-1))

    Args:
        logits (torch.tensor): (batch, group.order) tensor of logits
        trace_cube (torch.tensor): (group.order, group.order, group.order) tensor of tr(\rho(x)\rho(y)\rho(z^-1))

    Returns:
        float: mean cosine similarity over batch
    """
    centered_logits = logits - logits.mean(dim=-1, keepdim=True)
    centered_logits = centered_logits.reshape(-1)
    trace = trace_cube.reshape(-1)
    sim = F.cosine_similarity(centered_logits, trace, dim=0)
    return sim


key_reps = []
percent_explained = 0
for rep_name, rep in group.non_trivial_irreps.items():
    sim = logit_trace_similarity(all_logits, rep.logit_trace_tensor_cube)
    if sim > 0.005:
        key_reps.append(rep_name)
        percent_explained += sim **2
    print(f'{rep_name}: {sim:.4f}')

# print key reps
print('\nKey reps:')
print(key_reps)

# print percent logit explained as a percentage
print('\nPercent logit explained:')
print(percent_explained.item() * 100)



sign: 0.5141
standard: 0.7714
standard_sign: 0.0000
s5_5d_a: 0.0000
s5_5d_b: 0.0007
s5_6d: 0.0000

Key reps:
['sign', 'standard']

Percent logit explained:
85.93581914901733


In [157]:
# visualise a slice of true logits, and key rep logits
all_logits_cube = all_logits.reshape(group.order, group.order, group.order)
true_logit_slice = all_logits_cube[:, :, 0]/all_logits_cube[:, :, 0].abs().max()
key_rep_logit_slices = {}
for key_rep in key_reps:
    key_rep_logit_slices[key_rep] = (group.irreps[key_rep].logit_trace_tensor_cube[:, :, 0]/group.irreps[key_rep].logit_trace_tensor_cube[:, :, 0].abs().max())

imshow(true_logit_slice, input1="a", input2="b", title="True logit c=0")
for key_rep, key_rep_logit_slice in key_rep_logit_slices.items():
    imshow(key_rep_logit_slice, input1="a", input2="b", title=f'{key_rep} logit c=0')
    


### Embed and Unembed

The model has no non linearity between it's embedding and hidden layer. It therefore makes sense to interpret the factored matrices together.

In [158]:
def get_embeds(model):
    """ 
    Get the embedding matrices for x and y
    """
    embeds = model.W_x @ model.W[:model.embed_dim, :], model.W_y @ model.W[model.embed_dim:, :]
    return embeds

def get_unembed(model):
    """ 
    Get the unembedding matrix
    """
    unembed = model.W_U
    return unembed

In [159]:
def percent_total_embed(model, orth_rep):
    """
    Compute the percent of the total embedding represented by the representation. Total embedding is the matmul of the embedding and the linear layer.

    Args:
        model (nn.Module): neural network
        orth_rep (torch.tensor): orthonormal representation

    Returns:
        (float, float): (total percent x, total percent y)
    """
    x_embed, y_embed = get_embeds(model)

    norm_x = x_embed.pow(2).sum()
    norm_y = y_embed.pow(2).sum()

    coefs_x = orth_rep.T @ x_embed
    coefs_y = orth_rep.T @ y_embed

    conts_x = coefs_x.pow(2).sum(-1) / norm_x
    conts_y = coefs_y.pow(2).sum(-1) / norm_y

    return conts_x.sum().item(), conts_y.sum().item()

def percent_unembed(model, orth_rep):
    """
    Compute the percent of the unembed represented by the representation.

    Args:
        model (nn.Module): neural network
        orth_rep (torch.tensor): orthonormal representation

    Returns:
        (float, float): (total percent x, total percent y)

    """
    W_U = get_unembed(model)
    norm_U = W_U.pow(2).sum()
    coefs_U = orth_rep.T @ W_U.T
    conts_U = coefs_U.pow(2).sum(-1) / norm_U
    return conts_U.sum().item()


In [160]:
print('Representation, frac left embed, frac right embed, frac unembed')
percent_left_explained_by_key_reps = 0
percent_right_explained_by_key_reps = 0
percent_unembed_explained_by_key_reps = 0
for rep_name, rep in group.non_trivial_irreps.items():
    left_embed, right_embed = percent_total_embed(model, rep.orth_rep)
    unembed = percent_unembed(model, rep.orth_rep)
    print(f'{rep_name}: {left_embed:.4f}, {right_embed:.4f}, {unembed:.4f}')

    if rep_name in key_reps:
        percent_left_explained_by_key_reps += left_embed
        percent_right_explained_by_key_reps += right_embed
        percent_unembed_explained_by_key_reps += unembed

print('\nFrac left explained by key reps:')
print(percent_left_explained_by_key_reps)
print('\nFrac right explained by key reps:')
print(percent_right_explained_by_key_reps)
print('\nFrac unembed explained by key reps:')
print(percent_unembed_explained_by_key_reps)

        

Representation, frac left embed, frac right embed, frac unembed
sign: 0.0673, 0.0674, 0.1000
standard: 0.9327, 0.9326, 0.8378
standard_sign: 0.0000, 0.0000, 0.0005
s5_5d_a: 0.0000, 0.0000, 0.0012
s5_5d_b: 0.0000, 0.0000, 0.0032
s5_6d: 0.0000, 0.0000, 0.0017

Frac left explained by key reps:
1.0000000149011612

Frac right explained by key reps:
1.000000037252903

Frac unembed explained by key reps:
0.9378509074449539


### Hidden Layer Neurons


In [161]:
for rep_name in group.irreps:
    group.irreps[rep_name].hidden_reps_x = group.irreps[rep_name].rep[all_data[:, 0]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_x_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_x)[0]
    group.irreps[rep_name].hidden_reps_y = group.irreps[rep_name].rep[all_data[:, 1]].reshape(group.order**2, -1)
    group.irreps[rep_name].hidden_reps_y_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_y)[0]
    group.irreps[rep_name].hidden_reps_xy = group.irreps[rep_name].rep[all_labels].reshape(group.order*group.order, -1)
    group.irreps[rep_name].hidden_reps_xy_orth = torch.linalg.qr(group.irreps[rep_name].hidden_reps_xy)[0]

In [162]:
def get_hidden(model):
    """ 
    Get the final MLP neuron activations for all data points
    """
    logits, activations = model.run_with_cache(all_data)
    hidden = activations['hidden'] 
    return hidden

def percent_hidden(model, rep_name):
    """
    Compute the percent of the total hidden representation represented by the representation matrices \rho(xy).

    Args:
        model (nn.Module): neural network
        hidden_reps_xy (torch.tensor): orthonormal hidden representations \rho(xy)

    """
    hidden = get_hidden(model)
    hidden = hidden - hidden.mean(dim=0, keepdim=True)

    hidden_norm = hidden.pow(2).sum()

    hidden_reps_x_orth = group.irreps[rep_name].hidden_reps_x_orth
    hidden_reps_y_orth = group.irreps[rep_name].hidden_reps_y_orth
    hidden_reps_xy_orth = group.irreps[rep_name].hidden_reps_xy_orth
    coefs_x = hidden_reps_x_orth.T @ hidden
    coefs_y = hidden_reps_y_orth.T @ hidden
    coefs_xy = hidden_reps_xy_orth.T @ hidden
    x_conts = coefs_x.pow(2).sum() / hidden_norm
    y_conts = coefs_y.pow(2).sum() / hidden_norm
    xy_conts = coefs_xy.pow(2).sum() / hidden_norm
    total_conts = x_conts + y_conts + xy_conts

    return x_conts.item(), y_conts.item(), xy_conts.item(), total_conts.item()

In [163]:
for rep in group.non_trivial_irreps:
    x_conts, y_conts, xy_conts, total_conts = percent_hidden(model, rep)
    print(f'{rep}: {x_conts:.4f}, {y_conts:.4f}, {xy_conts:.4f}, {total_conts:.4f}')
    

sign: 0.0244, 0.0244, 0.0244, 0.0732
standard: 0.3423, 0.3756, 0.1055, 0.8235
standard_sign: 0.0000, 0.0000, 0.0000, 0.0000
s5_5d_a: 0.0000, 0.0000, 0.0000, 0.0000
s5_5d_b: 0.0002, 0.0002, 0.0001, 0.0005
s5_6d: 0.0000, 0.0000, 0.0000, 0.0000


In [164]:
def hidden_to_logits(hidden, model):
    """ 
    Convert hidden activations to logits via the correct unembed
    """
    return hidden @ model.W_U


def hidden_excluded_and_restricted_loss(model, hidden_reps_xy_orth):
    """ 
    Restrict or exclude reps rho(ab) from the hidden layer and compute the loss on the restricted and excluded parts of the hidden layer.
    """
    hidden = get_hidden(model)
    
    coefs_xy = hidden_reps_xy_orth.T @ hidden
    hidden_xy = hidden_reps_xy_orth @ coefs_xy

    hidden_xy_restricted = hidden_xy
    hidden_xy_excluded = hidden - hidden_xy

    logits_restricted = hidden_to_logits(hidden_xy_restricted, model)
    logits_excluded = hidden_to_logits(hidden_xy_excluded, model)

    restricted_loss = loss_fn(logits_restricted, all_labels).item()
    excluded_loss = loss_fn(logits_excluded[train_indices], train_labels).item()

    return excluded_loss, restricted_loss

    
def total_hidden_excluded_and_restricted_loss(model, key_reps):
    """ 
    Restrict or exclude all key reps rho(ab) from the hidden layer and compute the loss on the restricted and excluded parts of the hidden layer.
    """
    hidden = get_hidden(model)

    hidden_restricted = torch.zeros_like(hidden)
    for rep_name in key_reps:
        coefs_xy =  group.irreps[rep_name].hidden_reps_xy_orth.T @ hidden
        hidden_xy = group.irreps[rep_name].hidden_reps_xy_orth @ coefs_xy
        hidden_restricted += hidden_xy

    hidden_excluded = hidden - hidden_restricted

    logits_restricted =  hidden_to_logits(hidden_restricted, model)
    logits_excluded =  hidden_to_logits(hidden_excluded, model)

    restricted_loss = loss_fn(logits_restricted, all_labels).item()
    excluded_loss = loss_fn(logits_excluded[train_indices], train_labels).item()

    return excluded_loss, restricted_loss

In [165]:
print('Excluded loss, restricted loss')
for rep in group.non_trivial_irreps:
    excluded_loss, restricted_loss = hidden_excluded_and_restricted_loss(model, group.irreps[rep].hidden_reps_xy_orth)
    print(f'{rep}: {excluded_loss:.4f}, {restricted_loss:.4f}')

excluded_loss, restricted_loss = total_hidden_excluded_and_restricted_loss(model, key_reps)
print(f'Total: {excluded_loss:.4f}, {restricted_loss:.4f}')
    

Excluded loss, restricted loss
sign: 0.0006, 4.0955
standard: 7.0111, 0.0001
standard_sign: 0.0000, 4.7875
s5_5d_a: 0.0000, 4.7875
s5_5d_b: 0.0000, 4.7614
s5_6d: 0.0000, 4.7875
Total: 7.3947, 0.0000


### Neuron Clustering


In [166]:
# evidence: neuron clustering pre ReLU

threshold = 1

x_embed, y_embed = get_embeds(model)

x_embed_summed = x_embed.pow(2).sum(dim=0)
off_neurons_x = (x_embed_summed < threshold).nonzero().squeeze()

y_embed_summed = y_embed.pow(2).sum(dim=0)
off_neurons_y = (y_embed_summed < threshold).nonzero().squeeze()

assert (off_neurons_x == off_neurons_y).all()

off_neurons = off_neurons_x

print(f'Off neurons: {len(off_neurons)}, {off_neurons}')

rep_neurons = {}

print('Neurons corresponding to each representation')
for rep_name in group.non_trivial_irreps:
    rep = group.irreps[rep_name].orth_rep
    coefs_x = rep.T @ x_embed
    coefs_y = rep.T @ y_embed
    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)

    x_neurons = (coefs_x_summed > threshold).nonzero().squeeze()
    y_neurons = (coefs_y_summed > threshold).nonzero().squeeze()
    assert (x_neurons == y_neurons).all()
    x_neurons = torch.tensor(x_neurons)
    if x_neurons.dim() == 0:
        x_neurons = x_neurons.unsqueeze(0)
    rep_neurons[rep_name] = x_neurons
    print(f'{rep_name}: {len(x_neurons)}, {x_neurons}')

all_neurons = torch.arange(model.W_U.shape[0])
unaccounted_neurons = set(all_neurons.tolist())
unaccounted_neurons -= set(off_neurons.tolist())
for rep_name, neurons in rep_neurons.items():
    unaccounted_neurons -= set(neurons.tolist())

print('Unaccounted neurons')
print(unaccounted_neurons)

Off neurons: 15, tensor([ 3,  5, 12, 17, 45, 50, 54, 59, 65, 68, 81, 87, 88, 95, 99],
       device='cuda:0')
Neurons corresponding to each representation
sign: 10, tensor([ 55,  62,  66,  75,  84,  89,  90, 102, 123, 127], device='cuda:0')
standard: 103, tensor([  0,   1,   2,   4,   6,   7,   8,   9,  10,  11,  13,  14,  15,  16,
         18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,
         32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  46,
         47,  48,  49,  51,  52,  53,  56,  57,  58,  60,  61,  63,  64,  67,
         69,  70,  71,  72,  73,  74,  76,  77,  78,  79,  80,  82,  83,  85,
         86,  91,  92,  93,  94,  96,  97,  98, 100, 101, 103, 104, 105, 106,
        107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
        121, 122, 124, 125, 126], device='cuda:0')
standard_sign: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_a: 0, tensor([], device='cuda:0', dtype=torch.int64)
s5_5d_b: 0, tensor([], 


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [167]:
# evidence and table: neuron clustering in post hidden layer

threshold = 110

hidden = activations['hidden'].reshape(group.order**2, -1)
hidden = hidden - hidden.mean(dim=0, keepdim=True)

hidden_summed = hidden.pow(2).sum(dim=0)
off_neurons = (hidden_summed < threshold).nonzero().squeeze()

assert (off_neurons == off_neurons_x).all()

print(f'Off neurons: {off_neurons}')


fracs_explained_x = {}
fracs_explained_y = {}
fracs_explained_xy = {}
fracs_explained_trivial = {}

for rep_name in key_reps:
    rep_x = group.irreps[rep_name].hidden_reps_x_orth
    rep_y = group.irreps[rep_name].hidden_reps_y_orth
    rep_xy = group.irreps[rep_name].hidden_reps_xy_orth

    trivial = group.irreps['trivial'].hidden_reps_x_orth

    coefs_x = rep_x.T @ hidden
    coefs_y = rep_y.T @ hidden
    coefs_xy = rep_xy.T @ hidden

    coefs_trivial = trivial.T @ hidden

    coefs_x_summed = coefs_x.pow(2).sum(dim=0)
    coefs_y_summed = coefs_y.pow(2).sum(dim=0)
    coefs_xy_summed = coefs_xy.pow(2).sum(dim=0)
    coefs_trivial_summed = coefs_trivial.pow(2).sum(dim=0)


    neurons = rep_neurons[rep_name]

    frac_x = (coefs_x_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_y = (coefs_y_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_xy = (coefs_xy_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())
    frac_trivial = (coefs_trivial_summed[neurons]).sum() / (hidden[:, neurons].pow(2).sum())

    fracs_explained_x[rep_name] = frac_x
    fracs_explained_y[rep_name] = frac_y
    fracs_explained_xy[rep_name] = frac_xy
    fracs_explained_trivial[rep_name] = frac_trivial

print('Neurons corresponding to each representation')
for key in key_reps:
    print(f'frac variance explained in {key} x, y, xy: {fracs_explained_x[key], fracs_explained_y[key], fracs_explained_xy[key], fracs_explained_trivial[key]}')
    print(f'Sum of explained variance: {fracs_explained_x[key] + fracs_explained_y[key] + fracs_explained_xy[key] + fracs_explained_trivial[key]}')

Off neurons: tensor([ 3,  5, 12, 17, 45, 50, 54, 59, 65, 68, 81, 87, 88, 95, 99],
       device='cuda:0')
Neurons corresponding to each representation
frac variance explained in sign x, y, xy: (tensor(0.3333, device='cuda:0'), tensor(0.3334, device='cuda:0'), tensor(0.3333, device='cuda:0'), tensor(2.2360e-14, device='cuda:0'))
Sum of explained variance: 0.9999999403953552
frac variance explained in standard x, y, xy: (tensor(0.3694, device='cuda:0'), tensor(0.4053, device='cuda:0'), tensor(0.1138, device='cuda:0'), tensor(8.3030e-15, device='cuda:0'))
Sum of explained variance: 0.8885210156440735


### Extracting the final linear map in the representation basis

In [168]:
def projection_matrix_general(B):
    """Compute the projection matrix onto the space spanned by the columns of `B`
    Args:
        B: ndarray of dimension (D, M), the basis for the subspace
    
    Returns:
        P: the projection matrix
    """
    P = B @ (B.T @ B).inverse() @ B.T
    return P

hidden = activations['hidden'].reshape(group.order*group.order, -1)
hidden_to_reps_proj = {}
coefs = {}

for rep_name in key_reps:
    hidden_reps_xy = group.irreps[rep_name].hidden_reps_xy



    P = projection_matrix_general(hidden_reps_xy)
    hidden_xy = P @ hidden

    hidden_to_reps_proj[rep_name] = hidden_reps_xy.T @ hidden_xy
    #hidden_to_reps_proj[rep_name] = hidden_to_reps_proj[rep_name] / hidden_to_reps_proj[rep_name].norm(dim=1, keepdim=True)
    if rep_name == 'standard':
        plot = hidden_to_reps_proj[rep_name]
        plot = plot/plot.norm(dim=1, keepdim=True)
        plot = plot.detach().cpu().numpy() 
        fig = px.imshow(plot, color_continuous_scale='RdBu', color_continuous_midpoint=0.0, labels={'x':'neuron basis', 'y':'rep basis'})
        fig.update_layout(
            height = 180,
            margin = dict(l=80, r=80, t=0, b=30, pad=0)
        )
        # set the length of the colorbar to 1
        fig.update_layout(coloraxis_colorbar=dict(
            len  = 0.72,
        ))
        fig.show()

    hidden_in_rep = hidden_xy @ hidden_to_reps_proj[rep_name].T

    theoretical_reps = hidden_reps_xy.reshape(group.order*group.order, -1)

    hidden_in_rep_norm = hidden_in_rep.flatten() / hidden_in_rep.flatten().norm()
    theoretical_reps_norm = theoretical_reps.flatten() / theoretical_reps.flatten().norm()

    # MSE loss between hidden_in_rep and theoretical_reps
    sim = F.mse_loss(hidden_in_rep_norm, theoretical_reps_norm)
    print(f'MSE Loss between hidden layer and theoretical representations: {sim}')

    # get the coef
    coef = (hidden_in_rep.norm() / theoretical_reps.norm())
    coefs[rep_name] = coef

MSE Loss between hidden layer and theoretical representations: 0.0


MSE Loss between hidden layer and theoretical representations: 8.991650979339738e-09


In [169]:
# evidence: 
rep_name = 'standard'
W_U = model.W_U
rep = group.irreps[rep_name].rep.reshape(group.order, -1)
#rep = rep / rep.norm(dim=0, keepdim=True)
W_U_rep = hidden_to_reps_proj[rep_name] @ W_U @ rep [group.inverses]
fig = px.imshow(to_numpy(W_U_rep), color_continuous_scale='RdBu', color_continuous_midpoint=0.0, labels={'x':'$\Large \\rho(c^{-1})$', 'y':'$\Large \\rho(ab)$'})
fig.update_layout(
            margin = dict(l=50, r=80, t=20, b=80, pad=0),
            width = 500,
        )
        # make the colorbar length 1
fig.update_layout(coloraxis_colorbar=dict(
    lenmode="fraction",
    len=1.07,
))
fig.show()

real_linear_map = (W_U_rep > 1e5).float()
real_linear_map_norm = real_linear_map.flatten() / real_linear_map.flatten().norm()
W_U_rep_norm = W_U_rep.flatten() / W_U_rep.flatten().norm()
sim = F.mse_loss(real_linear_map_norm, W_U_rep_norm)
#sim = F.cosine_similarity(W_U_rep.flatten(), real_linear_map.flatten(), dim=0)
print(f'MSE loss between unembedding matrix and real linear map: {sim}')

MSE loss between unembedding matrix and real linear map: 1.525467450846918e-05


## Analysis during training


In [179]:
def sum_of_squared_weights(model):
    """
    Compute the sum of squared weights on the whole model.

    Args:
        model (nn.Module)

    Returns:
        float: sum of squared weights on entire group
    """

    sum_of_square_weights = 0

    if model.__class__.__name__ == "OneLayerMLP":
        sum_of_square_weights += torch.sum(model.W_x**2)
        sum_of_square_weights += torch.sum(model.W_y**2)
        sum_of_square_weights += torch.sum(model.W_U**2)
        sum_of_square_weights += torch.sum(model.W**2)
    
    return sum_of_square_weights.item()

In [180]:
checkpointed_epochs = list(range(0, training_cfg['num_epochs'], checkpoint_every))
metrics = {}
for rep in group.non_trivial_irreps:
    metrics['logit_trace_similarity_' + rep + '_rep'] = []
    metrics['percent_left_embed_' + rep + '_rep'] = []
    metrics['percent_right_embed_' + rep + '_rep'] = []
    metrics['percent_unembed_' + rep + '_rep'] = []
    metrics['percent_hidden_' + rep + '_rep'] = []
    metrics['excluded_loss_' + rep + '_rep'] = []
    metrics['restricted_loss_' + rep + '_rep'] = []
metrics['total_excluded_loss'] = []
metrics['total_restricted_loss'] = []
metrics['test_loss'] = []
metrics['test_acc'] = []
metrics['train_loss'] = []
metrics['train_acc'] = []
metrics['sum_of_square_weights'] = []
    
    
for epoch in checkpointed_epochs:
    load_checkpoint(model, task_dir, epoch)
    model.eval()
    metrics['test_loss'].append(loss_fn(model(test_data), test_labels).item())
    metrics['test_acc'].append(get_accuracy(model(test_data), test_labels))
    metrics['train_loss'].append(loss_fn(model(train_data), train_labels).item())
    metrics['train_acc'].append(get_accuracy(model(train_data), train_labels))
    all_logits = model(all_data)
    for rep in group.non_trivial_irreps:
        #metrics['logit_trace_similarity_' + rep + '_rep'].append(logit_trace_similarity(all_logits, group.irreps[rep].logit_trace_tensor_cube))
        left_embed, right_embed = percent_total_embed(model, group.irreps[rep].orth_rep)
        metrics['percent_left_embed_' + rep + '_rep'].append(left_embed)
        metrics['percent_right_embed_' + rep + '_rep'].append(right_embed)
        unembed = percent_unembed(model, group.irreps[rep].orth_rep)
        metrics['percent_unembed_' + rep + '_rep'].append(unembed)
        _, _, hidden_xy, _ = percent_hidden(model, rep)
        metrics['percent_hidden_' + rep + '_rep'].append(hidden_xy)
        excluded_loss, restricted_loss = hidden_excluded_and_restricted_loss(model, group.irreps[rep].hidden_reps_xy_orth)
        metrics['excluded_loss_' + rep + '_rep'].append(excluded_loss)
        metrics['restricted_loss_' + rep + '_rep'].append(restricted_loss)
    total_excluded_loss, total_restricted_loss = total_hidden_excluded_and_restricted_loss(model, key_reps)
    metrics['total_excluded_loss'].append(total_excluded_loss)
    metrics['total_restricted_loss'].append(total_restricted_loss)
    metrics['sum_of_square_weights'].append(sum_of_squared_weights(model))
    

    


In [171]:
non_trivial_irreps_list = list(group.non_trivial_irreps)

In [172]:
# plot left embeds
lines([metrics['percent_left_embed_' + rep + '_rep'] for rep in group.non_trivial_irreps],
      x=checkpointed_epochs,
      labels=non_trivial_irreps_list,
      title='Percent of left embedding matrix explained by rep',
      xaxis='epoch',
      yaxis='percent')

In [173]:
# plot right embeds
lines([metrics['percent_right_embed_' + rep + '_rep'] for rep in group.non_trivial_irreps],
      x=checkpointed_epochs,
      labels=non_trivial_irreps_list,
      title='Percent of right embedding matrix explained by rep',
      xaxis='epoch',
      yaxis='percent')

In [174]:
# plot unembeds
lines([metrics['percent_unembed_' + rep + '_rep'] for rep in group.non_trivial_irreps],
        x=checkpointed_epochs,
        labels=non_trivial_irreps_list,
        title='Percent of unembedding matrix explained by rep',
        xaxis='epoch',
        yaxis='percent')


In [175]:
print(metrics['percent_hidden_standard_rep'])

[0.0001500012876931578, 0.00047078728675842285, 0.0007364900666289032, 0.0010661304695531726, 0.0014370688004419208, 0.00184731581248343, 0.0023502500262111425, 0.002759685041382909, 0.0030899858102202415, 0.0033705136738717556, 0.003631064435467124, 0.003825458465144038, 0.0039922213181853294, 0.004171977750957012, 0.004311450757086277, 0.004425874445587397, 0.004525103606283665, 0.004686822649091482, 0.004805736243724823, 0.004941439256072044, 0.005065194796770811, 0.005187405273318291, 0.0053070466965436935, 0.005434718914330006, 0.0055747260339558125, 0.005723624024540186, 0.005867467727512121, 0.006032347213476896, 0.0062080854550004005, 0.006424125283956528, 0.006718774791806936, 0.007105672266334295, 0.007517695892602205, 0.007966691628098488, 0.008465652354061604, 0.009099755436182022, 0.009620602242648602, 0.010006606578826904, 0.010300724767148495, 0.010520825162529945, 0.010710221715271473, 0.010870308615267277, 0.01101385336369276, 0.011154270730912685, 0.011292294599115849

In [176]:
# percent hidden
lines([metrics['percent_hidden_' + rep + '_rep'] for rep in group.non_trivial_irreps],
        x=checkpointed_epochs,
        labels=non_trivial_irreps_list,
        title='Percent of hidden explained by rep',
        xaxis='epoch',
        yaxis='percent')


In [177]:
# plot losses
lines([metrics['train_loss'], metrics['test_loss'], metrics['total_excluded_loss'], metrics['total_restricted_loss']], x=checkpointed_epochs, log_y=True, labels=['train loss', 'test loss', 'total excluded loss', 'total restricted loss'], title='Losses', xaxis='Epoch', yaxis='Loss')

In [183]:
# excluded loss by rep
lines([metrics['excluded_loss_' + rep + '_rep'] for rep in group.non_trivial_irreps],
        x=checkpointed_epochs,
        labels=non_trivial_irreps_list,
        title='Excluded loss by rep',
        xaxis='epoch',
        yaxis='loss',
        log_y=True)

# restricted loss by rep
lines([metrics['restricted_loss_' + rep + '_rep'] for rep in group.non_trivial_irreps],
        x=checkpointed_epochs,
        labels=non_trivial_irreps_list,
        title='Restricted loss by rep',
        xaxis='epoch',
        yaxis='loss',
        log_y=True)



In [181]:
# plot sum of squared weights

lines([metrics['sum_of_square_weights']], x=checkpointed_epochs, log_y=True, labels=['sum of squared weights'], title='Sum of squared weights', xaxis='Epoch', yaxis='Sum of squared weights')