In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import sys
import os
import random
import gc
from collections import defaultdict
import einops
import math
import numpy as np
import pickle

import torch
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from tqdm import tqdm
from typing import Callable, Optional

from datasets import load_dataset
import datasets
from nnsight import LanguageModel

import experiments.utils as utils
from experiments.probe_training import *
from experiments.pipeline_config import PipelineConfig

from dictionary_learning.dictionary import AutoEncoder

# Configuration
DEBUGGING = False
SEED = 42

# Set up paths and model
parent_dir = os.path.abspath("..")
sys.path.append(parent_dir)

tracer_kwargs = dict(scan=DEBUGGING, validate=DEBUGGING)


In [None]:
@torch.no_grad()
def get_all_activations(
    text_inputs: list[str], model: LanguageModel, batch_size: int, submodule: utils.submodule_alias
) -> torch.Tensor:
    # TODO: Rename text_inputs
    text_batches = utils.batch_inputs(text_inputs, batch_size)

    all_acts_list_BD = []
    for text_batch_BL in text_batches:
        with model.trace(
            text_batch_BL,
            **tracer_kwargs,
        ):
            attn_mask = model.input[1]["attention_mask"]
            acts_BLD = submodule.output[0]
            acts_BLD = acts_BLD * attn_mask[:, :, None]
            acts_BD = acts_BLD.sum(1) / attn_mask.sum(1)[:, None]
            acts_BD = acts_BD.save()
        all_acts_list_BD.append(acts_BD.value)

    all_acts_bD = torch.cat(all_acts_list_BD, dim=0)
    return all_acts_bD

In [None]:
llm_model_name = "google/gemma-2-2b"
device = "cuda"
train_set_size = 4000
test_set_size = 1000
context_length = 128
include_gender = True
model_dtype = torch.bfloat16

probe_batch_size = 500
llm_batch_size = 32

# TODO: I think there may be a scoping issue with model and get_acts(), but we currently aren't using get_acts()
model = LanguageModel(
    llm_model_name,
    device_map=device,
    dispatch=True,
    torch_dtype=model_dtype,
    attn_implementation="eager",
)
probe_dir = "trained_bib_probes"
only_model_name = llm_model_name.split("/")[-1]

model_eval_config = utils.ModelEvalConfig.from_full_model_name(llm_model_name)
probe_layer = model_eval_config.probe_layer

probe_output_filename = (
    f"{probe_dir}/{only_model_name}/probes_ctx_len_{context_length}_layer_{probe_layer}.pkl"
)

epochs = 10
save_results = True
seed = SEED
include_gender = True


In [None]:

"""Because we save the probes, we always train them on all classes to avoid potential issues with missing classes. It's only a one-time cost."""
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

model_eval_config = utils.ModelEvalConfig.from_full_model_name(llm_model_name)
d_model = model_eval_config.activation_dim
probe_layer = model_eval_config.probe_layer
probe_act_submodule = utils.get_submodule(model, "resid_post", probe_layer)

dataset_name = "bias_in_bios"

column1_vals = ("professor", "nurse")
column2_vals = ("male", "female")

train_df, test_df = load_and_prepare_dataset(dataset_name)

train_bios, test_bios = get_train_test_data(
    train_df, test_df, dataset_name, True, train_set_size, test_set_size, 42, column1_vals, column2_vals
)

train_bios = utils.tokenize_data(train_bios, model.tokenizer, context_length, device)
test_bios = utils.tokenize_data(test_bios, model.tokenizer, context_length, device)


In [None]:
# trainer_ids = [2]

# ae_sweep_paths = {
#     # "pythia70m_sweep_standard_ctx128_0712": {
#     #     #     # "resid_post_layer_0": {"trainer_ids": None},
#     #     #     # "resid_post_layer_1": {"trainer_ids": None},
#     #     #     # "resid_post_layer_2": {"trainer_ids": None},
#     #     "resid_post_layer_3": {"trainer_ids": [6]},
#     #     #     "resid_post_layer_4": {"trainer_ids": None},
#     # },
#     # "pythia70m_sweep_topk_ctx128_0730": {
#     #     # "resid_post_layer_0": {"trainer_ids": None},
#     #     # "resid_post_layer_1": {"trainer_ids": None},
#     #     # "resid_post_layer_2": {"trainer_ids": None},
#     #     "resid_post_layer_3": {"trainer_ids": trainer_ids},
#     #     # "resid_post_layer_4": {"trainer_ids": trainer_ids},
#     # },
#     "gemma-2-2b_sweep_topk_ctx128_ef2_0824": {
#         # "resid_post_layer_3": {"trainer_ids": trainer_ids},
#         # "resid_post_layer_7": {"trainer_ids": trainer_ids},
#         "resid_post_layer_11": {"trainer_ids": trainer_ids},
#         # "resid_post_layer_15": {"trainer_ids": trainer_ids},
#         # "resid_post_layer_19": {"trainer_ids": trainer_ids},
#     },
# }

