In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("..")


In [2]:
from penzai import pz
import json

from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector


In [3]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2-2b-it.gguf",
                                         from_type="gemma2",
                                         load_eager=True
                                         )

In [4]:
from transformers import AutoTokenizer
import jax
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [None]:
from sprint.task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()

In [6]:
import json, numpy as np

# with open("cleanup_results_gemma_2_algo.jsonl", "r") as f:
#     cleanup_results = [json.loads(line) for line in f][18:]


# with open("cleanup_results_gemma_2_all.jsonl", "r") as f:
#     tmp = [json.loads(line) for line in f]
#     tmp = [x for x in tmp if not x["task"].startswith("algo")]

#     cleanup_results += tmp

with open("data/cleanup_results_gemma_2_post.jsonl", "r") as f:
    cleanup_results = [json.loads(line) for line in f]

In [7]:
# layers = [14, 16, 18, 20, 22, 24]
layers = [18]

In [8]:
from micrlhf.utils.load_sae import sae_encode, get_dm_res_sae

thresholds = {
    layer: get_dm_res_sae(layer, load_65k=True).get("threshold", 0) for layer in layers
}


In [None]:
task = "antonyms"
task_results = [result for result in cleanup_results if result["layer"] in layers and result["task"] == task]   

print(len(task_results))

In [10]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)


In [43]:
from micrlhf.utils.ito import grad_pursuit
from matplotlib import pyplot as plt


def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


sep = 3978
pad = 0
newline = 108

In [49]:
from sprint.task_vector_utils import load_tasks, ICLDataset, ICLSequence
from sprint.task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder


use_65k = True
n_few_shots, batch_size, max_seq_len = 20, 16, 256
seed = 10
prompt = "<start_of_turn>user\nFollow the pattern:\n{}"
os.makedirs("data/feature_plots", exist_ok=True)


for r in task_results:
    task = r["task"]
    layer = r["layer"]
    
    sae = get_dm_res_sae(layer, load_65k=use_65k)
    
    pairs = list(tasks[task].items())
    n_shot = n_few_shots - 1

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)


    tokenized = runner.get_tokens([
        x[:n_shot] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)
    
    resids = all_resids[layer].value.unwrap(
        "batch", "seq", "embedding"
    )
    
    prompt_length = len(tokenizer.encode(prompt[:-2]))
    masks = [
        ("prompt", jnp.zeros_like(train_tokens).at[:, :prompt_length].set(1).astype(bool)),
        ("input", jnp.roll(train_tokens == sep, -1, axis=-1).at[:, :prompt_length].set(False)),
        ("arrow", jnp.array(train_tokens == sep).at[:, :prompt_length].set(False)), 
        ("output", jnp.roll(train_tokens == newline, -1, axis=-1).at[:, :prompt_length].set(False)),
        ("newline", jnp.array(train_tokens == newline).at[:, :prompt_length].set(False)),
    ]
    
    tv = get_tv(resids, train_tokens, shift = 0, sep=sep)
    weights, _ = grad_pursuit(tv, sae["W_dec"], 20)

    i = np.argwhere(weights > thresholds[layer]).flatten()
    w = weights[i]

    idx = np.argsort(w)[::-1]

    i = i[idx]
    w = w[idx]
    indices, weights = i, w
    
    _, encodings, _ = sae_encode(sae, resids)
    
    for feature, feature_weight in zip(indices, weights):
        enc = encodings[..., feature]
        for mask_name, mask in masks:
            if (enc * mask).sum() == 0:
                continue
            if mask_name == "prompt":
                continue
            all_masked_values = []
            for batch_idx in range(enc.shape[0]):
                m = mask[batch_idx]
                e = enc[batch_idx] * m
                mask_segment = 0
                prev = False
                mask_segments = []
                for v in m:
                    if prev and not v:
                        mask_segment += 1
                    prev = v
                    mask_segments.append(mask_segment)
                mask_segments = jnp.array(mask_segments)
                max_mask_segment = mask_segment
                masked_values = []
                for i in range(max_mask_segment + 1):
                    masked_values.append(e[mask_segments == i].mean() / m[mask_segments == i].mean())
                all_masked_values.append(masked_values)
            masked_values = jnp.mean(jnp.asarray(all_masked_values), axis=0)
            plt.title(f"Feature {feature} with weight {feature_weight}; mask {mask_name}")
            plt.plot(masked_values)
            plt.savefig("data/feature_plots/{}_{}_{}_{}.png".format(task, layer, feature, mask_name))
            plt.xlabel("Shot #")
            plt.ylabel("Feature value averages over mask")
            plt.close()