In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:
from micrlhf.llama import LlamaTransformer
from micrlhf.scan import sequential_to_scan
llama = LlamaTransformer.from_pretrained("models/gemma-2-2b-it-q4_k_s.gguf", device_map="tpu:0",
                                         from_type="gemma2",
                                         load_on_cpu=True,
                                         )
llama = sequential_to_scan(llama)
llama = llama.to_tpu()

In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/gemma-2b-it-tokenizer")
prompt = tokenizer.apply_chat_template([
    {"role": "user", "content": "Who are you?"},
    {"role": "assistant", "content": "Hello, I am a language model that exists and stuff. How can I help you today?"},
] * 1_000, tokenize=False)

In [4]:
from penzai.toolshed import jit_wrapper
from micrlhf.flash import flashify
tokens = pz.nx.wrap([tokenizer.encode(prompt)[:128]] * 128, "batch", "seq")
inputs = llama.inputs.from_basic_segments(tokens)
llama_jitted = jit_wrapper.Jitted(llama)

In [5]:
import jax
import jax.numpy as jnp
@jax.jit
def lfn(llama_jitted, inputs):
    logits = llama_jitted(inputs)
    loss = -pz.nx.nmap(lambda l, t: jnp.take_along_axis(jax.nn.log_softmax(l[:-1], -1), t[1:, None], 1).mean())(logits.untag("seq", "vocabulary"), tokens.untag("seq"))
    return loss
print(lfn(llama_jitted, inputs).unwrap("batch"))

[27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25
 27.25 27.25 27.25 27.25 27.25 27.25 27.25 27.25]


In [8]:
import jax.numpy as jnp
import numpy as np
import jax
from functools import partial
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.sharding import PartitionSpec as P


compute = lambda f: (f(jnp.arange(256, dtype=jnp.uint8)), (f(jnp.arange(256, dtype=jnp.uint8).astype(jnp.int8))).astype(jnp.uint8))

# https://github.com/99991/pygguf/blob/829886d0726c89c6f6c0d8c39b0d507ec1604077/gguf.py#L209
def matmul_4bit(inputs, scale_factors, scale_offsets, qs1, qs2, based=False):
    num_blocks = scale_factors.shape[0]
    # assert num_blocks == scale_offsets.shape[0] == qs1.shape[0] == qs2.shape[0] == 8

    sr = jax.lax.shift_right_logical

    def switch(x):
        x = x.astype(jnp.float16)
        # x = x.astype(jnp.float32) * 10
        x = x.astype(jnp.float32)
        return x
    scale_factors = switch(scale_factors)
    scale_offsets = switch(scale_offsets)
    if based:
        scale_factors = scale_factors.transpose(1, 0, 2).reshape(1, 1, num_blocks, -1).astype(jnp.float32)
        scale_offsets = scale_offsets.transpose(1, 0, 2).reshape(1, 1, num_blocks, -1).astype(jnp.float32)
        qs1 = qs1.transpose(1, 0, 2).reshape(12, 1, num_blocks, -1)
        qs2 = qs2.transpose(1, 0, 2).reshape(4, 32, num_blocks, -1)
    else:
        scale_factors = scale_factors.transpose(0, 2, 1).reshape(-1, 1, 1).astype(jnp.float32)
        scale_offsets = scale_offsets.transpose(0, 2, 1).reshape(-1, 1, 1).astype(jnp.float32)
        qs1 = qs1.transpose(0, 2, 1).reshape(-1, 12, 1)
        qs2 = qs2.transpose(0, 2, 1).reshape(-1, 4, 32)

    qs1 = qs1.astype(jnp.int32)
    qs2 = qs2.astype(jnp.int32)
    i8tou8 = lambda x: jnp.where(x < 0, 256 + x, x)
    qs1 = i8tou8(qs1)
    qs2 = i8tou8(qs2)


    # max 63
    if based:
        chunk1 = qs1[0:4]
        chunk2 = qs1[4:8]
        chunk3 = qs1[8:]
        factor_scale = jnp.concatenate([chunk1 & 0b111111, (chunk3 & 15) | (sr(chunk1, 6) << 4)], axis=0)
        offset_scale = jnp.concatenate([chunk2 & 0b111111, (sr(chunk3, 4) % 16) | (sr(chunk2, 6) << 4)], axis=0)
    else:
        chunk1 = qs1[:, 0:4]
        chunk2 = qs1[:, 4:8]
        chunk3 = qs1[:, 8:]
        factor_scale = jnp.concatenate([chunk1 & 0b111111, (chunk3 & 15) | (sr(chunk1, 6) << 4)], axis=1)
        offset_scale = jnp.concatenate([chunk2 & 0b111111, (sr(chunk3, 4) % 16) | (sr(chunk2, 6) << 4)], axis=1)

    basify = lambda x: x  # x.astype(jnp.int8) if not based else x
    factors = scale_factors * basify(factor_scale)
    offsets = scale_offsets * basify(offset_scale)

    # max 15
    if based:
        # qs2 = jnp.stack([qs2 & 0xf, sr(qs2, 4)], axis=1).reshape(8, 32, num_blocks, -1)
        qs2 = jnp.concatenate([qs2 & 0xf, sr(qs2, 4)], axis=1).reshape(8, 32, num_blocks, -1)
    else:
        qs2 = jnp.stack([qs2 & 0xf, sr(qs2, 4)], axis=2).reshape(-1, 8, 32)

    matrix = factors * basify(qs2) - offsets
    # matrix = basify(qs2).astype(jnp.float32)
    if not based:
        matrix = matrix.reshape(num_blocks, -1, 256).transpose(0, 2, 1)
        return inputs @ matrix.reshape(inputs.shape[-1], -1)
    else:
        matrix = matrix.reshape(256, num_blocks, -1)
        inputs = inputs.reshape(inputs.shape[0], num_blocks, 256)
        result = jax.lax.dot_general(inputs, matrix, (((2, 1), (0, 1)), ((), ())))
        return result


