In [None]:
import pickle
from bib_intervention import *

%load_ext autoreload
%autoreload 2

In [None]:
# For select_unique_class_features()
T_effects_unique_class = [1e-2, 5e-3, 1e-8]
T_max_sideeffect = 5e-3

In [None]:
# Load model and dictionaries
DEVICE = "cuda:0"
# TODO: improve scoping of probe layer int
layer = 4  # model layer for attaching linear classification head
SEED = 42
activation_dim = 512
verbose = False
select_unique_features = True

submodule_trainers = {
    "resid_post_layer_4": {"trainer_ids": [10]},
}

model_name_lookup = {"pythia70m": "EleutherAI/pythia-70m-deduped"}
dictionaries_path = "../dictionary_learning/dictionaries"

model_location = "pythia70m"
sweep_name = "_sweep0711"
model_name = model_name_lookup[model_location]
model = LanguageModel(model_name, device_map=DEVICE, dispatch=True)
submodule = utils.get_submodule(model, list(submodule_trainers.keys())[0], layer)

probe_train_set_size = 5000
probe_test_set_size = 1000
probe_act_submodule = utils.get_submodule(model, list(submodule_trainers.keys())[0], layer)
probe_layer = class_probing.probe_layer_lookup[model_name]

# Load datset and probes
train_set_size = 1000
test_set_size = 1000
probe_batch_size = 500
llm_batch_size = 10

# Attribution patching variables
N_EVAL_BATCHES = 5
patching_batch_size = 10


In [None]:
ae_group_paths = utils.get_ae_group_paths(
    dictionaries_path, model_location, sweep_name, submodule_trainers
)
ae_paths = utils.get_ae_paths(ae_group_paths)

context_length = utils.get_ctx_length(ae_paths)

dataset, _ = load_and_prepare_dataset()
train_bios, test_bios = get_train_test_data(dataset, train_set_size, test_set_size)

train_bios = utils.trim_bios_to_context_length(train_bios, context_length)
test_bios = utils.trim_bios_to_context_length(test_bios, context_length)

probe_path = f"trained_bib_probes/probes_ctx_len_{context_length}.pt"

if not os.path.exists(probe_path):
    print("Probes not found, training probes")
    probes = class_probing.train_probes(
        train_set_size=probe_train_set_size,
        test_set_size=probe_test_set_size,
        context_length=context_length,
        probe_batch_size=probe_batch_size,
        llm_batch_size=llm_batch_size,
        device=DEVICE,
        llm_model_name=model_name,
        epochs=10,
    )

probes = t.load(probe_path)
all_classes_list = list(probes.keys())

### Get activations for original model, all classes
print("Getting activations for original model")
test_acts = {}
for class_idx in tqdm(all_classes_list, desc="Getting activations per evaluated class"):
    class_test_acts = get_all_activations(test_bios[class_idx], model, submodule, llm_batch_size, probe_layer)
    test_acts[class_idx] = class_test_acts

test_accuracies = class_probing.get_probe_test_accuracy(
    probes, all_classes_list, test_acts, probe_batch_size, verbose, device=DEVICE
)

In [None]:
### Get activations for ablated models
# ablating the top features for each class
print("Getting activations for ablated models")

unique_feats = {}

