In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [5]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

In [6]:
def check_if_single_token(token):
    return len(tokenizer.tokenize(token)) == 1

task_name = "es_en"

task = tasks[task_name]

print(len(task))

task = {
    k:v for k,v in task.items() if check_if_single_token(k) and check_if_single_token(v)
}

print(len(task))

pairs = list(task.items())

763
346


In [7]:
def logprob_loss(logits, tokens, sep=1599, pad_token=32000, n_first=None, shift=None):
    logits = logits[:, :-1]

    # print(
    #     logits.argmax(axis=-1)
    # )

    logits = jnp.take_along_axis(logits, tokens[:, 1:, None], axis=-1).squeeze(-1)

    mask = tokens[:, 1:] == sep
    mask = jnp.cumsum(mask[:, ::-1], axis=-1)[:, ::-1] > 0
    mask = jnp.logical_not(mask)

    if shift is not None:
        rolled_mask = jnp.roll(mask, shift, axis=-1)
        mask = jnp.logical_and(mask, rolled_mask)

    # print(mask[:, -5:])
    
    if n_first is not None:
        rolled_mask = jnp.roll(mask, n_first, axis=-1)
        mask = jnp.logical_and(mask, jnp.logical_not(rolled_mask))

    mask = jnp.logical_and(mask, tokens[:, 1:] != pad_token)

    logits = logits * mask

    return logits.sum(axis=-1).mean(axis=-1)

In [8]:
from functools import partial

sep = 3978
pad = 0

def metric_fn(logits, resids, tokens):
    return logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2)

In [9]:
from micrlhf.llama import LlamaBlock, LlamaAttention
from micrlhf.utils.activation_manipulation import ActivationAddition, wrap_vector
from functools import partial
import jax.numpy as jnp
from penzai import pz
import jax

@partial(jax.jit, static_argnames=("metric", "batched"))
def run_with_add(additions_pre, additions_mid, tokens, metric, batched=False, llama=None):
    get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
        pz.nn.Sequential([
            pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
            x
        ])
    )
    get_resids = get_resids.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda l, b: b.select().at_instances_of(pz.nn.Residual).apply_with_selected_index(lambda i, x: x if i == 0 else pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_mid_{l}"),
        x,
    ])))


    get_resids = get_resids.select().at_instances_of(LlamaAttention).apply_with_selected_index(lambda i, x: x.select().at_instances_of(pz.nn.Softmax).apply(lambda b: pz.nn.Sequential([
        b,
        pz.de.TellIntermediate.from_config(tag=f"attn_{i}"),
    ])))

    get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: True)
    make_additions = get_resids.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
        pz.nn.Sequential([
            ActivationAddition(pz.nx.wrap(additions_pre[i], *(("batch",) if batched else ()), "seq", "embedding"), "all"),
            x
        ])
    )
    make_additions = make_additions.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda l, b: b.select().at_instances_of(pz.nn.Residual).apply_with_selected_index(lambda i, x: x if i == 0 else pz.nn.Sequential([
        ActivationAddition(pz.nx.wrap(additions_mid[l], *(("batch",) if batched else ()), "seq", "embedding"), "all"),
        x,
    ])))
    tokens_wrapped = pz.nx.wrap(tokens, "batch", "seq")
    logits, resids = make_additions(llama.inputs.from_basic_segments(tokens_wrapped))
    return metric(logits.unwrap("batch", "seq", "vocabulary"), resids, tokens), (logits, resids[::3], resids[1::3], resids[2::3])


