In [1]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.tree_util as jtu

import equinox as eqx

In [23]:
key = jrandom.PRNGKey(10)

class Model(eqx.Module):
    mlp: eqx.nn.MLP
    def __init__(self, in_d, out_d, w, d, act, key):
        self.mlp = eqx.nn.MLP(in_d, out_d, w, d, act, act, key=key)
    
    def __call__(self, x):
        return self.mlp(x)

f1 = Model(10, 10, 10, 10, jnp.tanh, key=key)

In [24]:
x = jrandom.normal(key, shape=(10,))

In [25]:
f1(x)

Array([-0.1010102 ,  0.01973858, -0.11751334,  0.00370937,  0.00341123,
        0.23715648,  0.14698686,  0.3336342 ,  0.0921881 ,  0.12088551],      dtype=float32)

In [26]:
eqx.tree_serialise_leaves("10.eqx", f1)

In [31]:
f2 = Model(10, 10, 10, 10, jnp.tanh, key=key)
eqx.tree_serialise_leaves("10.eqx", f2)

In [32]:
f2(x)

Array([-0.1010102 ,  0.01973858, -0.11751334,  0.00370937,  0.00341123,
        0.23715648,  0.14698686,  0.3336342 ,  0.0921881 ,  0.12088551],      dtype=float32)

In [33]:
def equal_models(model1, model2):
    eq = jtu.tree_map(lambda leaf1, leaf2: leaf1 == leaf2, model1, model2)
    return all(a.all() if isinstance(a, jnp.ndarray) else a
               for a in jtu.tree_leaves(eq))



In [34]:
equal_models(f1, f2)

True