# Chapter 8 - Notes

## Set Up

### Packages

In [1]:
import os

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns
import xarray as xr
from scipy import stats
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
from scipy.special import expit, logit, logsumexp, softmax
from sklearn.preprocessing import StandardScaler



### Constants

In [None]:
DATA_DIR = "../data"
HOWELL_FILE = "howell.csv"
CHERRY_BLOSSOMS_FILE = "cherry_blossoms.csv"
WAFFLE_DIVORCE_FILE = "waffle_divorce.csv"
MILK_FILE = "milk.csv"
CHIMPANZEES_FILE = "chimpanzees.csv"
ADMISSIONS_FILE = "ucbadmit.csv"
KLINE_FILE = "kline.csv"
REEDFROGS_FILE = "reedfrogs.csv"
CARS_FILE = "cars.csv"
RUGGED_FILE = "rugged.csv"

RANDOM_SEED = 42

### Defaults

In [2]:
# seaborn defaults
sns.set(
    style="whitegrid",
    font_scale=1.2,
    rc={
        "axes.edgecolor": "0",
        "axes.grid.which": "both",
        "axes.labelcolor": "0",
        "axes.spines.right": False,
        "axes.spines.top": False,
        "xtick.bottom": True,
        "ytick.left": True,
    },
)

colors = sns.color_palette()

# set seed
rng = np.random.default_rng(RANDOM_SEED)

NameError: name 'RANDOM_SEED' is not defined

### Functions

In [None]:
def load_data(file_name, data_dir=DATA_DIR, **kwargs):
    path = os.path.join(data_dir, file_name)
    return pd.read_csv(path, **kwargs)

In [None]:
def smooth_plot_data(x, y, smooth_kwargs=None):
    if smooth_kwargs is None:
        smooth_kwargs = {}

    smooth_kwargs.setdefault("window_length", 55)
    smooth_kwargs.setdefault("polyorder", 2)
    x_data = np.linspace(x.min(), x.max(), 200)
    x_data[0] = (x_data[0] + x_data[1]) / 2
    y_interp = griddata(x, y, x_data)
    y_data = savgol_filter(y_interp, axis=0, **smooth_kwargs)

    return x_data, y_data

## 8.1 Building an interaction

### 8.1.1 Making a rugged model

In [None]:
# generate the data
rugged = load_data(RUGGED_FILE, delimiter=";")

# define log gdp
rugged["log_gdp"] = np.log(rugged["rgdppc_2000"])

# restrict to countries with gdp data
rugged = rugged.dropna(subset=["rgdppc_2000"])

# standardise variables
rugged["log_gdp_std"] = rugged["log_gdp"] / rugged["log_gdp"].mean()
rugged["rugged_std"] = rugged["rugged"] / rugged["rugged"].max()

Define a first model with rough priors

In [None]:
coords = {
    "country": rugged.country,
}
with pm.Model(coords_mutable=coords) as m8_1a:
    # data
    rugged_std = pm.MutableData("rugged_std", rugged.rugged_std, dims="country")
    rugged_std_mean = pm.MutableData("rugged_std_mean", rugged.rugged_std.mean())

    # priors
    alpha = pm.Normal("alpha", mu=1, sigma=1)
    beta = pm.Normal("beta", mu=0, sigma=1)
    sigma = pm.Exponential("sigma", lam=1)

    # model
    mu = pm.Deterministic(
        "mu", alpha + beta * (rugged_std - rugged_std_mean), dims="country"
    )

    # likelihood
    log_gdp_std = pm.Normal(
        "log_gdp_std", mu=mu, sigma=sigma, observed=rugged.log_gdp_std, dims="country"
    )

    # sample prior predictive
    prior_trace_8_1a = pm.sample_prior_predictive(random_seed=RANDOM_SEED)

In [None]:
m8_1a

In [None]:
m8_1a.to_graphviz()

And now with some better priors

