In [412]:
import functools
import importlib

In [413]:
from flax import nnx
import jax
import jax.numpy as jnp
from jax import lax
import optax

In [389]:
import layers as L

importlib.reload(L)

<module 'layers' from '/home/tianjiao/learn-jax/layers.py'>

In [6]:
with open("../assets/红楼梦.txt", "r") as f:
    data = f.read()

In [220]:
data = ""
for i in range(0, 10):
    for j in range(0, 10):
        s = f"{i}+{j}={i + j}\n"
        data += s

In [222]:
vocab = sorted(list(set(data)))
print("vocab: " + "".join(vocab[:-50]) + " ...")
print(f"vocab len: {len(vocab)}")

stoi = {ch: i for i, ch in enumerate(vocab)}
iots = {i: ch for i, ch in enumerate(vocab)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([iots[i] for i in l])
train_data = jnp.array(encode(data))

vocab:  ...
vocab len: 13


In [316]:
batch_size = 8
# context window length
block_size = 11

In [317]:
dynamic_slice_vmap = jax.vmap(lax.dynamic_slice, in_axes=(None, 0, None))


@jax.jit
def get_batch(random_key, data):
    ix = jax.random.randint(
        random_key, shape=(batch_size, 1), minval=0, maxval=len(data) - block_size
    )
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix + 1, (block_size,))
    return x, y

In [318]:
def loss_fn(model, x, y):
    logits = model(x)
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels=y).mean()


@nnx.jit
def train_step(model, optimizer, key):
    key, subkey = jax.random.split(key)
    batch = get_batch(key, train_data)
    loss, grads = nnx.value_and_grad(loss_fn)(model, *batch)
    optimizer.update(model, grads)
    return loss, key

In [390]:
vocab_size = len(vocab)
embed_dim = 8
qk_dim = 32
hidden_dim = 4 * embed_dim
layer_count = 1
learning_rate = 1e-3

model = L.MicroLM(
    L.MicroLMConfig(vocab_size, embed_dim, qk_dim, hidden_dim, block_size, layer_count),
    rngs=nnx.Rngs(params=0),
)

In [391]:
key = jax.random.PRNGKey(1234)

In [392]:
optimizer = nnx.Optimizer(model, optax.adam(learning_rate), wrt=nnx.Param)

In [393]:
all_train_losses = []

In [394]:
for i in range(10000):
    loss, key = train_step(model, optimizer, key)
    all_train_losses.append(loss)

    if i % 1000 == 0:
        print(f"step: {i}\t train loss: {loss}")

step: 0	 train loss: 2.6081106662750244
step: 1000	 train loss: 1.6376978158950806
step: 2000	 train loss: 1.1147936582565308
step: 3000	 train loss: 0.7417328953742981
step: 4000	 train loss: 0.22384488582611084
step: 5000	 train loss: 0.0853957012295723
step: 6000	 train loss: 0.0856800526380539
step: 7000	 train loss: 0.13447034358978271
step: 8000	 train loss: 0.09790881723165512
step: 9000	 train loss: 0.06301579624414444


In [408]:
@functools.partial(jax.jit, static_argnames=["length"])
def generate_text(model, key, length, initial):
    def scan_gen(carry, _):
        key, context = carry
        logits = model(context)
        key, subkey = jax.random.split(key)
        new_token = jax.random.categorical(subkey, logits[-1], shape=(1,))
        context = jnp.concatenate([context[1:], new_token])
        return (key, context), new_token

    _, new_tokens = lax.scan(scan_gen, (key, initial), (), length=length)
    return new_tokens

In [409]:
def compute_initial(text: str):
    s = encode(text)
    assert len(s) < block_size
    return jnp.concatenate(
        [
            jnp.zeros((block_size - len(s),), dtype=jnp.int32),
            jnp.array(s, dtype=jnp.int32),
        ]
    )


def completion(input: str, length: int):
    result = generate_text(model, key, length, compute_initial(input))[:, 0].tolist()
    return input + decode(result)

In [411]:
print(completion("1+0=", 4))

1+0=7
1+
