In [None]:
import matplotlib.pyplot as plt
import json
import torch
import pickle
from typing import Optional
from matplotlib.colors import Normalize
import numpy as np
import os
import random
import datasets

import einops
import dictionary_learning.interp as interp
from circuitsvis.activations import text_neuron_activations
from collections import namedtuple
from nnsight import LanguageModel
from tqdm import tqdm

import experiments.utils as utils
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.buffer import ActivationBuffer

DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

In [None]:
DEVICE = "cuda"
model_name = "EleutherAI/pythia-70m-deduped"
model = LanguageModel(model_name, device_map=DEVICE, dispatch=True)


context_length = 128

dataset = datasets.load_dataset("georgeyw/dsir-pile-100k", streaming=False)

data = model.tokenizer(dataset["train"]["contents"][:10000], return_tensors="pt", padding="max_length", truncation=True, max_length=context_length).to(DEVICE).data

In [None]:
print(type(data))
batch_size = 250

batched_data = utils.batch_inputs(data, batch_size)

In [None]:
import einops
import dictionary_learning.interp as interp
from circuitsvis.activations import text_neuron_activations
from collections import namedtuple
from nnsight import LanguageModel
from tqdm import tqdm

import experiments.utils as utils
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.buffer import ActivationBuffer


DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)


def get_max_activating_prompts(
    model,
    submodule,
    tokenized_inputs_bL: list[list[dict]],
    dim_indices: torch.Tensor,
    batch_size: int,
    dictionary=None,
    n_inputs: int = 512,
    k: int = 30,
):

    assert n_inputs % batch_size == 0

    feature_count = dim_indices.shape[0]

    device = model.device

    max_activating_indices_FK = torch.zeros((feature_count, k), device=device, dtype=torch.int)
    max_activations_FK = torch.zeros((feature_count, k), device=device, dtype=torch.float32)
    max_tokens_FKL = torch.zeros((feature_count, k, context_length), device=device, dtype=torch.int)
    max_activations_FKL = torch.zeros((feature_count, k, context_length), device=device, dtype=torch.float32)

    for i, inputs in tqdm(tokenized_inputs_bL):

        batch_offset = i * batch_size
        inputs_BL = inputs['input_ids']

        with torch.no_grad(), model.trace(inputs, **tracer_kwargs):
            activations_BLD = submodule.output
            if type(activations_BLD.shape) == tuple:
                activations_BLD = activations_BLD[0]
            activations_BLF = dictionary.encode(activations_BLD)
            activations_BLF = activations_BLF[:, :, dim_indices].save()

        activations_FBL = einops.rearrange(activations_BLF.value, 'B L F -> F B L')
        # Use einops to find the max activation per input
        activations_FB = einops.reduce(activations_FBL, 'F B L -> F B', 'max')
        tokens_FBL = einops.repeat(inputs_BL, 'B L -> F B L', F=feature_count)
        
        # Keep track of input indices
        indices_B = torch.arange(batch_offset, batch_offset + batch_size, device=device)
        indices_FB = einops.repeat(indices_B, 'B -> F B', F=feature_count)

        # Concatenate current batch activations and indices with the previous ones
        combined_activations_FB = torch.cat([max_activations_FK, activations_FB], dim=1)
        combined_indices_FB = torch.cat([max_activating_indices_FK, indices_FB], dim=1)
        combined_activations_FBL = torch.cat([max_activations_FKL, activations_FBL], dim=1)
        combined_tokens_FBL = torch.cat([max_tokens_FKL, tokens_FBL], dim=1)

        # Sort and keep top k activations for each dimension
        topk_activations_FK, topk_indices_FK = torch.topk(combined_activations_FB, k, dim=1)
        max_activations_FK = topk_activations_FK

        feature_indices_F1 = torch.arange(feature_count, device=device)[:, None]
        max_activating_indices_FK = combined_indices_FB[feature_indices_F1, topk_indices_FK]
        max_activations_FKL = combined_activations_FBL[feature_indices_F1, topk_indices_FK]
        max_tokens_FKL = combined_tokens_FBL[feature_indices_F1, topk_indices_FK]
            

    return max_tokens_FKL, max_activations_FKL

dictionaries_path = "../dictionary_learning/dictionaries"

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    # "pythia70m_sweep_standard_ctx128_0712": {"resid_post_layer_3": {"trainer_ids": [11]}},
    # "pythia70m_sweep_gated_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [9]}},
    "pythia70m_sweep_topk_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [18]}},
}
sweep_name = list(ae_sweep_paths.keys())[0]
submodule_trainers = ae_sweep_paths[sweep_name]

filter_class_ids = []
# filter_class_ids = [-4, -2]

ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)
ae_paths = utils.get_ae_paths(ae_group_paths)

# TODO
# Add bias_in_bios dataset option
# Cosine sim with probes
# Vector per class probe

ae_path = ae_paths[0]
submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

