In [1]:
import numpy as np
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
matplotlib.use("pgf")
import matplotlib.pyplot as plt
import time
from requests import get
import zipfile, io
import os
import random
import math
import pickle
from IPython import display
import torch.nn.utils.parametrizations as param
%matplotlib inline

device = 'cpu'

plt.rcParams.update({
        "pgf.texsystem": "pdflatex",
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": "Helvetica",
        'pgf.rcfonts': False,
    })

plt.style.use('seaborn-v0_8-ticks')

matplotlib.rcParams.update({'font.size': 18})

# Model definitions, can skip this.

In [2]:
class DecoderTot(nn.Module):

  def __init__(self, decoder, embed, generator):
    super().__init__()
    self.embed = embed
    self.gen = generator
    self.decoder = decoder
    self.generator = generator

  def forward(self, src, mask):
    return self.generator(self.decoder(self.embed(src), mask))


class Generator(nn.Module):

  def __init__(self, d_model, vocab_size):
    super().__init__()
    self.ln = nn.Linear(d_model, vocab_size)
    self.out = None

  def forward(self, x):
    self.out = self.ln(x)
    return F.log_softmax(self.ln(x), dim=-1)

class Decoder(nn.Module):
    
    def __init__(self, attn, ffn, d_model, dropout, ablation_data_dec):
        super().__init__()
        self.attn = attn
        self.ffn = ffn
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
        self.out = None
        self.out_a = None
        self.n_ab_att = ablation_data_dec[0]
        self.n_ab_ffn = ablation_data_dec[1]
        
    def forward(self, x, mask, n):
        
        x1 = self.norm(x)

        if n in self.n_ab_att:
            x = x + 0*self.dropout(self.attn(x1, x1, x1, n, mask))
        else:
            x = x + self.dropout(self.attn(x1, x1, x1, n, mask))

        self.out_a = x
        
        if n in self.n_ab_ffn:
            self.out = x + 0*self.dropout(self.ffn(self.norm(x), n))
        else:
            self.out = x + self.dropout(self.ffn(self.norm(x), n))
        
        return self.out

class DecoderStack(nn.Module):
    
    def __init__(self, layer, N):
        super().__init__()
        self.N = N
        self.norm = nn.LayerNorm(layer.d_model)
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])
        
    def forward(self, x, mask):
        n = 0
        for layer in self.layers:
            x = layer(x, mask, n)
            n += 1
        return self.norm(x)

def Attention(q, k, v, n, ablation_data_att, mask=None, dropout=None):
            
            M_apply, n_ab_head, ab_head, ab_head_row = ablation_data_att
            
            ### -- Softmax Attention -- ###
            
            # q, k, v are dims (batch_size, # heads, seq_len, d_{k,v}) 
            
            m = torch.arange(k.shape[-2]).view(k.shape[-2], 1)
            t = torch.arange(k.shape[-1]).view(1, k.shape[-1])
            t = torch.exp( - ( 2 * np.log(10**4) / k.shape[-1] ) * torch.floor(t/2) )
            r1 = torch.cos(m * t)
            r2 = torch.sin(m * t)
            
            K = torch.cat((q, k, v))
            
            Kp = torch.einsum('ijkl, kl -> ijkl', K, r1)
            
            L = torch.kron(torch.eye(k.shape[-1]//2), torch.Tensor([[0,-1],[1,0]]))
            K = torch.einsum('ijkl, ml -> ijkm', K, L)
            
            Kp += torch.einsum('ijkl, kl -> ijkl', K, r2)
            
            Kp = Kp.view(-1, k.shape[0], k.shape[1], k.shape[2], k.shape[-1])
            
            q, k, v = Kp[0], Kp[1], Kp[2]
            
            A = torch.matmul(q, k.transpose(-2,-1)) * k.size(-1)**(-0.5)
            
            if M_apply:

                range_in = np.arange(A.shape[-1])
                range_in_1 = np.delete(range_in, ab_head_row[n][0])
                range_in_2 = np.delete(range_in, ab_head_row[n][1])
                index_p1 = torch.tensor(range_in_1)
                index_p2 = torch.tensor(range_in_2)
                Ab_mask = torch.zeros_like(A)
                Ab_mask[:, 0, :, :].index_fill_(-2, index_p1, 1)
                Ab_mask[:, 1, :, :].index_fill_(-2, index_p2, 1)
            
            if mask is not None:
                # mask = mask.unsqueeze(1)
                A.masked_fill_(mask == 0, float('-inf'))

            O = F.softmax(A, dim=-1)

            if dropout is not None:
                O = dropout(O)
            
            if n in n_ab_head and M_apply:
                Ab = Ab_mask
                O = torch.einsum('ijkl, ijkl -> ijkl', O, Ab)
                
            if n in n_ab_head:
                Ab = ab_head[n]
                O = torch.einsum('ijkl, j -> ijkl', O, Ab)
            
            return torch.matmul(O, v), O, A


class MultiHeadedAttention(nn.Module):
    
    def __init__(self, h, d_model, ablation_data_attention, dropout=0.1):
        super().__init__()
        self.d_k = d_model // h
        self.h = h
        self.attn = None
        self.attnA = None
        self.out = None
        self.out_A = None
        self.linears = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(4)])
        self.dropout =  nn.Dropout(p=dropout)
        self.ab = ablation_data_attention
        
    
    def forward(self, query, keys, values, n, mask=None):
        
        batch_size = query.shape[0]
        
        x = [l(z).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, z in zip(self.linears, (query, keys, values))]
        
        y, self.attn, self.attnA = Attention(x[0], x[1], x[2], n, ablation_data_att=self.ab, mask=mask, dropout=self.dropout)

        y = y.transpose(1,2).contiguous().view(batch_size, -1, self.h * self.d_k)
        
        self.out_A = y

        self.out = self.linears[-1](y)

        return self.linears[-1](y)
    
