## Manual feature labelling and prompt building

### Setup

In [9]:
# Imports

import os
import torch as t
from nnsight import LanguageModel
import datasets
import anthropic
from tqdm import tqdm
import re
import ast
import pickle
from collections import defaultdict
from circuitsvis.activations import text_neuron_activations
from transformers import AutoTokenizer

import experiments.utils as utils
from experiments.autointerp import (
    get_max_activating_prompts, 
    highlight_top_activations,
    compute_dla, 
    format_examples,
    evaluate_binary_llm_output,
    get_autointerp_inputs_for_all_saes,
)
from experiments.explainers.simple.prompt_builder import build_prompt
from experiments.explainers.simple.prompts import build_system_prompt

DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
profession_dict = {
    "accountant": 0, "architect": 1, "attorney": 2, "chiropractor": 3,
    "comedian": 4, "composer": 5, "dentist": 6, "dietitian": 7,
    "dj": 8, "filmmaker": 9, "interior_designer": 10, "journalist": 11,
    "model": 12, "nurse": 13, "painter": 14, "paralegal": 15,
    "pastor": 16, "personal_trainer": 17, "photographer": 18, "physician": 19,
    "poet": 20, "professor": 21, "psychologist": 22, "rapper": 23,
    "software_engineer": 24, "surgeon": 25, "teacher": 26, "yoga_teacher": 27,
    "male / female": "male / female", "professor / nurse": "professor / nurse",
    "male_professor / female_nurse": "male_professor / female_nurse",
    "biased_male / biased_female": "biased_male / biased_female",
}

In [17]:
# Load dictionary

dictionaries_path = "../dictionary_learning/dictionaries/autointerp_test_data"

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    "pythia70m_sweep_topk_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [2, 6, 10, 18]}},
    "gemma-2-2b_test_sae": {"resid_post_layer_12": {"trainer_ids": [2]}}, 
}
pythia_sweep_name = list(ae_sweep_paths.keys())[0]
pythia_submodule_trainers = ae_sweep_paths[pythia_sweep_name]

pythia_group_paths = utils.get_ae_group_paths(dictionaries_path, pythia_sweep_name, pythia_submodule_trainers)
pythia_ae_paths = utils.get_ae_paths(pythia_group_paths)

gemma_sweep_name = list(ae_sweep_paths.keys())[1]
gemma_submodule_trainers = ae_sweep_paths[gemma_sweep_name]

gemma_group_paths = utils.get_ae_group_paths(dictionaries_path, gemma_sweep_name, gemma_submodule_trainers)
gemma_ae_paths = utils.get_ae_paths(gemma_group_paths)

chosen_class_indices = [
        "male / female",
        "professor / nurse",
        "male_professor / female_nurse",
        "biased_male / biased_female",
        "accountant",
        "architect",
        "attorney",
        "dentist",
        "filmmaker",
    ]

In [142]:
sampled_indices = [0, 10, 100, 1000]
class_indices_1 = ["male / female", "professor / nurse", "accountant", "architect"]
class_indices_2 = ["male_professor / female_nurse", "attorney", "dentist", "filmmaker"]
class_indices_3 = class_indices_1 + class_indices_2

combinations = [
    {"sae": pythia_ae_paths[0], "class_indices": class_indices_1},
    {"sae": pythia_ae_paths[1], "class_indices": class_indices_2},
    {"sae": pythia_ae_paths[2], "class_indices": class_indices_1},
    {"sae": pythia_ae_paths[3], "class_indices": class_indices_2},
    {"sae": gemma_ae_paths[0], "class_indices": class_indices_3},
]

def generate_combinations(sampled_indices, combinations):
    result = []
    for combo in combinations:
        sae = combo["sae"]
        class_indices = combo["class_indices"]
        for class_index in class_indices:
            for sampled_index in sampled_indices:
                result.append({
                    "sae": sae,
                    "class_index": class_index,
                    "sampled_index": sampled_index
                })
    return result

combinations_list = generate_combinations(sampled_indices, combinations)
print(f"There are {len(combinations_list)} combinations to evaluate.")

There are 96 combinations to evaluate.


In [103]:
pythia_model_eval_config = utils.ModelEvalConfig.from_sweep_name(pythia_sweep_name)
pythia_model_name = pythia_model_eval_config.full_model_name
pythia_tokenizer = AutoTokenizer.from_pretrained(pythia_model_name)

gemma_model_eval_config = utils.ModelEvalConfig.from_sweep_name(gemma_sweep_name)
gemma_model_name = gemma_model_eval_config.full_model_name
gemma_tokenizer = AutoTokenizer.from_pretrained(gemma_model_name)

In [104]:
# Run max activating examples once from all SAEs

k_inputs_per_feature = 10
example_ae_path = pythia_ae_paths[0]

