In this notebook, we will train a network to perform group composition on the abstract group S5, and then reverse engineer the algorithm learned. This notebook should be fairly self contained, and uses similar code to that in the paper. Full code for the paper is available at ...

# Setup


In [8]:
# Imports
import os
if "cd" not in globals():
    os.chdir("../")
    cd = True
import torch
import torch.nn as nn
import numpy as np
import math
from tqdm import tqdm
from transformer_lens.hook_points import HookPoint, HookedRootModule
from sympy.combinatorics.named_groups import SymmetricGroup as SymPySymmetricGroup
from sympy.combinatorics import Permutation, PermutationGroup
from utils.plotting import *

In [9]:
SEED = 2

Let's first create a class to contain all the information about our group. We include methods 
* to convert between permutations and their indices that we use to one-hot-encode them. 
* to compose permutations, based on their index
* to calculate the order of a permutation
* to calculate the signature of a permutation
* to cache all signatures of the group elements
* to calculate the inverse of a permutation
* to cache all inverses of group elements
* to calculate the full multiplication table of the group
* to return all pairs of permutations, and their product

In [10]:
class SymmetricGroup():
    """
    Class for the symmetric group of order index.
    """
    def __init__(self, index, init_all=False):
        """
        Initialise the group. Optionally calculate all other tensors needed to track metrics.

        Args:
            index (int): Index of group in family of symmetric groups.
            init_all (bool, optional): If false, only calculate what is required to train. If true, calculate all other tensors needed to track metrics. Defaults to False.
        """

        # build class on top of sympy
        self.G = SymPySymmetricGroup(index)
        self.index = index
        self.order = math.factorial(index)
        self.acronym = 'S'
        self.multiplication_table = self.compute_multiplication_table()
        
        # hacky method to find the index of the identity element 
        self.identity = [i for i in range(self.order) if self.idx_to_perm(i).order() == 1][0]
        
        self.inverses = self.compute_inverses()

        # compute the signatures of the elements
        self.signatures = self.compute_signatures()

        self.all_data, _ = self.get_all_data()
        self.all_data = self.all_data[:, :2]


    def idx_to_perm(self, x):
        """
        Convert an index to a permutation.

        Args:
            x (int): index of element in group

        Returns:
            Permutation: permutation object from sympy
        """
        return self.G._elements[x]

    def perm_to_idx(self, perm):
        """
        Converts a permutation to an index.

        Args:
            perm (Permutation): permutation object from sympy

        Returns:
            int: index of element in group
        """
        return self.G._elements.index(perm)

    def compose(self, x, y):
        """
        Compose elements of the group by converting to permutations, composing, and converting back.

        Args:
            x (int): left index
            y (int): right index

        Returns:
            int: index of composition
        """
        return self.perm_to_idx(self.idx_to_perm(x) * self.idx_to_perm(y))

    def perm_order(self, x):
        """
        Gets the order of a permutation.

        Args:
            x (int): index of element

        Returns:
            int: order of permutation
        """
        return self.idx_to_perm(x).order()

    def signature(self, x):
        """
        Gets the signature of a permutation.

        Args:
            x (int): index of element

        Returns:
            int: Integer \in {0, 1} representing the signature of the permutation.
        """
        return self.idx_to_perm(x).signature()
    
    def compute_signatures(self):
        """
        Compute and store the signature of each element in the group.

        Returns:
            torch.tensor: tensor of signatures
        """
        signatures = torch.tensor([self.signature(i) for i in range(self.order)]).cuda()
        return signatures
    
    def inverse(self, x):
        """
        Compute the inverse of an element of the group.

        Args:
            x (int): Index of element to inverse

        Returns:
            int: Index of inverse element
        """
        return (self.multiplication_table[x, :] == self.identity).nonzero().item()

    def compute_inverses(self):
        inverses = torch.zeros(self.order, dtype=torch.int64)
        for i in range(self.order):
            inverses[i] = self.inverse(i)
        return inverses
    
    def compute_multiplication_table(self):
        """
        Compute the multiplication table of the group. Caches/loads from file if possible.
        """
        print('Computing multiplication table...')
        table = torch.zeros((self.order, self.order), dtype=torch.int64).cuda()
        for i in tqdm(range(self.order)):
            for j in range(self.order):
                table[i, j] = self.compose(i, j)
        return table

    def get_all_data(self, shuffle_seed=False):
        """
        Get's all data and labels for the pairwise composition task.

        Args:
            shuffle_seed (bool, optional): Shuffle data for training. Defaults to False.

        Returns:
            torch.tensor: Tensor of shape (order*order, 3) where each row is (x, y, x*y).
        """
        data=torch.zeros((self.order*self.order, 3), dtype=torch.int64)
        shuffled_indices = None
        for i in range(self.order):
            for j in range(self.order):
                data[i*self.order+j, 0] = i
                data[i*self.order+j, 1] = j
                data[i*self.order+j, 2] = self.multiplication_table[i, j]
        if shuffle_seed:
            torch.manual_seed(shuffle_seed) 
            shuffled_indices = torch.randperm(self.order*self.order)
            data = data[shuffled_indices]
        return data.cuda(), shuffled_indices

In [11]:
group = SymmetricGroup(5)

Computing multiplication table...


100%|██████████| 120/120 [00:14<00:00,  8.42it/s]


The model architecture is an MLP as described in the paper. We subclass the HookedRootModule class from the transformer_lens library to easily cache activations on the forward pass.

In [12]:
model_cfg = {    
    "layers": {
        "embed_dim": 256,
        "hidden_dim": 128
    }
}