class FeedForward(nn.Module):
    
    def __init__(self, d_model, d_ff, dropout, ablation_data_ffn):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.out = None
        self.out_p = None
        self.FFN_needle = ablation_data_ffn
    
    def forward(self, x, n):
        A = self.w1(x)
        if n == 1:
            for i in range(A.shape[-1]):
                if i in self.FFN_needle:
                    A[:, :, i] *= 0
        self.out = self.relu(A)
        self.out_p = self.w2(self.dropout(self.out))
        return self.w2(self.dropout(self.out))

class Embeddings(nn.Module):
    
    def __init__(self, src_vocab, d_model):
        super().__init__()
        self.Emb = nn.Embedding(src_vocab, d_model)
        self.d_model = d_model
        self.out_e = None
    
    def forward(self, x):
        self.out_e = self.Emb(x) * np.sqrt(self.d_model)
        return self.Emb(x) * np.sqrt(self.d_model)

def make_model(vocab, N, d_model, d_ff, h, dropout, ablation_data):
    
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model, ablation_data[0])
    ffn = FeedForward(d_model, d_ff, dropout, ablation_data[1])
    
    model = DecoderTot(DecoderStack(Decoder( c(attn), c(ffn), d_model, dropout, ablation_data[2]), N),
                           Embeddings(vocab, d_model),  Generator(d_model, vocab))
    
    for p in model.parameters():
        if p.dim() > 1: # This is there to not initialize the biases
            nn.init.xavier_uniform_(p)
    
    return model

# Functions to construct dataset and label it according to where the carrying happens.

In [3]:
""" 

Generate the dataset for ndig digits (with 0 padding) and ndig + n_extra digits 

"""

def GenerateDataset(ndig, n_extra):

    P = 10**ndig
    
    data = []
    target = []

    stoi = {'0': 0, '1': 1, '2': 2,'3': 3,'4': 4,'5': 5,'6': 6,'7': 7,'8': 8,'9': 9,'+': 10,'=': 11}

    for i in range(P):
        for j in range(P):
            if i + j < P:
                li = list(f'{i}')
                lj = list(f'{j}')
                lij = list(f'{i+j}')
                if len(li) < ndig + n_extra:
                    li = ['0'] * (ndig + n_extra - len(li)) + li
                if len(lj) < ndig + n_extra:
                    lj = ['0'] * (ndig + n_extra - len(lj)) + lj
                if len(lij) < ndig + n_extra:
                    lij = ['0'] * (ndig + n_extra - len(lij)) + lij

                lsum = li + ['+'] + lj + ['='] * (ndig + n_extra)
                # lij = ['0'] * (2*(ndig + 3) + 1) + lij
                data.append([stoi[lsum[i]] for i in range(len(lsum))])
                target.append([stoi[lij[i]] for i in range(len(lij))])

    vocab = len(stoi) 
    data = torch.LongTensor(data)
    target = torch.LongTensor(target)

    data_f = []
    target_f = []

    P_f = 10**(ndig + n_extra) - 1

    k = 0
    while k < 20000:
        i = torch.randint(P_f, size=(1,)).item()
        j = torch.randint(P_f, size=(1,)).item()
        if i + j < P_f + 1:
            li = list(f'{i}')
            lj = list(f'{j}')
            lij = list(f'{i+j}')
            if len(li) < ndig + n_extra:
                li = ['0'] * (ndig + n_extra - len(li)) + li
            if len(lj) < ndig + n_extra:
                lj = ['0'] * (ndig + n_extra - len(lj)) + lj
            if len(lij) < ndig + n_extra:
                lij = ['0'] * (ndig + n_extra - len(lij)) + lij

            lsum = li + ['+'] + lj + ['='] * (ndig + n_extra)
            # lij = ['0'] * (2*(ndig + 3) + 1) + lij
            data_f.append([stoi[lsum[i]] for i in range(len(lsum))])
            target_f.append([stoi[lij[i]] for i in range(len(lij))])
            k += 1

    data_f = torch.LongTensor(data_f)
    target_f = torch.LongTensor(target_f)

    return vocab, data, target, data_f, target_f

