# Palindrome Dataset 

In [29]:
import torch
import string
from transformer_lens import EasyTransformer, EasyTransformerConfig
from torch.utils.data import Dataset, DataLoader
from src.dataset import PalindromeDataset, is_palindrome, get_palindrome_distance

dataset = PalindromeDataset(1000, perturb_n_times=8, k = 3, alphabet=string.ascii_lowercase[:2])
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

for s, y in train_loader:
    print(s)
    print(len(s[0]))
    print([is_palindrome(i) for i in s])
    print([get_palindrome_distance(i) for i in s])
    print(y)
    break

('aabbaa', 'bbbbbb', 'bbbbbb', 'abaaba')
6
[True, True, True, True]
[0, 0, 0, 0]
tensor([True, True, True, True])


# Palindrome Tokenizer

In [17]:
from src.tokenizer import ToyTokenizer

alphabet = string.ascii_lowercase[:2]
print(alphabet)
tokenizer = ToyTokenizer(alphabet)
tokens = tokenizer(["abba", "baba"])
torch.tensor(tokens["input_ids"])

tokenizer.vocab_size

ab


2

# The Model

We will use TransformerLens's `Transformer` class to build a model. The model will be a Transformer with 3 layers, 64 hidden units, 2 attention heads, and a vocabulary size of 28 (24 ascii lowercase + BOS + EOS). We will use the `Transformer` class's `from_config` method to create the model.

[1]:

In [34]:
cfg = EasyTransformerConfig(
    n_layers=3,
    d_model=32,
    d_head=16,
    n_heads=2,
    d_mlp=64,
    d_vocab=30,
    n_ctx=26,
    act_fn="relu",
    normalization_type="LN",
    attention_dir="causal",
    # d_vocab_out=64,
)
model = EasyTransformer(cfg)

# 32 batch size, 24 sequence length, 28 vocab size
model.forward(torch.randint(0, 26, (32, 24))).shape
model.forward(torch.randint(0, 26, (32, 24)))[0][0]

tensor([ 0.9631, -0.6242,  1.0414, -1.2069, -0.8282,  0.1911,  0.3691, -1.7960,
        -0.3446, -0.4955,  1.6231,  0.6735,  2.0840, -0.5482, -0.3408, -0.2557,
        -1.0136, -0.2018, -0.7057,  0.8524, -0.3537, -1.4293,  0.4259,  0.5453,
         0.6986, -0.1323,  0.9153,  0.0112, -0.1668,  1.1856],
       grad_fn=<SelectBackward0>)

In order to do a classification task, we need to wrap our transformer in nn.Module class which appends a classifer. This classifier will look at the embeddings
of the last token in the sequence and use a linear layer to map it to a vector of size 2. The vector will be passed through a softmax to get the probabilities.

