In [None]:
import numpy as np
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
matplotlib.use("pgf")
import matplotlib.pyplot as plt

from requests import get
import zipfile, io
import os
import random
import math
import pickle
from IPython import display
%matplotlib inline

from model import make_model
from dataset_generator import dataset_generator
from utilities import generate_head_layer_ablations, svd, PCA_plot, plot_attention_patterns

In [None]:
device = 'cpu'

plt.rcParams.update({
        "pgf.texsystem": "pdflatex",
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": "Helvetica",
        'pgf.rcfonts': False,
    })

plt.style.use('seaborn-v0_8-ticks')

matplotlib.rcParams.update({'font.size': 18})

# plt.locator_params(nbins=4)

# Decoder only models analysis

<font size = '3'> Below is how we dealt with the different kinds of ablations. It is a bit cumbersome, but hopefully you will manage.  </font>

In [None]:
### Ablation specification

## Attention 

### The function generate_head_layer_ablations generates the set of all possible combinations of heads that can be ablated. 
# Pick one of these an use it for ab_head. Then specify with n_ab_head what layer you want to apply that ablation to. 
# The list ab_headrow is used in conjuction with M_apply. If M_apply is true, ab_headrow will be used and will ablate a specific row of the attention pattern.
# Again n_ab_head determines what layer this ablation is applied to. 

n_ab_head = [] # Specifies of what layer we want to apply zero-ablation to.
ab_head = [[1, 1], [1, 1]] # Specifies what head to ablate. This is a tensor of size (n, heads) with a 0 for those heads that you want to ablate, otherwise entries are one.
ab_headrow = [] # Specifies what row within a head we want to apply zero-ablation to.
M_apply = False # Specifies to whether to apply row-wise ablation or not. If True, use n_ab_head to set what layers you want the row-wise ablation to apply to.


ablation_attention = [M_apply, n_ab_head, ab_head, ab_headrow]

# FFN

List_Neurons = [] # List of neurons in the final layer you want to ablate.

ablation_ffn = List_Neurons

## Decoder 

n_ab_ffn = [] # Specifies which layer you want to ablate the FFN of. 
n_ab_att = [] # Specifies which layer you want to ablate the entire attention of. 

ablation_decoder = [n_ab_att, n_ab_ffn]

ablations = [ablation_attention, ablation_ffn, ablation_decoder]

Data generation, output is input data, target data and the vocabulary.

In [None]:
data_f, target_f, stoi = dataset_generator(P_f = 1000)

Load the model with specified layers (n), train/test split (s) and weight decay (w).

In [None]:
n = 2
s = 0.3
w = 0.2

vocab = 12

d_model = 128
d_ff = 128
heads = 2
dropout = 0.1

directory = 'n{!s}_s{!s}_w{!s}/'.format(n, s, w)
mdd = 'model_n{!s}_s{!s}_w{!s}'.format(n, s, w)
toLoad = directory + mdd

model = make_model(vocab, N = n, d_model = d_model, d_ff = d_ff, h = heads, dropout = dropout, ablation_data = ablations)
model.load_state_dict(torch.load(toLoad, map_location='cpu')["model"])
model.eval()

print('# of parameters =', sum(p.nelement() for p in model.parameters())) # number of parameters in total
print('Training set size =', len(data_f)*s)
    

Generates all possible ablations. 

In [None]:
Q = generate_head_layer_ablations(n=n, heads=heads)

# Reproducing Figures 16, 15

In [None]:
""" This outputs:

    1) PCA Analysis at output of Attention and MLP in each layer for the two leading axes. 
    
    2) Attention pattern per task (determined by where carried ones are needed)
"""

### Ablation specification

## Attention 

M_apply = False
n_ab_head = []
# ab_headrow = [[[7, 8, 9], [7, 8, 9]], [[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]], [[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]]]
# ab_headrow = [[[7, 8, 9], [7, 8, 9]], [[],[]]]
ab_headrow = []
ab_head = torch.eye(2,2)

