# David Bau stream test task
### Darius Kianersi

Given the following type of prompt, a sufficiently large language model will be able to answer with the correct number.
```
Count the number of words in the following  list that match the given type, and put the numerical answer in parentheses.
Type: fruit
List: [dog apple cherry bus cat grape bowl]
Answer: (
```
1. create a dataset of several thousand examples like this.
2. benchmark some open-weight LMs on solving this task zero-shot (without reasoning tokens)
3. for a single model, create a causal mediation analysis experiment (patching from one run to another) to answer: "is there a hidden state layer that contains a representation of the running count of matching words, while processing the list of words?"

In [1]:
import re
import sys
import os
from functools import partial
from itertools import product
from pathlib import Path
from typing import Callable, Literal

import circuitsvis as cv
import einops
import numpy as np
import plotly.express as px
import torch
from pprint import pprint
from IPython.display import HTML, display
from jaxtyping import Bool, Float, Int
from rich import print as rprint
from rich.table import Column, Table
from torch import Tensor
from tqdm.notebook import tqdm
from transformer_lens import ActivationCache, HookedTransformer, utils
from transformer_lens.components import MLP, Embed, LayerNorm, Unembed
from transformer_lens.hook_points import HookPoint
from openai import OpenAI

torch.set_grad_enabled(False)
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# from plotly_utils import bar, imshow, line, scatter

MAIN = __name__ == "__main__"

In [2]:
from huggingface_hub import login
from dotenv import load_dotenv

load_dotenv()
assert os.getenv("OPENAI_API_KEY") is not None, "OPENAI_KEY is not set"
openai_client = OpenAI()
login(token=os.getenv("HF_TOKEN"))

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [3]:
response = openai_client.chat.completions.create(
    model="gpt-4.1-nano",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is the capital of France?"},
    ],
    n=2,
)

pprint(response.model_dump())  # See the entire ChatCompletion object, as a dict (more readable)
print("\n", response.choices[0].message.content)  # See the response message only

{'choices': [{'finish_reason': 'stop',
              'index': 0,
              'logprobs': None,
              'message': {'annotations': [],
                          'audio': None,
                          'content': 'The capital of France is Paris.',
                          'function_call': None,
                          'refusal': None,
                          'role': 'assistant',
                          'tool_calls': None}},
             {'finish_reason': 'stop',
              'index': 1,
              'logprobs': None,
              'message': {'annotations': [],
                          'audio': None,
                          'content': 'The capital of France is Paris.',
                          'function_call': None,
                          'refusal': None,
                          'role': 'assistant',
                          'tool_calls': None}}],
 'created': 1748501310,
 'id': 'chatcmpl-BcRLyglj8NgWcoQlQ3vbwOfsrUmtI',
 'model': 'gpt-4.1-nano-2025-04-14',
 'obj

In [3]:
model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    # center_unembed=True,
    # center_writing_weights=True,
    # fold_ln=True,
    # refactor_factored_attn_matrices=True,
)

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]



Loaded pretrained model meta-llama/Llama-3.1-8B-Instruct into HookedTransformer


In [None]:
# Here is where we test on a single prompt
# Result: 70% probability on Mary, as we expect

example_prompt = """Count the number of words in the following  list that match the given type, and put the numerical answer in parentheses.
Type: fruit
List: [dog apple cherry bus cat grape bowl]
Answer: ("""
example_answer = "3"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|begin_of_text|>', 'Count', ' the', ' number', ' of', ' words', ' in', ' the', ' following', ' ', ' list', ' that', ' match', ' the', ' given', ' type', ',', ' and', ' put', ' the', ' numerical', ' answer', ' in', ' parentheses', '.\n', 'Type', ':', ' fruit', '\n', 'List', ':', ' [', 'dog', ' cherry', ' bus', ' cat', ' grape', ' bowl', ']\n', 'Answer', ':', ' (']
Tokenized answer: [' ', '3']


Top 0th token. Logit: 20.52 Prob: 56.46% Token: |1|
Top 1th token. Logit: 20.19 Prob: 40.66% Token: |2|
Top 2th token. Logit: 16.75 Prob:  1.30% Token: |0|
Top 3th token. Logit: 16.32 Prob:  0.84% Token: |3|
Top 4th token. Logit: 14.75 Prob:  0.18% Token: | |
Top 5th token. Logit: 14.44 Prob:  0.13% Token: |4|
Top 6th token. Logit: 13.16 Prob:  0.04% Token: |gr|
Top 7th token. Logit: 12.97 Prob:  0.03% Token: |5|
Top 8th token. Logit: 12.24 Prob:  0.01% Token: |one|
Top 9th token. Logit: 12.14 Prob:  0.01% Token: | grape|


Top 0th token. Logit: 19.42 Prob: 50.86% Token: |1|
Top 1th token. Logit: 19.33 Prob: 46.67% Token: |2|
Top 2th token. Logit: 15.22 Prob:  0.77% Token: |0|
Top 3th token. Logit: 14.92 Prob:  0.56% Token: | |
Top 4th token. Logit: 14.56 Prob:  0.39% Token: |3|
Top 5th token. Logit: 13.23 Prob:  0.10% Token: | )
|
Top 6th token. Logit: 12.94 Prob:  0.08% Token: |4|
Top 7th token. Logit: 12.28 Prob:  0.04% Token: | )|
Top 8th token. Logit: 11.94 Prob:  0.03% Token: | grape|
Top 9th token. Logit: 11.73 Prob:  0.02% Token: |5|


In [27]:
probs = model(example_prompt)[:, -1].softmax(dim=-1)
toks = torch.multinomial(probs, num_samples=10, replacement=True)
model.to_string(toks)

['113213320fresh']