In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nnsight import LanguageModel
import torch

import experiments.utils as utils
import experiments.probe_training as probe_training

data = {
    "inputs": ["Short input", "Another short input", "F", "This is a longer input that is more than 10 characters long"],
}

model_name = "EleutherAI/pythia-70m-deduped"
model_dtype = torch.bfloat16
device = "cuda"

model = LanguageModel(model_name, torch_dtype=model_dtype, device_map=device, attn_implementation="eager")

tokenized_data = utils.tokenize_data(data, model.tokenizer, 128, device)

submodule = model.gpt_neox.layers[4]



In [None]:
# print(tokenized_data['inputs'])

In [None]:
meaned_activations = probe_training.get_all_meaned_activations(tokenized_data['inputs'], model, 2, submodule)

In [None]:
print(meaned_activations.shape)

In [None]:
with model.trace(tokenized_data['inputs']):
    acts_BLD = submodule.output.save()
    input = model.input.save()

acts_BLD = acts_BLD.value
if isinstance(acts_BLD, tuple):
    acts_BLD = acts_BLD[0]

print(acts_BLD.shape)

attn_mask_BL = input.value[1]["attention_mask"]
print(attn_mask_BL.shape)

acts_BL_D = acts_BLD[attn_mask_BL != 0]

print(acts_BL_D.shape)

In [None]:
from experiments.pipeline_config import PipelineConfig

ae_sweep_paths = {"pythia70m_test_sae": {"resid_post_layer_3": {"trainer_ids": [0]}}}
p_config = PipelineConfig()

sweep_name, submodule_trainers = list(ae_sweep_paths.items())[0]

ae_group_paths = utils.get_ae_group_paths(
    p_config.dictionaries_path, sweep_name, submodule_trainers
)
ae_paths = utils.get_ae_paths(ae_group_paths)
print(ae_paths)

ae_path = ae_paths[0]

submodule, dictionary, sae_config = utils.load_dictionary(model, ae_path, device)

In [None]:
import experiments.autointerp as autointerp

decoder_weight_DF = autointerp.get_decoder_weight(dictionary)
print(decoder_weight_DF.shape)

In [None]:
import pickle
import os

probe_path = os.path.join(p_config.probes_dir, "pythia-70m-deduped", "spurious_probes_bias_in_bios_professor_nurse_ctx_len_128_layer_3.pkl")

with open(probe_path, "rb") as f:
    probe = pickle.load(f)

print(probe.keys())
print(probe['male / female'].net.weight.shape)

probe_weight_D = probe['male / female'].net.weight.to(dtype=torch.float32, device=device)

In [None]:
dot_prod_F = probe_weight_D @ decoder_weight_DF
print(dot_prod_F.shape)

# min, mean, max
print(dot_prod_F.min(), dot_prod_F.mean(), dot_prod_F.max())

In [None]:
effects_path = os.path.join(ae_path, "node_effects.pkl")
effects_orig_path = os.path.join(ae_path, "node_effects_orig.pkl")

with open(effects_path, "rb") as f:
    effects = pickle.load(f)

with open(effects_orig_path, "rb") as f:
    effects_orig = pickle.load(f)



In [None]:
print(effects.keys())
print(effects_orig.keys())

effects_F = effects['male / female']
effects_orig_F = effects_orig['male / female']

In [None]:
import torch

# Get the top 20 values and their indices from effects_orig_F
top20_orig_values, top20_orig_indices = torch.topk(effects_orig_F, 20)

# Get the top 20 values and their indices from effects_F
top20_F_values, top20_F_indices = torch.topk(effects_F, 20)

print("Top 20 indices from effects_orig_F:")
print(top20_orig_indices)

print("Top 20 indices from effects_F:")
print(top20_F_indices)

# Find how many of the top 20 indices from effects_orig_F are in the top 20 of effects_F
common_indices = set(top20_orig_indices.tolist()) & set(top20_F_indices.tolist())
num_common = len(common_indices)

print(f"Number of common indices in top 20: {num_common}")

# If you want to see the actual common indices:
print("Common indices:", common_indices)

# If you want to see the values for these common indices in both tensors:
for idx in common_indices:
    orig_value = effects_orig_F[idx].item()
    F_value = effects_F[idx].item()
    print(f"Index {idx}: Original value = {orig_value:.4f}, New value = {F_value:.4f}")