# Posterior inference and regression

The examples in the previous notebook don’t really show much in the way of fancy prediction; for that we want to do some _regression_.
We will follow the [pyro regression tutorial](http://pyro.ai/examples/bayesian_regression.html) but see also the McElreath book for a really nice discussion of this regression problem.

* http://pyro.ai/examples/predictive_deterministic.html

In [None]:
from functools import partial
from math import sqrt

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, Predictive, MCMC, NUTS
import pyro.distributions as dist
from pyro import poutine
sns.set_theme()

In [None]:

# data from https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv
rugged_data = pd.read_csv("rugged_data.csv", encoding="ISO-8859-1")
rugged_data.head()

In [None]:
# preprocess data
df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

In [None]:
# Utility function to print latent sites' quantile information.
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats


(Linear) regression model capturing a predictor variables (Africa or not, Terrain roughness) and a response variable (GDP), and an interaction term.

In [None]:
pyro.clear_param_store()
def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Gamma(1.0, 0.5)) 
    with pyro.plate("data", len(ruggedness)):
        mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
        pyro.sample("log_gdp", dist.Normal(mean, sigma), obs=log_gdp)

In [None]:
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(is_cont_africa, ruggedness, log_gdp)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [None]:
pyro.clear_param_store()
def model3(is_cont_africa, ruggedness):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Gamma(1.0, 0.5)) 
    # with pyro.plate("data", len(ruggedness)):
    is_cont_africa = pyro.deterministic("is_cont_africa", is_cont_africa)
    ruggedness = pyro.deterministic("ruggedness", ruggedness)
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    s = pyro.sample("log_gdp", dist.Normal(mean, sigma))
    return s

trace = poutine.trace(model2).get_trace(is_cont_africa, ruggedness)
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

In [None]:
pyro.clear_param_store()
def model2():
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Gamma(1.0, 0.5)) 
    # is_cont_africa = pyro.deterministic("is_cont_africa", torch.tensor(0.))
    # ruggedness = pyro.deterministic("ruggedness", torch.tensor(0.))
    # is_cont_africa = pyro.sample("is_cont_africa", dist.Delta(torch.tensor(0.)))
    # ruggedness = pyro.sample("ruggedness", dist.Delta(torch.tensor(0.)))
    is_cont_africa = pyro.sample("is_cont_africa", dist.Normal(0.5, 0.5))
    ruggedness = pyro.sample("ruggedness", dist.Normal(0.5, 0.5))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    s = pyro.sample("log_gdp", dist.Normal(mean, sigma))
    return s

trace = poutine.trace(model2).get_trace()
# trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

In [None]:
trace = poutine.trace(poutine.condition(model2, data={"log_gdp": log_gdp, "ruggedness":ruggedness, "is_cont_africa": is_cont_africa})).get_trace()
# trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

In [None]:

observed_model = poutine.condition(model2, data={
    "log_gdp": log_gdp, "ruggedness": ruggedness, "is_cont_africa": is_cont_africa})
nuts_kernel = NUTS(observed_model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run()

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [None]:
mcmc.summary()

In [None]:
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

In [None]:
Predictive(poutine.condition(model2, data={
    "ruggedness": torch.tensor(1000.0), "is_cont_africa": torch.tensor(0.)}), posterior_samples=mcmc.get_samples())()['log_gdp']


In [None]:
Predictive(poutine.condition(model2, data={
    "ruggedness": torch.tensor(0.0), "is_cont_africa": torch.tensor(0.)}), posterior_samples=mcmc.get_samples())()['log_gdp']


In [None]:
Predictive(poutine.condition(model2, data={
    "log_gdp": torch.tensor(0.0), "is_cont_africa": torch.tensor(0.)}), posterior_samples=mcmc.get_samples())()['ruggedness']


In [None]:
predictive = Predictive(model2, posterior_samples=mcmc.get_samples())
predictive(is_cont_africa, ruggedness, log_gdp)['gdp']

For this we need to use the `Predictive` class, which is not well explained in the docs.
An only-slightly-confusing explanation is [here](http://pyro.ai/examples/bayesian_regression.html#Model-Evaluation).

In [None]:
predictive = Predictive(model, posterior_samples=mcmc.get_samples())
predictive(is_cont_africa, ruggedness, log_gdp)['obs']

In [None]:
predictive(torch.tensor(1.), torch.tensor(1.))['obs']

In [None]:
mcmc.get_samples().keys()