In [1]:
import numpyro
def model():
    with numpyro.plate("i",5) as i:
        x = numpyro.deterministic("x",0)
        y = numpyro.deterministic("y",i)
        z = numpyro.sample("z",numpyro.distributions.Normal(0,1))
    return x, y, z

with numpyro.handlers.seed(rng_seed=0):
    out = model()

print(out)

(0, Array([0, 1, 2, 3, 4], dtype=int32), Array([-1.4581939, -2.047044 ,  2.0473392,  1.1684095, -0.9758364],      dtype=float32))


In [8]:
import numpyro
from jax import numpy as jnp

means = jnp.array([1,2,3])

def model():
    with numpyro.plate("i",3) as i:
        z = numpyro.sample("x",numpyro.distributions.Normal(means[i],1e-10))
    return z

with numpyro.handlers.seed(rng_seed=0):
    out = model()

print(out)


[1. 2. 3.]


In [15]:
import numpyro
from jax import numpy as jnp

means = jnp.array([1,2,3])

def model():
    with numpyro.plate("i",3) as i:
        with numpyro.plate("j",4) as j:
            tmp = numpyro.distributions.Normal(means[i],1e-10)
            z = numpyro.sample("x",tmp)
    return z

with numpyro.handlers.seed(rng_seed=0):
    out = model()

print(out)
 

[[1. 2. 3.]
 [1. 2. 3.]
 [1. 2. 3.]
 [1. 2. 3.]]


In [20]:
import numpyro
from jax import numpy as jnp
import numpy as np

means = jnp.array([1,2,3])

x_obs = jnp.array(np.random.randn(4,3))

def model():
    with numpyro.plate("i",3) as i:
        my_x_obs = x_obs[:,i]
        with numpyro.plate("j",4) as j:
            tmp = numpyro.distributions.Normal(means[i],1e-10)
            z = numpyro.sample("z",tmp)
            tmp = numpyro.distributions.Normal(z,0.1)
            my_my_x_obs = my_x_obs[j]
            x = numpyro.sample("x",tmp,obs=my_my_x_obs)
    return z, x

with numpyro.handlers.seed(rng_seed=0):
    out = model()

print(out)
 

(Array([[1., 2., 3.],
       [1., 2., 3.],
       [1., 2., 3.],
       [1., 2., 3.]], dtype=float32), Array([[-0.05056199,  0.5899013 ,  0.07918721],
       [ 0.24101706,  1.2480859 , -0.31339505],
       [-1.1583914 , -2.0697548 , -1.9650959 ],
       [ 3.796438  , -0.5678566 ,  0.8969241 ]], dtype=float32))


In [111]:
import numpyro
from jax import numpy as jnp
import numpy as np
import jax

#means = jnp.array([0.25,0.9,0.2])
#other_means = jnp.array([0,1,1,0])
means = 0.5 * jnp.ones(3)

x_obs = jnp.array(np.random.randint(2,size=(4,3)))

def model():
    with numpyro.plate("i",3) as i:
        #my_x_obs = x_obs[i,:]
        #my_mean = means[i]
        with numpyro.plate("j",4) as j:
            tmp = numpyro.distributions.Bernoulli(means[i])
            z = numpyro.sample("z",tmp)
            tmp = numpyro.distributions.Bernoulli(.05 + .9*z)
            x = numpyro.sample("x",tmp,obs=x_obs)
    return z, x

# big = np.repeat(means[:,None],4,axis=1).T
# print(big.shape)
# def model():
#     tmp = numpyro.distributions.Bernoulli(big)
#     z = numpyro.sample("z",tmp)
#     tmp = numpyro.distributions.Bernoulli(.05 + .9*z)
#     x = numpyro.sample("x",tmp,obs=x_obs)
# 
#     return z,x


#kernel = numpyro.infer.NUTS(model)
kernel = numpyro.infer.DiscreteHMCGibbs(
        numpyro.infer.NUTS(model), modified=True
    )

mcmc = numpyro.infer.MCMC(
    kernel,
    num_warmup=1000,
    num_samples=1000,
    progress_bar=False,
)
key = jax.random.PRNGKey(0)

