<a href="https://colab.research.google.com/github/curt-tigges/logic-mi/blob/main/Syllogism_Analysis_for_SERI_MATS_Application.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Summary

## Problem Selection
For this project, I decided to investigate the question of how Pythia LMs resolve logical inference problems--specifically nonsensical syllogisms--e.g., "All birds are exoskeletal, not endoskeletal. Horses are birds. Therefore, horses are exoskeletal."

Initially, I tried examples of the format "All birds are exoskeletal. Horses are birds. Therefore, horses are exoskeletal," but this produced much messier behavior. Ultimately, I think this task is worth revisiting with a greater variety of datasets.

I thought this would be interesting because it is a fairly simple example of what larger LLMs do quite well--which is to say, in-context reasoning and the linking and categorization of entities. In this example, the "exoskeletal" property is initially mapped to "birds," and then "horses" are mapped to "birds" in a hierarchical relationship. The property "exoskeletal" then gets mapped back to "horses."

This, of course, isn't necessarily how the LLM does it (and probably isn't!), but study of how the LLM performs the task seems likely to yield some useful "pieces" of more complex in-context reasoning (or at least association) behaviors.

## Model
Pythia 160m was unable to achieve much performance on this task. The smallest model with interesting and tractable performance was Pythia-410m, so I elected to use it for this exploratory analysis.

## Investigation Methodology
In order to get an idea of what the circuit for this function looks like, I followed the principles from the Interpretability in the Wild paper and first obtained the directions of the logit difference between correct and incorrect answers. I then traced the information flow back using direct logit contribution, path patching, and activation patching.

## Circuit Findings

### Word Mover Heads
The first circuit components I found are three heads that appears to move the word from the Descriptor position to the final position. These heads:
- Write in the direction of the correct answer when a threshold of attention is reached
- Display clear clusters in their behavior (depending on the prompt)
- Are most heavily influenced by value patching


### Second-Level Q-Influencers
Using path patching, I identify two heads that have an outside on the queries of the WMHs.
- They appear to 1. attend to the first "are" at the final token position, and 2. several words preceding the correct answer respectively.
- The values of the most important Q-influencer are most important at position 4, the first token of the descriptor (e.g., "ex").

## Open Questions & Future Experiments
- Is the circuit simply performing induction, rather than paying attention to the animals and categories? That is, is it simply moving the word that follows the first "are" to follow the second "are"? This could be combined with an inhibition of the word that follows "not."
- Could the dataset be tweaked to avoid this issue, perhaps using synonyms or non-straightforward grammatical structures?
- In the attention-prob vs writing-out-strength space of the WMHs, there are clear clusters of behavior. What do these clusters correlate with?
- One of the WMHs is also a NMH in the IOI circuit I previously found in Pythia-410m. Why is there an overlap, and why is it only one of the NMHs?
- What are the Q-influencers actually doing? There are two significant ones, with two different attention patterns. Are they simply part of an induction circuit?
- After generating a random dataset, I filtered it for prompts where there was a clear logit difference between the correct and incorrect answer (and vice-versa, when swapped). Is it possible that I have inadvertently done something akin to p-hacking to obtain a non-representative result?
- How limited was the behavior overall by my descriptor choices?
- Are there backup WMHs?

## General Notes
- Some functions were copied from my earlier work, and overall this took about 14 hours. It feels far from complete, but far enough to be interesting!
- Dataset formulation took unexpectedly long--a good dataset takes a lot of care, and I iterated through several versions.
- I really enjoyed the investigation and think it would be fun to look into this circuit a bit more! I'd like to define the problem better and see if I can elicit something a bit more advanced--correcting any issues in the section above.

# Setup

Not much need to read through all this. Many of the functions below are tools I've started adding to my mech interp repertoire and aren't exclusively/solely for this project.

Some, like the path patching functions, were adapted from other codebases--and of course a lot has been copied from the TransformerLens notebooks.

## Installation

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install -Uqq git+https://github.com/neelnanda-io/TransformerLens.git
    # Install another version of node that makes PySvelte work way faster
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install -Uqq git+https://github.com/neelnanda-io/PySvelte.git
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the TransformerLens code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone

## Installing the NodeSource Node.js 16.x repo...


## Populating apt-get cache...

