In [1]:
import os
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
os.environ['JAX_COMPILATION_CACHE_DIR'] = '/tmp/jax_cache'
# os.environ['TPU_CHIPS_PER_PROCESS_BOUNDS'] = '1,1,1'
# os.environ['TPU_PROCESS_BOUNDS'] = '1,1,1'
# os.environ['TPU_VISIBLE_DEVICES'] = '0'
import jax
import jax.numpy as jnp
import operator as op
from functools import partial
from tqdm.auto import tqdm
from models import qwen3
from typing import NamedTuple
from jax.sharding import PartitionSpec as P, reshard
%load_ext autoreload
%autoreload 2

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


# training

In [2]:
model = qwen3.load('Qwen/Qwen3-8B-Base')
model.weights = jax.tree.map(lambda x: x.astype(jnp.float32), model.weights)

In [3]:
def cross_entropy(logits, labels):
    log_softmax = jax.nn.log_softmax(logits.astype(jnp.float32))
    one_hot = jax.nn.one_hot(labels, logits.shape[-1])
    return -jnp.sum(one_hot * log_softmax, axis=-1)

def loss_fn(forward, weights, x):
    y = jnp.roll(x, -1, axis=1)
    logits, _ = forward(x, weights)
    loss = cross_entropy(logits, y)
    return loss.mean()

@partial(jax.jit, static_argnames=('forward'), donate_argnames=('weights'))
def train_step(forward, weights, x):
    loss, grads = jax.value_and_grad(loss_fn, argnums=1)(forward, weights, x)
    weights = jax.tree.map(lambda w, g: w - 0.001*g, weights, grads)
    return weights, loss

B, T = 128, 512
x = jax.random.randint(jax.random.key(0), [B, T], 0, 1<<15, jnp.int32)
for i in tqdm(range(200)):
    model.weights, loss = train_step(model.forward, model.weights, x)

  0%|          | 0/200 [00:00<?, ?it/s]

In [6]:
# get number of model params
n_params = jax.tree.reduce_associative(op.add, jax.tree.map(jnp.size, model.weights))
print(f'{n_params=:_}')

# compute flops per fwd pass
flops_per_sec = 8 * 918 * 10**12
flops_per_iter = 6 * n_params * B * T
iters_per_sec_100mfu = flops_per_sec / flops_per_iter

# compute mfu
measured_iters_per_sec = 1/1.44
mfu = measured_iters_per_sec / iters_per_sec_100mfu
print(f'{mfu=:.1%}')

n_params=8_190_735_360
mfu=30.5%


# sampling

In [2]:
class SamplingState(NamedTuple):
    key: jax.Array
    tokens: jax.Array # [B, T]
    kv: jax.Array # [B, T, ...]

def _sample_step(i, state, forward, weights):

    # sample next token
    key, key_sampling = jax.random.split(state.key)
    input_token = state.tokens[:, i, None] # [B, 1]
    logits, kv = forward(input_token, weights, state.kv, i) # [B, 1, V]
    sampled_token = jax.random.categorical(key_sampling, logits[:, 0, :])

    # update buffer
    tokens = state.tokens.at[:, i+1].set(sampled_token)

    return SamplingState(key, tokens, kv)

@partial(jax.jit, static_argnames=('forward'))
def sample(key, forward, weights, tokens):
    B, T = tokens.shape
    tokens = reshard(tokens, P('data', None))

    # initialize state
    state = SamplingState(
        key=key,
        tokens=tokens,
        kv=model.init_kv(B, T),
    )

    # sample next token inside a for loop
    step_fn = lambda i, state: _sample_step(i, state, forward, weights)
    state = jax.lax.fori_loop(0, T, step_fn, state)

    return state.tokens

In [3]:
model = qwen3.load('Qwen/Qwen3-32B', tp_devices=8)

In [32]:
# run this cell twice to get the sampling time excluding jit compilation time #
key = jax.random.key(0)
B, T = 1, 512
# B, T = 512, 512
# B, T = 128, 4096
x = jax.random.randint(key, [B, T], 0, 1<<15, jnp.int32)
sample(key, model.forward, model.weights, x)

Array([[  869,   345,   197,   197,   322,   714,   582,  1366,   419,
          311,   387,  9434,   304,   279,  9308,  1895,   624,   197,
         7719,    25,   830,   198,   197,  4546,   853,  1102,   280,
         2315,   947,  2038, 12561,   284,   501,  6119, 12561,  2129,
          198,   198,  1851, 12561, 17487,  1006,   197,     1, 40923,
          756,   197,     1,  1688, 33428,   320,   606,     8, 28152,
           77,     1,  3610,   197,     1,   197,   322,  4692,  1699,
            1,  3610,   197,     1,  2405,   429,   284,   419, 17882,
           77,     1,   488,   198,   197, 11934,    77,     1,  3610,
          197,     1,   197,   322,   869,  3890,  1699,     1,  3610,
          197,     1,  2405,   264,   284,   220,    16, 17882,    77,
            1,  3610,   197, 11934,    77,     1,  3610,   197,     1,
          197,   322,   869,   729,  1699,     1,  3610,   197,     1,
         2405,   869,   284,   729,  1719, 28152,    77,     1,  3610,
      

In [31]:
# kv cache size (L, 2, B, T, K, H)
device_mem_bwd = 8 * 1600 * 10**9 # bytes / second
L, K, H = 64, 8, 128
kv_size = (B * T) * (L * K * H) * 2 * 2 # k+v, 16-bit precision
print(f'{kv_size=:_}')
model_params = jax.tree.reduce_associative(op.add, jax.tree.map(jnp.size, model.weights))
model_size = model_params * 2 # 16-bit precision
print(f'{model_size=:_}')
time_100_util = (T * (kv_size + model_params)) / device_mem_bwd
print(time_100_util)
actual_time = 4.76 # seconds
bwd_util = time_100_util / actual_time
print(f'{bwd_util=:.1%}')

kv_size=134_217_728
model_size=65_524_246_528
1.31585363968
bwd_util=27.6%