In [7]:
a, b, c = 256 * 32, 1024, 1024
bs = 256
key = jax.random.key(0)
key_scale_factors, key_scale_offsets, key_qs1, key_qs2 = jax.random.split(key, 4)
scale_factors = jax.random.normal(key_scale_factors, (a // bs, 1, b), dtype=jnp.bfloat16)
scale_offsets = jax.random.normal(key_scale_offsets, (a // bs, 1, b), dtype=jnp.bfloat16)
qs1 = jax.random.randint(key_qs1, (a // bs, 12, b), 0, 255, dtype=jnp.uint8).view(jnp.int8)
qs2 = jax.random.randint(key_qs2, (a // bs, 128, b), 0, 255, dtype=jnp.uint8).view(jnp.int8)
inputs = jax.random.normal(jax.random.PRNGKey(2), (c, a), dtype=jnp.bfloat16) / (a ** 0.5)

NameError: name 'matmul_4bit' is not defined

In [9]:
d = matmul_4bit(inputs, scale_factors, scale_offsets, qs1, qs2)
print(d)
e = matmul_4bit(inputs, scale_factors, scale_offsets, qs1, qs2, based=True)
print(d - e)
print(jnp.abs(d - e).max())

[[ 239.15594    743.40454   -211.24771   ...   56.04226    170.22482
  -286.69595  ]
 [  -2.4591599  551.1879    -373.605     ...  -27.097576   149.38379
   -57.147476 ]
 [-220.97462   -212.58714    357.2976    ...  342.2035    -220.43375
   257.03244  ]
 ...
 [-314.5901      95.02302   -653.9875    ...  233.09988   -152.33174
   305.01868  ]
 [  64.707375   -47.08348   -289.0493    ... -278.11423    165.84293
  -159.15613  ]
 [  -0.9532623  356.84503   -123.46204   ...  -80.8894     394.97888
  -458.64697  ]]
[[ 0.0000000e+00  1.2207031e-04  0.0000000e+00 ...  1.9073486e-05
  -1.0681152e-04 -9.1552734e-05]
 [ 2.2888184e-05  1.2207031e-04  1.8310547e-04 ...  1.1444092e-04
   1.5258789e-05  4.1961670e-05]
 [-4.5776367e-05  9.1552734e-05 -1.2207031e-04 ... -3.0517578e-05
   7.6293945e-05  0.0000000e+00]
 ...
 [-1.2207031e-04 -5.3405762e-05 -6.1035156e-05 ...  0.0000000e+00
   0.0000000e+00  6.1035156e-05]
 [ 5.3405762e-05 -3.8146973e-06  9.1552734e-05 ...  1.5258789e-04
  -1.5258789e-05 

In [11]:
from micrlhf.quantizers import matmul_4bit_kernel, matmul_fast
mesh = jax.sharding.Mesh(np.array(jax.devices("tpu")[:1]).reshape((1, 1)), ("dp", "mp"))
f = matmul_fast(
    inputs,#.reshape(-1, 8, 256).swapaxes(1, 2).reshape(inputs.shape),
    # inputs,
    scale_factors.swapaxes(0, 1).astype(jnp.float16), scale_offsets.swapaxes(0, 1).astype(jnp.float16),
    qs1.swapaxes(0, 1), qs2.swapaxes(0, 1),
    kernel=matmul_4bit_kernel, mesh=mesh, in_axis=None, out_axis=None, is_transpose=True).block_until_ready()

In [12]:
jnp.abs(f - e).max(), jnp.abs(f).mean(), jnp.abs(e).mean()

In [6]:
 import jax; jax.print_environment_info()

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
jax.devices (4 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0) TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-43c1cc56-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')

