## Autointerp for SHIFT evals
LLM judge decides whether a neuron/latent is related to a natural language concept. Inputs to Autointerp LLM judge:
- Max activating examples
- DLA top promoted tokens


### Functionality of this notebook
Inputs: 
- model, datset, dictionaries
- list of concepts to check whether it is related to sth.

Outputs:
- node_effects.pkl per dictionary per concept with binary yes/no decision on whether feature is related to prompt.

In [1]:
# Imports

import os
import torch as t
from nnsight import LanguageModel
import datasets
import anthropic
from tqdm import tqdm
import re
import ast
import pickle
from collections import defaultdict

import experiments.utils as utils
from experiments.autointerp import (
    get_max_activating_prompts, 
    compute_dla, 
    format_examples,
    evaluate_binary_llm_output
)
from experiments.explainers.simple.prompt_builder import build_prompt
from experiments.explainers.simple.prompts import build_system_prompt

DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

%load_ext autoreload
%autoreload 2

In [2]:
# Securely input the API key
api_key = input("Enter your API key: ")
os.environ['ANTHROPIC_API_KEY'] = api_key

In [3]:
# Load model

DEVICE = "cuda"
model_name = "EleutherAI/pythia-70m-deduped"
model_dtype = t.bfloat16
model = LanguageModel(
    model_name,
    device_map=DEVICE,
    dispatch=True,
    attn_implementation="eager",
    torch_dtype=model_dtype,
)
model_unembed = model.embed_out # For direct logit attribution

In [4]:
# Load data

num_contexts = 10000
context_length = 128
batch_size = 250

dataset = datasets.load_dataset("georgeyw/dsir-pile-100k", streaming=False)
data = model.tokenizer(dataset["train"]["contents"][:num_contexts], return_tensors="pt", padding="max_length", truncation=True, max_length=context_length).to(DEVICE).data
batched_data = utils.batch_inputs(data, batch_size)


# Class specific data

profession_dict = {
    "accountant": 0, "architect": 1, "attorney": 2, "chiropractor": 3,
    "comedian": 4, "composer": 5, "dentist": 6, "dietitian": 7,
    "dj": 8, "filmmaker": 9, "interior_designer": 10, "journalist": 11,
    "model": 12, "nurse": 13, "painter": 14, "paralegal": 15,
    "pastor": 16, "personal_trainer": 17, "photographer": 18, "physician": 19,
    "poet": 20, "professor": 21, "psychologist": 22, "rapper": 23,
    "software_engineer": 24, "surgeon": 25, "teacher": 26, "yoga_teacher": 27,
    "profession": -4, "gender": -2
}

Repo card metadata block was not found. Setting CardData to empty.


In [5]:
# Load dictionary

dictionaries_path = "../dictionary_learning/dictionaries"

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    "pythia70m_sweep_standard_ctx128_0712": {"resid_post_layer_3": {"trainer_ids": [6]}},
    # "pythia70m_sweep_gated_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [9]}},
    # "pythia70m_sweep_topk_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [10]}},
    # "gemma-2-2b_sweep_topk_ctx128_0817": {"resid_post_layer_12": {"trainer_ids": [2]}}, 
}
sweep_name = list(ae_sweep_paths.keys())[0]
submodule_trainers = ae_sweep_paths[sweep_name]

ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)
ae_paths = utils.get_ae_paths(ae_group_paths)

ae_path = ae_paths[0]
submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

Loading dictionary from ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_6


In [6]:
# Get max_activating_inputs 
# and direct logit attribution (DLA) scores per feature

# all_latent_indices = t.arange(dictionary.dict_size)
# all_latent_indices = t.tensor([0])
all_latent_indices = t.arange(0, 10)
k_inputs_per_feature = 10

max_token_idxs_FKL, max_activations_FKL = get_max_activating_prompts(
    model, 
    submodule, 
    batched_data, 
    all_latent_indices, 
    batch_size, 
    dictionary, 
    k=k_inputs_per_feature,
    context_length=context_length
)

# TODO write out max activating prompts to a file

top_dla_token_idxs_FK = compute_dla(
    all_latent_indices,
    dictionary.decoder.weight,
    model_unembed.weight,
    k_inputs_per_feature
)

  0%|          | 0/40 [00:00<?, ?it/s]You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


100%|██████████| 40/40 [00:07<00:00,  5.27it/s]


In [7]:
# Format max_activating_inputs by << emphasizing>> max act examples

num_top_emphasized_tokens = 5

