# Problem statement
Understand how the small transformer models come up with the algorithm of permuting elements in a list. 
What goes on in the QK and OV circuits. Are there more ways to explore what goes on in under the hood.

The problem is borrowed from this post and it explains the motivation for solving it.
https://www.alignmentforum.org/posts/LbrPTJ4fmABEdEnLf/200-concrete-open-problems-in-mechanistic-interpretability

### Assumptions and constraints:
1. A small transformer model (1 to 2-Layers) will be sufficient at coming up with correct permuations. To test this assumption. 
2. ReLU activation are sufficient to begin with. To test this assumption.
3. Permutations of single groups shall be used for this exercise. **Explain what is meant by single groups here.**
4. Context window will be set to 1024. Which means we may use at max 5-6 elements at most.
5. Since the input data size is limited (max 5 elements), we may limit the vocab size to 62 characters (52 alphabets + 10 digits). Then we have, $62 \choose 5$ different sequences to make sure model learns to generalize attention pattern on positions instead of characters seen from previous sequence term.
6. A very simple custom tokenizer works fine for the provided dataset. To test this assumption. 

In [None]:
%pip install ipykernel setuptools transformer_lens ipywidgets plotly nbformat circuitsvis

## Load model

In [None]:
len("> abcd abdc acbd acdb adbc adcb bacd badc bcad bcda bdac bdca cabd cadb cbad cbda cdab cdba dabc dacb dbac dbca dcab dcba.")

In [1]:
from typing import List
import torch
from transformer_lens import EasyTransformerConfig, EasyTransformer
import string

special_chars = "> ."
[beg, space, period] = list(special_chars)
elements = special_chars + string.ascii_letters + string.digits

torch.mps.empty_cache()
torch.set_default_device("mps")

cfg = EasyTransformerConfig(
    n_layers=1,
    d_model=128,
    d_head=64,
    n_heads=1,
    # d_mlp=64,
    d_mlp=None,
    d_vocab=len(elements),
    n_ctx=122, #> abcd abdc acbd acdb adbc adcb bacd badc bcad bcda bdac bdca cabd cadb cbad cbda cdab cdba dabc dacb dbac dbca dcab dcba.
    act_fn="solu",
    # act_fn="relu",
    attn_only=True,
)
model = EasyTransformer(cfg)

def str_to_tokens(s: str) -> List[int]:
    return [elements.index(ch) for ch in s]

def tokens_to_str(tokens: List[int]) -> str:
    return "".join([elements[token] for token in tokens])

print(model)
print(model.cfg.d_vocab)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (ln_final): LayerNorm(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (unembed): Unembed()
)
65