""" 

Construct a label for each example, based on whether the sum is >= 10, < 9 or ==9. The label is an integer constructed as follows. We start with an integer s = 0.
At position i, if the sum at this position is < 9, then we add 0 to s. if it is >= 10, we add 10**i to s and if it is equal to 9 (and the previousu sum was >= 10),
then we add 2*10**i to s. 

"""

def Pick_classes(data, ndig):
    s = 0
    if data[ndig - 1] + data[2*ndig] >= 10:
        r = 1
        s += 1
    else:
        r = 0
        s += 0

    for i in range(1, ndig):

        if data[ndig - i - 1] + data[2*ndig - i] >= 10:
            r = 1
            s += 10**i
        elif (data[ndig - 1 - i] + data[2*ndig - i] == 9) and (r == 1):
            r = 1
            s += 2*10**i
        else:
            r = 0
            s += 0
    return s 

""" 

Find indices of the target integer with a given output at a given position.

"""

def Pick_target_classes(target):
    idx = []
    for i in range(target.shape[-1]):
        ids_p = []
        for j in range(10):
            ids_p.append(torch.argwhere(target[:, i] == j).view(-1))
        idx.append(ids_p)
    return idx


# Get ablation data

<font size = '3'> Specify what component in what layer to ablate. </font>

In [4]:
### Ablation specification

## Attention 

### The function generate_head_layer_ablations generates the set of all possible combinations of heads that can be ablated. 
# Pick one of these an use it for ab_head. Then specify with n_ab_head what layer you want to apply that ablation to. 
# The list ab_headrow is used in conjuction with M_apply. If M_apply is true, ab_headrow will be used and will ablate a specific row of the attention pattern.
# Again n_ab_head determines what layer this ablation is applied to. 

n_ab_head = [] # Specifies of what layer we want to apply zero-ablation to.
ab_head = torch.tensor([[1, 1], [1, 1]]) # Specifies what head to ablate. This is a tensor of size (n, heads) with a 0 for those heads that you want to ablate, otherwise entries are one.
ab_headrow = [] # Specifies what row within a head we want to apply zero-ablation to.
M_apply = False # Specifies to whether to apply row-wise ablation or not. If True, use n_ab_head to set what layers you want the row-wise ablation to apply to.


ablation_attention = [M_apply, n_ab_head, ab_head, ab_headrow]

# FFN

List_Neurons = [] # List of neurons in the final layer you want to ablate.

ablation_ffn = List_Neurons

## Decoder 

n_ab_ffn = [] # Specifies which layer you want to ablate the FFN of. 
n_ab_att = [] # Specifies which layer you want to ablate the entire attention of. 

ablation_decoder = [n_ab_att, n_ab_ffn]

ablations = [ablation_attention, ablation_ffn, ablation_decoder]

# Load model

<font size = "3"> Change directory and mmd to get the model you want to load, default is the primed model. </font>

In [5]:
n = 2 # Number of layers
s = 0.3 # Train/test split
w = 0.2 # weight decay
p = 1 # Specifies which of the six models to consider. In this block used as a dummy variable, but can be specified later on. 

vocab = 12

d_ff = 128
d_model = 128
heads = 2

ndig = 3
ntot = 6