class OneLayerMLP(HookedRootModule):
    """ 
    A one layer MLP. W_x and W_y are embedding layers, whose outputs are concatenated and fed into a hidden layer. The result is unembedded by W_U.
    """
    def __init__(self, layers, n, seed=0):
        # embed_dim: dimension of the embedding
        # hidden : hidden dimension size
        # n : group order
        super().__init__()
        torch.manual_seed(seed)

        self.embed_dim = layers['embed_dim']
        hidden = layers['hidden_dim']

        # xavier initialise parameters
        self.W_x = nn.Parameter(torch.randn(n, self.embed_dim)/np.sqrt(self.embed_dim))
        self.W_y = nn.Parameter(torch.randn(n, self.embed_dim)/np.sqrt(self.embed_dim))
        self.W = nn.Parameter(torch.randn(2*self.embed_dim, hidden)/np.sqrt(2*self.embed_dim))
        self.relu = nn.ReLU()
        self.W_U = nn.Parameter(torch.randn(hidden, n)/np.sqrt(hidden))

        # hookpoints
        self.embed_stack = HookPoint()
        self.hidden = HookPoint()

        # We need to call the setup function of HookedRootModule to build an 
        # internal dictionary of modules and hooks, and to give each hook a name
        super().setup()

    def forward(self, data):
        x = data[:, 0] # (batch)
        half_x_embed = self.W_x[x] # (batch, embed_dim)
        y = data[:, 1] # (batch)
        half_y_embed = self.W_y[y] # (batch, embed_dim)
        embed_stack = self.embed_stack(torch.hstack((half_x_embed, half_y_embed))) # (batch, 2*embed_dim)
        hidden = self.hidden(self.relu(embed_stack @ self.W)) # (batch, hidden)
        out = hidden @ self.W_U # (batch, n)
        return out

def loss_fn(logits, labels):
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()

def get_accuracy(logits, labels):
    """
    Compute accuracy of model.

    Args:
        logits (torch.tensor): (batch, group.order) tensor of logits
        labels (torch.tensor): (batch) tensor of labels

    Returns:
        float: accuracy
    """
    return ((logits.argmax(1)==labels).sum()/len(labels)).item()

model = OneLayerMLP(model_cfg['layers'], group.order).cuda()

To generate our train/test split, we just take a fixed proportion of the entire multiplication table of the group at random. We'll use a ratio of 0.4 here.

In [13]:
def generate_train_test_data(group, frac_train, seed=False):
    """
    Generate train and test data from a group's all data.
    """
    data, shuffled_indices = group.get_all_data(seed)
    train_size = int(frac_train*data.shape[0])
    train = data[:train_size]
    test = data[train_size:]
    train_data = train[:, :2]
    train_labels = train[:, 2]
    test_data = test[:, :2]
    test_labels = test[:, 2]
    return train_data, test_data, train_labels, test_labels, shuffled_indices



# Training

In [14]:
training_cfg = {
    "lr": 0.001,
    "betas": (0.9, 0.98),
    "weight_decay": 1,
    "num_epochs": 75000,
    "frac_train": 0.4,
}
train_data, test_data, train_labels, test_labels, shuffled_indices = generate_train_test_data(group, training_cfg["frac_train"], seed=SEED)
optimizer = torch.optim.AdamW(model.parameters(), lr=training_cfg["lr"], betas=training_cfg["betas"], weight_decay=training_cfg["weight_decay"])

In [15]:
print(train_data)

tensor([[ 65,  48],
        [ 85,  94],
        [104,  89],
        ...,
        [ 92,  13],
        [ 75,  94],
        [  7,  80]], device='cuda:0')


In [16]:
train_losses = []
test_losses = []
train_accs = []
test_accs = []

for epoch in tqdm(range(training_cfg["num_epochs"])):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % 1000 == 0:
        with torch.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            train_acc = get_accuracy(train_logits, train_labels)
            test_acc = get_accuracy(test_logits, test_labels)
            print(f"Epoch {epoch}: Train Loss: {train_loss}, Test Loss: {test_loss}, Train Accuracy: {train_acc}, Test Accuracy: {test_acc}")
        train_losses.append(train_loss.item())
        test_losses.append(test_loss.item())
        train_accs.append(train_acc)
        test_accs.append(test_acc)





  0%|          | 50/75000 [00:00<09:03, 137.93it/s]

Epoch 0: Train Loss: 4.788640449718693, Test Loss: 4.789846160721121, Train Accuracy: 0.008854166604578495, Test Accuracy: 0.007986110635101795


  1%|▏         | 1084/75000 [00:02<02:32, 486.01it/s]

Epoch 1000: Train Loss: 0.0014230852654155432, Test Loss: 31.025889312314813, Train Accuracy: 1.0, Test Accuracy: 0.00023148147738538682


  3%|▎         | 2070/75000 [00:04<02:26, 497.34it/s]

Epoch 2000: Train Loss: 1.0371127345550383e-05, Test Loss: 42.32135782920235, Train Accuracy: 1.0, Test Accuracy: 0.00023148147738538682


  4%|▍         | 3090/75000 [00:06<02:23, 502.58it/s]

Epoch 3000: Train Loss: 9.28916789346865e-06, Test Loss: 39.74175323113012, Train Accuracy: 1.0, Test Accuracy: 0.000347222201526165


  5%|▌         | 4059/75000 [00:08<02:21, 502.92it/s]

Epoch 4000: Train Loss: 9.027984962200929e-06, Test Loss: 37.979951976766685, Train Accuracy: 1.0, Test Accuracy: 0.00046296295477077365


  7%|▋         | 5069/75000 [00:10<02:19, 500.19it/s]

