# Experiment

We wish to understand how a small 1-Layer transformer model with only a single attention head can learn to predict correct premutation sequence.

1. Can we possibly understand how the models develops an algorithm of permuting elements in a list ?
2. What goes on in the QK and OV circuits ? 

This experiment is an attempt to answer the above questions.

## Motivation

This experiment is divised to solve one of concrete open problems proposed by Neel Nanda here:
https://www.alignmentforum.org/posts/LbrPTJ4fmABEdEnLf/200-concrete-open-problems-in-mechanistic-interpretability

### Assumptions and constraints:

Like any decent experiment, we first need to come up with assumptions and constraints that limit the scope of the problem so some progress can be tracked.

Following are some assumptions & constraints for this experiment:
1. A 1-Layer attention only transformer model is sufficient to do correct permuations term predictions.
2. ReLU activations are sufficient to begin with. Incorrect, SoLU did better.
3. Permutations of full complete groups shall be used for this exercise. Full complete groups implies - all elements are to be used to generate permutation.
4. Context window will be set to a fixed length and fixed permuation size. Analysis will be done for max permutation size of 5 elements.
5. The limit of vocab size is 62 characters (52 alphabets + 10 digits). There is a possibility of $62 \choose 5$ different sequences to make sure model learns to generalize attention pattern on positions instead of characters seen from previous sequence term.
6. Custom tokenizer (similar to ascii-coding) is used for the provided dataset. This makes the experiment setup and we can only focus on attention mechanisms of transformer which is the meat of the problem.

In [1]:
%pip install ipykernel setuptools transformer_lens ipywidgets plotly nbformat circuitsvis

Note: you may need to restart the kernel to use updated packages.


### Defining Vocabulary and Synthetic data generation:

Example generation:

> \> equ78 equ87 eq7u8 eq78u eq8u7 eq87u euq78 euq87 eu7q8 eu78q eu8q7 eu87q e7qu8 e7q8u e7uq8 e7u8q e78qu e78uq e8qu7 e8q7u e8uq7 e8u7q e87qu e87uq qeu78 qeu87 qe7u8 qe78u qe8u7 qe87u que78 que87 qu7e8 qu78e qu8e7 qu87e q7eu8 q7e8u q7ue8 q7u8e q78eu q78ue q8eu7 q8e7u q8ue7 q8u7e q87eu q87ue ueq78 ueq87 ue7q8 ue78q ue8q7 ue87q uqe78 uqe87 uq7e8 uq78e uq8e7 uq87e u7eq8 u7e8q u7qe8 u7q8e u78eq u78qe u8eq7 u8e7q u8qe7 u8q7e u87eq u87qe 7equ8 7eq8u 7euq8 7eu8q 7e8qu 7e8uq 7qeu8 7qe8u 7que8 7qu8e 7q8eu 7q8ue 7ueq8 7ue8q 7uqe8 7uq8e 7u8eq 7u8qe 78equ 78euq 78qeu 78que 78ueq 78uqe 8equ7 8eq7u 8euq7 8eu7q 8e7qu 8e7uq 8qeu7 8qe7u 8que7 8qu7e 8q7eu 8q7ue 8ueq7 8ue7q 8uqe7 8uq7e 8u7eq 8u7qe 87equ 87euq 87qeu 87que 87ueq 87uqe.

$n_{ctx} = 722$

where $n_{ctx}$ is the context length.

### Custom Tokenizer

In [2]:
import string
from typing import List

class Tokenizer:
    def __init__(self) -> None:
        self._special_chars = "> ."
        self.vocab = self._special_chars + string.ascii_letters + string.digits

    def str_to_tokens(self, s: str) -> List[int]:
        return [self.vocab.index(ch) for ch in s]

    def tokens_to_str(self, tokens: List[int]) -> str:
        return "".join([self.vocab[token] for token in tokens])
    
    def special_chars(self) -> List[str]:
        return list(self._special_chars)
        
    def vocab_without_special_chars(self) -> str:
        return self.vocab[len(self._special_chars):]