directory = 'finetuning/'.format(n, s, w)
mdd = 'model_n{!s}_s{!s}_w{!s}_{!s}_FT_p_epoch_500'.format(n, s, w, p)
toLoad = directory + mdd

n_layer = n # number of layers
d_model = 128 # model dimension, residual stream
d_ff = d_model # dim intermediate feed-forward layer
h_a = 2 # number of heads in attention (doesnt impact # of params)

model = make_model(vocab, N = n_layer, d_model = d_model, d_ff = d_ff, h = h_a, dropout=0.1, ablation_data=ablations)
model.load_state_dict(torch.load(toLoad, map_location='cpu')['model'])

mask = None

In [8]:
n = 2 # Number of layers
s = 0.3 # Train/test split
w = 0.2 # weight decay
p = 0 # Specifies which of the six models to consider. In this block used as a dummy variable, but can be specified later on. 

vocab = 12

d_ff = 128
d_model = 128
heads = 2

ndig = 3
ntot = 6

directory = 'unprimed/'.format(n, s, w)
mdd = 'model_n{!s}_s{!s}_w{!s}_{!s}_epoch_500'.format(n, s, w, p)
toLoad = directory + mdd

n_layer = n # number of layers
d_model = 128 # model dimension, residual stream
d_ff = d_model # dim intermediate feed-forward layer
h_a = 2 # number of heads in attention (doesnt impact # of params)

modelp = make_model(vocab, N = n_layer, d_model = d_model, d_ff = d_ff, h = h_a, dropout=0.1, ablation_data=ablations)
modelp.load_state_dict(torch.load(toLoad, map_location='cpu')['model'])

mask = None

In [14]:
for i in range(n):
    print(f'--- layer {i} ---')
    for j in range(4):

        att = modelp.decoder.layers[i].attn.linears[j].weight.detach()

        att_ft = model.decoder.layers[i].attn.linears[j].weight.detach()

        zero_modes = (torch.abs((att - att_ft)) < 10**-6).sum()

        print(zero_modes)


--- layer 0 ---
tensor(14)
tensor(8)
tensor(3)
tensor(4)
--- layer 1 ---
tensor(7)
tensor(10)
tensor(11)
tensor(3)


# Generate data and compute accuracy on the ntot digit sums, i.e. we use the datasets data_f and target_f.

In [None]:
vocab, data, target, data_f, target_f = GenerateDataset(ndig, ntot - ndig)

inputs = data_f[:20000]
targets = target_f[:20000]

with torch.no_grad():
    model.eval()
    out = model(inputs, None)
    out_p = torch.argmax(out, -1)

pre_acc = (out_p[:, -ntot:] == targets[:, -ntot:])

print('--- Accuracy per position ---')
print((sum(pre_acc[i] for i in range(len(out_p))) / len(out_p)))
print('--- Accuracy (correctness of target) --- ')
print((sum(pre_acc[i].min() for i in range(len(out_p))) / len(out_p)))

# Construct a dictionary with keys being the type of sum (where the carry is etc.) and the values are the positions in the full dataset

In [None]:
d_pos = {}

ix = torch.randint(0, len(data_f), size=(20000,))
data_p = data_f[ix]
target_p = target_f[ix]

for i, src in enumerate(data_p):
    cls_ = Pick_classes(src, ntot)

    if cls_ not in d_pos.keys():
        d_pos[cls_] = [i]
    elif cls_ in d_pos.keys():
        d_pos[cls_].append(i)

# Compute the attention patterns for Fig. 21, 23 and 25.

<font size = '3'> Make sure to select the primed model to reproduce Fig. 23, and the unprimed model for Fig. 21, and the finetuned model for Fig. 25. </font>

In [None]:
positions = [1, 100, 210, 1110, 10010]

colors = ['Blues', 'Oranges', 'Greens', 'Reds', 'Purples']

fig, ax = plt.subplots(2*n, 5, figsize=(16, 20))
fig.tight_layout(h_pad=-30, w_pad=-1)

out = model(data_p, None)