@partial(jax.jit, static_argnames=("metric",))
def get_metric_resid_grad(tokens, llama=llama, metric=metric_fn):
    additions = [jnp.zeros(tokens.shape + (llama.config.hidden_size,)) for _ in range(llama.config.num_layers)]
    batched = tokens.ndim > 1
    (metric, (logits, resids_pre, qk, resids_mid)), (grad_pre, grad_mid) = jax.value_and_grad(run_with_add, argnums=(0, 1), has_aux=True)(additions, additions, tokens, metric, batched=batched, llama=llama)
    return (
        metric,
        [r.value.unwrap("batch", "seq", "embedding") for r in resids_pre],
        [r.value.unwrap("batch", "seq", "embedding") for r in resids_mid],
        [r.value.unwrap("batch", "kv_heads", "q_rep", "seq", "kv_seq") for r in qk],
        grad_pre,
        grad_mid
    )


In [10]:
batch_size = 8 
n_shot=20
max_seq_len = 128
seed = 10

In [11]:
prompt = "Follow the pattern:\n{}"

In [12]:
runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

In [13]:
from sprint.task_vector_utils import tokenized_to_inputs

train_tokens = runner.get_tokens(
    runner.train_pairs, tokenizer
)["input_ids"]

In [14]:
metric_value, resids_pre, resids_mid, qk, grad_pre, grad_mid = get_metric_resid_grad(train_tokens, llama=llama)

In [15]:
get_rms_block = lambda layer, resid_index: (
    llama.select()
    .at_instances_of(LlamaBlock).pick_nth_selected(layer)
    .at_instances_of(pz.nn.Residual).pick_nth_selected(resid_index)
    .at_instances_of(pz.nn.RMSLayerNorm).pick_nth_selected(0)
    ).get()

In [16]:
mlp_rms = [get_rms_block(layer, 1) for layer in range(llama.config.num_layers)]

In [17]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite

In [18]:
from micrlhf.utils.load_sae import weights_to_resid, resids_to_weights

In [19]:
from micrlhf.utils.load_sae import sae_encode_gated

In [20]:
def mlp_normalize(layer, resid_mid):
    # return resid_mid / resids_mid_norms[layer] * mlp_rms_weights[layer]
    # return resid_mid / jnp.linalg.norm(resid_mid, axis=-1, keepdims=True) * mlp_rms_weights[layer]
    return mlp_rms[layer](pz.nx.wrap(resid_mid, "batch", "seq", "embedding")).unwrap("batch", "seq", "embedding")

In [21]:
from tqdm.auto import tqdm

ie_attn = {}
sae_grads_attn = {}
ie_resid = {}
sae_grads_resid = {}
ie_transcoder = {}
sae_grads_transcoder = {}

ie_error_attn = {}
ie_error_resid = {}
ie_error_transcoder = {}

sae_error_attn = {}
sae_error_resid = {}
sae_error_transcoder = {}

def sfc_simple(grad, resid, target, sae):
    pre_relu, post_relu, recon = sae_encode_gated(sae, resid)
    error = target - recon
    f = partial(weights_to_resid, sae=sae)

    sae_grad, = jax.vjp(f, post_relu)[1](grad,)
    indirect_effects = sae_grad * post_relu
    indirect_effects_error = jnp.einsum("...f, ...f -> ...", grad, error)
    return indirect_effects, indirect_effects_error, sae_grad, error


layers = list(range(6, 17))
for l in tqdm(layers):
    r_pre, r_mid, g_mid = resids_pre[l], resids_mid[l], grad_mid[l]
    sae = get_nev_it_sae_suite(layer=l, label="attn_out")
    indirect_effects, indirect_effects_error, sae_grad, error = sfc_simple(g_mid, r_mid - r_pre, r_mid - r_pre, sae)
    # display((indirect_effects > 0).sum(-1))
    ie_attn[l] = indirect_effects
    ie_error_attn[l] = indirect_effects_error
    sae_grads_attn[l] = sae_grad
    sae_error_attn[l] = error

# for layer, (r_pre, g_pre) in enumerate(zip(resids_pre, grad_pre)):
for l in tqdm(layers):
    r_pre, g_pre = resids_pre[l], grad_pre[l]
    sae = get_nev_it_sae_suite(layer=l)
    indirect_effects, indirect_effects_error, sae_grad, error = sfc_simple(g_pre, r_pre, r_pre, sae)
    # display((indirect_effects != 0).sum(-1))
    ie_resid[l] = indirect_effects
    ie_error_resid[l] = indirect_effects_error
    sae_grads_resid[l] = sae_grad
    sae_error_resid[l] = error

