Jax transformations such as `jax.jit` and `jax.grad` require objects that are immutable and can be mapped over using the `jax.tree_util` methods. The `dataclass` decorator makes it easy to define custom classes that can be passed safely to Jax. 

See https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html

In [33]:
import jax
import jax.numpy as jnp
from types import FunctionType
from flax import struct
from clu import metrics
from typing import Any

@struct.dataclass
class Model:
    param: Any
    param2: Any
    apply_fn: FunctionType = struct.field(pytree_node=False)
    

In [34]:
m = Model({'a': jnp.array(1), 
           'b': jnp.array(2), 
           'c': jnp.array([3, 4, 5])},
          jnp.array([1, 2, 3]),
          lambda x: x + 1)

dataclass is immutable

In [35]:
# This is 'FrozenInstanceError'
# m.param = 1

In [36]:
print(m.param2)
m = m.replace(param2=jnp.array([4, 5, 6]))
print(m.param2)

[1 2 3]
[4 5 6]


And it supports tree_util

In [37]:
print(jax.tree_util.tree_leaves(m))
print(jax.tree_util.tree_map(lambda x: x.shape, m))

[Array(1, dtype=int32, weak_type=True), Array(2, dtype=int32, weak_type=True), Array([3, 4, 5], dtype=int32), Array([4, 5, 6], dtype=int32)]
Model(param={'a': (), 'b': (), 'c': (3,)}, param2=(3,), apply_fn=<function <lambda> at 0x12f9627a0>)
