In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
%env JAX_TRACEBACK_FILTERING=off
import jax
jax.config.update('jax_traceback_filtering', 'off')


env: JAX_TRACEBACK_FILTERING=off


In [4]:
from sprint.icl_sfc_utils import Circuitizer

In [5]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [6]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [7]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

In [8]:
def check_if_single_token(token):
    return len(tokenizer.tokenize(token)) == 1

task_name = "es_en"

task = tasks[task_name]

print(len(task))

task = {
    k:v for k,v in task.items() if check_if_single_token(k) and check_if_single_token(v)
}

print(len(task))

pairs = list(task.items())

batch_size = 8 
n_shot=20
max_seq_len = 128
seed = 10

prompt = "Follow the pattern:\n{}"

runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

763
346


In [9]:
layers = list(range(6, 17))
circuitizer = Circuitizer(llama, tokenizer, runner, layers, prompt)

Setting up masks...
Running metrics...
Setting up RMS...


  0%|          | 0/18 [00:00<?, ?it/s]

Loading SAEs...


  0%|          | 0/11 [00:00<?, ?it/s]

Running node IEs...


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

ValueError: matmul input operand 0 must have ndim at least 1, but it has ndim 0

In [23]:
import json

with open("micrlhf-progress/full-graph-antonyms.json") as f:
    graph = json.load(f)

In [44]:
circuit = graph["edges"][:20000]

In [52]:
len(circuit)

In [54]:
from tqdm.auto import tqdm, trange
from micrlhf.utils.load_sae import sae_encode
from collections import namedtuple
import jax.numpy as jnp
import numpy as np
from collections import defaultdict
start_layer = 7
layer_context = namedtuple("layer_context", ["resids_pre", "sae_resid", "sae_attn_out", "resids_mid", "sae_transcoder", "resids_post"])

previous_ctx = layer_context(
    circuitizer.resids_pre[start_layer - 1],
    sae_encode(circuitizer.saes[(start_layer - 1, "resid")], circuitizer.resids_pre[start_layer])[1],
    sae_encode(circuitizer.saes[(start_layer - 1, "attn_out")], circuitizer.resids_mid[start_layer] - circuitizer.resids_pre[start_layer])[1],
    circuitizer.resids_mid[start_layer - 1],
    sae_encode(circuitizer.saes[(start_layer - 1, "transcoder")], circuitizer.resids_mid[start_layer])[1],
    circuitizer.resids_pre[start_layer]
)

def find_targets(layer, type, mask, source_type, source_mask):
    source_sae = circuitizer.saes[(layer if type != "resid" else layer - 1, source_type)]
    masks = defaultdict(lambda: np.zeros(source_sae["W_dec"].shape[0]))
    for _, source, target in circuit:
        if target[0] != type[0]:
            continue
        if target[1] != layer:
            continue
        if source[0] != source_type[0]:
            continue
        if source[2] != source_mask:
            continue
        if target[2] != mask:
            continue
        masks[target[3]][source[3]] = 1
    return {k: jnp.asarray(v) for k, v in masks.items()}

for layer in trange(start_layer, layers[-1] + 1):
    target_sae = circuitizer.saes[(layer, "resid")]
    inputs = previous_ctx.resids_post
    pre_encodings, _, pre_recon = sae_encode(target_sae, inputs)
    pre_err = previous_ctx.resids_post - pre_recon
    for source_type in ["attn_out", "transcoder"]:
        source_sae = circuitizer.saes[(layer - 1, source_type)]
        for mask in circuitizer.masks:
            masks = find_targets(layer, "resid", mask, source_type, mask)
            for target_feature, source_mask in masks.items():
                sae_feats = getattr(previous_ctx, f"sae_{source_type}")
                # (b, t, d)
                input_delta = -(sae_feats * (1 - source_mask)) @ source_sae["W_dec"] / source_sae["out_norm_factor"]
                # (b, t(sm), d)
                input_delta = input_delta * circuitizer.masks[mask][..., None]
                # d
                grad_to_input = target_sae["W_enc"][:, target_feature] * target_sae["norm_factor"]

                pre_encodings = pre_encodings.at[..., target_feature].add(input_delta @ grad_to_input)
    _, pre_encodings, resid_pre = sae_encode(target_sae, inputs, pre_relu=pre_encodings)
    resid_pre = resid_pre + pre_err

    target_sae = circuitizer.saes[(layer, "attn_out")]
    inputs = circuitizer.grad_through_attn_fwd(layer, resid=resid_pre)
    attn_encodings, _, attn_recon = sae_encode(target_sae, inputs)
    attn_err = inputs - attn_recon
    source_type = "resid"
    source_sae = circuitizer.saes[(layer, source_type)]
    sae_feats = pre_encodings
    for source_mask in circuitizer.masks:
        for target_mask in circuitizer.masks:
            masks = find_targets(layer, "attn_out", target_mask, source_type, source_mask)
            for target_feature, feature_mask in masks.items():
                # (b, t, d)
                input_delta = -(sae_feats * (1 - feature_mask)) @ source_sae["W_dec"] / source_sae["out_norm_factor"]
                # (b, t(sm), d)
                input_delta = input_delta * circuitizer.masks[source_mask][..., None]
                # (b, t, d)
                input_delta = circuitizer.grad_through_attn_fwd(layer, resid=resid_pre, grad=input_delta)
                # (b, t(sm), d)
                input_delta = input_delta * circuitizer.masks[target_mask][..., None]
                # d
                grad_to_input = target_sae["W_enc"][:, target_feature] * target_sae["norm_factor"]
                
                attn_encodings = attn_encodings.at[..., target_feature].add(input_delta @ grad_to_input)
    _, attn_encodings, attn_recon = sae_encode(target_sae, inputs, pre_relu=attn_encodings)
    resid_mid = resid_pre + attn_recon + attn_err

KeyboardInterrupt: 

In [34]:
inputs

In [None]:
#[[8.290185360237956e-05, ['er', 16, 'arrow', 0], ['a', 16, 'arrow', 26950]], [3.0926865292713046e-05, ['r', 16, 'prompt', 5241], ['a', 16, 'arrow', 26950]], [2.6628349587554112e-05, ['r', 16, 'arrow', 11391], ['a', 16, 'arrow', 26950]], [2.0594392481143586e-05, ['r', 16, 'arrow', 1925], ['a', 16, 'arrow', 26950]], [1.5115344467631076e-05, ['r', 16, 'arrow', 1383], ['a', 16, 'arrow', 26950]], [1.3805756680085324e-05, ['r', 16, 'arrow', 31633], ['a', 16, 'arrow', 26950]], [9.770472388481721e-06, ['er', 16, 'prompt', 0], ['a', 16, 'arrow', 26950]], [9.19807735044742e-06, ['r', 16, 'input', 15720], ['a', 16, 'arrow', 26950]], [8.853277904563583e-06, ['r', 16, 'input', 23790], ['a', 16, 'arrow', 26950]], [8.566045835323166e-06, ['r', 16, 'input', 8739], ['a', 16, 'arrow', 26950]]]


In [12]:
def ablate_pre_to_attn_out(circuitizer, layer, mask, target_feature, ablated_features):
    effects = circuitizer.attn_out_feature_to_pre()
