In [1]:
%matplotlib inline
import os
from collections import namedtuple
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np
import time
import jax
import genjax
from genjax import grasp

from jax import jit, lax, random
from jax.example_libraries import stax
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
from numpyro import optim
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO
from numpyro.handlers import replay, trace, seed
from optax import adam

# the only pyro dependency
import pyro.contrib.examples.multi_mnist as multi_mnist

key = jax.random.PRNGKey(314159)
console = genjax.pretty()

In [2]:
ϕ = (0.0, 0.0, 1.0, 1.0)

## Model

In [38]:
def model(data):
    x = numpyro.sample("x", dist.Normal(0.0, 10.0))
    y = numpyro.sample("y", dist.Normal(0.0, 10.0))
    rs = x**2 + y**2
    z = numpyro.sample("z", dist.Normal(rs, 0.3 + (rs / 100.0)), obs=data)
    return (x, y, z)


seed(model, rng_seed=0)(25.0)

[1m([0m[1;35mArray[0m[1m([0m[1;36m-12.5153885[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m, [1;35mArray[0m[1m([0m[1;36m-5.8665056[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m, [1;36m25.0[0m[1m)[0m

## Naive variational guide

In [39]:
# Now, we define our variational proposal.
def guide(data):
    μ1 = numpyro.param("μ1", 0.0)
    μ2 = numpyro.param("μ2", 0.0)
    log_σ1 = numpyro.param("log_σ1", 1.0)
    log_σ2 = numpyro.param("log_σ2", 1.0)
    x = numpyro.sample("x", dist.Normal(μ1, jnp.exp(log_σ1)))
    y = numpyro.sample("y", dist.Normal(μ2, jnp.exp(log_σ2)))

## Training

In [43]:
svi = SVI(model, guide, adam(1e-4), loss=TraceGraph_ELBO())
key, sub_key = jax.random.split(key)
svi_result = svi.run(sub_key, 5000, 50.0)
params = svi_result.params
params

100%|█| 5000/5000 [00:00<00:00, 5287.43it/s, init loss: 9129.2832, avg. loss [



[1m{[0m
    [32m'log_σ1'[0m: [1;35mArray[0m[1m([0m[1;36m1.3400445[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m,
    [32m'log_σ2'[0m: [1;35mArray[0m[1m([0m[1;36m1.3362696[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m,
    [32m'μ1'[0m: [1;35mArray[0m[1m([0m[1;36m0.00420607[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m,
    [32m'μ2'[0m: [1;35mArray[0m[1m([0m[1;36m0.0016424[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m
[1m}[0m