In [None]:
import flax
from flax import linen as nn

from jax import random, Array
from pprint import pprint

class Model(nn.Module):
    
    def setup(self):
        self.dense = nn.Dense(features=2) # n output features

    def __call__(self, batch: Array):
        return self.dense(batch)

model = Model()

key1, key2 = random.split(random.key(0))
print(key1, key2)

x = random.normal(key1, (2,))
params = model.init(key2, x) # infer model params via test input
pprint(params)

model.apply(params, x)

In [None]:
import jax

jax.tree_util.tree_map(lambda x: x.shape, params)

In [None]:
# try compiling model using jit

# Make an argument static if it affects how the computation is structured
# Keep it dynamic if it's just data flowing through the computation


def model_fn(params: dict, batch: Array):
    return model.apply(params, batch)

model_fn_jit = jax.jit(model_fn)

model_fn_jit(params, x)


In [None]:
# Compare speed of non-compiled vs jit-compiled forward pass
import time
import numpy as np

# Number of iterations for timing
n_iters = 1000

# Time non-compiled version
start = time.time()
for _ in range(n_iters):
    out = model_fn(params, x) 
end = time.time()
non_compiled_time = end - start

# Time jit-compiled version 
# First call to compile
_ = model_fn_jit(params, x)

start = time.time()
for _ in range(n_iters):
    out = model_fn_jit(params, x)
end = time.time()
compiled_time = end - start

print(f"Non-compiled time: {non_compiled_time:.4f} seconds")
print(f"JIT-compiled time: {compiled_time:.4f} seconds") 
print(f"Speedup: {non_compiled_time/compiled_time:.1f}x")


In [None]:
# See the JAX program representation before optimization
print(jax.make_jaxpr(model_fn)(params, x))

# See the optimized version
print(jax.make_jaxpr(model_fn_jit)(params, x))

In [None]:
# Get the HLO (High Level Optimizer) representation
compiled = model_fn_jit.lower(params, x).compile()
print(compiled.as_text())  # Print HLO IR

In [None]:
from flax import nnx

class NnxModel(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        super().__init__()
        self.dense = nnx.Linear(2, 2, rngs=rngs)

    def forward(self, x: Array) -> Array:
        return self.dense(x)

rngs = nnx.Rngs(0)
model = NnxModel(rngs)

model.forward(x)