Epoch 5000: Train Loss: 8.843711821515795e-06, Test Loss: 36.75291020285111, Train Accuracy: 1.0, Test Accuracy: 0.0005787037080153823


  8%|▊         | 6089/75000 [00:12<02:17, 502.97it/s]

Epoch 6000: Train Loss: 8.713580143243848e-06, Test Loss: 35.90106414693857, Train Accuracy: 1.0, Test Accuracy: 0.00069444440305233


  9%|▉         | 7058/75000 [00:14<02:15, 500.10it/s]

Epoch 7000: Train Loss: 8.605533626326864e-06, Test Loss: 35.235500056005996, Train Accuracy: 1.0, Test Accuracy: 0.00069444440305233


 11%|█         | 8078/75000 [00:16<02:13, 500.30it/s]

Epoch 8000: Train Loss: 8.513514004822296e-06, Test Loss: 34.65525806014785, Train Accuracy: 1.0, Test Accuracy: 0.00069444440305233


 12%|█▏        | 9098/75000 [00:18<02:11, 500.39it/s]

Epoch 9000: Train Loss: 8.430327062914028e-06, Test Loss: 34.144935524117926, Train Accuracy: 1.0, Test Accuracy: 0.0009259259095415473


 13%|█▎        | 10063/75000 [00:20<02:10, 498.73it/s]

Epoch 10000: Train Loss: 8.363853707070644e-06, Test Loss: 33.728746871927584, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 15%|█▍        | 11080/75000 [00:22<02:06, 503.32it/s]

Epoch 11000: Train Loss: 8.30144133616486e-06, Test Loss: 33.36465937707598, Train Accuracy: 1.0, Test Accuracy: 0.0009259259095415473


 16%|█▌        | 12100/75000 [00:24<02:05, 501.56it/s]

Epoch 12000: Train Loss: 8.237403655008239e-06, Test Loss: 33.042006361873064, Train Accuracy: 1.0, Test Accuracy: 0.0009259259095415473


 17%|█▋        | 13047/75000 [00:26<02:19, 444.23it/s]

Epoch 13000: Train Loss: 8.174784670199603e-06, Test Loss: 32.73591884208349, Train Accuracy: 1.0, Test Accuracy: 0.0009259259095415473


 19%|█▉        | 14090/75000 [00:29<02:15, 448.83it/s]

Epoch 14000: Train Loss: 8.126368153885452e-06, Test Loss: 32.454488264784196, Train Accuracy: 1.0, Test Accuracy: 0.0008101851562969387


 20%|██        | 15096/75000 [00:31<01:59, 501.91it/s]

Epoch 15000: Train Loss: 8.076891728273977e-06, Test Loss: 32.183694283909034, Train Accuracy: 1.0, Test Accuracy: 0.00069444440305233


 21%|██▏       | 16065/75000 [00:33<01:57, 502.41it/s]

Epoch 16000: Train Loss: 8.03142297004289e-06, Test Loss: 31.877598700000124, Train Accuracy: 1.0, Test Accuracy: 0.0008101851562969387


 23%|██▎       | 17085/75000 [00:35<01:55, 503.15it/s]

Epoch 17000: Train Loss: 7.985882949788269e-06, Test Loss: 31.588083174209576, Train Accuracy: 1.0, Test Accuracy: 0.0009259259095415473


 24%|██▍       | 18101/75000 [00:37<01:54, 499.01it/s]

Epoch 18000: Train Loss: 7.942705801747821e-06, Test Loss: 31.300450615126945, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 25%|██▌       | 19069/75000 [00:39<01:51, 499.60it/s]

Epoch 19000: Train Loss: 7.902100509987374e-06, Test Loss: 31.071757274844458, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 27%|██▋       | 20081/75000 [00:41<01:49, 503.00it/s]

Epoch 20000: Train Loss: 7.866609099229755e-06, Test Loss: 30.825448419846026, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 28%|██▊       | 21090/75000 [00:43<01:59, 451.56it/s]

Epoch 21000: Train Loss: 7.832928618637219e-06, Test Loss: 30.57406529479693, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 29%|██▉       | 22098/75000 [00:45<01:46, 498.11it/s]

Epoch 22000: Train Loss: 7.801011500145634e-06, Test Loss: 30.32559158391679, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 31%|███       | 23061/75000 [00:47<01:43, 499.64it/s]

Epoch 23000: Train Loss: 7.76612875167956e-06, Test Loss: 30.067768782116268, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 32%|███▏      | 24080/75000 [00:49<01:41, 500.74it/s]

Epoch 24000: Train Loss: 7.730398161195362e-06, Test Loss: 29.786146012719477, Train Accuracy: 1.0, Test Accuracy: 0.00138888880610466


 33%|███▎      | 25100/75000 [00:51<01:39, 500.87it/s]

Epoch 25000: Train Loss: 7.697242767052906e-06, Test Loss: 29.493918591238668, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 35%|███▍      | 26069/75000 [00:53<01:37, 500.54it/s]

Epoch 26000: Train Loss: 7.65714235115075e-06, Test Loss: 29.166109643467536, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 36%|███▌      | 27084/75000 [00:55<01:36, 498.40it/s]

Epoch 27000: Train Loss: 7.6153467795990345e-06, Test Loss: 28.831969611141492, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 37%|███▋      | 28097/75000 [00:57<01:33, 499.94it/s]

Epoch 28000: Train Loss: 7.566795523489502e-06, Test Loss: 28.46175553196621, Train Accuracy: 1.0, Test Accuracy: 0.001041666604578495


 39%|███▊      | 29059/75000 [00:59<01:31, 499.62it/s]