+ apt-get update
Hit:1 https://deb.nodesource.com/node_16.x focal InRelease
Hit:2 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease
Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  InRelease
Get:4 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
Hit:5 http://archive.ubuntu.com/ubuntu focal InRelease
Hit:6 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease
Get:7 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Hit:8 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease
Hit:9 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease
Get:10 http://archive.ubuntu.com/ubuntu focal-backports InRel

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

## Imports

In [None]:
# Import stuff
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import pysvelte
import circuitsvis as cv

import transformer_lens
import transformer_lens.utils as utils
import transformer_lens.patching as patching
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from neel_plotly import line, imshow, scatter

In [None]:
if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"

In [None]:
DO_SLOW_RUNS = True

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f37b99c7520>

## Visualization

In [None]:
def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
def get_attn_head_patterns(model, prompt, attn_heads):
    if isinstance(prompt, str):
        prompt = model.to_tokens(prompt)
    logits, cache = model.run_with_cache(prompt, remove_batch_dim=True)

    head_list = []
    head_name_list = []
    for layer, head in attn_heads:
        head_list.append(cache["pattern", layer, "attn"][head])
        head_name_list.append(f"L{layer}H{head}")
    attention_pattern = torch.stack(head_list, dim=0)
    tokens = model.to_str_tokens(prompt)

    return tokens, attention_pattern, head_name_list

In [None]:
def plot_attention_heads(tensor, title="", top_n=0, range_x=[0, 2.5], threshold=0.02):
    # convert the PyTorch tensor to a numpy array
    values = tensor.cpu().detach().numpy()

    # create a list of labels for each head
    labels = []
    for layer in range(values.shape[0]):
        for head in range(values.shape[1]):
            label = f"Layer {layer}, Head {head}"
            labels.append(label)

    # flatten the values array
    flattened_values = values.flatten()

    if top_n > 0:
        # get the indices of the top N values
        top_indices = flattened_values.argsort()[-top_n:][::-1]

        # filter the flattened values and labels arrays based on the top N indices
        flattened_values = flattened_values[top_indices]
        labels = [labels[i] for i in top_indices]

        # sort the values and labels in descending order
        flattened_values, labels = zip(
            *sorted(zip(flattened_values, labels), reverse=False)
        )

    # create a dataframe with the flattened values and labels
    df = pd.DataFrame({"Logit Diff": flattened_values, "Attention Head": labels})
    flat_value_array = np.array(flattened_values)
    # print sum of all values over threshold
    print(
        f"Total logit diff contribution above threshold: {flat_value_array.sum():.2f}"
    )

    # create the plot
    fig = px.bar(
        df,
        x="Logit Diff",
        y="Attention Head",
        orientation="h",
        range_x=range_x,
        title=title,
    )
    fig.show()

In [None]:
def scatter_attention_and_contribution_logic(
    model,
    head,
    prompts,
    answer_residual_directions,
    return_vals=False,
    return_fig=False,
):

    df = []

    layer, head_idx = head
    # Get the attention output to the residual stream for the head
    logits, cache = model.run_with_cache(prompts)
    per_head_residual, labels = cache.stack_head_results(
        layer=-1, pos_slice=-1, return_labels=True
    )
    scaled_residual_stack = cache.apply_ln_to_stack(
        per_head_residual, layer=-1, pos_slice=-1
    )
    head_resid = scaled_residual_stack[layer * model.cfg.n_heads + head_idx]

    # Loop over each prompt
    for i in range(len(answer_residual_directions)):
        # Get attention values
        tokens, attn, names = get_attn_head_patterns(model, prompts[i], [head])

        # For IO
        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][0]
        )

        # Get the attention probability to the correct answer
        prob = attn[0, 16, 4]
        df.append([prob, dot, "Descriptor", model.to_str_tokens(prompts[i])])

        # Get the attention contribution in the residual directions
        dot = einsum(
            "d_model, d_model -> ", head_resid[i], answer_residual_directions[i][1]
        )
        # Get the attention probability to the wrong answer
        prob = attn[0, 16, 9]
        df.append([prob, dot, "Wrong Descriptor", model.to_str_tokens(prompts[i])])


    # Plot the results
    viz_df = pd.DataFrame(
        df, columns=[f"Attn Prob on Descriptor", f"Dot w Descriptor Embed", "Word Type", "text"]
    )
    fig = px.scatter(
        viz_df,
        x=f"Attn Prob on Descriptor",
        y=f"Dot w Descriptor Embed",
        color="Word Type",
        hover_data=["text"],
        color_discrete_sequence=["rgb(114,255,100)", "rgb(201,165,247)"],
        title=f"How Strong {layer}.{head_idx} Writes in the Descriptor Embed Direction Relative to Attn Prob",
    )

    if return_vals:
        return viz_df
    if return_fig:
        return fig
    else:
        fig.show()

