## Autointerp for SHIFT evals
LLM judge decides whether a neuron/latent is related to a natural language concept. Inputs to Autointerp LLM judge:
- Max activating examples
- DLA top promoted tokens


### Functionality of this notebook
Inputs: 
- model, datset, dictionaries
- list of concepts to check whether it is related to sth.

Outputs:
- node_effects.pkl per dictionary per concept with binary yes/no decision on whether feature is related to prompt.

In [None]:
# Imports

import os
import torch as t
from nnsight import LanguageModel
import datasets
import anthropic
from tqdm import tqdm
import re
import ast
import pickle
from collections import defaultdict

import experiments.utils as utils
from experiments.autointerp import (
    get_max_activating_prompts, 
    compute_dla, 
    format_examples,
    evaluate_binary_llm_output
)
from experiments.llm_autointerp.prompt_builder import build_prompt
from experiments.llm_autointerp.prompts import build_system_prompt

DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

%load_ext autoreload
%autoreload 2

In [None]:
# Securely input the API key
api_key = input("Enter your API key: ")
os.environ['ANTHROPIC_API_KEY'] = api_key

In [None]:
# Load model

DEVICE = "cuda"
model_name = "EleutherAI/pythia-70m-deduped"
model_dtype = t.bfloat16
model = LanguageModel(
    model_name,
    device_map=DEVICE,
    dispatch=True,
    attn_implementation="eager",
    torch_dtype=model_dtype,
)
model_unembed = model.embed_out # For direct logit attribution

In [None]:
# Load data

num_contexts = 10000
context_length = 128
batch_size = 250

dataset = datasets.load_dataset("georgeyw/dsir-pile-100k", streaming=False)
data = model.tokenizer(dataset["train"]["contents"][:num_contexts], return_tensors="pt", padding="max_length", truncation=True, max_length=context_length).to(DEVICE).data
batched_data = utils.batch_inputs(data, batch_size)


# Class specific data

profession_dict = {
    "accountant": 0, "architect": 1, "attorney": 2, "chiropractor": 3,
    "comedian": 4, "composer": 5, "dentist": 6, "dietitian": 7,
    "dj": 8, "filmmaker": 9, "interior_designer": 10, "journalist": 11,
    "model": 12, "nurse": 13, "painter": 14, "paralegal": 15,
    "pastor": 16, "personal_trainer": 17, "photographer": 18, "physician": 19,
    "poet": 20, "professor": 21, "psychologist": 22, "rapper": 23,
    "software_engineer": 24, "surgeon": 25, "teacher": 26, "yoga_teacher": 27,
    "profession": -4, "gender": -2
}

In [None]:
# Load dictionary

dictionaries_path = "../dictionary_learning/dictionaries"

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    "pythia70m_sweep_standard_ctx128_0712": {"resid_post_layer_3": {"trainer_ids": [6]}},
    # "pythia70m_sweep_gated_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [9]}},
    # "pythia70m_sweep_topk_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [10]}},
    # "gemma-2-2b_sweep_topk_ctx128_0817": {"resid_post_layer_12": {"trainer_ids": [2]}}, 
}
sweep_name = list(ae_sweep_paths.keys())[0]
submodule_trainers = ae_sweep_paths[sweep_name]

ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)
ae_paths = utils.get_ae_paths(ae_group_paths)

ae_path = ae_paths[0]
submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

In [None]:
from experiments.autointerp import get_autointerp_inputs_for_all_saes

all_latent_indices = t.arange(dictionary.dict_size)
k_inputs_per_feature = 5

get_autointerp_inputs_for_all_saes(
    model, 
    n_inputs=num_contexts,
    batch_size=batch_size,
    context_length=context_length,
    top_k_inputs=k_inputs_per_feature,
    ae_paths=ae_paths,
    force_rerun=False,
)
t.cuda.empty_cache()

with open(os.path.join(ae_path, "max_activating_inputs.pkl"), "rb") as f:
    file = pickle.load(f)
    
max_token_idxs_FKL = file["max_tokens_FKL"]
max_activations_FKL = file["max_activations_FKL"]
top_dla_token_idxs_FK = file["dla_results_FK"]

In [None]:
# Format max_activating_inputs by << emphasizing>> max act examples
# moving max token idxs to device is four times as slow as leaving it on CPU

num_top_emphasized_tokens = 5
formatting_batch_size = len(max_token_idxs_FKL)
formatting_batch_size = 100

example_prompts = format_examples(model.tokenizer, max_token_idxs_FKL[:formatting_batch_size], max_activations_FKL[:formatting_batch_size], num_top_emphasized_tokens)

In [None]:
example_prompts[1]

In [None]:
from experiments.llm_autointerp.prompts import build_system_prompt, create_few_shot_examples, create_unlabeled_prompts
from experiments import utils_bib_dataset

system_prompt = build_system_prompt(
    concepts=list(utils_bib_dataset.profession_dict.keys()),
)
print(system_prompt)

In [None]:
PROMPT_DIR = "llm_autointerp"
import json

with open(f"{PROMPT_DIR}/manual_labels_few_shot.json", "r") as f:
        few_shot_manual_labels = json.load(f)

few_shot_examples = create_few_shot_examples(few_shot_manual_labels);
print(few_shot_examples)

In [None]:
top_dla_token_strs_FK = utils.list_decode(top_dla_token_idxs_FK, model.tokenizer)