for l in tqdm(layers[:-1]):
    r_mid, r_pre, g_pre = resids_mid[l], resids_pre[l + 1], grad_pre[l + 1]
    sae = get_nev_it_sae_suite(layer=l, label="transcoder")
    indirect_effects, indirect_effects_error, sae_grad, error = sfc_simple(g_pre, mlp_normalize(l, r_mid), r_pre - r_mid, sae)
    # display((indirect_effects != 0).sum(-1))
    ie_transcoder[l] = indirect_effects
    ie_error_transcoder[l] = indirect_effects_error
    sae_grads_transcoder[l] = sae_grad
    sae_error_transcoder[l] = error

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [22]:
tokens_wrapped = pz.nx.wrap(train_tokens, "batch", "seq")
llama_inputs = llama.inputs.from_basic_segments(tokens_wrapped)

In [23]:
# def attn_call(layer, resid_pre):
#     subblock = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(0).get().delta
#     subblock = subblock.select().at_instances_of(LlamaAttention).get()

#     si_selection = subblock.select().at_instances_of(pz.de.HandledSideInputRef)
#     keys = sorted(set([ref.tag for ref in si_selection.get_sequence()]))
#     replaced = si_selection.apply(lambda ref: pz.de.SideInputRequest(tag=ref.tag))
#     subblock = pz.de.WithSideInputsFromInputTuple.handling(replaced, keys)

#     side_inputs = {
#         'positions': llama_inputs.positions,
#         'attn_mask': llama_inputs.attention_mask
#     }
    
#     resid_pre = resid_pre / 

#     resid_pre = pz.nx.wrap(resid_pre, "batch", "seq", "embedding")
#     attn_out = subblock((resid_pre,) + tuple(side_inputs[tag] for tag in subblock.side_input_tags))

#     attn_out = attn_out.unwrap("batch", "seq", "embedding") 

#     return attn_out.astype(resid.dtype)

# attn_call(6, resids_pre[6])

In [24]:
prompt_length = len(tokenizer.tokenize(prompt))
periods = ["input", "arrow", "output", "newline"]
masks = {
    "prompt": jnp.zeros_like(train_tokens).at[:, :prompt_length].set(1).astype(bool),
    **{
        period: jnp.zeros_like(train_tokens).at[:, prompt_length+i::len(periods)].set(1).astype(bool) * (train_tokens != pad) for i, period in enumerate(periods)
    }
}

In [25]:
def mask_average(vector, mask):
    mask = masks[mask]
    while mask.ndim < vector.ndim:
        mask = mask[..., None]

    return ((mask * vector).sum(1) / mask.sum(1)).mean(0)

In [26]:
def transcoder_feature_to_mid(layer, feature_idx, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="transcoder")
    resid = resids_mid[layer]

    def f(resid):
        resid = mlp_normalize(layer, resid)
        batch_token_feat = resids_to_weights(resid, sae)[:, :, feature_idx] * sae_grads_transcoder[layer][:, :, feature_idx]
        token_act = mask_average(batch_token_feat, mask)
        return token_act

    return jax.grad(f)(resid)

In [27]:
def transcoder_error_to_mid(layer, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="transcoder")
    resid_next = resids_pre[layer + 1]
    resid = resids_mid[layer]

    grad = grad_pre[layer + 1]

    def f(resid):
        _, _, recon = sae_encode_gated(sae, resid)
        err_by_grad = jnp.einsum("...f, ...f -> ...", (resid_next - recon), grad)
        return mask_average(err_by_grad, mask)

    return jax.grad(f)(resid)