In [35]:
import torch.nn as nn
class Classifier(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.transformer = EasyTransformer(cfg)
        self.transformer.unembed = nn.Identity()
        self.linear = torch.nn.Linear(cfg.d_model, 2)
        
    def forward(self, x):
        x = self.transformer(x)
        x = x[:, 0, :]
        x = self.linear(x)
        return x

model = Classifier(cfg)
output = model.forward(torch.randint(0, 26, (32, 24)))
output.shape
output[0]

tensor([ 0.7730, -0.0093], grad_fn=<SelectBackward0>)

# Loss Function

We will use a classic cross entropy loss on the classification token  (token id 0) which is prepended to the input sequence.

In [106]:
from torch.nn import CrossEntropyLoss

for s, y in train_loader:
    print(s)
    tokens = tokenizer(s)
    tokens = torch.tensor(tokens["input_ids"])
    logits = model.forward(tokens)
    
    print(logits.shape)
    print(y.shape)
    break

loss_fn = CrossEntropyLoss()
loss = loss_fn(logits,  y.long())
loss

('fpllmmllpf', 'nzwoffozzw', 'upqqququpq', 'uxwmllmwxu', 'doivrrviod', 'viiqeeqiiv', 'gzvgzgnggg', 'ggdkjjkdgg', 'atakookata', 'rttjbbjttr', 'umcpggpcmu', 'cjdrvvrdjc', 'hhuohqoooh', 'ryyryxryxp', 'rrorosioss', 'wweheeeewe', 'rrcnurrccr', 'xmmbbbuumx', 'tfvfnnfvft', 'sddvvvvdds', 'rwbnuunbwr', 'fxhollohxf', 'gkakkkckzg', 'uywoooowyu', 'xzjybbyjzx', 'lltuulltlf', 'ircvccvcri', 'kfppssppfk', 'dmzzzldmmm', 'ougxqqxguo', 'csqwyywqsc', 'fzhuuuuhzf')
torch.Size([32, 12, 64])
torch.Size([32])


RuntimeError: Expected target size [32, 64], got [32]

In [140]:
import tqdm.notebook as tqdm 
device = 'cpu'
total_examples =1*10**6
loss_fn = torch.nn.CrossEntropyLoss()
alphabet = string.ascii_lowercase

print(f"Alphabet size: {len(alphabet)}")

sequence_length = 10
assert sequence_length % 2 == 0, "Sequence length must be even"
assert sequence_length > 3, "Sequence length must be greater than 3"
k = sequence_length // 2 
print(k)

batch_size = 32

dataset = PalindromeDataset(total_examples, k = k, perturb_n_times=8, alphabet=alphabet)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print(f"Length of tokens: {sequence_length} (including start and end tokens)")

d_head = 16

cfg = EasyTransformerConfig(
    n_layers=3,
    d_model=d_head*2,
    d_head=d_head,
    n_heads=2,
    d_mlp=d_head*4,
    d_vocab= len(alphabet) + 3,
    n_ctx=sequence_length + 2,
    act_fn="relu",
    normalization_type=None,
    attention_dir="bidirectional",
    d_vocab_out=64,
)
model = EasyTransformer(cfg)

classifier = Classifier(cfg)
classifier.to(device)

tokenizer = ToyTokenizer(alphabet)

optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)