create_unlabeled_prompts(example_prompts, top_dla_token_strs_FK)

In [None]:
all_latent_indices = t.arange(dictionary.dict_size)
k_inputs_per_feature = 5

get_autointerp_inputs_for_all_saes(
    model, 
    n_inputs=num_contexts,
    batch_size=batch_size,
    context_length=context_length,
    top_k_inputs=k_inputs_per_feature,
    ae_paths=ae_paths,
    force_rerun=False,
)
t.cuda.empty_cache()

with open(os.path.join(ae_path, "max_activating_inputs.pkl"), "rb") as f:
    file = pickle.load(f)
    
max_token_idxs_FKL = file["max_tokens_FKL"]
max_activations_FKL = file["max_activations_FKL"]
top_dla_token_idxs_FK = file["dla_results_FK"]

with open(f"llm_autointerp/manual_labels_few_shot.json", "r") as f:
    few_shot_manual_labels = json.load(f)

In [None]:
from experiments.llm_autointerp.prompts import build_system_prompt, create_few_shot_examples, create_unlabeled_prompts
from experiments import utils_bib_dataset
from experiments.llm_autointerp import llm_utils
from experiments.autointerp import get_autointerp_inputs_for_all_saes
import json

def node_effects_autointerp(model, node_effects_classprobe, max_token_idxs_FKL, max_activations_FKL, top_dla_token_idxs_FK, few_shot_manual_labels, num_top_features_from_probe=5, chosen_class_idxs):  



    num_top_emphasized_tokens = 5
    formatting_batch_size = len(max_token_idxs_FKL)
    formatting_batch_size = 100

    system_prompt = build_system_prompt(concepts=list(utils_bib_dataset.profession_dict.keys()),)
    few_shot_examples = create_few_shot_examples(few_shot_manual_labels)

    unlabeled_prompts = format_examples(model.tokenizer, max_token_idxs_FKL[:formatting_batch_size], max_activations_FKL[:formatting_batch_size], num_top_emphasized_tokens)



    print(f"Few shot example is using {llm_utils.count_tokens(few_shot_examples)} tokens")
    print(f"System prompt is using {llm_utils.count_tokens(system_prompt[0]['text'])} tokens")


In [None]:
# TODO prompt caching
# TODO parallelize LLM inference

client = anthropic.Anthropic()

llm_outputs_direct_prompt = []
for i, messages in tqdm(enumerate(prompts), desc="LLM inference", total=len(prompts)):

    # barrier for testing
    if i == 5: 
        print("stopping LLM inference early for testing")
        break

    llm_out = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1000,
        temperature=0,
        system=system_prompt,
        messages=messages,
    )
    llm_outputs_direct_prompt.append(llm_out.content)

print(f'Total number of LLM outputs: {len(llm_outputs_direct_prompt)}\n')
llm_outputs_direct_prompt[:5]

In [None]:
for out in llm_outputs_direct_prompt:
    print(f'{out[0].text.split('yes_or_no_decisions = ')[-1]}\n')

In [None]:
# Data extraction with Instructor

import instructor # pip install -U instructor
from anthropic import Anthropic
from pydantic import BaseModel


class Decisions(BaseModel):
    gender: str
    professors: str
    nurses: str


client = instructor.from_anthropic(Anthropic())

llm_outputs_instructor = []
for i, messages in tqdm(enumerate(prompts), desc="LLM inference", total=len(prompts)):

    # barrier for testing
    if i == 5: 
        print("stopping LLM inference early for testing")
        break

    resp = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=1000,
        temperature=0,
        system=system_prompt,
        messages=messages,
        response_model=Decisions,
    )
    llm_outputs_instructor.append(resp)

print(f'Total number of LLM outputs: {len(llm_outputs_instructor)}\n')
llm_outputs_instructor[:5]

### Note for the simple example of gender, doctor, nurse; Prompting with Instructor yielded a different result than direct prompting! I do not fully trust Instructor and would default to regex, if possible.

In [None]:
# Regex pattern matching

def extract_and_convert_json(input_string):
    # Regular expression to find JSON-like dictionaries
    match = re.search(r'\{.*?\}', input_string)
    
    if match:
        json_string = match.group(0)
        # Convert the extracted string to a dictionary
        return ast.literal_eval(json_string)
    else:
        raise ValueError("No JSON-like dictionary found in the input string.")

llm_outputs_json = []
for out in llm_outputs_direct_prompt:
    out = out[0].text
    json_str = out.split('yes_or_no_decisions')[-1]
    yes_or_no_decisions = extract_and_convert_json(json_str)
    llm_outputs_json.append(yes_or_no_decisions)

llm_outputs_json[:5]

In [None]:
# NOTE assumes that feature_idx starts from 0
# NOTE node_effects does currently not contain tensors
# NOTE we currently do not check whether all classes are contained in all llm json outputs, this could lead to feature idx mismatching

node_effects = defaultdict(list)
for feat_idx, decisions in enumerate(llm_outputs_json):
    for class_name, decision in decisions.items():
        class_idx = profession_dict[class_name]
        decision_bool = evaluate_binary_llm_output(decision)
        node_effects[class_idx].append(decision_bool)

with open(os.path.join(ae_path, "node_effects_autointerp.pkl"), "wb") as f:
    pickle.dump(node_effects, f)

In [None]:
node_effects