In [None]:
import os
import pickle
import json
from transformers import AutoTokenizer
from experiments.pipeline_config import PipelineConfig
from experiments.llm_autointerp.llm_query import perform_llm_autointerp, construct_llm_features_prompts
from typing import Dict

%load_ext autoreload
%autoreload 2

In [None]:
from nnsight import LanguageModel

model = LanguageModel('EleutherAI/pythia-70m-deduped')

First, run autointerp with
```python
cd experiments
python llm_autointerp/run_autointerp_can.py
```

In [None]:
# TODO create run.sh script for above to run llm_query with custom args from this notebook
# For now, manually copy hyperparameters here
repo_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
ae_path = "dictionary_learning/dictionaries/pythia70m_sweep_topk_ctx128_0730/resid_post_layer_3/trainer_10"
ae_path = os.path.abspath(os.path.join(repo_dir, ae_path))

In [None]:
# To save the raw llm output when querying llm, set DEBUG=True
with open(os.path.join(ae_path, "raw_llm_outputs.json"), "r") as f:
    raw_llm_outputs = json.load(f)

with open(os.path.join(ae_path, "extracted_json_llm_outputs.json"), "r") as f:
    extracted_json_llm_outputs = json.load(f)

with open(os.path.join(ae_path, "node_effects.pkl"), "rb") as f:
    node_effects_classprobe = pickle.load(f)

with open(os.path.join(ae_path, "node_effects_auto_interp_attrib_top20.pkl"), "rb") as f:
    node_effects_autointerp = pickle.load(f)

with open(os.path.join(ae_path, "node_effects_dist_diff.pkl"), "rb") as f:
    node_effects_classprobe = pickle.load(f)

with open(os.path.join(ae_path, "node_effects_auto_interp_dist_top20.pkl"), "rb") as f:
    node_effects_autointerp = pickle.load(f)

with open(os.path.join(ae_path, "max_activating_inputs.pkl"), "rb") as f:
    max_activating_inputs = pickle.load(f)

In [None]:
# Show config
cfg = PipelineConfig()

cfg.prompt_dir = "llm_autointerp/"
cfg.force_autointerp_recompute = True
cfg.chosen_autointerp_class_names = [
    "gender",
    "professor",
    "nurse",
    "accountant",
    "architect",
    "attorney",
    "dentist",
]

cfg.num_top_features_per_class = 1

for k, v in cfg.__dict__.items():
    print(f"{k}: {v}")

In [None]:
node_effects_classprobe.keys()

In [None]:
node_effects_autointerp

In [None]:
total_interpretable = 0

for class_name in node_effects_autointerp:
    # if isinstance(class_name, str):
    #     continue
    effects = node_effects_autointerp[class_name]
    total_interpretable += (effects > 0).sum()

total_interpretable

In [None]:
# Print cfg.num_features_per_class for each class
from experiments.utils_bib_dataset import profession_dict
import torch as t

top_feature_idxs_per_class = {}

for cls in cfg.chosen_autointerp_class_names:
    if cls in ['gender', 'professor', 'nurse', ]:
        continue
    else:
        cls_idx = profession_dict[cls]

    top_feature_vals, top_feature_idxs = t.topk(node_effects_classprobe[cls_idx], cfg.num_top_features_per_class)
    top_feature_idxs_per_class[cls] = top_feature_idxs
    autointerp_vals = node_effects_autointerp[cls_idx][top_feature_idxs]
    extracted_jsons = [extracted_json_llm_outputs[str(j.item())] for j in top_feature_idxs]
    # raw_outputs = [raw_llm_outputs[str(j.item())][0][-100:] for j in top_feature_idxs]
    print(f"{cls}:")
    print(f'feature_idx, probe_effect, autointerp_effects_val, autointerp_extracted_json')
    for i in range(cfg.num_top_features_per_class):
        if extracted_jsons[i] is not None:
            print(f'{top_feature_idxs[i]}, {top_feature_vals[i]}, {autointerp_vals[i]}, {extracted_jsons[i][cls]}')#, raw_outputs[i])
        else:
            print(f'{top_feature_idxs[i]}, {top_feature_vals[i]}, {autointerp_vals[i]}, None')

In [None]:
features_prompts = construct_llm_features_prompts(ae_path, model.tokenizer, cfg)

In [None]:
## Are the outputs that have been set to 0 really not related at all

for cls in cfg.chosen_autointerp_class_names:
    if cls in ['gender', 'professor', 'nurse', ]:
        continue
    for top_feature_idx in top_feature_idxs_per_class[cls]:
        idx = top_feature_idx.item()
        if node_effects_autointerp[profession_dict[cls]][idx] == 0:
            print(f'feature_idx: {idx}, class: {cls}')
            print(f"(high effect on the classprobe of this class). Autointerp scored with 0 for this class. Here is the prompt to autointerp:\n")
            print(f'this is the rating from autointerp: {extracted_json_llm_outputs[str(idx)]}')
            print(features_prompts[idx][:-1])

            # print border
            print("\n--------------------------------------------------\n")

In [None]:
# Feat idx with obvious errors for further inspection

err_idxs = [
    (6036, 'architect'),
    (14889, 'attorney'),
]

for idx, cls in err_idxs:
    print(f'feature_idx: {idx}, class: {cls}')
    print(f'prompts: {features_prompts[idx]}')
    print(f'raw_llm_output: {raw_llm_outputs[str(idx)][0]}')

    print(f'\n\n___________________________________\n\n')

In [None]:
print(features_prompts[14889])