# Bayesian Hierarchical Linear Regression Exercises

In [1]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns
from scipy import stats

az.style.use("arviz-darkgrid")
RANDOM_SEED = 8265
np.random.seed(RANDOM_SEED)

np.set_printoptions(2)

fish_market = pd.read_csv("data/fish-market.csv")
fish_market = fish_market.drop(["Length2", "Length3"], axis="columns")
fish_market["log_width"] = np.log(fish_market.Width)
fish_market["log_height"] = np.log(fish_market.Height)
fish_market["log_length"] = np.log(fish_market.Length1)
fish_market["log_weight"] = np.log(fish_market.Weight)
fish_reduced = fish_market[fish_market["Weight"] != 0].copy()

fish_test = fish_reduced.sample(frac=0.1, random_state=RANDOM_SEED).sort_index()
test_idx = fish_test.index
fish_train = fish_reduced.loc[fish_reduced.index.difference(test_idx)]

species_idx, species = fish_train.Species.factorize(sort=True)
COORDS = {
    "slopes": ["width_effect"],  # , "height_effect", "length_effect"],
    "species": species,
}

  result = getattr(ufunc, method)(*inputs, **kwargs)


## Exercise 1: Real world examples of hierarchies

In the lesson we mentioned that nested and grouped data is very common. Think of 5 real-world examples where hierarchical models would be a good fit. Bonus points if some examples come from data sets you worked with.

## Exercise 2: Explore Distributions Over Parameters

In this exercise we want to build stronger intuitions around what the different parameters of the model do. Towards this goal, you will generate plots using different settings and analyzing the outputs.
Specifically, try the following settings and describe what changed:
* global_mu = -0.5 ; global_sigma = 1 ; local_sigma = 0.1 ; n_local = 8
* global_mu = 0.0 ; global_sigma = 0.5 ; local_sigma = 0.5 ; n_local = 8
* global_mu = 0.0 ; global_sigma = 0.5 ; local_sigma = 0.1 ; n_local = 16
* global_mu = 0.0 ; global_sigma = 10. ; local_sigma = 0.1 ; n_local = 8
* global_mu = 0.0 ; global_sigma = 0.001 ; local_sigma = 0.1 ; n_local = 8

In [None]:
def gen_hier_plot(mu=0, sigma=0.5, eps=0.1, groups=8, seed=2):
    """
    Sample hierarchical data and plot it.
    """
    from scipy import stats
    from numpy.random import default_rng

    rng = default_rng(seed)
    
    group_dist = stats.norm(mu, sigma)
    mus = group_dist.rvs(groups, random_state=rng)
    data = stats.norm(mus, eps).rvs(size=(1000, groups), random_state=rng)
    x = np.linspace(data.min() - .5, data.max() + .5, 1000)

    fig, axs = plt.subplots(nrows=2, sharex=True)
    pdf = group_dist.pdf(x)
    axs[0].plot(x, pdf)
    axs[0].set(title=f"global distribution\nmu={mu} sigma={sigma}, eps={eps}, groups={groups}", ylabel="Belief")
    axs[1].set(title="local distributions", xlabel="x", ylabel="Belief")
    for i in range(groups):
        color = sns.color_palette(n_colors = groups)[i]
        axs[0].plot(mus[i], 0, ".", ms=8, color=color)
        az.plot_dist(data[:, i], color=color, ax = axs[1])

## Exercise 3: Effect of group size on estimates

We saw that for Perch, where we had few data, we get strong shrinkage while for Whitefish, where we had more data, we got less shrinkage. But how do these effects interact? In this exercise, you will explore this relationship more systematically. Towards this goal, you will create a plot where on the x-axis is the number of data points for a species, and on the y-axis is the beta estimate (with error bars reflecting posterior uncertainty) for that species. As we are interested in how strong the shrinkage is towards the global mean include the global beta mu (here you can just take the posterior mean) as a horizontal line (you can use `plt.axhline` for this).

It is interesting to compare this to the unpooled model where the shrinkage should be much weaker (and towards 0 instead of towards global mean) so create the same plot for the unpooled model.

What do you see in terms of how strong shrinkage is in relation to how much data we have per species? Does this make sense? Describe the effect in your own words!

