# Palindrome Dataset 

In [8]:
import torch
import string
from transformer_lens import EasyTransformer, EasyTransformerConfig
from torch.utils.data import Dataset, DataLoader
from src.dataset import PalindromeDataset, is_palindrome, get_palindrome_distance

dataset = PalindromeDataset(1000, perturb_n_times=8, k = 3, alphabet=string.ascii_lowercase[:2])
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

for s, y in train_loader:
    print(s)
    print(len(s[0]))
    print([is_palindrome(i) for i in s])
    print([get_palindrome_distance(i) for i in s])
    print(y)
    break

('aabbba', 'ababaa', 'aaaaaa', 'bbbbbb')
6
[False, False, True, True]
[1, 2, 0, 0]
tensor([False, False,  True,  True])


# Palindrome Tokenizer

In [9]:
import torch
from src.tokenizer import ToyTokenizer
import string
alphabet = string.ascii_lowercase[:2]
print(alphabet)
tokenizer = ToyTokenizer(alphabet, pad=False, sep=True, cls=False)
tokens = tokenizer(["abba", "baba"])
torch.tensor(tokens["input_ids"])

ab


tensor([[0, 1, 1, 0, 2],
        [1, 0, 1, 0, 2]])

# The Model

We will use TransformerLens's `Transformer` class to build a model. The model will be a Transformer with 3 layers, 64 hidden units, 2 attention heads, and a vocabulary size of 28 (24 ascii lowercase + BOS + EOS). We will use the `Transformer` class's `from_config` method to create the model.

[1]:

In [10]:
cfg = EasyTransformerConfig(
    n_layers=3,
    d_model=32,
    d_head=16,
    n_heads=2,
    d_mlp=64,
    d_vocab=30,
    n_ctx=26,
    act_fn="relu",
    normalization_type="LN",
    attention_dir="causal",
    # d_vocab_out=64,
)
model = EasyTransformer(cfg)

# 32 batch size, 24 sequence length, 28 vocab size
model.forward(torch.randint(0, 26, (32, 24))).shape
model.forward(torch.randint(0, 26, (32, 24)))[0][0]

tensor([ 1.0392,  0.0811, -0.2246, -0.3473, -0.3690,  0.0734, -0.6244,  0.9733,
        -0.0098, -0.0241,  0.0505, -0.1579,  0.8297, -0.9750, -0.0714, -0.2628,
        -0.1653, -0.1548,  0.1972,  1.4743,  0.6208, -0.2412,  0.7243, -1.0446,
         0.8276,  0.2859, -1.1139, -0.8985,  0.8647,  0.4042],
       grad_fn=<SelectBackward0>)

In [11]:
cfg

HookedTransformerConfig:
{'act_fn': 'relu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 16,
 'd_mlp': 64,
 'd_model': 32,
 'd_vocab': 30,
 'd_vocab_out': 30,
 'device': 'cpu',
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'init_mode': 'gpt2',
 'init_weights': True,
 'initializer_range': 0.1414213562373095,
 'model_name': 'custom',
 'n_ctx': 26,
 'n_heads': 2,
 'n_layers': 3,
 'n_params': 24576,
 'normalization_type': 'LN',
 'original_architecture': None,
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': None,
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_local_attn': False,
 'window_size': None}

In order to do a classification task, we need to wrap our transformer in nn.Module class which appends a classifer. This classifier will look at the embeddings
of the last token in the sequence and use a linear layer to map it to a vector of size 2. The vector will be passed through a softmax to get the probabilities.

