In [1]:
# %%
import penzai
from penzai import pz
# %%
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

# %%

In [2]:
%load_ext autoreload
%autoreload 2

import json

from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [3]:
import plotly.express as px
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True)

In [4]:
from transformers import AutoTokenizer
import jax


# %%
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [5]:
# %%
from sprint.task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()

In [6]:

# %%
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)


In [7]:
def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


In [8]:
from sprint.task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder

from micrlhf.utils.load_sae import sae_encode, sae_encode_gated

from safetensors import safe_open

In [9]:

layer = 12


In [10]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite

sae = get_nev_it_sae_suite(layer)


In [11]:
# %%
sep = 3978
pad = 0


In [12]:
task_names = [
    "es_en"
]
n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 16, 256

prompt = "Follow the pattern:\n{}"

In [13]:
from sprint.task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder

from micrlhf.utils.load_sae import sae_encode, sae_encode_gated

from safetensors import safe_open
from sprint.task_vector_utils import FeatureSearch

seed = 10


collected_weights = []

for task in tqdm(task_names):
    pairs = list(tasks[task].items())

    n_shot = 40

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=seed, prompt=prompt)


    tokenized = runner.get_tokens([
        x[:n_shot] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, resids = get_resids_call(inputs)

    resids = resids[layer].value.unwrap(
        "batch", "seq", "embedding"
    )

    tv = get_tv(resids, train_tokens, shift = 0, sep=sep)

    tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    add_act = make_act_adder(llama, tv.astype('bfloat16'), tokens, layer, length=1, shift= 0, sep=sep)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2, sep=sep, pad_token=0
    )

    print(
        f"TV: {task}, L: {layer}, Loss: {loss}"  
    )

    _, pr, rtv = sae_encode_gated(sae, tv)

    add_act = make_act_adder(llama, rtv.astype('bfloat16'), tokens, layer, length=1, shift= 0, sep=sep)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2, sep=sep, pad_token=0
    )

    print(
        f"Recon TV: {task}, L: {layer}, Loss: {loss}"  
    )

    fs = FeatureSearch(task, pairs, layer, llama, tokenizer, n_shot=1, seed=seed+100, init_w=pr, early_stopping_steps=200, n_first=2, sep=sep, pad_token=0, sae_v=8, sae=sae, batch_size=24, iterations=400, prompt=prompt, l1_coeff=0.05)

    w, m = fs.find_weights()

    collected_weights.append(
        w                                           
    )

    # weights = (w > 0) * jax.nn.relu(w * jax.nn.softplus(sae["s_gate"]) + sae["b_gate"])   

    weights = jax.nn.relu(w)

    recon = jnp.einsum("fv,f->v", sae["W_dec"], weights) + sae["b_dec"]
    recon = recon.astype('bfloat16')

    add_act = make_act_adder(llama, recon, tokens, layer, length=1, shift= 0, sep=sep)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2, sep=sep, pad_token=0
    )

    print(
        f"Recon fs: {task}, L: {layer}, Loss: {loss}"  
    )

    with open("cleanup_results.jsonl", "a") as f:
        item = {
            "task": task,
            "weights": w.tolist(),
            "loss": loss.tolist(),
            "tv": tv.tolist(),
        }

        f.write(json.dumps(item) + "\n")

  0%|          | 0/1 [00:00<?, ?it/s]

TV: es_en, L: 12, Loss: 3.1711559295654297
Recon TV: es_en, L: 12, Loss: 3.467447280883789


  0%|          | 0/400 [00:00<?, ?it/s]

Recon fs: es_en, L: 12, Loss: 3.6704518795013428


In [14]:
with open("cleanup_results.jsonl") as f:
    for line in f:
        data = json.loads(line)
        print(data["task"])
        print(data["loss"])
        w = jnp.array(data["weights"])
        display(jax.lax.top_k(w, 15))

en_es
3.0577216148376465


[Array([7.781368  , 7.687472  , 4.5929112 , 2.3076181 , 1.8058963 ,
        0.9904503 , 0.26809523, 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],      dtype=float32),
 Array([  330, 15886,  3079,  3196,  4643,  7800, 11586,     0,     1,
            2,     3,     4,     5,     6,     7], dtype=int32)]

en_es
2.7183780670166016


[Array([9.630995 , 4.7369976, 3.1871123, 2.4100933, 2.3127606, 1.0788292,
        0.8785468, 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       ], dtype=float32),
 Array([15886,  3079,  2207, 11836,  4643,  7018,  7800,     0,     1,
            2,     3,     4,     5,     6,     7], dtype=int32)]

en_fr
9.213409423828125


[Array([12.659199  ,  9.891936  ,  3.5718477 ,  2.161456  ,  1.9997382 ,
         0.20197022,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],      dtype=float32),
 Array([13367,   330,  7018,  3079,   369,  8618,     0,     1,     2,
            3,     4,     5,     6,     7,     8], dtype=int32)]

antonyms
1.9884886741638184


[Array([5.149931  , 4.4392924 , 4.424511  , 2.9623194 , 1.0613811 ,
        0.41651246, 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],      dtype=float32),
 Array([ 2207, 12017, 15392,  2468,  8226,  4871,     0,     1,     2,
            3,     4,     5,     6,     7,     8], dtype=int32)]

es_en
2.3643341064453125


[Array([5.7215486, 4.077082 , 1.6318746, 1.3926536, 0.       , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       ], dtype=float32),
 Array([ 883, 4061,  181, 1908,    0,    1,    2,    3,    4,    5,    6,
           7,    8,    9,   10], dtype=int32)]

en_it
6.378209114074707


