In [1]:
%load_ext autoreload
%autoreload 2

from dictionary_learning import AutoEncoder, ActivationBuffer, GatedAutoEncoder, JumpReluAutoEncoder
from nnsight import LanguageModel
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import zst_to_generator
from huggingface_hub import hf_hub_download, list_repo_files

import torch as t
import numpy as np
import gc

In [2]:
# from nnsight.models.UnifiedTransformer import UnifiedTransformer
# from sae_lens import SparseAutoencoder
from tokenizers.processors import TemplateProcessing

DEVICE = "cuda:0"
model = LanguageModel("google/gemma-2-2b", attn_implementation="eager",
                      torch_dtype=t.bfloat16, device_map=DEVICE)

# loading dictionaries
layer = 19

# dictionary hyperparameters
# dict_id = 10
# expansion_factor = 64
# dictionary_size = expansion_factor * activation_dim

submodules = []
dictionaries = {}
use_inputs = {}

# For Gemma 2
import numpy as np

def _params_path_to_sae(params_path, torch_dtype=t.bfloat16):
    params = np.load(path_to_sae_params)
    pt_params = {k: t.from_numpy(v).cuda() for k, v in params.items()}  
    sae = JumpReluAutoEncoder(params['W_enc'].shape[0], params['W_enc'].shape[1]).to("cuda").to(torch_dtype)
    sae.load_state_dict(pt_params)
    sae = sae.to(DEVICE)
    return sae

def _get_l0_nearest_100_filename(repo_id, layer):
    repo_files = list_repo_files(repo_id=repo_id)
    attn_files = [f for f in repo_files if f.startswith(f"layer_{layer}/width_16k")]
    attn_files = [f for f in attn_files if "canonical" not in f]
    submod_l0s = [int(f.split("average_")[1].split("/")[0].split("_")[-1]) for f in attn_files]
    distances_to_100 = [abs(x - 100) for x in submod_l0s]
    idx_mindistance = np.argmin(distances_to_100)
    submod_file_nearest_100 = attn_files[idx_mindistance]
    return submod_file_nearest_100

repo_id = "google/gemma-scope-2b-pt-{submod_name}"
# filename = "layer_{layer_idx}/width_16k/canonical/params.npz"

# submodules.append(model.model.embed_tokens)
# path_to_sae_params = hf_hub_download(
#     repo_id=repo_id.format(submod_name="res"),
#     filename="embedding/width_4k/average_l0_44/params.npz",
#     force_download=False
# )
# sae = _params_path_to_sae(path_to_sae_params)
# dictionaries[model.model.embed_tokens] = sae
# use_inputs[model.model.embed_tokens] = False

for layer in range(layer + 1):
    submodules.append(model.model.layers[layer].self_attn.o_proj)
    # get attn filename (no canonical file provided)
    # median_l0_filename = _get_median_l0_filename(repo_id.format(submod_name="att"), layer)
    l0_filename = _get_l0_nearest_100_filename(repo_id.format(submod_name="att"), layer)
    path_to_sae_params = hf_hub_download(
        repo_id=repo_id.format(submod_name="att"),
        filename=l0_filename,
        force_download=False
    )
    sae = _params_path_to_sae(path_to_sae_params)
    dictionaries[model.model.layers[layer].self_attn.o_proj] = sae
    use_inputs[model.model.layers[layer].self_attn.o_proj] = True

    # median_l0_filename = _get_median_l0_filename(repo_id.format(submod_name="mlp"), layer)
    l0_filename = _get_l0_nearest_100_filename(repo_id.format(submod_name="mlp"), layer)
    submodules.append(model.model.layers[layer].post_feedforward_layernorm)
    path_to_sae_params = hf_hub_download(
        repo_id=repo_id.format(submod_name="mlp"),
        filename=l0_filename,
        force_download=False
    )
    sae = _params_path_to_sae(path_to_sae_params)
    dictionaries[model.model.layers[layer].post_feedforward_layernorm] = sae
    use_inputs[model.model.layers[layer].post_feedforward_layernorm] = False

    # median_l0_filename = _get_median_l0_filename(repo_id.format(submod_name="res"), layer)
    l0_filename = _get_l0_nearest_100_filename(repo_id.format(submod_name="res"), layer)
    submodules.append(model.model.layers[layer])
    path_to_sae_params = hf_hub_download(
        repo_id=repo_id.format(submod_name="res"),
        filename=l0_filename,
        force_download=False
    )
    sae = _params_path_to_sae(path_to_sae_params)
    dictionaries[model.model.layers[layer]] = sae
    use_inputs[model.model.layers[layer]] = False

