# Posterior inference and regression

The examples in the previous notebook don’t really show much in the way of fancy prediction; for that we want to do some _regression_.
We will follow the [pyro regression tutorial](http://pyro.ai/examples/bayesian_regression.html) but see also the McElreath book for a really nice discussion of this regression problem.

* Pyro does not really have a very good explanation for the `Predictive` class that we use here, but [you can work it out from their example](http://pyro.ai/examples/predictive_deterministic.html)
* The best introduction is IMO Florian Whilhelm’s [Bayesian Hierarchical Modelling at Scale](https://florianwilhelm.info/2020/10/bayesian_hierarchical_modelling_at_scale/), although that is for the (similar but not identical) numpyro rahter than pyro.

In [None]:
from functools import partial
from math import sqrt

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, Predictive, MCMC, NUTS
import pyro.distributions as dist
from pyro import poutine
sns.set_theme()

from src import graphs

In [None]:

# data from https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv
rugged_data = pd.read_csv("rugged_data.csv", encoding="ISO-8859-1")
rugged_data.head()

In [None]:
# preprocess data
df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

(Linear) regression model capturing a predictor variables (Africa or not, Terrain roughness) and a response variable (GDP), and an interaction term.

In [None]:
graphs.ruggedness_graph(170)

In [None]:
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(is_cont_africa, ruggedness, log_gdp)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

For this we need to use the `Predictive` class, which is not well explained in the docs.
An only-slightly-confusing explanation is [here](http://pyro.ai/examples/bayesian_regression.html#Model-Evaluation).

In [None]:
pyro.clear_param_store()
def model():
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Gamma(1.0, 0.5)) 
    is_cont_africa = pyro.sample("is_cont_africa", dist.Bernoulli(0.5))
    ruggedness = pyro.sample("ruggedness", dist.Normal(1.0, 0.5))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    s = pyro.sample("log_gdp", dist.Normal(mean, sigma))
    return s



In [None]:

observed_model = poutine.condition(model, data={
    "log_gdp": log_gdp, "ruggedness": ruggedness, "is_cont_africa": is_cont_africa})
nuts_kernel = NUTS(observed_model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run()

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [None]:
mcmc.summary()

In [None]:
colors = sns.color_palette()
swiss_rugged = rugged_data[rugged_data['country']=='Switzerland']['rugged'].item()
belgian_rugged = rugged_data[rugged_data['country']=='Belgium']['rugged'].item()
swiss_loggdp = np.log(rugged_data[rugged_data['country']=='Switzerland']['rgdppc_2000'].item())
belgian_loggdp = np.log(rugged_data[rugged_data['country']=='Belgium']['rgdppc_2000'].item())

switzerland2_gdp = Predictive(poutine.condition(model, data={
    "ruggedness": torch.tensor(swiss_rugged), "is_cont_africa": torch.tensor(0.)}), posterior_samples=mcmc.get_samples())()['log_gdp']
belgium2_gdp = Predictive(poutine.condition(model, data={
    "ruggedness": torch.tensor(belgian_rugged), "is_cont_africa": torch.tensor(0.)}), posterior_samples=mcmc.get_samples())()['log_gdp']
sns.kdeplot(switzerland2_gdp, label="Switzerland 2.0 log GDP", color=colors[0])
sns.kdeplot(belgium2_gdp, label="Belgium 2.0 log GDP", color=colors[1])
plt.vlines(swiss_loggdp, 0, 0.5, label="Switzerland 1.0 log GDP", linestyle="dashed", color=colors[0])
plt.vlines(belgian_loggdp, 0, 0.5, label="Belgium 1.0 log GDP", linestyle="dotted", color=colors[1])
plt.legend()

In [None]:
trace = poutine.trace(model).get_trace()
print(trace.format_shapes())