In [1]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu

import equinox as eqx

In [2]:
key = jrandom.PRNGKey(10)

class Model(eqx.Module):
    mlp: eqx.nn.MLP
    label: str
    act: eqx.nn.Lambda
    d = dict(a=3)
    l = [2, 3]
    def __init__(self, in_d, out_d, w, d, act, key):
        self.mlp = eqx.nn.MLP(in_d, out_d, w, d, act, act, key=key)
        self.label = 'some_important_label'
        self.act = eqx.nn.Lambda(jnp.tanh)
        
    def __call__(self, x):
        return self.act(self.mlp(x))

f1 = Model(10, 10, 10, 10, jnp.tanh, key=key)

In [3]:
x = jrandom.normal(key, shape=(10,))

In [4]:
f1(x)

Array([-0.10066809,  0.01973604, -0.1169754 ,  0.00370935,  0.00341123,
        0.23280813,  0.14593743,  0.32178235,  0.09192786,  0.12030009],      dtype=float32)

In [7]:
eqx.tree_serialise_leaves("10.eqx", f1)

In [8]:
f2 = Model(10, 10, 10, 10, jnp.tanh, key=key)
eqx.tree_deserialise_leaves("10.eqx", f2)

Model(
  mlp=MLP(
    layers=[
      Linear(
        weight=f32[10,10],
        bias=f32[10],
        in_features=10,
        out_features=10,
        use_bias=True
      ),
      Linear(
        weight=f32[10,10],
        bias=f32[10],
        in_features=10,
        out_features=10,
        use_bias=True
      ),
      Linear(
        weight=f32[10,10],
        bias=f32[10],
        in_features=10,
        out_features=10,
        use_bias=True
      ),
      Linear(
        weight=f32[10,10],
        bias=f32[10],
        in_features=10,
        out_features=10,
        use_bias=True
      ),
      Linear(
        weight=f32[10,10],
        bias=f32[10],
        in_features=10,
        out_features=10,
        use_bias=True
      ),
      Linear(
        weight=f32[10,10],
        bias=f32[10],
        in_features=10,
        out_features=10,
        use_bias=True
      ),
      Linear(
        weight=f32[10,10],
        bias=f32[10],
        in_features=10,
        out_features=10,

In [9]:
f2(x)

Array([-0.10066809,  0.01973604, -0.1169754 ,  0.00370935,  0.00341123,
        0.23280813,  0.14593743,  0.32178235,  0.09192786,  0.12030009],      dtype=float32)

In [12]:
def equal_models(model1, model2):
    eq = jtu.tree_map(lambda leaf1, leaf2: leaf1 == leaf2, model1, model2)
    _, tdef1 = jtu.tree_flatten(model1)
    _, tdef2 = jtu.tree_flatten(model2)
    
    return all(a.all() if isinstance(a, jnp.ndarray) else a
               for a in jtu.tree_leaves(eq)) and tdef1 == tdef2



In [13]:
equal_models(f1, f2)

val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True
val: True


True

In [78]:
f2.d


{'a': 3}