# if component == 'resid':
#     submodule = resids[layer]

# activation_dim=4096
activation_dim=2304


# the GPT-2 SAEs expect a BOS token at start of sequence. nnsight doesn't do this,
# so we need to tell the tokenizer to always do this
# model.tokenizer._tokenizer.post_processor = TemplateProcessing(
#     single=model.tokenizer.bos_token + " $A",
#     special_tokens=[(model.tokenizer.bos_token, model.tokenizer.bos_token_id)]
# )

# dictionaries = {}
# for i in (16,):
#     ae = GatedAutoEncoder(4096, 32768).to("cuda")
#     ae.load_state_dict(t.load(f'llama_saes/layer{i}/ae_81920.pt'))
#     ae = ae.half()
#     dictionaries[resids[i]] = ae
#     break
#     # obj = t.load(f'llama_saes/layer{i}/ae_81920.pt')
# #     # print(obj)
# for i in (13,):
#     path_to_params = hf_hub_download(
#         repo_id="google/gemma-scope-2b-pt-res",
#         filename=f"layer_{i}/width_16k/canonical/params.npz",
#         force_download=False,
#     )
#     params = np.load(path_to_params)
#     pt_params = {k: t.from_numpy(v).cuda() for k, v in params.items()}
#     ae = JumpReLUSAE(params["W_enc"].shape[0], params["W_enc"].shape[1]).to("cuda")
#     ae.load_state_dict(pt_params)
#     dictionaries[resids[i]] = ae

# resids = list(dictionaries.keys())

In [3]:
import random
import json

def load_examples_prefix_len(dataset, num_examples, model, seed=12, pad_to_length=None, length=None,
                  ignore_patch=False):
    examples = []
    dataset_items = open(dataset).readlines()
    random.seed(seed)
    random.shuffle(dataset_items)
    for line in dataset_items:
        data = json.loads(line)
        clean_prefix = model.tokenizer(data["clean_prefix"], return_tensors="pt",
                                        padding=False).input_ids
        patch_prefix = model.tokenizer(data["patch_prefix"], return_tensors="pt",
                                        padding=False).input_ids
        clean_answer = model.tokenizer(data["clean_answer"], return_tensors="pt",
                                        padding=False).input_ids
        patch_answer = model.tokenizer(data["patch_answer"], return_tensors="pt",
                                        padding=False).input_ids

        clean_prefix_firstsent = data["clean_prefix"].split(".")[0]
        clean_prefix_firstsent_tok = model.tokenizer(clean_prefix_firstsent, return_tensors="pt",
                                                     padding=False).input_ids
        
        # remove BOS tokens from answers
        clean_answer = clean_answer[clean_answer != model.tokenizer.bos_token_id].unsqueeze(0)
        patch_answer = patch_answer[patch_answer != model.tokenizer.bos_token_id].unsqueeze(0)
        # only keep examples where answers are single tokens
        if not ignore_patch:
            if clean_prefix.shape[1] != patch_prefix.shape[1]:
                continue
        # only keep examples where clean and patch answers are the same length
        if clean_answer.shape[1] != 1 or patch_answer.shape[1] != 1:
            continue
        # if we specify a `length`, filter examples if they don't match
        if length and clean_prefix_firstsent_tok.shape[1] != length:
            continue
        # if we specify `pad_to_length`, left-pad all inputs to a max length
        prefix_length_wo_pad = clean_prefix.shape[1]
        if pad_to_length:
            model.tokenizer.padding_side = 'right'
            pad_length = pad_to_length - prefix_length_wo_pad
            if pad_length < 0:  # example too long
                continue
            # left padding: reverse, right-pad, reverse
            clean_prefix = t.flip(F.pad(t.flip(clean_prefix, (1,)), (0, pad_length), value=model.tokenizer.pad_token_id), (1,))
            patch_prefix = t.flip(F.pad(t.flip(patch_prefix, (1,)), (0, pad_length), value=model.tokenizer.pad_token_id), (1,))
        
        example_dict = {"clean_prefix": clean_prefix,
                        "patch_prefix": patch_prefix,
                        "clean_answer": clean_answer.item(),
                        "patch_answer": patch_answer.item(),
                        # "annotations": get_annotation(dataset, model, data),
                        "prefix_length_wo_pad": prefix_length_wo_pad,}
        examples.append(example_dict)
        if len(examples) >= num_examples:
            break

    return examples

