# Setup

In [1]:
!nvidia-smi 

/bin/bash: line 1: nvidia-smi: command not found


## Imports

In [2]:
!pip install git+https://github.com/neelnanda-io/Easy-Transformer.git

Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/neelnanda-io/Easy-Transformer.git
  Cloning https://github.com/neelnanda-io/Easy-Transformer.git to /tmp/pip-req-build-_jftmgft
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/Easy-Transformer.git /tmp/pip-req-build-_jftmgft
  Resolved https://github.com/neelnanda-io/Easy-Transformer.git to commit fdd4fa591bc10b428c1154ea16d8577a9f909407
  Preparing metadata (setup.py) ... [?25ldone
You should consider upgrading via the '/usr/local/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


# plotting
from functools import partial
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
pio.renderers.default = "vscode"

# EasyTransformer interpretability tooling
from easy_transformer.hook_points import HookedRootModule, HookPoint 

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Training might be rather slow')

Good to go!


## Helper Functions 

In [5]:
#Plotting functions
# This is mostly a bunch of over-engineered mess to hack Plotly into producing 
# the pretty pictures I want, I recommend not reading too closely unless you 
# want Plotly hacking practice
def to_numpy(tensor, flat=False):
    if type(tensor)!=torch.Tensor:
        return tensor
    if flat:
        return tensor.flatten().detach().cpu().numpy()
    else:
        return tensor.detach().cpu().numpy()

def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', **kwargs):
    # if tensor.shape[0]==p*p:
    #     tensor = unflatten_first(tensor)
    tensor = torch.squeeze(tensor)
    px.imshow(to_numpy(tensor, flat=False), 
              labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name}, 
              **kwargs).show()

# Set default colour scheme
imshow_pos = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both 
# positive and negative values, where 0 is white)

imshow = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
# Presets a bunch of defaults to imshow to make it suitable for showing heatmaps 
# of activations with x axis being input 1 and y axis being input 2.

inputs_heatmap = partial(imshow, xaxis='Input 1', yaxis='Input 2', color_continuous_scale='RdBu', color_continuous_midpoint=0.0)

def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
    if type(y)==torch.Tensor:
        y = to_numpy(y, flat=True)
    if type(x)==torch.Tensor:
        x=to_numpy(x, flat=True)
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    fig.show()

def scatter(x, y, **kwargs):
    px.scatter(x=to_numpy(x, flat=True), y=to_numpy(y, flat=True), **kwargs).show()

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

## Define Groups

### Define Group Parent Class

In [6]:
class Group:
    """
    parent class for all groups
    """
    def __init__(self):
        raise NotImplementedError

    def compute_multiplication_table():
        raise NotImplementedError

    def compose(self, x, y):
        raise NotImplementedError

    def inverse(self, x):
        raise NotImplementedError
    


### Define Individual Groups

In [7]:

class CyclicGroup(Group):
    def __init__(self, index):
        self.index = index
        self.order = index
        self.compute_multiplication_table()


    def compose(self, x, y):
        return (x+y)%self.order

    def inverse(self, x):
        return -x%self.order


    def compute_multiplication_table(self):
        table = torch.zeros((self.order, self.order), dtype=torch.int64)
        for i in range(self.order):
            for j in range(self.order):
                table[i, j] = self.compose(i, j)
        self.multiplication_table = table

    def get_all_data(self, shuffle_seed = False):
        data=torch.zeros((self.order*self.order, 3), dtype=torch.int64)
        for i in range(self.order):
            for j in range(self.order):
                data[i*self.order+j, 0] = i
                data[i*self.order+j, 1] = j
                data[i*self.order+j, 2] = self.multiplication_table[i, j]
        if shuffle_seed:
            torch.manual_seed(shuffle_seed)
            shuffled_indices = torch.randperm(self.order*self.order)
            shuffled_arr = data[shuffled_indices]
        return data
        

#class DihedralGroup(Group):


#class SymettricGroup(Group):

## Define Model

