I want to save my traces for later reuse. What do I do?

In [None]:
import genjax
import jax
import jax.numpy as jnp
from genjax import gen, pretty
from genjax._src.core.serialization.msgpack import msgpack_serialize

pretty()

Genjax supports serializing traces into a byte format. To support different disk formats, GenJax offers different backends. Here we use `msgpack_serialize` which writes traces using the `MsgPack` protocol.

In [None]:
@gen
def model(mu, coins):
    x = genjax.normal(mu, 1.0) @ "x"
    y = genjax.flip(jnp.sum(coins) / coins) @ "y"
    return x + y


key = jax.random.PRNGKey(314159)
args = (-2.1, jnp.array([1, 1, 0]))
tr = model.simulate(key, args)
tr.get_sample()

We can now convert `tr` into a byte representation using `serialize`. Note that GenJax provides `pickle`-like functions (e.g. `dumps`) that can alternatively be called instead.

In [None]:
serialized_tr = msgpack_serialize.serialize(tr)  # or msgpack_serialize.dumps(tr

Deserialization is slightly different then one might be used to. In addition to passing in the bytes, the generative function is specified as well as its argument.

In [None]:
retrieved_tr = msgpack_serialize.deserialize(serialized_tr, model, args)
retrieved_tr

Note that serialized trace can be written to disk with the `MsgPack` backend. This means that even if the Python runtime terminates, the written trace can still be deserialized so long as the generative function induces the same trace structure. Observe that a fresh copy of `model` works just as well to deserialize the trace.

In [None]:
@gen
def second_model_with_similar_trace_structure(mu, coins):
    x = genjax.inverse_gamma(mu, 2.0) @ "x"
    y = genjax.flip(jnp.sum(coins) / coins) @ "y"
    return x + y


retrieved_tr = msgpack_serialize.deserialize(
    serialized_tr, second_model_with_similar_trace_structure, args
)
retrieved_tr