In [2]:
import flax
from flax import linen as nn

from jax import random, Array
from pprint import pprint

class Model(nn.Module):
    
    def setup(self):
        self.dense = nn.Dense(features=2) # n output features

    def __call__(self, batch: Array):
        return self.dense(batch)

model = Model()

key1, key2 = random.split(random.key(0))
print(key1, key2)

x = random.normal(key1, (2,))
params = model.init(key2, x) # infer model params via test input
pprint(params)

model.apply(params, x)

Array((), dtype=key<fry>) overlaying:
[1797259609 2579123966] Array((), dtype=key<fry>) overlaying:
[ 928981903 3453687069]
{'params': {'dense': {'bias': Array([0., 0.], dtype=float32),
                      'kernel': Array([[-0.51274115, -0.44576186],
       [ 0.7367678 , -0.98018515]], dtype=float32)}}}


Array([-1.1825595,  0.440827 ], dtype=float32)

In [3]:
import jax

jax.tree_util.tree_map(lambda x: x.shape, params)

{'params': {'dense': {'bias': (2,), 'kernel': (2, 2)}}}

In [4]:
# try compiling model using jit

# Make an argument static if it affects how the computation is structured
# Keep it dynamic if it's just data flowing through the computation


def model_fn(params: dict, batch: Array):
    return model.apply(params, batch)

model_fn_jit = jax.jit(model_fn)

model_fn_jit(params, x)


Array([-1.1825595,  0.440827 ], dtype=float32)

In [5]:
# Compare speed of non-compiled vs jit-compiled forward pass
import time
import numpy as np

# Number of iterations for timing
n_iters = 1000

# Time non-compiled version
start = time.time()
for _ in range(n_iters):
    out = model_fn(params, x) 
end = time.time()
non_compiled_time = end - start

# Time jit-compiled version 
# First call to compile
_ = model_fn_jit(params, x)

start = time.time()
for _ in range(n_iters):
    out = model_fn_jit(params, x)
end = time.time()
compiled_time = end - start

print(f"Non-compiled time: {non_compiled_time:.4f} seconds")
print(f"JIT-compiled time: {compiled_time:.4f} seconds") 
print(f"Speedup: {non_compiled_time/compiled_time:.1f}x")


Non-compiled time: 0.8328 seconds
JIT-compiled time: 0.0038 seconds
Speedup: 219.6x


In [6]:
# See the JAX program representation before optimization
print(jax.make_jaxpr(model_fn)(params, x))

# See the optimized version
print(jax.make_jaxpr(model_fn_jit)(params, x))

{ lambda ; a:f32[2] b:f32[2,2] c:f32[2]. let
    d:f32[2] = dot_general[dimension_numbers=(([0], [0]), ([], []))] c b
    e:f32[2] = add d a
  in (e,) }
{ lambda ; a:f32[2] b:f32[2,2] c:f32[2]. let
    d:f32[2] = pjit[
      name=model_fn
      jaxpr={ lambda ; e:f32[2] f:f32[2,2] g:f32[2]. let
          h:f32[2] = dot_general[dimension_numbers=(([0], [0]), ([], []))] g f
          i:f32[2] = add h e
        in (i,) }
    ] a b c
  in (d,) }


In [7]:
# Get the HLO (High Level Optimizer) representation
compiled = model_fn_jit.lower(params, x).compile()
print(compiled.as_text())  # Print HLO IR

HloModule jit_model_fn, is_scheduled=true, entry_computation_layout={(f32[2]{0}, f32[2,2]{1,0}, f32[2]{0})->f32[2]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}

%fused_computation (param_0.1: f32[2], param_1.1: f32[2], param_2: f32[2,2]) -> f32[2] {
  %param_1.1 = f32[2]{0} parameter(1)
  %param_2 = f32[2,2]{1,0} parameter(2)
  %dot.0 = f32[2]{0} dot(f32[2]{0} %param_1.1, f32[2,2]{1,0} %param_2), lhs_contracting_dims={0}, rhs_contracting_dims={0}, metadata={op_name="jit(model_fn)/jit(main)/Model/dense/dot_general" source_file="/Users/gardberg/dev/cascade/sandbox/hdemucs/.venv/lib/python3.11/site-packages/flax/linen/linear.py" source_line=271}
  %param_0.1 = f32[2]{0} parameter(0)
  ROOT %add.0 = f32[2]{0} add(f32[2]{0} %dot.0, f32[2]{0} %param_0.1), metadata={op_name="jit(model_fn)/jit(main)/Model/dense/add" source_file="/Users/gardberg/dev/cascade/sandbox/hdemucs/.venv/lib/python3.11/site-packages/flax/linen/line

In [10]:
from flax import nnx

class NnxModel(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        super().__init__()
        self.dense = nnx.Linear(2, 2, rngs=rngs)

    def forward(self, x: Array) -> Array:
        return self.dense(x)

rngs = nnx.Rngs(0)
model = NnxModel(rngs)

model.forward(x)

Array([1.2753228 , 0.10528118], dtype=float32)