In [4]:
from loading_utils import load_examples

data_path = "data/NPS_gp_post_readingcomp_samelen.json"
# data_path = "data/MVRR_ambiguous_samelen.json"
ignore_patch = True
num_examples = 100
length = 9
pad_length = 32

examples = load_examples_prefix_len(data_path, num_examples, model, length=length, #pad_to_length=pad_length
                                     ignore_patch=False)
# examples = load_examples(data_path, num_examples, model, length=length, # pad_to_length=pad_length,
#                                      ignore_patch=True)
print(len(examples))

24


In [5]:
from activation_utils import SparseAct

tracer_kwargs = {'validate' : False, 'scan' : False}

def _pe_ig(
        clean,
        patch,
        model,
        submodules,
        dictionaries,
        metric_fn,
        use_inputs=None,
        steps=10,
        metric_kwargs=dict(),
):
    if use_inputs is None:
        for submodule in submodules:
            use_inputs[submodule] = False

    # first run through a test input to figure out which hidden states are tuples
    is_tuple = {}
    with model.trace("_"):
        for submodule in submodules:
            is_tuple[submodule] = type(submodule.output.shape) == tuple

    hidden_states_clean = {}
    with model.trace(clean, **tracer_kwargs), t.no_grad():
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            x = submodule.output if not use_inputs[submodule] else submodule.input
            if use_inputs[submodule]:
                x = x[0][0]
            elif is_tuple[submodule]:
                x = x[0]
            f = dictionary.encode(x)
            x_hat = dictionary.decode(f)
            residual = x - x_hat
            hidden_states_clean[submodule] = SparseAct(act=f.save(), res=residual.save())
        metric_clean = metric_fn(model, **metric_kwargs).save()
    hidden_states_clean = {k : v.value for k, v in hidden_states_clean.items()}

    if patch is None:
        hidden_states_patch = {
            k : SparseAct(act=t.zeros_like(v.act), res=t.zeros_like(v.res)) for k, v in hidden_states_clean.items()
        }
        total_effect = None
    else:
        hidden_states_patch = {}
        with model.trace(patch, **tracer_kwargs), t.no_grad():
            for submodule in submodules:
                dictionary = dictionaries[submodule]
                x = submodule.output if not use_inputs[submodule] else submodule.input
                if use_inputs[submodule]:
                    x = x[0][0]
                elif is_tuple[submodule]:
                    x = x[0]
                f = dictionary.encode(x)
                x_hat = dictionary.decode(f)
                residual = x - x_hat
                hidden_states_patch[submodule] = SparseAct(act=f.save(), res=residual.save())
            metric_patch = metric_fn(model, **metric_kwargs).save()
        total_effect = (metric_patch.value - metric_clean.value).detach()
        hidden_states_patch = {k : v.value for k, v in hidden_states_patch.items()}

    effects = {}
    deltas = {}
    grads = {}
    for submodule in submodules:
        dictionary = dictionaries[submodule]
        clean_state = hidden_states_clean[submodule]
        patch_state = hidden_states_patch[submodule]
        with model.trace(**tracer_kwargs) as tracer:
            metrics = []
            fs = []
            for step in range(steps):
                alpha = step / steps
                f = (1 - alpha) * clean_state + alpha * patch_state
                f.act.retain_grad()
                f.res.retain_grad()
                fs.append(f)
                with tracer.invoke(clean, scan=tracer_kwargs['scan']):
                    if use_inputs[submodule]:
                        submodule.input[0][0][:] = dictionary.decode(f.act) + f.res
                    elif is_tuple[submodule]:
                        submodule.output[0][:] = dictionary.decode(f.act) + f.res
                    else:
                        submodule.output = dictionary.decode(f.act) + f.res
                    metrics.append(metric_fn(model, **metric_kwargs))
            metric = sum([m for m in metrics])
            metric.sum().backward(retain_graph=True) # TODO : why is this necessary? Probably shouldn't be, contact jaden

        mean_grad = sum([f.act.grad for f in fs]) / steps
        mean_residual_grad = sum([f.res.grad for f in fs]) / steps
        grad = SparseAct(act=mean_grad, res=mean_residual_grad)
        delta = (patch_state - clean_state).detach() if patch_state is not None else -clean_state.detach()
        effect = grad @ delta

        effects[submodule] = effect
        deltas[submodule] = delta
        grads[submodule] = grad

    return (effects, deltas, grads, total_effect)

