In [None]:
import random

from experiments.bib_intervention import run_interventions, FeatureSelection
import experiments.utils as utils

In [None]:
selection_method = FeatureSelection.above_threshold
selection_method = FeatureSelection.top_n

random_seed = random.randint(0, 1000)
num_classes = 5

probe_train_set_size = 5000
probe_test_set_size = 1000

# Load datset and probes
train_set_size = 1000
test_set_size = 1000
probe_batch_size = 50
llm_batch_size = 125
# llm_batch_size = 10

# Attribution patching variables
n_eval_batches = 4
patching_batch_size = 50
# patching_batch_size = 5

top_n_features = [5, 10, 20, 50, 100, 500]
top_n_features = [5, 500]
T_effects_all_classes = [0.1, 0.01, 0.005, 0.001]
# T_effects_all_classes = [0.001]
T_effects_unique_class = [1e-4, 1e-8]

if selection_method == FeatureSelection.top_n:
    T_effects = top_n_features
elif selection_method == FeatureSelection.above_threshold:
    T_effects = T_effects_all_classes
elif selection_method == FeatureSelection.unique:
    T_effects = T_effects_unique_class
else:
    raise ValueError("Invalid selection method")

T_max_sideeffect = 5e-3

dictionaries_path = "../dictionary_learning/dictionaries"
probes_dir = "trained_bib_probes"

# Example of sweeping over all SAEs in a sweep
ae_sweep_paths = {"pythia70m_test_sae": None}

# Example of sweeping over all SAEs in a submodule
ae_sweep_paths = {"pythia70m_test_sae": {"resid_post_layer_3": None}}

# Example of sweeping over a single SAE
ae_sweep_paths = {"pythia70m_test_sae": {"resid_post_layer_3": {"trainer_ids": [0]}}}

# This will look for any empty folders in any ae_path and raise an error if it finds any
for sweep_name, submodule_trainers in ae_sweep_paths.items():
    ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)

for sweep_name, submodule_trainers in ae_sweep_paths.items():

    run_interventions(
        submodule_trainers,
        sweep_name,
        dictionaries_path,
        probes_dir,
        selection_method,
        probe_train_set_size,
        probe_test_set_size,
        train_set_size,
        test_set_size,
        probe_batch_size,
        llm_batch_size,
        n_eval_batches,
        patching_batch_size,
        T_effects,
        T_max_sideeffect,
        num_classes,
        random_seed,
    )

## Direct Probe Attribution

In [None]:
from nnsight import LanguageModel
import matplotlib.pyplot as plt
import pickle
from dictionary_learning.dictionary import AutoEncoder


# dictionaries_path = "../dictionary_learning/dictionaries"
# device = "cuda"


# submodule_trainers = {"resid_post_layer_3": {"trainer_ids": [10]}}
# sweep_name = "pythia70m_sweep0711"


# model_eval_config = utils.ModelEvalConfig.from_sweep_name(sweep_name)
# model_name = model_eval_config.full_model_name

# model = LanguageModel(model_name, device_map=device, dispatch=True)

# ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)
# ae_paths = utils.get_ae_paths(ae_group_paths)
# submodule, dictionary, config = utils.load_dictionary(model, ae_paths[0], device)
# sae_decoder = dictionary.decoder.weight

# Load Sam's SAE
ae_path = '/share/u/can/sae_eval/dictionary_learning/dictionaries/pythia-70m-deduped/resid_out_layer4/10_32768/ae.pt'
dictionary = AutoEncoder.from_pretrained(ae_path, device='cuda')
sae_decoder = dictionary.decoder.weight

In [None]:
probe_path = '/share/u/can/sae_eval/experiments/trained_bib_probes/pythia-70m-deduped/probes_ctx_len_128.pkl'

with open(probe_path, "rb") as f:
    probes = pickle.load(f)

probe_map = probes[0].net.weight

In [None]:
# Simple weight matmul
direct_feature_attr = probe_map @ sae_decoder
direct_feature_attr = direct_feature_attr.squeeze().detach().cpu().numpy()
plt.hist(direct_feature_attr, bins=100)

In [None]:
# Profession dictionary
profession_dict = {
    "accountant": 0,
    "architect": 1,
    "attorney": 2,
    "chiropractor": 3,
    "comedian": 4,
    "composer": 5,
    "dentist": 6,
    "dietitian": 7,
    "dj": 8,
    "filmmaker": 9,
    "interior_designer": 10,
    "journalist": 11,
    "model": 12,
    "nurse": 13,
    "painter": 14,
    "paralegal": 15,
    "pastor": 16,
    "personal_trainer": 17,
    "photographer": 18,
    "physician": 19,
    "poet": 20,
    "professor": 21,
    "psychologist": 22,
    "rapper": 23,
    "software_engineer": 24,
    "surgeon": 25,
    "teacher": 26,
    "yoga_teacher": 27,
}
profession_dict_rev = {v: k for k, v in profession_dict.items()}

In [None]:
# Feature to probe direction
for i, probe in probes.items():
    print(f'probe direction for profession {profession_dict_rev[i]}')
    probe_map = probe.net.weight
    probe_map_norm = probe_map / probe_map.norm(dim=1, keepdim=True)
    sae_decoder_norm = sae_decoder / sae_decoder.norm(dim=0, keepdim=True)

    feature_probe_cossim = probe_map_norm @ sae_decoder_norm
    feature_probe_cossim = feature_probe_cossim.squeeze().detach().cpu().numpy()

    max_cossim = feature_probe_cossim.max()
    min_cossim = feature_probe_cossim.min()
    print(f"Max cossim: {max_cossim}, Min cossim: {min_cossim}")

    plt.hist(feature_probe_cossim, bins=100);