node_effects_filename = f"{ae_path}/node_effects.pkl"

with open(node_effects_filename, "rb") as f:
    node_effects = pickle.load(f)

effects = node_effects[-2][ae_path]

print(effects.shape)

k = 500
top_k_values, top_k_indices = torch.topk(effects, k)
torch.set_printoptions(sci_mode=False)
print(top_k_values)
print(top_k_indices)

all_indices = torch.arange(0, effects.shape[0])
all_indices = top_k_indices
torch.cuda.empty_cache()
import gc
gc.collect()
max_tokens_FKL, max_activations_FKL = get_max_activating_prompts(model, submodule, batched_data, all_indices, batch_size, dictionary, 10000, 30)

In [None]:
for feat_idx in range(10):
    sae_feat_idx = top_k_indices[feat_idx].item()
    print(f"Feature index: {feat_idx}")
    print(sae_feat_idx)

    encoded_tokens_KL = max_tokens_FKL[feat_idx].tolist()
    activations_KL = max_activations_FKL[feat_idx]

    activations_KL11 = [activations_KL[k, :, None, None] for k in range(activations_KL.shape[0])]
    def _list_decode(x):
        if isinstance(x, int):
            return model.tokenizer.decode(x)
        else:
            return [_list_decode(y) for y in x]

    decoded_tokens_KL = _list_decode(encoded_tokens_KL)
    top_contexts = text_neuron_activations(decoded_tokens_KL, activations_KL11)
    display(top_contexts)

In [None]:

thresholds = [0.1, 0.05, 0.025, 0.01, 0.001]
top_ns = [1, 10, 100, 500]


for i, ae_path in enumerate(ae_paths):
    node_effects_filename = f"{ae_path}/node_effects.pkl"

    with open(node_effects_filename, "rb") as f:
        node_effects = pickle.load(f)

    print(node_effects.keys())

    effects = node_effects[-4][ae_path]

    print(f"\nEffects for {ae_path}")
    for theshold in thresholds:
        above_threshold = effects[effects > theshold]
        count_above_threshold = above_threshold.shape[0]
        avg_above_threshold = above_threshold.mean().item()
        print(
            f"Threshold {theshold}: {count_above_threshold} nodes above threshold, {avg_above_threshold:.3f} average"
        )

    for top_n in top_ns:
        top_k = torch.topk(effects, top_n)
        avg_top_k = top_k.values.mean().item()
        print(f"Top {top_n}: {avg_top_k:.3f} average")

In [None]:
for i, ae_path in enumerate(ae_paths):
    node_effects_filename = f"{ae_path}/node_effects.pkl"

    with open(node_effects_filename, "rb") as f:
        node_effects = pickle.load(f)

    print(node_effects.keys())
    print(node_effects[-2].keys())

    effects = node_effects[-4][ae_path]
    print(effects.shape)

    # Create histogram
    plt.figure(figsize=(10, 6))
    plt.hist(effects, bins=100)
    plt.ylim(0, 10)
    plt.title(f'Histogram for {ae_path.split("/")[-3]}')
    plt.xlabel("Effect Value")
    plt.ylabel("Frequency")

    # Display the plot
    plt.show()

In [None]:
# def examine_dimension(
#     model,
#     submodule,
#     buffer,
#     feat_idx: int,
#     n_inputs: int,
#     context_length: int,
#     batch_size: int,
#     dictionary=None,
#     max_length: int = 128,
#     k: int = 30,
# ):


#     def _list_decode(x):
#         if isinstance(x, int):
#             return model.tokenizer.decode(x)
#         else:
#             return [_list_decode(y) for y in x]

#     # if dim_indices is None:
#         # dim_indices = random.randint(0, activations.shape[-1] - 1)

#     assert n_inputs % batch_size == 0
#     n_iters = n_inputs // batch_size

#     device = model.device

#     activations = torch.zeros((n_inputs, context_length), device=device)
#     tokens = torch.zeros((n_inputs, context_length), dtype=torch.long, device=device)

#     for i in tqdm(range(n_iters), desc="Collecting activations"):
#         inputs_BL = buffer.tokenized_batch(batch_size=batch_size)

#         with torch.no_grad(), model.trace(inputs_BL, **tracer_kwargs):
#             tokens_BL = model.input[1][
#                 "input_ids"
#             ].save()  # if you're getting errors, check here; might only work for pythia models
#             activations_BLD = submodule.output
#             if type(activations_BLD.shape) == tuple:
#                 activations_BLD = activations_BLD[0]
#             if dictionary is not None:
#                 activations_BLF = dictionary.encode(activations_BLD)
#             activations_BL = activations_BLF[:, :, feat_idx].save()

#         activations[i * batch_size : (i + 1) * batch_size] = activations_BL.value
#         tokens[i * batch_size : (i + 1) * batch_size] = tokens_BL.value

