In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import sys
import os
import random
import gc
from collections import defaultdict
import einops
import math
import numpy as np
import pickle

import torch
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from tqdm import tqdm
from typing import Callable, Optional

from datasets import load_dataset
import datasets
from nnsight import LanguageModel

import experiments.utils as utils
from experiments.probe_training import *
from experiments.pipeline_config import PipelineConfig

from dictionary_learning.dictionary import AutoEncoder

# Configuration
DEBUGGING = False
SEED = 42

# Set up paths and model
parent_dir = os.path.abspath("..")
sys.path.append(parent_dir)

tracer_kwargs = dict(scan=DEBUGGING, validate=DEBUGGING)


In [None]:
@torch.no_grad()
def get_all_activations(
    text_inputs: list[str], model: LanguageModel, batch_size: int, submodule: utils.submodule_alias
) -> torch.Tensor:
    # TODO: Rename text_inputs
    text_batches = utils.batch_inputs(text_inputs, batch_size)

    all_acts_list_BD = []
    for text_batch_BL in text_batches:
        with model.trace(
            text_batch_BL,
            **tracer_kwargs,
        ):
            attn_mask = model.input[1]["attention_mask"]
            acts_BLD = submodule.output[0]
            acts_BLD = acts_BLD * attn_mask[:, :, None]
            acts_BD = acts_BLD.sum(1) / attn_mask.sum(1)[:, None]
            acts_BD = acts_BD.save()
        all_acts_list_BD.append(acts_BD.value)

    all_acts_bD = torch.cat(all_acts_list_BD, dim=0)
    return all_acts_bD

In [None]:

llm_model_name = "EleutherAI/pythia-70m-deduped"
device = "cuda"
train_set_size = 4000
test_set_size = 1000
context_length = 128
include_gender = True
model_dtype = torch.bfloat16

probe_batch_size = 500
llm_batch_size = 500

# TODO: I think there may be a scoping issue with model and get_acts(), but we currently aren't using get_acts()
model = LanguageModel(llm_model_name, device_map=device, dispatch=True, torch_dtype=model_dtype)
probe_dir = "trained_bib_probes"
only_model_name = llm_model_name.split("/")[-1]

model_eval_config = utils.ModelEvalConfig.from_full_model_name(llm_model_name)
probe_layer = model_eval_config.probe_layer

probe_output_filename = (
    f"{probe_dir}/{only_model_name}/probes_ctx_len_{context_length}_layer_{probe_layer}.pkl"
)

epochs= 1
save_results = True
seed = SEED
include_gender = True


In [None]:

"""Because we save the probes, we always train them on all classes to avoid potential issues with missing classes. It's only a one-time cost."""
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

model_eval_config = utils.ModelEvalConfig.from_full_model_name(llm_model_name)
d_model = model_eval_config.activation_dim
probe_layer = model_eval_config.probe_layer
probe_act_submodule = utils.get_submodule(model, "resid_post", probe_layer)

dataset, df = load_and_prepare_dataset()

train_bios, test_bios = get_train_test_data(
    dataset,
    train_set_size,
    test_set_size,
    include_gender,
)

train_bios = utils.tokenize_data(train_bios, model.tokenizer, context_length, device)
test_bios = utils.tokenize_data(test_bios, model.tokenizer, context_length, device)


In [None]:
trainer_ids = [10]

ae_sweep_paths = {
    # "pythia70m_sweep_standard_ctx128_0712": {
    #     #     # "resid_post_layer_0": {"trainer_ids": None},
    #     #     # "resid_post_layer_1": {"trainer_ids": None},
    #     #     # "resid_post_layer_2": {"trainer_ids": None},
    #     "resid_post_layer_3": {"trainer_ids": [6]},
    #     #     "resid_post_layer_4": {"trainer_ids": None},
    # },
    "pythia70m_sweep_topk_ctx128_0730": {
        # "resid_post_layer_0": {"trainer_ids": None},
        # "resid_post_layer_1": {"trainer_ids": None},
        # "resid_post_layer_2": {"trainer_ids": None},
        "resid_post_layer_3": {"trainer_ids": trainer_ids},
        # "resid_post_layer_4": {"trainer_ids": trainer_ids},
    },
}

p_config = PipelineConfig()

sweep_name, submodule_trainers = list(ae_sweep_paths.items())[0]

