I want to save my traces for later reuse. What do I do?

Genjax supports serializing traces into a byte format. To support different disk formats, GenJax offers different backends. Here we use `msgpack_serialize` which writes traces using the `MsgPack` protocol.

In [1]:
# Closely following Ian's notebook
import genjax
import jax
import jax.numpy as jnp
from genjax._src.core.serialization.msgpack import msgpack_serialize
from genjax import gen

@gen
def model(mu, coins):
    x = genjax.normal(mu, 1.0) @ "x"
    y = genjax.flip(jnp.sum(coins) / coins) @ "y"
    return x + y

key = jax.random.PRNGKey(314159)
args = (-2.1, jnp.array([1, 1, 0]))
tr = model.simulate(key, args)
print(tr.get_sample())


XorChm(c1=StaticChm(addr='x', c=ValueChm(v=<jax.Array(-2.208231, dtype=float32)>)), c2=StaticChm(addr='y', c=ValueChm(v=<jax.Array([1, 1, 1], dtype=int32)>)))


We can now convert `tr` into a byte representation using `serialize`. Note that GenJax provides `pickle`-like functions (e.g. `dumps`) that can alternatively be called instead.

In [2]:
serialized_tr = msgpack_serialize.serialize(tr)  # or msgpack_serialize.dumps(tr

Deserialization is slightly different then one might be used to. In addition to passing in the bytes, the generative function is specified as well as its argument.

In [3]:
retrieved_tr = msgpack_serialize.deserialize(serialized_tr, model,args)
print(retrieved_tr)

StaticTrace(gen_fn=StaticGenerativeFunction(source=Closure(dyn_args=(), fn=<function model at 0x14d555b20>)), args=(-2.1, <jax.Array([1, 1, 0], dtype=int32)>), retval=<jax.Array([-1.208231, -1.208231, -1.208231], dtype=float32)>, addresses=AddressVisitor(visited=[('x',), ('y',)]), subtraces=[DistributionTrace(gen_fn=ExactDensityFromCallables(sampler=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.sampler at 0x14d3b0b80>), logpdf_evaluator=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.logpdf at 0x14d3b0c20>)), args=(-2.1, 1.0), value=<jax.Array(-2.208231, dtype=float32)>, score=<jax.Array(-0.9247955, dtype=float32)>), DistributionTrace(gen_fn=ExactDensityFromCallables(sampler=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.sampler at 0x14d393100>), logpdf_evaluator=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.logpdf at 0x14d3931a0>)), args=(<jax.Array([ 2.,  2., inf], dtype=float32)>,), value=<jax.Array([1, 1, 1], dtype=int32)>, score

Note that serialized trace can be written to disk with the `MsgPack` backend. This means that even if the Python runtime terminates, the written trace can still be deserialized so long as the generative function induces the same trace structure. Observe that a fresh copy of `model` works just as well to deserialize the trace.

In [4]:
@gen
def second_model_with_similar_trace_structure(mu, coins):
    x = genjax.inverse_gamma(mu, 2.) @ "x"
    y = genjax.flip(jnp.sum(coins) / coins) @ "y"
    return x + y

retrieved_tr = msgpack_serialize.deserialize(serialized_tr, second_model_with_similar_trace_structure, args)
print(retrieved_tr)

StaticTrace(gen_fn=StaticGenerativeFunction(source=Closure(dyn_args=(), fn=<function second_model_with_similar_trace_structure at 0x13c1aca40>)), args=(-2.1, <jax.Array([1, 1, 0], dtype=int32)>), retval=<jax.Array([-1.208231, -1.208231, -1.208231], dtype=float32)>, addresses=AddressVisitor(visited=[('x',), ('y',)]), subtraces=[DistributionTrace(gen_fn=ExactDensityFromCallables(sampler=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.sampler at 0x14d393c40>), logpdf_evaluator=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.logpdf at 0x14d393ce0>)), args=(-2.1, 1.0), value=<jax.Array(-2.208231, dtype=float32)>, score=<jax.Array(-0.9247955, dtype=float32)>), DistributionTrace(gen_fn=ExactDensityFromCallables(sampler=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.sampler at 0x14d393100>), logpdf_evaluator=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.logpdf at 0x14d3931a0>)), args=(<jax.Array([ 2.,  2., inf], dtype=float32)>,), value=<jax.A