for ae_path in ae_paths:
    print(f"Running ablation for {ae_path}")
    submodules = []
    dictionaries = {}
    submodule, dictionary, config = utils.load_dictionary(model, model_name, ae_path, DEVICE)
    submodules.append(submodule)
    dictionaries[submodule] = dictionary
    dict_size = config["trainer"]["dict_size"]
    context_length = config["buffer"]["ctx_len"]

    # ae_name_lookup is useful if we are using attribution patching on multiple submodules
    ae_name_lookup = {submodule: ae_path}

    node_effects = {}
    class_accuracies = test_accuracies.copy()

    for ablated_class_idx in all_classes_list:
        node_effects[ablated_class_idx] = {}

        node_effects[ablated_class_idx] = get_effects_per_class(
            model,
            submodules,
            dictionaries,
            probes,
            probe_act_submodule,
            ablated_class_idx,
            train_bios,
            N_EVAL_BATCHES,
            batch_size=patching_batch_size,
            patching_method="attrib",
            steps=10,
            device=DEVICE,
        )

    node_effects_cpu = utils.to_device(node_effects, "cpu")
    # Replace submodule keys with submodule_ae_path
    for abl_class_idx in node_effects_cpu.keys():
        node_effects_cpu[abl_class_idx] = {
            ae_name_lookup[submodule]: effects
            for submodule, effects in node_effects_cpu[abl_class_idx].items()
        }
    with open(ae_path + "node_effects.pkl", "wb") as f:
        pickle.dump(node_effects_cpu, f)
    del node_effects_cpu
    gc.collect()

    
    if select_unique_features:
        T_effects = T_effects_unique_class
        for T_effect in T_effects:
            unique_feats[T_effect] = select_unique_class_features(
                node_effects,
                dict_size,
                T_effect=T_effect,
                T_max_sideeffect=T_max_sideeffect,
                verbose=verbose,
                device=DEVICE,
            )
    else:
        T_effects = T_effects_all_classes
        for T_effect in T_effects:
            unique_feats[T_effect] = select_significant_features(
                node_effects, dict_size, T_effect=T_effect, verbose=verbose
            )

    # for ablated_class_idx in all_classes_list:
    #     class_accuracies[ablated_class_idx] = {}
    #     print(f"evaluating class {ablated_class_idx}")

    #     for T_effect in T_effects:
    #         feats = unique_feats[T_effect][ablated_class_idx]

    #         if len(feats) == 0:
    #             print(f"No features selected for T_effect = {T_effect}")
    #             continue

    #         class_accuracies[ablated_class_idx][T_effect] = {}
    #         if verbose:
    #             print(f"Running ablation for T_effect = {T_effect}")
    #         test_acts_ablated = {}
    #         for evaluated_class_idx in all_classes_list:
    #             test_acts_ablated[evaluated_class_idx] = get_all_acts_ablated(
    #                 test_bios[evaluated_class_idx],
    #                 model,
    #                 submodules,
    #                 dictionaries,
    #                 feats,
    #                 llm_batch_size,
    #             )

    #         for evaluated_class_idx in all_classes_list:
    #             batch_test_acts, batch_test_labels = prepare_probe_data(
    #                 test_acts_ablated, evaluated_class_idx, probe_batch_size, device=DEVICE
    #             )
    #             test_acc_probe = test_probe(
    #                 batch_test_acts,
    #                 batch_test_labels,
    #                 probes[evaluated_class_idx],
    #                 precomputed_acts=True,
    #             )
    #             if verbose:
    #                 print(
    #                     f"Ablated {ablated_class_idx}, evaluated {evaluated_class_idx} test accuracy: {test_acc_probe}"
    #                 )
    #             class_accuracies[ablated_class_idx][T_effect][evaluated_class_idx] = test_acc_probe

    #         del test_acts_ablated
    #         del batch_test_acts
    #         del batch_test_labels
    #         t.cuda.empty_cache()
    #         gc.collect()

    # class_accuracies = utils.to_device(class_accuracies, "cpu")
    # with open(ae_path + "class_accuracies.pkl", "wb") as f:
    #     pickle.dump(class_accuracies, f)

In [None]:
# Inspect that single feature
submodule, dictionary, config = utils.load_dictionary(model, model_name, ae_path, DEVICE)

In [None]:
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.buffer import ActivationBuffer

# interpret some features
data = hf_dataset_to_generator("monology/pile-uncopyrighted")
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    d_submodule=512,
    refresh_batch_size=128, # decrease to fit on smaller GPUs
    n_ctxs=512, # decrease to fit on smaller GPUs
    device=DEVICE
)

In [None]:
# Load /share/u/can/sae_eval/dictionary_learning/dictionaries/pythia70m_sweep0711/resid_post_layer_4/trainer_10/node_effects.pkl
file_name = '/share/u/can/sae_eval/dictionary_learning/dictionaries/pythia70m_sweep0711/resid_post_layer_4/trainer_10/node_effects.pkl'
with open(file_name, 'rb') as f:
    node_effects = pickle.load(f)

In [None]:
print(f't_effect_inscope: {T_effects_unique_class}')
print(f't_max_sideeffect: {T_max_sideeffect}')
T_effect_inscope = T_effects_unique_class[1]
print(f'chosen effect: {T_effect_inscope}')