In [None]:
with pm.Model(coords=COORDS) as unpooled_intercept_unpooled_beta:
    # data
    log_width = pm.MutableData("log_width", 
                               fish_train.log_width.values)
    log_weight = pm.MutableData("log_weight", 
                                fish_train.log_weight.values)
    species_idx_ = pm.ConstantData("species_idx", species_idx)

    # priors
    intercept = pm.Normal("intercept", sigma=1.0, dims="species")
    β = pm.Normal("β", sigma=0.5, dims="species")
    
    # linear regression
    mu = pm.Deterministic("mu", intercept[species_idx_] + \
                                β[species_idx_] * log_width)

    # observational noise
    sigma = pm.HalfNormal("sigma", 1.0)

    # likelihood
    log_obs = pm.Normal(
        "log_obs",
        mu=mu,
        sigma=sigma,
        observed=log_weight,
    )

    # sampling
    idata_unpooled_intercept_unpooled_beta = pm.sample()


with pm.Model(coords=COORDS) as hierarchical_intercept_hierarchical_beta:
    # data
    log_width = pm.MutableData("log_width", 
                               fish_train.log_width.values)
    log_weight = pm.MutableData("log_weight", 
                                fish_train.log_weight.values)
    species_idx_ = pm.MutableData("species_idx", species_idx)
    
    # global priors for intercepts
    intercept_mu = pm.Normal("intercept_mu", sigma=3.0)
    intercept_sigma = pm.HalfNormal("intercept_sigma", sigma=1.0)

    # individual intercepts for each species
    intercept = pm.Normal(
        "intercept", mu=intercept_mu, sigma=intercept_sigma, 
        dims="species"
    )
    
    # global prior for betas
    β_mu = pm.Normal("β_mu", sigma=3.0)
    β_sigma = pm.HalfNormal("β_sigma", sigma=1.0)

    # individual betas for each species
    β = pm.Normal("β", mu=β_mu, sigma=β_sigma, dims="species")
    
    # linear regression
    mu = intercept[species_idx_] + β[species_idx_] * log_width

    # observational noise
    eps = pm.HalfNormal("eps", 1.0)

    # likelihood
    log_obs = pm.Normal(
        "log_obs",
        mu=mu,
        sigma=eps,
        observed=log_weight,
    )
    
    # Hit the Inference Button(TM)
    idata_hierarchical_intercept_hierarchical_beta = pm.sample(target_accept=0.95)

## Exercise 4: Unpooled, pooled, hierarchical variations

Our linear regression model had 2 parameters (per species): an intercept and a beta (slope). In the lesson we put a hierarchy on both of these. However, we can mix and match freely and decide per parameter whether it should be pooled, unpooled, or hierarchical. All of these choices lead to different constraints on the parameters and that's what you'll explore in this exercise. Towards this goal, you will build different variations to get more familiarity with the code. Keep these models and idata objects around, however, as you will analyze the outputs in the next exercise.

Build the following versions of the model:
* Replace the hierarchy on the intercepts with a pooled style
* Replace the hierarchy on the intercepts with an unpooled style
* Replace the hierarchy on the slopes with a pooled style
* Replace the hierarchy on the slopes with an unpooled style
  
Run each of these different models and make sure that everything converged well.

## Exercise 5: Effect of Hierarchy

Let's examine what type of regression patterns these different models produce. For example, if you fix all slopes to be the same (i.e. pooled), but let the intercepts vary (i.e. hierarchical), how do you think the regression lines would look like?

To find out, generate a plot where you plot all the regression lines of the different species in a single plot. If you were to include the uncertainty in this plot it would get too messy, so just the plot regression line for each species of the posterior mean.

Now generate this plot for the different variants of the model you created in the previous exercise. What do you observe in each variant, what are the similarities and differences? Do the patterns make sense in terms of how the model is structured? Why?

In [None]:
import xarray as xr

def plot_pred(idata, ax, plot_multiple_draws=False, colors=None, title=""):
    """
    Helper function to plot regression lines from the posterior on top of data points.
    """
    x = xr.DataArray(np.linspace(0, 2.5, 150), dims=["x_plot"])
    post_mean = idata.posterior.mean(("chain", "draw"))
    y_mu = post_mean["intercept"] + x * post_mean["β"]
    
    if plot_multiple_draws:
        post_subset = az.extract(idata, num_samples=50)
        y_reg = post_subset["intercept"] + x * post_subset["β"]

    if colors is None:
        colors = sns.color_palette(n_colors = len(species))
        
    for i, species_i in enumerate(species):
        fish_spec = fish_reduced[fish_reduced.Species == species_i]
        ax.scatter(fish_spec["log_width"], fish_spec["log_weight"], color=colors[i])
        ax.plot(x, y_mu.sel(species=species_i), color=colors[i], alpha=0.5, lw=2, label=species_i)

        if plot_multiple_draws:
            ax.plot(x, y_reg.sel(species=species_i).transpose(..., "sample"), color=color, alpha=0.1)

    if title != "":
        ax.set_title(title)

    return ax