In [None]:
coords = {
    "country": rugged.country,
}
with pm.Model(coords_mutable=coords) as m8_1b:
    # data
    rugged_std = pm.MutableData("rugged_std", rugged.rugged_std, dims="country")
    rugged_std_mean = pm.MutableData("rugged_std_mean", rugged.rugged_std.mean())

    # priors
    alpha = pm.Normal("alpha", mu=1, sigma=0.1)
    beta = pm.Normal("beta", mu=0, sigma=0.3)
    sigma = pm.Exponential("sigma", lam=1)

    # model
    mu = pm.Deterministic(
        "mu", alpha + beta * (rugged_std - rugged_std_mean), dims="country"
    )

    # likelihood
    log_gdp_std = pm.Normal(
        "log_gdp_std", mu=mu, sigma=sigma, observed=rugged.log_gdp_std, dims="country"
    )

    # sample prior predictive
    prior_trace_8_1b = pm.sample_prior_predictive(random_seed=RANDOM_SEED)

In [None]:
m8_1b

In [None]:
m8_1b.to_graphviz()

Let's plot some sample prior means

In [None]:
def plot_sample_prior_lines(trace, n_lines=50, ylim=None, title=None, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 5))
    if ylim is None:
        ylim = [0.5, 1.5]
    if title is None:
        title = "Sample prior means"

    x_val = trace.constant_data.rugged_std.data
    ax.plot(
        np.repeat(x_val[:, np.newaxis], n_lines, axis=1),
        trace.prior.mu.isel(chain=0, draw=slice(n_lines)).T,
        color="k",
        alpha=0.2,
    )

    min_log_gdp_std = trace.observed_data.log_gdp_std.min()
    ax.axhline(min_log_gdp_std, color="k", ls="--")

    max_log_gdp_std = trace.observed_data.log_gdp_std.max()
    ax.axhline(max_log_gdp_std, color="k", ls="--")

    ax.set(
        ylim=ylim,
        xlabel="Ruggedness",
        ylabel="Log GDP (proportion of mean)",
        title=title,
    )

    return ax

In [None]:
fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(12, 6))

plot_sample_prior_lines(
    prior_trace_8_1a, title="a ~ Norm(1, 1)\nb ~ Norm(0, 1)", ax=axs[0]
)
plot_sample_prior_lines(
    prior_trace_8_1b, title="a ~ Norm(1, 0.1)\nb ~ Norm(0, 0.3)", ax=axs[1]
)

fig.suptitle("Sample prior lines")
fig.tight_layout();

Now let's look at the posterior

In [None]:
with m8_1b:
    trace_8_1b = pm.sample(random_seed=RANDOM_SEED)

In [None]:
az.summary(
    trace_8_1b,
    var_names=["~mu"],
    kind="stats",
    hdi_prob=0.89,
    round_to=2,
)

The slope - beta - is zero, showing no relationship.

### 8.1.2 Adding an indicator variable isn't enough

Let's add a separate slope for Africa

In [None]:
continents = [
    "Not Africa",
    "Africa",
]

rugged["continent"] = np.where(
    rugged["cont_africa"] == 1,
    "Africa",
    "Not Africa",
)
continents_idx_mapper = {
    "Not Africa": 0,
    "Africa": 1,
}

coords = {
    "country": rugged.country,
    "continent": continents,
}
with pm.Model(coords_mutable=coords) as m8_2:
    # data
    rugged_std = pm.MutableData("rugged_std", rugged.rugged_std, dims="country")
    rugged_std_mean = pm.MutableData("rugged_std_mean", rugged.rugged_std.mean())
    continent_idx = pm.MutableData("continent_idx", rugged.cont_africa, dims="country")

    # priors
    alpha = pm.Normal("alpha", mu=1, sigma=0.1, dims="continent")
    beta = pm.Normal("beta", mu=0, sigma=0.3)
    sigma = pm.Exponential("sigma", lam=1)

    # model
    mu = pm.Deterministic(
        "mu",
        alpha[continent_idx] + beta * (rugged_std - rugged_std_mean),
        dims="country",
    )

    # likelihood
    log_gdp_std = pm.Normal(
        "log_gdp_std", mu=mu, sigma=sigma, observed=rugged.log_gdp_std, dims="country"
    )

    # sample prior predictive
    trace_8_2 = pm.sample(
        random_seed=RANDOM_SEED,
        idata_kwargs={"log_likelihood": True},
    )

