In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [5]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

In [6]:
def check_if_single_token(token):
    return len(tokenizer.tokenize(token)) == 1

task_name = "es_en"

task = tasks[task_name]

print(len(task))

task = {
    k:v for k,v in task.items() if check_if_single_token(k) and check_if_single_token(v)
}

print(len(task))

pairs = list(task.items())

763
346


In [26]:
from sprint.task_vector_utils import logprob_loss
from functools import partial

sep = 3978
pad = 0

def metric_fn(logits, resids, tokens):
    return logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2)

In [62]:
from micrlhf.llama import LlamaBlock, LlamaAttention
from micrlhf.utils.activation_manipulation import ActivationAddition, wrap_vector
from functools import partial
import jax.numpy as jnp
from penzai import pz
import jax

@partial(jax.jit, static_argnames=("metric", "batched"))
def run_with_add(additions_pre, additions_mid, tokens, metric, batched=False, llama=None):
    get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
        pz.nn.Sequential([
            pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
            x
        ])
    )
    get_resids = get_resids.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda l, b: b.select().at_instances_of(pz.nn.Residual).apply_with_selected_index(lambda i, x: x if i == 0 else pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_mid_{l}"),
        x,
    ])))


    get_resids = get_resids.select().at_instances_of(LlamaAttention).apply_with_selected_index(lambda i, x: x.select().at_instances_of(pz.nn.Softmax).apply(lambda b: pz.nn.Sequential([
        b,
        pz.de.TellIntermediate.from_config(tag=f"attn_{i}"),
    ])))

    get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: True)
    make_additions = get_resids.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
        pz.nn.Sequential([
            ActivationAddition(pz.nx.wrap(additions_pre[i], *(("batch",) if batched else ()), "seq", "embedding"), "all"),
            x
        ])
    )
    make_additions = make_additions.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda l, b: b.select().at_instances_of(pz.nn.Residual).apply_with_selected_index(lambda i, x: x if i == 0 else pz.nn.Sequential([
        ActivationAddition(pz.nx.wrap(additions_mid[l], *(("batch",) if batched else ()), "seq", "embedding"), "all"),
        x,
    ])))
    tokens_wrapped = pz.nx.wrap(tokens, "batch", "seq")
    logits, resids = make_additions(llama.inputs.from_basic_segments(tokens_wrapped))
    return metric(logits.unwrap("batch", "seq", "vocabulary"), resids, tokens), (logits, resids[::3], resids[1::3], resids[2::3])


@partial(jax.jit, static_argnames=("metric",))
def get_metric_resid_grad(tokens, llama=llama, metric=metric_fn):
    additions = [jnp.zeros(tokens.shape + (llama.config.hidden_size,)) for _ in range(llama.config.num_layers)]
    batched = tokens.ndim > 1
    (metric, (logits, resids_pre, qk, resids_mid)), (grad_pre, grad_mid) = jax.value_and_grad(run_with_add, argnums=(0, 1), has_aux=True)(additions, additions, tokens, metric, batched=batched, llama=llama)
    return (
        metric,
        [r.value.unwrap("batch", "seq", "embedding") for r in resids_pre],
        [r.value.unwrap("batch", "seq", "embedding") for r in resids_mid],
        [r.value.unwrap("batch", "kv_heads", "q_rep", "seq", "kv_seq") for r in qk],
        grad_pre,
        grad_mid
    )


In [45]:
batch_size = 8 
n_shot=20
max_seq_len = 128
seed = 10

In [46]:
runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed)

In [47]:
from sprint.task_vector_utils import tokenized_to_inputs

train_tokens = runner.get_tokens(
    runner.train_pairs, tokenizer
)["input_ids"]

In [48]:
# train_tokens = jnp.asarray(train_tokens)
# train_tokens = jax.device_put(train_tokens, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
# train_tokens = pz.nx.wrap(train_tokens, "batch", "seq").untag("batch").tag("batch")


In [63]:
metric_value, resids_pre, resids_mid, qk, grad_pre, grad_mid = get_metric_resid_grad(train_tokens, llama=llama)

{'seq': 128, 'batch': 8, 'kv_heads': 1, 'q_rep': 8, 'kv_seq': 128}


In [69]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite

In [16]:
layer = 12

In [17]:
sae = get_nev_it_sae_suite(layer=layer)

