In [1]:
import os
import pickle
import json
from transformers import AutoTokenizer
from experiments.pipeline_config import PipelineConfig
from experiments.llm_autointerp.llm_query import perform_llm_autointerp
from typing import Dict

%load_ext autoreload
%autoreload 2

First, run autointerp with
```python
cd experiments
python llm_autointerp/run_autointerp_can.py
```

In [2]:
# TODO create run.sh script for above to run llm_query with custom args from this notebook
# For now, manually copy hyperparameters here
repo_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
ae_path = "dictionary_learning/dictionaries/autointerp_test_data/pythia70m_sweep_topk_ctx128_0730/resid_post_layer_3/trainer_2"
ae_path = os.path.abspath(os.path.join(repo_dir, ae_path))

In [3]:
# To save the raw llm output when querying llm, set DEBUG=True
with open(os.path.join(ae_path, "llm_debug_results.json"), "r") as f:
    llm_debug_results = json.load(f)

with open(os.path.join(ae_path, "node_effects.pkl"), "rb") as f:
    node_effects_classprobe = pickle.load(f)

with open(os.path.join(ae_path, "node_effects_auto_interp.pkl"), "rb") as f:
    node_effects_autointerp = pickle.load(f)

In [4]:
# Show config
cfg = PipelineConfig()
for k, v in cfg.__dict__.items():
    print(f"{k}: {v}")

max_activations_collection_n_inputs: 5000
top_k_inputs_act_collect: 5
probe_train_set_size: 4000
probe_test_set_size: 500
train_set_size: 100
test_set_size: 200
eval_saes_n_inputs: 250
probe_batch_size: 200
probe_epochs: 10
model_dtype: torch.bfloat16
reduced_GPU_memory: False
include_gender: True
use_autointerp: True
force_eval_results_recompute: False
force_max_activations_recompute: False
force_probe_recompute: False
force_node_effects_recompute: False
force_autointerp_recompute: False
dictionaries_path: ../dictionary_learning/dictionaries
probes_dir: trained_bib_probes
attribution_patching_method: attrib
ig_steps: 10
api_llm: claude-3-5-sonnet-20240620
prompt_dir: llm_autointerp/
node_effects_attrib_filename: node_effects.pkl
autointerp_filename: node_effects_auto_interp.pkl
bias_shift_dir1_filename: node_effects_bias_shift_dir1.pkl
bias_shift_dir2_filename: node_effects_bias_shift_dir2.pkl
num_top_emphasized_tokens: 5
num_top_inputs_autointerp: 5
num_top_features_per_class: 20
llm

In [5]:
node_effects_autointerp

{'male / female': tensor([-0., 0., 0.,  ..., -0., -0., 0.]),
 'professor / nurse': tensor([0., 0., 0.,  ..., -0., -0., -0.]),
 'male_professor / female_nurse': tensor([0., 0., 0.,  ..., -0., 0., -0.]),
 'biased_male / biased_female': tensor([0., 0., 0.,  ..., -0., -0., -0.]),
 0: tensor([-0., 0., -0.,  ..., -0., 0., 0.]),
 1: tensor([-0., -0., 0.,  ..., 0., 0., 0.]),
 2: tensor([-0., 0., 0.,  ..., -0., 0., 0.]),
 6: tensor([0., 0., 0.,  ..., -0., 0., 0.]),
 9: tensor([0., 0., -0.,  ..., -0., -0., 0.])}

In [6]:
# Print cfg.num_features_per_class for each class
from experiments.utils_bib_dataset import profession_dict
import torch as t


for cls in cfg.chosen_autointerp_class_names:
    if cls in ['gender', 'professor', 'nurse', ]:
        continue
    else:
        cls_idx = profession_dict[cls]

    top_feature_vals, top_feature_idxs = t.topk(node_effects_autointerp[cls_idx], cfg.num_top_features_per_class)
    print(f"{cls}:")
    for i, v in zip(top_feature_idxs, top_feature_vals):
        print(f'{i.item()} ({v.item()})')