### Synthetic data

1. Use `itertools` library to first pick a combination of elements from vocabulary of 62 charcaters. See [itertools](https://docs.python.org/3/library/itertools.html#itertools.combinations)
2. Generate permutation sequence from given elements. See [itertools](https://docs.python.org/3/library/itertools.html#itertools.permutations)

In [12]:
import torch
from random import shuffle
from itertools import permutations, combinations
from typing import List

n_comb = 4
start_pos = 1+n_comb # > abcd (total of 2 + n_comb characters, -1 as indexing is from 0)
tokenizer = Tokenizer()
[beg, space, period] = tokenizer.special_chars()

def permute(vocab: str) -> List[str]:
    return ["".join(item) for item in permutations(vocab, len(vocab))]

def data_from_permute(terms: List[str]) -> str:
    output = space.join(terms)
    
    return f"> {output}."

def generate_dataset(n, combs=None):
    if not combs:
        elems = tokenizer.vocab_without_special_chars()
        combs = list(combinations(elems, n))
        shuffle(combs)
    # excluding special tokens in the beginning
    while True:
        for comb in combs:
            terms = permute("".join(comb))
            yield data_from_permute(terms)


prompt = next(generate_dataset(n=n_comb))
print(prompt)
n_ctx = len(prompt)
print(n_ctx)

> gCLS gCSL gLCS gLSC gSCL gSLC CgLS CgSL CLgS CLSg CSgL CSLg LgCS LgSC LCgS LCSg LSgC LSCg SgCL SgLC SCgL SCLg SLgC SLCg.
122


### Test
1. Data set generator should generate expected data.
2. Conversion of `str_to_tokens` and `tokens_to_string` should do conversions correctly.
2. Model should work correctly for a sample data. Tensor shapes should match with expected shape.
3. Model should incorrectly predict the permuation sequence.

In [13]:
def should_generate_correct_dataset(n: int, expected_data: str):
    test_dataset_gen = generate_dataset(n, combs=combinations(tokenizer.vocab_without_special_chars(), n))
    actual_data = next(test_dataset_gen)
    assert expected_data == actual_data, f"{expected_data} != {actual_data}"

should_generate_correct_dataset(1, "> a.")
should_generate_correct_dataset(2, "> ab ba.")
should_generate_correct_dataset(3, "> abc acb bac bca cab cba.")
should_generate_correct_dataset(4, "> abcd abdc acbd acdb adbc adcb bacd badc bcad bcda bdac bdca cabd cadb cbad cbda cdab cdba dabc dacb dbac dbca dcab dcba.")

### Model parameters

1. n_layers = 1
2. d_model=128
3. d_head=64
4. n_heads=1
5. d_vocab=65
6. attn_only=True
7. n_ctx=len(prompt) (722 max)

In [5]:
from transformer_lens import EasyTransformerConfig, EasyTransformer

torch.mps.empty_cache()
torch.set_default_device("mps")

cfg = EasyTransformerConfig(
    n_layers=1,
    d_model=128,
    d_head=64,
    n_heads=1,
    d_mlp=None,
    d_vocab=len(tokenizer.vocab),
    n_ctx=n_ctx,
    act_fn="solu", # think about what activation function is best. SoLU is doing fine for now.
    attn_only=True,
)
model = EasyTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)


### Test synthetic data generation

In [14]:
def should_return_output_tensor_with_correct_shape(input: torch.Tensor, expected_shape: List[int]):
    with torch.no_grad():
        expected_shape = torch.Size(expected_shape)
        output = model(input)
        actual_shape = output.shape
    
    assert expected_shape == actual_shape, f"{expected_shape} != {actual_shape}"
    
def should_return_output_tensor_with_incorrect_prediction(input: str, expected_completion: str):
    with torch.no_grad():
        tokens = torch.tensor(tokenizer.str_to_tokens(input))
        output = model(tokens)
    
    actual_completion = tokenizer.tokens_to_str([output[:, -1, :].argmax().item()])
    assert expected_completion != actual_completion, f"{expected_completion} != {actual_completion}"

test_dataset_gen = generate_dataset(2)
test_input_seq = next(test_dataset_gen)
test_input = tokenizer.str_to_tokens(test_input_seq)
test_input = torch.tensor(test_input)

should_return_output_tensor_with_correct_shape(test_input, [1, len(test_input), model.cfg.d_vocab])
should_return_output_tensor_with_incorrect_prediction("> ab", "> ab ba.")

### Test model prediction on sample data set.

Test model outputs before it is trained.

As we can see here, the model outputs gibberish. It cannot do sequence prediction since it is not trained.

In [15]:
def sample_data_test(sample: str):
    with torch.no_grad():
        tokens = tokenizer.str_to_tokens(sample)
        i = start_pos
        output = tokens[:i]
        tokens = torch.tensor(tokens)
        while i < tokens.size(0):
            logits = model(tokens[:i])
            prediction = logits[:, -1, :].argmax().item()
            output.append(prediction)
            i += 1
        print(f"expected:{sample}")
        print(f"actual:{tokenizer.tokens_to_str(output)}")

sample_data_test(next(generate_dataset(n_comb)))

expected:> cyM6 cy6M cMy6 cM6y c6yM c6My ycM6 yc6M yMc6 yM6c y6cM y6Mc Mcy6 Mc6y Myc6 My6c M6cy M6yc 6cyM 6cMy 6ycM 6yMc 6Mcy 6Myc.
actual:> cyM44664B4B644dB66BfB6BinB6B4h6h6BB666hf66aadv66h66ia6666Bh6666dP466Bh66a6B64aa66ai6a6ha466a66Iya6466ah646P6666ah6h6aha6


### Data set with generator functions
1. Generator functions are handy tools for generating data lazily. Here `__getitem__` is a generator function
2. This becomes very handy when we deal with lots of data and can't load it to memory to work from.

In [16]:
from torch.utils.data import Dataset

class PermutationDataset(Dataset):
    def __init__(self, n=n_comb, max_len=n_ctx) -> None:
        super().__init__()
        self.generator = generate_dataset(n)
        self.n = n
        self.max_len = max_len
        
    def __len__(self) -> int:
        return self.max_len
        
    def __getitem__(self, _) -> dict:
        while True:
            for entry in self.generator:
                tokens = torch.tensor(tokenizer.str_to_tokens(entry), device="mps")
                return {"tokens": tokens}

### Define data loader

Here we define data loaders for training and test data having a split of 20/80

In [17]:
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

def collate_fn(data):
    max_size = max(d["tokens"].size(0) for d in data)
    padded_data = torch.stack([F.pad(d["tokens"], (0, max_size-d["tokens"].size(0)), value=0) for d in data])
    
    return padded_data

generator = torch.Generator(device="mps")
test, train = random_split(dataset=PermutationDataset(n=n_comb, max_len=n_ctx), lengths=[.2, .8], generator=generator)

train_data_loader = DataLoader(
    dataset=train, 
    batch_size=4, 
    generator=generator,
    collate_fn=collate_fn,
)
test_data_loader = DataLoader(
    dataset=test, 
    batch_size=4, 
    generator=generator, 
    collate_fn=collate_fn,
)

### Test Data loader shape and content

In [18]:
expected_shape = torch.Size([4, n_ctx])
for _, batch in enumerate(train_data_loader):
    assert expected_shape == batch.shape, f"{expected_shape} != {batch.shape}"
    break

## Train model
1. Use synthetically generated training data from dataloader to train the model.
2. Collect loss data at every step and generate loss vs steps graph to validate that the loss is decreasing.
3. Loss function computes the `log_softmax` of the logits for every token predicted in the sequence.
4. Initial `2+n_comb` positions need to be ignored while computing loss as they are part of prompt.

In [19]:
from tqdm.notebook import trange
from transformer_lens import EasyTransformer
from torch import optim
from transformer_lens.train import train
import os

def loss_fn(
    logits: torch.Tensor, # logits: [batch, position, d_vocab]
    tokens: torch.Tensor, # tokens: [batch, position]
    return_per_token: bool=False,
):
    logits = logits[:, start_pos:-1]
    tokens = tokens[:, start_pos+1:]
    
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if return_per_token:
        return -correct_log_probs
    return -correct_log_probs.mean()

def train(
    model: EasyTransformer,
    num_epochs=10,
    lr=1e-3,
    max_grad_norm=1.0,
    print_every=100,
    save_dir="./model_weights",
    betas=(.9, .99),
    save_every=None,
    max_steps = None,
) -> List[float]:
    model.zero_grad()
    model.init_weights()
    
    optimizer = optim.AdamW(model.parameters(), lr=lr, betas=betas)
    losses = []
    
    for epoch in trange(1, num_epochs + 1):
        samples = 0
        for step, tokens in enumerate(train_data_loader):
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            losses.append(loss.item())
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            optimizer.zero_grad()

            samples += tokens.shape[0]

            if save_every is not None and step % save_every == 0 and save_dir is not None:
                os.makedirs(save_dir, exist_ok=True)
                torch.save(model.state_dict(), f"{save_dir}/model_{step}.pt")
                
            if print_every is not None and step % print_every == 0:
                print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")

            if max_steps is not None and step >= max_steps:
                break

    return losses

losses = train(model=model, num_epochs=20)

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 Samples 4 Step 0 Loss 4.336881637573242
Epoch 2 Samples 4 Step 0 Loss 3.1034786701202393
Epoch 3 Samples 4 Step 0 Loss 1.980333685874939
Epoch 4 Samples 4 Step 0 Loss 0.4396762251853943
Epoch 5 Samples 4 Step 0 Loss 0.06941234320402145
Epoch 6 Samples 4 Step 0 Loss 0.028200622648000717
Epoch 7 Samples 4 Step 0 Loss 0.016682539135217667
Epoch 8 Samples 4 Step 0 Loss 0.012012602761387825
Epoch 9 Samples 4 Step 0 Loss 0.008823905140161514
Epoch 10 Samples 4 Step 0 Loss 0.007344107609242201
Epoch 11 Samples 4 Step 0 Loss 0.005826868582516909
Epoch 12 Samples 4 Step 0 Loss 0.004719889257103205
Epoch 13 Samples 4 Step 0 Loss 0.004131702706217766
Epoch 14 Samples 4 Step 0 Loss 0.0029435132164508104
Epoch 15 Samples 4 Step 0 Loss 0.0028161206282675266
Epoch 16 Samples 4 Step 0 Loss 0.0020409803837537766
Epoch 17 Samples 4 Step 0 Loss 0.001804783707484603
Epoch 18 Samples 4 Step 0 Loss 0.0016812746180221438
Epoch 19 Samples 4 Step 0 Loss 0.0013869601534679532
Epoch 20 Samples 4 Step 0 L

In [20]:
import plotly.express as px

px.line(losses, labels={"index": "steps", "value": "loss", "title": "Loss vs steps"})

### Observation
This is promising!

Clearly the loss decreases rapidly within first few epochs. 

Hence, model must have learned to do permutation sequence prediction given what we have observed.

We can validate this observation by testing on a test data.

### Testing model accuracy
1. Accuracy function is defined similar to a loss function before, except this time we count how many tokens in the predicted sequence match the expected sequence tokens.
2. accuracy computation is done ignoring the first `2+n_comb` positions.

In [21]:
from random import choice

def acc_fn(
    logits: torch.Tensor, # logits: [batch, position, d_vocab]
    tokens: torch.Tensor, # tokens: [batch, position]
    return_per_token: bool=False,
):
    logits = logits[:, start_pos:-1]
    preds = logits.argmax(-1)
    tokens = tokens[:, start_pos+1:]
    
    if return_per_token:
        return (preds == tokens).prod(dim=1).float().mean().item()
    return (preds == tokens).float().mean().item()

def validate_model(model: EasyTransformer):
    with torch.no_grad():
        accuracies = []
        for _, input in enumerate(test_data_loader):
            logits = model(input)
            accuracy = acc_fn(logits, input)
            accuracies.append(accuracy)
    
    return sum(accuracies)/len(accuracies)

accuracy = validate_model(model)
print(f"Accuracy: {accuracy}")

Accuracy: 1.0


Again a primising result! Accuracy is `1.0`

It's conclusive that the model has learned perfectly to generate a fixed length permutation sequence.

## Circuit analysis and observations

Since we have only 1 attention and 1 layer, this means all the permutation logic is encoded in single attention head.

Let's see if we can analyze what that attention head has actually learned after training.

We can start by plotting a heat map of attention scores between every pair of positions $(i,j)$ 

where, $i = source$ and $j = destination$.

> Moreover, we know by our design that there is no MLP layer, hence, the model is not really concerning itself with token values, but only their relative positions. 

In [23]:
def imshow(input, tensor, yaxis="", xaxis="", **kwargs):
    plot_kwargs = {
        "color_continuous_scale":"RdBu", 
        "color_continuous_midpoint":0.0, 
        "labels":{"x": xaxis, "y": yaxis}, 
        "width": 1024, 
        "height": 1024,
        "x": [f"{c},{i}" for i, c in enumerate(input)],
        "y": [f"{c},{i}" for i, c in enumerate(input)],
    }
    plot_kwargs.update(kwargs)
    fig = px.imshow(tensor, **plot_kwargs)
    fig.show()

In [24]:
from transformer_lens.utils import to_numpy

test_input = next(generate_dataset(n=n_comb))
test_input_tokens = tokenizer.str_to_tokens(test_input)
logits, cache = model.run_with_cache(torch.tensor(test_input_tokens))

attn_pattern = to_numpy(cache["pattern", 0])

limit = 500
imshow(test_input[:limit], attn_pattern[-1,-1][:limit,:limit], xaxis="src_token", yaxis="dst_token")

### Observations



In [25]:
import numpy as np

activation_threshold = 0.8
prompt = " ".join(test_input.split(" ")[:2])
print(prompt)

dst_token_idx, src_token_idx = np.nonzero(attn_pattern[-1,-1] > activation_threshold)

unqiue, counts = np.unique(src_token_idx, return_counts=True)
idx_copied = sorted(zip(counts, unqiue), reverse=True)

token_seen = set()
for _, idx in idx_copied:
    token = test_input[idx]
    if token in token_seen: continue
    
    max_act = np.max(attn_pattern[-1,-1][:,idx])
    act_positions = np.nonzero(attn_pattern[-1,-1][:,idx] > activation_threshold)[0]
    # shifted = np.concatenate([[0], act_positions])[:-1]
    # pos_diff = (act_positions-shifted)[1:]
    
    # token_seen.add(token)
    # print(f"Token: '{test_input[idx]}'")
    # print(f"Activation: {max_act}")
    # print(pos_diff)

> psM1
Token: 's'
Activation: 0.943097710609436
[2 5 5 5 5 5 7 6 3 5 7 4 5 6 3 5 7 4]
Token: 'p'
Activation: 0.9964991807937622
[ 2  5  5  5  5  6 11 10  6 24  9 21]
Token: 'M'
Activation: 0.9302446842193604
[12  9  9  5  7 40 10  4]
Token: '>'
Activation: 1.0
[1 1]
Token: '1'
Activation: 0.9898272752761841
[6]
Token: ' '
Activation: 0.9827485680580139
[5]


In [26]:
from circuitsvis.tokens import colored_tokens_multi

colored_tokens_multi(tokens=list(test_input), values=cache["pattern", 0][-1,-1])

## 

## Observations

## Questions
1. What happens if we feed duplicate characters, can it predict correctly ?
2. Why initial poistions have higher scores ? 
3. Why do we have roughly equal spaces in between ?
4. Where is the loss ? What sequence predictions are incorrect.