In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [5]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

In [6]:
def check_if_single_token(token):
    return len(tokenizer.tokenize(token)) == 1

task_name = "es_en"

task = tasks[task_name]

print(len(task))

task = {
    k:v for k,v in task.items() if check_if_single_token(k) and check_if_single_token(v)
}

print(len(task))

pairs = list(task.items())

763
346


In [7]:
from sprint.task_vector_utils import logprob_loss
from functools import partial

sep = 3978
pad = 0

def metric_fn(logits, resids, tokens):
    return logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2)

In [8]:
from micrlhf.llama import LlamaBlock, LlamaAttention
from micrlhf.utils.activation_manipulation import ActivationAddition, wrap_vector
from functools import partial
import jax.numpy as jnp
from penzai import pz
import jax

@partial(jax.jit, static_argnames=("metric", "batched"))
def run_with_add(additions_pre, additions_mid, tokens, metric, batched=False, llama=None):
    get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
        pz.nn.Sequential([
            pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
            x
        ])
    )
    get_resids = get_resids.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda l, b: b.select().at_instances_of(pz.nn.Residual).apply_with_selected_index(lambda i, x: x if i == 0 else pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_mid_{l}"),
        x,
    ])))


    get_resids = get_resids.select().at_instances_of(LlamaAttention).apply_with_selected_index(lambda i, x: x.select().at_instances_of(pz.nn.Softmax).apply(lambda b: pz.nn.Sequential([
        b,
        pz.de.TellIntermediate.from_config(tag=f"attn_{i}"),
    ])))

    get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: True)
    make_additions = get_resids.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
        pz.nn.Sequential([
            ActivationAddition(pz.nx.wrap(additions_pre[i], *(("batch",) if batched else ()), "seq", "embedding"), "all"),
            x
        ])
    )
    make_additions = make_additions.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda l, b: b.select().at_instances_of(pz.nn.Residual).apply_with_selected_index(lambda i, x: x if i == 0 else pz.nn.Sequential([
        ActivationAddition(pz.nx.wrap(additions_mid[l], *(("batch",) if batched else ()), "seq", "embedding"), "all"),
        x,
    ])))
    tokens_wrapped = pz.nx.wrap(tokens, "batch", "seq")
    logits, resids = make_additions(llama.inputs.from_basic_segments(tokens_wrapped))
    return metric(logits.unwrap("batch", "seq", "vocabulary"), resids, tokens), (logits, resids[::3], resids[1::3], resids[2::3])


@partial(jax.jit, static_argnames=("metric",))
def get_metric_resid_grad(tokens, llama=llama, metric=metric_fn):
    additions = [jnp.zeros(tokens.shape + (llama.config.hidden_size,)) for _ in range(llama.config.num_layers)]
    batched = tokens.ndim > 1
    (metric, (logits, resids_pre, qk, resids_mid)), (grad_pre, grad_mid) = jax.value_and_grad(run_with_add, argnums=(0, 1), has_aux=True)(additions, additions, tokens, metric, batched=batched, llama=llama)
    return (
        metric,
        [r.value.unwrap("batch", "seq", "embedding") for r in resids_pre],
        [r.value.unwrap("batch", "seq", "embedding") for r in resids_mid],
        [r.value.unwrap("batch", "kv_heads", "q_rep", "seq", "kv_seq") for r in qk],
        grad_pre,
        grad_mid
    )


In [9]:
batch_size = 8 
n_shot=20
max_seq_len = 128
seed = 10

In [10]:
prompt = "Follow the pattern:\n{}"

In [11]:
runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

In [12]:
from sprint.task_vector_utils import tokenized_to_inputs

train_tokens = runner.get_tokens(
    runner.train_pairs, tokenizer
)["input_ids"]

In [13]:
metric_value, resids_pre, resids_mid, qk, grad_pre, grad_mid = get_metric_resid_grad(train_tokens, llama=llama)