accountant:
18 (-0.0)
10 (-0.0)
4 (0.0)
16 (-0.0)
12 (-0.0)
8 (0.0)
19 (-0.0)
1 (0.0)
11 (0.0)
5 (0.0)
15 (0.0)
13 (0.0)
7 (0.0)
17 (0.0)
3 (0.0)
9 (0.0)
0 (-0.0)
2 (-0.0)
6 (0.0)
14 (-0.0)
architect:
18 (-0.0)
10 (0.0)
4 (0.0)
16 (-0.0)
12 (0.0)
8 (0.0)
19 (0.0)
1 (-0.0)
11 (0.0)
5 (0.0)
15 (-0.0)
13 (0.0)
7 (0.0)
17 (0.0)
3 (0.0)
9 (0.0)
0 (-0.0)
2 (0.0)
6 (0.0)
14 (-0.0)
attorney:
18 (-0.0)
10 (0.0)
4 (0.0)
16 (-0.0)
12 (-0.0)
8 (-0.0)
19 (-0.0)
1 (0.0)
11 (0.0)
5 (0.0)
15 (-0.0)
13 (0.0)
7 (0.0)
17 (0.0)
3 (-0.0)
9 (0.0)
0 (-0.0)
2 (0.0)
6 (0.0)
14 (0.0)
dentist:
18 (-0.0)
10 (0.0)
4 (-0.0)
16 (-0.0)
12 (0.0)
8 (0.0)
19 (0.0)
1 (0.0)
11 (0.0)
5 (0.0)
15 (-0.0)
13 (0.0)
7 (0.0)
17 (0.0)
3 (-0.0)
9 (-0.0)
0 (0.0)
2 (0.0)
6 (0.0)
14 (0.0)
filmmaker:
12700 (2.15625)
6339 (0.000537872314453125)
18 (0.0)
4 (0.0)
10 (0.0)
16 (0.0)
8 (-0.0)
19 (0.0)
1 (0.0)
11 (0.0)
5 (0.0)
12 (0.0)
15 (-0.0)
7 (0.0)
17 (0.0)
3 (0.0)
9 (-0.0)
13 (0.0)
0 (0.0)
2 (-0.0)


In [7]:
# All unique features in the top cfg.num_features_per_class for all classes

llm_debug_results.keys()

dict_keys(['6339', '9743', '1512', '12700', '15482', '15917'])

In [8]:
from nnsight import LanguageModel

model = LanguageModel('EleutherAI/pythia-70m-deduped')

In [10]:
from experiments.llm_autointerp.llm_query import construct_llm_features_prompts
import experiments.llm_autointerp.llm_utils as llm_utils

prompts = construct_llm_features_prompts(ae_path, model.tokenizer, cfg)

prompts_num_tokens = {k: llm_utils.count_tokens(v) for k, v in prompts.items()}

dict_keys(['max_tokens_FKL', 'max_activations_FKL', 'dla_results_FK'])


Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 364.34it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 376.68it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 381.86it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 375.06it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 259.98it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 384.62it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 386.96it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 384.27it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 372.43it/s]
Formatting examples:   0%|          | 0/1 [00:00<?, ?it/s]

Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 371.05it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 373.03it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 372.13it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 378.04it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 386.18it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 377.83it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 385.36it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 371.37it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 378.10it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 380.26it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 370.49it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 364.34it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 366.76it/s]
Formatting examples: 100%|██████████| 1/1 [00:00<00:00, 353.32it/s]
Formatting examples: 100%|██████████| 1/1 [00:00

In [12]:
system_prompt_num_tokens = 3_000

prompts_num_tokens = {k: v+system_prompt_num_tokens for k, v in prompts_num_tokens.items()}


In [13]:
# num allowed tokens per minute
autointerp_api_total_token_per_minute_limit = 400_000
autointerp_api_total_requests_per_minute_limit = 4_000

num_allowed_tokens_per_minute = 0.5 * autointerp_api_total_token_per_minute_limit
num_allowed_requests_per_minute = 0.5 * autointerp_api_total_requests_per_minute_limit

def get_prompt_batch_indices(prompts: Dict[str: str], p_config: PipelineConfig):
    '''Given a dictionary of prompts, return a list of lists of indices of the prompts to be queried in each batch.'''
    assert p_config.num_tokens_system_prompt is not None, "num_tokens_system_prompt must be set in the config during the pipeline"
    prompts_num_tokens = {k: llm_utils.count_tokens(v) + p_config.num_tokens_system_prompt 
                          for k, v in prompts.items()}

    running_token_count = 0
    running_feat_idx_batch = []
    api_call_feat_idx_batches = []
    for feat_idx, num_tokens in prompts_num_tokens.items():
        if (len(running_feat_idx_batch) > num_allowed_requests_per_minute) or (running_token_count + num_tokens > num_allowed_tokens_per_minute): 
            api_call_feat_idx_batches.append(running_feat_idx_batch)
            running_feat_idx_batch = [feat_idx]
            running_token_count = num_tokens
        else:
            running_feat_idx_batch.append(feat_idx)
            running_token_count += num_tokens

    api_call_feat_idx_batches.append(running_feat_idx_batch)
    return api_call_feat_idx_batches
