This example finds attention heads that move information from variable types to the corresponding variable name in python code. It uses the `FunctionFinder` token finder to find the relevant tokens in the code, then the  `ActivationAnalyzer` class to find heads that match the criteria.

## Setup

In [None]:
from transformer_lens import HookedTransformer
from llm_inspect import TokenDisplayer


llm = HookedTransformer.from_pretrained("EleutherAI/pythia-2.8b-deduped-v0")

In [None]:
code = """# This is an irrelevant comment
def add(a: int, b: int, c: str) -> int:
    return str(a + b) + c

def subtract(a: int, b: int) -> int:
    return a - b
"""

input_tokens = llm.tokenizer.tokenize(code, add_special_tokens=True)
input_token_ids = llm.tokenizer.encode(code, add_special_tokens=True, return_tensors="pt")

print(input_tokens)
print(input_token_ids)

In [4]:
_, activation_cache = llm.run_with_cache(input_token_ids)

In [5]:
token_displayer = TokenDisplayer.create_for_tokenizer(llm.tokenizer)

## Find variable type-variable name pairs using FunctionFinder

In [None]:
from llm_inspect import Token
from dataclasses import dataclass

@dataclass
class VariableNameTypePair:
    name_token: Token
    type_tokens: list[Token]

In [None]:
from llm_inspect import ActivationAnalyzer, FunctionFinder

function_finder = FunctionFinder.create_from_tokenizer(code, llm.tokenizer)
activation_analyzer = ActivationAnalyzer.create_from_tokenizer(llm.tokenizer, input_tokens, activation_cache)

In [None]:
function = function_finder.find_function_scope("add")

print(f"Function name token: {function.function_name_token}")
print(f"Parameters: {function.parameters}")
print(f"Return type token: {function.return_type_token}")

In [9]:
print("Function scope:")
token_displayer.html_for_scope_with_context(function.function_scope)

Function scope:


In [10]:
print("Function body scope:")
token_displayer.html_for_scope_with_context(function.body_scope)

Function body scope:


In [11]:
a_variable_token = function.body_scope.find_first("a", allow_space_prefix=True)
b_variable_token = function.body_scope.find_first("b", allow_space_prefix=True)
c_variable_token = function.body_scope.find_first("c", allow_space_prefix=True)

token_displayer.html_for_token_with_context(a_variable_token)

### Find heads that move information from variable types to variable names

In [None]:
from llm_inspect import AttentionHead


a_matching_heads = activation_analyzer.find_heads_where_query_looks_at_value(a_variable_token, function.parameters[0].type)
b_matching_heads = activation_analyzer.find_heads_where_query_looks_at_value(b_variable_token, function.parameters[1].type)
c_matching_heads = activation_analyzer.find_heads_where_query_looks_at_value(c_variable_token, function.parameters[2].type)

type_moving_heads = AttentionHead.intersection([
    a_matching_heads,
    b_matching_heads,
    c_matching_heads
])

print(f"Found {len(type_moving_heads)} head(s) that move information from the type token to the variable token.")

## Visualise

In [None]:
from llm_inspect import TokenDisplayer


token_displayer = TokenDisplayer.create_for_tokenizer(llm.tokenizer)

In [None]:
print(f"Head {type_moving_heads[0]}:")

token_displayer.html_for_token_attention(
    input_tokens,
    activation_cache,
    type_moving_heads[0],
)

What's interesting is that the same head we found that moves variable type information in the python code also moves information from the type to the variable name in Java code, even though python and java have different syntax for declaring variables.

In [15]:
java_code = """// This is an irrelevant comment
public class Main {
    public static void main(String[] args) {
        int a = 1;
        int b = 2;
        String c = "3";
        System.out.println(a + b + c);
    }
}
"""

java_input_tokens = llm.tokenizer.tokenize(java_code, add_special_tokens=True)
java_input_ids = llm.tokenizer.encode(java_code, add_special_tokens=True, return_tensors="pt")
_, java_activation_cache = llm.run_with_cache(java_input_ids)

token_displayer.html_for_token_attention(
    java_input_tokens,
    java_activation_cache.remove_batch_dim(),
    type_moving_heads[0],
)