[Array([10.971681  ,  6.341071  ,  4.6003017 ,  2.8350635 ,  0.43321684,
         0.36964452,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],      dtype=float32),
 Array([15886,  7018,  7800,   369,  4424,  3079,     0,     1,     2,
            3,     4,     5,     6,     7,     8], dtype=int32)]

fr_en
3.789034843444824


[Array([5.05083   , 4.956336  , 3.0223713 , 1.8414422 , 0.06206318,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],      dtype=float32),
 Array([  883,   181, 15288,  8684,  2207,     0,     1,     2,     3,
            4,     5,     6,     7,     8,     9], dtype=int32)]

location_continent
2.0601282119750977


[Array([11.718188  ,  8.453872  ,  7.960262  ,  6.295211  ,  6.0035276 ,
         3.7814338 ,  2.8840249 ,  2.545268  ,  0.85108435,  0.45151338,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ],      dtype=float32),
 Array([11896,  1685, 12805, 12803,   330,  2207, 15176, 15315,  6637,
         1236,     0,     1,     2,     3,     4], dtype=int32)]

location_language
2.383928060531616


[Array([9.012982 , 4.8249397, 2.3631427, 2.3532834, 1.392308 , 0.       ,
        0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
        0.       , 0.       , 0.       ], dtype=float32),
 Array([ 2417, 12975,  1236,  6637, 11552,     0,     1,     2,     3,
            4,     5,     6,     7,     8,     9], dtype=int32)]

person_profession
3.4659793376922607


[Array([5.0680375 , 4.5605073 , 3.216459  , 2.551515  , 1.98678   ,
        1.8147309 , 1.5620106 , 0.94064075, 0.16524048, 0.07524845,
        0.        , 0.        , 0.        , 0.        , 0.        ],      dtype=float32),
 Array([15885, 13589,  4760,   482,  1174, 16341,  6637,  8226,  1236,
         4478,     0,     1,     2,     3,     4], dtype=int32)]

es_en
2.475567102432251


[Array([4.174203  , 4.085805  , 2.9904673 , 2.4196987 , 2.383186  ,
        2.2566912 , 1.9916719 , 1.9025337 , 1.0599831 , 0.9431557 ,
        0.60823464, 0.5850021 , 0.08400264, 0.        , 0.        ],      dtype=float32),
 Array([20995, 25680,  8668,  4126, 17274, 31871, 28080, 28491, 27988,
         7028, 26180,  9258,  9497,     0,     1], dtype=int32)]

antonyms
2.1853365898132324


[Array([4.3000865 , 4.0614967 , 3.4217446 , 2.8159714 , 2.2184155 ,
        2.0036063 , 1.8057712 , 1.629368  , 0.8533016 , 0.49197406,
        0.34810975, 0.3253169 , 0.        , 0.        , 0.        ],      dtype=float32),
 Array([10472,  2670,  6462, 19190, 11906, 27988,  8668,  9258,  8159,
        20995, 15648,  9497,     0,     1,     2], dtype=int32)]

es_en
2.4363996982574463


[Array([4.991836  , 3.4801493 , 3.4285762 , 3.2275271 , 3.22435   ,
        3.2031236 , 3.1699693 , 2.7637427 , 1.4658012 , 0.8083773 ,
        0.60823464, 0.13834275, 0.01402632, 0.        , 0.        ],      dtype=float32),
 Array([25680, 17274, 28080,  8668, 28491,  4126, 31871, 20995, 27988,
        32590, 26180,  7028,  9497,     0,     1], dtype=int32)]

es_en
2.7721495628356934


[Array([4.981508  , 4.1341496 , 3.6995676 , 3.5769396 , 3.1237977 ,
        3.0046246 , 2.640297  , 1.8352877 , 1.7989156 , 0.5231221 ,
        0.34189785, 0.181504  , 0.        , 0.        , 0.        ],      dtype=float32),
 Array([10574,  6382,  1777,  6882,  9180, 31923, 25260, 18303, 10722,
        28088, 12773, 22236,     0,     1,     2], dtype=int32)]

es_en
3.3540031909942627


[Array([5.964342 , 5.4778223, 5.4725738, 5.438234 , 4.9570684, 4.7402883,
        4.710356 , 4.6115384, 4.248613 , 3.0572634, 2.3268917, 2.1025603,
        1.8345332, 1.5063884, 1.4786938], dtype=float32),
 Array([21043, 20013, 29597, 12039, 25421,  1272, 10807, 21523,  4560,
        27329,  3925, 30363,  5217,  2114, 17969], dtype=int32)]

es_en
2.688720464706421


[Array([7.1359415e+00, 7.1008859e+00, 6.5918088e+00, 6.4541144e+00,
        5.4941392e+00, 5.4563355e+00, 3.1585410e+00, 2.6410753e-03,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32),
 Array([ 8379,  1481,  2612, 12626,  8426, 15973, 22035,  6277,     0,
            1,     2,     3,     4,     5,     6], dtype=int32)]

es_en
3.6704518795013428


[Array([6.8068395 , 5.6830826 , 5.675987  , 5.1919236 , 0.09097664,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],      dtype=float32),
 Array([32345,  9157, 28537, 23723, 17160,     0,     1,     2,     3,
            4,     5,     6,     7,     8,     9], dtype=int32)]

In [15]:

from huggingface_hub import HfFileSystem

fs = HfFileSystem()
weights = fs.glob(f"nev/gemma-2b-saex-test/it-l{layer}-residual-test-run-1-*/*.safetensors")

In [16]:
weights

['nev/gemma-2b-saex-test/it-l12-residual-test-run-1-2.00E-05/sae_weights.safetensors']

In [17]:
data["weights"]

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0

In [18]:
sae["mean_norm"]

Array(28.125, dtype=bfloat16)

: 