In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:
from micrlhf.gguf import GGUFReader
import numpy as np
phi = GGUFReader("models/phi-3-16.gguf")
abl = GGUFReader("models/abl.gguf")
a, (b,), c = phi["blk.0.ffn_down.weight"]
_, (b_,), _ = abl["blk.0.ffn_down.weight"]
diff = b.reshape(c[::-1]) - b_.reshape(c[::-1])
u, s, vt = np.linalg.svd(diff.astype(np.float32), full_matrices=False)
vector = u[:, 1]

In [3]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="tpu:0")

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
from micrlhf.llama import LlamaBlock
from micrlhf.flash import flashify
from micrlhf.sampling import sample, trange, jnp, load_tokenizer, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre.{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [8]:
from micrlhf.sampling import sample, trange, jnp, load_tokenizer, jit_wrapper
from tqdm.auto import trange
from penzai.toolshed import sharding_util
import random
from functools import partial
import jax

@jax.jit
def loss_fn(logits, inputs):
    losses = pz.nx.nmap(lambda l, i: jnp.take_along_axis(jax.nn.log_softmax(l[:-1], -1), i[1:, None], 1)[:, 0].mean()
                        )(logits.untag("seq", "vocabulary"), inputs.tokens.untag("seq"))
    return -losses

bs_start = llama.mesh.shape["dp"]
tokens_init = tokenizer.encode("<|user|>\nX X X X X X X X X X X X X X X X X X X X<|end|>\n<|assistant|>\n")
optim_mask = [token == 1060 for token in tokens_init]
tokens_init = np.asarray(tokens_init)
MAX_ELITES = 4
tokens_init = np.repeat(tokens_init[None, :], MAX_ELITES, axis=0)
seed = 21
np.random.seed(seed)
tokens_init[:, optim_mask] = np.random.randint(100, tokenizer.vocab_size, tokens_init[:, optim_mask].shape)
def tokens_to_array(tokens):
    token_array = jnp.asarray(tokens)
    if len(token_array) >= bs_start:
        token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq")
    return token_array
def run_tokens(token_array):
    if not isinstance(token_array, pz.nx.NamedArray):
        token_array = tokens_to_array(token_array)
    inputs = llama.inputs.from_basic_segments(token_array)
    logits, resids = get_resids_call(inputs)
    losses = loss_fn(logits, inputs)
    return logits, losses, {resid.tag: resid.value for resid in resids}

mask = jax.device_put(jnp.asarray(optim_mask), jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("sp")))
# @partial(jax.jit, static_argnames=("key", "candidates", "expected_changes"))
def algo_iteration(elites, vector, key="resid_pre.16", candidates=1024, seed=13, expected_changes=3):
    elites = elites.untag("solutions").tag("batch")
    logits = run_tokens(elites)[0].untag("batch").tag("elites")
    
    def temper(logits, key, elites):
        key_choice, key_random = jax.random.split(key)
        index = jax.random.randint(key_choice, (), 0, len(logits) - 1)
        key_categorical, key_uniform, key_bernoulli, key_randint = jax.random.split(key_random, 4)
        logit = logits[index]
        elite = elites[index]
        logit = jnp.roll(logit, 1, 0)
        logit = logit * jax.random.uniform(key_uniform, minval=0.5, maxval=2)
        to_change = jax.random.bernoulli(key_bernoulli, expected_changes / mask.sum(), mask.shape)
        definite_indices = jax.random.randint(key_randint, mask.shape[:-1], 0, mask.shape[-1])
        definite_mask = jax.nn.one_hot(definite_indices, to_change.shape[-1], dtype=jnp.bool_)
        to_change = to_change | definite_mask
        return jnp.where(mask & to_change,
                         jax.random.categorical(key_categorical, logit),
                         elite)
    tempered_samples = pz.nx.nmap(temper)(
        logits.untag("elites", "seq", "vocabulary"),
        pz.nx.wrap(jax.random.split(jax.random.key(seed), candidates), "batch"),
        elites.untag("batch", "seq")).tag("seq")
    tempered_samples = sharding_util.name_to_name_device_put(tempered_samples, llama.mesh, dict(batch="dp", seq="sp"))
    _, new_losses, new_resids = run_tokens(tempered_samples)
    new_scores = (new_resids[key][{"seq": -1}].untag("embedding") * vector).sum().astype(new_losses.dtype)
    metrics = pz.nx.nmap(lambda *xs: jnp.stack(xs))(new_losses, new_scores).tag("metrics")
    solution_axes = [k for k in tempered_samples.named_shape.keys() if k != "seq"]
    solutions = tempered_samples.untag(*solution_axes).flatten().tag("solutions").unwrap("solutions", "seq")
    metrics = metrics.untag(*(k for k in solution_axes if k != "seq")).flatten().tag("solutions").unwrap("solutions", "metrics")
    return solutions, metrics


best_metrics = None
best = tokens_to_array(tokens_init).untag("batch").tag("solutions")
xent_min = 1
xent_max = 10
weights = jnp.stack((
    jnp.linspace(-xent_max, -xent_min, MAX_ELITES),
    jnp.ones(MAX_ELITES),
), -1)
for seed in (bar := trange(1_000)):
    solutions, metrics = algo_iteration(best, vector, seed=seed)
    if best_metrics is not None:
        best_metrics = jnp.concatenate((best_metrics, metrics), 0)
        best = pz.nx.nmap(lambda a, b: jnp.concatenate((a, b)))(
            best.untag("solutions"),
            pz.nx.wrap(solutions, "solutions", "seq").untag("solutions")
        ).tag("solutions").unwrap("solutions", "seq")
    else:
        best_metrics = metrics
        best = solutions
    elite_mask = (best_metrics[None, :] * weights[:, None]).sum(-1).argmax(1)
    best_metrics = best_metrics[elite_mask]
    best = pz.nx.wrap(best[elite_mask], "solutions", "seq")
    m = {}
    for index in range(MAX_ELITES):
        i = index
        m |= {f"decoded.{i}": tokenizer.decode(best[{"solutions": index}].unwrap("seq")),
              f"loss.{i}": best_metrics[index][0], f"score.{i}": best_metrics[index][1]}
    bar.set_postfix(**m)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
_, resids = get_resids_call(inputs)