# p_config = PipelineConfig()

# sweep_name, submodule_trainers = list(ae_sweep_paths.items())[0]

# ae_group_paths = utils.get_ae_group_paths(
#     p_config.dictionaries_path, sweep_name, submodule_trainers
# )
# ae_paths = utils.get_ae_paths(ae_group_paths)
# print(ae_paths)

# ae_path = ae_paths[0]

# submodule, dictionary, sae_config = utils.load_dictionary(model, ae_path, device)

In [None]:
@torch.no_grad()
def get_all_sae_activations(
    text_inputs: list[str],
    model: LanguageModel,
    dictionary: AutoEncoder,
    batch_size: int,
    submodule: utils.submodule_alias,
) -> tuple[torch.Tensor, torch.Tensor]:
    # TODO: Rename text_inputs
    text_batches = utils.batch_inputs(text_inputs, batch_size)

    with torch.no_grad(), model.trace("_"):
        is_tuple = type(submodule.outputorch.shape) == tuple

    model_dtype = model.dtype

    all_acts_list_BD = []
    all_sae_acts_list_BF = []
    for text_batch_BL in text_batches:
        with model.trace(
            text_batch_BL,
            **tracer_kwargs,
        ):
            attn_mask = model.input[1]["attention_mask"]
            acts_BLD = submodule.output

            if is_tuple:
                acts_BLD = acts_BLD[0]

            acts_BLF = dictionary.encode(acts_BLD)
            acts_BLF = acts_BLF * attn_mask[:, :, None]
            acts_BF = acts_BLF.sum(1) / attn_mask.sum(1)[:, None]
            acts_BF = acts_BF.save()

            acts_BLD = acts_BLD * attn_mask[:, :, None]
            acts_BD = acts_BLD.sum(1) / attn_mask.sum(1)[:, None]
            acts_BD = acts_BD.save()
        all_acts_list_BD.append(acts_BD.value)
        all_sae_acts_list_BF.append(acts_BF.value.to(dtype=model_dtype))

    all_acts_bD = torch.cat(all_acts_list_BD, dim=0)
    all_sae_acts_bF = torch.cat(all_sae_acts_list_BF, dim=0)

    return all_acts_bD, all_sae_acts_bF

In [None]:
print(train_bios.keys())
print(test_bios.keys())

new_train_bios = {}
new_test_bios = {}

for key in train_bios:
    if isinstance(key, int):
        continue

    new_train_bios[key] = train_bios[key]
    new_test_bios[key] = test_bios[key]

train_bios = new_train_bios
test_bios = new_test_bios

print(train_bios.keys())
print(test_bios.keys())

In [None]:
all_train_acts = {}
all_test_acts = {}


llm_batch_size = 32

with torch.no_grad():
    for i, profession in enumerate(train_bios.keys()):
        # if isinstance(profession, int):
        #     continue

        print(f"Collecting activations for profession: {profession}")

        all_train_acts[profession] = get_all_activations(
            train_bios[profession], model, llm_batch_size, probe_act_submodule
        )
        all_test_acts[profession] = get_all_activations(
            test_bios[profession], model, llm_batch_size, probe_act_submodule
        )


In [None]:
# def print_tensor_memory_usage(tensor: torch.Tensor):
#     if not isinstance(tensor, torch.Tensor):
#         print("Input is not a tensor. Cannot calculate memory usage.")
#         return
    
#     memory = tensor.element_size() * tensor.nelement()
#     print(f"Tensor Shape: {tensor.shape}")
#     print(f"Tensor Type: {tensor.dtype}")
#     print(f"Memory usage: {memory / 1024 / 1024:.2f} MB")

# gc.collect()
# torch.cuda.empty_cache()

