In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
%env JAX_TRACEBACK_FILTERING=off
import jax
jax.config.update('jax_traceback_filtering', 'off')


env: JAX_TRACEBACK_FILTERING=off


In [4]:
from sprint.icl_sfc_utils import Circuitizer

In [5]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [6]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [7]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

In [8]:
def check_if_single_token(token):
    return len(tokenizer.tokenize(token)) == 1

task_name = "antonyms"

task = tasks[task_name]

# print(len(task))

task = {
    k:v for k,v in task.items() if check_if_single_token(k) and check_if_single_token(v)
}

print(len(task))

pairs = list(task.items())

batch_size = 8 
n_shot=16
max_seq_len = 128
seed = 10

prompt = "Follow the pattern:\n{}"

runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

143


In [9]:
layers = list(range(6, 17))
circuitizer = Circuitizer(llama, tokenizer, runner, layers, prompt)

Setting up masks...




Running metrics...
Setting up RMS...


  0%|          | 0/18 [00:00<?, ?it/s]

Loading SAEs...


  0%|          | 0/11 [00:00<?, ?it/s]

Running node IEs...


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [12]:

from micrlhf.llama import LlamaBlock, LlamaAttention, LlamaInputs

block = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(11).get()

In [38]:
from penzai.toolshed.jit_wrapper import Jitted
layer = 11
attn_getter = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(0).at(lambda x: x.delta).at_instances_of(LlamaAttention).pick_nth_selected(0).at(lambda x: x.attn_value_to_output).at_instances_of(pz.nn.Linear)
attn_layer = attn_getter.get()
attn_getter = attn_getter.apply(lambda x: pz.nn.Sequential([pz.de.TellIntermediate.from_config(tag=f"attn_heads_{layer}"), x]))
attn_getter = pz.de.CollectingSideOutputs.handling(attn_getter, tag_predicate=lambda x: x.startswith("attn_heads_"))
attn_getter = Jitted(attn_getter)

In [43]:
attn_layer

In [34]:
_, attns = attn_getter(circuitizer.llama_inputs)

In [50]:
attn_output = attns[0].value
attn_outs = []
for kv_heads in range(attn_output.named_shape["kv_heads"]):
    for q_rep in range(attn_output.named_shape["q_rep"]):
        attn_pre = attn_output[{"kv_heads":kv_heads,"q_rep":q_rep}]
        layer_restricted = pz.nn.Linear(weights=pz.nn.Parameter(attn_layer.weights.value[{"kv_heads":kv_heads,"q_rep":q_rep}], "at"), in_axis_names=("projection",), out_axis_names=("embedding",))
        attn_out = layer_restricted(attn_pre)
        attn_outs.append(attn_out)

In [58]:
attn_outs[0]

In [53]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite


sae = get_nev_it_sae_suite(layer, label="attn_out")

In [57]:
feature = 4080
# r_pre = circuitizer.resids_pre[layer]
# r_mid = circuitizer.resids_mid[layer]
# attn_out = (r_mid - r_pre)
direction = sae["W_dec"][feature]

In [66]:
import jax.numpy as jnp

cossims = []
masks = list(circuitizer.masks.keys())
for ao in attn_outs:
    ao = ao.unwrap("batch", "seq", "embedding")
    cossims_all = (ao @ direction) / jnp.linalg.norm(ao, axis=-1) / jnp.linalg.norm(direction, axis=-1, keepdims=True)
    cossims_masks = []
    for mask in masks:
        cossims_masks.append(circuitizer.mask_average(cossims_all, mask))
    cossims.append(cossims_masks)

In [73]:
for i in range(len(cossims)):
    print(f"Head {i}")
    overall = sum(map(abs, cossims[i]))
    # print(f" Overall: {overall}")
    # if overall < 0.1:
        # continue
    arrow_index = masks.index("arrow")
    arrow_sim = cossims[i][arrow_index]
    if abs(arrow_sim) < 0.1:
        continue
    print(f" Arrow: {arrow_sim}")
    for j in range(len(cossims[i])):
        print(f"  {masks[j]}: {cossims[i][j]}")

Head 0
Head 1
Head 2
 Arrow: 0.466797
  prompt: -0.020874
  arrow: 0.466797
  newline: 0.057373
  input: 0.219727
  output: 0.120117
Head 3
 Arrow: 0.168945
  prompt: 0.0112305
  arrow: 0.168945
  newline: 0.0810547
  input: 0.0400391
  output: 0.0634766
Head 4
Head 5
Head 6
Head 7
