# Qualitative assessment of Autointerp for Targeted Probe Perturbation metric
For example: An SAE with L0 500 has good performance before auto-interp, and poor performance after. What features are being rejected by the autointerp?

Notes:

- Probe to perturbate is trained on layer 24 / 26 resid post

In [1]:
import experiments.utils as utils
import pickle
import os
import experiments.autointerp as autointerp
from nnsight import LanguageModel
from experiments.pipeline_config import PipelineConfig
import torch as t

from experiments.bib_intervention import select_features
from experiments.pipeline_config import FeatureSelection

In [2]:
# Define dictionaries

DICTIONARIES_PATH = "../dictionary_learning/dictionaries/gemma-2-2b-saved-data"
trainer_ids = [5]
ae_sweep_paths = {
    # "gemma-2-2b_sweep_jumprelu_0902_probe_layer24_results": {
    #     "resid_post_layer_11": {"trainer_ids": trainer_ids},
    # },
    # "gemma-2-2b_sweep_standard_ctx128_ef8_0824_probe_layer24_results": {
    #     "resid_post_layer_11": {"trainer_ids": trainer_ids},
    # },
    "gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer24_results": {
        "resid_post_layer_11": {"trainer_ids": trainer_ids},
    },
    # "gemma-2-2b_sweep_standard_ctx128_ef2_0824_probe_layer_24_results": {
    #     "resid_post_layer_11": {"trainer_ids": trainer_ids},
    # },
    # "gemma-2-2b_sweep_topk_ctx128_ef2_0824_probe_layer_24_results": {
    #     "resid_post_layer_11": {"trainer_ids": trainer_ids},
    # },
}

sweep_name = list(ae_sweep_paths.keys())[0]
submodule_trainers = ae_sweep_paths[sweep_name]

ae_paths = []
for sweep_name, submodule_trainers in ae_sweep_paths.items():

    ae_group_paths = utils.get_ae_group_paths(
        DICTIONARIES_PATH, sweep_name, submodule_trainers
    )
    ae_paths.extend(utils.get_ae_paths(ae_group_paths))

print(f'available paths: {ae_paths}\n')
ae_path = ae_paths[0]
print(f'Selecting path: {ae_path}')

available paths: ['../dictionary_learning/dictionaries/gemma-2-2b-saved-data/gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer24_results/resid_post_layer_11/trainer_5']

Selecting path: ../dictionary_learning/dictionaries/gemma-2-2b-saved-data/gemma-2-2b_sweep_topk_ctx128_ef8_0824_probe_layer24_results/resid_post_layer_11/trainer_5


In [3]:
# # Compute max activating examples, if they haven't been computed yet

# p_config = PipelineConfig()
# model_eval_config = utils.ModelEvalConfig.from_sweep_name(sweep_name)
# model_name = model_eval_config.full_model_name

# llm_batch_size, patching_batch_size, eval_results_batch_size = utils.get_batch_sizes(
#     model_eval_config,
#     p_config.reduced_GPU_memory,
#     p_config.train_set_size,
#     p_config.test_set_size,
#     p_config.probe_train_set_size,
#     p_config.probe_test_set_size,
# )

# model = LanguageModel(
#     model_name,
#     device_map=p_config.device,
#     dispatch=True,
#     attn_implementation="eager",
#     torch_dtype=p_config.model_dtype,
# )

# autointerp.get_autointerp_inputs_for_all_saes(
#         model,
#         p_config.max_activations_collection_n_inputs,
#         llm_batch_size,
#         p_config.autointerp_context_length,
#         p_config.top_k_inputs_act_collect,
#         ae_paths,
#         force_rerun=p_config.force_max_activations_recompute,
#     )

# # Load max activations
# with open(os.path.join(ae_path, "max_activating_inputs.pkl"), "rb") as f:
#     max_activating_inputs = pickle.load(f)

In [4]:
# # Load class_accuracies after ablation
# def load_class_accuracies(ae_path: str, scoring_method: str):
#     assert scoring_method in ['attrib', 'auto_interp']

#     filename = f"{ae_path}/class_accuracies_{scoring_method}.pkl"
#     with open(filename, "rb") as f:
#             class_accuracies = pickle.load(f)
#     return class_accuracies

