In [3]:
import pyro, torch, numpy as np
import pyro.distributions as dist
import pyro.optim as optim
import pyro.infer as infer
import matplotlib.pyplot as plt

In [4]:
def model(data): # Take the observation
    # Define coin bias as parameter. That's what 'pyro.param' does
    rho = pyro.param("rho", # Give it a name for Pyro to track properly
        torch.tensor([0.5]), # Initial value
        constraint=dist.constraints.unit_interval) # Has to be in [0, 1]
    # Define both means and std with random initial values
    means = pyro.param("M", torch.tensor([1.5, 3.]))
    stds = pyro.param("S", torch.tensor([0.5, 0.5]),
        constraint=dist.constraints.positive) # std deviation cannot be negative

    with pyro.plate("data", len(data)): # Mark conditional independence
        # construct a Bernoulli and sample from it. 
        c = pyro.sample("c", dist.Bernoulli(rho)) # c \in {0, 1}
        c = c.type(torch.LongTensor)
        X = dist.Normal(means[c], stds[c]) # pick a mean as per 'c'
        pyro.sample("x", X, obs=data) # sample data (also mark it as observed)