## Questions

- Why is the MCMC sampling so slow...?

In [1]:
import logging

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.distributions.constraints as constraints

import pyro
import pyro.infer
import pyro.distributions as dist

from pyro.infer import MCMC, NUTS, Predictive, Trace_ELBO, SVI

pyro.set_rng_seed(101)

%matplotlib inline
%config InlineBackend.figure_format = "retina"

logging.basicConfig(format='%(message)s', level=logging.INFO)

In [15]:
intercept_0 = 4
beta_0 = [2, 3] # the _0 represents the true parameter, not to be confused with the intercept
sigma_0 = 1.5
n = 20

NUM_ITERS = 5000000

In [16]:
x1 = np.exp(np.random.normal(loc=3, scale=2, size=[n, 1]))
x2 = np.random.binomial(n=1, p=0.4, size=[n, 1])
assert sum(x2) > 0 and sum(x2) < n

x = np.hstack((x1, x2))
y = intercept_0 + x@np.array(beta_0) + np.random.normal(loc=0, scale=sigma_0, size=[n,])

x = torch.Tensor(x)
y = torch.Tensor(y)

In [44]:
def regression(x, y):
    intercept = pyro.sample("intercept", dist.Normal(loc=0, scale=10))
    beta = []
    
    for i in range(x.shape[1]):
        beta.append(pyro.sample(f"beta{i+1}", dist.Normal(loc=0, scale=10)))
        
    sigma = pyro.sample("sigma", dist.InverseGamma(concentration=4, rate=2))
    
    mean = intercept + x.matmul(torch.Tensor(beta))
    
    y = pyro.sample("y", dist.Normal(loc=mean, scale=sigma), obs=y)

"""
# code below is too slow... using a guide

nuts_kernel = NUTS(model=regression)
mcmc = MCMC(kernel=nuts_kernel,
                 num_samples=10000,
                 num_chains=1, 
                 warmup_steps=1000)
posterior = mcmc.run(x=x, y=y)
"""

def guide(x, y):
    intercept_loc   = pyro.param("intercept_loc", torch.tensor(0.))
    intercept_scale = pyro.param("intercept_scale", torch.tensor(10.), constraint=constraints.positive)
    
    intercept_g = pyro.sample("intercept", dist.Normal(loc=intercept_loc, scale=intercept_scale))
    
    beta_g = []
    beta_loc   = {}
    beta_scale = {}
    
    for i in range(x.shape[1]):
        beta_loc[i]   = pyro.param(f"beta{i+1}_loc", torch.tensor(0.))
        beta_scale[i] = pyro.param(f"beta{i+1}_scale", torch.tensor(10.), constraint=constraints.positive)
        
        beta_g.append(pyro.sample(f"beta{i+1}", dist.Normal(loc=beta_loc[i], scale=beta_scale[i])))
    
    sigma_concentration = pyro.param("sigma_concentration", torch.tensor(4.), constraint=constraints.positive)
    sigma_rate = pyro.param("sigma_rate", torch.tensor(2.), constraint=constraints.positive)
    
    sigma_g = pyro.sample("sigma", dist.InverseGamma(concentration=sigma_concentration, 
                                                       rate=sigma_rate))

In [45]:
adam_params = {"lr": 0.001}
optimizer = pyro.optim.Adam(adam_params)


pyro.clear_param_store()

svi = SVI(regression,
          guide,
          pyro.optim.Adam(adam_params),
          loss=Trace_ELBO())

for i in range(NUM_ITERS):
    if i % 2000 == 0:
        elbo = svi.step(x, y)
        logging.info("Elbo loss: {}".format(np.log(elbo)))

Elbo loss: 21.17296626402927
Elbo loss: 17.194840786734563
Elbo loss: 14.909699744100802
Elbo loss: 21.659881149290783
Elbo loss: 19.290512087733024
Elbo loss: 20.83483148329439
Elbo loss: 21.009891575662643
Elbo loss: 21.686518221889322
Elbo loss: 20.295109670995657
Elbo loss: 19.834340824599494
Elbo loss: 18.72493540908678
Elbo loss: 15.766808958743187
Elbo loss: 19.103746315504637
Elbo loss: 17.327568165290803
Elbo loss: 19.40920920162535
Elbo loss: 20.103619019973287
Elbo loss: 15.12300417514859
Elbo loss: 20.056929702614568
Elbo loss: 19.428538221484004
Elbo loss: 18.53088290387328
Elbo loss: 19.331797409414023
Elbo loss: 20.8316182754897
Elbo loss: 22.067452577465527
Elbo loss: 13.414947222819485
Elbo loss: 16.385323231578163


In [47]:
pyro.param("beta1_scale").item()

10.039931297302246