In [4]:
%load_ext autoreload
%autoreload 2

from dictionary_learning import AutoEncoder, ActivationBuffer, GatedAutoEncoder
from nnsight import LanguageModel
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import zst_to_generator
import torch as t
import gc

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
from nnsight.models.UnifiedTransformer import UnifiedTransformer
from sae_lens import SparseAutoencoder
from tokenizers.processors import TemplateProcessing

model = LanguageModel("meta-llama/Meta-Llama-3-8B", torch_dtype=t.float16,
                      device_map="cuda")
resids = [layer for layer in model.model.layers]
component = 'resid'

# if component == 'resid':
#     submodule = resids[layer]

activation_dim=4096



# the GPT-2 SAEs expect a BOS token at start of sequence. nnsight doesn't do this,
# so we need to tell the tokenizer to always do this
# model.tokenizer._tokenizer.post_processor = TemplateProcessing(
#     single=model.tokenizer.bos_token + " $A",
#     special_tokens=[(model.tokenizer.bos_token, model.tokenizer.bos_token_id)]
# )

dictionaries = {}
for i in (16,):
    ae = GatedAutoEncoder(4096, 32768).to("cuda")
    ae.load_state_dict(t.load(f'llama_saes/layer{i}/ae_81920.pt'))
    ae = ae.half()
    dictionaries[resids[i]] = ae
    break
    # obj = t.load(f'llama_saes/layer{i}/ae_81920.pt')
    # print(obj)
resids = list(dictionaries.keys())

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [31]:
import random
import json

def load_examples_prefix_len(dataset, num_examples, model, seed=12, pad_to_length=None, length=None,
                  ignore_patch=False):
    examples = []
    dataset_items = open(dataset).readlines()
    random.seed(seed)
    random.shuffle(dataset_items)
    for line in dataset_items:
        data = json.loads(line)
        clean_prefix = model.tokenizer(data["clean_prefix"], return_tensors="pt",
                                        padding=False).input_ids
        patch_prefix = model.tokenizer(data["patch_prefix"], return_tensors="pt",
                                        padding=False).input_ids
        clean_answer = model.tokenizer(data["clean_answer"], return_tensors="pt",
                                        padding=False).input_ids
        patch_answer = model.tokenizer(data["patch_answer"], return_tensors="pt",
                                        padding=False).input_ids

        clean_prefix_firstsent = data["clean_prefix"].split(".")[0]
        clean_prefix_firstsent_tok = model.tokenizer(clean_prefix_firstsent, return_tensors="pt",
                                                     padding=False).input_ids
        
        # remove BOS tokens from answers
        clean_answer = clean_answer[clean_answer != model.tokenizer.bos_token_id].unsqueeze(0)
        patch_answer = patch_answer[patch_answer != model.tokenizer.bos_token_id].unsqueeze(0)
        # only keep examples where answers are single tokens
        if not ignore_patch:
            if clean_prefix.shape[1] != patch_prefix.shape[1]:
                continue
        # only keep examples where clean and patch answers are the same length
        if clean_answer.shape[1] != 1 or patch_answer.shape[1] != 1:
            continue
        # if we specify a `length`, filter examples if they don't match
        if length and clean_prefix_firstsent_tok.shape[1] != length:
            continue
        # if we specify `pad_to_length`, left-pad all inputs to a max length
        prefix_length_wo_pad = clean_prefix.shape[1]
        if pad_to_length:
            model.tokenizer.padding_side = 'right'
            pad_length = pad_to_length - prefix_length_wo_pad
            if pad_length < 0:  # example too long
                continue
            # left padding: reverse, right-pad, reverse
            clean_prefix = t.flip(F.pad(t.flip(clean_prefix, (1,)), (0, pad_length), value=model.tokenizer.pad_token_id), (1,))
            patch_prefix = t.flip(F.pad(t.flip(patch_prefix, (1,)), (0, pad_length), value=model.tokenizer.pad_token_id), (1,))
        
        example_dict = {"clean_prefix": clean_prefix,
                        "patch_prefix": patch_prefix,
                        "clean_answer": clean_answer.item(),
                        "patch_answer": patch_answer.item(),
                        # "annotations": get_annotation(dataset, model, data),
                        "prefix_length_wo_pad": prefix_length_wo_pad,}
        examples.append(example_dict)
        if len(examples) >= num_examples:
            break

    return examples