ablation_attention = [M_apply, n_ab_head, ab_head, ab_headrow]

## FFN

List_Neurons = []

ablation_ffn = List_Neurons

## Decoder 

n_ab_ffn = []
n_ab_att = []

ablation_decoder = [n_ab_att, n_ab_ffn]

ablations = [ablation_attention, ablation_ffn, ablation_decoder]

ix = torch.randint(len(data_f), size = (20000,))
data_ff = data_f[ix, :]
target_ff = target_f[ix, -4:-1]

toLoad = directory + mdd

model = make_model(vocab, N = n, d_model = 128, d_ff = d_ff, h = 2, dropout=0.1, ablation_data=ablations)
model.load_state_dict(torch.load(toLoad, map_location='cpu')["model"])
model.eval()

seq_len = data_ff.shape[-1]
mask = torch.tril(torch.ones(seq_len, seq_len))

out = model(data_ff.to('cpu'), mask)

Out_all = []
for l in range(n):

    Out_a = model.decoder.layers[l].out_a[:, :, :].detach().clone()
    Out_a = Out_a - Out_a.mean(0, keepdim=True)
    Out_all.append(Out_a)
    Out_f = model.decoder.layers[l].out[:, :, :].detach().clone()
    Out_f = Out_f - Out_f.mean(0, keepdim=True)
    Out_all.append(Out_f)


svd_full, positions, digit_ans_pos, digit_naive_ans_pos = svd(Out_all, data_ff, target_ff)

PCA_plot(n=n, path="", svd_full=svd_full, positions=positions, digit_ans_pos=digit_ans_pos, digit_naive_ans_pos=digit_naive_ans_pos)

plot_attention_patterns(n=n, model=model, positions=positions)

In [None]:
""" This outputs:

    Accuracy per position
"""

a = torch.tensor([])

toLoad = directory + mdd

### Ablation specification

## Attention 

M_apply = False
n_ab_head = []
# ab_headrow = [[[7, 8, 9], [7, 8, 9]], [[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]], [[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]]]
# ab_headrow = [[[7, 8, 9], [7, 8, 9]], [[],[]]]
ab_headrow = []
ab_head = torch.eye(2,2)

ablation_attention = [M_apply, n_ab_head, ab_head, ab_headrow]

## FFN

List_Neurons = []

ablation_ffn = List_Neurons

## Decoder 

n_ab_ffn = []
n_ab_att = []

ablation_decoder = [n_ab_att, n_ab_ffn]

ablations = [ablation_attention, ablation_ffn, ablation_decoder]

model = make_model(vocab, N = n, d_model = 128, d_ff = d_ff, h = 2, dropout=0.1, ablation_data=ablations)
model.load_state_dict(torch.load(toLoad, map_location='cpu')["model"])
model.eval()

ix = torch.randint(len(data_f), size = (20000,))
inputs = data_f[ix]
targets = target_f[ix]

seq_len = inputs.shape[-1]
mask = torch.tril(torch.ones(seq_len, seq_len))
out = model(inputs, mask)

a = torch.cat((a, (sum((torch.argmax(out[i, -4:-1, :].detach().to('cpu'), -1) == targets[i, -4:-1]).float() for i in range(len(inputs))) / len(inputs)).unsqueeze(0)), 0)
a

# Table 4 (left)

In [None]:
""" This outputs:

    Accuracy, non-corrected and corrected after ablating chosen part of network zero-ablated. 
    We correct answers manually to see whether the original ones where off by one either 
    by forgetting a carried one or adding one where it shouldnt have, this is the list vcorr. 
    This means: For non-carry over sums we subtract 1 from each position and when a carried one is needed we add it.  

    ### By default we have ablated the final MLP! ###
"""

scores_z = torch.tensor([])

