In [None]:
from flax import nnx
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Array, PRNGKeyArray, Float


class MLP(nnx.Module):
    def __init__(self, dim:int, mlp_ratio:int,*, rngs: nnx.Rngs):
        self.layer_norm = nnx.LayerNorm(dim, use_bias=False, use_scale=False, rngs=rngs)

        dhid = dim * mlp_ratio
        self.linear1 = nnx.Linear(dim, dhid, rngs=rngs)
        self.linear2 = nnx.Linear(dhid, dim, rngs=rngs)

        self.random_const = 5

    def __call__(
        self,
        x: Float[Array, "b num_patches embed_dim"],
        gamma: Float[Array, "b embed_dim"],
        beta: Float[Array, "b embed_dim"],
        alpha: Float[Array, "b embed_dim"],
    ) -> Float[Array, "b num_patches embed_dim"]:

        residual = x
        x = self.layer_norm(x)
        x = x * (1 + gamma[:, None, :]) + beta[:, None, :]
        x = self.linear1(x)
        x = nnx.silu(x)
        x = self.linear2(x)
        eturn alpha[:, None, :] * x + residual

In [None]:
def create_model():
    return MLP(dim=128, mlp_ratio=4, rngs=nnx.Rngs(0))

model = jax.eval_shape(create_model)
state = nnx.state(model, nnx.Param)
# print(state)

mask = jax.tree_util.tree_map(lambda m: True, state)
mask['linear1']['bias'].value = False
mask['linear1']['kernel'].value = False
print(mask)