# probe_acc_attrib = load_class_accuracies(ae_path, 'attrib')
# probe_autointerp = load_class_accuracies(ae_path, 'auto_interp')

In [5]:
# Load importance scores
with open(os.path.join(ae_path, "node_effects.pkl"), "rb") as f:
    scores_attrib = pickle.load(f)

with open(os.path.join(ae_path, "node_effects_auto_interp.pkl"), "rb") as f:
    scores_autointerp = pickle.load(f)

In [6]:
# Load top features from importance score files

selection_method = FeatureSelection.top_n
num_top_features_from_attrib = 20 # aka threshold
class_indices = [0, 1, 2, 6] # Effectively filtering out spurious correlation classes, these are the only TPP available
dict_size = next(iter(scores_attrib.values())).shape[0]

top_latent_tensor_attrib = select_features(
    selection_method=selection_method,
    node_effects=scores_attrib,
    T_effects=[num_top_features_from_attrib],
    T_max_sideeffect=None,
    dict_size=dict_size,
)
top_latent_tensor_autointerp = select_features(
    selection_method=selection_method,
    node_effects=scores_autointerp,
    T_effects=[num_top_features_from_attrib],
    T_max_sideeffect=None,
    dict_size=dict_size,
)

# Reformat indices
top_latent_indices_attrib = {
    k: v.nonzero().squeeze()
    for k, v in top_latent_tensor_attrib[num_top_features_from_attrib].items()
    if k in class_indices
}
top_latent_indices_autointerp = {
    k: v.nonzero().squeeze()
    for k, v in top_latent_tensor_autointerp[num_top_features_from_attrib].items()
    if k in class_indices
}
top_latent_indices_not_autointerp = {
    k: v[t.isin(v, top_latent_indices_autointerp[k], invert=True)]
    for k, v in top_latent_indices_attrib.items()
    if k in class_indices
}

for k in class_indices:
    print()
    print(f'Class {k}')
    print(f'Attrib: {len(top_latent_indices_attrib[k])}, {top_latent_indices_attrib[k]}')
    print(f'AutoInterp: {len(top_latent_indices_autointerp[k])}, {top_latent_indices_autointerp[k]}')
    print(f'AutoInterp: {len(top_latent_indices_autointerp[k])}, {top_latent_tensor_autointerp[num_top_features_from_attrib][k][top_latent_indices_autointerp[k]]}')
    print(f'Attrib not AutoInterp: {len(top_latent_indices_not_autointerp[k])}, {(top_latent_indices_not_autointerp[k])}')


Class 0
Attrib: 20, tensor([  394,   453,   994,  1834,  1849,  2075,  3650,  4058,  5205,  6082,
        10010, 10446, 10627, 12031, 12424, 12954, 13107, 14939, 15263, 17170])
AutoInterp: 0, tensor([], dtype=torch.int64)
AutoInterp: 0, tensor([], dtype=torch.bool)
Attrib not AutoInterp: 20, tensor([  394,   453,   994,  1834,  1849,  2075,  3650,  4058,  5205,  6082,
        10010, 10446, 10627, 12031, 12424, 12954, 13107, 14939, 15263, 17170])

Class 1
Attrib: 20, tensor([  377,   453,  1381,  1849,  2075,  4554,  5163,  5733,  6820,  8040,
         8381, 12424, 12954, 12996, 13221, 13346, 13762, 14154, 16645, 17307])
AutoInterp: 4, tensor([    1,  6082, 12424, 16645])
AutoInterp: 4, tensor([True, True, True, True])
Attrib not AutoInterp: 18, tensor([  377,   453,  1381,  1849,  2075,  4554,  5163,  5733,  6820,  8040,
         8381, 12954, 12996, 13221, 13346, 13762, 14154, 17307])

Class 2
Attrib: 20, tensor([  377,   453,   577,  2075,  3648,  3650,  4058,  5826,  6864,  7064,
  

In [None]:
# Get max act examples per feature

# Show max act of autointerp included / rejected features

In [None]:
# Why are ablation differences so small?
# Attribution patching is a bad approximation and misses indirectly relevant features? Do acutal patching instead?
# SAE latents are correlated? Patching a single latent will not have a large effect?