# model_acts = all_train_acts[0]
# print_tensor_memory_usage(model_acts)

In [None]:

def prepare_probe_data(
    all_activations: dict[int | str, torch.Tensor],
    class_idx: int | str,
    batch_size: int,
    select_top_k: Optional[int] = None,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
    """If class_idx is a string, there is a paired class idx in utils.py."""
    positive_acts_BD = all_activations[class_idx]
    device = positive_acts_BD.device

    num_positive = len(positive_acts_BD)

    if isinstance(class_idx, int):
        # Collect all negative class activations and labels
        negative_acts = []
        for idx, acts in all_activations.items():
            if idx != class_idx and isinstance(idx, int):
                negative_acts.append(acts)

        negative_acts = torch.cat(negative_acts)
    # elif class_idx == "biased_male / biased_female":
    #     male_professors = all_activations["male_professor / female_nurse"]
    #     female_nurses = all_activations["female_nurse_data_only"]
    #     males = all_activations["male / female"]
    #     females = all_activations["female_data_only"]
    #     professors = all_activations["professor / nurse"]
    #     nurses = all_activations["nurse_data_only"]

    #     mixed_data = [males, females, professors, nurses]
    #     mixed_data = torch.cat(mixed_data)
    #     shuffle_indices = torch.randperm(len(mixed_data))
    #     mixed_data = mixed_data[shuffle_indices]

    #     random_pos = torch.randperm(int(num_positive * 0.1))
    #     random_neg = torch.randperm(int(num_positive * 0.1))

    #     random_pos_data = mixed_data[random_pos]
    #     random_neg_data = mixed_data[random_neg]

    #     positive_acts_BD = torch.cat([male_professors, random_pos_data])
    #     negative_acts = torch.cat([female_nurses, random_neg_data])
    else:
        if class_idx not in utils.PAIRED_CLASS_KEYS:
            raise ValueError(f"Class index {class_idx} is not a valid class index.")

        negative_acts = all_activations[utils.PAIRED_CLASS_KEYS[class_idx]]

    # Randomly select num_positive samples from negative class
    indices = torch.randperm(len(negative_acts))[:len(positive_acts_BD)]
    selected_negative_acts_BD = negative_acts[indices]

    assert selected_negative_acts_BD.shape == positive_acts_BD.shape

    if select_top_k is not None:
        positive_distribution_D = positive_acts_BD.mean(dim=(0))
        negative_distribution_D = negative_acts.mean(dim=(0))
        distribution_diff_D = (positive_distribution_D - negative_distribution_D).abs()
        top_k_indices_D = torch.argsort(distribution_diff_D, descending=True)[:select_top_k]

        mask_D = torch.ones(distribution_diff_D.shape[0], dtype=torch.bool, device=positive_acts_BD.device)
        mask_D[top_k_indices_D] = False

        masked_positive_acts_BD = positive_acts_BD.clone()
        masked_negative_acts_BD = selected_negative_acts_BD.clone()

        masked_positive_acts_BD[:, mask_D] = 0.0
        masked_negative_acts_BD[:, mask_D] = 0.0
    else:
        masked_positive_acts_BD = positive_acts_BD
        masked_negative_acts_BD = selected_negative_acts_BD

    # Combine positive and negative samples
    combined_acts = torch.cat([masked_positive_acts_BD, masked_negative_acts_BD])

    combined_labels = torch.empty(len(combined_acts), dtype=torch.int, device=device)
    combined_labels[:num_positive] = utils.POSITIVE_CLASS_LABEL
    combined_labels[num_positive:] = utils.NEGATIVE_CLASS_LABEL

    # Shuffle the combined data
    shuffle_indices = torch.randperm(len(combined_acts))
    shuffled_acts = combined_acts[shuffle_indices]
    shuffled_labels = combined_labels[shuffle_indices]

    # Reshape into lists of tensors with specified batch_size
    num_samples = len(shuffled_acts)
    num_batches = num_samples // batch_size

    batched_acts = [
        shuffled_acts[i * batch_size : (i + 1) * batch_size] for i in range(num_batches)
    ]
    batched_labels = [
        shuffled_labels[i * batch_size : (i + 1) * batch_size] for i in range(num_batches)
    ]

    return batched_acts, batched_labels


In [None]:
# Probe model and training
class Probe(nn.Module):
    def __init__(self, activation_dim: int, dtype: torch.dtype):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True, dtype=dtype)

    def forward(self, x):
        return self.net(x).squeeze(-1)

