In [None]:
import matplotlib.pyplot as plt
import json
import torch
import pickle
from typing import Optional
from matplotlib.colors import Normalize
import numpy as np
import os
import random

import experiments.utils as utils


dictionaries_path = "../dictionary_learning/dictionaries"

# Another way to generate graphs, where you manually populate sweep_name and submodule_trainers
sweep_name = "pythia70m_test_sae"
submodule_trainers = {"resid_post_layer_3": {"trainer_ids": [0]}}

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    "pythia70m_sweep_standard_ctx128_0712": {"resid_post_layer_3": {"trainer_ids": [1, 7, 11, 18]}}
}
sweep_name = list(ae_sweep_paths.keys())[0]
submodule_trainers = ae_sweep_paths[sweep_name]

filter_class_ids = []
# filter_class_ids = [-4, -2]

ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)
ae_paths = utils.get_ae_paths(ae_group_paths)

include_diff = True

print(ae_paths)

In [None]:
thresholds = [0.1, 0.05, 0.025, 0.01, 0.001]
top_ns = [1, 10, 100, 500]


for i, ae_path in enumerate(ae_paths):
    node_effects_filename = f"{ae_path}/node_effects.pkl"

    with open(node_effects_filename, "rb") as f:
        node_effects = pickle.load(f)

    effects = node_effects[-4][ae_path]

    print(f"\nEffects for {ae_path}")
    for theshold in thresholds:
        above_threshold = effects[effects > theshold]
        count_above_threshold = above_threshold.shape[0]
        avg_above_threshold = above_threshold.mean().item()
        print(
            f"Threshold {theshold}: {count_above_threshold} nodes above threshold, {avg_above_threshold:.3f} average"
        )

    for top_n in top_ns:
        top_k = torch.topk(effects, top_n)
        avg_top_k = top_k.values.mean().item()
        print(f"Top {top_n}: {avg_top_k:.3f} average")

In [None]:
for i, ae_path in enumerate(ae_paths):
    node_effects_filename = f"{ae_path}/node_effects.pkl"

    with open(node_effects_filename, "rb") as f:
        node_effects = pickle.load(f)

    print(node_effects.keys())
    print(node_effects[-2].keys())

    effects = node_effects[-4][ae_path]
    print(effects.shape)

    # Create histogram
    plt.figure(figsize=(10, 6))
    plt.hist(effects, bins=100)
    plt.ylim(0, 10)
    plt.title(f'Histogram for {ae_path.split("/")[-3]}')
    plt.xlabel("Effect Value")
    plt.ylabel("Frequency")

    # Display the plot
    plt.show()

In [None]:
for ae_path in ae_paths:
    print(ae_path)

In [None]:
import einops
import dictionary_learning.interp as interp
from circuitsvis.activations import text_neuron_activations
from collections import namedtuple
from nnsight import LanguageModel

import experiments.utils as utils
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.buffer import ActivationBuffer


DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