ae_group_paths = utils.get_ae_group_paths(
    p_config.dictionaries_path, sweep_name, submodule_trainers
)
ae_paths = utils.get_ae_paths(ae_group_paths)
print(ae_paths)

ae_path = ae_paths[0]

submodule, dictionary, sae_config = utils.load_dictionary(model, ae_path, device)

In [None]:
@torch.no_grad()
def get_all_sae_activations(
    text_inputs: list[str],
    model: LanguageModel,
    dictionary: AutoEncoder,
    batch_size: int,
    submodule: utils.submodule_alias,
) -> tuple[torch.Tensor, torch.Tensor]:
    # TODO: Rename text_inputs
    text_batches = utils.batch_inputs(text_inputs, batch_size)

    with torch.no_grad(), model.trace("_"):
        is_tuple = type(submodule.output.shape) == tuple

    model_dtype = model.dtype

    all_acts_list_BD = []
    all_sae_acts_list_BF = []
    for text_batch_BL in text_batches:
        with model.trace(
            text_batch_BL,
            **tracer_kwargs,
        ):
            attn_mask = model.input[1]["attention_mask"]
            acts_BLD = submodule.output

            if is_tuple:
                acts_BLD = acts_BLD[0]

            acts_BLF = dictionary.encode(acts_BLD)
            acts_BLF = acts_BLF * attn_mask[:, :, None]
            acts_BF = acts_BLF.sum(1) / attn_mask.sum(1)[:, None]
            acts_BF = acts_BF.save()

            acts_BLD = acts_BLD * attn_mask[:, :, None]
            acts_BD = acts_BLD.sum(1) / attn_mask.sum(1)[:, None]
            acts_BD = acts_BD.save()
        all_acts_list_BD.append(acts_BD.value)
        all_sae_acts_list_BF.append(acts_BF.value.to(dtype=model_dtype))

    all_acts_bD = torch.cat(all_acts_list_BD, dim=0)
    all_sae_acts_bF = torch.cat(all_sae_acts_list_BF, dim=0)

    return all_acts_bD, all_sae_acts_bF

In [None]:
all_train_acts = {}
all_test_acts = {}

all_train_sae_acts = {}
all_test_sae_acts = {}

llm_batch_size = 100

with torch.no_grad():
    for i, profession in enumerate(train_bios.keys()):
        # if isinstance(profession, int):
        #     continue

        print(f"Collecting activations for profession: {profession}")

        all_train_acts[profession], all_train_sae_acts[profession] = get_all_sae_activations(
            train_bios[profession], model, dictionary, llm_batch_size, submodule
        )
        all_test_acts[profession], all_test_sae_acts[profession] = get_all_sae_activations(
            test_bios[profession], model, dictionary, llm_batch_size, submodule
        )


In [None]:
def print_tensor_memory_usage(tensor: torch.Tensor):
    if not isinstance(tensor, torch.Tensor):
        print("Input is not a tensor. Cannot calculate memory usage.")
        return
    
    memory = tensor.element_size() * tensor.nelement()
    print(f"Tensor Shape: {tensor.shape}")
    print(f"Tensor Type: {tensor.dtype}")
    print(f"Memory usage: {memory / 1024 / 1024:.2f} MB")

gc.collect()
torch.cuda.empty_cache()

model_acts = all_train_acts[0]
sae_acts = all_train_sae_acts[0]
print_tensor_memory_usage(model_acts)
print_tensor_memory_usage(sae_acts)

In [None]:
# Probe model and training
class Probe(nn.Module):
    def __init__(self, activation_dim: int, dtype: torch.dtype):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True, dtype=dtype)

    def forward(self, x):
        return self.net(x).squeeze(-1)