# numpyro gives some annoying future warnings
import warnings

with warnings.catch_warnings(action="ignore", category=FutureWarning):  # type: ignore
    mcmc.run(key)

#mcmc.print_summary()
print(mcmc.get_samples())

posterior_samples = mcmc.get_samples()

#predictive = numpyro.infer.Predictive(model, posterior_samples, infer_discrete=True)
predictive = numpyro.infer.Predictive(model, num_samples=1000, infer_discrete=True)
key = jax.random.PRNGKey(0)
conditional_samples = predictive(rng_key=key)

print(f"{conditional_samples['z'].shape=}")
print(f"{conditional_samples['x'].shape=}")
 

{'z': Array([[[0, 1, 0],
        [0, 0, 0],
        [0, 1, 0],
        [0, 1, 0]],

       [[1, 1, 0],
        [0, 0, 0],
        [1, 1, 0],
        [0, 1, 0]],

       [[0, 1, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 1, 0]],

       ...,

       [[0, 1, 0],
        [0, 0, 0],
        [1, 1, 0],
        [0, 1, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 1, 0],
        [0, 1, 0]],

       [[0, 1, 0],
        [0, 0, 0],
        [0, 1, 0],
        [0, 1, 0]]], dtype=int32)}




conditional_samples['z'].shape=(1000, 4, 3)
conditional_samples['x'].shape=(1000, 4, 3)


In [96]:
print(np.mean(conditional_samples['z'],axis=0))

[[0.95000005 0.059      0.95600003]
 [0.051      0.053      0.93600005]
 [0.947      0.95100003 0.049     ]
 [0.059      0.95100003 0.95100003]]


In [97]:
print(x_obs)

[[1 0 1]
 [0 0 1]
 [1 1 0]
 [0 1 1]]


In [98]:
print(np.mean(conditional_samples['x'],axis=0))

[[1. 0. 1.]
 [0. 0. 1.]
 [1. 1. 0.]
 [0. 1. 1.]]


In [131]:
import numpyro
from jax import numpy as jnp
import numpy as np
import jax

means = 0.5 * jnp.ones(3)

x_obs = jnp.array(np.random.randint(2,size=(4,3)))

def model():
    with numpyro.plate("i",3,dim=-2) as i:
        with numpyro.plate("j",4,dim=-1) as j:
            #tmp = numpyro.distributions.Bernoulli(means[:,None])
            #z = numpyro.sample("z",tmp)
            tmp = means[:,None]+1.0
            z = numpyro.deterministic("z",tmp)
    return z

with numpyro.handlers.seed(rng_seed=0):
    out = model()

print(out)
 


[[1.5]
 [1.5]
 [1.5]]


In [133]:
def model():
    z = numpyro.distributions.Normal(0,1)
    print(z.support)
model()

Real()


In [135]:
def model():
    z = numpyro.distributions.Uniform([0,1,2],[3,4,5])
    print(z.support)
model()

Interval(lower_bound=[0, 1, 2], upper_bound=[3, 4, 5])


In [142]:
def model():
    with numpyro.plate("i",3):
        z = numpyro.distributions.Normal(0,1)
    print(z.support)
model()

Real()


In [138]:
def model():
    with numpyro.plate("i",3):
        z = numpyro.distributions.Bernoulli(0,[3,4,5])
    print(z.support)
model()

Interval(lower_bound=0, upper_bound=[3, 4, 5])


In [140]:
jax.vmap(numpyro.distributions.constraints.interval)(jnp.array([0,1,2]),jnp.array([3,4,5]))

Interval(lower_bound=[0 1 2], upper_bound=[3 4 5])

In [141]:
jax.vmap(numpyro.distributions.constraints.interval,[None,0])(jnp.array(0),jnp.array([3,4,5]))


Interval(lower_bound=[0 0 0], upper_bound=[3 4 5])

In [147]:
numpyro.distributions.Bernoulli(jnp.array([0.5,0.9])).log_prob(jnp.array([1,1]))

Array([-0.6931472 , -0.10536055], dtype=float32)