def examine_dimension(
    model,
    submodule,
    buffer,
    feat_idx: int,
    n_inputs: int,
    context_length: int,
    batch_size: int,
    dictionary=None,
    max_length: int = 128,
    k: int = 30,
):


    def _list_decode(x):
        if isinstance(x, int):
            return model.tokenizer.decode(x)
        else:
            return [_list_decode(y) for y in x]

    # if dim_indices is None:
        # dim_indices = random.randint(0, activations.shape[-1] - 1)

    assert n_inputs % batch_size == 0
    n_iters = n_inputs // batch_size

    device = model.device

    activations = torch.zeros((n_inputs, context_length), device=device)
    tokens = torch.zeros((n_inputs, context_length), dtype=torch.long, device=device)

    for i in tqdm(range(n_iters), desc="Collecting activations"):
        inputs_BL = buffer.tokenized_batch(batch_size=batch_size)

        with torch.no_grad(), model.trace(inputs_BL, **tracer_kwargs):
            tokens_BL = model.input[1][
                "input_ids"
            ].save()  # if you're getting errors, check here; might only work for pythia models
            activations_BLD = submodule.output
            if type(activations_BLD.shape) == tuple:
                activations_BLD = activations_BLD[0]
            if dictionary is not None:
                activations_BLF = dictionary.encode(activations_BLD)
            activations_BL = activations_BLF[:, :, feat_idx].save()

        activations[i * batch_size : (i + 1) * batch_size] = activations_BL.value
        tokens[i * batch_size : (i + 1) * batch_size] = tokens_BL.value

    token_mean_acts = {}
    for ctx in tokens:
        for tok in ctx:
            if tok.item() in token_mean_acts:
                continue
            idxs = (tokens == tok).nonzero(as_tuple=True)
            token_mean_acts[tok.item()] = activations[idxs].mean().item()
    top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k]
    top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens]

    flattened_acts = einops.rearrange(activations, "b n -> (b n)")
    topk_indices = torch.argsort(flattened_acts, dim=0, descending=True)[:k]
    batch_indices = topk_indices // activations.shape[1]
    token_indices = topk_indices % activations.shape[1]
    tokens = [
        tokens[batch_idx, : token_idx + 1].tolist()
        for batch_idx, token_idx in zip(batch_indices, token_indices)
    ]
    activations = [
        activations[batch_idx, : token_id + 1, None, None]
        for batch_idx, token_id in zip(batch_indices, token_indices)
    ]
    decoded_tokens = _list_decode(tokens)
    top_contexts = text_neuron_activations(decoded_tokens, activations)

    top_affected = interp.feature_effect(
        model, submodule, dictionary, feat_idx, tokens, max_length=max_length, k=k
    )
    top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)]

    return namedtuple("featureProfile", ["top_contexts", "top_tokens", "top_affected"])(
        top_contexts, top_tokens, top_affected
    )

DEVICE = "cuda"
model_name = "EleutherAI/pythia-70m-deduped"
model = LanguageModel(model_name, device_map=DEVICE, dispatch=True)

ae_path = ae_paths[2]
submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

context_length = config['buffer']['ctx_len']

data = hf_dataset_to_generator("monology/pile-uncopyrighted")
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    d_submodule=512,
    ctx_len=context_length,
    refresh_batch_size=128, # decrease to fit on smaller GPUs
    n_ctxs=512, # decrease to fit on smaller GPUs
    device=DEVICE
)

In [None]:
# import einops
# import dictionary_learning.interp as interp
# from circuitsvis.activations import text_neuron_activations
# from collections import namedtuple
# from nnsight import LanguageModel

# import experiments.utils as utils
# from dictionary_learning.utils import hf_dataset_to_generator
# from dictionary_learning.buffer import ActivationBuffer


# DEBUGGING = True

# if DEBUGGING:
#     tracer_kwargs = dict(scan=True, validate=True)
# else:
#     tracer_kwargs = dict(scan=False, validate=False)


# def examine_dimension(
#     model,
#     submodule,
#     buffer,
#     dim_indices: torch.Tensor,
#     context_length: int,
#     batch_size: int,
#     dictionary=None,
#     max_length: int = 128,
#     n_inputs: int = 512,
#     k: int = 30,
# ):

#     def _list_decode(x):
#         if isinstance(x, int):
#             return model.tokenizer.decode(x)
#         else:
#             return [_list_decode(y) for y in x]

#     # if dim_indices is None:
#     # dim_indices = random.randint(0, activations.shape[-1] - 1)

#     assert n_inputs % batch_size == 0
#     n_iters = n_inputs // batch_size

#     dim_count = dim_indices.shape[0]

#     device = model.device

#     activations_bLF = torch.zeros((n_inputs, context_length, dim_count), device=device)
#     tokens_bL = torch.zeros((n_inputs, context_length), dtype=torch.long, device=device)

#     for i in range(n_iters):
#         inputs_BL = buffer.tokenized_batch(batch_size=batch_size)