def train_probe(
    train_input_batches: list,
    train_label_batches: list[torch.Tensor],
    test_input_batches: list,
    test_label_batches: list[torch.Tensor],
    get_acts: Callable,
    precomputed_acts: bool,
    dim: int,
    epochs: int,
    device: str,
    model_dtype: torch.dtype,
    lr: float = 1e-2,
    seed: int = SEED,
    verbose: bool = False,
) -> tuple[Probe, float]:
    """input_batches can be a list of tensors or strings. If strings, get_acts must be provided."""

    if type(train_input_batches[0]) == str or type(test_input_batches[0]) == str:
        assert precomputed_acts == False
    elif type(train_input_batches[0]) == torch.Tensor or type(test_input_batches[0]) == torch.Tensor:
        assert precomputed_acts == True

    probe = Probe(dim, model_dtype).to(device)
    optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        batch_idx = 0
        for inputs, labels in zip(train_input_batches, train_label_batches):
            if precomputed_acts:
                acts_BD = inputs
            else:
                acts_BD = get_acts(inputs)
            logits_B = probe(acts_BD)
            loss = criterion(logits_B, labels.clone().detach().to(device=device, dtype=model_dtype))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_idx += 1

        

        train_accuracy = test_probe(
            train_input_batches[:30], train_label_batches[:30], probe, get_acts, precomputed_acts
        )


        test_accuracy = test_probe(
            test_input_batches, test_label_batches, probe, get_acts, precomputed_acts
        )

        if epoch == epochs - 1 and verbose:
            print(f"\nEpoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}\n")
    
    return probe, test_accuracy

In [None]:
from experiments.probe_training import prepare_probe_data

epochs = 5

torch.manual_seed(0)

def train_probe_on_activations(
    train_activations: dict[str | int : torch.Tensor],
    test_activations: dict[str | int : torch.Tensor],
    select_top_k: Optional[int] = None,
) -> tuple[dict[str | int : Probe], dict[str | int : float]]:
    torch.set_grad_enabled(True)

    probes, test_accuracies = {}, {}

    for profession in train_activations.keys():
        if profession in utils.PAIRED_CLASS_KEYS.values():
            continue

        train_acts, train_labels = prepare_probe_data(train_activations, profession, probe_batch_size, select_top_k)

        test_acts, test_labels = prepare_probe_data(test_activations, profession, probe_batch_size, select_top_k)

        if profession == "biased_male / biased_female" or profession == "male / female":
            probe_epochs = 1
        else:
            probe_epochs = epochs

        activation_dim = train_acts[0].shape[1]

        probe, test_accuracy = train_probe(
            train_acts,
            train_labels,
            test_acts,
            test_labels,
            get_acts,
            precomputed_acts=True,
            epochs=probe_epochs,
            dim=activation_dim,
            device=device,
            model_dtype=model_dtype,
            verbose=False,
        )

        probes[profession] = probe
        test_accuracies[profession] = test_accuracy

    return probes, test_accuracies

    # if save_results:
    #     only_model_name = llm_model_name.split("/")[-1]
    #     os.makedirs(f"{probe_dir}", exist_ok=True)
    #     os.makedirs(f"{probe_dir}/{only_model_name}", exist_ok=True)

    #     with open(probe_output_filename, "wb") as f:
    #         pickle.dump(probes, f)


model_activation_probes, model_activation_accuracies = train_probe_on_activations(all_train_acts, all_test_acts)
sae_activation_probes, sae_activation_accuracies = train_probe_on_activations(all_train_sae_acts, all_test_sae_acts)
sae_activation_top_1_probes, sae_activation_top_1_accuracies = train_probe_on_activations(all_train_sae_acts, all_test_sae_acts, select_top_k=1)
sae_activation_top_5_probes, sae_activation_top_5_accuracies = train_probe_on_activations(all_train_sae_acts, all_test_sae_acts, select_top_k=5)
sae_activation_top_10_probes, sae_activation_top_10_accuracies = train_probe_on_activations(all_train_sae_acts, all_test_sae_acts, select_top_k=10)

In [None]:
all_diffs = []
int_diffs = []
str_diffs = []

for class_name in model_activation_accuracies.keys():
    model_acc = model_activation_accuracies[class_name][0]
    sae_acc = sae_activation_accuracies[class_name][0]
    # sae_acc = sae_activation_top_1_accuracies[class_name][0]
    # sae_acc = sae_activation_top_5_accuracies[class_name][0]
    # sae_acc = sae_activation_top_10_accuracies[class_name][0]

    diff = model_acc - sae_acc
    print(f"Class: {class_name}, Model Acc: {model_acc}, SAE Acc: {sae_acc}, Diff: {diff}")

    if isinstance(class_name, int):
        int_diffs.append(diff)
    if isinstance(class_name, str):
        str_diffs.append(diff)
    all_diffs.append(diff)