--2024-07-11 19:16:51--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l12-residual-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.157.142.74, 108.157.142.55, 108.157.142.53, ...
Connecting to huggingface.co (huggingface.co)|108.157.142.74|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/482902b98aceb9c73042f4ebd95aa5742af2d793c8812407b806a3166900a2ef?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720984611&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4NDYxMX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZjgv

In [77]:
resid = results[1][layer]

In [78]:
resid = resid

In [79]:
resid

In [65]:
# grad_mid[0]  # 8, 128, 2048

In [67]:
def weights_to_resid(weights):
    if "s_gate" in sae:
        weights = (weights > 0) * jax.nn.relu(weights * jax.nn.softplus(sae["s_gate"]) * sae.get("scaling_factor", 1.0) + sae["b_gate"])   
    else:
        weights = jax.nn.relu(weights)

    recon = jnp.einsum("fv,bsf->bsv", sae["W_dec"], weights)
    # recon = recon.astype('bfloat16')
    return recon

In [66]:
from micrlhf.utils.load_sae import sae_encode_gated



In [72]:
from tqdm.auto import tqdm

def sfc_simple(grad, resid, sae):
    pre_relu, post_relu, recon = sae_encode_gated(sae, resid)
    sae_grad, = jax.vjp(weights_to_resid, post_relu)[1](grad)
    indirect_effects = sae_grad * post_relu
    return indirect_effects
layers = [6,] + list(range(8, 17))
for layer in tqdm(layers):
    r_pre, r_mid, g_mid = resids_pre[layer], resids_mid[layer], grad_mid[layer]
    sae = get_nev_it_sae_suite(layer=layer, label="attn_out")
    indirect_effects = sfc_simple(g_mid, r_mid - r_pre, sae)
    display((indirect_effects > 0).sum(-1))

# for layer, (r_pre, g_pre) in enumerate(zip(resids_pre, grad_pre)):
for layer in tqdm(layers):
    r_pre, g_pre = resids_pre[layer], grad_pre[layer]
    sae = get_nev_it_sae_suite(layer=layer)
    indirect_effects = sfc_simple(g_pre, r_pre, sae)
    display((indirect_effects != 0).sum(-1))