In [28]:
def attn_out_feature_to_pre(layer, feature_idx, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="attn_out")

    resid = resids_pre[layer]

    subblock = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(0).get().delta

    si_selection = subblock.select().at_instances_of(pz.de.HandledSideInputRef)
    keys = sorted(set([ref.tag for ref in si_selection.get_sequence()]))
    replaced = si_selection.apply(lambda ref: pz.de.SideInputRequest(tag=ref.tag))
    subblock = pz.de.WithSideInputsFromInputTuple.handling(replaced, keys)

    side_inputs = {
        'positions': llama_inputs.positions,
        'attn_mask': llama_inputs.attention_mask
    }

    def f(resid):
        resid = pz.nx.wrap(resid, "batch", "seq", "embedding")
        attn_out = subblock((resid,) + tuple(side_inputs[tag] for tag in subblock.side_input_tags))

        attn_out = attn_out.unwrap("batch", "seq", "embedding") 

        batch_token_feat = resids_to_weights(attn_out, sae)[:, :, feature_idx] * sae_grads_attn[layer][:, :, feature_idx]
        token_act = mask_average(batch_token_feat, mask)
        return token_act

    return jax.grad(f)(resid)

def attn_out_error_to_pre(layer, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="attn_out")

    resid = resids_pre[layer]

    subblock = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(0).get().delta

    si_selection = subblock.select().at_instances_of(pz.de.HandledSideInputRef)
    keys = sorted(set([ref.tag for ref in si_selection.get_sequence()]))
    replaced = si_selection.apply(lambda ref: pz.de.SideInputRequest(tag=ref.tag))
    subblock = pz.de.WithSideInputsFromInputTuple.handling(replaced, keys)

    side_inputs = {
        'positions': llama_inputs.positions,
        'attn_mask': llama_inputs.attention_mask
    }

    def f(resid):
        resid = pz.nx.wrap(resid, "batch", "seq", "embedding")
        attn_out = subblock((resid,) + tuple(side_inputs[tag] for tag in subblock.side_input_tags))

        attn_out = attn_out.unwrap("batch", "seq", "embedding") 

        _, _, recon = sae_encode_gated(sae, attn_out)
        batch_token_feat = jnp.einsum("...f, ...f -> ...", attn_out - recon, grad_mid[layer])
        token_act = mask_average(batch_token_feat, mask)
        return token_act

    return jax.grad(f)(resid)
# float(jnp.linalg.norm(attn_out_error_to_pre(6, "arrow")))

In [29]:
def pre_feature_to_pre(layer, feature_idx, mask):
    sae = get_nev_it_sae_suite(layer=layer)
    resid = resids_pre[layer]

    def f(resid):
        batch_token_feat = resids_to_weights(resid, sae)[:, :, feature_idx] * sae_grads_resid[layer][:, :, feature_idx]
        token_act = mask_average(batch_token_feat, mask)
        return token_act

    return jax.grad(f)(resid)

def pre_error_to_pre(layer, mask):
    sae = get_nev_it_sae_suite(layer=layer)
    resid = resids_pre[layer]

    def f(resid):
        _, _, recon = sae_encode_gated(sae, resid)
        batch_token_error = jnp.einsum("...f, ...f -> ...", (resid - recon), grad_pre[layer])
        token_grad = mask_average(batch_token_error, mask)
        return token_grad

    return jax.grad(f)(resid)


In [30]:
def ie_pre_to_transcoder_features(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="transcoder")
    resid_mid = resids_mid[layer]
    resid_mid = mlp_normalize(layer, resid_mid)
    ie = sfc_simple(grad, resid_mid, resid_mid, sae)[0]
    ie = mask_average(ie, mask)

    return ie

def ie_pre_to_transcoder_error(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="transcoder")
    resid_next = resids_pre[layer + 1]
    resid_mid = resids_mid[layer]
    ie = sfc_simple(grad, mlp_normalize(layer, resid_mid), resid_next - resid_mid, sae)[1]
    ie = mask_average(ie, mask)

    return ie

