In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="tpu:0")

In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
import jax
import jax.numpy as jnp
prompt = """<|user|>
What is 3940 * 3892?<|end|>
<|assistant|>
"""
tokens = tokenizer.encode(prompt)
def tokens_to_array(tokens):
    token_array = jnp.asarray(tokens)
    if len(token_array) >= llama.mesh.shape["dp"]:
        token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq")
    return token_array

In [5]:
from micrlhf.llama import LlamaBlock
from functools import partial


layer_source = 10
def make_get_resids(llama, layer_target):
    get_resids = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer_target
                                                                              ).apply(lambda x:
        pz.nn.Sequential([
            pz.de.TellIntermediate.from_config(tag=f"resid_pre"),
            x
        ])
    )
    get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
    return get_resids
jittify = lambda x: partial(jax.jit(lambda lr, *args, **kwargs: lr(*args, **kwargs)[1][0].value), x)
get_resids_initial = make_get_resids(llama, layer_source)
get_resids_initial = jittify(get_resids_initial)

In [6]:
layer_target = 20
taker = jittify(make_get_resids(llama, layer_target).select().at_instances_of(LlamaBlock).apply_with_selected_index(
    lambda i, x: x if i >= layer_source else pz.nn.Identity()
).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

In [9]:
import dataclasses
import jax.numpy as jnp
from tqdm.auto import trange
import optax


inputs = llama.inputs.from_basic_segments(tokens_to_array([tokens]))
resid_initial = get_resids_initial(inputs)
bs = 16
seed = 9
iterations = 1000
scale = 10
def get_loss(rep):
    rep = pz.nx.wrap(rep, "batch", "embedding")
    rep = pz.nx.nmap(lambda a: a / jnp.linalg.norm(a) * scale)(rep.untag("embedding")).tag("embedding")
    resid = pz.nx.nmap(lambda x, y: x.at[-1].add(y))(resid_initial[{"batch": 0}].untag("seq"), rep).tag("seq")
    resid_out = taker(dataclasses.replace(inputs, tokens=resid,
                                          attention_mask=inputs.attention_mask,
                                          positions=inputs.positions))[{"seq": -1}]
    diffs = pz.nx.nmap(lambda x: (x[:, None] - x[None, :]) ** 2)(resid_out.untag("batch")).tag("batch1", "batch2")
    diffs = diffs.untag("embedding").sum()  # ** 0.5
    loss = diffs.data_array.mean()
    return -loss, (resid_out,)
start_add = jax.random.normal(jax.random.key(seed),
                              (bs, resid_initial.named_shape["embedding"]), dtype=resid_initial.dtype)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(1e-2),
    optax.zero_nans(),
    
)
lwg = jax.value_and_grad(get_loss, has_aux=True)
# lwg = jax.jit(lwg)

# @partial(jax.jit, donate_argnums=(0, 1))
def train_step(addition, opt_state):
    (loss, (resid_out,)), grad = lwg(addition)
    updates, opt_state = optimizer.update(grad, opt_state, addition)
    addition = optax.apply_updates(addition, updates)
    return loss, addition, opt_state, dict(resid_out=resid_out)


In [10]:
addition = start_add
opt_state = optimizer.init(addition)
for _ in (bar := trange(iterations)):
    loss, addition, opt_state, aux = train_step(addition, opt_state)
    bar.set_postfix(loss=loss)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [11]:
from micrlhf.sampling import sample
from typing import Literal
import numpy as np
import dataclasses


@pz.pytree_dataclass
class ActivationAddition(pz.Layer):
    addition: pz.nx.NamedArray
    position: Literal["all", "last", "first"] = dataclasses.field(metadata={"pytree_node": False}, default="all")
    size_cond: Literal["all", "last"] = dataclasses.field(metadata={"pytree_node": False}, default="all")
    
    def adder(self, a, b):
        if self.position == "all":
            return a + b
        elif self.position == "last":
            return a.at[-1].add(b)
        elif self.position == "first":
            return a.at[0].add(b)
    
    def __call__(self, x):
        return pz.nx.nmap(lambda a, b: jax.lax.select(
            self.size_cond == "all" or len(a) > 1, self.adder(a, b).astype(a), a))(
            x.untag("seq"), self.addition).tag("seq")

layer = layer_source
repeat = 4
act_add = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer).apply(
    lambda x: pz.nn.Sequential([ActivationAddition(
        pz.nx.wrap(
            jnp.tile(addition / jnp.linalg.norm(addition, axis=-1, keepdims=True) * scale, (repeat, 1))
            , "batch", "embedding"), position="all", size_cond="all"), x]))
texts, cached = sample(act_add, tokenizer, prompt,
       batch_size=bs * repeat, do_sample=True, return_model=True)

  0%|          | 0/46 [00:00<?, ?it/s]

In [14]:
np.array(texts).reshape(repeat, bs).T.tolist()

In [28]:
from micrlhf.utils.load_sae import get_sae
from micrlhf.utils.ito import grad_pursuit
resid_out = aux["resid_out"]
refusal = (resid_out[{"batch": -4}] - resid_out.untag("batch").mean()).unwrap("embedding")
sae = get_sae(20, 5)
k = 10
w, r = grad_pursuit(refusal, sae["W_dec"], k)
i = jax.lax.top_k(jnp.abs(w), k)[1]
((r - refusal) ** 2).mean(), i, w[i]

In [21]:
r.shape, refusal.shape