print(f"\nAverage difference: {np.mean(all_diffs)}")
print(f"Average difference for int classes: {np.mean(int_diffs)}")
print(f"Average difference for str classes: {np.mean(str_diffs)}")

In [None]:
from experiments.probe_training import get_activation_distribution_diff

sae_feature_distribution_differences = {}
for profession in all_train_sae_acts.keys():
    if profession in utils.PAIRED_CLASS_KEYS.values():
        continue

    sae_feature_distribution_differences[profession] = get_activation_distribution_diff(all_train_sae_acts, profession)

print(sae_feature_distribution_differences[0].sum())
print(sae_feature_distribution_differences["male / female"].sum())

In [None]:
chosen_class_indices = [
    "male / female",
    "professor / nurse",
    "male_professor / female_nurse",
    "biased_male / biased_female",
    0,
    1,
    2,
    6,
]

saved_sae_feature_distribution_differences = {}

for class_index in chosen_class_indices:
    saved_sae_feature_distribution_differences[class_index] = sae_feature_distribution_differences[class_index]

In [None]:
with open(f"{ae_path}/node_effects.pkl", "rb") as f:
    node_effects_attrib_patching = pickle.load(f)

with open(f"{ae_path}/node_effects_dist_diff.pkl", "wb") as f:
    pickle.dump(sae_feature_distribution_differences, f)

class_idx = "male / female"
# class_idx = 0

node_effect = node_effects_attrib_patching[class_idx]
sae_distribution = sae_feature_distribution_differences[class_idx]

In [None]:
print(f"node effect stats: {node_effect.mean():.4f}, {node_effect.std().item():.4f}, {node_effect.max().item():.4f}, {node_effect.min().item():.4f}")
print(f"sae distribution stats: {sae_distribution.mean():.4f}, {sae_distribution.std().item():.4f}, {sae_distribution.max().item():.4f}, {sae_distribution.min().item():.4f}")
import torch

def normalize_tensor(tensor):
    return (tensor - tensor.mean()) / tensor.std()

def compare_top_values(tensor1, tensor2, top_n=10):
    # Ensure tensors are on the same device
    device = tensor1.device
    tensor2 = tensor2.to(device)

    # Convert to float32 for calculations
    tensor1 = tensor1.to(torch.float32)
    tensor2 = tensor2.to(torch.float32)

    # Normalize tensors
    norm_tensor1 = normalize_tensor(tensor1)
    norm_tensor2 = normalize_tensor(tensor2)
    
    # Get indices of top N values
    top_indices1 = torch.argsort(norm_tensor1, descending=True)[:top_n]
    top_indices2 = torch.argsort(norm_tensor2, descending=True)[:top_n]
    
    print(f"Top {top_n} indices in normalized node_effect:")
    for i, idx in enumerate(top_indices1):
        print(f"  {i+1}. Index {idx.item()}: {norm_tensor1[idx].item():.4f} (original: {tensor1[idx].item():.4f})")
    
    print(f"\nTop {top_n} indices in normalized sae_distribution:")
    for i, idx in enumerate(top_indices2):
        print(f"  {i+1}. Index {idx.item()}: {norm_tensor2[idx].item():.4f} (original: {tensor2[idx].item():.4f})")
    
    # Compare common indices
    common_indices = set(top_indices1.tolist()) & set(top_indices2.tolist())
    print(f"\nCommon indices in top {top_n}: {common_indices}")
    
    if common_indices:
        print("\nValues at common indices:")
        for idx in common_indices:
            print(f"  Index {idx}: node_effect = {norm_tensor1[idx].item():.4f}, sae_distribution = {norm_tensor2[idx].item():.4f}")
    
    # Calculate correlation
    correlation = torch.corrcoef(torch.stack([norm_tensor1, norm_tensor2]))[0, 1]
    print(f"\nCorrelation between normalized tensors: {correlation.item():.4f}")

# Usage:
compare_top_values(node_effect, sae_distribution, top_n=20)