## Path Patching Functions

Note--these path patching functions are not my own code, but rather borrowed from Callum McDougal (with permission).

In [None]:
def patch_or_freeze_head_vectors(
    orig_head_vector,
    hook: HookPoint, 
    new_cache: ActivationCache,
    orig_cache: ActivationCache,
    head_to_patch
):
    '''
    This helps implement step 2 of path patching. We freeze all head outputs (i.e. set them
    to their values in orig_cache), except for head_to_patch (if it's in this layer) which
    we patch with the value from new_cache.

    head_to_patch: tuple of (layer, head)
        we can use hook.layer() to check if the head to patch is in this layer
    '''
    # Setting using ..., otherwise changing orig_head_vector will edit cache value too
    orig_head_vector[...] = orig_cache[hook.name][...]
    if head_to_patch[0] == hook.layer():
        orig_head_vector[:, :, head_to_patch[1]] = new_cache[hook.name][:, :, head_to_patch[1]]
    return orig_head_vector

In [None]:
def patch_head_input(
    orig_activation,
    hook: HookPoint,
    patched_cache: ActivationCache,
    head_list,
):
    '''
    Function which can patch any combination of heads in layers,
    according to the heads in head_list.
    '''
    heads_to_patch = [head for layer, head in head_list if layer == hook.layer()]
    orig_activation[:, :, heads_to_patch] = patched_cache[hook.name][:, :, heads_to_patch]
    return orig_activation


