## Autointerp for SHIFT evals
LLM judge decides whether a neuron/latent is related to a natural language concept. Inputs to Autointerp LLM judge:
- Max activating examples
- DLA top promoted tokens


### Functionality of this notebook
Inputs: 
- model, datset, dictionaries
- list of concepts to check whether it is related to sth.

Outputs:
- node_effects.pkl per dictionary per concept with binary yes/no decision on whether feature is related to prompt.

In [1]:
# Imports

import os
import torch as t
from nnsight import LanguageModel
import datasets
import anthropic
from tqdm import tqdm

import experiments.utils as utils
from experiments.autointerp import (
    get_max_activating_prompts, 
    compute_dla, 
    highlight_top_activations,
    evaluate_binary_llm_output
)


DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

%load_ext autoreload
%autoreload 2

In [5]:
# Securely input the API key
api_key = input("Enter your API key: ")
os.environ['ANTHROPIC_API_KEY'] = api_key

In [6]:
# Load model

DEVICE = "cuda"
model_name = "EleutherAI/pythia-70m-deduped"
model_dtype = t.bfloat16
model = LanguageModel(
    model_name,
    device_map=DEVICE,
    dispatch=True,
    attn_implementation="eager",
    torch_dtype=model_dtype,
)
model_unembed = model.embed_out # For direct logit attribution

In [7]:
# Load data

num_contexts = 10000
context_length = 128
batch_size = 250

dataset = datasets.load_dataset("georgeyw/dsir-pile-100k", streaming=False)
data = model.tokenizer(dataset["train"]["contents"][:num_contexts], return_tensors="pt", padding="max_length", truncation=True, max_length=context_length).to(DEVICE).data
batched_data = utils.batch_inputs(data, batch_size)

Repo card metadata block was not found. Setting CardData to empty.


In [8]:
# Load dictionary

dictionaries_path = "../dictionary_learning/dictionaries"

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    "pythia70m_sweep_standard_ctx128_0712": {"resid_post_layer_3": {"trainer_ids": [6]}},
    # "pythia70m_sweep_gated_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [9]}},
    # "pythia70m_sweep_topk_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [10]}},
    # "gemma-2-2b_sweep_topk_ctx128_0817": {"resid_post_layer_12": {"trainer_ids": [2]}}, 
}
sweep_name = list(ae_sweep_paths.keys())[0]
submodule_trainers = ae_sweep_paths[sweep_name]

ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)
ae_paths = utils.get_ae_paths(ae_group_paths)

ae_path = ae_paths[0]
submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

Loading dictionary from ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_6


In [11]:
# Get max_activating_inputs 
# and direct logit attribution (DLA) scores per feature

# all_latent_indices = t.arange(dictionary.dict_size)
# all_latent_indices = t.tensor([0])
all_latent_indices = t.arange(0, 10)
k_inputs_per_feature = 10

max_token_idxs_FKL, max_activations_FKL = get_max_activating_prompts(
    model, 
    submodule, 
    batched_data, 
    all_latent_indices, 
    batch_size, 
    dictionary, 
    k=k_inputs_per_feature,
    context_length=context_length
)

# TODO write out max activating prompts to a file

top_dla_token_idxs_FK = compute_dla(
    all_latent_indices,
    dictionary.decoder.weight,
    model_unembed.weight,
    k_inputs_per_feature
)

  0%|          | 0/40 [00:00<?, ?it/s]You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 40/40 [00:07<00:00,  5.15it/s]


In [12]:
# Format max_activating_inputs by << emphasizing>> max act examples

num_top_emphasized_tokens = 5

example_prompts = []
for feat_idx, (max_token_idxs_KL, max_activations_KL), in enumerate(zip(max_token_idxs_FKL, max_activations_FKL)):
    decoded_tokens_KL = model.tokenizer.batch_decode(max_token_idxs_KL, skip_special_tokens=True)
    formatted_sequences_K = highlight_top_activations(
            decoded_tokens_KL, 
            max_activations_KL, 
            top_n=num_top_emphasized_tokens, 
            include_activations=False
    )
    formatted_sequences = [seq for seq in formatted_sequences_K if seq]
    example_prompt = []
    for i, seq in enumerate(formatted_sequences):
        example_prompt.append(f"Example {i+1}: {seq}\n\n")
    example_prompt = "".join(example_prompt)
    example_prompts.append(example_prompt)