Epoch 29000: Train Loss: 7.524803465351525e-06, Test Loss: 28.06147039609432, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 40%|████      | 30078/75000 [01:01<01:29, 501.77it/s]

Epoch 30000: Train Loss: 7.492317728010532e-06, Test Loss: 27.74509806970608, Train Accuracy: 1.0, Test Accuracy: 0.0012731481110677123


 41%|████▏     | 31098/75000 [01:03<01:27, 502.28it/s]

Epoch 31000: Train Loss: 7.461456200378397e-06, Test Loss: 27.464479134413967, Train Accuracy: 1.0, Test Accuracy: 0.0011574074160307646


 43%|████▎     | 32067/75000 [01:05<01:25, 501.62it/s]

Epoch 32000: Train Loss: 7.425241405011563e-06, Test Loss: 27.2185777159789, Train Accuracy: 1.0, Test Accuracy: 0.0017361111240461469


 44%|████▍     | 33086/75000 [01:07<01:23, 501.74it/s]

Epoch 33000: Train Loss: 7.392358659557805e-06, Test Loss: 26.975435239425646, Train Accuracy: 1.0, Test Accuracy: 0.001967592630535364


 45%|████▌     | 34055/75000 [01:09<01:21, 502.13it/s]

Epoch 34000: Train Loss: 7.363923512458337e-06, Test Loss: 26.73794948917406, Train Accuracy: 1.0, Test Accuracy: 0.002314814832061529


 47%|████▋     | 35075/75000 [01:11<01:19, 502.54it/s]

Epoch 35000: Train Loss: 7.309714471480002e-06, Test Loss: 26.37225088949308, Train Accuracy: 1.0, Test Accuracy: 0.00277777761220932


 48%|████▊     | 36095/75000 [01:13<01:17, 500.96it/s]

Epoch 36000: Train Loss: 7.264856044341571e-06, Test Loss: 25.98035746159736, Train Accuracy: 1.0, Test Accuracy: 0.003935185261070728


 49%|████▉     | 37050/75000 [01:15<01:22, 458.46it/s]

Epoch 37000: Train Loss: 7.232683014535503e-06, Test Loss: 25.5770855410318, Train Accuracy: 1.0, Test Accuracy: 0.003935185261070728


 51%|█████     | 38050/75000 [01:17<01:23, 442.95it/s]

Epoch 38000: Train Loss: 7.188941963345562e-06, Test Loss: 25.172942494121965, Train Accuracy: 1.0, Test Accuracy: 0.00416666641831398


 52%|█████▏    | 39087/75000 [01:19<01:21, 443.18it/s]

Epoch 39000: Train Loss: 7.141047773126178e-06, Test Loss: 24.769493756876546, Train Accuracy: 1.0, Test Accuracy: 0.00486111082136631


 53%|█████▎    | 40084/75000 [01:22<01:17, 448.76it/s]

Epoch 40000: Train Loss: 7.092052074271086e-06, Test Loss: 24.308201922805644, Train Accuracy: 1.0, Test Accuracy: 0.004398148041218519


 55%|█████▍    | 41087/75000 [01:24<01:14, 452.50it/s]

Epoch 41000: Train Loss: 7.040348210478575e-06, Test Loss: 23.848055355144385, Train Accuracy: 1.0, Test Accuracy: 0.005208333022892475


 56%|█████▌    | 42087/75000 [01:26<01:14, 444.63it/s]

Epoch 42000: Train Loss: 6.976306592459694e-06, Test Loss: 23.37419637600757, Train Accuracy: 1.0, Test Accuracy: 0.005787036847323179


 57%|█████▋    | 43087/75000 [01:28<01:04, 495.85it/s]

Epoch 43000: Train Loss: 6.912835748887075e-06, Test Loss: 22.88790642861647, Train Accuracy: 1.0, Test Accuracy: 0.005671296268701553


 59%|█████▉    | 44098/75000 [01:30<01:01, 502.51it/s]

Epoch 44000: Train Loss: 6.861945306670228e-06, Test Loss: 22.451284704347913, Train Accuracy: 1.0, Test Accuracy: 0.00555555522441864


 60%|██████    | 45061/75000 [01:32<01:00, 491.52it/s]

Epoch 45000: Train Loss: 6.7342440650079595e-06, Test Loss: 21.829569588717305, Train Accuracy: 1.0, Test Accuracy: 0.006018518470227718


 61%|██████▏   | 46074/75000 [01:34<00:57, 500.80it/s]

Epoch 46000: Train Loss: 6.414317237825353e-06, Test Loss: 19.575365906353877, Train Accuracy: 1.0, Test Accuracy: 0.010532407090067863


 63%|██████▎   | 47085/75000 [01:36<01:02, 446.86it/s]

Epoch 47000: Train Loss: 6.343124420598653e-06, Test Loss: 18.95071475322166, Train Accuracy: 1.0, Test Accuracy: 0.011805555783212185


 64%|██████▍   | 48046/75000 [01:38<01:00, 442.91it/s]

Epoch 48000: Train Loss: 6.2874606028192515e-06, Test Loss: 18.567742582969146, Train Accuracy: 1.0, Test Accuracy: 0.012615740299224854


 65%|██████▌   | 49070/75000 [01:41<00:51, 499.52it/s]

Epoch 49000: Train Loss: 6.23843281761228e-06, Test Loss: 18.23249668670641, Train Accuracy: 1.0, Test Accuracy: 0.01354166679084301


 67%|██████▋   | 50086/75000 [01:43<00:50, 496.21it/s]