In [None]:
m8_2.to_graphviz()

In [None]:
with m8_1b:
    pm.compute_log_likelihood(trace_8_1b)

In [None]:
az.compare(
    {
        "m8.1": trace_8_1b,
        "m8.2": trace_8_2,
    },
    ic="waic",
    scale="deviance",
)

In [None]:
az.compare(
    {
        "m8.1": trace_8_1b,
        "m8.2": trace_8_2,
    },
    ic="loo",
    scale="deviance",
)

Why am I getting warnings for model 8.1?

Let's look at the individual PSIS and WAIC values.

In [None]:
psis_8_2 = az.loo(trace_8_2, scale="deviance", pointwise=True)
waic_8_2 = az.waic(trace_8_2, scale="deviance", pointwise=True)

In [None]:
psis_8_2

I don't get any warnings from PSIS.

But plotting the values I can see a pretty extreme WAIC value

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

sns.scatterplot(
    x=psis_8_2.pareto_k,
    y=waic_8_2.waic_i,
    ax=ax,
)

ax.set(
    xlabel="Pareto $k$",
    ylabel="WAIC",
    title="WAIC and Pareto $k$ values",
);

What's this outlier?

In [None]:
trace_8_2.constant_data.isel(country=waic_8_2.waic_i.argmax())

It looks like the Seychelles.
Inspecting the scatterplot this is the African point at the top right - it's easy to see why this would have a large effect on the model fit.

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
sns.scatterplot(
    rugged,
    x="rugged_std",
    y="log_gdp_std",
    hue="continent",
    hue_order=continents,
    ax=ax,
);

Now let's look at the new model's posterior

In [None]:
az.summary(
    trace_8_2,
    var_names=["~mu"],
    kind="stats",
    hdi_prob=0.89,
    round_to=2,
)

The intercepts are noticeably different.

Let's plot the posterior predictive intervals.

In [None]:
# generate counterfactual ruggedness values
n_vals = 100
rugged_std_vals = np.linspace(0, 1, n_vals)

counterfactual_trace_8_2 = dict.fromkeys(continents)
with m8_2:
    for continent, continent_idx in continents_idx_mapper.items():
        pm.set_data(
            {
                "rugged_std": rugged_std_vals,
                "continent_idx": np.full(n_vals, continent_idx, dtype=int),
            },
            coords={
                "country": range(n_vals),
            },
        )
        counterfactual_trace_8_2[continent] = pm.sample_posterior_predictive(
            trace_8_2,
            var_names=["mu", "log_gdp_std"],
            predictions=True,
            random_seed=RANDOM_SEED,
        )

In [None]:
def plot_posterior_linear_model(
    x, mu, hdi_prob=0.89, smooth=True, smooth_kwargs=None, color=None, ax=None
):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 5))

    # plot mu
    mu_mean = mu.mean(dim=["chain", "draw"])
    if smooth:
        x_plot, mu_mean = smooth_plot_data(x, mu_mean, smooth_kwargs)
    ax.plot(
        x_plot,
        mu_mean,
        color=color,
    )

    # plot hdi around mu
    mu_hdi = az.hdi(mu, hdi_prob=hdi_prob).mu
    if smooth:
        x_plot, mu_hdi = smooth_plot_data(x, mu_hdi, smooth_kwargs)

    ax.fill_between(
        x_plot,
        mu_hdi[:, 0],
        mu_hdi[:, 1],
        color=color,
        alpha=0.3,
    )

    return ax

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))

sns.scatterplot(
    rugged,
    x="rugged_std",
    y="log_gdp_std",
    hue="continent",
    hue_order=continents,
    ax=ax,
)

for idx, (continent, trace) in enumerate(counterfactual_trace_8_2.items()):
    plot_posterior_linear_model(
        x=trace.predictions_constant_data.rugged_std,
        mu=trace.predictions.mu,
        smooth=True,
        color=colors[idx],
        ax=ax,
    )