with open(os.path.join(example_ae_path, "max_activating_inputs.pkl"), "rb") as f:
    max_activating_inputs = pickle.load(f)

max_token_idxs_FKL = max_activating_inputs["max_tokens_FKL"]
max_activations_FKL = max_activating_inputs["max_activations_FKL"]
top_dla_token_idxs_FK = max_activating_inputs["dla_results_FK"]

  return torch.load(io.BytesIO(b))


### Select class

In [105]:
### Find features relevant to a profession from the class_probe node_effects.pkl

class_name = "male / female"
k_features_per_concept = 3 

filename_counter = ""
class_id = profession_dict[class_name]
node_effects_filename = f"{example_ae_path}/node_effects{filename_counter}.pkl"

with open(node_effects_filename, "rb") as f:
    node_effects = pickle.load(f)

effects = node_effects[class_id]
print(effects.shape)


top_k_values, top_k_indices = t.topk(effects, k_features_per_concept)
t.set_printoptions(sci_mode=False)
print(top_k_values)
print(top_k_indices)

selected_token_idxs_FKL = max_token_idxs_FKL[top_k_indices]
selected_activations_FKL = max_activations_FKL[top_k_indices]
top_dla_token_idxs_FK = top_dla_token_idxs_FK[top_k_indices]

torch.Size([16384])
tensor([1.5156, 0.2812, 0.1514], dtype=torch.bfloat16)
tensor([15482, 11871,  2419])


In [106]:
# Format max_activating_inputs by << emphasizing>> max act examples

num_top_emphasized_tokens = 5

print(top_dla_token_idxs_FK.shape)

example_prompts = format_examples(pythia_tokenizer, selected_token_idxs_FKL, selected_activations_FKL, num_top_emphasized_tokens)

top_dla_tokens_list = []

tokens_list = utils.list_decode(top_dla_token_idxs_FK, pythia_tokenizer)
tokens = tokens_list[0]

print(",".join(tokens))



torch.Size([3, 10])
 her, herself, she, hers, She,She,she, Her, pregnancy, woman


## MANUAL LABELING BEGINS HERE

Run manual inputs creation to view DLA and top 4 example prompts.

Add scores / chain of thought to Label adding per example.

Save as json once done.

In [161]:
manual_labels = {}

In [164]:
current_idx = 0

displayed_prompts = 4
num_top_emphasized_tokens = 5
t.set_printoptions(sci_mode=False)

In [165]:
# Manual inputs creation

current_combination = combinations_list[current_idx]
ae_path = current_combination["sae"]
class_index = current_combination["class_index"]
sampled_index = current_combination["sampled_index"]

class_id = profession_dict[class_index]

if "pythia" in ae_path:
    tokenizer = pythia_tokenizer
else:
    tokenizer = gemma_tokenizer

print(f"current idx: {current_idx}")
print(f"SAE path: {ae_path}")
print(f"class index: {class_index}")
print(f"sampled index: {sampled_index}")

inputs_path = os.path.join(ae_path, "max_activating_inputs.pkl")
node_effects_path = os.path.join(ae_path, "node_effects.pkl")

with open(inputs_path, "rb") as f:
    max_activating_inputs = pickle.load(f)

with open(node_effects_path, "rb") as f:
    node_effects = pickle.load(f)

max_token_idxs_FKL = max_activating_inputs["max_tokens_FKL"]
max_activations_FKL = max_activating_inputs["max_activations_FKL"]
top_dla_token_idxs_FK = max_activating_inputs["dla_results_FK"]


effects = node_effects[class_id]

top_k_values, top_k_indices = t.topk(effects, 2000)

sae_feat_index = top_k_indices[sampled_index]

selected_token_idxs_1KL = max_token_idxs_FKL[sae_feat_index, :displayed_prompts, :].unsqueeze(0)
selected_activations_1KL = max_activations_FKL[sae_feat_index, :displayed_prompts, :].unsqueeze(0)
top_dla_token_idxs_K = top_dla_token_idxs_FK[sae_feat_index]

example_prompts = format_examples(tokenizer, selected_token_idxs_1KL, selected_activations_1KL, num_top_emphasized_tokens)
tokens_list = utils.list_decode(top_dla_token_idxs_K, tokenizer)
tokens_string = ",".join(tokens_list)

print(tokens_string)
print(example_prompts[0])

previous_idx = current_idx
current_idx += 2

current idx: 0
SAE path: ../dictionary_learning/dictionaries/autointerp_test_data/pythia70m_sweep_topk_ctx128_0730/resid_post_layer_3/trainer_2
class index: male / female
sampled index: 0
 her, herself, she, hers, She,She,she, Her, pregnancy, woman



Example 1: the broker and said yes.

She << told>> her younger sister she was going to America for work, << but>> << to>> << keep>> it a secret from her parents, who would never grant her permission to work abroad. You Mi told her parents she was going to Seoul to be a golf caddy -- one of the few legal women's jobs that bring hefty tips from rich men.