Epoch 50000: Train Loss: 6.193474432257247e-06, Test Loss: 17.922258976232744, Train Accuracy: 1.0, Test Accuracy: 0.01597222127020359


 68%|██████▊   | 51054/75000 [01:45<00:47, 502.19it/s]

Epoch 51000: Train Loss: 6.153174938434978e-06, Test Loss: 17.620777015700728, Train Accuracy: 1.0, Test Accuracy: 0.016550926491618156


 69%|██████▉   | 52074/75000 [01:47<00:45, 502.47it/s]

Epoch 52000: Train Loss: 6.100840532322743e-06, Test Loss: 17.280815788051274, Train Accuracy: 1.0, Test Accuracy: 0.018171295523643494


 71%|███████   | 53094/75000 [01:49<00:43, 501.63it/s]

Epoch 53000: Train Loss: 6.063849291620353e-06, Test Loss: 16.94379653918741, Train Accuracy: 1.0, Test Accuracy: 0.019212963059544563


 72%|███████▏  | 54050/75000 [01:51<00:46, 449.78it/s]

Epoch 54000: Train Loss: 6.024758016045179e-06, Test Loss: 16.6183743006699, Train Accuracy: 1.0, Test Accuracy: 0.019907407462596893


 73%|███████▎  | 55053/75000 [01:53<00:44, 446.26it/s]

Epoch 55000: Train Loss: 5.985341697855243e-06, Test Loss: 16.296648692931694, Train Accuracy: 1.0, Test Accuracy: 0.021643517538905144


 75%|███████▍  | 56054/75000 [01:55<00:42, 443.35it/s]

Epoch 56000: Train Loss: 5.9418252786658535e-06, Test Loss: 15.939173890708515, Train Accuracy: 1.0, Test Accuracy: 0.025462962687015533


 76%|███████▌  | 57063/75000 [01:57<00:40, 447.18it/s]

Epoch 57000: Train Loss: 5.890394850185071e-06, Test Loss: 15.54493618806889, Train Accuracy: 1.0, Test Accuracy: 0.027199072763323784


 77%|███████▋  | 58061/75000 [02:00<00:34, 496.51it/s]

Epoch 58000: Train Loss: 5.826642916973531e-06, Test Loss: 15.037411139574886, Train Accuracy: 1.0, Test Accuracy: 0.0295138880610466


 79%|███████▉  | 59080/75000 [02:02<00:31, 501.77it/s]

Epoch 59000: Train Loss: 5.732879013082719e-06, Test Loss: 14.254321682879759, Train Accuracy: 1.0, Test Accuracy: 0.03680555522441864


 80%|████████  | 60060/75000 [02:04<00:32, 453.25it/s]

Epoch 60000: Train Loss: 5.627360697916806e-06, Test Loss: 13.368103136708891, Train Accuracy: 1.0, Test Accuracy: 0.04398148134350777


 81%|████████▏ | 61061/75000 [02:06<00:28, 494.73it/s]

Epoch 61000: Train Loss: 5.502662342646345e-06, Test Loss: 12.187907188945404, Train Accuracy: 1.0, Test Accuracy: 0.05474536865949631


 83%|████████▎ | 62064/75000 [02:08<00:26, 496.20it/s]

Epoch 62000: Train Loss: 5.360685661622294e-06, Test Loss: 11.001297270589832, Train Accuracy: 1.0, Test Accuracy: 0.06851851940155029


 84%|████████▍ | 63060/75000 [02:10<00:24, 492.89it/s]

Epoch 63000: Train Loss: 5.121443327639741e-06, Test Loss: 9.225520989872802, Train Accuracy: 1.0, Test Accuracy: 0.10150463134050369


 85%|████████▌ | 64060/75000 [02:12<00:22, 491.33it/s]

Epoch 64000: Train Loss: 4.88577493173858e-06, Test Loss: 7.311357697622685, Train Accuracy: 1.0, Test Accuracy: 0.16006943583488464


 87%|████████▋ | 65060/75000 [02:14<00:20, 477.03it/s]

Epoch 65000: Train Loss: 4.465699717121323e-06, Test Loss: 4.852745334415711, Train Accuracy: 1.0, Test Accuracy: 0.27893519401550293


 88%|████████▊ | 66068/75000 [02:16<00:19, 452.19it/s]

Epoch 66000: Train Loss: 4.013236719781686e-06, Test Loss: 2.364037010628906, Train Accuracy: 1.0, Test Accuracy: 0.5271990895271301


 89%|████████▉ | 67074/75000 [02:18<00:17, 445.44it/s]

Epoch 67000: Train Loss: 3.4904340042329537e-06, Test Loss: 0.6758065565089949, Train Accuracy: 1.0, Test Accuracy: 0.849189817905426


 91%|█████████ | 68072/75000 [02:21<00:15, 441.23it/s]

Epoch 68000: Train Loss: 2.9931952284852097e-06, Test Loss: 0.2437987763206815, Train Accuracy: 1.0, Test Accuracy: 0.9653934836387634


 92%|█████████▏| 69074/75000 [02:23<00:13, 449.54it/s]

Epoch 69000: Train Loss: 2.5913068533276957e-06, Test Loss: 0.19307809618841817, Train Accuracy: 1.0, Test Accuracy: 0.9702546000480652


 93%|█████████▎| 70080/75000 [02:25<00:10, 449.08it/s]

Epoch 70000: Train Loss: 2.356193168237708e-06, Test Loss: 0.0756523948424468, Train Accuracy: 1.0, Test Accuracy: 0.9777777791023254


 95%|█████████▍| 71065/75000 [02:27<00:07, 493.49it/s]

Epoch 71000: Train Loss: 2.071160523341473e-06, Test Loss: 3.5184810063978113e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 96%|█████████▌| 72083/75000 [02:29<00:05, 496.53it/s]