#     token_mean_acts = {}
#     for ctx in tokens:
#         for tok in ctx:
#             if tok.item() in token_mean_acts:
#                 continue
#             idxs = (tokens == tok).nonzero(as_tuple=True)
#             token_mean_acts[tok.item()] = activations[idxs].mean().item()
#     top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k]
#     top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens]

#     flattened_acts = einops.rearrange(activations, "b n -> (b n)")
#     topk_indices = torch.argsort(flattened_acts, dim=0, descending=True)[:k]
#     batch_indices = topk_indices // activations.shape[1]
#     token_indices = topk_indices % activations.shape[1]
#     tokens = [
#         tokens[batch_idx, : token_idx + 1].tolist()
#         for batch_idx, token_idx in zip(batch_indices, token_indices)
#     ]
#     activations = [
#         activations[batch_idx, : token_id + 1, None, None]
#         for batch_idx, token_id in zip(batch_indices, token_indices)
#     ]
#     decoded_tokens = _list_decode(tokens)
#     top_contexts = text_neuron_activations(decoded_tokens, activations)

#     top_affected = interp.feature_effect(
#         model, submodule, dictionary, feat_idx, tokens, max_length=max_length, k=k
#     )
#     top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)]

#     return namedtuple("featureProfile", ["top_contexts", "top_tokens", "top_affected"])(
#         top_contexts, top_tokens, top_affected
#     )

# # DEVICE = "cuda"
# # model_name = "EleutherAI/pythia-70m-deduped"
# # model = LanguageModel(model_name, device_map=DEVICE, dispatch=True)

# # ae_path = ae_paths[2]
# # submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

# # context_length = config['buffer']['ctx_len']

# # data = hf_dataset_to_generator("monology/pile-uncopyrighted")
# # buffer = ActivationBuffer(
# #     data,
# #     model,
# #     submodule,
# #     d_submodule=512,
# #     ctx_len=context_length,
# #     refresh_batch_size=128, # decrease to fit on smaller GPUs
# #     n_ctxs=512, # decrease to fit on smaller GPUs
# #     device=DEVICE
# # )

# feat_idx = 0
# sae_feat_idx = top_k_indices[feat_idx].item()
# print(sae_feat_idx)

# n_inputs = 1024
# batch_size = 256

# torch.cuda.empty_cache()
# torch.set_grad_enabled(False)

# out = examine_dimension(
#     model,
#     submodule,
#     buffer,
#     sae_feat_idx,
#     n_inputs,
#     context_length,
#     batch_size,
#     dictionary,
#     max_length=context_length,
#     k=30,
# )

# print(f'\n\ntop activating tokens for feature {sae_feat_idx}')
# for token in out.top_tokens:
#     print(token)
# print(f'\n\ntop affected tokens for feature {sae_feat_idx}')
# for token in out.top_affected:
#     print(token)

# out.top_contexts

In [None]:
def get_max_activating_prompts_old(
    model,
    submodule,
    inputs_bL: list[str],
    dim_indices: torch.Tensor,
    batch_size: int,
    dictionary=None,
    n_inputs: int = 512,
    k: int = 30,
):

    assert n_inputs % batch_size == 0
    n_iters = n_inputs // batch_size

    dim_count = dim_indices.shape[0]

    device = model.device

    max_activating_indices_FK = torch.zeros((dim_count, k), device=device, dtype=torch.int)
    max_activations_FK = torch.zeros((dim_count, k), device=device, dtype=torch.float32)

    for i in range(n_iters):

        batch_offset = i * batch_size

        inputs_BL = inputs_bL[batch_offset : batch_offset + batch_size]

        with torch.no_grad(), model.trace(inputs_BL, **tracer_kwargs):
            activations_BLD = submodule.output
            if type(activations_BLD.shape) == tuple:
                activations_BLD = activations_BLD[0]
            activations_BLF = dictionary.encode(activations_BLD)
            activations_BLF = activations_BLF[:, :, dim_indices].save()

        # Use einops to find the max activation per input
        activations_FB = einops.reduce(activations_BLF.value, 'B L F _> F B', 'max')
        
        # Keep track of input indices
        indices_B = torch.arange(batch_offset, batch_offset + batch_size, device=device)
        indices_FB = einops.repeat(indices_B, 'B -> F B', F=dim_count)

        # Concatenate current batch activations and indices with the previous ones
        combined_activations_FK = torch.cat([max_activations_FK, activations_FB], dim=1)
        combined_indices_FK = torch.cat([max_activating_indices_FK, indices_FB], dim=1)

        # Sort and keep top k activations for each dimension
        topk_activations_FK, topk_indices = torch.topk(combined_activations_FK, k, dim=1)
        max_activations_FK = topk_activations_FK
        max_activating_indices_FK = torch.gather(combined_indices_FK, 1, topk_indices)

    return max_activating_indices_FK, max_activations_FK