In [31]:
def ie_mid_to_attn_features(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="attn_out")
    resid_mid = resids_mid[layer]
    resid_pre = resids_pre[layer]

    ie = sfc_simple(grad, resid_mid - resid_pre, resid_mid - resid_pre, sae)[0]
    ie = mask_average(ie, mask)
    return ie

def ie_mid_to_attn_error(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer, label="attn_out")
    resid_mid = resids_mid[layer]
    resid_pre = resids_pre[layer]

    ie = sfc_simple(grad, resid_mid - resid_pre, resid_mid - resid_pre, sae)[1]
    ie = mask_average(ie, mask)
    return ie

In [32]:
def ie_pre_to_pre_features(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer)
    resid = resids_pre[layer]
    ie = sfc_simple(grad, resid, resid, sae)[0]
    ie = mask_average(ie, mask)
    return ie

def ie_pre_to_pre_error(layer, grad, mask):
    sae = get_nev_it_sae_suite(layer=layer)
    resid = resids_pre[layer]
    ie = sfc_simple(grad, resid, resid, sae)[1]
    ie = mask_average(ie, mask)
    return ie
# float((ie_pre_to_pre_features(6, grad_pre[6], "arrow") - mask_average(ie_error_resid[6], "arrow")).sum())

In [33]:
ie_pre_to_pre_features(6, grad_pre[6], "prompt")

In [33]:
def grad_through_transcoder(layer, grad):
    sae = get_nev_it_sae_suite(layer, label="transcoder")
    resid_mid = resids_mid[layer]

    def f(resid_mid):
        resid_mid = mlp_normalize(layer, resid_mid)
        # we ignore error nodes
        weights = resids_to_weights(resid_mid, sae)
        recon = weights_to_resid(weights, sae)

        return recon

    grad = jax.vjp(f, resid_mid)[1](grad,)[0]

    return grad


In [34]:
def grad_through_mlp(layer, grad):
    mlp = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(1).get().delta
    resid_mid = resids_mid[layer]
    def f(resid_mid):
        resids_mid = pz.nx.wrap(resid_mid, "batch", "seq", "embedding")
        out = mlp(resids_mid)
        return out.unwrap("batch", "seq", "embedding").astype(resid_mid.dtype)
    return jax.vjp(f, resid_mid)[1](grad.astype(resid_mid.dtype),)[0]

In [35]:
def grad_through_attn(layer, grad):
    subblock = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).at_instances_of(pz.nn.Residual).pick_nth_selected(0).get().delta

    si_selection = subblock.select().at_instances_of(pz.de.HandledSideInputRef)
    keys = sorted(set([ref.tag for ref in si_selection.get_sequence()]))
    replaced = si_selection.apply(lambda ref: pz.de.SideInputRequest(tag=ref.tag))
    subblock = pz.de.WithSideInputsFromInputTuple.handling(replaced, keys)

    side_inputs = {
        'positions': llama_inputs.positions,
        'attn_mask': llama_inputs.attention_mask
    }

    def f(resid):
        resid_pre = pz.nx.wrap(resid, "batch", "seq", "embedding")
        attn_out = subblock((resid_pre,) + tuple(side_inputs[tag] for tag in subblock.side_input_tags))

        attn_out = attn_out.unwrap("batch", "seq", "embedding") 

        return attn_out.astype(resid.dtype)

    resid = resids_pre[layer]
    return jax.vjp(f, resid)[1](grad.astype(resid.dtype),)[0]

In [37]:
ies = ie_resid[12]
ies = mask_average(ies, "arrow")

In [38]:
jax.lax.top_k(ies, 30)

In [39]:
from tqdm import tqdm, trange

