Finding Circuits in LLM when forced to do Self Explanations

## Setup and Imports

In [None]:

# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2


Running as a Colab notebook
Collecting transformer_lens
  Downloading transformer_lens-1.15.0-py3-none-any.whl (124 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.2/124.2 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.29.2-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.4/297.4 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl (3.5 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

from transformers import AutoModelForCausalLM, AutoTokenizer
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

In [None]:
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

Disabled automatic differentiation


In [None]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

## Self-Explanations - Circuits Experimentation

In [None]:
model = HookedTransformer.from_pretrained(
    "qwen1.5-1.8b-chat"
)
# Get the default device used
device: torch.device = utils.get_device()

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/206 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model qwen1.5-1.8b-chat into HookedTransformer


In [None]:
# Load the pre-trained HookedTransformer model

example_prompt = "Henry found a Wolf in the forest... Who ran away? Why? -"

example = "Explanation"
print(model.to_single_token(example))
print(model.to_tokens(example_prompt))
example_answer = " Henry ran away, because he's a human. Humans are scared of Wolves"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

69769
tensor([[151645,  63363,   1730,    264,  25230,    304,    279,  13638,   1112,
          10479,  10613,   3123,     30,   8429,     30,    481]],
       device='cuda:0')
Tokenized prompt: ['<|im_end|>', 'Henry', ' found', ' a', ' Wolf', ' in', ' the', ' forest', '...', ' Who', ' ran', ' away', '?', ' Why', '?', ' -']
Tokenized answer: [' Henry', ' ran', ' away', ',', ' because', ' he', "'s", ' a', ' human', '.', ' Humans', ' are', ' scared', ' of', ' Wolves']


Top 0th token. Logit: 14.95 Prob:  7.69% Token: | a|
Top 1th token. Logit: 14.92 Prob:  7.48% Token: | human|
Top 2th token. Logit: 14.83 Prob:  6.79% Token: | A|
Top 3th token. Logit: 14.49 Prob:  4.86% Token: | The|
Top 4th token. Logit: 14.03 Prob:  3.05% Token: | humans|
Top 5th token. Logit: 13.40 Prob:  1.63% Token: | the|
Top 6th token. Logit: 13.24 Prob:  1.38% Token: | Wolf|
Top 7th token. Logit: 13.24 Prob:  1.38% Token: | Dog|
Top 8th token. Logit: 13.21 Prob:  1.34% Token: | Human|
Top 9th token. Logit: 13.13 Prob:  1.25% Token: | friend|


Top 0th token. Logit: 17.01 Prob: 42.63% Token: | -|
Top 1th token. Logit: 15.61 Prob: 10.51% Token: | is|
Top 2th token. Logit: 15.24 Prob:  7.26% Token: | was|
Top 3th token. Logit: 14.76 Prob:  4.50% Token: |<|im_end|>|
Top 4th token. Logit: 14.36 Prob:  3.00% Token: |'s|
Top 5th token. Logit: 14.26 Prob:  2.73% Token: |\n|
Top 6th token. Logit: 14.03 Prob:  2.15% Token: | had|
Top 7th token. Logit: 13.92 Prob:  1.93% Token: | went|
Top 8th token. Logit: 13.56 Prob:  1.35% Token: | and|
Top 9th token. Logit: 13.15 Prob:  0.90% Token: | has|


Top 0th token. Logit: 21.92 Prob: 93.68% Token: | away|
Top 1th token. Logit: 18.21 Prob:  2.30% Token: | after|
Top 2th token. Logit: 16.84 Prob:  0.58% Token: | home|
Top 3th token. Logit: 16.51 Prob:  0.42% Token: | faster|
Top 4th token. Logit: 16.04 Prob:  0.26% Token: | in|
Top 5th token. Logit: 16.00 Prob:  0.25% Token: | because|
Top 6th token. Logit: 15.99 Prob:  0.25% Token: | into|
Top 7th token. Logit: 15.74 Prob:  0.19% Token: | to|
Top 8th token. Logit: 15.60 Prob:  0.17% Token: | scared|
Top 9th token. Logit: 15.57 Prob:  0.16% Token: | from|


Top 0th token. Logit: 18.37 Prob: 49.23% Token: | because|
Top 1th token. Logit: 17.34 Prob: 17.47% Token: |<|im_end|>|
Top 2th token. Logit: 16.74 Prob:  9.60% Token: | -|
Top 3th token. Logit: 16.47 Prob:  7.32% Token: |.|
Top 4th token. Logit: 15.97 Prob:  4.44% Token: | from|
Top 5th token. Logit: 14.83 Prob:  1.42% Token: | to|
Top 6th token. Logit: 14.55 Prob:  1.08% Token: | after|
Top 7th token. Logit: 14.32 Prob:  0.86% Token: |...|
Top 8th token. Logit: 14.30 Prob:  0.84% Token: | due|
Top 9th token. Logit: 13.99 Prob:  0.62% Token: | Henry|


Top 0th token. Logit: 17.81 Prob: 33.12% Token: | because|
Top 1th token. Logit: 17.49 Prob: 23.91% Token: | he|
Top 2th token. Logit: 16.58 Prob:  9.65% Token: | as|
Top 3th token. Logit: 15.87 Prob:  4.73% Token: | but|
Top 4th token. Logit: 15.50 Prob:  3.28% Token: | the|
Top 5th token. Logit: 15.35 Prob:  2.83% Token: | scared|
Top 6th token. Logit: 14.86 Prob:  1.73% Token: | it|
Top 7th token. Logit: 14.75 Prob:  1.55% Token: | -|
Top 8th token. Logit: 14.63 Prob:  1.37% Token: | not|
Top 9th token. Logit: 14.41 Prob:  1.11% Token: | and|


Top 0th token. Logit: 21.51 Prob: 92.06% Token: | he|
Top 1th token. Logit: 18.17 Prob:  3.26% Token: | the|
Top 2th token. Logit: 17.76 Prob:  2.17% Token: | Henry|
Top 3th token. Logit: 17.03 Prob:  1.04% Token: | it|
Top 4th token. Logit: 15.59 Prob:  0.25% Token: | wolves|
Top 5th token. Logit: 15.53 Prob:  0.23% Token: | his|
Top 6th token. Logit: 15.13 Prob:  0.16% Token: | there|
Top 7th token. Logit: 14.61 Prob:  0.09% Token: | of|
Top 8th token. Logit: 14.60 Prob:  0.09% Token: | they|
Top 9th token. Logit: 14.21 Prob:  0.06% Token: | a|


Top 0th token. Logit: 21.74 Prob: 73.05% Token: | was|
Top 1th token. Logit: 19.90 Prob: 11.52% Token: | heard|
Top 2th token. Logit: 19.29 Prob:  6.31% Token: | saw|
Top 3th token. Logit: 17.89 Prob:  1.54% Token: | felt|
Top 4th token. Logit: 17.50 Prob:  1.04% Token: | couldn|
Top 5th token. Logit: 17.32 Prob:  0.88% Token: | scared|
Top 6th token. Logit: 17.13 Prob:  0.73% Token: | wanted|
Top 7th token. Logit: 16.78 Prob:  0.51% Token: | didn|
Top 8th token. Logit: 16.72 Prob:  0.48% Token: | got|
Top 9th token. Logit: 16.66 Prob:  0.45% Token: | thought|


Top 0th token. Logit: 20.53 Prob: 43.94% Token: | afraid|
Top 1th token. Logit: 20.39 Prob: 38.05% Token: | scared|
Top 2th token. Logit: 19.17 Prob: 11.25% Token: | a|
Top 3th token. Logit: 16.50 Prob:  0.78% Token: | not|
Top 4th token. Logit: 16.13 Prob:  0.54% Token: | the|
Top 5th token. Logit: 16.02 Prob:  0.48% Token: | an|
Top 6th token. Logit: 15.79 Prob:  0.38% Token: | stupid|
Top 7th token. Logit: 15.64 Prob:  0.33% Token: | lost|
Top 8th token. Logit: 15.57 Prob:  0.31% Token: | lazy|
Top 9th token. Logit: 15.48 Prob:  0.28% Token: | too|


Top 0th token. Logit: 19.65 Prob: 60.46% Token: | coward|
Top 1th token. Logit: 18.28 Prob: 15.30% Token: | scared|
Top 2th token. Logit: 17.25 Prob:  5.45% Token: | wolf|
Top 3th token. Logit: 16.17 Prob:  1.86% Token: | dog|
Top 4th token. Logit: 16.13 Prob:  1.78% Token: | pet|
Top 5th token. Logit: 15.62 Prob:  1.07% Token: | vegetarian|
Top 6th token. Logit: 15.45 Prob:  0.91% Token: | predator|
Top 7th token. Logit: 15.29 Prob:  0.77% Token: | chicken|
Top 8th token. Logit: 15.18 Prob:  0.69% Token: | human|
Top 9th token. Logit: 15.04 Prob:  0.60% Token: | big|


Top 0th token. Logit: 17.02 Prob: 32.41% Token: | -|
Top 1th token. Logit: 16.73 Prob: 24.32% Token: |<|im_end|>|
Top 2th token. Logit: 16.12 Prob: 13.13% Token: |.|
Top 3th token. Logit: 15.90 Prob: 10.58% Token: | and|
Top 4th token. Logit: 14.85 Prob:  3.70% Token: | being|
Top 5th token. Logit: 14.13 Prob:  1.80% Token: |.-|
Top 6th token. Logit: 13.92 Prob:  1.46% Token: |,|
Top 7th token. Logit: 13.28 Prob:  0.77% Token: |\n|
Top 8th token. Logit: 13.04 Prob:  0.61% Token: | who|
Top 9th token. Logit: 12.97 Prob:  0.57% Token: |\.|


Top 0th token. Logit: 20.40 Prob: 55.81% Token: |<|im_end|>|
Top 1th token. Logit: 20.15 Prob: 43.15% Token: | -|
Top 2th token. Logit: 14.82 Prob:  0.21% Token: | He|
Top 3th token. Logit: 14.56 Prob:  0.16% Token: | The|
Top 4th token. Logit: 13.79 Prob:  0.08% Token: | Henry|
Top 5th token. Logit: 12.96 Prob:  0.03% Token: | |
Top 6th token. Logit: 12.70 Prob:  0.03% Token: | Because|
Top 7th token. Logit: 12.43 Prob:  0.02% Token: | Human|
Top 8th token. Logit: 12.37 Prob:  0.02% Token: | Humans|
Top 9th token. Logit: 12.36 Prob:  0.02% Token: | It|


Top 0th token. Logit: 19.51 Prob: 37.36% Token: | are|
Top 1th token. Logit: 18.81 Prob: 18.56% Token: | have|
Top 2th token. Logit: 18.69 Prob: 16.45% Token: | can|
Top 3th token. Logit: 17.79 Prob:  6.70% Token: | don|
Top 4th token. Logit: 17.43 Prob:  4.66% Token: | cannot|
Top 5th token. Logit: 16.72 Prob:  2.30% Token: | do|
Top 6th token. Logit: 16.38 Prob:  1.64% Token: | usually|
Top 7th token. Logit: 16.07 Prob:  1.20% Token: | generally|
Top 8th token. Logit: 15.81 Prob:  0.92% Token: | typically|
Top 9th token. Logit: 15.78 Prob:  0.90% Token: | like|


Top 0th token. Logit: 20.29 Prob: 61.45% Token: | not|
Top 1th token. Logit: 18.51 Prob: 10.44% Token: | afraid|
Top 2th token. Logit: 17.53 Prob:  3.92% Token: | creatures|
Top 3th token. Logit: 16.86 Prob:  2.00% Token: | generally|
Top 4th token. Logit: 16.79 Prob:  1.86% Token: | typically|
Top 5th token. Logit: 16.70 Prob:  1.70% Token: | social|
Top 6th token. Logit: 16.55 Prob:  1.46% Token: | scared|
Top 7th token. Logit: 16.45 Prob:  1.33% Token: | animals|
Top 8th token. Logit: 16.42 Prob:  1.29% Token: | usually|
Top 9th token. Logit: 16.01 Prob:  0.85% Token: | territorial|


Top 0th token. Logit: 22.51 Prob: 94.44% Token: | of|
Top 1th token. Logit: 19.40 Prob:  4.20% Token: | by|
Top 2th token. Logit: 17.62 Prob:  0.71% Token: | and|
Top 3th token. Logit: 16.06 Prob:  0.15% Token: | away|
Top 4th token. Logit: 15.49 Prob:  0.08% Token: | to|
Top 5th token. Logit: 14.80 Prob:  0.04% Token: | easily|
Top 6th token. Logit: 14.54 Prob:  0.03% Token: | animals|
Top 7th token. Logit: 14.48 Prob:  0.03% Token: |.|
Top 8th token. Logit: 14.43 Prob:  0.03% Token: | when|
Top 9th token. Logit: 14.34 Prob:  0.03% Token: | or|


Top 0th token. Logit: 23.99 Prob: 95.47% Token: | wolves|
Top 1th token. Logit: 19.53 Prob:  1.10% Token: | big|
Top 2th token. Logit: 19.30 Prob:  0.88% Token: | wild|
Top 3th token. Logit: 18.75 Prob:  0.51% Token: | animals|
Top 4th token. Logit: 18.29 Prob:  0.32% Token: | dogs|
Top 5th token. Logit: 18.05 Prob:  0.25% Token: | the|
Top 6th token. Logit: 18.04 Prob:  0.25% Token: | Wolves|
Top 7th token. Logit: 17.69 Prob:  0.18% Token: | predators|
Top 8th token. Logit: 17.65 Prob:  0.17% Token: | large|
Top 9th token. Logit: 16.91 Prob:  0.08% Token: |狼|


In the following part: my idea was to create three prompts, GWEN will generate answers for these prompts,

I did a bit of cheating here, because I setup the temperature to 0 in order to trigger the same answers all the time, I realised it was more complicated than expected with different length explanations, and I couldn't really trigger explanations with similar words and length to make a "good" analysis of the logits after.

However: my original idea was something like

Good Explanation = "Human had to run away, because he's a human. Humans are really scared of Wolves", "Wolf ran away, because he's an animal. Animals are scared of Humans",

LLM Explanation =

In [None]:
# Get the tokenizer from the model
tokenizer = model.tokenizer

prompts = [
    "In the forest two living creatures encountered each other, Henry and the Wolf. One of them was scared, and ran away.\n Who did run away? Why?\n",
    "In the jungle two living creatures encountered each other, Lion and Gazelle. One of them was scared and ran away.\n Who did run away? Why?\n",
    "In the rainforest two living creatures encountered each other, Jaguar and Monkey. One of creature was scared and ran away.\n Who did run away? Why?\n",
]

explanation = [
    "Human had to run away, because it was scared of the Wolf", "Wolf had to run away, because it was scared of the Human",
    "Gazelle ran away, because it was scared of the Lion", "Lion ran away, because it was scared of the Gazelle",
    "Monkey had to run away, because it was scared of the Jaguar", "Jaguar run away, because it was scared of the Monkey"
    ]

generation = []
explanations_tokens = []
generation_tokens = []


In [None]:
# Generate text for each prompt and collect results
for prompt in prompts:
    # Tokenize the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Generate text
    generated_output = model.generate(
        input_ids,
        max_new_tokens=14,  # Maximum number of new tokens to generate
        top_k=25,
        top_p=0.65,
        temperature = 0.10,# Sampling probability
        do_sample=True
    )

    # Decode the generated output
    generated_text = tokenizer.decode(generated_output[0], skip_special_tokens=True)
    print(generated_text)
    answer = generated_text.split('\n')[-1].strip()

    # Append the generated answer to the list
    generation.append(answer)


  0%|          | 0/14 [00:00<?, ?it/s]

In the forest two living creatures encountered each other, Henry and the Wolf. One of them was scared, and ran away.
 Who did run away? Why?
Henry ran away because he was scared of the Wolf. The Wolf is


  0%|          | 0/14 [00:00<?, ?it/s]

In the jungle two living creatures encountered each other, Lion and Gazelle. One of them was scared and ran away.
 Who did run away? Why?
Answer: Lion. Lion was scared of Gazelle because Gazelle was


  0%|          | 0/14 [00:00<?, ?it/s]

In the rainforest two living creatures encountered each other, Jaguar and Monkey. One of creature was scared and ran away.
 Who did run away? Why?
Answer: The Jaguar ran away because it was scared of the Monkey.


In [None]:
print(explanation)
print(generation)

['Human had to run away, because it was scared of the Wolf', 'Wolf had to run away, because it was scared of the Human', 'Gazelle ran away, because it was scared of the Lion', 'Lion ran away, because it was scared of the Gazelle', 'Monkey had to run away, because it was scared of the Jaguar', 'Jaguar run away, because it was scared of the Monkey']
['Henry ran away because he was scared of the Wolf. The Wolf is', 'Answer: Lion. Lion was scared of Gazelle because Gazelle was', 'Answer: The Jaguar ran away because it was scared of the Monkey.']


In [None]:
explanation_tokens =  model.to_tokens(explanation, prepend_bos=True)
print(explanation_tokens)

tensor([[151645,  33975,   1030,    311,   1598,   3123,     11,   1576,    432,
            572,  26115,    315,    279,  25230],
        [151645,  95434,   1030,    311,   1598,   3123,     11,   1576,    432,
            572,  26115,    315,    279,  11097],
        [151645,     38,   1370,   6712,  10613,   3123,     11,   1576,    432,
            572,  26115,    315,    279,  32099],
        [151645,     43,    290,  10613,   3123,     11,   1576,    432,    572,
          26115,    315,    279,  43292,   6712],
        [151645,  96838,   1030,    311,   1598,   3123,     11,   1576,    432,
            572,  26115,    315,    279,  73537],
        [151645,     41,    351,  18731,   1598,   3123,     11,   1576,    432,
            572,  26115,    315,    279,  57837]], device='cuda:0')


In [None]:
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

for prompt in explanation:
    str_tokens = model.to_str_tokens(prompt)
    print("Explanation length:", len(str_tokens))
    print("Explanation length:", str_tokens)

Prompt length: 33
Prompt as tokens: ['<|im_end|>', 'In', ' the', ' forest', ' two', ' living', ' creatures', ' encountered', ' each', ' other', ',', ' Henry', ' and', ' the', ' Wolf', '.', ' One', ' of', ' them', ' was', ' scared', ',', ' and', ' ran', ' away', '.\n', ' Who', ' did', ' run', ' away', '?', ' Why', '?\n']
Prompt length: 32
Prompt as tokens: ['<|im_end|>', 'In', ' the', ' jungle', ' two', ' living', ' creatures', ' encountered', ' each', ' other', ',', ' Lion', ' and', ' Gaz', 'elle', '.', ' One', ' of', ' them', ' was', ' scared', ' and', ' ran', ' away', '.\n', ' Who', ' did', ' run', ' away', '?', ' Why', '?\n']
Prompt length: 32
Prompt as tokens: ['<|im_end|>', 'In', ' the', ' rain', 'forest', ' two', ' living', ' creatures', ' encountered', ' each', ' other', ',', ' Jaguar', ' and', ' Monkey', '.', ' One', ' of', ' creature', ' was', ' scared', ' and', ' ran', ' away', '.\n', ' Who', ' did', ' run', ' away', '?', ' Why', '?\n']
Explanation length: 14
Explanation leng

In [None]:
#tokens = model.to_tokens(generation_tokens, prepend_bos=True)
generation_tokens = model.to_tokens(generation,prepend_bos=True)
print(generation_tokens.shape)

# Run the model and cache all activations
generation_logits, cache = model.run_with_cache(generation_tokens)
explanation_logits, cache = model.run_with_cache(explanation_tokens)



torch.Size([3, 15])


In [None]:
print(generation_logits.shape)
print(explanation_logits.shape)


torch.Size([3, 15, 151936])
torch.Size([6, 14, 151936])


Now, I'll check how Generated Self-Explanations text differs from Explanations

In [None]:
def calculate_logit_differences(gen_tensor, exp_tensor, per_prompt=False):
    # Extracting final logits for generated text and explanations
    gen_final_logits = gen_tensor[:, -1, :]  # Shape: [3, 151936]
    exp_final_logits = exp_tensor[:, -1, :]  # Shape: [6, 151936]

    # Initializing a tensor to store the individual differences
    individual_differences = []

    # Calculate differences for each generated text with corresponding two explanations
    for i in range(gen_final_logits.shape[0]):
        exp_indices = torch.tensor([2*i, 2*i+1])  # Indices for the corresponding explanations
        for index in exp_indices:
            diff = (gen_final_logits[i] - exp_final_logits[index]).norm()  # Norm of the difference
            individual_differences.append(diff)

    individual_differences = torch.tensor(individual_differences)

    if per_prompt:
        # Calculate mean differences per generated text
        mean_differences = torch.tensor([individual_differences[i:i+2].mean() for i in range(0, len(individual_differences), 2)])
        return individual_differences, mean_differences
    else:
        # Return overall mean of all individual differences
        return individual_differences.mean()

# Calculate differences with per-prompt mean
individual_diffs, mean_diffs = calculate_logit_differences(generation_logits, explanation_logits, per_prompt=True)
print("Logit Differences per Self-Explanation:", individual_diffs.detach().cpu().round(decimals=3))

# Calculate average difference across all pairings
average_diff = calculate_logit_differences(generation_logits, explanation_logits)
print("Average logit difference:", round(average_diff.item(), 3))


Logit Differences per Self-Explanation: tensor([789.9990, 916.5980, 950.1580, 899.3370, 792.7700, 853.8740])
Average logit difference: 867.122


In [None]:
answer_residual_directions = model.tokens_to_residual_directions(explanation_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([6, 14, 2048])
Logit difference directions shape: torch.Size([6, 2048])


## Logit Lens

In [None]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(explanation)

In [None]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)

This is actually good news! Since after

## Layer Attribution

In [None]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

## Head Attribution

In [None]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

## Attention Analysis

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [None]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    generation_tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    generation_tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

## Activation Patching