In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("..")

In [3]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [69]:
import jax.numpy as jnp
import jax
import json

from tqdm.auto import tqdm

In [5]:
from micrlhf.llama import LlamaTransformer
from transformers import AutoTokenizer


filename = "models/phi-3-16.gguf"
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
feature_dataset

In [46]:
from micrlhf.utils.activation_manipulation import replace_activation

def benchmark_vector(vector, tokens, model, positions, replacement_layer):
    act_rep = replace_activation(model, vector, positions, layer=replacement_layer)

    return act_rep(tokens)

In [11]:
def tokens_to_inputs(tokens):
    token_array = jnp.asarray(tokens)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [35]:
def logits_to_loss(logits, tokens, answer_start, pad_token=32000):
    logits = jax.nn.log_softmax(logits)

    logits = logits[:, :-1]
    logits = jnp.take_along_axis(logits, tokens[:, 1:, None], axis=-1).squeeze(-1)

    mask = tokens[:, 1:] != pad_token

    mask[:, :answer_start] = False

    logits = logits * mask

    return -logits.sum(axis=-1) / mask.sum(axis=-1)


In [52]:
from micrlhf.utils.load_sae import get_sae

def benchmark_feature(prompt_template, token_to_replace, layer, feature, explanation, batch_size=64, min_scale=0, max_scale=200, max_length=64, replacement_layer=2):
    vector = get_sae(layer, 4)["W_dec"][feature]
    vector = vector[None, :] * jnp.linspace(min_scale, max_scale, batch_size)[:, None]

    prompt = prompt_template.format(explanation)
    text = [prompt for _ in range(batch_size)]

    tokenized = tokenizer(text, return_tensors="np", padding="max_length", max_length=max_length, truncation=True)

    tokens = tokenized["input_ids"]

    inputs = tokens_to_inputs(tokens)

    positions = [i for i, a in enumerate(tokenizer.encode(prompt_template)) if tokenizer.decode([a]) == token_to_replace]

    logits = benchmark_vector(
        vector, inputs, llama, positions, replacement_layer
    )

    logits = logits.unwrap(
        "batch", "seq", "vocabulary"
    )

    answer_start = len(tokenizer.encode(prompt_template))

    loss = logits_to_loss(logits, tokens, answer_start)

    best_idx = jnp.argmin(loss)

    return loss, jnp.linspace(min_scale, max_scale, batch_size)[best_idx], loss[best_idx]

In [8]:
prompt_template = "<|user|>\nWhat is the meaning of the word \"X\"?<|end|>\n<|assistant|>\nThe meaning of the word \"X\" is \"{}\""
token_to_replace = "X"

positions = [i for i, a in enumerate(tokenizer.encode(prompt_template)) if tokenizer.decode([a]) == token_to_replace]

In [14]:
from datasets import load_dataset

feature_dataset = load_dataset("kisate-team/feature-explanations", split="train")

In [56]:
replacement_layer = 2

for i in tqdm(list(range(len(feature_dataset)))):
    item = feature_dataset[i]
    
    layer = item["layer"]
    feature = item["feature"]
    explanation = item["explanation"]

    result = {"id": i}

    if explanation is not None:
        loss, scale, best_loss = benchmark_feature(prompt_template, token_to_replace, layer, feature, explanation, replacement_layer=replacement_layer)
        result["loss"] = loss.tolist()
        result["scale"] = scale
        result["best_loss"] = best_loss

    with open("results.jsonl", )

Skipping item
Explanation: mentions of geographic locations
Loss: [5.5625 4.65625 4.3125 3.42188 3.95312 4.625 4.8125 4.9375 5.1875 5.28125
 5.3125 5.21875 5.21875 5.125 5.0625 4.96875 4.875 4.90625 4.6875 4.75
 4.75 4.6875 4.59375 4.53125 4.40625 4.4375 4.3125 4.1875 4.125 4 3.90625
 3.90625 3.71875 3.78125 3.75 3.79688 3.92188 3.95312 4.09375 4.09375 4.25
 4.3125 4.34375 4.3125 4.4375 4.5 4.53125 4.625 4.4375 4.5625 4.59375
 4.6875 4.625 4.65625 4.65625 4.75 4.75 4.71875 4.6875 4.71875 4.65625
 4.6875 4.75 4.8125]
Scale: 9.523810386657715
Best Loss: 3.42188

Explanation: declarations of variables or attributes with specific data types
Loss: [4.09375 5.15625 4.0625 4.40625 4.5625 4.5625 4.53125 4.4375 4.3125 4.3125
 4.3125 4.375 4.46875 4.5 4.5625 4.59375 4.5625 4.5625 4.53125 4.5625
 4.5625 4.5625 4.53125 4.5625 4.5625 4.5625 4.59375 4.59375 4.59375
 4.65625 4.65625 4.6875 4.6875 4.71875 4.75 4.75 4.8125 4.75 4.8125 4.8125
 4.84375 4.875 4.9375 4.96875 4.96875 5 4.96875 4.96875 5.062

In [65]:
item = feature_dataset[1]

print(item)

layer = item["layer"]
feature = item["feature"]
explanation = item["explanation"]

{'layer': 12, 'version': 4, 'feature': 27228, 'type': 'location_country', 'explanation': 'mentions of geographic locations'}


In [66]:
vector = get_sae(layer, 4)["W_dec"][feature]
prompt = "<|user|>\nWhat is the meaning of the word \"X\"?<|end|>\n<|assistant|>\nThe meaning of the word \"X\" is \""
bs = 64
msl = 64

act_rep = replace_activation(llama, vector[None, :] * jnp.linspace(8, 10, bs)[:, None], prompt=prompt, tokenizer=tokenizer, layer=2)

In [67]:
from micrlhf.sampling import sample
texts = sample(act_rep, tokenizer, prompt, batch_size=bs, do_sample=True, max_seq_len=msl)[0]

  0%|          | 0/40 [00:00<?, ?it/s]

In [68]:
display(texts)