rows = []
for i in range(len(positions)):
    
    for l in range(n):

        att0 = model.decoder.layers[l].attn.attn[d_pos[positions[i]], :, :, :].clone().detach().mean(0)
        if l == n-1: 
            ax[0 + 2*l, i].imshow(att0[0, -ntot:, :], cmap=colors[i])
            ax[1 + 2*l, i].imshow(att0[1, -ntot:, :], cmap=colors[i])
        else:
            ax[0 + 2*l, i].imshow(att0[0, :, :], cmap=colors[i])
            ax[1 + 2*l, i].imshow(att0[1, :, :], cmap=colors[i])
        if l == 0:
            rows.extend(['$\\rm Head\;0\\hspace{-5pt}:\\hspace{-5pt}0$', '$\\rm Head\;0\\hspace{-5pt}:\\hspace{-5pt}1$'])
        elif l == 1:
            rows.extend(['$\\rm Head\;1\\hspace{-5pt}:\\hspace{-5pt}0$', '$\\rm Head\;1\\hspace{-5pt}:\\hspace{-5pt}1$'])

ax[0, 0].set_title('$\\texttt{000001}$')
ax[0, 1].set_title('$\\texttt{000100}$')
ax[0, 2].set_title('$\\texttt{000210}$')
ax[0, 3].set_title('$\\texttt{001110}$')    
ax[0, 4].set_title('$\\texttt{010010}$')

for i in range(5):
    for j in range(4):  
        ax[j, i].set_xticks([])
        ax[j, i].set_yticks([])
for j in range(4):
    if j < 2:
        ax[j, 0].set_yticks(range(3*ntot+1), ['$*$']*ntot + ['$+$'] + ['$*$']*ntot + ['$=$'] * ntot)
    else:
        ax[j, 0].set_yticks(range(ntot), ['$=$']*ntot)  
for i in range(5):
    ax[-1, i].set_xticks(range(3*ntot+1), ['$*$']*ntot + ['$+$'] + ['$*$']*ntot + ['$=$'] * ntot) 

for ax, row in zip(ax[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - 5, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size='large', ha='right', va='center')

# Compute the PCA of a given layer and after a given component.  Can be used to reproduce Fig. 24 and Fig. 22 and the claims about the finetuned model.

<font size ='3'> Layers can be selected in the usual way. Use the attribute 'out_a' for the output of the attention (default) and the attribute 'out' for the output of the MLP.

In [None]:
dlist = sorted(list(d_pos.keys()))
lg = []

inputs = data_p

with torch.no_grad():
    model.eval()
    out = model(inputs, None)

Out_a = model.decoder.layers[1].out_a.detach().clone()
Out_a = Out_a - Out_a.mean(0, keepdim=True)

pca = [torch.svd(Out_a[:, k, :]) for k in range(3*ntot+1)]

# Plots PCA for the last ntot positions for the specified hidden states from above. Labelled according to the type of sums.

In [None]:
fig, ax = plt.subplots(1, ntot, figsize=(22, 3))
fig.tight_layout(h_pad=-1, w_pad=1)

a, b = 0, 1
lg = []
for i, pos in enumerate([1, 100, 210, 1110, 10010]):
    for j in range(2*ntot+1, 3*ntot+1):
        x = pca[j][0][d_pos[pos], a] * pca[j][1][a]
        y = pca[j][0][d_pos[pos], b] * pca[j][1][b]

        ax[j-2*ntot-1].scatter(x, y, alpha=0.3)
    q = len(str(pos))
    t = ''.join(['0']*(ntot - q) + str(pos).split())
    lg.append(f'$\\texttt{t}$')


lg = fig.legend(lg, bbox_to_anchor=(1.06, 0.24), loc='lower right', borderaxespad=0.)
for handle in lg.legend_handles:
    handle.set_alpha(1)

# Plots PCA for the last ntot positions for the specified hidden states from above. Labelled according to the value of the target digit. Can be used to reproduce Fig. 24.

In [None]:
tgt_classes = Pick_target_classes(target_p)

u = []
for pos in [1, 100, 210, 1110, 10010]:
    u += d_pos[pos]

fig, ax = plt.subplots(1, 6, figsize=(22, 3))
fig.tight_layout(h_pad=-1, w_pad=1)

a, b = 0, 1
lg = []

colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
                    '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
                    '#bcbd22', '#17becf']

q = 0
for j in range(13, 19):
    for i in range(10):
        tgt_class = tgt_classes[j-13][i]
        intersection = list(set(tgt_class.tolist()).intersection(set(u)))
        x = pca[j][0][intersection, a] * pca[j][1][a]
        y = pca[j][0][intersection, b] * pca[j][1][b]

        ax[j-13].scatter(x, y, alpha=0.3, color=colors[i])
    # q = len(str(s))
    # t = ''.join(['0']*(6 - q) + str(s).split())
    # lg.append(f'$\\texttt{t}$')
    