def get_path_patch_head_to_heads(
    receiver_heads,
    receiver_input: str,
    model: HookedTransformer,
    patching_metric,
    new_dataset,
    orig_dataset,
    new_cache: Optional[ActivationCache] = None,
    orig_cache: Optional[ActivationCache] = None,
):
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = input to a later head (or set of heads)

    The receiver node is specified by receiver_heads and receiver_input.
    Example (for S-inhibition path patching the queries):
        receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
        receiver_input = "v"

    Returns:
        tensor of metric values for every possible sender head
    '''
    model.reset_hooks()

    assert receiver_input in ("k", "q", "v")
    receiver_layers = set(next(zip(*receiver_heads)))
    receiver_hook_names = [utils.get_act_name(receiver_input, layer) for layer in receiver_layers]
    receiver_hook_names_filter = lambda name: name in receiver_hook_names

    results = torch.zeros(max(receiver_layers), model.cfg.n_heads, device="cuda", dtype=torch.float32)
    
    # ========== Step 1 ==========
    # Gather activations on x_orig and x_new

    # Note the use of names_filter for the run_with_cache function. Using it means we 
    # only cache the things we need (in this case, just attn head outputs).
    z_name_filter = lambda name: name.endswith("z")
    if new_cache is None:
        _, new_cache = model.run_with_cache(
            new_dataset, 
            names_filter=z_name_filter, 
            return_type=None
        )
    if orig_cache is None:
        _, orig_cache = model.run_with_cache(
            orig_dataset, 
            names_filter=z_name_filter, 
            return_type=None
        )

    # Note, the sender layer will always be before the final receiver layer, otherwise there will
    # be no causal effect from sender -> receiver. So we only need to loop this far.
    for (sender_layer, sender_head) in list(itertools.product(
        range(max(receiver_layers)),
        range(model.cfg.n_heads)
    )):

        # ========== Step 2 ==========
        # Run on x_orig, with sender head patched from x_new, every other head frozen

        hook_fn = partial(
            patch_or_freeze_head_vectors,
            new_cache=new_cache, 
            orig_cache=orig_cache,
            head_to_patch=(sender_layer, sender_head),
        )
        model.add_hook(z_name_filter, hook_fn, level=1)
        
        _, patched_cache = model.run_with_cache(
            orig_dataset, 
            names_filter=receiver_hook_names_filter,  
            return_type=None
        )
        # model.reset_hooks(including_permanent=True)
        assert set(patched_cache.keys()) == set(receiver_hook_names)

        # ========== Step 3 ==========
        # Run on x_orig, patching in the receiver node(s) from the previously cached value
        
        hook_fn = partial(
            patch_head_input, 
            patched_cache=patched_cache, 
            head_list=receiver_heads,
        )
        patched_logits = model.run_with_hooks(
            orig_dataset,
            fwd_hooks = [(receiver_hook_names_filter, hook_fn)], 
            return_type="logits"
        )

        # Save the results
        results[sender_layer, sender_head] = patching_metric(patched_logits)

    return results

## Load Model

In [None]:
source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")

model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-410m",
    hf_model=source_model,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    #refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-410m into HookedTransformer


# Data Setup

In order to prevent bias from semantic priors, below I create a randomized dataset consisting mostly of nonsensical syllogisms. In each case, the descriptor ("exoskeletal" in this case) needs to have a clear opposite. I include it in the prompts for these experiments, but there are certainly many other variations worth exploring.

In [None]:
example_prompt = "All birds are exoskeletal, not endoskeletal. Horses are birds. Therefore, horses are"
example_answer = " ex"
print(len(model.to_str_tokens(example_prompt)))
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

22
Tokenized prompt: ['<|endoftext|>', 'All', ' birds', ' are', ' ex', 'os', 'keletal', ',', ' not', ' end', 'os', 'keletal', '.', ' H', 'orses', ' are', ' birds', '.', ' Therefore', ',', ' horses', ' are']
Tokenized answer: [' ex']


Top 0th token. Logit: 16.01 Prob: 33.29% Token: | ex|
Top 1th token. Logit: 15.19 Prob: 14.75% Token: | not|
Top 2th token. Logit: 14.04 Prob:  4.66% Token: | also|
Top 3th token. Logit: 13.91 Prob:  4.09% Token: | horses|
Top 4th token. Logit: 13.75 Prob:  3.50% Token: | all|
Top 5th token. Logit: 13.12 Prob:  1.85% Token: | the|
Top 6th token. Logit: 12.67 Prob:  1.18% Token: | a|
Top 7th token. Logit: 12.64 Prob:  1.15% Token: | end|
Top 8th token. Logit: 12.51 Prob:  1.01% Token: | birds|
Top 9th token. Logit: 12.24 Prob:  0.77% Token: | animals|


In [None]:
import itertools
def generate_combinations(categories, descriptors, capitalized_nouns, answer_first_tokens, wrong_answer_first_tokens):
    prompts = []
    answers = []
    opposite_prompts = []
    # shift the descriptors by one so that the first descriptor is the answer
    opposites = descriptors[1:] + [descriptors[0]]
    for category, descriptor, capitalized_noun in itertools.product(categories, descriptors, capitalized_nouns):
        lowercase_noun = capitalized_noun.lower()
        opposite = opposites[descriptors.index(descriptor)]
        example = f"All {category} are {descriptor}, not {opposite}. {capitalized_noun} are {category}. Therefore, {lowercase_noun} are"
        opposite = f"All {category} are {opposite}, not {descriptor}. {capitalized_noun} are {category}. Therefore, {lowercase_noun} are"
        prompts.append(example)
        opposite_prompts.append(opposite)
        # get index of answer
        answer_index = descriptors.index(descriptor)
        answers.append((answer_first_tokens[answer_index], wrong_answer_first_tokens[answer_index]))

    prompt_tokens = [model.to_tokens(prompt).squeeze() for prompt in prompts]
    opposites_tokens = [model.to_tokens(prompt).squeeze() for prompt in opposite_prompts]
    answer_tokens = [[model.to_single_token(answer) for answer in answer_pair] for answer_pair in answers]
    answer_tokens = torch.tensor(answer_tokens, device=device)
    return prompts, opposite_prompts, answers, prompt_tokens, opposites_tokens, answer_tokens 

In [None]:
def check_len(items, specified_len, check_lowercase=False, lowercase_spec_len=None):
    new_items = []
    for item in items:
        length = len(model.to_str_tokens(" " + item)) - 1
        lowercase_len = len(model.to_str_tokens(" " + item.lower())) - 1
        if check_lowercase:
            if length == specified_len and lowercase_len == lowercase_spec_len:
                new_items.append(item)
        else:
            if length == specified_len:
                new_items.append(item)
    return new_items

In [None]:
categories = ['birds', 'fish', 'insects', 'mammals', 'primates', 'rodents', 'predators', 'prey', 'game']
categories = check_len(categories, 1)

descriptors = [
    'endothermic',
    'ectothermic',
    'herbivorous',
    'carnivorous',
    'bipedal',
    'quadrupedal',
    'exoskeletal',
    'endoskeletal',
    'digitigrade',
    'plantigrade'
    ]
descriptors = check_len(descriptors, 3)

answer_first_tokens = [model.to_str_tokens(" " + d)[1] for d in descriptors]
wrong_answer_first_tokens = []
for i in range(len(answer_first_tokens)):
    if i % 2 == 0:
        wrong_answer_first_tokens.append(answer_first_tokens[i + 1])
    else:
        wrong_answer_first_tokens.append(answer_first_tokens[i - 1])

capitalized_nouns = ['Dogs', 'Elephants', 'Lions', 'Zebras', 'Cows', 'Horses',
'Pigs', 'Monkeys', 'Wolves', 'Deer', 'Foxes', 'Bears', 'Mice', 'Bats', 'Snakes',
'Lizards', 'Seals', 'Whales', 'Cats', 'Squid', 'Dolphins', 'Crabs', 'Snails',
'Bees', 'Wasps', 'Emus', 'Ducks', 'Owls', 'Pigs', 'Goats'
]
capitalized_nouns = check_len(capitalized_nouns, 2, check_lowercase=True, lowercase_spec_len=1)

prompts, opposites_prompts, answers, prompt_tokens, opposites_tokens, answer_tokens = generate_combinations(categories, descriptors, capitalized_nouns, answer_first_tokens, wrong_answer_first_tokens)
len(prompts)

1440

In [None]:
def get_logit_diff(logits, answer_token_indices, print_per_prompt=False):
    """Gets the difference between the logits of the provided tokens (e.g., the correct and incorrect tokens in IOI)

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.

    Returns:
        torch.Tensor: Difference between the logits of the provided tokens.
    """
    if len(logits.shape) == 3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    if print_per_prompt:
        print(correct_logits - incorrect_logits)

    return (correct_logits - incorrect_logits).mean()

In [None]:
# check answer tokens
for i in range(len(prompts)):
    if prompt_tokens[i][4] != answer_tokens[i][0]:
        print(prompt_tokens[i])

clean_prompts = []
clean_opposites_prompts = []
clean_answer_tokens = []

# check logit differences
for i in range(len(prompts)):
    test_logits, test_cache = model.run_with_cache(prompt_tokens[i].unsqueeze(0).to(device))
    opposite_logits, opposite_cache = model.run_with_cache(opposites_tokens[i].unsqueeze(0).to(device))
    logit_diffs = get_logit_diff(test_logits, answer_tokens[i].unsqueeze(0))
    opposite_logit_diffs = get_logit_diff(opposite_logits, answer_tokens[i].unsqueeze(0))

    if logit_diffs.item() > 2 and opposite_logit_diffs.item() < -2:
        clean_prompts.append(prompts[i])
        clean_opposites_prompts.append(opposites_prompts[i])
        clean_answer_tokens.append(answer_tokens[i])

In [None]:
len(clean_prompts)

59

In [None]:
# take random samples from clean dataset

import random

random.seed(42)
dataset_size = len(clean_prompts)
random_indices = random.sample(range(dataset_size), 56)
#random_indices.sort()

clean_prompts = [clean_prompts[i] for i in random_indices]
clean_opposites_prompts = [clean_opposites_prompts[i] for i in random_indices]
clean_answer_tokens = torch.stack([clean_answer_tokens[i] for i in random_indices])

In [None]:
def pair_opposites(prompts_list, opposites_prompts_list, answer_tokens_list):
    prompts = []
    answer_tokens = []
    for i in range(len(prompts_list)):
        if prompts_list[i] not in prompts:
            prompts.append(prompts_list[i])
            answer_tokens.append(answer_tokens_list[i])
        if opposites_prompts_list[i] not in prompts:
            prompts.append(opposites_prompts_list[i])
            answer_tokens.append(torch.tensor([answer_tokens_list[i][1], answer_tokens_list[i][0]], device=device))
    return prompts, answer_tokens

prompts, answer_tokens = pair_opposites(
    clean_prompts, 
    clean_opposites_prompts,
    clean_answer_tokens)

In [None]:
prompts[:10], answer_tokens[:10]

(['All rodents are exoskeletal, not endoskeletal. Cats are rodents. Therefore, cats are',
  'All rodents are endoskeletal, not exoskeletal. Cats are rodents. Therefore, cats are',
  'All birds are exoskeletal, not endoskeletal. Pigs are birds. Therefore, pigs are',
  'All birds are endoskeletal, not exoskeletal. Pigs are birds. Therefore, pigs are',
  'All birds are exoskeletal, not endoskeletal. Horses are birds. Therefore, horses are',
  'All birds are endoskeletal, not exoskeletal. Horses are birds. Therefore, horses are',
  'All predators are exoskeletal, not endoskeletal. Monkeys are predators. Therefore, monkeys are',
  'All predators are endoskeletal, not exoskeletal. Monkeys are predators. Therefore, monkeys are',
  'All mammals are exoskeletal, not endoskeletal. Cats are mammals. Therefore, cats are',
  'All mammals are endoskeletal, not exoskeletal. Cats are mammals. Therefore, cats are'],
 [tensor([385, 990], device='cuda:0'),
  tensor([990, 385], device='cuda:0'),
  tensor(

In [None]:
clean_tokens = model.to_tokens(prompts)
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.stack(answer_tokens)


Clean string 0 <|endoftext|>All rodents are exoskeletal, not endoskeletal. Cats are rodents. Therefore, cats are
Corrupted string 0 <|endoftext|>All rodents are endoskeletal, not exoskeletal. Cats are rodents. Therefore, cats are


In [None]:
answer_token_indices.shape

torch.Size([104, 2])

In [None]:
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    prompt_length = len(str_tokens)
    if prompt_length != 22:
      print(prompt)
      print(prompt_length)

# Metrics & Baselines

In [None]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 2.6822
Corrupted logit diff: -2.6822


In [None]:
def normalized_metric(logits, clean_baseline, corrupted_baseline, answer_token_indices):
    """Computes the IOI metric for a given set of logits, baselines, and answer token indices. Metric is relative to the
    provided baselines.

    Args:
        logits (torch.Tensor): Logits to use.
        clean_baseline (float): Baseline for the clean model.
        corrupted_baseline (float): Baseline for the corrupted model.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.

    Returns:
        torch.Tensor: IOI metric.
    """
    return (get_logit_diff(logits, answer_token_indices) - corrupted_baseline) / (
        clean_baseline - corrupted_baseline
    )

In [None]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff

normalized_metric = partial(normalized_metric, clean_baseline=CLEAN_BASELINE, corrupted_baseline=CORRUPTED_BASELINE, answer_token_indices=answer_token_indices)

clean_baseline_norm = normalized_metric(clean_logits)
corrupted_baseline_norm = normalized_metric(corrupted_logits)

print(f"Clean Baseline is 1: {normalized_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {normalized_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


# Initial Logit Attribution

We will first look at direct logit attribution in order to obtain the parts of the network that contribute most clearly to the final answer. Similar to the IOI task, perhaps we will find something akin to the name-mover heads.

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([104, 2, 1024])
Logit difference directions shape: torch.Size([104, 1024])


In [None]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = clean_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",clean_logit_diff)

Final residual stream shape: torch.Size([104, 22, 1024])
Calculated average logit diff: 2.6821632385253906
Original logit difference: 2.682163715362549


In [None]:
def residual_stack_to_logit_diff(residual_stack, logit_diff_directions, prompts, cache):
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)

## Layer Attribution

In [None]:
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, logit_diff_directions, prompts, clean_cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

Layer 8 contributes the most by far here (attention specifically).

## Head Attribution

In [None]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, logit_diff_directions, prompts, clean_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, labels={"x":"Head", "y":"Layer"}, title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


These may be word-mover heads. Is there a clear cutoff point for inclusion? Below, I create an importance chart to compare their contributions.

In [None]:
plot_attention_heads(per_head_logit_diffs/clean_logit_diff, top_n=10, range_x=[0, 1])

Total logit diff contribution above threshold: 1.14


## Attention Analysis

Now that we have some heads to evaluate, let's look at their attention patterns.

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]], 
    local_cache: Optional[ActivationCache]=None, 
    local_tokens: Optional[torch.Tensor]=None, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = clean_cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = clean_tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [None]:
top_k = 3
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, title=f"Top {top_k} Positive Logit Attribution Heads")

Here we don't see any negative heads. The heads we do see attend to the first descriptor token ("ex") or the first *incorrect* descriptor token. As far as I can tell, PySvelte uses the first prompt to generate this display, rather than averaging attention patterns; thus, we should be cautious about what this visualization is telling us. In any case:

**Hypothesis:** The heads we see here function similarly to the Name Mover Heads in the IOI task. We'll call these "potential word movers" and label the head attending to the incorrect token as a potential inhibitor below.

**Interesting Observation:** One of the heads, L18H0, is the same as one of the Name Mover Heads I identified in my previous analyses of the IOI task in Pythia-410m.

In [None]:
potential_wmh = [(18, 0), (19, 4)]
potential_inhibitor = [(20, 2)]

Let's examine their behavior in more detail.

In [None]:
for h in potential_wmh:
    scatter_attention_and_contribution_logic(model, h, clean_tokens, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Observations:
- There's a clear attention probability threshold, above which the head abruptly transitions to strongly writing in the residual directions of the answer token.
- There are pretty clear clusters in these graphs. What do they correspond to? There seem to be some regions roughly grouped by category (e.g., predators, birds) but it's pretty clearly more complicated than that. This seems worth studying at some point.
- Overall, however, these heads seem to indicate what we'd expect from a word mover head: More attention results in greater probability of writing the word to the residual stream.

In [None]:
for h in potential_inhibitor:
    scatter_attention_and_contribution_logic(model, h, clean_tokens, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


Observations:
- We were right to doubt the attention pattern graph above: It's likely that the attention pattern we saw for this head corresponded to one of the prompts where there was not sufficient attention to make the head write the word to the residual stream.
- This head follows the WMH pattern, so we'll include it.

In [None]:
potential_wmh = potential_wmh + potential_inhibitor

# Activation Patching

## Top Level

Let's get a bird's-eye-view of what positions and head components are important for these heads.

In [None]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(
    model, 
    corrupted_tokens, 
    clean_cache, 
    normalized_metric)



  0%|          | 0/528 [00:00<?, ?it/s]

In [None]:
imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="Normed Metric for 'resid_pre' Activation Patching")

Interestingly, we can restore function by patching clean activations to the descriptor token position, but the opposite happens when we patch activations to the incorrect answer position.

In [None]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, normalized_metric)

In [None]:
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer") #, zmax=1, zmin=-1)

In [None]:
potential_wmh

[(18, 0), (19, 4), (20, 2)]

It may also be valuable to patch each head component at every position to see what we might be able to find out.

In [None]:
if DO_SLOW_RUNS:
    every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, normalized_metric)
    every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, "act_type layer pos head -> act_type (layer head) pos")
    imshow(every_head_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head")

In [None]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
imshow(
    every_head_act_patch_result, 
    facet_col=0, 
    yaxis="Head Label", 
    xaxis="Pos", 
    x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
    y=ALL_HEAD_LABELS,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"], 
    title="Activation Patching Per Head", zmin=-.1, zmax=0.1)

## Value Positions

The value activations seem to be much more important to recovering correct performance than the attention heads, at least when it comes to the potential WMHs. What token positions matter the most, relative to value?

In [None]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_v_act_patch_results = patching.get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, normalized_metric)
    attn_head_v_act_patch_results = einops.rearrange(attn_head_v_act_patch_results, "layer pos head -> (layer head) pos")

  0%|          | 0/8448 [00:00<?, ?it/s]

In [None]:
imshow(attn_head_v_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_v Activation Patching By Pos")

Position 4 is clearly the most important for all three potential word-mover heads' value activation performance.

## Tracing Information Backwards With Path Patching

To sum up so far, we've identified three heads who are responsible for the majority of writing in the directions of the correct logits. These heads share the following characteristics:
- Each attends to the correct descriptor (most of the time, in a token-dependent fashion)
- Value activations rather than attention patterns are most important, and this applies mostly to token position `4`.
- Each writes out the correct answer in an attention-probability-dependent fashion, with a clear abrupt threshold.

The evidence so far seems to justify tentatively calling these word-mover heads (or WMHs).

A natural next question is: What previous heads are contributing to their performance?

We'll conduct path-patching for both Q, K, and V components of the receiver attention heads to see what makes the most difference.

### Path Patching for Values

In [None]:
wmh_path_patching_results = get_path_patch_head_to_heads(
    receiver_heads = potential_wmh,
    receiver_input = "v",
    model = model,
    patching_metric = normalized_metric,
    new_dataset = corrupted_tokens,
    orig_dataset = clean_tokens
)
norm_res = -(wmh_path_patching_results - clean_baseline_norm) / clean_baseline_norm


In [None]:
imshow(
    norm_res,
    title="Direct Effect on Normed Logit Diff Through Word-Mover Heads' Values", 
)

We see a very weak effect here, and the heads that do have an effect are something around half of the entire selection of attention heads. We'll focus on other components.

### Path Patching for Queries

In [None]:
wmh_path_patching_results = get_path_patch_head_to_heads(
    receiver_heads = potential_wmh,
    receiver_input = "q",
    model = model,
    patching_metric = normalized_metric,
    new_dataset = corrupted_tokens,
    orig_dataset = clean_tokens
)
q_norm_res = -(wmh_path_patching_results - clean_baseline_norm) / clean_baseline_norm

In [None]:
imshow(
    q_norm_res,
    title="Direct Effect on Normed Logit Diff Through Word-Mover Heads' Queries", 
)

In [None]:
plot_attention_heads(
    -q_norm_res, 
    title="Direct % Change in Normed Logit Diff From Q-Path Patching in Corrupted Attention Patterns", 
    top_n=15, 
    range_x=[0, 0.05]
)

Total logit diff contribution above threshold: 0.07


There's a pretty clear set of attention heads that have an outsize impact through the queries of the WMHs. But what are they doing? Let's examine their attention patterns. We'll consider the top two here, since there seems to be a steep drop-off in effect after this point.

(We will exclude L18H0 since it is one of the WMHs--it does seem to have an impact on the other WMHs, but we'll save that for future investigation).

In [None]:
second_level_q_effectors = [(17, 15), (14, 0)]

In [None]:
top_k = 3
top_second_level_q_effectors = torch.topk(-q_norm_res.flatten(), k=top_k).indices[1:]
visualize_attention_patterns(top_second_level_q_effectors, title=f"Top {top_k} Positive Logit Attribution Heads")

From these attention patterns alone, it's a bit difficult to judge what these heads are doing--there isn't a clear pattern of attention. What we do see is the following:
- At the final token position, L17H15 attends back to "are." Is it possible that this is a kind of duplicate name token, and that the head that it influences is simply performing induction by copying the word that comes after the first "are"?
- Note that L17H15 was the head whose values were most important in the attention head V patching above--specifically at the "ex" position.
- If we wanted to trace this circuit back further, we would want to try path patching for this specific position.

### Path Patching for Keys

In [None]:
wmh_path_patching_results = get_path_patch_head_to_heads(
    receiver_heads = potential_wmh,
    receiver_input = "k",
    model = model,
    patching_metric = normalized_metric,
    new_dataset = corrupted_tokens,
    orig_dataset = clean_tokens
)
k_norm_res = -(wmh_path_patching_results - clean_baseline_norm) / clean_baseline_norm

In [None]:
imshow(
    k_norm_res,
    title="Direct % Change in Normed Logit Diff Through Word-Mover Heads' Keys", 
)

In [None]:
plot_attention_heads(
    -k_norm_res, 
    title="Direct % Change in Normed Logit Diff From K-Path Patching in Corrupted Attention Patterns", 
    top_n=15, 
    range_x=[0, 0.1]
)

Total logit diff contribution above threshold: 0.07


In [None]:
top_k = 2
top_second_level_k_effectors = torch.topk(k_norm_res.flatten(), k=top_k).indices
visualize_attention_patterns(top_second_level_k_effectors, title=f"Top {top_k} Positive Logit Attribution Heads")

Here it is likewise difficult to see what these heads are doing. L1H11 appears to be a punctuation head, as it attends to the commas and periods. Generally, the effect of patching these heads on the WMHs' keys is very weak, so we will not investigate further for now. 