In [43]:
get_rms_block = lambda layer, resid_index: (
    llama.select()
    .at_instances_of(LlamaBlock).pick_nth_selected(layer)
    .at_instances_of(pz.nn.Residual).pick_nth_selected(resid_index)
    .at_instances_of(pz.nn.RMSLayerNorm).pick_nth_selected(0)
    ).get()

In [44]:
mlp_rms = [get_rms_block(layer, 1) for layer in range(llama.config.num_layers)]

In [17]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite

In [18]:
def weights_to_resid(weights, sae):
    if "s_gate" in sae:
        weights = (weights > 0) * jax.nn.relu(weights * jax.nn.softplus(sae["s_gate"]) * sae.get("scaling_factor", 1.0) + sae["b_gate"])   
    else:
        weights = jax.nn.relu(weights)

    recon = jnp.einsum("fv,bsf->bsv", sae["W_dec"], weights)

    if "norm_factor" in sae:
        recon = recon / sae["norm_factor"]

    # recon = recon.astype('bfloat16')
    return recon

In [19]:
from micrlhf.utils.load_sae import sae_encode_gated

In [48]:
def mlp_normalize(layer, resid_mid):
    # return resid_mid / resids_mid_norms[layer] * mlp_rms_weights[layer]
    # return resid_mid / jnp.linalg.norm(resid_mid, axis=-1, keepdims=True) * mlp_rms_weights[layer]
    return mlp_rms[layer](pz.nx.wrap(resid_mid, "batch", "seq", "embedding")).unwrap("batch", "seq", "embedding")

In [49]:
from tqdm.auto import tqdm

ie_attn = {}
sae_grads_attn = {}
ie_resid = {}
sae_grads_resid = {}
ie_transcoder = {}
sae_grads_transcoder = {}

ie_error_attn = {}
ie_error_resid = {}
ie_error_transcoder = {}

def sfc_simple(grad, resid, target, sae):
    pre_relu, post_relu, recon = sae_encode_gated(sae, resid)
    error = target - recon
    f = partial(weights_to_resid, sae=sae)

    sae_grad, = jax.vjp(f, post_relu)[1](grad,)
    indirect_effects = sae_grad * post_relu
    indirect_effects_error = grad * error
    return indirect_effects, indirect_effects_error, sae_grad


layers = list(range(6, 17))
for layer in tqdm(layers):
    r_pre, r_mid, g_mid = resids_pre[layer], resids_mid[layer], grad_mid[layer]
    sae = get_nev_it_sae_suite(layer=layer, label="attn_out")
    indirect_effects, indirect_effects_error, sae_grad = sfc_simple(g_mid, r_mid - r_pre, r_mid - r_pre, sae)
    # display((indirect_effects > 0).sum(-1))
    ie_attn[layer] = indirect_effects
    ie_error_attn[layer] = indirect_effects_error
    sae_grads_attn[layer] = sae_grad

# for layer, (r_pre, g_pre) in enumerate(zip(resids_pre, grad_pre)):
for layer in tqdm(layers):
    r_pre, g_pre = resids_pre[layer], grad_pre[layer + 1]
    sae = get_nev_it_sae_suite(layer=layer)
    indirect_effects, indirect_effects_error, sae_grad = sfc_simple(g_pre, r_pre, r_pre, sae)
    # display((indirect_effects != 0).sum(-1))
    ie_resid[layer] = indirect_effects
    ie_error_resid[layer] = indirect_effects_error
    sae_grads_resid[layer] = sae_grad

for layer in tqdm(layers[:-1]):
    r_mid, r_pre, g_pre = resids_mid[layer], resids_pre[layer + 1], grad_pre[layer + 1]
    sae = get_nev_it_sae_suite(layer=layer, label="transcoder")
    indirect_effects, indirect_effects_error, sae_grad = sfc_simple(g_pre, mlp_normalize(layer, r_mid), r_pre - r_mid, sae)
    # display((indirect_effects != 0).sum(-1))
    ie_transcoder[layer] = indirect_effects
    ie_error_transcoder[layer] = indirect_effects_error
    sae_grads_transcoder[layer] = sae_grad

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [29]:
tokens_wrapped = pz.nx.wrap(train_tokens, "batch", "seq")
llama_inputs = llama.inputs.from_basic_segments(tokens_wrapped)

