# Experiment

Can a 1-Layer transformer model with only a 1 attention head learn to predict correct premutation sequences ?

1. Can we possibly understand how the models develops an algorithm of permuting elements in a list ?
2. What goes on in the QK circuits ? (OV circuits are not in guestion as we there are no MLP layers)

This experiment is an attempt to answer the above questions.

## Motivation

This experiment is divised to solve one of concrete open problems proposed by Neel Nanda here:

https://www.alignmentforum.org/posts/LbrPTJ4fmABEdEnLf/200-concrete-open-problems-in-mechanistic-interpretability

### Assumptions and constraints:

Like any decent experiment, we first need to come up with assumptions and constraints that limit the scope of the problem so some progress can be tracked.

Following are some assumptions & constraints for this experiment:
1. A 1-Layer attention only transformer model is sufficient to do correct permuations term predictions.
2. ReLU activations are sufficient to begin with. Incorrect, SoLU did better.
3. Permutations of full complete groups shall be used for this exercise. Full complete groups implies - all elements are to be used to generate permutation.
4. Context window will be set to a fixed length and fixed permuation size. Analysis will be done for max permutation size of 5 elements.
5. The limit of vocab size is 62 characters (52 alphabets + 10 digits). There is a possibility of $62 \choose 5$ different sequences to make sure model learns to generalize attention pattern on positions instead of characters seen from previous sequence term.
6. Custom tokenizer (similar to ascii-coding) is used for the provided dataset. This makes the experiment setup and we can only focus on attention mechanisms of transformer which is the meat of the problem.

In [1]:
%pip install ipykernel setuptools transformer_lens ipywidgets plotly nbformat circuitsvis

Note: you may need to restart the kernel to use updated packages.


### Defining Vocabulary and Synthetic data generation:

Example generation:

> \> equ78 equ87 eq7u8 eq78u eq8u7 eq87u euq78 euq87 eu7q8 eu78q eu8q7 eu87q e7qu8 e7q8u e7uq8 e7u8q e78qu e78uq e8qu7 e8q7u e8uq7 e8u7q e87qu e87uq qeu78 qeu87 qe7u8 qe78u qe8u7 qe87u que78 que87 qu7e8 qu78e qu8e7 qu87e q7eu8 q7e8u q7ue8 q7u8e q78eu q78ue q8eu7 q8e7u q8ue7 q8u7e q87eu q87ue ueq78 ueq87 ue7q8 ue78q ue8q7 ue87q uqe78 uqe87 uq7e8 uq78e uq8e7 uq87e u7eq8 u7e8q u7qe8 u7q8e u78eq u78qe u8eq7 u8e7q u8qe7 u8q7e u87eq u87qe 7equ8 7eq8u 7euq8 7eu8q 7e8qu 7e8uq 7qeu8 7qe8u 7que8 7qu8e 7q8eu 7q8ue 7ueq8 7ue8q 7uqe8 7uq8e 7u8eq 7u8qe 78equ 78euq 78qeu 78que 78ueq 78uqe 8equ7 8eq7u 8euq7 8eu7q 8e7qu 8e7uq 8qeu7 8qe7u 8que7 8qu7e 8q7eu 8q7ue 8ueq7 8ue7q 8uqe7 8uq7e 8u7eq 8u7qe 87equ 87euq 87qeu 87que 87ueq 87uqe.

$n_{ctx} = 722$

where $n_{ctx}$ is the context length.

### Custom Tokenizer

In [2]:
import string
from typing import List

class Tokenizer:
    def __init__(self) -> None:
        self._special_chars = "> ."
        self.vocab = self._special_chars + string.ascii_letters + string.digits

    def str_to_tokens(self, s: str) -> List[int]:
        return [self.vocab.index(ch) for ch in s]

    def tokens_to_str(self, tokens: List[int]) -> str:
        return "".join([self.vocab[token] for token in tokens])
    
    def special_chars(self) -> List[str]:
        return list(self._special_chars)
        
    def vocab_without_special_chars(self) -> str:
        return self.vocab[len(self._special_chars):]

### Synthetic data