lg = plt.legend(['$0$', '$1$', '$2$', '$3$', '$4$', '$5$', '$6$', '$7$', '$8$', '$9$'], bbox_to_anchor=(1.4, 0.30), loc='lower right', borderaxespad=0.)
for handle in lg.legend_handles:
    handle.set_alpha(1)


# Code below can be used to generate the SVD of the preactivations, labelled by whether they are activated or not, depending on position and type of sum. (for six digit sum, there are six rows and 6 rows for the 6 types of sums we consider)

In [None]:
"""
    This generates a scatter plot of 64 most active neurons for a given task and position.

"""

number_most_active_neurons = 64 # Number of most active neurons to consider

model.eval()

# Get preactivation weights and perform SVD

MLP = model.decoder.layers[1].ffn.w1.weight.clone().detach()
svd_mlp = torch.svd(MLP)

positions = []

# The types of sums to consider and collect positions in full dataset.

for case in [0, 1, 10, 11, 21, 1100]:

    positions.append(d_pos[case])

# Label target data according to target digit at given position.
digit_ans_pos = []
for k in range(6):
    ans_pp = []
    for i in range(10):
        ans_p = []
        for j in range(len(target_p)):
            if target_p[j, k] == i:
                ans_p.append(j)
        ans_pp.append(torch.tensor(ans_p))
    digit_ans_pos.append(ans_pp)

out = model(data_p, None)

z_out = []
for k in range(6): 
    z_out_p = []
    for i in range(10):
        # Computes activations for examples labelled by their target digit at a given output position.
        z = model.decoder.layers[1].ffn.out[digit_ans_pos[k][i], 13+k, :].clone().detach()
        z_out_p.append(z)
    z_out.append(z_out_p)

z_out_n = []
for k in range(6): 
    z_out_p = []
    for i in range(len(positions)):
        # Computes activations for examples labelled by the type of sum at a given output position.
        z = model.decoder.layers[1].ffn.out[positions[i], 13+k, :].clone().detach()
        z_out_p.append(z)
    z_out_n.append(z_out_p)

def counter(a):
    """Counts number of occurences of particular elements in tensor. 

    Args:
        a (torch.tensor): a tensor

    Returns:
       torch.tensor: tensor consisting of pairs of element with frequency.
    """
    b = a.sort(descending=True)[0]
    s = []
    j = 1
    for i in range(b.shape[0]-1):
        if (b[i] > b[i+1]) and (i != b.shape[0] - 2):
            s.append([b[i], j])
            j = 1
        elif i == b.shape[0] - 2:
            s.append([b[i], j + 1])
        elif b[i] == b[i+1]: 
            j += 1
        
    return torch.tensor(s)

fig, ax = plt.subplots(6, 6, figsize=(20, 10))

for k in range(6):
    for j in range(len(positions)):
        zz = torch.argwhere(z_out_n[k][j] != 0)[:, 1] # Get active neurons.

        # Count the number of times a active neurons is activated.
        s_ = counter(zz) 

        # Get only the top 64 most activated neurons.
        sp = s_[s_[:, 1].sort(descending=True)[1]][:number_most_active_neurons, 0]

        # Active neurons
        zzpos = list(set(sp.tolist()))
        # Non-active neurons
        notzzpos = list(set(range(d_ff)) - set(zzpos))
       
        a, b = 0, 1

        x = svd_mlp[0][notzzpos, a] * svd_mlp[1][a]
        y = svd_mlp[0][notzzpos, b] * svd_mlp[1][b]

        ax[5-k, j].scatter(x, y, alpha=0.3, color='blue')

        x = svd_mlp[0][zzpos, a] * svd_mlp[1][a]
        y = svd_mlp[0][zzpos, b] * svd_mlp[1][b]

        ax[5-k, j].scatter(x, y, alpha=0.3, color='red')
        ax[5-k, j].set_yticks([])
        ax[5-k, j].set_xticks([])

        # x = svd1[0][L, a] * svd1[1][a]
        # y = svd1[0][L, b] * svd1[1][b]

        # ax[j, k].scatter(x, y, alpha=0.3, color='black', marker='x')
    for i, txt in enumerate([0, 1, 10, 11, 21, 1100]):
        ax[0, i].set_title(f'$\\texttt{txt}$')
   