In [35]:
from loading_utils import load_examples

data_path = "data/NPZ_gp_post_readingcomp_samelen.json"
# data_path = "data/NPZ_ambiguous_samelen.json"
ignore_patch = True
num_examples = 100
length = 11
pad_length = 32

examples = load_examples_prefix_len(data_path, num_examples, model, length=length, #pad_to_length=pad_length
                                     ignore_patch=False)
# examples = load_examples(data_path, num_examples, model, length=length, # pad_to_length=pad_length,
#                                      ignore_patch=True)
print(len(examples))

24


In [26]:
from activation_utils import SparseAct

tracer_kwargs = {'validate' : False, 'scan' : False}

def _pe_ig(
        clean,
        patch,
        model,
        submodules,
        dictionaries,
        metric_fn,
        steps=10,
        metric_kwargs=dict(),
):
    
    # first run through a test input to figure out which hidden states are tuples
    is_tuple = {}
    with model.trace("_"):
        for submodule in submodules:
            is_tuple[submodule] = type(submodule.output.shape) == tuple

    hidden_states_clean = {}
    with model.trace(clean, **tracer_kwargs), t.no_grad():
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            f = dictionary.encode(x)
            x_hat = dictionary.decode(f)
            residual = x - x_hat
            hidden_states_clean[submodule] = SparseAct(act=f.save(), res=residual.save())
        metric_clean = metric_fn(model, **metric_kwargs).save()
    hidden_states_clean = {k : v.value for k, v in hidden_states_clean.items()}

    if patch is None:
        hidden_states_patch = {
            k : SparseAct(act=t.zeros_like(v.act), res=t.zeros_like(v.res)) for k, v in hidden_states_clean.items()
        }
        total_effect = None
    else:
        hidden_states_patch = {}
        with model.trace(patch, **tracer_kwargs), t.no_grad():
            for submodule in submodules:
                dictionary = dictionaries[submodule]
                x = submodule.output
                if is_tuple[submodule]:
                    x = x[0]
                f = dictionary.encode(x)
                x_hat = dictionary.decode(f)
                residual = x - x_hat
                hidden_states_patch[submodule] = SparseAct(act=f.save(), res=residual.save())
            metric_patch = metric_fn(model, **metric_kwargs).save()
        total_effect = (metric_patch.value - metric_clean.value).detach()
        hidden_states_patch = {k : v.value for k, v in hidden_states_patch.items()}

    effects = {}
    deltas = {}
    grads = {}
    for submodule in submodules:
        dictionary = dictionaries[submodule]
        clean_state = hidden_states_clean[submodule]
        patch_state = hidden_states_patch[submodule]
        with model.trace(**tracer_kwargs) as tracer:
            metrics = []
            fs = []
            for step in range(steps):
                alpha = step / steps
                f = (1 - alpha) * clean_state + alpha * patch_state
                f.act.retain_grad()
                f.res.retain_grad()
                fs.append(f)
                with tracer.invoke(clean, scan=tracer_kwargs['scan']):
                    if is_tuple[submodule]:
                        submodule.output[0][:] = dictionary.decode(f.act) + f.res
                    else:
                        submodule.output = dictionary.decode(f.act) + f.res
                    metrics.append(metric_fn(model, **metric_kwargs))
            metric = sum([m for m in metrics])
            metric.sum().backward(retain_graph=True) # TODO : why is this necessary? Probably shouldn't be, contact jaden

        mean_grad = sum([f.act.grad for f in fs]) / steps
        mean_residual_grad = sum([f.res.grad for f in fs]) / steps
        grad = SparseAct(act=mean_grad, res=mean_residual_grad)
        delta = (patch_state - clean_state).detach() if patch_state is not None else -clean_state.detach()
        effect = grad @ delta

        effects[submodule] = effect
        deltas[submodule] = delta
        grads[submodule] = grad

    return (effects, deltas, grads, total_effect)

In [27]:
import math
import numpy as np
from tqdm import tqdm