## Create synthetic data generator
1. Use `itertools` library to first pick a combination of elements from vocabulary of 62 charcaters. See [itertools](https://docs.python.org/3/library/itertools.html#itertools.combinations)
2. Generate permutation sequence from given elements. See [itertools](https://docs.python.org/3/library/itertools.html#itertools.permutations)

In [2]:
from itertools import permutations, combinations
from torch.utils.data import Dataset
from random import shuffle

def permute(elements: str) -> List[str]:
    return ["".join(item) for item in permutations(elements, len(elements))]

def data_from_permute(terms: List[str]) -> str:
    output = space.join(terms)
    
    return f"> {output}."

def generate_dataset(n, combs=None):
    if not combs:
        elems = elements[len(special_chars):]
        combs = list(combinations(elems, n))
        shuffle(combs)
    # excluding special tokens in the beginning
    while True:
        for comb in combs:
            terms = permute("".join(comb))
            yield data_from_permute(terms)

class PermutationDataset(Dataset):
    def __init__(self, n=5, max_len=1024) -> None:
        super().__init__()
        self.generator = generate_dataset(n)
        self.n = n
        self.max_len = max_len
        
    def __len__(self) -> int:
        return self.max_len
        
    def __getitem__(self, _) -> dict:
        while True:
            for entry in self.generator:
                tokens = torch.tensor(str_to_tokens(entry), device="mps")
                return {"tokens": tokens}

print(next(generate_dataset(3)))
print(next(generate_dataset(4)))

> drB dBr rdB rBd Bdr Brd.
> JV56 JV65 J5V6 J56V J6V5 J65V VJ56 VJ65 V5J6 V56J V6J5 V65J 5JV6 5J6V 5VJ6 5V6J 56JV 56VJ 6JV5 6J5V 6VJ5 6V5J 65JV 65VJ.


## Test
1. Data set generator should generate expected data.
2. Conversion of `str_to_tokens` and `tokens_to_string` should do conversions correctly.
2. Model should work correctly for a sample data. Tensor shapes should match with expected shape.
3. Model should incorrectly predict the permuation sequence.

In [3]:
def should_generate_correct_dataset(n: int, expected_data: str):
    test_dataset_gen = generate_dataset(n, combs=combinations(elements[len(special_chars):], n))
    actual_data = next(test_dataset_gen)
    assert expected_data == actual_data, f"{expected_data} != {actual_data}"

should_generate_correct_dataset(1, "> a.")
should_generate_correct_dataset(2, "> ab ba.")
should_generate_correct_dataset(3, "> abc acb bac bca cab cba.")
should_generate_correct_dataset(4, "> abcd abdc acbd acdb adbc adcb bacd badc bcad bcda bdac bdca cabd cadb cbad cbda cdab cdba dabc dacb dbac dbca dcab dcba.")

In [4]:
def should_return_output_tensor_with_correct_shape(input: torch.Tensor, expected_shape: List[int]):
    expected_shape = torch.Size(expected_shape)
    output = model(input)
    actual_shape = output.shape
    
    assert expected_shape == actual_shape, f"{expected_shape} != {actual_shape}"
    
def should_return_output_tensor_with_incorrect_prediction(input: str, expected_completion: str):
    tokens = torch.tensor(str_to_tokens(input))
    output = model(tokens)
    
    actual_completion = tokens_to_str([output[:, -1, :].argmax().item()])
    assert expected_completion != actual_completion, f"{expected_completion} != {actual_completion}"

test_dataset_gen = generate_dataset(2)
test_input_seq = next(test_dataset_gen)
test_tokens = str_to_tokens(test_input_seq)
input = torch.tensor(test_tokens)

should_return_output_tensor_with_correct_shape(input, [1, len(test_tokens), model.cfg.d_vocab])
should_return_output_tensor_with_incorrect_prediction("> ab ab.", "b")

### Sample run for a dataset

In [8]:
def sample_data_test(sample: str):
    tokens = str_to_tokens(sample)
    i = 1
    output = ""
    while i < len(tokens):
        input = tokens[:i]
        logits = model(torch.tensor(input))
        prediction = logits[:, -1, :].argmax().item()
        # prediction = tokens_to_str(prediction)
        output += tokens_to_str([prediction])
        i += 1
    print(f"input:{sample}")
    print(f"prediction:{output}")

sample_data_test(next(generate_dataset(4)))

input:> aegY aeYg ageY agYe aYeg aYge eagY eaYg egaY egYa eYag eYga gaeY gaYe geaY geYa gYae gYea Yaeg Yage Yeag Yega Ygae Ygea.
prediction:xmUD..xDDbCxlFbJ7lcxOWlxVVxq VOQclcxQO CVQO xCxC xlVVxxFxOxcexO exb KCeQ b CxcexeccxleVFlelxxeeleCcFexCeqLxceOcxxVlOcxOel


## Train model
1. Use synthetic data to train the model.
2. Observe and analyse the loss. Generate graphs.

In [12]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

def loss_fn(logits, tokens, return_per_token=False):
    # logits: [batch, position, d_vocab]
    # tokens: [batch, position]
    logits = logits[:, :-1]
    tokens = tokens[:, 1:]
    
    log_probs = logits.log_softmax(-1)
    correct_log_probs = log_probs.gather(-1, tokens[..., None])[..., 0]
    if return_per_token:
        return -correct_log_probs
    return -correct_log_probs.mean()

def collate_fn(data):
    max_size = max(d["tokens"].size(0) for d in data)
    padded_data = torch.stack([F.pad(d["tokens"], (0, max_size-d["tokens"].size(0)), value=0) for d in data])
    
    return padded_data

dataset=PermutationDataset(n=4, max_len=122)
dataloader = DataLoader(
    dataset=dataset, 
    batch_size=4, 
    shuffle=True, 
    generator=torch.Generator(device="mps"), 
    collate_fn=collate_fn,
)

### Test Data loader shape and content

In [13]:
from tqdm.notebook import trange
from transformer_lens import EasyTransformer
from torch import optim
from transformer_lens.train import train
import os

def train(
    model: EasyTransformer,
    num_epochs=100,
    lr=1e-3,
    max_grad_norm=1.0,
    print_every=100,
    save_dir="./model_weights",
    betas=(.9, .99),
    save_every=None,
    max_steps = None,
) -> List[float]:
    model.zero_grad()
    model.init_weights()
    
    optimizer = optim.AdamW(model.parameters(), lr=lr, betas=betas)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=100)
    losses = []
    
    for epoch in trange(1, num_epochs + 1):
        samples = 0
        for step, tokens in enumerate(dataloader):
            logits = model(tokens)
            loss = loss_fn(logits, tokens)
            losses.append(loss.item())
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step(loss)

            samples += tokens.shape[0]

            if save_every is not None and step % save_every == 0 and save_dir is not None:
                os.makedirs(save_dir, exist_ok=True)
                torch.save(model.state_dict(), f"{save_dir}/model_{step}.pt")
                
            if print_every is not None and step % print_every == 0:
                print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")

            if max_steps is not None and step >= max_steps:
                break

    return losses

losses = train(model=model)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 Samples 4 Step 0 Loss 4.369736671447754
Epoch 2 Samples 4 Step 0 Loss 3.0941827297210693
Epoch 3 Samples 4 Step 0 Loss 1.057798147201538
Epoch 4 Samples 4 Step 0 Loss 0.2768070101737976
Epoch 5 Samples 4 Step 0 Loss 0.16858747601509094
Epoch 6 Samples 4 Step 0 Loss 0.15389135479927063
Epoch 7 Samples 4 Step 0 Loss 0.14897415041923523
Epoch 8 Samples 4 Step 0 Loss 0.14437109231948853
Epoch 9 Samples 4 Step 0 Loss 0.14000502228736877
Epoch 10 Samples 4 Step 0 Loss 0.1377730518579483
Epoch 11 Samples 4 Step 0 Loss 0.13165484368801117
Epoch 12 Samples 4 Step 0 Loss 0.13296104967594147
Epoch 13 Samples 4 Step 0 Loss 0.1330646276473999
Epoch 14 Samples 4 Step 0 Loss 0.12676680088043213
Epoch 15 Samples 4 Step 0 Loss 0.14143575727939606
Epoch 16 Samples 4 Step 0 Loss 0.1309826672077179
Epoch 17 Samples 4 Step 0 Loss 0.12433738261461258
Epoch 18 Samples 4 Step 0 Loss 0.12415116280317307
Epoch 19 Samples 4 Step 0 Loss 0.13273538649082184
Epoch 20 Samples 4 Step 0 Loss 0.1330892294645309

In [14]:
import plotly.express as px

px.line(losses, labels={"index": "steps", "value": "loss", "title": "Loss vs steps"})

> Why is model training showing spike ?

> Why is model loss not converging to 0 ?

## Circuit analysis and observations

In [15]:
from circuitsvis.attention import attention_heads

def run_with_cache(prompt):
    cache = None
    input_tokens = str_to_tokens(prompt)
    while len(input_tokens) < cfg.n_ctx:
        output, cache = model.run_with_cache(torch.tensor(input_tokens))
        predicted = output[:, -1, :].argmax().item()
        input_tokens.append(predicted)
    
    return tokens_to_str(input_tokens), cache

def analyze_attention_heads():
    expected = next(generate_dataset(4))
    prompt = " ".join(expected.split(" ")[:2])
    
    print(f"prompt:{prompt}")
    completion, cache = run_with_cache(prompt)
    
    print(f"expected:{expected}")
    print(f"completion:{completion}")
    
    return expected, completion, cache

expected, completion, cache = analyze_attention_heads()
attention_heads(tokens=list(completion), attention=cache["pattern", 0][0])

prompt:> oHOP
expected:> oHOP oHPO oOHP oOPH oPHO oPOH HoOP HoPO HOoP HOPo HPoO HPOo OoHP OoPH OHoP OHPo OPoH OPHo PoHO PoOH PHoO PHOo POoH POHo.
completion:> oHOP oHPO oOHP oOPH oPHO oPOH HoOP HoPO HOoP HOPo HPoO HPOo OoHP OoPH OHoP OHPo OPoH OPHo PoHO PoOH PHoO PHOo POoH POHo.


In [16]:
from circuitsvis.tokens import colored_tokens_multi

print(completion)
colored_tokens_multi(tokens=list(completion[:-1]), values=cache["pattern", 0][-1, -1])

> oHOP oHPO oOHP oOPH oPHO oPOH HoOP HoPO HOoP HOPo HPoO HPOo OoHP OoPH OHoP OHPo OPoH OPHo PoHO PoOH PHoO PHOo POoH POHo.