for k in range(5):

    ix = torch.randint(len(data_f), size = (20000,))
    data_ff = data_f[ix]
    target_ff = target_f[ix]

    pos_nc = np.argwhere(sum((data_ff[:, j] + data_ff[:, j+4] >= 10) for j in range(3)) < 1)
    pos_c1 = np.argwhere((data_ff[:, 1] + data_ff[:, 1+4] >= 10) & (sum((data_ff[:, j] + data_ff[:, j+4] >= 10).float() for j in np.delete(np.arange(3), 1)) < 1))
    pos_c2 = np.argwhere((data_ff[:, 2] + data_ff[:, 2+4] >= 10) & (data_ff[:, 1] + data_ff[:, 5] < 9))
    pos_2c = np.argwhere(sum((data_ff[:, j] + data_ff[:, j+4] >= 10) for j in range(3)) == 2)
    pos_2cp = np.argwhere((data_ff[:, 1] + data_ff[:, 5] == 9) & (data_ff[:, 2] + data_ff[:, 6] >= 10))

    tasks_src = [data_ff[pos_nc[0]], data_ff[pos_c1[0]], data_ff[pos_c2[0]], data_ff[pos_2c[0]], data_ff[pos_2cp[0]]]
    tasks_tgt = [target_ff[pos_nc[0]], target_ff[pos_c1[0]], target_ff[pos_c2[0]], target_ff[pos_2c[0]], target_ff[pos_2cp[0]]]

    vcorr = [torch.tensor([1, 1, 1]), torch.tensor([-1, 1, 1]), torch.tensor([1, -1, 1]), torch.tensor([-1, -1, 1]), torch.tensor([-1, -1, 1])]

    inputs = tasks_src[k]
    targets = tasks_tgt[k]
    
    scores = torch.tensor([])

    for l in range(1):

        ### Ablation specification

        ## Attention 

        M_apply = False
        n_ab_head = []
        # ab_headrow = [[[7, 8, 9], [7, 8, 9]], [[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]], [[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]]]
        # ab_headrow = [[[7, 8, 9], [7, 8, 9]], [[],[]]]
        ab_headrow = []
        ab_head = torch.eye(2,2)

        ablation_attention = [M_apply, n_ab_head, ab_head, ab_headrow]

        ## FFN

        List_Neurons = []

        ablation_ffn = List_Neurons

        ## Decoder 

        n_ab_ffn = [1] ### Final MLP is ablated
        n_ab_att = []

        ablation_decoder = [n_ab_att, n_ab_ffn]

        ablations = [ablation_attention, ablation_ffn, ablation_decoder]

        model = make_model(vocab, N = n, d_model = 128, d_ff = d_ff, h = 2, dropout=0.1, ablation_data=ablations)
        model.load_state_dict(torch.load(toLoad, map_location='cpu')["model"])
        model.eval()

        seq_len = inputs.shape[-1]
        mask = torch.tril(torch.ones(seq_len, seq_len))
        out = model(inputs, mask)

        pre = torch.argmax(out[:, -4:-1, :].detach().to('cpu'), -1)

        bp = sum((pre[i] == targets[i, -4:-1]).float() for i in range(len(inputs))) / len(inputs)

        bcorr = sum(((pre[i] - vcorr[k]) % 10 == targets[i, -4:-1]).float() for i in range(len(inputs))) / len(inputs)

        scores = torch.cat((scores, torch.cat((bp.unsqueeze(0), bcorr.unsqueeze(0)), 0)), 0)
        
    scores_z = torch.cat((scores_z, scores.unsqueeze(0)), 0)  

scores_z

# Table 4 (right)

In [None]:
""" This outputs:
    
    Dissecting the MLP and computing the accuracy after ablating a specific set of Neurons (List_Neurons)
    We go through the data twice first to find the relevant Neurons and then to rerun the model with those 
    neurons zero-ablated. 
"""

List_Neurons = []