In [14]:
import torch.nn as nn
class Classifier(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.transformer = EasyTransformer(cfg)
        self.transformer.unembed = nn.Identity()
        self.linear = torch.nn.Linear(cfg.d_model, 2)
        
    def forward(self, x):
        x = self.transformer(x)
        x = x[:, -1, :]
        x = self.linear(x)
        return x

model = Classifier(cfg)
output = model.forward(torch.randint(0, 26, (32, 26)))
output.shape
output[0]

tensor([0.0316, 0.1197], grad_fn=<SelectBackward0>)

# Loss Function

We will use a classic cross entropy loss on the classification token  (token id 0) which is prepended to the input sequence.

In [15]:
from torch.nn import CrossEntropyLoss

for s, y in train_loader:
    print(s)
    tokens = tokenizer(s)
    tokens = torch.tensor(tokens["input_ids"])
    logits = model.forward(tokens)
    
    print(logits.shape)
    print(y.shape)
    break

loss_fn = CrossEntropyLoss()
loss = loss_fn(logits,  y.long())
loss

('bbbbbb', 'aaaaaa', 'abbbba', 'aaaaaa')
torch.Size([4, 2])
torch.Size([4])


tensor(0.6736, grad_fn=<NllLossBackward0>)

In [116]:
import tqdm.notebook as tqdm 
device = 'cpu'
total_examples =1*10**6
loss_fn = torch.nn.CrossEntropyLoss()
alphabet = string.ascii_lowercase

print(f"Alphabet size: {len(alphabet)}")

sequence_length = 10
assert sequence_length % 2 == 0, "Sequence length must be even"
assert sequence_length > 3, "Sequence length must be greater than 3"
k = sequence_length // 2 
print(k)

batch_size = 32

dataset = PalindromeDataset(total_examples, k = k, perturb_n_times=8, alphabet=alphabet)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print(f"Length of tokens: {sequence_length} (including start and end tokens)")

d_head = 16

cfg = EasyTransformerConfig(
    n_layers=3,
    d_model=d_head*2,
    d_head=d_head,
    n_heads=2,
    d_mlp=d_head*4*2,
    d_vocab= len(alphabet) + 2,
    n_ctx=sequence_length + 1,
    act_fn="relu",
    normalization_type=None,
    attention_dir="causal",
    d_vocab_out=64,
)
model = EasyTransformer(cfg)

classifier = Classifier(cfg)
classifier.to(device)

tokenizer = ToyTokenizer(alphabet)

optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)

