In [24]:
import logging
import os

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.optim as optim

pyro.set_rng_seed(1)
assert pyro.__version__.startswith('1.9.1')

In [25]:
%matplotlib inline
plt.style.use('default')

logging.basicConfig(format='%(message)s', level=logging.INFO)
smoke_test = ('CI' in os.environ)
pyro.set_rng_seed(1)
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
rugged_data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")

In [26]:
def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(0.0, 10.0))
    b_a = pyro.sample("bA", dist.Normal(0.0, 1.0))
    b_r = pyro.sample("bR", dist.Normal(0.0, 1.0))
    b_ar = pyro.sample("bAR", dist.Normal(0.0, 1.0))
    sigma = pyro.sample("sigma", dist.Uniform(0.0, 10.0))
    mean = (
        a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    )
    with pyro.plate("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)


def guide(is_cont_africa, ruggedness, log_gdp):
    a_loc = pyro.param("a_loc", torch.tensor(0.0))
    a_scale = pyro.param("a_scale", torch.tensor(1.0), constraint=constraints.positive)
    sigma_loc = pyro.param(
        "sigma_loc", torch.tensor(1.0), constraint=constraints.positive
    )
    weights_loc = pyro.param("weights_loc", torch.randn(3))
    weights_scale = pyro.param(
        "weights_scale", torch.ones(3), constraint=constraints.positive
    )
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    b_a = pyro.sample("bA", dist.Normal(weights_loc[0], weights_scale[0]))
    b_r = pyro.sample("bR", dist.Normal(weights_loc[1], weights_scale[1]))
    b_ar = pyro.sample("bAR", dist.Normal(weights_loc[2], weights_scale[2]))
    sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
    mean = (
        a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    )

In [27]:
# Utility function to print latent sites' quantile information.
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

# Prepare training data
df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)

In [28]:
from pyro.infer import SVI, Trace_ELBO


svi = SVI(model,
          guide,
          optim.Adam({"lr": .05}),
          loss=Trace_ELBO())

is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]
pyro.clear_param_store()
num_iters = 5000 if not smoke_test else 2
for i in range(num_iters):
    elbo = svi.step(is_cont_africa, ruggedness, log_gdp)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

Elbo loss: 5795.468078792095


