In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [12]:
from micrlhf.llama import LlamaTransformer
from micrlhf.scan import sequential_to_scan
llama = LlamaTransformer.from_pretrained("models/gemma-2-2b-it-q4_k_s.gguf", device_map="tpu:0",
                                         from_type="gemma2",
                                         load_on_cpu=True,
                                         )
llama = sequential_to_scan(llama)
llama = llama.to_tpu()

In [13]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
prompt = tokenizer.apply_chat_template([
    {"role": "user", "content": "Who are you?"},
    {"role": "assistant", "content": "Hello, I am a language model that exists and stuff. How can I help you today?"},
] * 1_000, tokenize=False)

In [14]:
from penzai.toolshed import jit_wrapper
from micrlhf.flash import flashify
tokens = pz.nx.wrap([tokenizer.encode(prompt)[:64]] * 256, "batch", "seq")
inputs = llama.inputs.from_basic_segments(tokens)
llama_jitted = jit_wrapper.Jitted(llama)

In [18]:
import jax
import jax.numpy as jnp
@jax.jit
def lfn(llama_jitted, inputs):
    logits = llama_jitted(inputs)
    loss = -pz.nx.nmap(lambda l, t: jnp.take_along_axis(jax.nn.log_softmax(l[:-1], -1), t[1:, None], 1).mean())(logits.untag("seq", "vocabulary"), tokens.untag("seq"))
    return loss
print(lfn(llama_jitted, inputs).unwrap("batch"))

[26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75
 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26.75 26

In [None]:
 import jax; jax.print_environment_info()