In [None]:
# def train_probes(
#     train_set_size: int,
#     test_set_size: int,
#     model: LanguageModel,
#     context_length: int,
#     probe_batch_size: int,
#     llm_batch_size: int,
#     device: str,
#     probe_output_filename: str,
#     probe_dir: str = "trained_bib_probes",
#     llm_model_name: str = "EleutherAI/pythia-70m-deduped",
#     epochs: int = 10,
#     model_dtype: torch.dtype = torch.bfloat16,
#     save_results: bool = True,
#     seed: int = SEED,
#     include_gender: bool = False,
# ) -> dict[int, float]:
#     """Because we save the probes, we always train them on all classes to avoid potential issues with missing classes. It's only a one-time cost."""
#     torch.manual_seed(seed)
#     random.seed(seed)
#     np.random.seed(seed)

#     model_eval_config = utils.ModelEvalConfig.from_full_model_name(llm_model_name)
#     d_model = model_eval_config.activation_dim
#     probe_layer = model_eval_config.probe_layer
#     probe_act_submodule = utils.get_submodule(model, "resid_post", probe_layer)

#     dataset, df = load_and_prepare_dataset()

#     train_bios, test_bios = get_train_test_data(
#         dataset,
#         train_set_size,
#         test_set_size,
#         include_gender,
#     )

#     train_bios = utils.tokenize_data(train_bios, model.tokenizer, context_length, device)
#     test_bios = utils.tokenize_data(test_bios, model.tokenizer, context_length, device)

#     probes, test_accuracies = {}, {}

#     all_train_acts = {}
#     all_test_acts = {}

#     with torch.no_grad():
#         for i, profession in enumerate(train_bios.keys()):
#             # if isinstance(profession, int):
#             #     continue

#             print(f"Collecting activations for profession: {profession}")

#             all_train_acts[profession] = get_all_activations(
#                 train_bios[profession], model, llm_batch_size, probe_act_submodule
#             )
#             all_test_acts[profession] = get_all_activations(
#                 test_bios[profession], model, llm_batch_size, probe_act_submodule
#             )

#     torch.set_grad_enabled(True)

#     for profession in all_train_acts.keys():
#         if profession in utils.PAIRED_CLASS_KEYS.values():
#             continue

#         train_acts, train_labels = prepare_probe_data(all_train_acts, profession, probe_batch_size)

#         test_acts, test_labels = prepare_probe_data(all_test_acts, profession, probe_batch_size)

#         if profession == "biased_male / biased_female" or profession == "male / female":
#             probe_epochs = 1
#         else:
#             probe_epochs = epochs

#         probe, test_accuracy = train_probe(
#             train_acts,
#             train_labels,
#             test_acts,
#             test_labels,
#             get_acts,
#             precomputed_acts=True,
#             epochs=probe_epochs,
#             dim=d_model,
#             device=device,
#             model_dtype=model_dtype,
#         )

#         probes[profession] = probe
#         test_accuracies[profession] = test_accuracy

#     if save_results:
#         only_model_name = llm_model_name.split("/")[-1]
#         os.makedirs(f"{probe_dir}", exist_ok=True)
#         os.makedirs(f"{probe_dir}/{only_model_name}", exist_ok=True)

#         with open(probe_output_filename, "wb") as f:
#             pickle.dump(probes, f)

#     return test_accuracies

In [None]:
# llm_model_name = "EleutherAI/pythia-70m-deduped"
# device = "cuda"
# train_set_size = 1000
# test_set_size = 1000
# context_length = 128
# include_gender = True
# model_dtype = torch.bfloat16

# # TODO: I think there may be a scoping issue with model and get_acts(), but we currently aren't using get_acts()
# model = LanguageModel(llm_model_name, device_map=device, dispatch=True, torch_dtype=model_dtype)
# probe_dir = "trained_bib_probes"
# only_model_name = llm_model_name.split("/")[-1]

# model_eval_config = utils.ModelEvalConfig.from_full_model_name(llm_model_name)
# probe_layer = model_eval_config.probe_layer

# probe_output_filename = (
#     f"{probe_dir}/{only_model_name}/probes_ctx_len_{context_length}_layer_{probe_layer}.pkl"
# )

# test_accuracies = train_probes(
#     train_set_size=1000,
#     test_set_size=1000,
#     model=model,
#     context_length=128,
#     probe_batch_size=500,
#     llm_batch_size=500,
#     llm_model_name=llm_model_name,
#     epochs=10,
#     device=device,
#     probe_output_filename=probe_output_filename,
#     probe_dir=probe_dir,
#     seed=SEED,
#     include_gender=include_gender,
#     model_dtype=model_dtype,
# )
# print(test_accuracies)