Epoch 72000: Train Loss: 2.085427719825188e-06, Test Loss: 2.8124387770377445e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 97%|█████████▋| 73099/75000 [02:31<00:03, 501.17it/s]

Epoch 73000: Train Loss: 2.1019317401172256e-06, Test Loss: 2.754367274920337e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


 99%|█████████▉| 74068/75000 [02:33<00:01, 503.00it/s]

Epoch 74000: Train Loss: 2.0976692827139747e-06, Test Loss: 2.751178487840493e-06, Train Accuracy: 1.0, Test Accuracy: 1.0


100%|██████████| 75000/75000 [02:35<00:00, 482.26it/s]


Let's plot loss and accuracy. We see the model groks! It initially memorises the training data, and much later on (around epoch 70k) suddenly generalises to the test set.

In [17]:
x = list(range(0, training_cfg["num_epochs"], 1000))
lines([train_losses, test_losses], x=x, labels=["Train Loss", "Test Loss"], xaxis="Epoch", yaxis="Loss", title="Loss vs Epoch", log_y=True)
lines([train_accs, test_accs], x=x, labels=["Train Accuracy", "Test Accuracy"], xaxis="Epoch", yaxis="Accuracy", title="Accuracy vs Epoch")

# Interpretability

Now let's try to interpret the model. To do so, we need to know some things about representation theory. Let's first create a class to hold information about representations. It will share some attributes with the group class. We will subclass from this class to create a class for each representation.

What information do we need, and how do we store it? 

Let $n$ be the order of group $G$. Each representation is a set $n$ $d \times d$ matrices. We store this in our representation objects as tensors of shape $n \times d^2$. 

We also compute what we call an 'orthogonal representation'. I'll explain why this is necessary later on.

Finally we compute the characters $\chi(abc^{-1})$ for each pair of inputs $a,b$ and output $c$. We store this in a tensor of shape $n \times n \times n$. This character tensor is used later to compute logit similarity.

## Defining Representations

In [20]:
class Representation():
    """
    Base class for all representations.
    """

    def __init__(self, compute_rep_params, index, order, multiplication_table, inverses, all_data, group_acronym, irrep=True):
        """
        Initialise the symmetric group representation.

        Args:
            compute_rep_params (tuple): representation specific parameters required for computing the representation
            index (int): group index in family
            order (int): order of the group
            multiplication_table (torch.tensor): square (group.order, group.order) tensor of group multiplication table 
            inverses (torch.tensor): vector of group inverses
            all_data ()
            irrep (Boolean)
        """

        self.index = index
        self.order = order
        self.multiplication_table = multiplication_table
        self.inverses = inverses
        self.all_data = all_data
        self.group_acronym = group_acronym

        # TODO: this is needed to get the dimension of generated representations - think up a better way of doing this
        self.compute_rep_params = compute_rep_params

        self.dim = self.get_rep_dim()

        self.rep = self.compute_rep(*compute_rep_params)

        if irrep:
            self.orth_rep = self.compute_orth_rep(self.rep)
            if self.friendly_name != 'trivial':
                self.logit_trace_tensor_cube = self.compute_logit_trace_tensor_cube()


    def get_rep_dim(self):
        return NotImplementedError

    def compute_rep(self):
        """
        Compute the representation. Must be implemented by child class.

        Raises:
            NotImplementedError
        """
        raise NotImplementedError
    
    def compute_orth_rep(self, rep):
        """
        Use QR decomposition to orthogonalise the representation but retain the subspace spanned by the columns. 

        Args:
            rep (torch.tensor): (group.order, dim^2) tensor of representation

        Returns:
            torch.tensor: (group.order, dim^2) tensor with orthonormal columns en
        """
        
        orth_rep = rep.reshape(self.order, self.dim * self.dim)
        orth_rep = torch.linalg.qr(orth_rep)[0]
        return orth_rep


    def compute_logit_trace_tensor_cube(self):
        """
        Under the hypothesis, the network computes tr(\rho(x)\rho(y)\rho(z^-1)) for some representation \rho.
        This function computes this trace tensor cube for a given representation
        
        Returns:
            torch.tensor: (group.order^3) trace tensor cube
        """
        print(f'Computing trace tensor cube for {self.friendly_name} representation')
        filename = f'utils/cache/{self.group_acronym}{self.index}/{self.group_acronym}{self.index}_{self.friendly_name}_trace_tensor_cube.pt'
        if os.path.exists(filename):
            print('... loading from file')
            t = torch.load(filename)
            return t 
        N = self.all_data.shape[0]
        t = torch.zeros((self.order*self.order, self.order), dtype=torch.float).cuda()
        for i in tqdm(range(N)):
            x = self.all_data[i, 0]
            y = self.all_data[i, 1]
            xy = self.multiplication_table[x, y]
            for z_idx in range(self.order):
                xyz = self.multiplication_table[xy, self.inverses[z_idx]]
                t[i, z_idx] = torch.trace(self.rep[xyz])
        t = t.reshape(self.order, self.order, self.order)
        f = open(filename, 'wb')
        torch.save(t, f)
        return t 

