In [None]:
import importlib
import random

In [None]:
from flax import nnx
import jax
import jax.numpy as jnp
from jax import lax
import optax

In [None]:
import layers as L
import utils as U

importlib.reload(L)
importlib.reload(U)

In [None]:
def foo(i):
    return chr((i + 1) % (126 - 32 + 1) + 32)

In [None]:
data = ""
for i in range(100000):
    x = random.randint(32, 126)
    c = foo(x) + foo(x + 1) + foo(x + 2) + foo(x + 3)
    data += f"{c}\n"

In [None]:
with open("../assets/new-dream.txt", "r") as f:
    data = f.read()

In [None]:
vocab = ["😱"] + sorted(list(set(data)))
print("vocab: " + "".join(vocab[:50]) + " ...")
print(f"vocab len: {len(vocab)}")

stoi = {ch: i for i, ch in enumerate(vocab)}
iots = {i: ch for i, ch in enumerate(vocab)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([iots[i] for i in l])
train_data = jnp.array(encode(data))

In [None]:
batch_size = 32
# context window length
block_size = 16

In [None]:
dynamic_slice_vmap = jax.vmap(lax.dynamic_slice, in_axes=(None, 0, None))


@jax.jit
def get_batch(random_key, data):
    ix = jax.random.randint(
        random_key, shape=(batch_size, 1), minval=0, maxval=len(data) - block_size
    )
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix + 1, (block_size,))
    return x, y

In [None]:
FULL_VISIBLE = jnp.full((batch_size, block_size), True, dtype=jnp.bool)


def loss_fn(model, x, y):
    logits = model(x, FULL_VISIBLE)
    return optax.softmax_cross_entropy_with_integer_labels(logits, labels=y).mean()


@nnx.jit
def train_step(model, optimizer, key):
    key, subkey = jax.random.split(key)
    batch = get_batch(key, train_data)
    loss, grads = nnx.value_and_grad(loss_fn)(model, *batch)
    optimizer.update(model, grads)
    return loss, key

In [None]:
vocab_size = len(vocab)
embed_dim = 16
qk_dim = embed_dim
hidden_dim = 2 * embed_dim
layer_count = 2
learning_rate = 1e-3

model = L.MicroLM(
    L.MicroLMConfig(
        vocab_size,
        embed_dim,
        qk_dim,
        hidden_dim,
        block_size,
        layer_count,
        position_encoding=L.PositionEncodingStrategy.ALL_YOU_NEED,
    ),
    rngs=nnx.Rngs(params=0),
)

In [None]:
key = jax.random.PRNGKey(1234)

In [None]:
optimizer = nnx.Optimizer(model, optax.adam(learning_rate), wrt=nnx.Param)

In [None]:
for i in range(5000):
    loss, key = train_step(model, optimizer, key)

    if i % 1000 == 0:
        print(f"step: {i}\t train loss: {loss}")

In [None]:
def compute_initial(text: str):
    s = encode(text)
    assert len(s) < block_size
    return jnp.concatenate(
        [
            jnp.array(s, dtype=jnp.int32),
            jnp.full((block_size - len(s),), stoi["😱"], dtype=jnp.int32),
        ]
    )


def completion(input: str, length: int, temp: float):
    result = U.generate_text2(model, key, temp, length, compute_initial(input))
    return input + decode(result)

In [None]:
print(completion("abcd\n1234\n5678\n", 100, 2))