top_dla_tokens_FK = model.tokenizer.batch_decode(top_dla_token_idxs_FK, skip_special_tokens=True)

In [13]:
from experiments.explainers.simple.prompt_builder import build_prompt
from collections import defaultdict

prompts = defaultdict(list)
for example_prompt, top_dla_tokens_K in zip(example_prompts, top_dla_tokens_FK):
    system_prompt, message = build_prompt(
        examples=example_prompt,
        cot=False,
        top_logits=top_dla_tokens_K,
        concept="gender",
    )
    prompts['system'].append(system_prompt)
    prompts['message'].append(message)

In [14]:
prompts

defaultdict(list,
            {'system': [[{'type': 'text',
                'text': "You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and decide whether its behavior is related to the concept of gender.\n\n(Part 2) Tokens that the neuron boosts in the next token prediction\n\nYou will also be shown a list called Top_logits. The logits promoted by the neuron shed light on how the neuron's activation influences the model's predictions or outputs. Look at this list of Top_logits and refine your hypotheses from part 1. It is possible that this list is more informative than the examples from part 1.\n\nPay close attention to the words in this list and write down what they have in common. Then look at what they have in common, as well as patterns in the tokens you found in Part 1, to produce a single explanation for what features of text cause the neuron to activate. Propose your explanation

In [15]:
client = anthropic.Anthropic()

llm_outputs = []
for i, (system_prompt, messages) in tqdm(
    enumerate(zip(prompts['system'], prompts['message'])), 
    desc="LLM inference", total=len(prompts['system'])):

    # barrier for testing
    if i == 5: 
        print("stopping LLM inference early for testing")
        break

    llm_out = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1000,
        temperature=0,
        system=system_prompt,
        messages=messages,
    )
    llm_outputs.append(llm_out.content)

print(f'Total number of LLM outputs: {len(llm_outputs)}\n')
llm_outputs[:5]

LLM inference:   0%|          | 0/10 [00:00<?, ?it/s]

LLM inference:  50%|█████     | 5/10 [00:16<00:16,  3.27s/it]

stopping LLM inference early for testing
Total number of LLM outputs: 5






[[TextBlock(text='I apologize, but the Top_logits list appears to be corrupted or improperly formatted in this case. Without valid information for the Top_logits, I\'ll base my analysis solely on the example provided in Part 1.\n\nFrom the given example, we can observe:\n\n1. The text appears to be part of a fantasy or science fiction narrative.\n2. The neuron activates on the letter "l" in "multiplied", but this seems to be in the middle of a word and doesn\'t provide much context.\n3. The surrounding text doesn\'t show any clear pattern related to gender.\n\nGiven the limited and unclear information, and the lack of a valid Top_logits list, it\'s difficult to draw any concrete conclusions about this neuron\'s behavior, especially in relation to gender.\n\n[yes/no DECISION]: no', type='text')],
 [TextBlock(text="I apologize, but the Top_logits list appears to be empty or corrupted in this case. Without that additional information, I'll have to base my analysis solely on the example se

In [16]:
decisions = evaluate_binary_llm_output(llm_outputs)
decisions

tensor([0, 0, 0, 0, 0], dtype=torch.int32)

In [71]:
# TODO scale LLM prompting to multiple classes to write out node_effects.pkl

# Check structure of current pkl file.

# import pickle

# with open(os.path.join(ae_path, "node_effects.pkl"), "rb") as f:
#     node_effects = pickle.load(f)

# node_effects

{-4: tensor([ 0.0000e+00, -9.3937e-05,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00, -1.4782e-04], dtype=torch.bfloat16),
 -2: tensor([0.0000e+00, 6.3419e-05, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         1.2970e-04], dtype=torch.bfloat16),
 0: tensor([ 0.0000e+00, -4.4107e-05,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00, -5.7936e-05], dtype=torch.bfloat16),
 1: tensor([ 0.0000, -0.0001,  0.0000,  ...,  0.0000,  0.0000,  0.0006],
        dtype=torch.bfloat16),
 2: tensor([ 0.0000, -0.0003,  0.0000,  ...,  0.0000,  0.0000, -0.0008],
        dtype=torch.bfloat16)}