In [5]:
import jax, jax.numpy as jp
import flax.linen as nn


class Test:
    def __init__(self, n):
        self.n = n
        self.model = nn.Dense(10)
        self.params = self.model.init(jax.random.PRNGKey(0), jp.ones((8, 1)))
        self.__call__ = jax.jit(self.make_fn())
        
    def make_fn(self):
        def apply(x):
            out = self.model.apply(self.params, x)
            return jp.sum(out >= self.n)
        return apply

In [6]:
test = Test(0.5)

In [7]:
test.params

{'params': {'kernel': Array([[ 0.7837025 ,  0.02946448, -0.55363274,  0.82858545, -0.89440423,
          -1.7786882 ,  0.51747984,  1.5032254 , -0.06438798,  0.699473  ]],      dtype=float32),
  'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}}

In [10]:
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (8, 1))
test.__call__(x)

Array(22, dtype=int32)

In [11]:
test.n = 0.1
test.__call__(x)

Array(22, dtype=int32)

In [34]:
from typing import Callable
from flax.struct import PyTreeNode, field

class JitTest(PyTreeNode):
    n: float = field(pytree_node=False)
    apply_fn: Callable = field(pytree_node=False)
    params: PyTreeNode
    
    @jax.jit
    def apply(self, x):
        out = self.apply_fn(self.params, x)
        return jp.sum(out >= self.n)

In [35]:
model = nn.Dense(10)
params = model.init(jax.random.PRNGKey(0), jp.ones((8, 1)))
test = JitTest(n=0.5, apply_fn=model.apply, params=params)

In [36]:
x = jax.random.normal(rng, (8, 1))
test.apply(x)

Array(22, dtype=int32)

In [39]:
test = test.replace(n=0.1)
test.apply(x)

Array(28, dtype=int32)