# PoC 0
Concept: Do a super first step.

Just identify what the mean and the stdev of some data is

In [1]:
import numpy as np
import pandas as pd
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import torch
import os
import matplotlib.pyplot as plt
import pyro.distributions.constraints as constraints
import logging
import seaborn as sns

%matplotlib inline
plt.style.use('default')

logging.basicConfig(format='%(message)s', level=logging.INFO)
smoke_test = ('CI' in os.environ)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mu = 1
sig = 2
n_samples = 10_000

In [21]:
data_np = np.random.normal(mu, sig, n_samples)
data = torch.tensor(data_np)


In [4]:
def model(data):
    sigma = pyro.param("sigma", torch.tensor([1.]), constraint=constraints.positive)
    mu = pyro.param("mu", torch.tensor([0.]))

    with pyro.plate("N", len(data)):
        return pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

In [5]:
from pyro.infer.autoguide import AutoMultivariateNormal, init_to_mean
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [6]:
def get_svi(model):
    guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)
    svi = SVI(model,
            guide,
            optim.Adam({"lr": .01}),
            loss=Trace_ELBO())
    return guide, svi

In [7]:
def custom_guide(data=None):
    sigma = pyro.param("sigma", lambda: torch.tensor([1.]), constraint=constraints.positive)
    mu = pyro.param("mu", lambda: torch.tensor([0.]))

    return {"mu": mu, "sigma": sigma}

In [8]:
# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, custom_guide, optimizer, loss=Trace_ELBO())

In [22]:
n_steps = 10000
# do gradient steps
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print('.', end='')

....................................................................................................

In [23]:
# grab the learned variational parameters
mu = pyro.param("mu").item()
sigma = pyro.param("sigma").item()

print(f"bayesian: mu={mu}, sigma={sigma}")
print(f"Check: mu={np.mean(data_np)}, sigma={np.std(data_np)}")

bayesian: mu=0.9707101583480835, sigma=2.0084011554718018
Check: mu=0.9707101158522659, sigma=2.008401068345451