In [8]:
class BilinearNet(HookedRootModule):
    """
    A completely linear network. W1a and W1b are embedding layers, whose outputs are elementwise multiplied. The result is unembedded by W2.
    """
    def __init__(self, hidden, n, seed=0):
        # hidden : hidden dimension size
        # n : group order
        super().__init__()
        torch.manual_seed(seed)

        # initialise parameters
        self.W1a = nn.Parameter(torch.randn(n, hidden)/np.sqrt(hidden))
        self.W1b = nn.Parameter(torch.randn(n, hidden)/np.sqrt(hidden))
        self.Wfinal = nn.Parameter(torch.randn(hidden, n)/np.sqrt(hidden))

        self.x_embed = HookPoint()
        self.y_embed = HookPoint()
        self.product = HookPoint()
        self.out = HookPoint()
        
        # We need to call the setup function of HookedRootModule to build an 
        # internal dictionary of modules and hooks, and to give each hook a name
        super().setup()

    def forward(self, data):
        x = data[:, 0] # (batch) 
        x_embed = self.x_embed(self.W1a[x]) # (batch, hidden)
        y = data[:, 1]
        y_embed = self.y_embed(self.W1b[y]) # (batch, hidden)
        product = self.product(x_embed * y_embed) # (batch, hidden)
        out = self.out(product @ self.Wfinal) #(batch, n)
        return out

class TwoLayerReLUNet(nn.Module):
    def __init__(self, hiddens, n, seed=0):
        # hidden : hidden dimension size
        # n : group order
        embed_dim, hidden = hiddens
        super().__init__()
        torch.manual_seed(seed)

        # xavier initialise parameters
        self.W1a = nn.Parameter(torch.randn(n, embed_dim)/np.sqrt(embed_dim))
        self.W1b = nn.Parameter(torch.randn(n, embed_dim)/np.sqrt(embed_dim))
        self.W2 = nn.Parameter(torch.randn(2*embed_dim, hidden)/np.sqrt(2*embed_dim))
        self.relu = nn.ReLU()
        self.Wfinal = nn.Parameter(torch.randn(hidden, n)/np.sqrt(hidden))
        self.hid = 0 #to access later

    def forward(self, data):
        x = data[:, 0] # (N, ) N is batch size
        rho_x = self.W1a[x] # (N, embed_dim)
        y = data[:, 1]
        rho_y = self.W1b[y] # (N, embed_dim)
        hid1 = torch.hstack((rho_x, rho_y)) # (N, 2*embed_dim)
        self.hid = self.relu(hid1 @ self.W2) # (N, hidden)
        return self.hid @ self.Wfinal #(N, n)

## Generate Data and Loss Functions 

In [9]:
def generate_train_test_data(group, frac_train):
    data = group.get_all_data().cuda()
    train_size = int(frac_train*data.shape[0])
    train = data[:train_size]
    test = data[train_size:]
    train_data = train[:, :2]
    train_labels = train[:, 2]
    test_data = test[:, :2]
    test_labels = test[:, 2]
    return train_data, test_data, train_labels, test_labels

def loss_fn(logits, labels):
    loss = F.cross_entropy(logits, labels)
    return loss

## Model Training


In [10]:
frac_train = 0.3
width = 256
lr = 1e-3
weight_decay = 0.3
num_epochs = 150000

group = CyclicGroup(131)
train_data, test_data, train_labels, test_labels = generate_train_test_data(group, frac_train)

train_losses = []
test_losses = []
train_accs = []
test_accs = []

model = BilinearNet(width, group.order)
model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

for epoch in range(num_epochs):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits, train_labels)
    train_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    train_losses.append(train_loss.item())
    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits, test_labels)
        test_losses.append(test_loss.item())
        train_acc = (train_logits.argmax(1)==train_labels).sum()/len(train_labels)
        test_acc = (test_logits.argmax(1)==test_labels).sum()/len(test_labels)
        train_accs.append(train_acc.item())
        test_accs.append(test_acc.item())
    if epoch%1000 == 0:
        print(f"Epoch:{epoch}")
        print(f"Train: L: {train_losses[-1]:.6f} A: {train_accs[-1]*100:.4f}%")
        print(f"Test: L: {test_losses[-1]:.6f} A: {test_accs[-1]*100:.4f}%")
    if epoch%10000 == 0 and epoch>0:
        lines([train_losses, test_losses], log_y=True, labels=['train loss', 'test loss'])
        lines([train_accs, test_accs], log_y=False, labels=['train acc', 'test acc'])


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

# Interpretability


## Interpretability Set Up 

## Interpretability Helper Functions

## Look at embeddings

## Look at cosine similarity to representations

### Over the whole model

### Over individual neurons