losses = []
classifier.train()
pbar = tqdm.tqdm(enumerate(train_loader), total=total_examples//batch_size)
for i, (x,y) in pbar:

    x_tokens = tokenizer(x)
    x_tokens = torch.tensor(x_tokens["input_ids"])
    y = y.to(device).long()
    logits = classifier.forward(x_tokens)
    loss = loss_fn(logits, y)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 100 == 0:
        pbar.set_description(f"Loss: {loss.item():.3f}")


Alphabet size: 26
5
Length of tokens: 10 (including start and end tokens)


  0%|          | 0/31250 [00:00<?, ?it/s]

In [141]:
# test model 

test_dataset = PalindromeDataset(1000, k = k, perturb_n_times=8, alphabet=alphabet)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

classifier.eval()
correct = 0
total = 0
for x,y in test_loader:
    x_tokens = tokenizer(x)
    x_tokens = torch.tensor(x_tokens["input_ids"])
    y = y.to(device).long()
    logits = classifier.forward(x_tokens)
    preds = torch.argmax(logits, dim=1)
    correct += (preds == y).sum().item()
    total += len(y)

print(f"Accuracy: {correct/total:.3f}")


Accuracy: 0.879


In [142]:
# save model

import os
import pickle
import torch
import json

model_path = "models"
os.makedirs(model_path, exist_ok=True)

model_file = os.path.join(model_path, "palindrome_classifier.pt")
torch.save(classifier.state_dict(), model_file)

tokenizer_file = os.path.join(model_path, "palindrome_tokenizer.pkl")
with open(tokenizer_file, "wb") as f:
    pickle.dump(tokenizer, f)

# save config file 
config_file = os.path.join(model_path, "palindrome_config.json")
with open(config_file, "w") as f:
    dictionary = cfg.__dict__
    json.dump(dictionary, f)


In [143]:
dictionary

{'n_layers': 3,
 'd_model': 32,
 'n_ctx': 12,
 'd_head': 16,
 'model_name': 'custom',
 'n_heads': 2,
 'd_mlp': 64,
 'act_fn': 'relu',
 'd_vocab': 29,
 'eps': 1e-05,
 'use_attn_result': False,
 'use_attn_scale': True,
 'use_local_attn': False,
 'original_architecture': None,
 'from_checkpoint': False,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'tokenizer_name': None,
 'window_size': None,
 'attn_types': None,
 'init_mode': 'gpt2',
 'normalization_type': None,
 'device': 'cpu',
 'attention_dir': 'bidirectional',
 'attn_only': False,
 'seed': None,
 'initializer_range': 0.1414213562373095,
 'init_weights': True,
 'scale_attn_by_inverse_layer_idx': False,
 'positional_embedding_type': 'standard',
 'final_rms': False,
 'd_vocab_out': 64,
 'parallel_attn_mlp': False,
 'rotary_dim': None,
 'n_params': 24576}

In [144]:
import plotly.express as px 
px.line(losses, log_x=False, log_y=False)

In [145]:
px.imshow(classifier.transformer.pos_embed.W_pos.detach())

In [146]:
# get cosine similarity between each pair of matrices in the embedding matrix
from sklearn.metrics.pairwise import cosine_similarity
px.imshow(cosine_similarity(classifier.transformer.pos_embed.W_pos.detach()))

# Model Interpretation

In [147]:
test_dataset = PalindromeDataset(4, k = k, perturb_n_times=8, alphabet=alphabet)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
for x,y in test_loader:
    x_tokens = tokenizer(x)
    x_tokens = torch.tensor(x_tokens["input_ids"])
    y = y.to(device).long()
    logits = classifier.forward(x_tokens)
    preds = torch.argmax(logits, dim=1)
    print(f"Input: {x[0]}")
    print(f"Prediction: {preds[0]}")
    print(f"Label: {y[0]}")
    

Input: tnhghhghhh
Prediction: 0
Label: 0
Input: kkkkmumokk
Prediction: 0
Label: 0
Input: uyvljjlvyu
Prediction: 1
Label: 1
Input: wqrsyysrqw
Prediction: 1
Label: 1


In [191]:
classifier.linear.weight.shape

torch.Size([2, 32])

In [148]:
# let's find the direction of unpalindromicity. 

logit_diff = classifier.linear.weight[0,:] - classifier.linear.weight[1,:]

In [150]:
big_data_loader = DataLoader(dataset, batch_size=4000, shuffle=False)
big_tokens, y = next(iter(big_data_loader))
big_tokens = tokenizer(big_tokens)
big_tokens = torch.tensor(big_tokens["input_ids"])
big_tokens

tensor([[27, 14, 14,  ...,  8, 14, 28],
        [27,  3,  6,  ...,  6,  3, 28],
        [27, 18, 18,  ...,  6,  6, 28],
        ...,
        [27,  4, 18,  ...,  4,  4, 28],
        [27, 11,  9,  ...,  9, 11, 28],
        [27,  4,  3,  ...,  3,  4, 28]])

In [151]:
embeddings, cache = classifier.transformer.run_with_cache(big_tokens)
logits = classifier.linear(embeddings)
# print("Loss:", loss_fn(logits, big_tokens).item())

In [152]:
cache.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'blocks.2.attn.hook_q', 'blocks.2.attn.hook_k', 'blocks.2.attn.hook_v', 'blocks.2.attn.hook_attn_scores', 'blocks.2.attn.hook_pattern', 'blocks.2.attn.hook_z', 'blocks.2.hook_attn_out', 'blocks.2.hook_resid_mid', 'blocks.2.mlp.hook_pre', 'blo

In [189]:
classifier.transformer.unembed

Identity()

In [172]:
embeddings[:,0,:].shape

torch.Size([4000, 32])

In [180]:
logit_diff = classifier.linear.weight[1,:] - classifier.linear.weight[0,:]
logit_diff.shape

torch.Size([32])

In [181]:
palindrome_direction_magnitude = embeddings[:,0,:].detach().numpy() @ logit_diff.detach().numpy()
palindrome_direction_magnitude.shape

(4000,)

In [182]:
px.histogram(
    x=palindrome_direction_magnitude,
    color = y.numpy(),
    barmode="overlay",
)

In [183]:
cache.keys()

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'blocks.2.attn.hook_q', 'blocks.2.attn.hook_k', 'blocks.2.attn.hook_v', 'blocks.2.attn.hook_attn_scores', 'blocks.2.attn.hook_pattern', 'blocks.2.attn.hook_z', 'blocks.2.hook_attn_out', 'blocks.2.hook_resid_mid', 'blocks.2.mlp.hook_pre', 'blo

dict_keys(['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'blocks.2.attn.hook_q', 'blocks.2.attn.hook_k', 'blocks.2.attn.hook_v', 'blocks.2.attn.hook_attn_scores', 'blocks.2.attn.hook_pattern', 'blocks.2.attn.hook_z', 'blocks.2.hook_attn_out', 'blocks.2.hook_resid_mid', 'blocks.2.mlp.hook_pre', 'blocks.2.mlp.hook_post', 'blocks.2.hook_mlp_out', 'blocks.2.hook_resid_post'])

In [205]:
classifier.transformer.cfg.n_layers

3

In [202]:
classifier.transformer.cfg.n_heads

2

In [207]:
for key in cache.keys():
    print(key, cache[key].shape)

hook_embed torch.Size([4000, 12, 32])
hook_pos_embed torch.Size([4000, 12, 32])
blocks.0.hook_resid_pre torch.Size([4000, 12, 32])
blocks.0.attn.hook_q torch.Size([4000, 12, 2, 16])
blocks.0.attn.hook_k torch.Size([4000, 12, 2, 16])
blocks.0.attn.hook_v torch.Size([4000, 12, 2, 16])
blocks.0.attn.hook_attn_scores torch.Size([4000, 2, 12, 12])
blocks.0.attn.hook_pattern torch.Size([4000, 2, 12, 12])
blocks.0.attn.hook_z torch.Size([4000, 12, 2, 16])
blocks.0.hook_attn_out torch.Size([4000, 12, 32])
blocks.0.hook_resid_mid torch.Size([4000, 12, 32])
blocks.0.mlp.hook_pre torch.Size([4000, 12, 64])
blocks.0.mlp.hook_post torch.Size([4000, 12, 64])
blocks.0.hook_mlp_out torch.Size([4000, 12, 32])
blocks.0.hook_resid_post torch.Size([4000, 12, 32])
blocks.1.hook_resid_pre torch.Size([4000, 12, 32])
blocks.1.attn.hook_q torch.Size([4000, 12, 2, 16])
blocks.1.attn.hook_k torch.Size([4000, 12, 2, 16])
blocks.1.attn.hook_v torch.Size([4000, 12, 2, 16])
blocks.1.attn.hook_attn_scores torch.Size(

In [213]:
cache['blocks.2.attn.hook_z'].shape

torch.Size([4000, 12, 2, 16])

In [231]:
weights = classifier.transformer.blocks[2].attn.W_O
print(weights.shape) # (n_heads, d_head, d_model)

torch.Size([2, 16, 32])


In [232]:
activation = cache['blocks.2.attn.hook_z']
activation.shape # batch size, sequence length, n_heads, d_head

torch.Size([4000, 12, 2, 16])

In [251]:
from fancy_einsum import einsum

def get_out_by_head(activations, weights) -> torch.Tensor:
    '''
    activations: batch size, sequence length, n_heads, d_head
    weights: n_heads, d_head, d_model
    '''
    outputs = einsum("batch seq n_heads d_head, n_heads d_head d_model -> n_heads batch seq d_model", 
        activations, weights)
    return outputs

weights = classifier.transformer.blocks[0].attn.W_O
activation = cache['blocks.0.attn.hook_z']
out_by_head = get_out_by_head(activation, weights)
head_0_0 = out_by_head[0,:,0,:]
head_0_1 = out_by_head[1,:,0,:]

weights = classifier.transformer.blocks[1].attn.W_O
activation = cache['blocks.1.attn.hook_z']
out_by_head = get_out_by_head(activation, weights)
head_1_0 = out_by_head[0,:,0,:]
head_1_1 = out_by_head[1,:,0,:]

weights = classifier.transformer.blocks[2].attn.W_O
activation = cache['blocks.2.attn.hook_z']
out_by_head = get_out_by_head(activation, weights)
head_2_0 = out_by_head[0,:,0,:]
head_2_1 = out_by_head[1,:,0,:]

In [249]:
head_0_0.shape

torch.Size([4000, 32])

In [243]:
cache['blocks.0.hook_mlp_out'][:,0,:].shape

torch.Size([4000, 32])

In [256]:
magnitudes = torch.tensor([
    embeddings[:,0,:].detach().numpy() @ logit_diff.detach().numpy(),
    head_0_0.detach().numpy() @ logit_diff.detach().numpy(),
    head_0_1.detach().numpy() @ logit_diff.detach().numpy(),
    cache['blocks.0.hook_mlp_out'][:,0,:].detach().numpy() @ logit_diff.detach().numpy(),
    head_1_0.detach().numpy() @ logit_diff.detach().numpy(),
    head_1_1.detach().numpy() @ logit_diff.detach().numpy(),
    cache['blocks.1.hook_mlp_out'][:,0,:].detach().numpy() @ logit_diff.detach().numpy(),
    head_2_0.detach().numpy() @ logit_diff.detach().numpy(),
    head_2_1.detach().numpy() @ logit_diff.detach().numpy(),
    cache['blocks.2.hook_mlp_out'][:,0,:].detach().numpy() @ logit_diff.detach().numpy(),
])




Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:204.)



In [259]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def hists_per_comp(magnitudes, labels):
    num_comps = magnitudes.shape[0]
    titles = {
        (1, 1): "embeddings",
        (2, 1): "head 0.0",
        (2, 2): "head 0.1",
        (2, 3): "mlp 0",
        (3, 1): "head 1.0",
        (3, 2): "head 1.1",
        (3, 3): "mlp 1",
        (4, 1): "head 2.0",
        (4, 2): "head 2.1",
        (4, 3): "mlp 2"
    }
    assert num_comps == len(titles)

    fig = make_subplots(rows=4, cols=3)
    for ((row, col), title), mag in zip(titles.items(), magnitudes):
        fig.add_trace(go.Histogram(x=mag[labels].numpy(), name="Balanced", marker_color="blue", opacity=0.5, legendgroup = '1', showlegend=title=="embeddings"), row=row, col=col)
        fig.add_trace(go.Histogram(x=mag[~labels].numpy(), name="Unbalanced", marker_color="red", opacity=0.5, legendgroup = '2', showlegend=title=="embeddings"), row=row, col=col)
        fig.update_xaxes(title_text=title, range=[-10, 20], row=row, col=col)
    fig.update_layout(width=1200, height=1200, barmode="overlay", legend=dict(yanchor="top", y=0.92, xanchor="left", x=0.4), title="Histograms of component significance")
    fig.show()
    return fig


unbalanced_dir = logit_diff.detach().numpy()
# magnitudes = out_by_components[:,:,0,:].detach() @ unbalanced_dir
result = hists_per_comp(magnitudes, y.numpy())

In [261]:
logit_diff.shape

torch.Size([32])

In [262]:
classifier.linear

Linear(in_features=32, out_features=2, bias=True)

In [None]:
import transformer_lens.utils as utils
utils.get_act_name