This is an example of **BASIC APPROXIMATE VARIATIONAL INFERENCE** with pyro. It is the flip coin example. 

In [1]:
from __future__ import print_function
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam, SGD
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

# enable validation (e.g. validate parameters of distributions)
pyro.enable_validation(True)

# clear the param store in case we're in a REPL
pyro.clear_param_store()

# create some data with 1 observed heads and 9 observed tails
data = []
for _ in range(1):
    data.append(torch.tensor(1.0))
for _ in range(9):
    data.append(torch.tensor(0.0))
    
## WOW. In python everything defined outside a function is considered global 
## and can be used inside the function as well    
prior_dict = { "alpha_0": torch.tensor(1.0), "beta_0" : torch.tensor(10.0)}
    
def model(data):
    # define the hyperparameters that control the beta prior
    alpha0 = prior_dict["alpha_0"]
    beta0  = prior_dict["beta_0"]
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    # register the two variational parameters with Pyro
    # - both parameters will have initial value 15.0.
    # - because we invoke constraints.positive, the optimizer
    # will take gradients on the unconstrained parameters
    # (which are related to the constrained parameters by a log)
    alpha_q = pyro.param("alpha_q", torch.tensor(125.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(125.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

# setup the optimizer
#adam_params = {"lr": 0.05, "betas": (0.90, 0.999)}
#optimizer = Adam(adam_params)
sgd_params = {"lr": 0.05}
optimizer = SGD(sgd_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# do gradient steps
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print("step, alpha, beta",step,pyro.param("alpha_q").item(),pyro.param("beta_q").item())

# Exact values can be inferred due to conjugancy between Beta (prior dist) and Bernully (likelyhood)
observed_head = sum(data).item()
observed_tail = len(data) - observed_head
alpha_exact = prior_dict["alpha_0"].item() + observed_head
beta_exact  = prior_dict["beta_0"].item()  + observed_tail

def compute_mean_var_of_beta_dist(a,b):
    """ Return the mean and std of Beta distribution 
        when the parameters alpha,beta are given
    """
    mean = a/(a+b)
    factor = b/(a*(1+a+b))
    std = mean * math.sqrt(factor)
    return mean,std

# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q  = pyro.param("beta_q").item()

# print the result
print("\n")
print("Inferred alpha_q, beta_q =",alpha_q,beta_q)
print("Exact    alpha ,  beta   =",alpha_exact,beta_exact)
print("Inferred fairness= %.3f +- %.3f" % compute_mean_var_of_beta_dist(alpha_q,beta_q))
print("Exact    fairness= %.3f +- %.3f" % compute_mean_var_of_beta_dist(alpha_exact,beta_exact))

step, alpha, beta 0 80.12602233886719 191.89492797851562
step, alpha, beta 100 11.850266456604004 138.74346923828125
step, alpha, beta 200 5.099936008453369 49.65652847290039
step, alpha, beta 300 3.085559606552124 26.45911407470703
step, alpha, beta 400 2.096299409866333 26.560930252075195
step, alpha, beta 500 1.639630913734436 16.459138870239258
step, alpha, beta 600 2.1707279682159424 24.77992057800293
step, alpha, beta 700 2.067006826400757 20.55565643310547
step, alpha, beta 800 1.6133599281311035 24.301687240600586
step, alpha, beta 900 1.7965360879898071 25.169696807861328
step, alpha, beta 1000 2.129049301147461 20.82669448852539
step, alpha, beta 1100 1.882346510887146 23.698848724365234
step, alpha, beta 1200 2.3660542964935303 23.771549224853516
step, alpha, beta 1300 1.8733506202697754 20.857290267944336
step, alpha, beta 1400 1.8274405002593994 20.030136108398438
step, alpha, beta 1500 2.1322097778320312 22.286333084106445
step, alpha, beta 1600 1.7033658027648926 18.9097