## Manual feature labelling and prompt building

### Setup

In [None]:
# Imports

import os
import torch as t
from nnsight import LanguageModel
import datasets
import anthropic
from tqdm import tqdm
import re
import ast
import pickle
from collections import defaultdict
from circuitsvis.activations import text_neuron_activations

import experiments.utils as utils
from experiments.autointerp import (
    get_max_activating_prompts, 
    highlight_top_activations,
    compute_dla, 
    format_examples,
    evaluate_binary_llm_output,
    get_autointerp_inputs_for_all_saes,
)
from experiments.explainers.simple.prompt_builder import build_prompt
from experiments.explainers.simple.prompts import build_system_prompt

DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

%load_ext autoreload
%autoreload 2

In [None]:
# Load model

DEVICE = "cuda"
model_name = "EleutherAI/pythia-70m-deduped"
model_dtype = t.bfloat16
model = LanguageModel(
    model_name,
    device_map=DEVICE,
    dispatch=True,
    attn_implementation="eager",
    torch_dtype=model_dtype,
)
model_unembed = model.embed_out # For direct logit attribution

In [None]:
# Load data

num_contexts = 10000
context_length = 128
batch_size = 250

dataset = datasets.load_dataset("georgeyw/dsir-pile-100k", streaming=False)
data = model.tokenizer(dataset["train"]["contents"][:num_contexts], return_tensors="pt", padding="max_length", truncation=True, max_length=context_length).to(DEVICE).data
batched_data = utils.batch_inputs(data, batch_size)


# Class specific data

profession_dict = {
    "accountant": 0, "architect": 1, "attorney": 2, "chiropractor": 3,
    "comedian": 4, "composer": 5, "dentist": 6, "dietitian": 7,
    "dj": 8, "filmmaker": 9, "interior_designer": 10, "journalist": 11,
    "model": 12, "nurse": 13, "painter": 14, "paralegal": 15,
    "pastor": 16, "personal_trainer": 17, "photographer": 18, "physician": 19,
    "poet": 20, "professor": 21, "psychologist": 22, "rapper": 23,
    "software_engineer": 24, "surgeon": 25, "teacher": 26, "yoga_teacher": 27,
    "profession": -4, "gender": -2
}

In [None]:
# Load dictionary

dictionaries_path = "../dictionary_learning/dictionaries"

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    "pythia70m_sweep_standard_ctx128_0712": {"resid_post_layer_3": {"trainer_ids": [6]}},
    # "pythia70m_sweep_gated_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [9]}},
    # "pythia70m_sweep_topk_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [10]}},
    # "gemma-2-2b_sweep_topk_ctx128_0817": {"resid_post_layer_12": {"trainer_ids": [2]}}, 
}
sweep_name = list(ae_sweep_paths.keys())[0]
submodule_trainers = ae_sweep_paths[sweep_name]

ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)
ae_paths = utils.get_ae_paths(ae_group_paths)

ae_path = ae_paths[0]
submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

In [None]:
# Run max activating examples once from all SAEs

k_inputs_per_feature = 10

get_autointerp_inputs_for_all_saes(
    model, 
    n_inputs=num_contexts,
    batch_size=batch_size,
    context_length=context_length,
    top_k_inputs=k_inputs_per_feature,
    ae_paths=ae_paths,
    force_rerun=False,
)

with open(os.path.join(ae_path, "max_activating_inputs.pkl"), "rb") as f:
    file = pickle.load(f)

max_token_idxs_FKL = file["max_tokens_FKL"]
max_activations_FKL = file["max_activations_FKL"]
top_dla_token_idxs_FK = file["dla_results_FK"]

### Select class

In [None]:
### Find features relevant to a profession from the class_probe node_effects.pkl

class_name = "gender"
k_features_per_concept = 3 

filename_counter = ""
class_id = profession_dict[class_name]
node_effects_filename = f"{ae_path}/node_effects{filename_counter}.pkl"

with open(node_effects_filename, "rb") as f:
    node_effects = pickle.load(f)

effects = node_effects[class_id]
print(effects.shape)


top_k_values, top_k_indices = t.topk(effects, k_features_per_concept)
t.set_printoptions(sci_mode=False)
print(top_k_values)
print(top_k_indices)

selected_token_idxs_FKL = max_token_idxs_FKL[top_k_indices]
selected_activations_FKL = max_activations_FKL[top_k_indices]
top_dla_token_idxs_FK = top_dla_token_idxs_FK[top_k_indices]

In [None]:
# Format max_activating_inputs by << emphasizing>> max act examples

num_top_emphasized_tokens = 5

example_prompts = format_examples(model, selected_token_idxs_FKL, selected_activations_FKL, num_top_emphasized_tokens)
top_dla_tokens_FK = model.tokenizer.batch_decode(top_dla_token_idxs_FK, skip_special_tokens=True)

### Show prompt

In [None]:
system_prompt = build_system_prompt(
    cot=False,
    concepts=list(profession_dict.keys()),
    logits=True
)

prompts = []
for example_prompt, top_dla_tokens_K in zip(example_prompts, top_dla_tokens_FK):
    message = build_prompt(
        examples=example_prompt,
        cot=False,
        top_logits=top_dla_tokens_K,
    )
    prompts.append(message)

In [None]:
print(system_prompt[0]['text'])

In [None]:
for i, feat_examples in enumerate(example_prompts):
    print(f'############### Max act examples for Feature {top_k_indices[i]}:')
    print(feat_examples)
    print(f'############### Top dla tokens for Feature {top_k_indices[i]}:')
    print(top_dla_tokens_FK[i])
    print('\n\n\n')

### Render prompt for human labelling

In [None]:
# Or alternatively view with circuitvis
def _list_decode(x):
    if len(x.shape) == 0:
        return model.tokenizer.decode(x, skip_special_tokens=True)
    else:
        return [_list_decode(y) for y in x]

selected_token_strs_FKL = _list_decode(selected_token_idxs_FKL)
for i in range(selected_activations_FKL.shape[0]):
    feat_idx = top_k_indices[i]
    print(f"Feature {feat_idx}:")
    selected_activations_KL11 = [selected_activations_FKL[i, k, :, None, None] for k in range(k_inputs_per_feature)]
    html_activations = text_neuron_activations(selected_token_strs_FKL[i], selected_activations_KL11)
    display(html_activations)