batch_size = 1
num_examples = 100
device = "cuda"
num_examples = min([num_examples, len(examples)])
n_batches = math.ceil(len(examples) / batch_size)
batches = [
    examples[batch*batch_size:(batch+1)*batch_size] for batch in range(n_batches)
]
sum_effects = {}

for batch in tqdm(batches[1:]):
    clean_answer_idxs = t.tensor([e['clean_answer'] for e in batch], dtype=t.long, device=device)
    clean_inputs = t.cat([e['clean_prefix'] for e in batch], dim=0).to(device)

    patch_answer_idxs = t.tensor([e['patch_answer'] for e in batch], dtype=t.long, device=device)
    patch_inputs = t.cat([e['patch_prefix'] for e in batch], dim=0).to(device)
    def metric_fn(model):
        return (
            t.gather(model.lm_head.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) - \
            t.gather(model.lm_head.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
        )

    # for example in examples[:1]:
    effects, _, _, _ = _pe_ig(
        clean_inputs,
        patch_inputs,
        # None,
        model,
        resids,
        dictionaries,
        metric_fn
    )
    for submodule in resids:
        if submodule not in sum_effects:
            sum_effects[submodule] = effects[submodule].sum(dim=0)
        else:
            sum_effects[submodule] += effects[submodule].sum(dim=0)

print("positive effects")
for idx, submodule in enumerate(resids):
    sum_effects[submodule] /= num_examples
    sum_effects[submodule] = sum_effects[submodule].act[:length, :]
    print(sum_effects[submodule].shape)
    print(f"resid_{idx}")
    v, i = t.topk(sum_effects[submodule].flatten(), 10)
    print(np.array(np.unravel_index(i.cpu().numpy(), sum_effects[submodule].shape)).T)
    print(v)
    print()

print("negative effects")
for idx, submodule in enumerate(resids):
    print(f"resid_{idx}")
    v, i = t.topk(sum_effects[submodule].flatten(), 10, largest=False)
    print(np.array(np.unravel_index(i.cpu().numpy(), sum_effects[submodule].shape)).T)
    print(v)
    print()

  0%|          | 0/23 [00:00<?, ?it/s]

100%|██████████| 23/23 [00:26<00:00,  1.16s/it]

positive effects
torch.Size([9, 32768])
resid_0
[[    6  8245]
 [    6  8349]
 [    6   656]
 [    8 27982]
 [    6 21196]
 [    7 21196]
 [    6  3639]
 [    6  9408]
 [    7 27982]
 [    7 17244]]
tensor([0.0104, 0.0042, 0.0041, 0.0039, 0.0033, 0.0033, 0.0028, 0.0026, 0.0026,
        0.0025], device='cuda:0', dtype=torch.float16)

negative effects
resid_0
[[    7  7196]
 [    8 26787]
 [    8 20619]
 [    7 23975]
 [    6 19969]
 [    6 15156]
 [    8  7196]
 [    6  2155]
 [    8 19387]
 [    8 12974]]
tensor([-0.0052, -0.0020, -0.0015, -0.0012, -0.0011, -0.0011, -0.0010, -0.0010,
        -0.0010, -0.0008], device='cuda:0', dtype=torch.float16)






In [6]:
for idx, submodule in enumerate(resids):
    print(f"resid_{idx}")
    print(t.topk(-1 * sum_effects[submodule].act, 10))
    print()

resid_0
torch.return_types.topk(
values=tensor([2.3105, 0.6948, 0.5229, 0.4763, 0.4680, 0.4661, 0.3953, 0.3704, 0.3479,
        0.2825], device='cuda:0', dtype=torch.float16),
indices=tensor([17662, 19323,  8419, 17361,  1691, 23791,  4086, 15980, 25569, 20896],
       device='cuda:0'))



In [None]:
# load circuits, analyze
circuit_path = "circuits/NPS_ambiguous_samelen_dict10_node0.5_edge0.05_n24_aggnone_gpt2.pt"
circuit = t.load(open(circuit_path, 'rb'))

for submod in circuit["nodes"]:
    effects = circuit["nodes"][submod]
    top_effects = t.topk(effects.act, 10)
    print(submod, top_effects)

In [23]:
with model.trace("testing 1"):
    out_save = resids[4].output
    f = dictionaries[resids[4]].encode(out_save).save()
print(f.value.shape)

torch.Size([1, 2, 24576])