tensor([8.5800e-01, 3.4270e+00, 7.6900e-01, 7.7500e-01, 2.6880e+00, 6.0000e-03,
        1.4300e-01, 3.5130e+00, 1.6720e+00, 1.7800e+00, 3.8800e-01, 1.4100e-01,
        2.3600e-01, 1.8600e-01, 1.4790e+00, 2.3100e-01, 5.5000e-02, 2.3110e+00,
        1.6400e-01, 6.0700e-01, 8.5300e-01, 2.4000e-01, 9.6300e-01, 1.8100e-01,
        1.9700e-01, 7.7500e-01, 4.7610e+00, 2.4810e+00, 1.8780e+00, 2.2400e-01,
        5.1500e-01, 4.4300e-01, 1.5200e-01, 8.8500e-01, 3.3280e+00, 2.3670e+00,
        2.1120e+00, 2.7180e+00, 8.8400e-01, 5.9700e-01, 2.4320e+00, 3.0000e-03,
        1.8900e-01, 1.6410e+00, 5.1000e-01, 1.2780e+00, 7.2300e-01, 2.4810e+00,
        1.6890e+00, 1.2300e-01, 1.5700e+00, 3.2800e-01, 1.3960e+00, 1.0980e+00,
        2.1800e-01, 5.6800e-01, 3.6590e+00, 2.2800e-01, 7.4000e-01, 3.5300e-01,
        4.9100e-01, 5.5900e-01, 3.1030e+00, 2.0880e+00, 1.8070e+00, 2.7300e-01,
        2.5010e+00, 2.1500e+00, 1.2670e+00, 2.3620e+00, 3.4600e-01, 9.6700e-01,
        1.0130e+00, 5.1300e-01, 2.4450e+

Elbo loss: 415.81691539287567


tensor([8.5800e-01, 3.4270e+00, 7.6900e-01, 7.7500e-01, 2.6880e+00, 6.0000e-03,
        1.4300e-01, 3.5130e+00, 1.6720e+00, 1.7800e+00, 3.8800e-01, 1.4100e-01,
        2.3600e-01, 1.8600e-01, 1.4790e+00, 2.3100e-01, 5.5000e-02, 2.3110e+00,
        1.6400e-01, 6.0700e-01, 8.5300e-01, 2.4000e-01, 9.6300e-01, 1.8100e-01,
        1.9700e-01, 7.7500e-01, 4.7610e+00, 2.4810e+00, 1.8780e+00, 2.2400e-01,
        5.1500e-01, 4.4300e-01, 1.5200e-01, 8.8500e-01, 3.3280e+00, 2.3670e+00,
        2.1120e+00, 2.7180e+00, 8.8400e-01, 5.9700e-01, 2.4320e+00, 3.0000e-03,
        1.8900e-01, 1.6410e+00, 5.1000e-01, 1.2780e+00, 7.2300e-01, 2.4810e+00,
        1.6890e+00, 1.2300e-01, 1.5700e+00, 3.2800e-01, 1.3960e+00, 1.0980e+00,
        2.1800e-01, 5.6800e-01, 3.6590e+00, 2.2800e-01, 7.4000e-01, 3.5300e-01,
        4.9100e-01, 5.5900e-01, 3.1030e+00, 2.0880e+00, 1.8070e+00, 2.7300e-01,
        2.5010e+00, 2.1500e+00, 1.2670e+00, 2.3620e+00, 3.4600e-01, 9.6700e-01,
        1.0130e+00, 5.1300e-01, 2.4450e+

Elbo loss: 250.71913582086563


tensor([8.5800e-01, 3.4270e+00, 7.6900e-01, 7.7500e-01, 2.6880e+00, 6.0000e-03,
        1.4300e-01, 3.5130e+00, 1.6720e+00, 1.7800e+00, 3.8800e-01, 1.4100e-01,
        2.3600e-01, 1.8600e-01, 1.4790e+00, 2.3100e-01, 5.5000e-02, 2.3110e+00,
        1.6400e-01, 6.0700e-01, 8.5300e-01, 2.4000e-01, 9.6300e-01, 1.8100e-01,
        1.9700e-01, 7.7500e-01, 4.7610e+00, 2.4810e+00, 1.8780e+00, 2.2400e-01,
        5.1500e-01, 4.4300e-01, 1.5200e-01, 8.8500e-01, 3.3280e+00, 2.3670e+00,
        2.1120e+00, 2.7180e+00, 8.8400e-01, 5.9700e-01, 2.4320e+00, 3.0000e-03,
        1.8900e-01, 1.6410e+00, 5.1000e-01, 1.2780e+00, 7.2300e-01, 2.4810e+00,
        1.6890e+00, 1.2300e-01, 1.5700e+00, 3.2800e-01, 1.3960e+00, 1.0980e+00,
        2.1800e-01, 5.6800e-01, 3.6590e+00, 2.2800e-01, 7.4000e-01, 3.5300e-01,
        4.9100e-01, 5.5900e-01, 3.1030e+00, 2.0880e+00, 1.8070e+00, 2.7300e-01,
        2.5010e+00, 2.1500e+00, 1.2670e+00, 2.3620e+00, 3.4600e-01, 9.6700e-01,
        1.0130e+00, 5.1300e-01, 2.4450e+

KeyboardInterrupt: 

In [22]:
from pyro.infer import Predictive


num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)
svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
               for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
               if k != "obs"}

In [23]:
for site, values in summary(svi_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

Site: a
     mean       std        5%       25%       50%       75%       95%
0  9.1775  0.062302  9.077002  9.134531  9.178521  9.215998  9.278266 

Site: bA
       mean       std      5%       25%      50%       75%      95%
0 -1.895068  0.118995 -2.0918 -1.974352 -1.89098 -1.813421 -1.70285 

Site: bR
       mean       std        5%       25%      50%       75%       95%
0 -0.157187  0.038121 -0.222266 -0.181702 -0.15502 -0.130234 -0.095558 

Site: bAR
       mean       std       5%       25%       50%       75%       95%
0  0.304799  0.066955  0.19294  0.261902  0.304932  0.350269  0.412381 

Site: sigma
       mean       std        5%       25%       50%       75%       95%
0  0.902913  0.049275  0.822383  0.870878  0.901005  0.938589  0.983858 

