In [23]:
from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
from jax import lax
import optax

In [3]:
class Linear(nnx.Module):
    """
    A linear transformation.

        y = xA + b
    """

    def __init__(self, input_dim: int, output_dim: int, *, rngs: nnx.Rngs):
        """
        Args:
            input_dim : dimension of the input vector
            output_dim : dimension of the output vector
        """
        self.A = nnx.Param(rngs.params.uniform((input_dim, output_dim)))
        self.b = nnx.Param(jnp.zeros(output_dim))

    def __call__(self, x: jax.Array):
        return x @ self.A + self.b

In [7]:
Linear(2, 5, rngs=nnx.Rngs(params=0))(x=jnp.ones((2, 4, 3, 2))).shape

(2, 4, 3, 5)

In [None]:
class Embedding(nnx.Module):
    """
    Map discrete tokens to an embedding space.

        token i --> A[i]
    """

    def __init__(self, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):
        self.A = nnx.Param(rngs.params.uniform((vocab_size, embed_dim)))

    def __call__(self, ids: jax.Array):
        return jnp.take(self.A.value, ids, axis=0)

In [18]:
Embedding(10, 5, rngs=nnx.Rngs(params=0))(jnp.array([0, 2, 3, 9]))

Array([[0.8423141 , 0.18237865, 0.2271781 , 0.12072563, 0.19181347],
       [0.09871054, 0.55314326, 0.12444711, 0.59456205, 0.9594908 ],
       [0.6932272 , 0.72409594, 0.31816435, 0.82007146, 0.64102626],
       [0.36705065, 0.75454223, 0.36721492, 0.68864214, 0.5837884 ]],      dtype=float32)

In [20]:
class LayerNorm(nnx.Module):
    """
    Layer normalization. arXiv:1607.06450

        y = scale * (x - E[x]) / sqrt(Var[x] + epsilon) + bias
    """

    def __init__(
        self,
        input_dim: int,
        *,
        use_scale: bool = True,
        use_bias: bool = True,
        rngs: nnx.Rngs,
    ):
        if use_scale:
            self.scale = nnx.Param(rngs.params.uniform((input_dim,)))
        else:
            self.scale = None
        if use_bias:
            self.bias = nnx.Param(rngs.params.uniform((input_dim,)))
        else:
            self.bias = None
        self.epsilon = 1e-9

    def __call__(self, x: jax.Array):
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        result = (x - mean) * lax.rsqrt(var + self.epsilon)
        if self.scale is not None:
            result *= self.scale.value
        if self.bias is not None:
            result += self.bias.value
        return result

In [22]:
LayerNorm(5, use_scale=True, use_bias=True, rngs=nnx.Rngs(params=0))(
    jnp.ones((10, 9, 7, 5))
).shape

(10, 9, 7, 5)

In [None]:
def generic_dot_product_attention(
    q: jax.Array, k: jax.Array, v: jax.Array, *, use_scale: bool = True
) -> jax.Array:
    """
    Compute the (scaled) dot-product attention.

        Attention(Q, K, V) = softmax(Q K^t / sqrt(d_k)) V

    Args:
        q: shape (...batch, query_count, qk_dim)
        k: shape (...batch, kv_count, qk_dim)
        v: shape (...batch, kv_count, v_dim)

    Return:
        shape (...batch, query_count, v_dim)
    """
    assert q.shape[-1] == k.shape[-1]
    qk = jnp.einsum("...ij,...kj->...ik", q, k)
    if use_scale:
        dk = k.shape[-1]
        # TODO:
        qk *= lax.rsqrt(float(dk))
    s = jnn.softmax(qk, axis=-1)
    return jnp.einsum("...ij,...jk->...ik", s, v)

In [30]:
generic_dot_product_attention(
    jnp.zeros((10, 11, 9, 5)),
    jnp.zeros((10, 11, 7, 5)),
    jnp.zeros((10, 11, 7, 3)),
).shape

(10, 11, 9, 3)

In [31]:
class DotProductAttention(nnx.Module):
    """
    (Scaled) dot-product attention. arXiv:1706.03762
    """

    def __init__(
        self,
        input_q_dim: int,
        input_k_dim: int,
        input_v_dim: int,
        output_dim: int,
        qk_dim: int,
        *,
        use_scale: bool = True,
        rngs: nnx.Rngs,
    ):
        self.use_scale = use_scale
        self.W_q = nnx.Param(rngs.params.uniform((input_q_dim, qk_dim)))
        self.W_k = nnx.Param(rngs.params.uniform((input_k_dim, qk_dim)))
        self.W_v = nnx.Param(rngs.params.uniform((input_v_dim, output_dim)))

    def __call__(self, q: jax.Array, k: jax.Array, v: jax.Array):
        q = q @ self.W_q
        k = k @ self.W_k
        v = v @ self.W_v
        return generic_dot_product_attention(q, k, v, use_scale=self.use_scale)

In [33]:
DotProductAttention(3, 4, 5, 7, 10, rngs=nnx.Rngs(params=0))(
    jnp.zeros((11, 13, 3)), jnp.zeros((11, 13, 4)), jnp.zeros((11, 13, 5))
).shape

(11, 13, 7)

In [None]:
class FeedForward(nnx.Module):
    """
    A fully connected feed-forward network of depth 2, for using in transformers.

        y = ReLU(x W_1 + b_1) W_2 + b_2
    """

    def __init__(
        self, input_dim: int, output_dim: int, hidden_dim: int, *, rngs: nnx.Rngs
    ):
        self.W_1 = nnx.Param(rngs.params.uniform((input_dim, hidden_dim)))
        self.b_1 = nnx.Param(rngs.params.uniform((hidden_dim,)))
        self.W_2 = nnx.Param(rngs.params.uniform((hidden_dim, output_dim)))
        self.b_2 = nnx.Param(rngs.params.uniform((output_dim,)))

    def __call__(self, x: jax.Array):
        y = x @ self.W_1 + self.b_1
        y = jnn.relu(y)
        return y @ self.W_2 + self.b_2

In [41]:
FeedForward(3, 4, 5, rngs=nnx.Rngs(params=0))(jnp.zeros((10, 13, 3))).shape

(10, 13, 4)

In [None]:
class MicroLM(nnx.Module):
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        qk_dim: int,
        hidden_dim: int,
        *,
        rngs: nnx.Rngs,
    ):
        self.token_embed = Embedding(vocab_size, embed_dim, rngs=rngs)
        self.embed_normalization = LayerNorm(embed_dim, rngs=rngs)
        self.attention = DotProductAttention(
            embed_dim, embed_dim, embed_dim, embed_dim, qk_dim, rngs=rngs
        )
        self.feed_forward = FeedForward(embed_dim, embed_dim, hidden_dim, rngs=rngs)
        self.lm_head = Linear(embed_dim, vocab_size, rngs=rngs)

    def __call__(self, x: jax.Array):
        # seq_len = x.shape[-1]

        # Token id --> embedding
        x = self.token_embed(x)
        # normalize each embedded token
        x_norm = self.embed_normalization(x)
        x += self.attention(x_norm, x_norm, x_norm)
        x_norm = self.embed_normalization(x)
        x += self.feed_forward(x_norm)

        x_norm = self.embed_normalization(x)
        return self.lm_head(x_norm)

In [None]:
model = MicroLM(
    vocab_size=256, embed_dim=10, qk_dim=10, hidden_dim=10, rngs=nnx.Rngs(params=0)
)