In [None]:
# def attn_call(layer, resid_pre):
#     subblock = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(0).get().delta
#     subblock = subblock.select().at_instances_of(LlamaAttention).get()

#     si_selection = subblock.select().at_instances_of(pz.de.HandledSideInputRef)
#     keys = sorted(set([ref.tag for ref in si_selection.get_sequence()]))
#     replaced = si_selection.apply(lambda ref: pz.de.SideInputRequest(tag=ref.tag))
#     subblock = pz.de.WithSideInputsFromInputTuple.handling(replaced, keys)

#     side_inputs = {
#         'positions': llama_inputs.positions,
#         'attn_mask': llama_inputs.attention_mask
#     }
    
#     resid_pre = resid_pre / 

#     resid_pre = pz.nx.wrap(resid_pre, "batch", "seq", "embedding")
#     attn_out = subblock((resid_pre,) + tuple(side_inputs[tag] for tag in subblock.side_input_tags))

#     attn_out = attn_out.unwrap("batch", "seq", "embedding") 

#     return attn_out.astype(resid.dtype)

# attn_call(6, resids_pre[6])

In [19]:
prompt_length = len(tokenizer.tokenize(prompt))
periods = ["input", "arrow", "output", "newline"]
masks = {
    "prompt": jnp.zeros_like(train_tokens).at[:, :prompt_length].set(1).astype(bool),
    **{
        period: jnp.zeros_like(train_tokens).at[:, prompt_length+i::len(periods)].set(1).astype(bool) * (train_tokens != pad) for i, period in enumerate(periods)
    }
}

In [20]:
def compute_avg_node_effect(ie, top_k=64):
    masked = {}
    for mask_name, mask in masks.items():
        mask = mask[..., None]
        masked[mask_name] = ((mask * jnp.abs(ie)).sum(1) / mask.sum(1)).mean(0)
    masked = {k: jax.lax.top_k(v, k=top_k) for k,v in masked.items()}
    return masked
def compute_all_avg_node_effects(ies):
    return {k: compute_avg_node_effect(v) for k,v in ies.items()}
# compute_all_avg_node_effects(ie_resid)