def compute_edges(
    feature,
    mask,
    layer_window=3,
    k = 32,
    abs_effects = False,
):
    feature_type, layer, feature_idx = feature
    match feature_type:
        case "r":
            resid_grad = pre_feature_to_pre(layer, feature_idx, mask)
        case "t":
            resid_grad = transcoder_feature_to_mid(layer, feature_idx, mask)
        case "a":
            resid_grad = attn_out_feature_to_pre(layer, feature_idx, mask)
        case "er":
            resid_grad = pre_error_to_pre(layer, mask)
        case "et":
            resid_grad = transcoder_error_to_mid(layer, mask)
        case "ea":
            resid_grad = attn_out_error_to_pre(layer, mask)
    feature_effects = {}
    for l in range(layer, max(5, layer - 1 - layer_window), -1):
        if l < layer:
            for mask in masks:
                feature_effects[("t", l, mask)] = ie_pre_to_transcoder_features(l, resid_grad, mask)
                feature_effects[("et", l, mask)] = ie_pre_to_transcoder_error(l, resid_grad, mask)
        # # does not work # resid_grad = resid_grad - grad_through_mlp(layer, resid_grad)
        # resid_grad = resid_grad + grad_through_mlp(layer, resid_grad)
        if l < layer or feature_type in ("t", "et"):
            for mask in masks:
                feature_effects[("a", l, mask)] = ie_mid_to_attn_features(l, resid_grad, mask)
                feature_effects[("ea", l, mask)] = ie_mid_to_attn_error(l, resid_grad, mask)
        # # does not work # resid_grad = resid_grad - grad_through_attn(layer, resid_grad)
        # resid_grad = resid_grad + grad_through_attn(layer, resid_grad)
        if l < layer or feature_type in ("t", "et", "a", "ea"):
            for mask in masks:
                feature_effects[("r", l, mask)] = ie_pre_to_pre_features(l, resid_grad, mask)
                feature_effects[("er", l, mask)] = ie_pre_to_pre_error(l, resid_grad, mask)
    top_effects = []
    for key, features in feature_effects.items():
        if features.ndim == 0:
            top_effects.append((float(features), key, 0))
            continue
        effects, indices = jax.lax.top_k(features if not abs_effects else jnp.abs(features), k)
        for i, e in zip(indices.tolist(), effects.tolist()):
            top_effects.append((e, key, i))
    top_effects.sort(reverse=True)
    return top_effects[:k]


In [40]:
important_feats_masks = {}
for mask in masks:
    all_important_features = []
    ie_arrays = dict(r=ie_resid, a=ie_attn, t=ie_transcoder)
    k = 4
    for layer in layers[1:-1]:
        for type in "rat":
            all_important_features.append(("e" + type, layer, 0))
            w, i = jax.lax.top_k(mask_average(ie_arrays[type][layer], mask), k)
            for feat in i.tolist():
                all_important_features.append((type, layer, feat))
    important_feats_masks[mask] = all_important_features

In [41]:
from tqdm.auto import tqdm
graph = []
for mask, features in tqdm(important_feats_masks.items()):
    for downstream_feature in tqdm(features):
        edges = compute_edges(downstream_feature, mask, layer_window=2)
        graph.extend([(weight, upstream_feature_key + (upstream_feature,), downstream_feature[:2] + (mask,) + downstream_feature[2:])
                      for weight, upstream_feature_key, upstream_feature in edges])

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/135 [00:00<?, ?it/s]

  0%|          | 0/135 [00:00<?, ?it/s]

  0%|          | 0/135 [00:00<?, ?it/s]

  0%|          | 0/135 [00:00<?, ?it/s]

  0%|          | 0/135 [00:00<?, ?it/s]

In [42]:
downstream_feature

In [43]:
edges[:4]

In [44]:
sorted_graph = sorted(graph, reverse=True, key=lambda x: x[0])

In [45]:
n_nodes = sum(map(len, important_feats_masks.values()))
k_connections = 4
weight_threshold = sorted_graph[n_nodes * k_connections][0]

In [46]:
sum(len())

TypeError: len() takes exactly one argument (0 given)

In [None]:
from IPython.display import Image as IImage
from tqdm.auto import tqdm
import graphviz
import random
import math

