## In Pyro, we define a guide that encodes this

In [1]:
def guide(data): # Guide doesn't require data; just need the value of N
    with pyro.plate("data", len(data)): # conditional independence
        # Define variational parameters \lambda_i (one for every data point)
        lam = pyro.param("lam",
            torch.rand(len(data)), # randomly initiallized
            constraint=dist.constraints.unit_interval) # \in [0, 1]
        c = pyro.sample("c", # Careful, this name HAS TO BE same to match the model
            dist.Bernoulli(lam))

## We generate some synthetic data from the following simualator to train our model on

In [2]:
def getdata(N, mean1=2.0, mean2=-1.0, std1=0.5, std2=0.5):
    D1 = np.random.randn(N//2,) * std1 + mean1
    D2 = np.random.randn(N//2,) * std2 + mean2
    D = np.concatenate([D1, D2], 0)
    np.random.shuffle(D)
    return torch.from_numpy(D.astype(np.float32))

## Finally, Pyro requires a bit of boilerplate to setup the optimization

In [None]:
data = getdata(200) # 200 data points
pyro.clear_param_store()
optim = pyro.optim.Adam({})
svi = pyro.infer.SVI(model, guide, optim, infer.Trace_ELBO())

for t in range(10000):
    svi.step(data)