Now we construct all the irreducible representations of S5. One can find a full list [here](https://groupprops.subwiki.org/wiki/Linear_representation_theory_of_symmetric_group:S5). Feel free to skip over this construction if you're not interested in the details.

First, the boring trivial reprsentation. This is one dimensional - that is $d=1$, and all the elements map to the matrix $[1]$. We pass information needed to compute and initialise the representation from the child class to the parent class via `compute_rep_params` and `init_rep_params`.

### Constructing Irreps of S5


In [21]:
class TrivialRepresentation(Representation):
    """
    The trivial representation of the symmetric group.
    """
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the trivial representation. 

        Args:
            compute_rep_params (list): idx_to_perm function required to compute the representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'trivial'
        super().__init__(compute_rep_params, **init_rep_params, irrep=True)
    
    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return 1

    def compute_rep(self):
        """
        Compute the trivial representation.

        Args:
            idx_to_perm (function): function to convert an index to a permutation

        Returns:
            torch.tensor: (group.order, 1) tensor of representation
        """
        return torch.ones((self.order, 1, 1), dtype=torch.float).cuda()
    

# we need some data from the group to initialise representations
rep_params = {
            'index': group.index,
            'order': group.order,
            'multiplication_table': group.multiplication_table,
            'inverses': group.inverses,
            'all_data': group.all_data,
            'group_acronym': group.acronym
        }

trivial_rep = TrivialRepresentation(compute_rep_params=[], init_rep_params=rep_params)

Next, the sign representation. This is also 1 dimensional, but the elements map to $[1]$ or $[-1]$ depending on whether the permutation is even or odd.

In [22]:
class SignRepresentation(Representation):
    """
    Initialise the sign representation of the symmetric group.

    """
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the sign representation. 

        Args:
            compute_rep_params (list): list consisting of the signatures object required to compute the representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'sign'
        super().__init__(compute_rep_params, **init_rep_params)

    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return 1

    def compute_rep(self, signatures):
        """
        Compute the sign representation from the signatures.

        Args:
            signatures (torch.tensor): (group.order, 1) tensor of signatures

        Returns:
            torch.tensor: (group.order, 1, 1) tensor of sign representations
        """
        rep = torch.zeros(self.order, 1, 1).cuda()
        rep[:, 0, 0] = signatures
        return rep

sign_rep = SignRepresentation([group.signatures], rep_params)


Computing trace tensor cube for sign representation
... loading from file


Next, we construct the natural representation. This isn't actually an irreducible representation, but is easy to construct. We can later decompose it into a direct sum of the trivial representation and the `standard` representation. The natural representation itself is $d=5$, with group elements mapping naturally to the 5x5 [permutation matrices](https://en.wikipedia.org/wiki/Permutation_matrix) (with a 1 in each row and column, and 0s elsewhere).

In [23]:
class NaturalRepresentation(Representation):
    """
    Compute the natural representation of the symmetric group.
    """
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the natural representation. 

        Args:
            compute_rep_params (list): idx_to_perm function required to compute the representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'natural'
        super().__init__(compute_rep_params, **init_rep_params, irrep=False)
    
    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return self.index

    def compute_rep(self, idx_to_perm):
        """
        Compute the natural representation by directly computing permutation matrices

        Args:
            idx_to_perm (function): Function that takes an index and returns the corresponding permutation object

        Returns:
            torch.tensor: (group.order, group.index, group.index) tensor of permutation matrices for each group element
        """
        idx = list(np.linspace(0, self.index-1, self.index))
        rep = torch.zeros(self.order, self.index, self.index).cuda()
        for x in range(self.order):
            rep[x, idx, idx_to_perm(x)(idx)] = 1
        return rep

natural_rep = NaturalRepresentation([group.idx_to_perm], rep_params)


an integer is required (got type numpy.float64).  Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.



In [24]:
class StandardRepresentation(Representation):
    """
    Generate the standard representation of the symmetric group.

    """
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the standard representation. 

        Args:
            compute_rep_params (list): list containing natural_reps object necessary to calculate the standard representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'standard'
        super().__init__(compute_rep_params, **init_rep_params)

    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return self.index-1
    
    def compute_rep(self, natural_reps):
        """
        Compute the standard representation from the natural representation.

        Args:
            natural_reps (torch.tensor): (group.order, group.index, group.index) tensor of natural representations (permutation matrices)

        Returns:
            torch.tensor: (group.order, group.index-1, group.index-1) tensor of standard representations
        """
        rep = []
        basis_transform = torch.zeros(self.index, self.index).cuda()
        for i in range(self.index-1):
            basis_transform[i, i] = 1
            basis_transform[i, i+1] = -1
        basis_transform[self.index-1, self.index-1] = 1 #to make the transform non singular
        for x in natural_reps:
            temp = basis_transform @ x @ basis_transform.inverse()
            rep.append(temp[:self.index-1, :self.index-1])
        rep = torch.stack(rep, dim=0).cuda()
        return rep  

standard_rep = StandardRepresentation([natural_rep.rep], rep_params)


Computing trace tensor cube for standard representation
... loading from file


In [25]:
class StandardSignRepresentation(Representation):
    def __init__(self, compute_rep_params, init_rep_params):
        """
        Initialise the tensor product of the standard and sign representation. 

        Args:
            compute_rep_params (list): list consisting of tensor of standard representation and signatures, required to compute the representation
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
        """
        self.friendly_name = 'standard_sign'
        super().__init__(compute_rep_params, **init_rep_params)

    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return self.index-1

    def compute_rep(self, standard_reps, signatures):
        """
        Compute the tensor product of the standard and sign representation.

        Args:
            standard_reps (torch.tensor): (group.order, group.index-1, group.index-1) tensor of standard representations
            signatures (torch.tensor): (group.order, 1) tensor of signatures

        Returns:
            torch.tensor: (group.order, group.index-1, group.index-1) tensor of standard_sign representations
        """
        rep = []
        for i in range(standard_reps.shape[0]):
            rep.append(signatures[i]*standard_reps[i])
        rep = torch.stack(rep, dim=0).cuda()
        return rep
    
standard_sign_rep = StandardSignRepresentation([standard_rep.rep, sign_rep.rep], rep_params)      

Computing trace tensor cube for standard_sign representation
... loading from file


In [26]:
class SymmetricRepresentationFromGenerators(Representation):
    def __init__(self, compute_rep_params, init_rep_params, name):
        """
        Initialise a representation of the symmetric group from a set of generators.

        Args:
            compute_rep_params (list): list consisting of the generators of the representation, sympy group object, and a function that maps indices to permutations
            init_rep_params (dict): standard group parameters needed by the representation, including index, order, multiplication_table, inverses, all_data
            name (str): name of the representation
        """
        self.friendly_name = name
        super().__init__(compute_rep_params, **init_rep_params)

    # TODO: make this less hacky
    def get_rep_dim(self):
        """
        Get the dimension of the representation.

        Returns:
            int: dimension of the representation
        """
        return list(self.compute_rep_params[0].values())[0].shape[0] # hacky way to get the dimension of the representation

    def compute_rep(self, generators, G, idx_to_perm):
        """
        Compute the representation from the generators.

        Args:
            generators (dict): dictionary of generators of the group along with their representations
            G (sympy group object): group object
            idx_to_perm (function): function that maps indices to permutations

        Returns:
            torch.tensor: (group.order, dim, dim) tensor of arbitrary representations
        """
        rep = torch.zeros(self.order, self.dim, self.dim).cuda()
        for i in range(self.order):
            generator_product = G.generator_product(idx_to_perm(i), original=True)
            result = torch.eye(self.dim).float()
            for g in generator_product:
                result = result @ generators[g]
            rep[i] = result
        return rep.cuda()


s5_5d_a_generators = {}
s5_5d_a_generators[Permutation(0, 1, 2, 3, 4)] = torch.tensor([
    [ 1, -1, -1,  1,  0],
    [ 0, -1, -1,  0,  1],
    [ 1, -1,  0,  0,  0],
    [ 0, -1,  0,  0,  0],
    [ 1, -1, -1,  0,  0]
]).float()
s5_5d_a_generators[Permutation(4, 3, 2, 1, 0)] = s5_5d_a_generators[Permutation(0, 1, 2, 3, 4)].inverse()
s5_5d_a_generators[Permutation(4)(0,1)] = torch.tensor([
    [ 0,  0, -1,  0,  0],
    [ 0,  0,  0, -1,  0],
    [-1,  0,  0,  0,  0],
    [ 0, -1,  0,  0,  0],
    [ 0,  0,  0,  0, -1]
]).float()
s5_5d_a_rep = SymmetricRepresentationFromGenerators([s5_5d_a_generators, group.G, group.idx_to_perm], rep_params, 's5_5d_a')

# 3,2 specht
s5_5d_b_generators = {}
s5_5d_b_generators[Permutation(0, 1, 2, 3, 4)] = torch.tensor([
    [-1,  1, -1,  0,  0],
    [ 0,  0,  0,  1, -1],
    [ 0,  0,  1,  0, -1],
    [ 1,  0,  0,  0,  0],
    [ 1,  0,  1,  0,  0]
]).float()
s5_5d_b_generators[Permutation(4, 3, 2, 1, 0)] = s5_5d_b_generators[Permutation(0, 1, 2, 3, 4)].inverse()
s5_5d_b_generators[Permutation(4)(0,1)] = torch.tensor([
    [ 1,  0,  0,  0,  0],
    [ 0,  0,  0, -1,  0],
    [-1,  0,  0,  0, -1],
    [ 0, -1,  0,  0,  0],
    [-1,  0, -1,  0,  0]
]).float()
s5_5d_b_rep = SymmetricRepresentationFromGenerators([s5_5d_b_generators, group.G, group.idx_to_perm], rep_params, 's5_5d_b')

# 3,1,1 specht
s5_6d_generators = {}
s5_6d_generators[Permutation(0, 1, 2, 3, 4)] = torch.tensor([
    [ 1, -1,  1,  0,  0,  0],
    [ 1,  0,  0, -1,  1,  0],
    [ 0,  1,  0, -1,  0,  1],
    [ 1,  0,  0,  0,  0,  0],
    [ 0,  1,  0,  0,  0,  0],
    [ 0,  0,  0,  1,  0,  0]
]).float()
s5_6d_generators[Permutation(4, 3, 2, 1, 0)] = s5_6d_generators[Permutation(0, 1, 2, 3, 4)].inverse()
s5_6d_generators[Permutation(4)(0,1)] = torch.tensor([
    [-1,  0,  0,  0,  0,  0],
    [ 0,  0,  0, -1,  0,  0],
    [ 0,  0,  0,  0, -1,  0],
    [ 0, -1,  0,  0,  0,  0],
    [ 0,  0, -1,  0,  0,  0],
    [ 0,  0,  0,  0,  0,  1]
]).float()
s5_6d_rep = SymmetricRepresentationFromGenerators([s5_6d_generators, group.G, group.idx_to_perm], rep_params, 's5_6d')


Computing trace tensor cube for s5_5d_a representation
... loading from file
Computing trace tensor cube for s5_5d_b representation
... loading from file
Computing trace tensor cube for s5_6d representation
... loading from file


In [27]:
group.irreps = {
    'trivial': trivial_rep,
    'sign': sign_rep,
    'standard': standard_rep,
    'standard_sign': standard_sign_rep,
    's5_5d_a': s5_5d_a_rep,
    's5_5d_b': s5_5d_b_rep,
    's5_6d': s5_6d_rep
}

Now we have all our representations loaded, let's inspect them...