In [None]:
from bib_intervention import select_unique_class_features

unique_feats= select_unique_class_features(
    node_effects,
    dict_size,
    T_effect=T_effect_inscope,
    T_max_sideeffect=T_max_sideeffect,
    verbose=True,
)

In [None]:
# Profession dictionary
profession_dict = {
    "accountant": 0,
    "architect": 1,
    "attorney": 2,
    "chiropractor": 3,
    "comedian": 4,
    "composer": 5,
    "dentist": 6,
    "dietitian": 7,
    "dj": 8,
    "filmmaker": 9,
    "interior_designer": 10,
    "journalist": 11,
    "model": 12,
    "nurse": 13,
    "painter": 14,
    "paralegal": 15,
    "pastor": 16,
    "personal_trainer": 17,
    "photographer": 18,
    "physician": 19,
    "poet": 20,
    "professor": 21,
    "psychologist": 22,
    "rapper": 23,
    "software_engineer": 24,
    "surgeon": 25,
    "teacher": 26,
    "yoga_teacher": 27,
}
profession_dict_rev = {v: k for k, v in profession_dict.items()}

In [None]:
for profession_idx in unique_feats.keys():
    print(f'{profession_idx}. {profession_dict_rev[profession_idx]} has {next(iter(unique_feats[profession_idx].values())).sum()} unique features')

In [None]:
inspected_prof_idx = 1
print(f'Relevant features for {profession_dict_rev[inspected_prof_idx]}')

class_feats = unique_feats[inspected_prof_idx]
class_feats = list(class_feats.values())[0]
class_feats_nonzero = class_feats.nonzero().flatten().tolist()
class_feats_nonzero

prof_node_effects = list(node_effects[inspected_prof_idx].values())[0]
feat_effects = prof_node_effects[class_feats_nonzero].tolist()

prof_unique_feats = zip(class_feats_nonzero, feat_effects)
sorted_prof_unique_feats = sorted(prof_unique_feats, key=lambda x: x[1], reverse=True)
for feat_idx, effect in sorted_prof_unique_feats:
    print(f'Feature {feat_idx} has effect {effect}')

In [None]:
# feat_idx = 37
# feat_idx = 2539
# feat_idx = 7303
feat_idx = 6972

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    max_length=128,
    n_inputs=1000,
    dim_idx=feat_idx,
    k=30,
)

print(f'\n\ntop activating tokens for feature {feat_idx}')
for token in out.top_tokens:
    print(token)
print(f'\n\ntop affected tokens for feature {feat_idx}')
for token in out.top_affected:
    print(token)
out.top_contexts

### Top k node effects (attribution_patching)

In [None]:
k = 10
prof_idx = 0
print(f'{prof_idx}. {profession_dict_rev[prof_idx]}')

In [None]:
prof_node_effects = list(node_effects[inspected_prof_idx].values())[0]
feat_effects = enumerate(prof_node_effects.tolist())
sorted_feat_effects = sorted(feat_effects, key=lambda x: x[1], reverse=True)

In [None]:
sorted_feat_effects[:k]

In [None]:
inspected_top_idx = 2


feat_idx = sorted_feat_effects[inspected_top_idx][0]
print(f'Feature {feat_idx} has effect {sorted_feat_effects[inspected_top_idx][1]}')

# feat_idx = 1155 # female
# feat_idx = 37 # Jewish

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    max_length=128,
    n_inputs=1000,
    dim_idx=feat_idx,
    k=30,
)

print(f'\n\ntop activating tokens for feature {feat_idx}')
for token in out.top_tokens:
    print(token)
print(f'\n\ntop affected tokens for feature {feat_idx}')
for token in out.top_affected:
    print(token)
out.top_contexts



In [None]:
## Plotting functions


def plot_feature_effects_above_threshold(nodes, threshold=0.05):
    all_values = []
    for key in nodes.keys():
        all_values.append(nodes[key].cpu().numpy().reshape(-1))
    all_values = [x for sublist in all_values for x in sublist]
    all_values = [x for x in all_values if x > threshold]

    all_values = sorted(all_values, reverse=True)
    plt.scatter(range(len(all_values)), all_values)
    plt.title("all_values")
    plt.show()