[[1 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [1 0 0 ... 0 0 0]
 ...
 [1 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]


--2024-07-11 20:40:07--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l8-attn_out-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.157.142.50, 108.157.142.74, 108.157.142.53, ...
Connecting to huggingface.co (huggingface.co)|108.157.142.50|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/818e2aaa054ff1666e77883325a85a3494c975a70e2cc4bbcffed883ae08851e?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720989607&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4OTYwN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZjgvO

[[0 1 1 ... 0 0 0]
 [1 0 1 ... 0 0 0]
 [1 0 0 ... 0 0 0]
 ...
 [1 0 0 ... 0 0 0]
 [1 1 0 ... 0 0 0]
 [1 0 0 ... 0 0 0]]


--2024-07-11 20:40:34--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l9-attn_out-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.157.142.50, 108.157.142.55, 108.157.142.74, ...
Connecting to huggingface.co (huggingface.co)|108.157.142.50|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/4deb1770aadfd566e0e6af58ee77509609d4727efff621cc6a58a632fd9d3faa?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720989635&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4OTYzNX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZjgvN

[[0 3 0 ... 0 0 0]
 [4 1 2 ... 0 0 0]
 [5 5 5 ... 0 0 0]
 ...
 [4 3 5 ... 0 0 0]
 [4 2 3 ... 0 0 0]
 [4 3 3 ... 0 0 0]]


--2024-07-11 20:41:02--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l10-attn_out-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.157.142.50, 108.157.142.74, 108.157.142.53, ...
Connecting to huggingface.co (huggingface.co)|108.157.142.50|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/39babf724ff5dedfd120501507cbf321ff1852aa39e246167bb0bd41dbabfc61?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720989662&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4OTY2Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZjgv

[[1 3 2 ... 0 0 0]
 [5 5 1 ... 0 0 0]
 [7 2 1 ... 0 0 0]
 ...
 [2 4 1 ... 0 0 0]
 [4 3 1 ... 0 0 0]
 [2 2 1 ... 0 0 0]]


--2024-07-11 20:41:29--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l11-attn_out-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.157.142.74, 108.157.142.55, 108.157.142.50, ...
Connecting to huggingface.co (huggingface.co)|108.157.142.74|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/90ccc335056eda8de67fa88e7a14a45fb46333e2347919379cf691e6339bc861?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720989689&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4OTY4OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZjgv

[[0 1 1 ... 0 0 0]
 [1 0 1 ... 0 0 0]
 [0 1 1 ... 0 0 0]
 ...
 [0 0 2 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 [0 1 0 ... 0 0 0]]


--2024-07-11 20:41:49--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l12-attn_out-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.157.142.74, 108.157.142.50, 108.157.142.53, ...
Connecting to huggingface.co (huggingface.co)|108.157.142.74|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/ad7866788525c396bd6ad88d18634168f09cf11970f39c0cbe6745a00bc5f745?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720989709&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4OTcwOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZjgv

[[2 1 1 ... 0 0 0]
 [2 1 0 ... 0 0 0]
 [1 1 2 ... 0 0 0]
 ...
 [1 0 1 ... 0 0 0]
 [2 2 1 ... 0 0 0]
 [1 1 0 ... 0 0 0]]


--2024-07-11 20:42:13--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l13-attn_out-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.157.142.74, 108.157.142.50, 108.157.142.53, ...
Connecting to huggingface.co (huggingface.co)|108.157.142.74|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/b125d1b97b0c7c727a6dd58fa37891519e4f59b50b13f4bdf766d9282a859924?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720989734&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4OTczNH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZjgv

[[5 1 5 ... 0 0 0]
 [5 0 0 ... 0 0 0]
 [5 1 0 ... 0 0 0]
 ...
 [5 3 2 ... 0 0 0]
 [5 5 5 ... 0 0 0]
 [5 4 2 ... 0 0 0]]


--2024-07-11 20:42:48--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l14-attn_out-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.51, 108.156.211.125, 108.156.211.95, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.51|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/71d5a89d5d7288b6887aa7d7a6f1ff8ed9d7cb3b57be4b33b34b170156da05b6?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720989768&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4OTc2OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZjg

[[4 3 1 ... 0 0 0]
 [1 3 3 ... 0 0 0]
 [6 2 1 ... 0 0 0]
 ...
 [3 5 4 ... 0 0 0]
 [5 3 3 ... 0 0 0]
 [4 1 1 ... 0 0 0]]


...... .......... 93% 88.1M 1s
491800K .......... .......... .......... .......... .......... 93% 79.3M 1s
491850K .......... .......... .......... .......... .......... 93% 92.1M 1s
491900K .......... .......... .......... .......... .......... 93%  248M 1s
491950K .......... .......... .......... .......... .......... 93% 80.9M 1s
492000K .......... .......... .......... .......... .......... 93% 92.0M 1s
492050K .......... .......... .......... .......... .......... 93% 19.4M 1s
492100K .......... .......... .......... .......... .......... 93% 9.49M 1s
492150K .......... .......... .......... .......... .......... 93%  105M 1s
492200K .......... .......... .......... .......... .......... 93% 30.3M 1s
492250K .......... .......... .......... .......... .......... 93%  408M 1s
492300K .......... .......... .......... .......... .......... 93%  130M 1s
492350K .......... .......... .......... .......... .......... 93%  299M 1s
492400K .......... .......... .......... .......... .....

[[14  6 12 ...  0  0  0]
 [15  8 10 ...  0  0  0]
 [24 12  6 ...  0  0  0]
 ...
 [21 12 13 ...  0  0  0]
 [23 15 14 ...  0  0  0]
 [25  6 11 ...  0  0  0]]


--2024-07-11 20:43:15--  https://huggingface.co/nev/gemma-2b-saex-test/resolve/main/it-l16-attn_out-test-run-1-2.00E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.95, 108.156.211.90, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/41/2a/412ab25f82137bafe7bc1655651e4e6f7eeae46d2504fc103b2b3624ff745ff8/63d2d3e062ca3027f5bc1c2bd5c47d9552fe02f4a2273e285f5d02f18702bbfa?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1720989795&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMDk4OTc5NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzQxLzJhLzQxMmFiMjVmODIxMzdiYWZlN2JjMTY1NTY1MWU0ZTZmN2VlYWU0NmQyNTA0ZmMxMDNiMmIzNjI0ZmY3NDVmZj

In [89]:
post_relu

In [90]:
sae_grad = jax.vjp(weights_to_resid, post_relu)[1](results[2][layer])

In [91]:
sae_grad

In [69]:
sae_grad

In [None]:
post_relu.dtype

In [59]:
post_relu.shape     

In [37]:
llama