ax.set(
    xlabel="Ruggedness",
    ylabel="Log GDP (prop of mean)",
    title="Linear model with 'Africa' indicator",
)
ax.legend(title="Continent");

### 8.1.3 Adding an interaction does work

Now add an interaction term to the model

In [None]:
coords = {
    "country": rugged.country,
    "continent": continents,
}
with pm.Model(coords_mutable=coords) as m8_3:
    # data
    rugged_std = pm.MutableData("rugged_std", rugged.rugged_std, dims="country")
    rugged_std_mean = pm.MutableData("rugged_std_mean", rugged.rugged_std.mean())
    continent_idx = pm.MutableData("continent_idx", rugged.cont_africa, dims="country")

    # priors
    alpha = pm.Normal("alpha", mu=1, sigma=0.1, dims="continent")
    beta = pm.Normal("beta", mu=0, sigma=0.3, dims="continent")
    sigma = pm.Exponential("sigma", lam=1)

    # model
    mu = pm.Deterministic(
        "mu",
        alpha[continent_idx] + beta[continent_idx] * (rugged_std - rugged_std_mean),
        dims="country",
    )

    # likelihood
    log_gdp_std = pm.Normal(
        "log_gdp_std", mu=mu, sigma=sigma, observed=rugged.log_gdp_std, dims="country"
    )

    # sample prior predictive
    trace_8_3 = pm.sample(
        random_seed=RANDOM_SEED,
        idata_kwargs={"log_likelihood": True},
    )

In [None]:
m8_3.to_graphviz()

And look at the posterior distributions

In [None]:
az.summary(
    trace_8_3,
    var_names=["~mu"],
    kind="stats",
    hdi_prob=0.89,
    round_to=2,
)

Now we see that Africa has a positive slope, while outside of Africa has a negative slope.

Let's inspect the WAIC

In [None]:
az.compare(
    {
        "m8.1": trace_8_1b,
        "m8.2": trace_8_2,
        "m8.3": trace_8_3,
    },
    ic="waic",
    scale="deviance",
)

We again get warnings.

In [None]:
psis_8_3 = az.loo(trace_8_3, scale="deviance", pointwise=True)
waic_8_3 = az.waic(trace_8_3, scale="deviance", pointwise=True)

In [None]:
psis_8_3

Again no warnings from PSIS

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))

sns.scatterplot(
    x=psis_8_3.pareto_k,
    y=waic_8_3.waic_i,
    hue=rugged["continent"],
    ax=ax,
)

ax.set(
    xlabel="Pareto $k$",
    ylabel="WAIC",
    title="WAIC and Pareto $k$ values",
);

### 8.1.4 Plotting the interaction

Let's plot the posteriors for the interaction model

In [None]:
# generate counterfactual ruggedness values
n_vals = 100
rugged_std_vals = np.linspace(0, 1, n_vals)

counterfactual_trace_8_3 = dict.fromkeys(continents)
with m8_3:
    for continent, continent_idx in continents_idx_mapper.items():
        pm.set_data(
            {
                "rugged_std": rugged_std_vals,
                "continent_idx": np.full(n_vals, continent_idx, dtype=int),
            },
            coords={
                "country": range(n_vals),
            },
        )
        counterfactual_trace_8_3[continent] = pm.sample_posterior_predictive(
            trace_8_3,
            var_names=["mu", "log_gdp_std"],
            predictions=True,
            random_seed=RANDOM_SEED,
        )

In [None]:
fig, ax = plt.subplots(figsize=(7, 6))

sns.scatterplot(
    rugged,
    x="rugged_std",
    y="log_gdp_std",
    hue="continent",
    hue_order=continents,
    ax=ax,
)

for idx, (continent, trace) in enumerate(counterfactual_trace_8_3.items()):
    plot_posterior_linear_model(
        x=trace.predictions_constant_data.rugged_std,
        mu=trace.predictions.mu,
        smooth=True,
        color=colors[idx],
        ax=ax,
    )

ax.set(
    xlabel="Ruggedness",
    ylabel="Log GDP (prop of mean)",
    title="Linear model with 'Africa' indicator",
)
ax.legend(title="Continent");