example_prompts = format_examples(model, max_token_idxs_FKL, max_activations_FKL, num_top_emphasized_tokens)
top_dla_tokens_FK = model.tokenizer.batch_decode(top_dla_token_idxs_FK, skip_special_tokens=True)

In [8]:
system_prompt = build_system_prompt(
    cot=False,
    concepts=list(profession_dict.keys()),
    logits=True
)

prompts = []
for example_prompt, top_dla_tokens_K in zip(example_prompts, top_dla_tokens_FK):
    message = build_prompt(
        examples=example_prompt,
        cot=False,
        top_logits=top_dla_tokens_K,
    )
    prompts.append(message)

In [9]:
system_prompt

[{'type': 'text',
  'text': 'You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and decide whether its behavior is related to a concept for each concept in (accountant, architect, attorney, chiropractor, comedian, composer, dentist, dietitian, dj, filmmaker, interior_designer, journalist, model, nurse, painter, paralegal, pastor, personal_trainer, photographer, physician, poet, professor, psychologist, rapper, software_engineer, surgeon, teacher, yoga_teacher, profession, gender).\n\n(Part 2) Tokens that the neuron boosts in the next token prediction\n\nYou will also be shown a list called Top_logits. The logits promoted by the neuron shed light on how the neuron\'s activation influences the model\'s predictions or outputs. Look at this list of Top_logits and refine your hypotheses from part 1. It is possible that this list is more informative than the examples from part 1.\n\nPay close 

In [10]:
prompts

[[{'role': 'user',
   'content': '\nExample 1:  and he was <<over the moon>> to find\nExample 2:  we\'ll be laughing <<till the cows come home>>! Pro\nExample 3:  thought Scotland was boring, but really there\'s more <<than meets the eye>>! I\'d\n\nTop_logits: ["elated", "joyful", "story", "thrilled", "spider"]\n'},
  {'role': 'assistant',
   'content': '\n(Part 2)\nSIMILAR TOKENS: "elated", "joyful", "thrilled".\n- The top logits list contains words that are strongly associated with positive emotions.\n'},
  {'role': 'user',
   'content': '\nExample 1:  a river is wide but the ocean is wid<<er>>. The ocean\nExample 2:  every year you get tall<<er>>," she\nExample 3:  the hole was small<<er>> but deep<<er>> than the\n\nTop_logits: ["apple", "running", "book", "wider", "quickly"]\n'},
  {'role': 'assistant',
   'content': '\n(Part 2)\nSIMILAR TOKENS: None\n- The top logits list contains unrelated nouns and adverbs.\n'},
  {'role': 'user',
   'content': '\nExample 1:  something happening

In [11]:
# TODO prompt caching
# TODO parallelize LLM inference

client = anthropic.Anthropic()

llm_outputs_direct_prompt = []
for i, messages in tqdm(enumerate(prompts), desc="LLM inference", total=len(prompts)):

    # barrier for testing
    if i == 5: 
        print("stopping LLM inference early for testing")
        break

    llm_out = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1000,
        temperature=0,
        system=system_prompt,
        messages=messages,
    )
    llm_outputs_direct_prompt.append(llm_out.content)

print(f'Total number of LLM outputs: {len(llm_outputs_direct_prompt)}\n')
llm_outputs_direct_prompt[:5]

LLM inference:   0%|          | 0/10 [00:00<?, ?it/s]

LLM inference:  50%|█████     | 5/10 [00:33<00:33,  6.73s/it]

stopping LLM inference early for testing
Total number of LLM outputs: 5






[[TextBlock(text='I apologize, but the Top_logits list appears to be corrupted or improperly formatted in this case. Without valid information from the Top_logits, I\'ll have to base my analysis solely on the text example provided.\n\nBased on the text example alone, this neuron seems to activate on the end of words, particularly when they form part of a longer word that\'s split across tokens. In this case, it activates on "ose" which is likely part of the word "comatose".\n\nHowever, without more examples or valid Top_logits data, it\'s difficult to make a confident assessment about what concept this neuron might be related to. The single example doesn\'t provide enough context to reliably connect it to any of the specific professions or concepts listed.\n\nGiven the limited and potentially unreliable information, I cannot confidently associate this neuron\'s behavior with any of the listed concepts. Therefore, I will mark all concepts as "no" in this case.\n\nyes_or_no_decisions = {

In [12]:
for out in llm_outputs_direct_prompt:
    print(f'{out[0].text.split('yes_or_no_decisions = ')[-1]}\n')

{"accountant": "no", "architect": "no", "attorney": "no", "chiropractor": "no", "comedian": "no", "composer": "no", "dentist": "no", "dietitian": "no", "dj": "no", "filmmaker": "no", "interior_designer": "no", "journalist": "no", "model": "no", "nurse": "no", "painter": "no", "paralegal": "no", "pastor": "no", "personal_trainer": "no", "photographer": "no", "physician": "no", "poet": "no", "professor": "no", "psychologist": "no", "rapper": "no", "software_engineer": "no", "surgeon": "no", "teacher": "no", "yoga_teacher": "no", "profession": "no", "gender": "no"}

I apologize, but it seems the Top_logits information is not properly formatted or is missing in your input. Without this crucial information, I cannot refine my analysis or make accurate judgments about the neuron's behavior in relation to the given concepts.

Based solely on the examples provided, I can make some observations:

1. The neuron seems to activate on units of measurement, particularly those related to medical and 

In [13]:
# Data extraction with Instructor

import instructor # pip install -U instructor
from anthropic import Anthropic
from pydantic import BaseModel


class Decisions(BaseModel):
    gender: str
    professors: str
    nurses: str


client = instructor.from_anthropic(Anthropic())

llm_outputs_instructor = []
for i, messages in tqdm(enumerate(prompts), desc="LLM inference", total=len(prompts)):

    # barrier for testing
    if i == 5: 
        print("stopping LLM inference early for testing")
        break

    resp = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1000,
        temperature=0,
        system=system_prompt,
        messages=messages,
        response_model=Decisions,
    )
    llm_outputs_instructor.append(resp)

print(f'Total number of LLM outputs: {len(llm_outputs_instructor)}\n')
llm_outputs_instructor[:5]

LLM inference:  50%|█████     | 5/10 [00:09<00:09,  1.82s/it]

stopping LLM inference early for testing
Total number of LLM outputs: 5






[Decisions(gender='no', professors='no', nurses='no'),
 Decisions(gender='no', professors='no', nurses='no'),
 Decisions(gender='no', professors='no', nurses='no'),
 Decisions(gender='no', professors='no', nurses='no'),
 Decisions(gender='no', professors='no', nurses='no')]

### Note for the simple example of gender, doctor, nurse; Prompting with Instructor yielded a different result than direct prompting! I do not fully trust Instructor and would default to regex, if possible.

In [14]:
# Regex pattern matching

def extract_and_convert_json(input_string):
    # Regular expression to find JSON-like dictionaries
    match = re.search(r'\{.*?\}', input_string)
    
    if match:
        json_string = match.group(0)
        # Convert the extracted string to a dictionary
        return ast.literal_eval(json_string)
    else:
        raise ValueError("No JSON-like dictionary found in the input string.")

llm_outputs_json = []
for out in llm_outputs_direct_prompt:
    out = out[0].text
    json_str = out.split('yes_or_no_decisions')[-1]
    yes_or_no_decisions = extract_and_convert_json(json_str)
    llm_outputs_json.append(yes_or_no_decisions)

llm_outputs_json[:5]

ValueError: No JSON-like dictionary found in the input string.

In [None]:
# NOTE assumes that feature_idx starts from 0
# NOTE node_effects does currently not contain tensors
# NOTE we currently do not check whether all classes are contained in all llm json outputs, this could lead to feature idx mismatching

node_effects = defaultdict(list)
for feat_idx, decisions in enumerate(llm_outputs_json):
    for class_name, decision in decisions.items():
        class_idx = profession_dict[class_name]
        decision_bool = evaluate_binary_llm_output(decision)
        node_effects[class_idx].append(decision_bool)

with open(os.path.join(ae_path, "node_effects_autointerp.pkl"), "wb") as f:
    pickle.dump(node_effects, f)

In [None]:
node_effects

defaultdict(list,
            {0: [False, False, False, False, False],
             1: [False, False, False, False, False],
             2: [False, False, False, False, True],
             3: [False, False, False, False, False],
             4: [False, False, False, False, False],
             5: [False, False, False, False, False],
             6: [False, False, False, False, False],
             7: [False, False, False, False, False],
             8: [False, False, False, False, False],
             9: [False, False, False, False, False],
             10: [False, False, False, False, False],
             11: [False, False, False, False, False],
             12: [False, False, False, False, False],
             13: [False, True, False, False, False],
             14: [False, False, False, False, False],
             15: [False, False, False, False, True],
             16: [False, False, False, False, False],
             17: [False, False, False, False, False],
             18: [False