In [56]:
import numpyro
import numpyro.distributions as dist
import jax
from jax import numpy as jnp
from jax import vmap

def model():
    with numpyro.plate("N", 3):
        x = numpyro.sample("x", dist.Normal(0, 1))

nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key)
samps = mcmc.get_samples()
print(jnp.mean(samps['x'],0))
print(jnp.std(samps['x'],0))

sample: 100%|██████████| 2000/2000 [00:02<00:00, 688.69it/s, 3 steps of size 8.43e-01. acc. prob=0.92] 


[0.02599583 0.03129866 0.01175993]
[0.9868098 1.0511807 0.9563092]


In [12]:
means = jnp.array([0,1,2])
stds = jnp.array([3,4,5])

def model():
    with numpyro.plate("N", 3) as n:
        x = numpyro.sample("x", dist.Normal(means[n], stds[n]))

nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key)
samps = mcmc.get_samples()
print(jnp.mean(samps['x'],0))
print(jnp.std(samps['x'],0))

sample: 100%|██████████| 2000/2000 [00:02<00:00, 804.30it/s, 7 steps of size 7.49e-01. acc. prob=0.93] 


[0.06982622 1.0776219  1.9651314 ]
[3.0057814 4.166635  5.061793 ]


In [17]:
import numpy as np
A = jnp.array(np.random.randn(2,3,4,5))
A[*([slice(None)]*2),0].shape

(2, 3, 5)

In [49]:
means = jnp.array([[0,1,2],[3,4,5]])

def model():
    with numpyro.plate('N',means.shape[0],dim=-2) as n:
        with numpyro.plate('M',means.shape[1],dim=-1) as m:
            x = numpyro.sample("x", dist.Normal(means[n][:,m],1))
nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key)
samps = mcmc.get_samples()
print(samps['x'].shape)

sample: 100%|██████████| 2000/2000 [00:02<00:00, 784.88it/s, 7 steps of size 7.09e-01. acc. prob=0.91] 


(1000, 2, 3)


In [27]:
np.mean(samps['x'],axis=0)

Array([[3.5948488e-03, 9.7644651e-01, 1.9760137e+00],
       [2.9817114e+00, 3.9892282e+00, 4.9842019e+00],
       [2.9926810e+00, 4.0023050e+00, 4.9814672e+00],
       [3.0003114e+00, 4.0327892e+00, 4.9667888e+00]], dtype=float32)

In [30]:
means = jnp.array([0,1,2])

def model():
    with numpyro.plate("N") as n:
        x = numpyro.sample("x", dist.Normal(means[n],1))
nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key)
samps = mcmc.get_samples()

TypeError: plate.__init__() missing 1 required positional argument: 'size'

In [55]:
with numpyro.plate("N",5) as i:
    print("hi")
    print(i)

hi
[0 1 2 3 4]


In [72]:
means = jnp.array([0.1,0.5,0.9])
def model():
    with numpyro.numpyro_config(validate_args=False):
        x = numpyro.sample("x", vmap(dist.Exponential)(means))
    
nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key)
samps = mcmc.get_samples()['x']
print(jnp.mean(samps,axis=0))

AttributeError: module 'numpyro' has no attribute 'numpyro_config'

In [70]:
samps.shape

(1000,)

In [73]:
x = np.ones((2,3,4))
x[np.newaxis].shape

(1, 2, 3, 4)

In [74]:
x[np.newaxis]

array([[[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]]])

In [None]:
# verify that numpyro is fine with vectorized discrete variables