In [1]:
from dictionary_learning import AutoEncoder, ActivationBuffer, GatedAutoEncoder
from nnsight import LanguageModel
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import zst_to_generator
import torch as t
import gc

In [125]:
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cuda:0', dispatch=True)
layer = 0
component = 'resid'

if component == 'resid':
    submodule = model.gpt_neox.layers[layer]
elif component == 'attn':
    submodule = model.gpt_neox.layers[layer].attention
elif component == 'mlp':
    submodule = model.gpt_neox.layers[layer].mlp
elif component == 'embed':
    submodule = model.gpt_neox.embed_in

activation_dim=512

buffer = ActivationBuffer(
    zst_to_generator('/share/data/datasets/pile/the-eye.eu/public/AI/pile/train/00.jsonl.zst'),
    model,
    submodule,
    io='out',
    in_feats=activation_dim,
    out_feats=activation_dim,
    in_batch_size=128,
    out_batch_size=2 ** 13,
    n_ctxs=1e4,
    device='cuda:0',
)

ae = AutoEncoder(activation_dim, 64 * activation_dim).cuda()
if component != 'embed':
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/{component}_out_layer{layer}/10_32768/ae.pt'))
else:
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/embed/10_32768/ae.pt'))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [127]:
try:
    del fp
    gc.collect()
except:
    pass

feature = 1565

fp = examine_dimension(
    model,
    submodule,
    buffer,
    ae,
    dim_idx = feature,
    n_inputs=256
)

print(fp.top_tokens)
print(fp.top_affected)
fp.top_contexts

LanguageModelProxy (argument_2): FakeTensor(..., device='cuda:0', size=(256, 128, 512),
           grad_fn=<EmbeddingBackward0>)