#         with torch.no_grad(), model.trace(inputs_BL, **tracer_kwargs):
#             tokens_BL = model.input[1][
#                 "input_ids"
#             ].save()  # if you're getting errors, check here; might only work for pythia models
#             activations_BLD = submodule.output
#             if type(activations_BLD.shape) == tuple:
#                 activations_BLD = activations_BLD[0]
#             if dictionary is not None:
#                 activations_BLF = dictionary.encode(activations_BLD)
#             activations_BLF = activations_BLF[:, :, dim_indices].save()

#         activations_bLF[i * batch_size : (i + 1) * batch_size] = activations_BLF.value
#         tokens_bL[i * batch_size : (i + 1) * batch_size] = tokens_BL.value

#     token_mean_acts = {}
#     for ctx in tokens_bL:
#         for tok in ctx:
#             if tok.item() in token_mean_acts:
#                 continue
#             idxs = (tokens_bL == tok).nonzero(as_tuple=True)
#             token_mean_acts[tok.item()] = activations_bLF[idxs].mean().item()
#     top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k]
#     top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens]

#     flattened_acts_NF = einops.rearrange(activations_bLF, "B L F -> (B L F)")
#     topk_indices = torch.argsort(flattened_acts_NF, dim=0, descending=True)[:k]
#     batch_indices = topk_indices // activations_bLF.shape[1]
#     token_indices = topk_indices % activations_bLF.shape[1]
#     tokens_bL = [
#         tokens_bL[batch_idx, : token_idx + 1].tolist()
#         for batch_idx, token_idx in zip(batch_indices, token_indices)
#     ]
#     activations_bLF = [
#         activations_bLF[batch_idx, : token_id + 1, None, None]
#         for batch_idx, token_id in zip(batch_indices, token_indices)
#     ]
#     decoded_tokens = _list_decode(tokens_bL)

#     return namedtuple(
#         "featureProfile",
#         [
#             "top_tokens",
#             "encoded_tokens_bL",
#             "decoded_tokens_bL",
#             "activations_bLF",
#         ],
#     )(top_tokens, tokens_bL, decoded_tokens, activations_bLF)


# DEVICE = "cuda"
# model_name = "EleutherAI/pythia-70m-deduped"
# model = LanguageModel(model_name, device_map=DEVICE, dispatch=True)

# ae_path = ae_paths[2]
# submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

# context_length = config["buffer"]["ctx_len"]

# data = hf_dataset_to_generator("monology/pile-uncopyrighted")
# buffer = ActivationBuffer(
#     data,
#     model,
#     submodule,
#     d_submodule=512,
#     ctx_len=context_length,
#     refresh_batch_size=128,  # decrease to fit on smaller GPUs
#     n_ctxs=512,  # decrease to fit on smaller GPUs
#     device=DEVICE,
# )

In [None]:
node_effects_filename = f"{ae_path}/node_effects.pkl"

with open(node_effects_filename, "rb") as f:
    node_effects = pickle.load(f)

effects = node_effects[-2][ae_path]

print(effects.shape)

k = 10
top_k_values, top_k_indices = torch.topk(effects, k)

print(top_k_values)
print(top_k_indices)

In [None]:
feat_idx = 0
sae_feat_idx = top_k_indices[feat_idx].item()
print(sae_feat_idx)

n_inputs = 1024
batch_size = 256

torch.cuda.empty_cache()
torch.set_grad_enabled(False)

out = examine_dimension(
    model,
    submodule,
    buffer,
    sae_feat_idx,
    n_inputs,
    context_length,
    batch_size,
    dictionary,
    max_length=context_length,
    k=30,
)

print(f'\n\ntop activating tokens for feature {sae_feat_idx}')
for token in out.top_tokens:
    print(token)
print(f'\n\ntop affected tokens for feature {sae_feat_idx}')
for token in out.top_affected:
    print(token)

out.top_contexts