In [6]:
import math
import numpy as np
from tqdm import tqdm

batch_size = 1
num_examples = 100
device = "cuda"
num_examples = min([num_examples, len(examples)])
n_batches = math.ceil(len(examples) / batch_size)
batches = [
    examples[batch*batch_size:(batch+1)*batch_size] for batch in range(n_batches)
]

running_total = 0
nodes = None

for batch in tqdm(batches):
    clean_answer_idxs = t.tensor([e['clean_answer'] for e in batch], dtype=t.long, device=device)
    clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to(device)

    patch_answer_idxs = t.tensor([e['patch_answer'] for e in batch], dtype=t.long, device=device)
    patch_inputs = t.cat([e['patch_prefix'] for e in batch], dim=0).to(device)
    def metric_fn(model):
        return (
            t.gather(model.lm_head.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) - \
            t.gather(model.lm_head.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
        )

    # for example in examples[:1]:
    effects, _, _, _ = _pe_ig(
        clean_inputs,
        # patch_inputs,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        use_inputs=use_inputs
    )
    with t.no_grad():
        if nodes is None:
            # nodes = {k : len(clean_inputs) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
            nodes = {k : len(clean_inputs) * v.mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                # nodes[k] += len(clean_inputs) * v.sum(dim=1).mean(dim=0)
                nodes[k] += len(clean_inputs) * v.mean(dim=0)
        running_total += len(clean_inputs)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

for i, node in enumerate(nodes):
    t.save(nodes[node], open(f"node_effects/effects_NPS_readingcomp/node_{i}.pt", "wb"))

# print("negative effects")
# for idx, submodule in enumerate(resids):
#     print(f"resid_{idx}")
#     v, i = t.topk(sum_effects[submodule].flatten(), 10, largest=False)
#     print(np.array(np.unravel_index(i.cpu().numpy(), sum_effects[submodule].shape)).T)
#     print(v)
#     print()

  0%|          | 0/24 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


cuda:0


100%|██████████| 24/24 [04:27<00:00, 11.14s/it]


In [6]:
nodes = {}
for i, submodule in enumerate(submodules):
    nodes[submodule] = t.load(f"node_effects/effects_NPS_readingcomp/node_{i}.pt")

In [24]:
nodes[submodules[-1]].act.shape
idxs = (nodes[submodules[-1]].act[:length, :] > .05).nonzero()
idx = idxs[0]
nodes[submodules[-1]].act[idx[0], idx[1]]

tensor(0.0737, device='cuda:0', dtype=torch.bfloat16)

In [31]:
filter_bos = True
if filter_bos:
    start_idx = 1
else:
    start_idx = 0

print("positive and negative effects")
n_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    for idx in (effect.act[start_idx:length, :].abs() > 5.0).nonzero():
        print(idx, effect[idx[0], idx[1]].item())
        n_features += 1
print(f"total features: {n_features}")

positive and negative effects
Component 0:
Component 1:
Component 2:
Component 3:
Component 4:
Component 5:
Component 6:
Component 7:
Component 8:
tensor([    0, 15089], device='cuda:0') 38.5
Component 9:
Component 10:
Component 11:
tensor([ 0, 58], device='cuda:0') 7.21875
tensor([   0, 9134], device='cuda:0') -8.5
tensor([    0, 14238], device='cuda:0') 8.8125
Component 12:
Component 13:
Component 14:
Component 15:
Component 16:
Component 17:
tensor([   0, 1059], device='cuda:0') -6.5
tensor([   0, 8392], device='cuda:0') 12.4375
tensor([   1, 3235], device='cuda:0') -5.46875
Component 18:
Component 19:
Component 20:
tensor([   0, 4478], device='cuda:0') -6.09375
tensor([    0, 14077], device='cuda:0') -6.15625
Component 21:
Component 22:
Component 23:
tensor([  0, 110], device='cuda:0') -5.25
tensor([   0, 3928], device='cuda:0') -7.125
tensor([   0, 4105], device='cuda:0') 6.84375
tensor([   0, 4137], device='cuda:0') 5.34375
tensor([   0, 4411], device='cuda:0') 5.21875
tensor([  

In [29]:
# interpret features with Neuronpedia API

def _submodule_idx_to_name(submodule_idx):
    if submodule_idx % 3 == 0:
        layer = submodule_idx // 3
        return f"attn_{layer}"
    elif submodule_idx % 3 == 1:
        layer = submodule_idx // 3
        return f"mlp_{layer}"
    elif submodule_idx % 3 == 2:
        layer = submodule_idx // 3
        return f"resid_{layer}"
    return f"{submodule_name}_{layer}"

submodule_idx = 0
feature_idx = 1608
submodule_name = _submodule_idx_to_name(submodule_idx)

def _format_activations(activations, top_n=10):
    formatted_outputs = []

    # Process the top N activations
    for activation in activations[:top_n]:
        tokens = activation["tokens"]
        values = activation["values"]

        # Find the index of the maximum activation value
        max_value_index = values.index(max(values))

        # Determine the range of tokens to include (10 before and 10 after the max token)
        start_index = max(0, max_value_index - 10)
        end_index = min(len(tokens), max_value_index + 11)  # +11 because the range is inclusive of the max token

        # Slice the tokens and values accordingly
        tokens_slice = tokens[start_index:end_index]

        # Create the formatted string
        formatted_string = ""
        for i, token in enumerate(tokens_slice):
            # Replace special characters ▁ with spaces and underscores with spaces
            token = token.replace("▁", " ").replace("_", " ")

            if start_index + i == max_value_index:
                formatted_string += f"<<{token}>>"
            else:
                formatted_string += token

        formatted_outputs.append(formatted_string)

    return formatted_outputs

from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release="gemma-2-2b", sae_id=None, feature_idx=feature_idx):
    return html_template.format(sae_release, sae_id, feature_idx)

# Extract the type and number from submodule_name
submodule_type, submodule_number = submodule_name.split('_')

# Construct the sae_id based on the type
if submodule_type == "resid":
    sae_id = f"{submodule_number}-gemmascope-res-16k"
elif submodule_type == "attn":
    sae_id = f"{submodule_number}-gemmascope-att-16k"
elif submodule_type == "mlp":
    sae_id = f"{submodule_number}-gemmascope-mlp-16k"
else:
    raise ValueError("Unknown submodule type")

html = get_dashboard_html(sae_release="gemma-2-2b", sae_id=sae_id, feature_idx=feature_idx)
IFrame(html, width=800, height=400)

In [None]:
# load circuits, analyze
circuit_path = "circuits/NPS_ambiguous_samelen_dict10_node0.5_edge0.05_n24_aggnone_gpt2.pt"
circuit = t.load(open(circuit_path, 'rb'))

for submod in circuit["nodes"]:
    effects = circuit["nodes"][submod]
    top_effects = t.topk(effects.act, 10)
    print(submod, top_effects)

In [23]:
with model.trace("testing 1"):
    out_save = resids[4].output
    f = dictionaries[resids[4]].encode(out_save).save()
print(f.value.shape)

torch.Size([1, 2, 24576])
