In [1]:
import jax
import jax.numpy as jnp
from jax import random
import math
from typing import Callable
from einops import rearrange

In [12]:
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [18]:
rng = jax.random.PRNGKey(1)
rng, inp_rng, init_rng = jax.random.split(rng, 3)

B = 32
T = 64
C = 512
inp = jax.random.randint(inp_rng, (B, T), 0, 256)  # Batch size 8, input size 2
inp.shape

(32, 64)

In [19]:
class Embeddings(nn.Module):
    model_dimension : int
    vocab_size : int

    def setup(self):
        self.embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.model_dimension)
    
    def __call__(self, x):
        x = self.embedding(x) * math.sqrt(self.model_dimension)
        return x

In [None]:
class LayerNorm(nn.Module):
    model_dimension : int
    gamma_init : Callable = nn.initializers.lecun_normal()
    beta_init : Callable = nn.initializers.lecun_normal()

    def setup(self):
        self.gamma = self.param('gamma', self.gamma_init, (1, 1, self.model_dimension))
        self.beta = self.param('beta',self.beta_init, (1, 1, self.model_dimension))
        self.eps = 1e-05
    
    def __call__(self, x):
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        norm = ((x - mean)/jnp.sqrt(var + self.eps))
        y = norm * self.gamma + self.beta
        
        return y

In [20]:
model = Embeddings(model_dimension=C, vocab_size=256)
print(model)

Embeddings(
    # attributes
    model_dimension = 512
    vocab_size = 256
)


In [21]:
# Initialize the model
params = model.init(init_rng, inp)
print(params)

{'params': {'embedding': {'embedding': Array([[ 0.0180604 ,  0.00516129, -0.03475013, ..., -0.01153203,
        -0.07170662,  0.04004765],
       [ 0.05600861, -0.04859962,  0.0012374 , ...,  0.01420904,
        -0.02650646, -0.04624134],
       [ 0.01814407,  0.02494509,  0.06067709, ..., -0.00804305,
         0.01008524, -0.03755622],
       ...,
       [-0.03716417, -0.03236735,  0.06689288, ..., -0.01243117,
         0.04572471,  0.04765891],
       [-0.05248969,  0.07601349,  0.02349232, ..., -0.01345753,
        -0.08337483,  0.02151756],
       [-0.03665387,  0.04299542, -0.01967484, ..., -0.01788876,
        -0.01530913,  0.00988401]], dtype=float32)}}}


In [24]:
res = model.apply(params, inp)
res.shape

(32, 64, 512)