# Palindrome Dataset 

In [25]:
import torch
from transformer_lens import EasyTransformer, EasyTransformerConfig
from torch.utils.data import Dataset, DataLoader
from src.dataset import PalindromeDataset, is_palindrome, get_palindrome_distance

dataset = PalindromeDataset(1000, perturb_n_times=8, k = 2, alphabet='abc')
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

for s, y in train_loader:
    print(s)
    print(len(s[0]))
    print([is_palindrome(i) for i in s])
    print([get_palindrome_distance(i) for i in s])
    print(y)
    break

('caca', 'cccc', 'bcbb', 'cbbc')
4
[False, True, False, True]
[2, 0, 1, 0]
tensor([False,  True, False,  True])


# Palindrome Tokenizer

In [35]:
import string
from src.tokenizer import SimpleTokenizer


alphabet = string.ascii_lowercase[:2]
print(alphabet)
tokenizer = SimpleTokenizer(alphabet)
tokens = tokenizer.tokenize(["abba", "baba"])
print(tokens.shape)
tokens.max()

ab
torch.Size([2, 6])


tensor(4)

# The Model

We will use TransformerLens's `Transformer` class to build a model. The model will be a Transformer with 3 layers, 64 hidden units, 2 attention heads, and a vocabulary size of 28 (24 ascii lowercase + BOS + EOS). We will use the `Transformer` class's `from_config` method to create the model.

[1]:

In [10]:
cfg = EasyTransformerConfig(
    n_layers=3,
    d_model=32,
    d_head=16,
    n_heads=2,
    d_mlp=64,
    d_vocab=30,
    n_ctx=26,
    act_fn="relu",
    normalization_type="LN",
    attention_dir="bidirectional",
    # d_vocab_out=64,
)
model = EasyTransformer(cfg)

# 32 batch size, 24 sequence length, 28 vocab size
model.forward(torch.randint(0, 26, (32, 24))).shape
model.forward(torch.randint(0, 26, (32, 24)))[0][0]

tensor([ 0.7976, -0.8924,  0.4307,  0.3714,  0.2339, -0.6624, -0.0363,  0.7190,
        -0.3802,  0.3605,  0.0998, -0.2518, -0.3539,  0.9955,  1.7652,  1.1463,
        -0.0562,  1.2906, -0.4194,  0.7166, -0.2011, -0.1640, -0.4272,  2.0686,
        -0.1776, -0.7125, -0.8695, -0.4476, -0.2691,  0.4845],
       grad_fn=<SelectBackward0>)

In [12]:
import torch.nn as nn
class Classifier(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.transformer = EasyTransformer(cfg)
        self.transformer.unembed = nn.Identity()
        self.linear = torch.nn.Linear(cfg.d_model, 2)
        
    def forward(self, x):
        x = self.transformer(x)
        x = x[:, 0, :]
        x = self.linear(x)
        return x

model = Classifier(cfg)
output = model.forward(torch.randint(0, 26, (32, 24)))
output.shape
output[0]

tensor([-0.8926,  0.4910], grad_fn=<SelectBackward0>)

# Loss Function

We will use a classic cross entropy loss on the classification token which we is prepended to the input sequence.

In [16]:
from torch.nn import CrossEntropyLoss

for s, y in train_loader:
    tokens = tokenizer.tokenize(s)
    print(tokens[0])
    print(tokens.shape)
    logits = model.forward(tokens)
    
    print(logits.shape)
    print(y.shape)
    break

loss_fn = CrossEntropyLoss()
loss = loss_fn(logits,  y.long())
loss

tensor([0, 5, 5, 3, 3, 5, 5, 2])
torch.Size([32, 8])
torch.Size([32, 2])
torch.Size([32])


tensor(0.7862, grad_fn=<NllLossBackward0>)

In [38]:
import tqdm.notebook as tqdm 
device = 'cpu'
total_examples = 100000
loss_fn = torch.nn.CrossEntropyLoss()
alphabet = string.ascii_lowercase[:2]

print(f"Alphabet size: {len(alphabet)}")

dataset = PalindromeDataset(total_examples, k = 2, perturb_n_times=8, alphabet=alphabet)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

length_tokens = 4*2 + 2

print(f"Length of tokens: {length_tokens} (including start and end tokens)")

d_head = 16

cfg = EasyTransformerConfig(
    n_layers=3,
    d_model=d_head*2,
    d_head=d_head,
    n_heads=2,
    d_mlp=d_head*4,
    d_vocab= len(alphabet) + 2 + 1,
    n_ctx=length_tokens,
    act_fn="relu",
    normalization_type="LN",
    attention_dir="bidirectional",
    d_vocab_out=64,
)
model = EasyTransformer(cfg)

classifier = Classifier(cfg)
classifier.to(device)

optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-5)

losses = []
classifier.train()
pbar = tqdm.tqdm(enumerate(train_loader), total=total_examples//32)
for i, (x,y) in pbar:

    x_tokens = tokenizer.tokenize(x)
    x_tokens = x_tokens.to(device)
    y = y.to(device).long()
    logits = classifier.forward(x_tokens)
    loss = loss_fn(logits, y)
    losses.append(loss.item())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 100 == 0:
        pbar.set_description(f"Loss: {loss.item():.3f}")


Alphabet size: 2
Length of tokens: 10 (including start and end tokens)


  0%|          | 0/3125 [00:00<?, ?it/s]

In [39]:
import plotly.express as px 
px.line(losses)

In [21]:
px.imshow(classifier.transformer.pos_embed.W_pos.detach())

In [25]:
# get cosine similarity between each pair of matrices in the embedding matrix
from sklearn.metrics.pairwise import cosine_similarity
px.imshow(cosine_similarity(classifier.transformer.pos_embed.W_pos.detach()))

In [129]:
y[:4]

tensor([1, 1, 0, 1])