She planned to << tell>> them the truth after she paid off her debts.

You Mi was instructed to take passport photos and give them to a man named Kevin in Seoul. The broker drove her to the city, and two days later, You Mi had




Example 2: (everyone was wearing hop-bought clothes. What is the world coming to?). I thought she'd << lost>> << when>> an even littler girl appeared dressed as 

  return torch.load(io.BytesIO(b))


In [168]:
# Label adding

per_class_scores = {
    "male / female": 4,
    "professor / nurse": 0,
    "male_professor / female_nurse": 0,
    "biased_male / biased_female": 0,
    "accountant": 1,
    "architect": 0,
    "attorney": 0,
    "dentist": 0,
    "filmmaker": 0,
}

chain_of_thought = """The top promoted logits are female pronouns and female pronouns are in every sentence, thus this is a 4 for male / female.
Two sentences mention companies and brokers, so this is a 1 for accountant. No other classes appear to be relevant."""

new_label = {
    "sae": ae_path,
    "class_index": class_index,
    "sampled_index": sampled_index,
    "sae_feat_index": sae_feat_index.item(),
    "example_prompts": example_prompts,
    "tokens_string": tokens_string,
    "per_class_scores": per_class_scores,
    "chain_of_thought": chain_of_thought,
}

manual_labels[current_idx] = new_label

print(f"Add the following idx to manual_labels: {previous_idx}")

Add the following idx to manual_labels: 0


In [169]:
import json
with open("manual_labels.json", "w") as f:
    json.dump(manual_labels, f)

### Show prompt

In [26]:
# system_prompt = build_system_prompt(
#     cot=False,
#     concepts=list(profession_dict.keys()),
#     logits=True
# )

# prompts = []
# for example_prompt, top_dla_tokens_K in zip(example_prompts, top_dla_tokens_FK):
#     message = build_prompt(
#         examples=example_prompt,
#         cot=False,
#         top_logits=top_dla_tokens_K,
#     )
#     prompts.append(message)

In [27]:
# print(system_prompt[0]['text'])

You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and decide whether its behavior is related to a concept for each concept in (accountant, architect, attorney, chiropractor, comedian, composer, dentist, dietitian, dj, filmmaker, interior_designer, journalist, model, nurse, painter, paralegal, pastor, personal_trainer, photographer, physician, poet, professor, psychologist, rapper, software_engineer, surgeon, teacher, yoga_teacher, male / female, professor / nurse, male_professor / female_nurse, biased_male / biased_female).

(Part 2) Tokens that the neuron boosts in the next token prediction

You will also be shown a list called Top_logits. The logits promoted by the neuron shed light on how the neuron's activation influences the model's predictions or outputs. Look at this list of Top_logits and refine your hypotheses from part 1. It is possible that this list is more informative than 

In [14]:
for i, feat_examples in enumerate(example_prompts):
    print(f'############### Max act examples for Feature {top_k_indices[i]}:')
    print(feat_examples)
    print(f'############### Top dla tokens for Feature {top_k_indices[i]}:')
    print(top_dla_tokens_FK[i])
    print('\n\n\n')

############### Max act examples for Feature 15482:
Example 1: the broker and said yes.

She << told>> her younger sister she was going to America for work, << but>> << to>> << keep>> it a secret from her parents, who would never grant her permission to work abroad. You Mi told her parents she was going to Seoul to be a golf caddy -- one of the few legal women's jobs that bring hefty tips from rich men.

She planned to << tell>> them the truth after she paid off her debts.

You Mi was instructed to take passport photos and give them to a man named Kevin in Seoul. The broker drove her to the city, and two days later, You Mi had

Example 2: (everyone was wearing hop-bought clothes. What is the world coming to?). I thought she'd << lost>> << when>> an even littler girl appeared dressed as a bumble bee - but she hung << onto>> mummy the whole time, and wouldn't talk or look at the audience, << whereas>> LMD went up on her own, yelled her age down the microphone, and roundly << told>> off a

### Render prompt for human labelling

In [15]:
# Or alternatively view with circuitvis
def _list_decode(x):
    if len(x.shape) == 0:
        return pythia_tokenizer.decode(x, skip_special_tokens=True)
    else:
        return [_list_decode(y) for y in x]

selected_token_strs_FKL = _list_decode(selected_token_idxs_FKL)
for i in range(selected_activations_FKL.shape[0]):
    feat_idx = top_k_indices[i]
    print(f"Feature {feat_idx}:")
    selected_activations_KL11 = [selected_activations_FKL[i, k, :, None, None] for k in range(k_inputs_per_feature)]
    html_activations = text_neuron_activations(selected_token_strs_FKL[i], selected_activations_KL11)
    display(html_activations)



Feature 15482:


Feature 11871:


Feature 2419:
