# Posterior inference and regression

The examples in the previous notebook don’t really show much in the way of fancy prediction; for that we want to do some _regression_.
See e.g. the McElreath book for a practical intro to regression in a BAyesian context.
We will follow the [pyro regression tutorial](http://pyro.ai/examples/bayesian_regression.html).

Also good is Florian Whilhelm’s [Bayesian Hierarchical Modelling at Scale](https://florianwilhelm.info/2020/10/bayesian_hierarchical_modelling_at_scale/), although that is for the (similar but not identical) numpyro rather than pyro.

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import pyro
import pyro.distributions as dist
from pyro.optim import Adam
from pyro.infer import Trace_ELBO, Predictive, MCMC, NUTS
import pyro.distributions as dist
from pyro import poutine
sns.set_theme()

from src import graphs

The task here is to predict country’s GDP in the year 2000 from various other facts about it. Here are the facts we have to work with:

In [None]:

# data from https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv
rugged_data = pd.read_csv("rugged_data.csv", encoding="ISO-8859-1")
rugged_data.head()

In this case, we will predict a country’s GDP from its “ruggedness”, and whether it is in Africa  or not, which we observe interact in a non-trivial way.
We keep this simple by pre-processing the data (and in fact we work with the log-transformed GDP).

In [None]:
# preprocess data
df = rugged_data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

We define the model as follows,a (Linear) regression model capturing a predictor variables (Africa or not, Terrain roughness) and a response variable (GDP), and an interaction term.

$$\begin{aligned}
\text{GDP}_i &\sim \mathcal{N}(\mu, \sigma)\\
\mu &= a + b_a \cdot \operatorname{InAfrica}_i + b_r \cdot \operatorname{Ruggedness}_i + b_{ar} \cdot \operatorname{InAfrica}_i \cdot \operatorname{Ruggedness}_i \\
a &\sim \mathcal{N}(0, 1)\\
b_a &\sim \mathcal{N}(0, 1)\\
b_r &\sim \mathcal{N}(0, 1)\\
b_{ar} &\sim \mathcal{N}(0, 1)\\
\sigma &\sim \operatorname{Gamma}(1, \frac12)
\end{aligned}$$

where $\alpha$ and $\beta$ are the coefficients of the intercept and the ruggedness coefficient, respectively, and $\gamma$ is the coefficient of the binary indicator of whether the country is in Africa.


In [None]:
pyro.clear_param_store()
def model():
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Gamma(1.0, 0.5)) 
    is_cont_africa = pyro.sample("is_cont_africa", dist.Bernoulli(0.5))  # <- overridden
    ruggedness = pyro.sample("ruggedness", dist.Normal(1.0, 0.5))        # <- overridden
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    s = pyro.sample("log_gdp", dist.Normal(mean, sigma))                 # <- overridden
    return s



Note the trick here, that we gave distributions even to inputs that we will get from data; this is how you need to do it, even if that distribution will never by used. During inference we  always override the values at those sites with data. 

In [None]:
graphs.ruggedness_graph(170)

Inference proceeeds by conditioning the model on the observed data, and then sampling from the posterior distribution.

In [None]:

observed_model = poutine.condition(model, data={
    "log_gdp": log_gdp, "ruggedness": ruggedness, "is_cont_africa": is_cont_africa})
nuts_kernel = NUTS(observed_model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run()

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [None]:
mcmc.summary()

To actually make predictions we need to use the `Predictive` class, which is not well explained in the docs, but [you can work it out from their example](http://pyro.ai/examples/predictive_deterministic.html).
An only-slightly-confusing explanation is [here](http://pyro.ai/examples/bayesian_regression.html#Model-Evaluation).

Now, Let us suppose that we wish to found some new nations, by cutting an existing nation in half so as to preserve its ruggedness. What does this model tell us about the GDP of the new nations?

In [None]:
colors = sns.color_palette()
swiss_rugged = rugged_data[rugged_data['country']=='Switzerland']['rugged'].item()
swiss_loggdp = np.log(rugged_data[rugged_data['country']=='Switzerland']['rgdppc_2000'].item())
switzerland2_gdp = Predictive(poutine.condition(model, data={
    "ruggedness": torch.tensor(swiss_rugged), "is_cont_africa": torch.tensor(0.)}), posterior_samples=mcmc.get_samples())()['log_gdp']
sns.kdeplot(switzerland2_gdp, label="Switzerland 2.0 log GDP", color=colors[0])
plt.vlines(swiss_loggdp, 0, 0.5, label="Switzerland 1.0 log GDP", linestyle="dashed", color=colors[0])

pakistan_rugged = rugged_data[rugged_data['country']=='Pakistan']['rugged'].item()
pakistan_loggdp = np.log(rugged_data[rugged_data['country']=='Pakistan']['rgdppc_2000'].item())
pakistan2_gdp = Predictive(poutine.condition(model, data={
    "ruggedness": torch.tensor(pakistan_rugged), "is_cont_africa": torch.tensor(0.)}), posterior_samples=mcmc.get_samples())()['log_gdp']
plt.vlines(pakistan_loggdp, 0, 0.5, label="Pakistan 1.0 log GDP", linestyle="dotted", color=colors[1])
sns.kdeplot(pakistan2_gdp, label="Pakistan 2.0 log GDP", color=colors[1])

oz_rugged = rugged_data[rugged_data['country']=='Australia']['rugged'].item()
oz_loggdp = np.log(rugged_data[rugged_data['country']=='Australia']['rgdppc_2000'].item())
oz2_gdp = Predictive(poutine.condition(model, data={
    "ruggedness": torch.tensor(oz_rugged), "is_cont_africa": torch.tensor(0.)}), posterior_samples=mcmc.get_samples())()['log_gdp']
plt.vlines(oz_loggdp, 0, 0.5, label="Australia 1.0 log GDP", linestyle="dotted", color=colors[2])
sns.kdeplot(oz2_gdp, label="Australia 2.0 log GDP", color=colors[2])


plt.legend();

Note that the predictions are not incredibly informative; since we are working with little information, this model predictions have a high variance, but also the model lets us know that we should not be incredibly confident about them.