digraph = graphviz.Digraph(graph_attr={"overlap": "scale", "splines": "ortho", "engine": "sfdp"})
edges_out = set()
penwidth = 2
max_w = sorted_graph[0][0]

def feature_name(key):
    return f"{key[0]}_{key[1]}_{key[2]}_{key[3]}"


def key_to_node(key):
    name = feature_name(key)
    shape = {"r": "ellipse", "a": "egg", "t": "box", "e": "diamond"}[key[0][0]]
    color = {"r": "red", "a": "blue", "t": "green"}[key[0][-1]]

    digraph.node(name, label=name, shape=shape, color=color) 
    

ws = []
for w, uf, df in graph:
    if w < weight_threshold:
        continue

    u_name = feature_name(uf)
    d_name = feature_name(df)

    log_weight = penwidth * (math.log2(w) - math.log2(weight_threshold)) / (math.log2(max_w) - math.log2(weight_threshold))
    ws.append(log_weight)

    digraph.edge(u_name, d_name, color="black", penwidth=f"{log_weight:.4f}")

    key_to_node(uf)
    key_to_node(df)



digraph.render("data/icl-sfc", format="png", cleanup=True)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
dot: graph is too large for cairo-renderer bitmaps. Scaling by 0.70213 to fit


In [None]:
max_w

In [49]:
import json
with open("micrlhf-progress/graph-short.json", 'w') as f:
    json.dump(graph[:100], f)

In [76]:
@jax.jit
def ablated_metric(llama_ablated):
    ablated_logits = llama_ablated(llama_inputs)
    return metric_fn(ablated_logits.unwrap("batch", "seq", "vocabulary"), None, train_tokens)

In [None]:
jnp.abs()

In [74]:
from micrlhf.llama import LlamaMLP
from typing import Dict, List
import dataclasses


@pz.pytree_dataclass
class AblatedModule(pz.Layer):
    module: pz.Layer
    sae: dict
    keep_features: Dict[str, jax.typing.ArrayLike]

    @classmethod
    def wrap(self, sae, ablated_features, module=None):
        if module is None:
            module = pz.nn.Identity()
        return AblatedModule(module, sae, ablated_features)

    def __call__(self, x):
        inp = x.unwrap("batch", "seq", "embedding")
        out = self.module(x)
        result = 0
        out = out.unwrap("batch", "seq", "embedding")
        for mask, mask_values in masks.items():
            _, _, recon = sae_encode_gated(self.sae, inp)
            error = out - recon
            _, _, recon_ablated = sae_encode_gated(self.sae, inp, keep_features=self.keep_features[mask])
            res = recon_ablated + error
            res = res * mask_values[..., None]
            result = result + res
        return pz.nx.wrap(result, "batch", "seq", "embedding")


llama_ablated = llama
for layer in tqdm(layers):
    block_selection = llama_ablated.select().at_instances_of(LlamaBlock).pick_nth_selected(layer)

    def converter(block):
        try:
            attn_out = get_nev_it_sae_suite(layer=layer, label="attn_out")
            mask_attn_out = mask_ie(ie_attn[layer])
            block = block.select().at_instances_of(LlamaAttention).apply(lambda x: pz.nn.Sequential([x, AblatedModule.wrap(attn_out, mask_attn_out)]))
        except KeyError:
            pass

        try:
            transcoder = get_nev_it_sae_suite(layer=layer, label="transcoder")
            mask_transcoder = mask_ie(ie_transcoder[layer])
            block = block.select().at_instances_of(LlamaMLP).apply(lambda x: AblatedModule.wrap(transcoder, mask_transcoder, x))
        except KeyError:
            pass
        return block

    llama_ablated = block_selection.apply(converter)

  0%|          | 0/11 [00:00<?, ?it/s]

In [75]:

float(ablated_metric(llama_ablated))

In [None]:
r = 0
for k, v in masks.items():
    r = r + v
int(r.max())

In [None]:
def run_with_ablation(ablation_mask):
    