[(' driver', 0.7144788503646851), (' hood', 0.0361337810754776), ('DW', 0.030997931957244873), ('rette', 0.029590070247650146), (' dwarf', 0.013755664229393005), ('Hol', 0.0), ('iday', 0.0), (' P', 0.0), ('unch', 0.0), (' —', 0.0), (' Plus', 0.0), (' a', 0.0), (' Co', 0.0), ('zy', 0.0), (' Fire', 0.0), ('\n', 0.0), ('Charles', 0.0), (' Dick', 0.0), ('ens', 0.0), (' gave', 0.0), (' us', 0.0), (' so', 0.0), (' much', 0.0), ('.', 0.0), (' In', 0.0), ('cluding', 0.0), (' this', 0.0), ('In', 0.0), (' A', 0.0), (' Christmas', 0.0)]
[('lessly', 0.26721781492233276), ('lessness', 0.25997886061668396), (',[@', 0.24340176582336426), ('.[@', 0.23646807670593262), (',^[@', 0.21933341026306152), ('less', 0.20868515968322754), ('doms', 0.20835530757904053), (' since', 0.2060118019580841), ('.^[@', 0.20522615313529968), (';', 0.20469340682029724), (' of', 0.1955747753381729

In [79]:
import torch as t

threshold = 0.2
e_threshold = threshold / 10.

c = t.load(f'circuits/within_rc_train_dict10_node{threshold}_edge{e_threshold}_n100_aggnone.pt')
nodes = c['nodes']

for component, x in nodes.items():
    if component == 'embed': comp, layer = 'embed', ""
    else:
        comp, layer = component.split('_')

    for idx in (x.act.abs() > threshold).nonzero():
        print(f"{comp}\t{layer}\t{idx[-1]}")

embed		293
embed		11649
embed		17003
embed		27441
mlp	0	1901
mlp	0	10355
mlp	0	17834
mlp	0	18502
resid	0	2494
resid	0	10630
resid	0	23760
resid	1	4678
resid	1	15769
resid	1	17151
resid	1	18001
resid	1	32616
resid	2	10995
resid	2	14779
resid	3	8437
resid	3	11981
resid	3	20009
mlp	4	4523
mlp	4	15560
resid	4	8913
resid	4	11586
resid	4	14719
resid	4	16089
attn	5	25516
resid	5	295
resid	5	9340
resid	5	11839
resid	5	18629


In [98]:
def experiment(text, feature):
    print(text)
    component, feat_idx = feature.split('/')
    component, layer = component.split('_')
    layer, feat_idx = int(layer), int(feat_idx)
    ae = AutoEncoder(512, 32768).to('cuda:0')
    if component == 'resid':
        submodule = model.gpt_neox.layers[layer]
        ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{layer}/10_32768/ae.pt'))
    elif component == 'attn':
        submodule = model.gpt_neox.layers[layer].attention
        ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/attn_out_layer{layer}/10_32768/ae.pt'))
    elif component == 'mlp':
        submodule = model.gpt_neox.layers[layer].mlp
        ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{layer}/10_32768/ae.pt'))
    elif component == 'embed':
        submodule = model.gpt_neox.embed_in
        ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/embed/10_32768/ae.pt'))
    
    print('Original:')
    with model.trace(text):
        probs = t.softmax(model.output.logits[0, -1, :], dim=-1).save()
    for prob, idx in zip(*t.topk(probs, 5)):
        print(f'    {model.tokenizer.decode(idx)}: {prob.item()}')
    
    print(f'Ablating {feature}:')
    with model.trace(text):
        x = submodule.output[0]
        x_hat, f = ae(x, output_features=True)
        f[..., feat_idx] = 0
        submodule.output[0][:] = ae.decode(f) + (x - x_hat)
        probs = t.softmax(model.output.logits[0, -1, :], dim=-1).save()
    for prob, idx in zip(*t.topk(probs, 5)):
        print(f'    {model.tokenizer.decode(idx)}: {prob.item()}')

experiment("", 'attn_3/14579')

Sing do, re, mi,
Original:
     mi: 0.08595732599496841
     m: 0.016683943569660187
     no: 0.01310817152261734
     o: 0.012504851445555687
     s: 0.011511574499309063
Ablating attn_3/14579:
     mi: 0.085909903049469
     m: 0.01670733653008938
     no: 0.0130849564447999
     o: 0.012484229169785976
     s: 0.011529122479259968


In [128]:
from collections import defaultdict

# sums = [["embed/2886", "embed/9394", "embed/9491", "embed/13979",
#          "embed/24081", "embed/16130", "embed/23101", "mlp_0/12107",
#          "mlp_0/12228", "mlp_0/4995", "mlp_0/28747", "mlp_0/28953",
#          "mlp_0/12282", "mlp_0/16595", "mlp_0/23395", "resid_0/8729",
#          "resid_0/11657", "resid_0/30086", "resid_0/15527", "resid_0/12983",
#          "resid_0/4893"],
#         ["resid_1/18376", "resid_1/7542", "resid_2/25123", "resid_1/20098", "resid_1/19708",
#          "resid_2/26688", "resid_2/31615", "resid_2/21572", "resid_2/6749"], ["resid_1/ε", "attn_2/ε", "resid_2/ε"],
#         ["attn_3/14579"],
#         ["attn_3/26353", "resid_3/11408", "resid_3/5357", "resid_3/25644", 
#          "resid_3/27675", "resid_3/26633", "resid_4/26222", "resid_4/23831",
#          "resid_4/19577", "resid_4/4340", "resid_4/8908", "resid_4/12046",
#          "resid_5/899", "resid_5/26851", "resid_5/27843", "resid_5/21625"], ["resid_3/ε", "resid_5/ε"]
# ]

# sums = [["embed/10859", "embed/13807", "embed/16937", "embed/31875", "embed/8233",
#          "mlp_0/10665", "mlp_0/29209", "mlp_0/15008", "mlp_0/20713", "mlp_0/26732"], ["embed/ε"],
#         ["resid_0/12292", "resid_0/30147", "resid_1/32358", "resid_2/22653",
#          "resid_2/20421", "resid_3/10364", "resid_3/31728", "resid_3/11794",
#          "resid_4/25089", "resid_4/6798", "resid_4/1620", "resid_4/10570"], ["resid_1/ε", "resid_2/ε", "resid_3/ε", "resid_4/ε"],
#         ["resid_5/7552", "resid_5/308"]
#         ]
# sums = [["embed/293", "embed/15970", "embed/22499", "embed/26494", "embed/29610" "embed/31565",
#          "mlp_0/5536", "mlp_0/16150", "mlp_0/17834", "mlp_0/19216", "mlp_0/24139", "mlp_0/30990",
#          "resid_0/1807", "resid_0/2494", "resid_0/10590", "resid_0/15075", "resid_0/21085", "resid_0/23953",
#          "mlp_1/22854", "resid_1/7749", "resid_1/10524", "resid_1/12037", "resid_1/15769", "resid_1/18001",
#          "resid_2/10995", "resid_2/14779", "resid_3/8437", "resid_3/11981", "mlp_4/15560", "resid_4/8913",
#          "resid_4/11586"], [""]
#         [], []
#         [], []]
sums = [['1, embed/293', '1, embed/1319', '1, embed/4285', '1, embed/10189', '1, embed/10941', '1, embed/15970', '1, embed/19485', '1, embed/22499', '1, embed/26494', '1, embed/28282', '1, embed/29610', '1, embed/31565', '1, mlp_0/5536', '1, mlp_0/16150', '1, mlp_0/16934', '1, mlp_0/17834', '1, mlp_0/19216', '1, mlp_0/23899', '1, mlp_0/24139', '1, mlp_0/28637', '1, mlp_0/30990', '1, resid_0/1807', '1, resid_0/2494', '1, resid_0/9765', '1, resid_0/10590', '1, resid_0/15075', '1, resid_0/21085', '1, resid_0/23953', '1, resid_0/24693', '1, resid_0/25899', '1, mlp_1/10410', '1, mlp_1/22854', '1, resid_1/7749', '1, resid_1/10524', '1, resid_1/12037', '1, resid_1/15769', '1, resid_1/18001', '1, resid_2/10995', '1, resid_2/14779', '1, resid_3/8437', '1, resid_3/11981', '1, resid_3/20009', '1, mlp_4/15560', '1, resid_4/8913', '1, resid_4/11586', '1, resid_4/16089', '1, resid_4/22614', '1, resid_4/26914'], ['1, embed/ε', '1, mlp_0/ε', '1, mlp_1/ε', '1, resid_1/ε', '1, attn_2/ε', '1, resid_2/ε', '1, resid_3/ε', '1, mlp_4/ε'], ['2, attn_2/32044', '2, resid_2/30677', '2, mlp_3/6782', '2, mlp_3/17671', '2, resid_3/18529'], ['2, attn_2/ε', '2, resid_2/ε', '2, mlp_3/ε', '2, resid_3/ε'], ['5, attn_4/3982', '5, attn_4/31148', '5, resid_4/14719', '5, attn_5/974', '5, attn_5/7447', '5, attn_5/9376', '5, attn_5/19313', '5, attn_5/25516', '5, resid_5/295', '5, resid_5/3036', '5, resid_5/4264', '5, resid_5/7468', '5, resid_5/8919', '5, resid_5/19066', '5, resid_5/28507', '5, resid_5/30107'], ['5, attn_4/ε', '5, resid_4/ε', '5, attn_5/ε', '5, mlp_5/ε', '5, resid_5/ε']]


circuit_path = "circuits/rc_train_dict10_node0.1_edge0.01_n100_aggnone.pt"
with open(circuit_path, "rb") as circuit_data:
    circuit = t.load(circuit_data)

sum_effects = defaultdict(int)

nodes = circuit["nodes"]
min_effect = min([v.to_tensor().min() for n, v in nodes.items() if n != 'y'])
max_effect = max([v.to_tensor().sum() for n, v in nodes.items() if n != 'y'])

for idx, cluster in enumerate(sums):
    for feature in cluster:
        pos_submod, feat_idx = feature.split("/")
        pos, submod = pos_submod.split(", ")
        pos = int(pos)
        if feat_idx == "ε":
            effect = nodes[submod].resc[pos]
        else:
            effect = nodes[submod].act[pos, int(feat_idx)]
        sum_effects[idx] += effect

for cluster in sum_effects:
    max_effect = max(sum_effects[cluster], max_effect)
    min_effect = min(sum_effects[cluster], min_effect)

scale = max(abs(min_effect), abs(max_effect))

def to_hex(number):
    number = number / scale
    
    # Define how the intensity changes based on the number
    # - Negative numbers increase red component to max
    # - Positive numbers increase blue component to max
    # - 0 results in white
    if number < 0:
        # Increase towards red, full intensity at -1.0
        red = 255
        green = blue = int((1 + number) * 255)  # Increase other components less as it gets more negative
    elif number > 0:
        # Increase towards blue, full intensity at 1.0
        blue = 255
        red = green = int((1 - number) * 255)  # Increase other components less as it gets more positive
    else:
        # Exact 0, resulting in white
        red = green = blue = 255 
    
    # decide whether text is black or white depending on darkness of color
    text_hex = "#000000" if (red*0.299 + green*0.587 + blue*0.114) > 170 else "#ffffff"

    # Convert to hex, ensuring each component is 2 digits
    hex_code = f'#{red:02X}{green:02X}{blue:02X}'
    
    return hex_code, text_hex

for cluster in sum_effects:
    hex = to_hex(sum_effects[cluster])
    print(len(sums[cluster]))
    print(hex)

48
('#0000FF', '#ffffff')
8
('#D3D3FF', '#000000')
5
('#EAEAFF', '#000000')
4
('#DCDCFF', '#000000')
16
('#A8A8FF', '#000000')
5
('#DBDBFF', '#000000')


In [131]:
from nnsight.models.UnifiedTransformer import UnifiedTransformer
from sae_lens import SparseAutoencoder
from tokenizers.processors import TemplateProcessing

model = UnifiedTransformer("gpt2-small", device="cuda")
resids = [block.hook_resid_pre for block in model.blocks]
component = 'resid'

# if component == 'resid':
#     submodule = resids[layer]

activation_dim=768

buffer = ActivationBuffer(
    zst_to_generator('/share/data/datasets/pile/the-eye.eu/public/AI/pile/train/00.jsonl.zst'),
    model,
    submodule,
    io='out',
    in_feats=activation_dim,
    out_feats=activation_dim,
    in_batch_size=128,
    out_batch_size=2 ** 13,
    n_ctxs=1e4,
    device='cuda:0',
)

# the GPT-2 SAEs expect a BOS token at start of sequence. nnsight doesn't do this,
# so we need to tell the tokenizer to always do this
# model.tokenizer._tokenizer.post_processor = TemplateProcessing(
#     single=model.tokenizer.bos_token + " $A",
#     special_tokens=[(model.tokenizer.bos_token, model.tokenizer.bos_token_id)]
# )

dictionaries = {}
for i in range(len(model.blocks)):
    ae = AutoEncoder(768, 32 * 768).to("cuda")
    ae.load_state_dict(t.load(f'GPT2-SAEs/gpt2_layer{i}_resid_pre_sfc.pt'))
    dictionaries[resids[i]] = ae

17

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import zst_to_generator
import torch as t
import gc
from dictionary_learning.dictionary import GatedAutoEncoder
from dictionary_learning.buffer import ActivationBuffer
from sae_lens import SparseAutoencoder
from tokenizers.processors import TemplateProcessing

model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=t.float16,
                             device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
layer = 16
submodule_name = f"model.layers.{layer}"
submodule = model.model.layers[layer]

dictionaries = {}
ae = GatedAutoEncoder(4096, 32768).half().to("cuda")
ae.load_state_dict(t.load(f'llama_saes/layer{layer}/ae_81920.pt'))
dictionaries[submodule] = ae

from datasets import load_dataset
def download_dataset(dataset_name, tokenizer, max_length=256, num_datapoints=None):
    if(num_datapoints):
        split_text = f"train[:{num_datapoints}]"
    else:
        split_text = "train"
    dataset = load_dataset(dataset_name, split=split_text).map(
        lambda x: tokenizer(x['text']),
        batched=True,
    ).filter(
        lambda x: len(x['input_ids']) > max_length
    ).map(
        lambda x: {'input_ids': x['input_ids'][:max_length]}
    )
    return dataset

dataset_name = "stas/openwebtext-10k"
max_seq_length = 40
print(f"Downloading {dataset_name}")
dataset = download_dataset(dataset_name, tokenizer=tokenizer, max_length=max_seq_length, num_datapoints=7000)

# if component == 'resid':
#     submodule = resids[layer]



# the GPT-2 SAEs expect a BOS token at start of sequence. nnsight doesn't do this,
# so we need to tell the tokenizer to always do this
# model.tokenizer._tokenizer.post_processor = TemplateProcessing(
#     single=model.tokenizer.bos_token + " $A",
#     special_tokens=[(model.tokenizer.bos_token, model.tokenizer.bos_token_id)]
# )


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Downloading stas/openwebtext-10k


In [2]:
import torch
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm
from baukit import Trace

def get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32):
    num_features, d_model = autoencoder.encoder.weight.shape
    datapoints = dataset.num_rows
    dictionary_activations = torch.zeros((datapoints*max_seq_length, num_features))
    token_list = torch.zeros((datapoints*max_seq_length), dtype=torch.int64)
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            batch = batch.to(model.device)
            token_list[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length] = rearrange(batch, "b s -> (b s)")
            with Trace(model, cache_name) as ret:
                _ = model(batch).logits
                internal_activations = ret.output
                # check if instance tuple
                if(isinstance(internal_activations, tuple)):
                    internal_activations = internal_activations[0]
            batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
            dictionary_activations[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
    return dictionary_activations, token_list

batch_size = 128
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, submodule_name, max_seq_length, ae, batch_size=batch_size)

100%|██████████| 55/55 [00:53<00:00,  1.04it/s]


In [39]:
# CUSTOM TEXT
from dictionary_learning.interp_utils import *

# text = [" If you know that you shouldn'"]
# text = [" What we know about the ownership of Barack Obama'"]
# text = [" You shouldn't done that! Now you'"]
# text = [" You know that I'"]
# text = ["Die Hunde sind hier, und der Hund ist da"]
# text = ["Mies tuo ja miehet tuovat"]
text = ["The woman who was brought the mail disappeared mysteriously after reading the bad news in it."]
# text = ["Rumah di sini besar-besar"]
# text = ["Lelaki-lelaki itu juga diekspor ke luar negeri"]
tokens = tokenizer.encode(text[0])
tokens = torch.tensor(tokens).unsqueeze(0)
dict_act = get_autoencoder_activation(model, submodule_name, tokens, ae)
html = tokens_and_activations_to_html(tokens, dict_act[:, 25569], tokenizer)
display(HTML(html))

In [19]:
from dictionary_learning.interp_utils import *

num_feature_datapoints = 10
features = [26067]
# features = [7251]
ablate_context = False
# ablate_context = True
for feature in features:
    nz_ind_amount = dictionary_activations[:, feature].count_nonzero()
    print(f"feature: {feature}, non-zero activations: {nz_ind_amount}")
    if(nz_ind_amount == 0):
        continue
    # uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="max")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    logit_diffs = ablate_feature_direction(model, full_token_list, submodule_name, max_seq_length, ae, feature = feature, batch_size=32, setting="sentences", model_type="causal")
    # logit_diffs = None

    html = tokens_and_activations_to_html(full_token_list, full_activations, tokenizer, logit_diffs=logit_diffs)
    print(f"feature: {feature}")
    display(HTML(html))
    if(ablate_context):
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, submodule_name, ae, feature, max_ablation_length=10)
        # html = tokens_and_activations_to_html(full_token_list, full_activations, tokenizer, logit_diffs=logit_diffs)

        html = tokens_and_activations_to_html(token_list, all_changed_activations, tokenizer)
        print("Context_ablation\n=================================================================")
        display(HTML(html))

feature: 17662, non-zero activations: 25685
feature: 17662


feature: 23756, non-zero activations: 14894
feature: 23756
