In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [4]:
from micrlhf.llama import LlamaTransformer
from transformers import AutoTokenizer


filename = "models/phi-3-16.gguf"
llama = LlamaTransformer.from_pretrained(filename, device_map="tpu:0")
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
from micrlhf.llama import LlamaBlock
embeds = llama.select().at_instances_of(pz.nn.EmbeddingLookup).get_sequence()[0].table.embeddings.value.unwrap("vocabulary", "embedding")
llama_without_embeds = llama.select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda x: pz.nn.Identity())
get_resids = llama_without_embeds.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
llama_without_embeds = get_resids

In [11]:
from micrlhf.utils.vector_storage import load_vector
refusal_layer = 20
refusal_vector = load_vector(f"phi-refusal-ablit", from_remote=True, overwrite=True)

In [12]:
from collections import OrderedDict
from jax_tqdm import scan_tqdm
import jax.numpy as jnp
import jax
n_steps = 2_000
n_vocab, embed_size = embeds.shape
seq_len = 16
start_phrase = "<s><|user|>"
fixed_ids = tokenizer.encode(start_phrase)
fixed_tokens = len(fixed_ids)
count_loss_from = max(0, fixed_tokens - 1)

def compute_loss(logits):
    in_probs = jax.nn.softmax(logits, -1)
    in_resid = in_probs @ embeds
    named_array = pz.nx.NamedArray(OrderedDict({"batch": 1, "seq": seq_len, "embedding": embed_size}), in_resid[None, ...])
    logits, residuals = llama_without_embeds(llama.inputs.from_basic_segments(named_array))
    logits = logits.unwrap("batch", "seq", "vocabulary")[0]
    out_logprobs = jax.nn.log_softmax(logits, -1)
    loss = (in_probs[1:] * out_logprobs[:-1])[count_loss_from:].sum(-1).mean()
    return loss - 10 * (residuals[refusal_layer].value[{"seq": -1}].unwrap("batch", "embedding") @ refusal_vector).mean()

def set_token(logits, index, token_id):
    return logits.at[index, token_id].set(0).at[index, token_id+1:].set(-1e9).at[index, :token_id].set(-1e9)

@scan_tqdm(n_steps)
def update_logits(params, hparams):
    logits, key = params
    for i, token_id in enumerate(fixed_ids):
        logits = set_token(logits, i, token_id)
    sigma, nu = hparams
    grad = jax.grad(compute_loss)(logits)
    key, subkey = jax.random.split(key)
    noise = jax.random.normal(subkey, logits.shape)
    logits = logits - nu * grad + sigma * noise
    return (logits, key), None
    

logits = jax.random.normal(jax.random.key(0), (seq_len, n_vocab)) * 0.1
# Throughout the experiments, we set the number of Langevin dynamics steps to N = 2000, with a step size η = 0.1
nus = jnp.full(n_steps, 0.1)
# In our experiments, we typically used the schedule which sets/reduces σ to {1, 0.5, 0.1, 0.05, 0.01} at iterations {0, 50, 500, 1000, 1500}
sigmas = jnp.full(n_steps, 1.).at[50:].set(0.5).at[500:].set(0.1).at[1000:].set(0.05).at[1500:].set(0.01)

new_logits = jax.lax.scan(update_logits, (logits, jax.random.key(0)), (sigmas, nus))[0][0]
tokenizer.decode(new_logits.argmax(-1))

In [13]:
from penzai.toolshed.jit_wrapper import Jitted
ljit = Jitted(llama)

In [14]:
from tqdm.auto import trange
tokens = new_logits.argmax(-1)
k = 5
for i in trange(count_loss_from, seq_len - 1):
    key = jax.random.key(i)
    named_array = pz.nx.NamedArray(OrderedDict({"batch": 1, "seq": seq_len}), tokens[None, ...])
    logits_predicted = ljit(llama.inputs.from_basic_segments(named_array)).unwrap("batch", "seq", "vocabulary")[0][i]
    _, possible_tokens = jax.lax.top_k(new_logits[i], k)
    choice = possible_tokens[jax.random.categorical(key, jax.nn.softmax(logits_predicted[possible_tokens]))]
    tokens = tokens.at[i + 1].set(choice)
tokenizer.decode(tokens)

  0%|          | 0/14 [00:00<?, ?it/s]