def train_probe(
    train_input_batches: list,
    train_label_batches: list[torch.Tensor],
    test_input_batches: list,
    test_label_batches: list[torch.Tensor],
    get_acts: Callable,
    precomputed_acts: bool,
    dim: int,
    epochs: int,
    device: str,
    model_dtype: torch.dtype,
    lr: float = 1e-2,
    seed: int = SEED,
    verbose: bool = False,
) -> tuple[Probe, float]:
    """input_batches can be a list of tensors or strings. If strings, get_acts must be provided."""

    if type(train_input_batches[0]) == str or type(test_input_batches[0]) == str:
        assert precomputed_acts == False
    elif type(train_input_batches[0]) == torch.Tensor or type(test_input_batches[0]) == torch.Tensor:
        assert precomputed_acts == True

    probe = Probe(dim, model_dtype).to(device)
    optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        batch_idx = 0
        for inputs, labels in zip(train_input_batches, train_label_batches):
            if precomputed_acts:
                acts_BD = inputs
            else:
                acts_BD = get_acts(inputs)
            logits_B = probe(acts_BD)
            loss = criterion(logits_B, labels.clone().detach().to(device=device, dtype=model_dtype))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_idx += 1

        

        train_accuracy = test_probe(
            train_input_batches[:30], train_label_batches[:30], probe, get_acts, precomputed_acts
        )


        test_accuracy = test_probe(
            test_input_batches, test_label_batches, probe, get_acts, precomputed_acts
        )

        if epoch == epochs - 1 and verbose:
            print(f"\nEpoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}\n")
    
    return probe, test_accuracy

In [None]:
from experiments.probe_training import prepare_probe_data

epochs = 10

torch.manual_seed(0)

def train_probe_on_activations(
    train_activations: dict[str | int : torch.Tensor],
    test_activations: dict[str | int : torch.Tensor],
    select_top_k: Optional[int] = None,
) -> tuple[dict[str | int : Probe], dict[str | int : float]]:
    torch.set_grad_enabled(True)

    probes, test_accuracies = {}, {}

    for profession in train_activations.keys():
        if profession in utils.PAIRED_CLASS_KEYS.values():
            continue

        train_acts, train_labels = prepare_probe_data(train_activations, profession, True, probe_batch_size, select_top_k)

        test_acts, test_labels = prepare_probe_data(test_activations, profession, True, probe_batch_size, select_top_k)

        if profession == "biased_male / biased_female" or profession == "male / female":
            probe_epochs = epochs
        else:
            probe_epochs = epochs

        activation_dim = train_acts[0].shape[1]

        probe, test_accuracy = train_probe(
            train_acts,
            train_labels,
            test_acts,
            test_labels,
            get_acts,
            precomputed_acts=True,
            epochs=probe_epochs,
            dim=activation_dim,
            device=device,
            model_dtype=model_dtype,
            verbose=False,
        )

        probes[profession] = probe
        test_accuracies[profession] = test_accuracy

    return probes, test_accuracies

    # if save_results:
    #     only_model_name = llm_model_name.split("/")[-1]
    #     os.makedirs(f"{probe_dir}", exist_ok=True)
    #     os.makedirs(f"{probe_dir}/{only_model_name}", exist_ok=True)

    #     with open(probe_output_filename, "wb") as f:
    #         pickle.dump(probes, f)


probes, test_accuracies = train_probe_on_activations(all_train_acts, all_test_acts)

In [None]:
# with open("train_accs.pkl", "wb") as f:
#     pickle.dump(all_train_acts, f)

In [None]:
all_diffs = []
int_diffs = []
str_diffs = []

for class_name in test_accuracies.keys():
    model_acc = test_accuracies[class_name][0]

    print(f"Class: {class_name}, Accuracy: {model_acc}")

In [None]:
chosen_class_indices = ["male / female", "professor / nurse", "male_professor / female_nurse", "biased_male / biased_female"]

test_accuracies = get_probe_test_accuracy(
    probes, chosen_class_indices, all_test_acts, p_config.probe_batch_size
)

for class_name in test_accuracies.keys():
    model_acc = test_accuracies[class_name]['acc']

    print(f"Class: {class_name}, Accuracy: {model_acc}")