In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf",
                                         from_type="gemma",
                                         load_eager=True
                                         )

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
tokenizer.padding_side = "right"

In [5]:
import jax
import jax.numpy as jnp
format_prompt = """<start_of_turn>user\n
{}\n
<start_of_turn>model\n
{}"""
prompt = format_prompt.format("What is 3940 * 3892?", "")
tokens = tokenizer.encode(prompt)
def tokens_to_array(tokens):
    token_array = jnp.asarray(tokens)
    if len(token_array) >= llama.mesh.shape["dp"]:
        token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq")
    return token_array

In [6]:
from micrlhf.llama import LlamaBlock
from functools import partial


layer_source = 5
def make_get_resids(llama, layer_target):
    get_resids = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer_target
                                                                              ).apply(lambda x:
        pz.nn.Sequential([
            pz.de.TellIntermediate.from_config(tag=f"resid_pre"),
            x
        ])
    )
    get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
    return get_resids
jittify = lambda x: partial(jax.jit(lambda lr, *args, **kwargs: lr(*args, **kwargs)[1][0].value), x)
get_resids_initial = make_get_resids(llama, layer_source)
get_resids_initial = jittify(get_resids_initial)

In [7]:
layer_target = 10
taker = jittify(make_get_resids(llama, layer_target).select().at_instances_of(LlamaBlock).apply_with_selected_index(
    lambda i, x: x if i >= layer_source else pz.nn.Identity()
).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

In [20]:
import dataclasses
import jax.numpy as jnp
from tqdm.auto import trange
import optax


inputs = llama.inputs.from_basic_segments(tokens_to_array([tokens]))
resid_initial = get_resids_initial(inputs)
bs = 16
seed = 9
iterations = 1000
scale = 2
def get_loss(rep):
    rep = pz.nx.wrap(rep, "batch", "embedding")
    rep = pz.nx.nmap(lambda a: a / jnp.linalg.norm(a) * scale)(rep.untag("embedding")).tag("embedding")
    resid = pz.nx.nmap(lambda x, y: x.at[-1].add(y))(resid_initial[{"batch": 0}].untag("seq"), rep).tag("seq")
    resid_out = taker(dataclasses.replace(inputs, tokens=resid,
                                          attention_mask=inputs.attention_mask,
                                          positions=inputs.positions))[{"seq": -1}]
    diffs = pz.nx.nmap(lambda x: (x[:, None] - x[None, :]) ** 2)(resid_out.untag("batch")).tag("batch1", "batch2")
    diffs = diffs.untag("embedding").sum()  # ** 0.5
    loss = diffs.data_array.mean()
    return -loss, (resid_out,)
start_add = jax.random.normal(jax.random.key(seed),
                              (bs, resid_initial.named_shape["embedding"]), dtype=resid_initial.dtype)

optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adam(1e-2),
    optax.zero_nans(),
    
)
lwg = jax.value_and_grad(get_loss, has_aux=True)
# lwg = jax.jit(lwg)

# @partial(jax.jit, donate_argnums=(0, 1))
def train_step(addition, opt_state):
    (loss, (resid_out,)), grad = lwg(addition)
    updates, opt_state = optimizer.update(grad, opt_state, addition)
    addition = optax.apply_updates(addition, updates)
    return loss, addition, opt_state, dict(resid_out=resid_out)


In [21]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite
from tqdm.auto import trange
from more_itertools import chunked
import numpy as np
sae = get_nev_it_sae_suite(layer_source)

feat_bs = 64
feat_outs = []
try:
    for index in trange(0, len(sae["W_dec"]), feat_bs):
        rep = sae["W_dec"][index:index+feat_bs]
        rep = pz.nx.wrap(rep, "batch", "embedding")
        rep = pz.nx.nmap(lambda a: a / jnp.linalg.norm(a) * scale)(rep.untag("embedding")).tag("embedding")
        resid = pz.nx.nmap(lambda x, y: x.at[-1].add(y))(resid_initial[{"batch": 0}].untag("seq"), rep).tag("seq")
        resid_out = taker(dataclasses.replace(inputs, tokens=resid,
                                            attention_mask=inputs.attention_mask,
                                            positions=inputs.positions))[{"seq": -1}]
        feat_outs.append(np.array(resid_out.unwrap("batch", "embedding")))
except KeyboardInterrupt:
    pass
all_outs = np.concatenate(feat_outs, axis=0)
all_outs.shape

  0%|          | 0/512 [00:00<?, ?it/s]

In [22]:
from tqdm.auto import trange
import numba as nb
@jax.jit
def cdist(a, b):
    return jax.vmap(jax.vmap(lambda x, y: jnp.linalg.norm(x - y), in_axes=(0, None), out_axes=0), in_axes=(None, 0), out_axes=1)(a, b)
ao = all_outs.astype(jnp.float16)
ao = jnp.asarray(ao)
vchunk = 2048
existing = []
@jax.jit
def vscan(carry, v, diff):
    max_dist, rc = carry
    sl = jnp.arange(vchunk) + v
    dists = cdist(ao[sl], ao)
    dists += (diff[sl, None] + diff[None, :]) * (np.arange(vchunk)[:, None] + v != np.arange(len(diff))[None, :])
    maxd = dists.max()
    md = (dists).ravel().argmax()
    row, col = jnp.unravel_index(md, dists.shape)
    row += v
    # return jnp.maximum((max_dist, rc), (maxd, (row, col)))
    cond = maxd > max_dist
    max_dist = jax.lax.select(cond, maxd, max_dist)
    rc = jax.lax.select(cond, jnp.array([row, col]), rc)
    return (max_dist, rc), None
