In [77]:
import torch as t
from IPython.display import display
import circuitsvis as cv
import numpy as np

from transformer_lens import (
    ActivationCache,
    FactoredMatrix,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)

from arithmetic import detokenize, tokenize, chars, generate_completion_with_cache

from huggingface_hub import hf_hub_download


device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")


cfg = HookedTransformerConfig(
    d_model=768,
    d_head=64,
    n_heads=12,
    n_layers=2,
    n_ctx=2048,
    d_vocab=len(chars),
    attention_dir="causal",
    attn_only=True,  # defaults to False
    seed=398,
    use_attn_result=True,
    normalization_type=None,  # defaults to "LN", i.e. layernorm with weights & biases
    positional_embedding_type="shortformer",
)

weights_path = "models/arithmetic_model_2layers/model_weights.pt"

model = HookedTransformer(cfg)
pretrained_weights = t.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(pretrained_weights)
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_attn_input): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)

#### View model output for sample problem

In [78]:
text = ".581.+.668.=.1249."
def to_str_tokens(text):
    return [detokenize([i]) for i in tokenize(text)]
str_tokens = to_str_tokens(text)


In [79]:
import torch
inp = torch.tensor([tokenize(text)], dtype=torch.long).to(device)
logits, cache = model.run_with_cache(inp, remove_batch_dim=True)

In [80]:
attention_pattern = cache["pattern", 0]
attention_pattern.shape


torch.Size([12, 18, 18])

In [81]:

for layer in range(model.cfg.n_layers):
    attention_pattern = cache["pattern", layer]
    display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern, ))

In [76]:

def identify_primary_digit_heads(model, cache, text, print_details=False):
    """
    For each layer and head, find the digit that receives the most attention
    from digit-token positions. Returns a nested list of shape [n_layers][n_heads].
    
    Args:
        model: A HookedTransformer instance.
        cache: An ActivationCache from running the model on 'text'.
        text:  The input text string you ran through the model.
        print_details: If True, prints out which digits each head attends to and the counts.

    Returns:
        results: A nested list of size [n_layers][n_heads], where each entry is
                 the digit (as a string) that the head primarily attends to, or None
                 if no digit was found.
    """
    # 1. Tokenize and find which tokens are digits
    tokens = tokenize(text)
    str_tokens = [detokenize([i]) for i in tokens]
    digit_positions = [i for i, tok in enumerate(str_tokens) if tok.isdigit()]

    results = []

    # 2. Loop over layers
    for layer_idx in range(model.cfg.n_layers):
        attention_pattern = cache["pattern", layer_idx]  # shape: (n_heads, seq_len, seq_len)
        attention_pattern = attention_pattern.detach().cpu()  # move to CPU for safety
        layer_result = []

        # 3. Loop over heads in this layer
        for head_idx in range(model.cfg.n_heads):
            digit_counts = {}
            # 4. For each digit position, find which token has the max attention
            for pos in digit_positions:
                # attention_pattern[head, query_pos] is a 1D tensor over keys
                attn_row = attention_pattern[head_idx, pos]
                max_attended_idx = attn_row.argmax().item()
                attended_token = str_tokens[max_attended_idx]

                # If that attended token is also a digit, increment count
                if attended_token.isdigit():
                    digit_counts[attended_token] = digit_counts.get(attended_token, 0) + 1

            # 5. Find the most common digit (if any)
            if len(digit_counts) > 0:
                primary_digit = max(digit_counts, key=digit_counts.get)
            else:
                primary_digit = None

            layer_result.append(primary_digit)
            if print_details:
                print(f"Layer {layer_idx}, Head {head_idx}: {primary_digit}, counts={digit_counts}")

        results.append(layer_result)

    return results

# Example usage:
layer_head_digits = identify_primary_digit_heads(model, cache, text, print_details=True)
print(layer_head_digits)

Layer 0, Head 0: 1, counts={'1': 1}
Layer 0, Head 1: 5, counts={'5': 2, '6': 1, '8': 1}
Layer 0, Head 2: 1, counts={'1': 1}
Layer 0, Head 3: 1, counts={'1': 4}
Layer 0, Head 4: 8, counts={'5': 1, '1': 2, '8': 3}
Layer 0, Head 5: 1, counts={'1': 2}
Layer 0, Head 6: 5, counts={'5': 2, '1': 1}
Layer 0, Head 7: 5, counts={'5': 3, '8': 1}
Layer 0, Head 8: 1, counts={'5': 1, '1': 2, '6': 2}
Layer 0, Head 9: 5, counts={'5': 1, '1': 1}
Layer 0, Head 10: 5, counts={'5': 1}
Layer 1, Head 0: 8, counts={'5': 1, '8': 3, '6': 2}
Layer 1, Head 1: 1, counts={'5': 1, '8': 2, '1': 3}
Layer 1, Head 2: 1, counts={'5': 1, '8': 2, '1': 3}
Layer 1, Head 3: 8, counts={'5': 1, '8': 2, '1': 1}
Layer 1, Head 4: 8, counts={'5': 1, '8': 2, '1': 2}
Layer 1, Head 5: 8, counts={'5': 1, '8': 3, '6': 2}
Layer 1, Head 6: 5, counts={'5': 2, '1': 1, '8': 1}
Layer 1, Head 7: 1, counts={'5': 1, '8': 2, '1': 3}
Layer 1, Head 8: 5, counts={'5': 2, '1': 1, '6': 2, '8': 1}
Layer 1, Head 9: 8, counts={'5': 1, '8': 3, '6': 1}
Lay