<a href="https://colab.research.google.com/github/nmonson1/walkthroughs/blob/main/August_2023_Monthly_Algorithmic_Challenge_First_Unique_Character.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Monthly Algorithmic Challenge (August 2023): First Unique Character

This post is the second in the sequence of monthly mechanistic interpretability challenges. They are designed in the spirit of [Stephen Casper's challenges](https://www.lesswrong.com/posts/KSHqLzQscwJnv44T8/eis-vii-a-challenge-for-mechanists), but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far.

If you prefer, you can access the Streamlit page [here](https://arena-ch1-transformers.streamlit.app/Monthly_Algorithmic_Problems).

<img src="https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/writer.png" width="350">

## Setup

In [None]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter1_transformers"):
        !curl -o /content/main.zip https://codeload.github.com/callummcdougall/ARENA_2.0/zip/refs/heads/main
        !unzip /content/main.zip 'ARENA_2.0-main/chapter1_transformers/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter1_transformers/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter1_transformers", "chapter1_transformers")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter1_transformers/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m41.0/42.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m786.6 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting jaxtyping
  Downloading jaxtyping-0.2.20-py3-none-any.whl (24 kB)
Collecting typeguard>=2.13.3 (from jaxtyping)
  Downloading typeguard-4.1.0-py3-none-any.whl (33 kB)
Installing collected packages: typeguard, jaxtyping
Successfully installed jaxtyping-0.2.20 typeguard-4.1.0
Collecting transformer_lens
  Downloading transformer_lens-1.4.0-py3-none-any.whl (105 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.7/105.7 kB[0m [31m2.2 MB/s[0m 

In [None]:
import torch as t
from pathlib import Path

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "monthly_algorithmic_problems" / "august23_unique_char"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from monthly_algorithmic_problems.august23_unique_char.dataset import UniqueCharDataset
from monthly_algorithmic_problems.august23_unique_char.model import create_model
from plotly_utils import hist, bar, imshow

device = t.device("cuda" if t.cuda.is_available() else "cpu")



## Prerequisites

The following ARENA material should be considered essential:

* **[1.1] Transformer from scratch** (sections 1-3)
* **[1.2] Intro to Mech Interp** (sections 1-3)

The following material isn't essential, but is very strongly recommended:

* **[1.2] Intro to Mech Interp** (section 4)
* **[1.4] Balanced Bracket Classifier** (all sections)


## Difficulty

This problem is a step up in difficulty to the July problem. The algorithmic problem is of a similar flavour, and the model architecture is very similar (the main difference is that this model has 3 attention heads per layer, instead of 2).


## Motivation

Neel Nanda's post [200 COP in MI: Interpreting Algorithmic Problems](https://www.lesswrong.com/posts/ejtFsvyhRkMofKAFy/200-cop-in-mi-interpreting-algorithmic-problems) does a good job explaining the motivation behind solving algorithmic problems such as these. I'd strongly recommend reading the whole post, because it also gives some high-level advice for approaching such problems.

The main purpose of these challenges isn't to break new ground in mech interp, rather they're designed to help you practice using & develop better understanding for standard MI tools (e.g. interpreting attention, direct logit attribution), and more generally working with libraries like TransformerLens.

Also, they're hopefully pretty fun, because why shouldn't we have some fun while we're learning?

## Logistics

The solution to this problem will be published on this page in the first few days of September, at the same time as the next problem in the sequence. There will also be an associated LessWrong post.

If you try to interpret this model, you can send your attempt in any of the following formats:

* Colab notebook,
* GitHub repo (e.g. with ipynb or markdown file explaining results),
* Google Doc (with screenshots and explanations),
* or any other sensible format.

You can send your attempt to me (Callum McDougall) via any of the following methods:

* The [Slack group](https://join.slack.com/t/arena-la82367/shared_invite/zt-1uvoagohe-JUv9xB7Vr143pdx1UBPrzQ), via a direct message to me
* My personal email: `cal.s.mcdougall@gmail.com`
* LessWrong message ([here](https://www.lesswrong.com/users/themcdouglas) is my user)

**I'll feature the names of everyone who sends me a solution on this website, and also give a shout out to the best solutions.** It's possible that future challenges will also feature a monetary prize, but this is not guaranteed.

Please don't discuss specific things you've found about this model until the challenge is over (although you can discuss general strategies and techniques, and you're also welcome to work in a group if you'd like). The deadline for this problem will be the end of this month, i.e. 31st August.

## What counts as a solution?

Going through the solutions for the previous problem in the sequence (July: Palindromes) as well as the exercises in **[1.4] Balanced Bracket Classifier** should give you a good idea of what I'm looking for. In particular, I'd expect you to:

* Describe a mechanism for how the model solves the task, in the form of the QK and OV circuits of various attention heads (and possibly any other mechanisms the model uses, e.g. the direct path, or nonlinear effects from layernorm),
* Provide evidence for your mechanism, e.g. with tools like attention plots, targeted ablation / patching, or direct logit attribution.
* (Optional) Include additional detail, e.g. identifying the subspaces that the model uses for certain forms of information transmission, or using your understanding of the model's behaviour to construct adversarial examples.

## Task & Dataset

The algorithmic task is as follows: the model is presented with a sequence of characters, and for each character it has to correctly identify the first character in the sequence (up to and including the current character) which is unique up to that point.

The null character `"?"` has two purposes:

* In the input, it's used as the start character (because it's often helpful for interp to have a constant start character, to act as a "rest position").
* In the output, it's also used as the start character, **and** to represent the classification "no unique character exists".

Here is an example of what this dataset looks like:

```python
dataset = UniqueCharDataset(size=2, vocab=list("abc"), seq_len=6, seed=42)

for seq, first_unique_char_seq in zip(dataset.str_toks, dataset.str_tok_labels):
    print(f"Seq = {''.join(seq)}, Target = {''.join(first_unique_char_seq)}")
```

<div style='font-family:monospace;'>
Seq = ?acbba, Target = ?aaaac<br>
Seq = ?cbcbc, Target = ?ccb??
</div><br>

Explanation:

1. In the first sequence, `"a"` is unique in the prefix substring `"acbb"`, but it repeats at the 5th sequence position, meaning the final target character is `"c"` (which appears second in the sequence).
2. In the second sequence, `"c"` is unique in the prefix substring `"cb"`, then it repeats so `"b"` is the new first unique token, and for the last 2 positions there are no unique characters (since both `"b"` and `"c"` have been repeated) so the correct classification is `"?"` (the "null character").

The relevant files can be found in local storage (after you run the setup code at the top of this notebook), at:

```
chapter1_transformers/
└── exercises/
    └── monthly_algorithmic_problems/
        └── august23_unique_char/
            ├── model.py               # code to create the model
            ├── dataset.py             # code to define the dataset
            ├── training.py            # code to training the model
            └── training_model.ipynb   # actual training script
```

We've given you the class `UniqueCharDataset` to store your data, as you can see above. You can slice this object to get batches of tokens and labels (e.g. `dataset[:5]` returns a length-2 tuple, containing the 2D tensors representing the tokens and correct labels respectively). You can also use `dataset.toks` or `dataset.labels` to access these tensors directly, or `dataset.str_toks` and `dataset.str_tok_labels` to get the string representations of the tokens and labels (like we did in the code above).

## Model

Our model was trained by minimising cross-entropy loss between its predictions and the true labels, at every sequence position simultaneously (including the zeroth sequence position, which is trivial because the input and target are both always `"?"`). You can inspect the notebook `training_model.ipynb` to see how it was trained. I used the version of the model which achieved highest accuracy over 40 epochs.



The model is is a 2-layer transformer with 3 attention heads, and causal attention. It includes layernorm, but no MLP layers. You can load it in as follows:

In [None]:
filename = section_dir / "first_unique_char_model.pt"

model = create_model(
    seq_len=20,
    vocab=list("abcdefghij"),
    seed=42,
    d_model=42,
    d_head=14,
    n_layers=2,
    n_heads=3,
    normalization_type="LN",
    d_mlp=None # attn-only model
)

state_dict = t.load(filename)

state_dict = model.center_writing_weights(t.load(filename))
state_dict = model.center_unembed(state_dict)
state_dict = model.fold_layer_norm(state_dict)
state_dict = model.fold_value_biases(state_dict)
model.load_state_dict(state_dict, strict=False);

The code to process the state dictionary is a bit messy, but it's necessary to make sure the model is easy to work with. For instance, if you inspect the model's parameters, you'll see that `model.ln_final.w` is a vector of 1s, and `model.ln_final.b` is a vector of 0s (because the weight and bias have been folded into the unembedding).

In [None]:
print("ln_final weight: ", model.ln_final.w)
print("\nln_final, bias: ", model.ln_final.b)

ln_final weight:  Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1.], device='cuda:0', requires_grad=True)

ln_final, bias:  Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', requires_grad=True)


<details>
<summary>Aside - the other weight processing parameters</summary>

Here's some more code to verify that our weights processing worked, in other words:

* The unembedding matrix has mean zero over both its input dimension (`d_model`) and output dimension (`d_vocab`)
* All writing weights (i.e. `b_O`, `W_O`, and both embeddings) have mean zero over their output dimension (`d_model`)
* The value biases `b_V` are zero (because these can just be folded into the output biases `b_O`)

```python
W_U_mean_over_input = einops.reduce(model.W_U, "d_model d_vocab -> d_model", "mean")
t.testing.assert_close(W_U_mean_over_input, t.zeros_like(W_U_mean_over_input))

W_U_mean_over_output = einops.reduce(model.W_U, "d_model d_vocab -> d_vocab", "mean")
t.testing.assert_close(W_U_mean_over_output, t.zeros_like(W_U_mean_over_output))

W_O_mean_over_output = einops.reduce(model.W_O, "layer head d_head d_model -> layer head d_head", "mean")
t.testing.assert_close(W_O_mean_over_output, t.zeros_like(W_O_mean_over_output))

b_O_mean_over_output = einops.reduce(model.b_O, "layer d_model -> layer", "mean")
t.testing.assert_close(b_O_mean_over_output, t.zeros_like(b_O_mean_over_output))

W_E_mean_over_output = einops.reduce(model.W_E, "token d_model -> token", "mean")
t.testing.assert_close(W_E_mean_over_output, t.zeros_like(W_E_mean_over_output))

W_pos_mean_over_output = einops.reduce(model.W_pos, "position d_model -> position", "mean")
t.testing.assert_close(W_pos_mean_over_output, t.zeros_like(W_pos_mean_over_output))

b_V = model.b_V
t.testing.assert_close(b_V, t.zeros_like(b_V))
```

</details>

The model's output is a logit tensor, of shape `(batch_size, seq_len, d_vocab+1)`. The `[i, j, :]`-th element of this tensor is the logit distribution for the label at position `j` in the `i`-th sequence in the batch. The first `d_vocab` elements of this tensor correspond to the elements in the vocabulary, and the last element corresponds to the null character `"?"` (which is not in the input vocab).

A demonstration of the model working:


In [None]:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)

logits, cache = model.run_with_cache(dataset.toks)

logprobs = logits.log_softmax(-1) # [batch seq_len d_vocab]
probs = logprobs.softmax(-1) # [batch seq_len d_vocab]

batch_size, seq_len = dataset.toks.shape
logprobs_correct = logprobs[t.arange(batch_size)[:, None], t.arange(seq_len)[None, :], dataset.labels] # [batch seq_len]
probs_correct = probs[t.arange(batch_size)[:, None], t.arange(seq_len)[None, :], dataset.labels] # [batch seq_len]

avg_cross_entropy_loss = -logprobs_correct.mean().item()
avg_correct_prob = probs_correct.mean().item()
min_correct_prob = probs_correct.min().item()

print(f"Average cross entropy loss: {avg_cross_entropy_loss:.3f}")
print(f"Average probability on correct label: {avg_correct_prob:.3f}")
print(f"Min probability on correct label: {min_correct_prob:.3f}")

Average cross entropy loss: 0.017
Average probability on correct label: 0.988
Min probability on correct label: 0.001


And a visualisation of its probability output for a single sequence:

In [None]:
imshow(
    probs[0].T,
    y=dataset.vocab,
    x=[f"{dataset.str_toks[0][i]}<br>({i})" for i in range(model.cfg.n_ctx)],
    labels={"x": "Token", "y": "Vocab"},
    xaxis_tickangle=0,
    title="Sample model probabilities (for batch idx = 0), with correct classification highlighted",
    text=[
        ["〇" if str_tok == correct_str_tok else "" for correct_str_tok in dataset.str_tok_labels[0]]
        for str_tok in dataset.vocab
    ], # text can be a 2D list of lists, with the same shape as the data
)

If you want some guidance on how to get started, I'd recommend reading the solutions for the July problem - I expect there to be a lot of overlap in the best way to tackle these two problems. You can also reuse some of that code!


Note - although this model was trained for long enough to get loss close to zero (you can test this for yourself), it's not perfect. There are some weaknesses that the model has which might make it vulnerable to adversarial examples, and I've decided to leave these in. The model is still very good at its intended task, and the main focus of this challenge is on figuring out how it solves the task, not dissecting the situations where it fails. However, you might find that the adversarial examples help you understand the model better.


Best of luck! 🎈