for _ in trange(bs // 2):
    diff = jnp.zeros((ao.shape[0],), dtype=jnp.float32)
    if existing:
        je = jnp.asarray(existing)
        diff += cdist(ao[je], ao).sum(axis=0)
        diff = diff.at[je].set(-1e6)
    max_dist, rc = 0.0, (0, 0)
    (max_dist, rc), _ = jax.lax.scan(partial(vscan, diff=diff), (max_dist, jnp.array(rc)), jnp.arange(0, len(diff), vchunk))
    existing.extend(rc)

print(existing)
vectors = sae["W_dec"][jnp.asarray(existing)]

  0%|          | 0/8 [00:00<?, ?it/s]

[Array(30630, dtype=int32), Array(31655, dtype=int32), Array(8047, dtype=int32), Array(15751, dtype=int32), Array(6755, dtype=int32), Array(7235, dtype=int32), Array(8541, dtype=int32), Array(20798, dtype=int32), Array(12098, dtype=int32), Array(15214, dtype=int32), Array(13952, dtype=int32), Array(15321, dtype=int32), Array(19986, dtype=int32), Array(22143, dtype=int32), Array(15336, dtype=int32), Array(26247, dtype=int32)]


In [23]:
from micrlhf.utils.activation_manipulation import add_vector
from micrlhf.sampling import sample
import numpy as np

layer = layer_source
repeat = 4
act_add = add_vector(llama,
                     jnp.tile(vectors / jnp.linalg.norm(vectors, axis=-1, keepdims=True) * scale, (repeat, 1)),
                     layer, position="last")

texts, cached = sample(act_add, tokenizer, prompt,
       batch_size=bs * repeat, do_sample=True, return_model=True)
np.array(texts).reshape(repeat, bs).T.tolist()

  0%|          | 0/42 [00:00<?, ?it/s]

In [24]:
addition = start_add
opt_state = optimizer.init(addition)
for _ in (bar := trange(iterations)):
    loss, addition, opt_state, aux = train_step(addition, opt_state)
    bar.set_postfix(loss=loss)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [25]:
from micrlhf.utils.activation_manipulation import add_vector
from micrlhf.sampling import sample
import numpy as np

layer = layer_source
repeat = 4
act_add = add_vector(llama,
                     jnp.tile(addition / jnp.linalg.norm(addition, axis=-1, keepdims=True) * scale, (repeat, 1)),
                     layer)

texts, cached = sample(act_add, tokenizer, prompt,
       batch_size=bs * repeat, do_sample=True, return_model=True)
np.array(texts).reshape(repeat, bs).T.tolist()

  0%|          | 0/42 [00:00<?, ?it/s]

In [26]:
# from micrlhf.utils.activation_manipulation import add_vector
# from micrlhf.sampling import sample
# import numpy as np

# layer = layer_source
# repeat = 4
# act_add = add_vector(llama,
#                      jnp.tile(addition / jnp.linalg.norm(addition, axis=-1, keepdims=True) * scale, (repeat, 1)),
#                      layer)

# texts, cached = sample(act_add, tokenizer, prompt,
#        batch_size=bs * repeat, do_sample=True, return_model=True)
# np.array(texts).reshape(repeat, bs).T.tolist()

In [29]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite
from micrlhf.utils.ito import grad_pursuit
resid_out = aux["resid_out"]
refusal = (resid_out[{"batch": -6}] - resid_out.untag("batch").mean()).unwrap("embedding")
sae = get_nev_it_sae_suite(layer_target)
k = 2
w, recon = grad_pursuit(refusal, sae["W_dec"], k, pos_only=True)
i = jax.lax.top_k(jnp.abs(w), k)[1]
((recon - refusal) ** 2).mean(), i, w[i]

In [30]:
from micrlhf.utils.activation_manipulation import set_direction


layer = layer_target
repeat = 64
vector = refusal
# vector = recon
act_add = add_vector(llama,
                    (vector / jnp.linalg.norm(vector, axis=-1, keepdims=True))[None, :] * jnp.linspace(10, 50, repeat)[:, None],
                     layer)
# act_add = set_direction(llama, vector, jnp.linspace(-100, 100, repeat), layer)

sample(act_add, tokenizer, prompt,
       batch_size=repeat, do_sample=False)[0]

  0%|          | 0/42 [00:00<?, ?it/s]

In [None]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite
from micrlhf.utils.ito import grad_pursuit
refusal_steer = addition[-4]
sae_steer = get_nev_it_sae_suite(layer_source)
k = 16
w, recon_steer = grad_pursuit(refusal_steer, sae_steer["W_dec"], k, pos_only=True)
i = jax.lax.top_k(jnp.abs(w), k)[1]
((recon_steer - refusal_steer) ** 2).mean(), i, w[i]

--2024-05-30 04:39:06--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l10-test-run-4-2.73E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.95, 108.156.211.51, 108.156.211.90, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.95|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/c8cf3f62062b3cdc9509167526de244fa3271fca14dd9dc97f653ef970374bd7?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1717303146&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNzMwMzE0Nn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvYzhjZjNmNjIw

In [None]:
layer = layer_source
repeat = 64
# vector = addition[-4]
vector = recon_steer
act_add = add_vector(llama,
                     (vector / jnp.linalg.norm(vector, axis=-1, keepdims=True))[None, :] * jnp.linspace(-scale * 2, scale * 2, repeat)[:, None],
                     layer)

sample(act_add, tokenizer, prompt,
       batch_size=repeat, do_sample=False)

  0%|          | 0/46 [00:00<?, ?it/s]

: 