In [None]:
import os
if "models" not in os.listdir("."):
    os.chdir("..")

In [None]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [None]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf",
                                         from_type="gemma",
                                         load_eager=True
                                         )

In [None]:
!mkdir -p data
import pandas as pd
df_adv = pd.read_csv("data/adv.csv")
format_prompt = """<start_of_turn>user\n
{}\n
<start_of_turn>model\n
{}"""
# offset = 1
# df_do = df.apply(lambda x: format_response.format(x['goal'], x['target']), axis=1)
# prompts_harmful = df.apply(lambda x: format_prompt.format(x['goal'], "")[:-offset], axis=1).to_list()[:100]
prompts_harmful = df_adv.apply(lambda x: format_prompt.format(x['goal'], ""), axis=1).to_list()[:100]
dataset_jail = pd.read_csv("data/jail.csv").apply(lambda x: x["Goal"], axis=1).to_list()
prompts_jail = [format_prompt.format(x, "") for x in dataset_jail]
import datasets
# https://colab.research.google.com/drive/1a-aQvKC9avdZpdyBn4jgRQFObTPy1JZw
hf_path = 'tatsu-lab/alpaca'
dataset = datasets.load_dataset(hf_path)
# filter for instructions that do not have inputs
prompts_harmless = []
for i in range(len(dataset['train'])):
    if len(prompts_harmless) >= len(prompts_harmful):
        break
    if dataset['train'][i]['input'].strip() == '':
        # prompts_harmless.append(format_prompt.format(dataset['train'][i]['instruction'], "")[:-offset])
        prompts_harmless.append(format_prompt.format(dataset['train'][i]['instruction'], ""))

# ds = datasets.load_dataset("MBZUAI/LaMini-instruction", split="train", streaming=True)
# prompts_harmless = []
# for _, text in zip(range(100), ds):
#     prompts_harmless.append(format_prompt.format(text["instruction"], ""))

# ds = datasets.load_dataset("nev/openhermes-2.5-phi-format-text", split="train", streaming=True)
# prompts_harmless = []
# for _, text in zip(range(100), ds):
#     text = text["text"]
#     text = "".join(text.partition("<|assistant|>\n")[:2])
#     prompts_harmless.append(text)

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [None]:
from micrlhf.sampling import sample, trange, jnp, load_tokenizer, jit_wrapper
import jax

tokens = tokenizer.batch_encode_plus(prompts_harmful + prompts_harmless,
                                     return_tensors="np",
                                     padding="max_length",
                                     truncation=True,
                                     max_length=128,
                                     return_attention_mask=True)
token_array = jnp.asarray(tokens["input_ids"])
token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")
inputs = llama.inputs.from_basic_segments(token_array)

In [None]:
from micrlhf.llama import LlamaBlock
from micrlhf.flash import flashify
from micrlhf.sampling import sample, trange, jnp, load_tokenizer, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)
_, resids = get_resids_call(inputs)

In [None]:
from matplotlib import pyplot as plt
import numpy as np
residiffs = []
last_resids = []
for i, resid in enumerate(sorted(resids, key=lambda x: int(x.tag.rpartition("_")[-1]))):
    resid = resid.value.unwrap("batch", "seq", "embedding")
    indices = jnp.asarray(tokens["attention_mask"].sum(1, keepdims=True))[..., None] - 1
    last_resid = jnp.take_along_axis(resid, indices, 1)[:, 0]
    last_resid = last_resid.reshape(2, -1, last_resid.shape[-1])
    last_resids.append(last_resid)
    last_resid = last_resid.mean(1)
    residiff = np.array(last_resid[0] - last_resid[1])
    residiff = residiff / np.linalg.norm(residiff)
    residiffs.append(residiff)
matmuls = np.matmul(residiffs, np.transpose(residiffs))
norms = np.linalg.norm(residiffs, axis=-1) + 1e-10
plt.title("Correlations between layers' refusal vectors")
plt.imshow(matmuls / norms[:, None] / norms[None, :], vmin=0, vmax=1)
plt.colorbar()
plt.show()

In [None]:
from micrlhf.utils.vector_storage import save_and_upload_vector, download_vector
from micrlhf.utils.load_sae import get_nev_it_sae_suite
from micrlhf.utils.ito import grad_pursuit


sae_k = 4
threshold = 0.2
interesting_features = set()
for layer_sae in range(8, 16):
    dictionary = get_nev_it_sae_suite(layer_sae)["W_dec"]
    dictionary = dictionary / np.linalg.norm(dictionary, axis=1, keepdims=True)
    for layer in range(1, len(residiffs)):
        vector_name = f"gemma-refusal-l{layer}"
        vector = residiffs[layer]
        vector = vector / np.linalg.norm(vector)
        try:
            save_and_upload_vector(vector_name, vector)
        except FileExistsError:
            pass
        weights, recon = grad_pursuit(vector, dictionary, sae_k, pos_only=True)
        w, i = jax.lax.top_k(jnp.abs(weights), sae_k)
        print(f"Layer {layer} -> Layer {layer_sae}: {[(int(a), float(b)) for a, b in zip(i, w)]}")
        for f, u in zip(i, w):
            if u < threshold:
                continue
            interesting_features.add((layer_sae, int(f)))
    print()

In [None]:
from tqdm.auto import tqdm
import requests
import numpy as np
import os
import json

feat_dir = "data/feature_explanations"
os.makedirs(feat_dir, exist_ok=True)
for l, f in tqdm(interesting_features):
    filename = f"{feat_dir}/explanation_{l}_{f}.json"
    if os.path.exists(filename):
        with open(filename, 'r') as file:
            data = json.load(file)
    else:
        response = requests.get(f"https://datasets-server.huggingface.co/rows?dataset=kisate-team%2Fgemma-2b-suite-explanations&config=l{l}&split=train&offset={f}&length=1")
        data = response.json()
        with open(filename, 'w') as file:
            json.dump(data, file)
    
    if not data or set(data.keys()) == {"error"}:
        print("skipping", l, f, data)
        continue
    
    feat_info = data["rows"][0]["row"]
    criterion = np.array(feat_info["scale_tuning"]["selfsims"][-2]) - 0.1 * np.array(feat_info["scale_tuning"]["entropy"])
    scale = feat_info["scale_tuning"]["scales"][np.argmax(criterion[10:]) + 10]
    index = np.searchsorted(feat_info["generations"]["scales"], scale)
    texts = feat_info["generations"]["texts"]
    display((l, f, scale, texts[index], texts))