for _ in range(2):
    ### Ablation specification

    ## Attention 

    M_apply = False
    n_ab_head = []
    # ab_headrow = [[[7, 8, 9], [7, 8, 9]], [[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]], [[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]]]
    # ab_headrow = [[[7, 8, 9], [7, 8, 9]], [[],[]]]
    ab_headrow = []
    ab_head = torch.eye(2,2)

    ablation_attention = [M_apply, n_ab_head, ab_head, ab_headrow]

    ## FFN

    ablation_ffn = List_Neurons

    ## Decoder 

    n_ab_ffn = []
    n_ab_att = []

    ablation_decoder = [n_ab_att, n_ab_ffn]

    ablations = [ablation_attention, ablation_ffn, ablation_decoder]

    model = make_model(vocab, N = n, d_model = 128, d_ff = d_ff, h = 2, dropout=0.1, ablation_data=ablations)
    model.load_state_dict(torch.load(toLoad, map_location='cpu')["model"])
    model.eval()

    ix = torch.randint(len(data_f), size = (100000,))
    data_ff = data_f[ix]
    target_ff = target_f[ix]

    pos_nc = np.argwhere(sum((data_ff[:, j] + data_ff[:, j+4] >= 10) for j in range(3)) < 1)
    pos_c1 = np.argwhere((data_ff[:, 1] + data_ff[:, 1+4] >= 10) & (sum((data_ff[:, j] + data_ff[:, j+4] >= 10).float() for j in np.delete(np.arange(3), 1)) < 1))
    pos_c2 = np.argwhere((data_ff[:, 2] + data_ff[:, 2+4] >= 10) & (data_ff[:, 1] + data_ff[:, 5] < 9))
    pos_2c = np.argwhere(sum((data_ff[:, j] + data_ff[:, j+4] >= 10) for j in range(3)) == 2)
    pos_2cp = np.argwhere((data_ff[:, 1] + data_ff[:, 5] == 9) & (data_ff[:, 2] + data_ff[:, 6] >= 10))

    tasks_src = [data_ff[pos_nc[0]], data_ff[pos_c1[0]], data_ff[pos_c2[0]], data_ff[pos_2c[0]], data_ff[pos_2cp[0]]]
    tasks_tgt = [target_ff[pos_nc[0]], target_ff[pos_c1[0]], target_ff[pos_c2[0]], target_ff[pos_2c[0]], target_ff[pos_2cp[0]]]

    vcorr = [torch.tensor([1, 1, 1]), torch.tensor([-1, 1, 1]), torch.tensor([1, -1, 1]), torch.tensor([-1, -1, 1]), torch.tensor([-1, -1, 1])]

    a = torch.tensor([])
    z_out = []
    for k in range(len(tasks_src)):

        inputs = tasks_src[k]
        targets = tasks_tgt[k]

        seq_len = inputs.shape[-1]
        mask = torch.tril(torch.ones(seq_len, seq_len))
        out = model(inputs, mask)

        ap = sum((torch.argmax(out[i, -4:-1, :].detach().to('cpu'), -1) == targets[i, -4:-1]).float() for i in range(len(inputs))) / len(inputs)

        acorr = sum(((torch.argmax(
                        out[i, -4:-1, :].detach().to('cpu'), -1) - vcorr[k]) % 10 == targets[i, -4:-1]).float() for i in range(len(inputs))) / len(inputs)

        a = torch.cat((a, torch.cat((ap.unsqueeze(0), acorr.unsqueeze(0)), 0)), 0)

        z = model.decoder.layers[1].ffn.out[:, -4:-1, :].mean(0).clone().detach()[0]
        z_out.append(z)
    Neur = []
    for i in range(1, 5):  
        zg = torch.argwhere(z_out[i] > z_out[0]).squeeze(-1)
        for j in range(len(zg)):
            if zg[j] not in Neur:
                Neur.append(zg[j].item())
    List_Neurons = Neur
    print(len(List_Neurons), torch.tensor(List_Neurons))
a