In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")


In [2]:
from penzai import pz
import json

from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector


In [3]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2-2b-it.gguf",
                                         from_type="gemma2",
                                         load_eager=True
                                         )

In [4]:
from transformers import AutoTokenizer
import jax
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [5]:
from sprint.task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()

In [6]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)


In [7]:
def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs



In [8]:
from safetensors import safe_open


sep = 3978
pad = 0


In [9]:
n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 16, 256

prompt = "<start_of_turn>user\nFollow the pattern:\n{}"

In [10]:
task_names = [
    "antonyms"
]

In [11]:
from sprint.task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder

from micrlhf.utils.load_sae import sae_encode, get_dm_res_sae, weights_to_resid

from safetensors import safe_open
from sprint.task_vector_utils import FeatureSearch
from micrlhf.utils.ito import grad_pursuit

seed = 10

layers = list(range(8, llama.config.num_layers - 2, 2))
layers = [18]

use_65k = True

In [14]:
for task in tqdm(task_names):
    pairs = list(tasks[task].items())

    n_shot = n_few_shots - 1
    if task.startswith("algo"):
        n_shot = 16

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)


    tokenized = runner.get_tokens([
        x[:n_shot] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    logits = llama(inputs)
    
    zero_loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
    )

    print(
        f"Zero: {task}, Loss: {zero_loss}"  
    )

    for layer in layers:
        sae = get_dm_res_sae(layer, load_65k=use_65k)

        resids = all_resids[layer].value.unwrap(
            "batch", "seq", "embedding"
        )

        tv = get_tv(resids, train_tokens, shift = 0, sep=sep)

        add_act = make_act_adder(llama, tv.astype('bfloat16'), tokens, layer, length=1, shift= 0, sep=sep)

        logits = add_act(inputs)

        tv_loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
        )

        print(
            f"TV: {task}, L: {layer}, Loss: {tv_loss}"  
        )
        

        pr, _, rtv = sae_encode(sae, tv)

        add_act = make_act_adder(llama, rtv.astype('bfloat16'), tokens, layer, length=1, shift= 0, sep=sep)

        logits = add_act(inputs)

        recon_loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
        )

        print(
            f"Recon TV: {task}, L: {layer}, Loss: {recon_loss}"  
        )

        _, gtv = grad_pursuit(tv, sae["W_dec"], 20)

        add_act = make_act_adder(llama, gtv.astype('bfloat16'), tokens, layer, length=1, shift= 0, sep=sep)

        logits = add_act(inputs)

        ito_loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
        )

        print(
            f"Grad pursuit TV: {task}, L: {layer}, Loss: {ito_loss}"
        )

        fs = FeatureSearch(task, pairs, layer, llama, tokenizer, n_shot=1, seed=seed+100, init_w=pr, early_stopping_steps=50, n_first=2, sep=sep, pad_token=0, sae_v=8, sae=sae, batch_size=24, iterations=1000, prompt=prompt, l1_coeff=0.003 if use_65k else 0.005, lr=0.15 if use_65k else 0.09)

        w, m = fs.find_weights()

        _, _, recon = sae_encode(sae, None, pre_relu=w)

        # recon = weights_to_resid(w, sae)

        add_act = make_act_adder(llama, recon, tokens, layer, length=1, shift= 0, sep=sep)

        logits = add_act(inputs)

        loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=0, n_first=2, sep=sep, pad_token=0
        )

        print(
            f"Recon fs: {task}, L: {layer}, Loss: {loss}"  
        )

        # with open("cleanup_results_gemma_2_post.jsonl", "a") as f:
        #     item = {
        #         "task": task,
        #         "weights": w.tolist(),
        #         "loss": loss.tolist(),
        #         "recon_loss": recon_loss.tolist(),
        #         "ito_loss": ito_loss.tolist(),
        #         "tv_loss": tv_loss.tolist(),
        #         "zero_loss": zero_loss.tolist(),
        #         "tv": tv.tolist(),
        #         "layer": layer
        #     }

        #     f.write(json.dumps(item) + "\n")

  0%|          | 0/1 [00:00<?, ?it/s]

Zero: antonyms, Loss: 3.46875
TV: antonyms, L: 18, Loss: 2.40625
Recon TV: antonyms, L: 18, Loss: 3.89062
Grad pursuit TV: antonyms, L: 18, Loss: 2.98438


  0%|          | 0/1000 [00:00<?, ?it/s]

51 1000000000.0
50 51
49 50
49 49
47 49
46 47
43 46
41 43
40 41
40 40
39 40
39 39
39 39
37 39
36 37
36 36
36 36
36 36
36 36
36 36
36 36
36 36
34 36
34 34
34 34
34 34
34 34
34 34
34 34
33 34
33 33
31 33
31 31
31 31
31 31
31 31
30 31
30 30
30 30
30 30
30 30
30 30
30 30
30 30
30 30
30 30
30 30
30 30
30 30
30 30
29 30
29 29
29 29
28 29
28 28
28 28
28 28
27 28
27 27
26 27
25 26
25 25
25 25
25 25
25 25
25 25
25 25
25 25
25 25
25 25
25 25
25 25
25 25
25 25
24 25
24 24
23 24
23 23
23 23
23 23
23 23
23 23
22 23
22 22
22 22
22 22
22 22
22 22
22 22
21 22
20 21
20 20
20 20
20 20
20 20
20 20
20 20
20 20
20 20
20 20
20 20
20 20
20 20
19 20
19 19
19 19
19 19
19 19
19 19
19 19
19 19
18 19
18 18
18 18
17 18
17 17
17 17
17 17
17 17
17 17
17 17
17 17
17 17
17 17
17 17
16 17
16 16
16 16
15 16
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
15 15
14 15
14 14
13 14
13 13
13 13
13 13
13 13
13 13
13 13
13 13
13 13




Recon fs: antonyms, L: 18, Loss: 2.54688


In [31]:
# recon = weights_to_resid(w, sae)

jnp.einsum("fv,...f->...v", sae["W_dec"], w)

Array([-1.3925353 ,  0.74642676,  0.14181848, ...,  2.5580206 ,
       -0.5393305 , -1.2404633 ], dtype=float32)