In [21]:
def resids_to_weights(vector, sae):
    inputs = vector

    if "norm_factor" in sae:
        inputs = inputs * sae["norm_factor"]

    pre_relu = inputs @ sae["W_enc"]
    pre_relu = pre_relu +sae["b_enc"]
    post_relu = jax.nn.relu(pre_relu)
    
    post_relu = (post_relu > 0) * jax.nn.relu((inputs @ sae["W_enc"]) * jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"] + sae["b_gate"])   

    return post_relu


In [23]:
def mask_average(vector, mask):
    mask = masks[mask]
    while mask.ndim < vector.ndim:
        mask = mask[..., None]

    return ((mask * vector).sum(1) / mask.sum(1)).mean(0)

In [50]:
def transcoder_feature_to_mid(layer, feature_idx, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="transcoder")
    mask = masks[mask]

    resid = resids_mid[layer]

    def f(resid):
        resid = mlp_normalize(layer, resid)
        batch_token_feat = resids_to_weights(resid, sae)[:, :, feature_idx] * sae_grads_transcoder[layer][:, :, feature_idx]
        token_act = mask_average(batch_token_feat, mask)
        return token_act

    return jax.grad(f)(resid)

In [59]:
def attn_out_feature_to_pre(layer, feature_idx, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="attn_out")
    mask = masks[mask]

    resid = resids_pre[layer]

    subblock = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(0).delta

    si_selection = subblock.select().at_instances_of(pz.de.HandledSideInputRef)
    keys = sorted(set([ref.tag for ref in si_selection.get_sequence()]))
    replaced = si_selection.apply(lambda ref: pz.de.SideInputRequest(tag=ref.tag))
    subblock = pz.de.WithSideInputsFromInputTuple.handling(replaced, keys)

    side_inputs = {
        'positions': llama_inputs.positions,
        'attn_mask': llama_inputs.attention_mask
    }

    def f(resid):
        resid = pz.nx.wrap(resid, "batch", "seq", "embedding")
        attn_out = subblock((resid,) + tuple(side_inputs[tag] for tag in subblock.side_input_tags))

        attn_out = attn_out.unwrap("batch", "seq", "embedding") 

        batch_token_feat = resids_to_weights(attn_out, sae)[:, :, feature_idx] * sae_grads_attn[layer][:, :, feature_idx]
        token_act = mask_average(batch_token_feat, mask)
        return token_act

    return jax.grad(f)(resid)

In [26]:
def pre_feature_to_pre(layer, feature_idx, mask):
    sae = get_nev_it_sae_suite(layer=layer)
    mask = masks[mask]
    resid = resids_pre[layer]

    def f(resid):
        batch_token_feat = resids_to_weights(resid, sae)[:, :, feature_idx] * sae_grads_resid[layer][:, :, feature_idx]
        token_act = mask_average(batch_token_feat, mask)
        return token_act

    return jax.grad(f)(resid)

In [51]:
def ie_pre_to_transcoder_features(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="transcoder")
    resid_mid = resids_mid[layer]
    resid_mid = mlp_normalize(layer, resid_mid)
    ie = sfc_simple(grad, resid_mid, resid_mid, sae)[0]
    ie = mask_average(ie, mask)

    return ie

In [None]:
def ie_mid_to_attn_features(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="attn_out")
    resid_mid = resids_mid[layer]
    resid_pre = resids_pre[layer]

    ie = sfc_simple(grad, resid_mid - resid_pre, resid_mid - resid_pre, sae)[0]
    ie = mask_average(ie, mask)
    return ie

In [35]:
def ie_pre_to_pre_features(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer)
    resid = resids_pre[layer]
    ie = sfc_simple(grad, resid, resid, sae)[0]
    ie = mask_average(ie, mask)
    return ie

In [36]:
ie_pre_to_pre_features(6, grad_pre[6], "prompt")

In [52]:
def grad_through_transcoder(layer, grad):
    sae = get_nev_it_sae_suite(layer, label="transcoder")
    resid_mid = resids_mid[layer]

    def f(resid_mid):
        resid_mid = mlp_normalize(layer, resid_mid)
        # we ignore error nodes
        weights = resids_to_weights(resid_mid, sae)
        recon = weights_to_resid(weights, sae)

        return recon

    grad = jax.vjp(f, resid_mid)[1](grad,)

    return grad


In [39]:
def grad_through_mlp(layer, grad):
    mlp = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(1).get().delta
    resid_mid = resids_mid[layer]
    def f(resid_mid):
        resids_mid = pz.nx.wrap(resid_mid, "batch", "seq", "embedding")
        out = mlp(resids_mid)
        return out.unwrap("batch", "seq", "embedding")
    return jax.vjp(f, resid_mid)[1](grad,)

In [56]:
def grad_through_attn(layer, grad):
    subblock = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(0).get().delta

    si_selection = subblock.select().at_instances_of(pz.de.HandledSideInputRef)
    keys = sorted(set([ref.tag for ref in si_selection.get_sequence()]))
    replaced = si_selection.apply(lambda ref: pz.de.SideInputRequest(tag=ref.tag))
    subblock = pz.de.WithSideInputsFromInputTuple.handling(replaced, keys)

    side_inputs = {
        'positions': llama_inputs.positions,
        'attn_mask': llama_inputs.attention_mask
    }

    def f(resid):
        resid_pre = pz.nx.wrap(resid, "batch", "seq", "embedding")
        attn_out = subblock((resid_pre,) + tuple(side_inputs[tag] for tag in subblock.side_input_tags))

        attn_out = attn_out.unwrap("batch", "seq", "embedding") 

        return attn_out.astype(resid.dtype)

    resid = resids_pre[layer]
    return jax.vjp(f, resid)[1](grad.astype(resid.dtype),)