losses = []
classifier.train()
pbar = tqdm.tqdm(enumerate(train_loader), total=total_examples//batch_size)
for i, (x,y) in pbar:

    x_tokens = tokenizer(x)
    x_tokens = torch.tensor(x_tokens["input_ids"])
    y = y.to(device).long()
    logits = classifier.forward(x_tokens)
    loss = loss_fn(logits, y)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 100 == 0:
        pbar.set_description(f"Loss: {loss.item():.3f}")


Alphabet size: 26
5
Length of tokens: 10 (including start and end tokens)


  0%|          | 0/31250 [00:00<?, ?it/s]

IndexError: index 28 is out of bounds for dimension 0 with size 28

In [17]:
# test model 

test_dataset = PalindromeDataset(1000, k = k, perturb_n_times=8, alphabet=alphabet)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

classifier.eval()
correct = 0
total = 0
for x,y in test_loader:
    x_tokens = tokenizer(x)
    x_tokens = torch.tensor(x_tokens["input_ids"])
    y = y.to(device).long()
    logits = classifier.forward(x_tokens)
    preds = torch.argmax(logits, dim=1)
    correct += (preds == y).sum().item()
    total += len(y)

print(f"Accuracy: {correct/total:.3f}")


Accuracy: 0.963


In [18]:
# save model

import os
import pickle
import torch
import json

model_path = "models"
os.makedirs(model_path, exist_ok=True)

model_file = os.path.join(model_path, "palindrome_classifier.pt")
torch.save(classifier.state_dict(), model_file)

tokenizer_file = os.path.join(model_path, "palindrome_tokenizer.pkl")
with open(tokenizer_file, "wb") as f:
    pickle.dump(tokenizer, f)

# save config file 
config_file = os.path.join(model_path, "palindrome_config.json")
with open(config_file, "w") as f:
    dictionary = cfg.__dict__
    json.dump(dictionary, f)


In [19]:
dictionary

{'n_layers': 3,
 'd_model': 32,
 'n_ctx': 12,
 'd_head': 16,
 'model_name': 'custom',
 'n_heads': 2,
 'd_mlp': 64,
 'act_fn': 'relu',
 'd_vocab': 29,
 'eps': 1e-05,
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_local_attn': False,
 'original_architecture': None,
 'from_checkpoint': False,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'tokenizer_name': None,
 'window_size': None,
 'attn_types': None,
 'init_mode': 'gpt2',
 'normalization_type': None,
 'device': 'cpu',
 'attention_dir': 'causal',
 'attn_only': False,
 'seed': None,
 'initializer_range': 0.1414213562373095,
 'init_weights': True,
 'scale_attn_by_inverse_layer_idx': False,
 'positional_embedding_type': 'standard',
 'final_rms': False,
 'd_vocab_out': 64,
 'parallel_attn_mlp': False,
 'rotary_dim': None,
 'n_params': 24576}

In [21]:
import plotly.express as px 
px.line(losses, log_x=False, log_y=True)

In [22]:
px.imshow(classifier.transformer.pos_embed.W_pos.detach())

In [23]:
# get cosine similarity between each pair of matrices in the embedding matrix
from sklearn.metrics.pairwise import cosine_similarity
px.imshow(cosine_similarity(classifier.transformer.pos_embed.W_pos.detach()))

# Model Interpretation

In [134]:
test_dataset = PalindromeDataset(4, k = k, perturb_n_times=8, alphabet=alphabet)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
for x,y in test_loader:
    x_tokens = tokenizer(x)
    x_tokens = torch.tensor(x_tokens["input_ids"])
    y = y.to(device).long()
    logits = classifier.forward(x_tokens)
    preds = torch.argmax(logits, dim=1)
    print(f"Input: {x[0]}")
    print(f"Prediction: {preds[0]}")
    print(f"Label: {y[0]}")
    

IndexError: index 28 is out of bounds for dimension 0 with size 28

In [117]:
# let's find the direction of unpalindromicity. 
big_data_loader = DataLoader(dataset, batch_size=4000, shuffle=False)
big_tokens, y = next(iter(big_data_loader))
big_tokens = tokenizer(big_tokens)
big_tokens = torch.tensor(big_tokens["input_ids"])
big_tokens

tensor([[17, 10, 15,  ..., 10, 17, 28],
        [10,  6, 18,  ...,  6, 10, 28],
        [13, 23, 13,  ..., 23, 23, 28],
        ...,
        [11,  3, 14,  ...,  3, 12, 28],
        [ 7, 21, 25,  ...,  7, 25, 28],
        [21, 23, 23,  ...,  0, 23, 28]])

In [30]:
embeddings, cache = classifier.transformer.run_with_cache(big_tokens)
logits = classifier.linear(embeddings)

In [135]:
from fancy_einsum import einsum

def get_out_by_head(activations, weights) -> torch.Tensor:
    '''
    activations: batch size, sequence length, n_heads, d_head
    weights: n_heads, d_head, d_model
    '''
    outputs = einsum("batch seq n_heads d_head, n_heads d_head d_model -> n_heads batch seq d_model", 
        activations, weights)
    return outputs


def get_magnitude_by_component(classifier, cache):
    weights = classifier.transformer.blocks[0].attn.W_O
    activation = cache['blocks.0.attn.hook_z']
    out_by_head = get_out_by_head(activation, weights)
    head_0_0 = out_by_head[0,:,-1,:]
    head_0_1 = out_by_head[1,:,-1,:]

    weights = classifier.transformer.blocks[1].attn.W_O
    activation = cache['blocks.1.attn.hook_z']
    out_by_head = get_out_by_head(activation, weights)
    head_1_0 = out_by_head[0,:,-1,:]
    head_1_1 = out_by_head[1,:,-1,:]

    weights = classifier.transformer.blocks[2].attn.W_O
    activation = cache['blocks.2.attn.hook_z']
    out_by_head = get_out_by_head(activation, weights)
    head_2_0 = out_by_head[0,:,-1,:]
    head_2_1 = out_by_head[1,:,-1,:]

    logit_diff = classifier.linear.weight[1,:] - classifier.linear.weight[0,:]

    magnitudes = torch.tensor([
        embeddings[:,0,:].detach().numpy() @ logit_diff.detach().numpy(),
        head_0_0.detach().numpy() @ logit_diff.detach().numpy(),
        head_0_1.detach().numpy() @ logit_diff.detach().numpy(),
        cache['blocks.0.hook_mlp_out'][:,-1,:].detach().numpy() @ logit_diff.detach().numpy(),
        head_1_0.detach().numpy() @ logit_diff.detach().numpy(),
        head_1_1.detach().numpy() @ logit_diff.detach().numpy(),
        cache['blocks.1.hook_mlp_out'][:,-1,:].detach().numpy() @ logit_diff.detach().numpy(),
        head_2_0.detach().numpy() @ logit_diff.detach().numpy(),
        head_2_1.detach().numpy() @ logit_diff.detach().numpy(),
        cache['blocks.2.hook_mlp_out'][:,-1,:].detach().numpy() @ logit_diff.detach().numpy(),
    ])

    return magnitudes

magnitudes = get_magnitude_by_component(classifier, cache)

In [136]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def hists_per_comp(magnitudes, labels):
    num_comps = magnitudes.shape[0]
    titles = {
        (1, 1): "embeddings",
        (2, 1): "head 0.0",
        (2, 2): "head 0.1",
        (2, 3): "mlp 0",
        (3, 1): "head 1.0",
        (3, 2): "head 1.1",
        (3, 3): "mlp 1",
        (4, 1): "head 2.0",
        (4, 2): "head 2.1",
        (4, 3): "mlp 2"
    }
    assert num_comps == len(titles)

    fig = make_subplots(rows=4, cols=3)
    for ((row, col), title), mag in zip(titles.items(), magnitudes):
        fig.add_trace(go.Histogram(x=mag[labels].numpy(), name="Balanced", marker_color="blue", opacity=0.5, legendgroup = '1', showlegend=title=="embeddings"), row=row, col=col)
        fig.add_trace(go.Histogram(x=mag[~labels].numpy(), name="Unbalanced", marker_color="red", opacity=0.5, legendgroup = '2', showlegend=title=="embeddings"), row=row, col=col)
        fig.update_xaxes(title_text=title, range=[-10, 20], row=row, col=col)
    fig.update_layout(width=1200, height=1200, barmode="overlay", legend=dict(yanchor="top", y=0.92, xanchor="left", x=0.4), title="Histograms of component significance")
    fig.show()
    return fig

result = hists_per_comp(magnitudes, y.numpy())

In [137]:
example_tokens = big_tokens[:10,:]
for i in range(10):
    print(tokenizer.decode(example_tokens[i,:]))

b k b b k k k k k b [SEP]
z c l s p p s l c z [SEP]
i r r i b b i e r i [SEP]
d n l n n r n d l r [SEP]
f t t t f d d t t t [SEP]
w p q q p p p w p w [SEP]
g p b f s s f b p g [SEP]
s s l b x x b l s s [SEP]
x v n z b b z n v x [SEP]
l r h z x x z h r l [SEP]


In [138]:
small_embeddings, small_cache = classifier.transformer.run_with_cache(example_tokens)
small_logits = classifier.linear(small_embeddings)

IndexError: index 28 is out of bounds for dimension 0 with size 28

In [139]:
import circuitsvis as cv 


attention_pattern = small_cache["pattern", 0, "attn"]
px.imshow(attention_pattern, title="Attention pattern", animation_frame=0, facet_col=1, range_color=[0, 1]).show()
attention_pattern = small_cache["pattern", 1, "attn"]
px.imshow(attention_pattern, title="Attention pattern", animation_frame=0, facet_col=1, range_color=[0, 1]).show()
attention_pattern = small_cache["pattern", 2, "attn"]
px.imshow(attention_pattern, title="Attention pattern", animation_frame=0, facet_col=1, range_color=[0, 1]).show()

# Failue Type Analysis 

In this type of analysis, we can look at the output of a head in the classification direction and see whether this correlates with the failure type. 

In this case, let's consider algorithms for detecting a palindrome:


In [140]:
# let's find the direction of unpalindromicity. 
big_data_loader = DataLoader(dataset, batch_size=4000, shuffle=False)
big_sequences, y = next(iter(big_data_loader))
big_tokens = tokenizer(big_sequences)
big_tokens = torch.tensor(big_tokens["input_ids"])
big_tokens

tensor([[10, 10, 15,  ..., 16, 10, 28],
        [ 0,  6,  7,  ...,  6,  0, 28],
        [21,  3, 11,  ...,  1,  1, 28],
        ...,
        [23,  4, 23,  ..., 25,  4, 28],
        [ 3, 13, 18,  ..., 13,  3, 28],
        [ 1,  7, 15,  ..., 14, 14, 28]])

In [141]:
# get palindrome distance for sequence in big_sequences
from src.dataset import get_palindrome_distance

distances = [get_palindrome_distance(i) for i in big_sequences]
palindrome = torch.tensor(distances) > 0
px.histogram(distances,
             title="Histogram of palindrome distance",
             color=palindrome,
             opacity=0.5,
             barmode="overlay",
             labels={
                 "value": "Palindrome distance",
                 "color": "Palindrome?"
             }).show()


In [142]:
embeddings, cache = classifier.transformer.run_with_cache(big_tokens)
logits = classifier.linear(embeddings)
magnitudes = get_magnitude_by_component(classifier, cache)

IndexError: index 28 is out of bounds for dimension 0 with size 28

# MLP Analysis

Ok so MLP2 is doing the majority of the work here with head1 and head0 doing some extra work. We can't forget that the model is solving the problem very reliably so there does have to be an answer as to how it does this. 



# Ablation test 

Let's try removing each head and see what effect this has on classification accuracy