1. Use `itertools` library to generate a combination of elements from vocabulary of 62 charcaters. See [itertools](https://docs.python.org/3/library/itertools.html#itertools.combinations)
2. Generate permutation sequence from given elements. See [itertools](https://docs.python.org/3/library/itertools.html#itertools.permutations)

Properties of data:
Sample prompt:
> abcd abdc acbd acdb adbc adcb bacd badc bcad bcda bdac bdca cabd cadb cbad cbda cdab cdba dabc dacb dbac dbca dcab dcba.


Total characters:

$$seq\_len(n) = \underbrace{(n!)}_{\text{sequence terms}} \times \underbrace{(n+1)}_{\text{term length + space}} + \underbrace{2}_{\text{special chars}}$$
$$seq\_len(n) = (n+1)! + 2$$

In [3]:
import torch
from random import shuffle
from itertools import permutations, combinations
from typing import List
import math

n_comb = 5
tokenizer = Tokenizer()
[beg, space, period] = tokenizer.special_chars()

def sequence_len(n: int) -> int:
    return math.factorial(n+1)+2

def start_pos(seq_len: int) -> int: # reverse mapping for seq_len vs offset
    return {8: 4, 26: 5, 122: 6, 722: 7}[seq_len]

def permute(vocab: str) -> List[str]:
    return ["".join(item) for item in permutations(vocab, len(vocab))]

def data_from_permute(terms: List[str]) -> str:
    output = space.join(terms)
    
    return f"> {output}."

def generate_dataset(n, combs=None):
    if not combs:
        elems = tokenizer.vocab_without_special_chars()
        combs = list(combinations(elems, n))
        shuffle(combs)
    
    while True:
        for comb in combs:
            terms = permute("".join(comb))
            yield data_from_permute(terms)

n_ctx = sequence_len(n_comb)
print(n_ctx)

722


### Test generator function and tokenizer
1. Data set generator should generate expected data.
2. Conversion of `str_to_tokens` and `tokens_to_string` should do conversions correctly.
2. Model should work correctly for a sample data. Tensor shapes should match with expected shape.
3. Model should incorrectly predict the permuation sequence.

In [4]:
def should_generate_correct_dataset(n: int, expected_data: str):
    test_dataset_gen = generate_dataset(n, combs=combinations(tokenizer.vocab_without_special_chars(), n))
    actual_data = next(test_dataset_gen)
    assert expected_data == actual_data, f"{expected_data} != {actual_data}"

should_generate_correct_dataset(1, "> a.")
should_generate_correct_dataset(2, "> ab ba.")
should_generate_correct_dataset(3, "> abc acb bac bca cab cba.")
should_generate_correct_dataset(4, "> abcd abdc acbd acdb adbc adcb bacd badc bcad bcda bdac bdca cabd cadb cbad cbda cdab cdba dabc dacb dbac dbca dcab dcba.")

### Model parameters

1. n_layers = 1
2. d_model=128
3. d_head=64
4. n_heads=1
5. d_vocab=65
6. attn_only=True
7. n_ctx=len(prompt) (722 max)

In [5]:
from transformer_lens import EasyTransformerConfig, EasyTransformer

torch.mps.empty_cache()
torch.set_default_device("mps")

cfg = EasyTransformerConfig(
    n_layers=1,
    d_model=128,
    d_head=64,
    n_heads=1,
    d_mlp=None,
    d_vocab=len(tokenizer.vocab),
    n_ctx=n_ctx,
    act_fn="solu", # think about what activation function is best. SoLU does better than ReLU.
    attn_only=True,
)
model = EasyTransformer(cfg)
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)


### Test synthetic data generation

In [6]:
def should_return_output_tensor_with_correct_shape(input: torch.Tensor, expected_shape: List[int]):
    with torch.no_grad():
        expected_shape = torch.Size(expected_shape)
        output = model(input)
        actual_shape = output.shape
    
    assert expected_shape == actual_shape, f"{expected_shape} != {actual_shape}"
    
def should_return_output_tensor_with_incorrect_prediction(input: str, expected_completion: str):
    with torch.no_grad():
        tokens = torch.tensor(tokenizer.str_to_tokens(input))
        output = model(tokens)
    
    actual_completion = tokenizer.tokens_to_str([output[:, -1, :].argmax().item()])
    assert expected_completion != actual_completion, f"{expected_completion} != {actual_completion}"

test_dataset_gen = generate_dataset(2)
test_input_seq = next(test_dataset_gen)
test_input = tokenizer.str_to_tokens(test_input_seq)
test_input = torch.tensor(test_input)

should_return_output_tensor_with_correct_shape(test_input, [1, len(test_input), model.cfg.d_vocab])
should_return_output_tensor_with_incorrect_prediction("> ab", "> ab ba.")

### Test model prediction on sample data set.

Test model outputs before it is trained.

In [7]:
def sample_data_test(expected: str):
    with torch.no_grad():
        offset = start_pos(len(expected))
        tokens = tokenizer.str_to_tokens(expected)
        prefix = tokens[:offset]
        tokens = torch.tensor(tokens)
        i = offset
        output = []
        while i < len(expected):
            logits = model(tokens[:i])
            prediction = logits[:, -1, :].argmax().item()
            output.append(prediction)
            i += 1
        print(f"expected:{expected}")
        print(f"actual:{tokenizer.tokens_to_str(prefix + output)}")

sample_data_test(next(generate_dataset(3)))

expected:> kw1 k1w wk1 w1k 1kw 1wk.
actual:> kw1UUUUUUUUUUUUUUUUUUUUU


Perhaps, not surprisingly, the model outputs gibberish. It cannot do sequence prediction yet, since it is not trained.

### Data set with generator functions
1. Generator functions are handy tools for generating data lazily. Here `__getitem__` is a generator function
2. This becomes very handy when we deal with lots of data and can't load it to memory to work from.

In [8]:
from torch.utils.data import Dataset
from random import choice

class PermutationDataset(Dataset):
    def __init__(
        self, 
        n:int=n_comb, 
        max_len:int=n_ctx, 
        all_combs:bool=False,
    ) -> None:
        super().__init__()
        generators = [generate_dataset(n)]
        if all_combs:
            generators = [generate_dataset(i) for i in range(2, n+1)]
        
        self.generators = generators
        self.n = n
        self.max_len = max_len
        
    def __len__(self) -> int:
        return 1000
        
    def __getitem__(self, _) -> dict:
        while True:
            for entry in choice(self.generators):
                tokens = torch.tensor(tokenizer.str_to_tokens(entry), device="mps")
                return {"tokens": tokens, "length": len(entry)}

### Define data loader

1. $test \ratio train = 20 \ratio 80$
2. `collate_fn` is defined so we can train the model on variable sequence length (if needed).
3. For now, only fixed length dataset is generated. This is to make the attention head pattern more comprehensible. 

> To study the pattern for variable sequence length, just feed `all_combs=True` to `PermutationDataset` constructor.

In [9]:
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

def collate_fn(data):
    pad_token = tokenizer.special_chars().index(space)
    batch_size = len(data)
    max_len = max([d["length"] for d in data])
    
    padded_data = torch.stack([F.pad(d["tokens"], (0, max_len-d["length"]), value=pad_token) for d in data])
    attention_mask = torch.tensor([start_pos(d["length"])-1 for d in data]).reshape(batch_size, 1)
    attention_mask = (torch.arange(max_len).repeat(batch_size).reshape(batch_size, max_len) >= attention_mask).int()
    
    return {"tokens": padded_data, "attention_mask": attention_mask}

# Switching of all_combs=True flag for now. To do an easier the analysis.
dataset = PermutationDataset(n=n_comb, max_len=n_ctx)
generator = torch.Generator(device="mps")
test, train = random_split(dataset=dataset, lengths=[.2, .8], generator=generator)

train_data_loader = DataLoader(
    dataset=train, 
    batch_size=4, 
    generator=generator,
    collate_fn=collate_fn,
)
test_data_loader = DataLoader(
    dataset=test, 
    batch_size=4, 
    generator=generator, 
    collate_fn=collate_fn,
)

### Test shape and content of Data loader

In [10]:
expected_batch_size = 4

for index, batch in enumerate(train_data_loader):
    for tokens, attn_mask in zip(batch["tokens"], batch["attention_mask"]):
        print(f"{tokenizer.tokens_to_str(tokens)}")
    
    assert expected_batch_size == batch["tokens"].shape[0], f"{expected_batch_size} != {batch.shape[0]}"
    if index == 10: break

> ipszF ipsFz ipzsF ipzFs ipFsz ipFzs ispzF ispFz iszpF iszFp isFpz isFzp izpsF izpFs izspF izsFp izFps izFsp iFpsz iFpzs iFspz iFszp iFzps iFzsp piszF pisFz pizsF pizFs piFsz piFzs psizF psiFz psziF pszFi psFiz psFzi pzisF pziFs pzsiF pzsFi pzFis pzFsi pFisz pFizs pFsiz pFszi pFzis pFzsi sipzF sipFz sizpF sizFp siFpz siFzp spizF spiFz spziF spzFi spFiz spFzi szipF sziFp szpiF szpFi szFip szFpi sFipz sFizp sFpiz sFpzi sFzip sFzpi zipsF zipFs zispF zisFp ziFps ziFsp zpisF zpiFs zpsiF zpsFi zpFis zpFsi zsipF zsiFp zspiF zspFi zsFip zsFpi zFips zFisp zFpis zFpsi zFsip zFspi Fipsz Fipzs Fispz Fiszp Fizps Fizsp Fpisz Fpizs Fpsiz Fpszi Fpzis Fpzsi Fsipz Fsizp Fspiz Fspzi Fszip Fszpi Fzips Fzisp Fzpis Fzpsi Fzsip Fzspi.
> betE5 bet5E beEt5 beE5t be5tE be5Et bteE5 bte5E btEe5 btE5e bt5eE bt5Ee bEet5 bEe5t bEte5 bEt5e bE5et bE5te b5etE b5eEt b5teE b5tEe b5Eet b5Ete ebtE5 ebt5E ebEt5 ebE5t eb5tE eb5Et etbE5 etb5E etEb5 etE5b et5bE et5Eb eEbt5 eEb5t eEtb5 eEt5b eE5bt eE5tb e5btE e5bEt e5tbE e5tEb

## Train model
1. Use synthetically generated training data from dataloader to train the model.
2. Collect loss data at every step and generate loss vs steps graph to validate that the loss is decreasing.
3. Loss function computes the `log_softmax` of the logits for every token predicted in the sequence.
4. Attention mask is used to ignore tokens until first sequence term for loss calulcation. Sequence cannot be predicted until first sequence term is supplied.

### Hyper parameters:
1. learning rate    =0.0001
2. num of epochs    =10 (default)
3. optimizer betas  =(0.9, 0.99)

In [11]:
from tqdm.notebook import trange
from transformer_lens import EasyTransformer
from transformer_lens.utils import lm_cross_entropy_loss
from torch import optim
from transformer_lens.train import train
import os

def train(
    model: EasyTransformer,
    num_epochs=10,
    lr=1e-4,
    max_grad_norm=1.0,
    print_every=100,
    save_dir="./model_weights",
    betas=(.9, .99),
    save_every=None,
    max_steps = None,
) -> List[float]:
    model.zero_grad()
    model.init_weights()
    
    optimizer = optim.AdamW(model.parameters(), lr=lr, betas=betas)
    losses = []
    
    for epoch in trange(1, num_epochs + 1):
        samples = 0
        for step, batch in enumerate(train_data_loader):
            tokens = batch["tokens"]
            attention_mask = batch["attention_mask"]
            logits = model(input=tokens)
            
            loss = lm_cross_entropy_loss(logits, tokens, attention_mask)
            losses.append(loss.item())
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            optimizer.zero_grad()

            samples += tokens.shape[0]

            if save_every is not None and step % save_every == 0 and save_dir is not None:
                os.makedirs(save_dir, exist_ok=True)
                torch.save(model.state_dict(), f"{save_dir}/model_{step}.pt")
                
            if print_every is not None and step % print_every == 0:
                print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")

            if max_steps is not None and step >= max_steps:
                break

    return losses

losses = train(model=model, num_epochs=15)

  0%|          | 0/15 [00:00<?, ?it/s]

Epoch 1 Samples 4 Step 0 Loss 4.427218437194824
Epoch 1 Samples 404 Step 100 Loss 3.754289388656616
Epoch 2 Samples 4 Step 0 Loss 3.377424955368042
Epoch 2 Samples 404 Step 100 Loss 2.9508118629455566
Epoch 3 Samples 4 Step 0 Loss 2.275911569595337
Epoch 3 Samples 404 Step 100 Loss 1.5159636735916138
Epoch 4 Samples 4 Step 0 Loss 0.92534339427948
Epoch 4 Samples 404 Step 100 Loss 0.38351601362228394
Epoch 5 Samples 4 Step 0 Loss 0.13572341203689575
Epoch 5 Samples 404 Step 100 Loss 0.060301944613456726
Epoch 6 Samples 4 Step 0 Loss 0.03562820330262184
Epoch 6 Samples 404 Step 100 Loss 0.021080872043967247
Epoch 7 Samples 4 Step 0 Loss 0.014321811497211456
Epoch 7 Samples 404 Step 100 Loss 0.010190415196120739
Epoch 8 Samples 4 Step 0 Loss 0.006638733204454184
Epoch 8 Samples 404 Step 100 Loss 0.003868614323437214
Epoch 9 Samples 4 Step 0 Loss 0.0024447825271636248
Epoch 9 Samples 404 Step 100 Loss 0.0017125402810052037
Epoch 10 Samples 4 Step 0 Loss 0.001101108267903328
Epoch 10 Sample

### Plot the loss vs steps graph

In [12]:
import plotly.express as px

px.line(losses, labels={"index": "steps", "value": "loss", "title": "Loss vs steps", "variable": "loss"})

<img src="loss_vs_steps_graph.png">

### Observation
This is promising!

Clearly the loss decreases rapidly within first few epochs. 

Hence, model must have learned to do permutation sequence prediction given what we have observed.

We can validate this observation by testing on a test data.

### Testing model accuracy
1. Accuracy function is defined similar to a loss function before, except this time we count how many tokens in the predicted sequence match the expected sequence tokens.
2. Accuracy computation is done ignoring the first `2+n_comb` positions. This is needed to make sure model is not penalised unnecessarily.

In [13]:
from transformer_lens.utils import lm_accuracy

def validate_model(model: EasyTransformer):
    with torch.no_grad():
        accuracies = []
        for _, batch in enumerate(test_data_loader):
            logits = model(input=batch["tokens"])
            accuracy = lm_accuracy(logits, batch["tokens"])
            accuracies.append(accuracy)
    
    return sum(accuracies)/len(accuracies)

accuracy = validate_model(model)
print(f"Accuracy: {accuracy}")

Accuracy: 0.9916845560073853


Accuracy is very close to `~1.0`!

We may conclude that the model has indeed learned how to generate a permutation sequence. Where the length of the sequence is between the range of [2, 5].

Let's check if the output being generated by model is accurate enough.

In [14]:
def check(batch: torch.Tensor):
    attn_mask = batch["attention_mask"]
    tokens = batch["tokens"]
    logits = model(input=tokens)
    
    predictions = logits.argmax(-1)
    for i in range(len(predictions)):
        print(tokenizer.tokens_to_str(tokens[i]))
        zeros = attn_mask.shape[1] - torch.count_nonzero(attn_mask[i]).item()
        prefix = torch.cat((tokens[i,:zeros+1], predictions[i,zeros:]), dim=-1)
        print(tokenizer.tokens_to_str(prefix))
        
    print(f"Accuracy: {lm_accuracy(logits, batch["tokens"])}")
    print(f"Loss: {lm_cross_entropy_loss(logits, batch["tokens"], attn_mask)}")

batch = next(iter(train_data_loader))
check(batch)

> fgjPQ fgjQP fgPjQ fgPQj fgQjP fgQPj fjgPQ fjgQP fjPgQ fjPQg fjQgP fjQPg fPgjQ fPgQj fPjgQ fPjQg fPQgj fPQjg fQgjP fQgPj fQjgP fQjPg fQPgj fQPjg gfjPQ gfjQP gfPjQ gfPQj gfQjP gfQPj gjfPQ gjfQP gjPfQ gjPQf gjQfP gjQPf gPfjQ gPfQj gPjfQ gPjQf gPQfj gPQjf gQfjP gQfPj gQjfP gQjPf gQPfj gQPjf jfgPQ jfgQP jfPgQ jfPQg jfQgP jfQPg jgfPQ jgfQP jgPfQ jgPQf jgQfP jgQPf jPfgQ jPfQg jPgfQ jPgQf jPQfg jPQgf jQfgP jQfPg jQgfP jQgPf jQPfg jQPgf PfgjQ PfgQj PfjgQ PfjQg PfQgj PfQjg PgfjQ PgfQj PgjfQ PgjQf PgQfj PgQjf PjfgQ PjfQg PjgfQ PjgQf PjQfg PjQgf PQfgj PQfjg PQgfj PQgjf PQjfg PQjgf QfgjP QfgPj QfjgP QfjPg QfPgj QfPjg QgfjP QgfPj QgjfP QgjPf QgPfj QgPjf QjfgP QjfPg QjgfP QjgPf QjPfg QjPgf QPfgj QPfjg QPgfj QPgjf QPjfg QPjgf.
> fgjPQ fgjQP fgPjQ fgPQj fgQjP fgQPj fjgPQ fjgQP fjPgQ fjPQg fjQgP fjQPg fPgjQ fPgQj fPjgQ fPjQg fPQgj fPQjg fQgjP fQgPj fQjgP fQjPg fQPgj fQPjg gfjPQ gfjQP gfPjQ gfPQj gfQjP gfQPj gjfPQ gjfQP gjPfQ gjPQf gjQfP gjQPf gPfjQ gPfQj gPjfQ gPjQf gPQfj gPQjf gQfjP gQfPj gQjfP gQjPf

Pretty accurate! We can now try to understand what model has learned. Maybe we can get some insights.

## Observations

### Overview

Since we have only 1 attention and 1 layer, this means all the permutation generation logic is encoded in the single attention head.

Let's see if we can analyze what that attention head has actually learned after training.

We can start by plotting a heat map of attention scores between every pair of positions $(i,j)$ 

where, $i = source$ and $j = destination$.

> Note: This model only has 1-Layer with only 1 attention head. This means the OV circuit behaviour will not come into play. 

From the analysis in this paper: [Detecting copy behaviour](https://transformer-circuits.pub/2021/framework/index.html#:~:text=DETECTING%20COPYING%20BEHAVIOR)

A single attention head is capable of doing copy behaviour.

In [15]:
from transformer_lens.utils import to_numpy

def imshow(tensor, yaxis="", xaxis="", **kwargs):
    tensor = to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale":"RdBu", 
        "color_continuous_midpoint":0.0, 
        "labels":{"x": xaxis, "y": yaxis}, 
        "width": 1024, 
        "height": 1024,
    }
    plot_kwargs.update(kwargs)
    fig = px.imshow(tensor, **plot_kwargs)
    fig.show()

In [16]:
test_input = next(generate_dataset(n=4))
test_input_tokens = tokenizer.str_to_tokens(test_input)
logits, cache = model.run_with_cache(torch.tensor(test_input_tokens))

for k in cache:
    print(f"{k} {cache[k].shape}")

attn_pattern = cache["pattern", 0]
print(attn_pattern.shape)

limit = sequence_len(n_comb)
opts = {
    "x": [f"{c},{i}" for i, c in enumerate(test_input[:limit])],
    "y": [f"{c},{i}" for i, c in enumerate(test_input[:limit])],
}

print(test_input)
imshow(attn_pattern[-1,-1][:limit,:limit], xaxis="src_token", yaxis="dst_token", **opts)

hook_embed torch.Size([1, 122, 128])
hook_pos_embed torch.Size([1, 122, 128])
blocks.0.hook_resid_pre torch.Size([1, 122, 128])
blocks.0.ln1.hook_scale torch.Size([1, 122, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 122, 128])
blocks.0.attn.hook_q torch.Size([1, 122, 1, 64])
blocks.0.attn.hook_k torch.Size([1, 122, 1, 64])
blocks.0.attn.hook_v torch.Size([1, 122, 1, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 1, 122, 122])
blocks.0.attn.hook_pattern torch.Size([1, 1, 122, 122])
blocks.0.attn.hook_z torch.Size([1, 122, 1, 64])
blocks.0.hook_attn_out torch.Size([1, 122, 128])
blocks.0.hook_resid_post torch.Size([1, 122, 128])
ln_final.hook_scale torch.Size([1, 122, 1])
ln_final.hook_normalized torch.Size([1, 122, 128])
torch.Size([1, 1, 122, 122])
> sG01 sG10 s0G1 s01G s1G0 s10G Gs01 Gs10 G0s1 G01s G1s0 G10s 0sG1 0s1G 0Gs1 0G1s 01sG 01Gs 1sG0 1s0G 1Gs0 1G0s 10sG 10Gs.


## Questions
1. What does the figure showing attention pattern mean ?
2. Why top left corner of the attention head pattern shows higher scores (darker blue colors) ? 
3. Why do we have regular spacing between activations
4. What happens if we feed duplicate characters, will the model predict sequence correctly ?

### Analysis and conclusion

Given permutation sequence starting with `> MNR5`

<img src="./attn_heads_heat_map_1.png" height=800>

1. Zooming in, it is clear that the permutation sequence is being generated by doing copy operations by the QK circuit. Note, there is no OV circuit in this model.
    
    <img src="./analysis_copy_head.png" height=500>
    
    Here in this graph, we can see that attention copies the token from position 14 to position 19. In this case `N` is copied to the position where `5` is. 
    The attention head is 99% certain that copying the token from 14 to 19 position is the right thing to do. This learning happened during the training phase.
    
    Following this, the next `dst_token` at position 20 is `N` (as expected) and it has 2 activations both of which are `" "` coming from 2 different positions `16 & 1`. 
    Since, both are correct choices, model choses ` ` from `src_postion = 1` in this case. The process continues. 
2. Activations of `src_token = " "` takes up regular intervals of `5` in vertical direction.
    
    This makes sense, as the term length is of `5` characters and activation/copying of `" "` will happen regularly.
3. More interesting question is what is the pattern of activations of non-space characters ? 
    
    Other characters don't follow a strict spacing pattern of activation, as the characters once in a while do a swap in subsequent terms. 
    Due to this, the spacing interval is not regular and grows and shrinks regularly.
    I have not been able to work out the mathematics for this yet. But I suspect it is something that other researchers might have already figured out.

4. There are no significant activations/copying of `src_tokens` beyond position number `15`

    One explanation could be that since the characters repeat quite a lot in the sequence, it is sufficient for the model to just pay attention to first few character positions to predict the whole sequence. Similar to the explanation of 2 activations of space at position number `20`.
    We can test this assumption by shrinking the context length of the model under `122` characters, the model should then activate more positions in the lower diagonal.


# Conclusion

The Attention pattern heat map seems to point towards the core of how this model is predicting the next token of the sequence. 
More rigor is required to prove that QK offset is doing what has been hypothesized but given evidence seems fairly convincing.

Working on this problem was a great learning opportunity and I hope to contribute to the field of Mechanistic interpretability in future. 
Thanks again Neel! Your contributions have been amazing with `transformer_lens` library. I hope we get to collaborate in future.

### Visualisation/Demo of attention on token

`circuitvis` is neat library that helps with such visualizations.

In [17]:
from circuitsvis.tokens import colored_tokens_multi

print(cache["pattern", 0][-1,-1].shape)
colored_tokens_multi(tokens=list(test_input), values=cache["pattern", 0][-1,-1])

torch.Size([122, 122])


### Follow up. can the model do sequence generation for duplicate tokens ?

Yes. Even though the data set did not contain repeating characters, the model can still output sequence for repeating characters. 
This shows that model has learned sequence generation using positions and not the value of tokens itself.

In [18]:
# character g is duplicate
dup_data = "> ghg4 gh4g ggh4 gg4h g4hg g4gh hgg4 hg4g hgg4 hg4g h4gg h4gg ggh4 gg4h ghg4 gh4g g4gh g4hg 4ghg 4ggh 4hgg 4hgg 4ggh 4ghg."
sample_data_test(dup_data)

expected:> ghg4 gh4g ggh4 gg4h g4hg g4gh hgg4 hg4g hgg4 hg4g h4gg h4gg ggh4 gg4h ghg4 gh4g g4gh g4hg 4ghg 4ggh 4hgg 4hgg 4ggh 4ghg.
actual:> ghg4hgghg 4 ghgg  ghg g g4 gg gh ggggghg4 ggg4g gggg  ggg g gg ggggg gg gggg  g4g g gggg4 ggg4gggg4gg gg4gg g4ggg g4ggg 
