In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Introduction

In this chapter, we are going to look at the issue of model parametrization.

The parametrization of a model has knock-on effects when sampling.
In the example from the previous chapter,
you saw how we had obtained "divergences" warnings
when hitting Thomas Wiecki's Inference Button (tm).
These divergences show up when the sampler has problems
sampling from the posterior distribution.

What kind of problems, you might ask?

Well, we'll first have to take a detour into the "shape" of joint distributions!

## Funneling down the river

To illustrate what we mean by the "shape" of joint distributions,
I am going to introduce you to **Neal's Funnel**,
named after Professor Radford Neal,
who proposed this as a _particularly extreme_ example
of pathological likelihood geometries.
Neal's funnel a slightly different form and more dimensions of $x$,
but I have simplified the example here to make it easier to follow.

In particular, we're going to look at the following joint distribution
between two random variables, $x$ and $v$:

$$v \sim N(0, 1)$$
$$x \sim N(0, e^v)$$

Let's see this simplified Neal's Funnel implemented in code.

In [None]:
from scipy.stats import norm, expon
import numpy as np

In [None]:
import pandas as pd
data = []
n_dims = 1
for i in range(1000):
    v = norm(0, 1).rvs(1)
    x = norm(0, np.exp(v)).rvs(n_dims)
    data.append(np.hstack([v, x]))
data = pd.DataFrame(data)
data.columns = ["v"] + [fr"x_{i+1}" for i in range(n_dims)]

In [None]:
import seaborn as sns
ax = data.plot(x="x_1", y="v", kind="scatter")
ax.set_xlabel(r"$x_1$")
sns.despine()

Let's go through why the joint distribution is shaped as it is.
(In retrospect,
this is one of those situations that makes a ton of sense
only after one sees it.)

In this joint distribution between $x_1$ and $v$,

$$v \sim N(0, 1)$$
$$x \sim N(0, e^v)$$

when $v$ is negative, the variance term of $x$ is very small,
because $e^v$ is a fractional number,
and hence the numbers drawn have a very tight distribution around the mean $0$.
When $v$ is positive, the variance term of $x$ grows exponentially,
and so the numbers drawn are extremely variable.

As you can see, the joint distribution samples are shaped like a funnel.
Not only that, if you imagine the _density_ of points sampled,
there is a third dimension that rises up from the screen to you.
The contours look something like this:

In [None]:
def neals_funnel_likelihood(v, x):
    v_like = norm(0, 1).pdf(v)
    x_like = norm(0, np.exp(v)).pdf(x)
    return v_like * x_like

In [None]:
from itertools import product
from tqdm.autonotebook import tqdm
from scipy.interpolate import griddata
import matplotlib.pyplot as plt


vs = np.linspace(-3, 3, 1000)
xs = np.linspace(-20, 20, 1000)
X, V = np.meshgrid(xs, vs)
L = neals_funnel_likelihood(V, X)
fig, ax = plt.subplots()

ax.contour(X, V, L, levels=10, cmap="viridis");
ax.set_xlabel("x")
ax.set_ylabel("v")
sns.despine()


Now that _really_ looks like a funnel!
As usual, $x$ and $v$ are plotted against each other,
but now the likelihood of jointly drawing any two values of $(x, v)$
are going to be really concentrated in the funnel.

### Sampling difficulties

Now, I'd like to imagine you're an MCMC sampler.
(To understand what goes on underneath the hood,
check out the chapter on MCMC sampling.)
Not a fancy one, just a simple one.
The rule by which you propose a new value to sample
is governed by a standard Gaussian distribution $N(0, 1)$.
(That is your proposal distribution.)
Your goal is to _most_ of the time sample new points
that are within regions of high likelihood,
occasionally allowing yourself to step out but not always.

Now imagine that you fell into the funnel,
such that you sampled $x=0.0003$ and $v=-2.89275$ on your last step.
With a simple $N(0, 1)$ proposal distribution,
most of the points you propose for $x$
are going to fall outside of the region of high likelihood.
You're going to be mostly stuck!

### Alternative reparametrization

A reparametrization of the model makes things a lot easier.

Instead of:

$$v \sim N(0, 3)$$
$$x \sim N(0, e^v)$$

We switch to:

$$\hat{v} \sim N(0, 1)$$
$$ v = 3 \hat{v}  + 0$$
$$\hat{x} \sim N(0, 1)$$
$$ x = \hat{x}e^v + 0$$

### What manner of monstrosity is this?

Fret not! What is going on here can be explained.

One little trick we have up our sleeves is this:
Every Gaussian distribution $N(\mu, \sigma)$
can be generated by sampling from an $N(0, 1)$ distribution,
multiplying by $\sigma$, and adding $\mu$!

So in the above model,
$\hat{v}$ is sampled from a standard Gaussian,
then multiplied by the variance term $3$ and added to the mean term $0$,
thereby generating the original $v \sim N(0, 3)$ distribution.

Same goes for $x$:
$\hat{x}$ is sampled from a standard Gaussian,
then multiplied by the variance term $e^v$ and added to the mean term $0$,
thereby regenerating the original $x \sim N(0, e^v)$ distribution!

### What's the effect?

The effect of a reparametrization for the MCMC samplers used in PyMC3 is tremendous!

Instead of having to sample from an $N(0, e^v)$ distribution for the random variable $x$,
the sampler can instead propose new steps for $\hat{x}$ in a standard $N(0, 1)$ space,
which is less likely to generate rejected samples.

### But the analogy feels a bit contrived...

If you're feeling this way, you're not alone.

Most data problems don't look _this_ pathological.
However, it is an extremely illustrative example of what happens
when an MCMC sampler rejects most of the proposed samples.
The 

## Reparametrizing Logits

Let's see how all of those lessons can be applied in our ice cream shop problem!
After all, in that prior chapter on hierarchical models,
we did suffer from divergences in the sampling process,
indicating to us that there may have been pathologies 
in the shape of the joint likelihoods.

Let's start by re-loading the trace from before.

In [None]:
import arviz as az
from pyprojroot import here

trace_original = az.from_netcdf(here() / "data/ice_cream_shop_hierarchical_posterior.nc")

In [None]:
az.plot_trace(trace_original, var_names=["p_owner"]);

In the trace plot, you might see some of the places where divergences happened.
In particular, for owner 7, on one of the chains, the divergences are pretty obvious.
That's where the sampler _really_ got stuck in a region of bad geometry.

Here's another way to visualize why this _might_ be happening:
let's go ahead and plot the $p$ terms and variance terms against each other.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_posterior_pair(trace, rv1="p_owner", rv2="logit_p_owner_scale"):
    locations = trace.posterior[rv1].to_dataframe().unstack(-1)
    scales = trace.posterior[rv2].to_dataframe().unstack(-1)
    
    for i in range(9):
        plt.scatter(locations[(rv1, i)], scales[(rv2, i)], alpha=0.3, label=f"{i}")
    plt.xlabel(rv1)
    plt.ylabel(rv2)
    sns.despine()
    plt.legend();

In [None]:
plot_posterior_pair(trace_original)

As you can see, the geometry isn't _so_ pathological,
but a funnel-like shape can be observed for a subset of shop owners.
(We have to keep in mind that these are not exact posteriors,
but samples taken from what is otherwise a fairly biased sampling procedure
because of the divergences.)
Let's take a look at a reparametrized version of the same model that we wrote before.
But first, just to jog your memory:

In [None]:
from bayes_tutorial.solutions.hierarchical import ice_cream_hierarchical_model

ice_cream_hierarchical_model??

The place that should look like we might be able to do a reparametrization
is in the line that has the `logit_p_shop` RV defined.
This is because it has a $\mu$ and $\sigma$ both defined as random variables.
Perhaps that might be where the pathologies might lie.

## Load Data



In [None]:
from bayes_tutorial.data import load_ice_cream
data = load_ice_cream()
data.head()

## Build reparametrized model

Remember that the key idea here is to convert:

$$N(\mu, \sigma)$$ 

into its equivalent:

$$N(0, 1) * \sigma + \mu$$

Take a look at the code below to see how it's done.
I've done my best annotating the location where the reparametrization happens.

In [None]:
import pymc3 as pm

n_owners = len(data["owner_idx"].unique())
with pm.Model() as model:
    logit_p_overall = pm.Normal("logit_p_overall", mu=0, sigma=1)
    logit_p_owner_mean = pm.Normal(
        "logit_p_owner_mean",
        mu=logit_p_overall,
        sigma=1,
        shape=(n_owners,),
    )
    logit_p_owner_scale = pm.Exponential(
        "logit_p_owner_scale", lam=1 / 5.0, shape=(n_owners,)
    )
    logit_p_shop_raw = pm.Normal(
        "logit_p_shop_raw",
        mu=0,
        sigma=1,
        shape=(len(data),),
    )
    
    logit_p_shop = pm.Deterministic(
        "logit_p_shop",
        logit_p_shop_raw * logit_p_owner_scale[data["owner_idx"]] + logit_p_owner_mean[data["owner_idx"]],
    )

    p_overall = pm.Deterministic("p_overall", pm.invlogit(logit_p_overall))
    p_shop = pm.Deterministic("p_shop", pm.invlogit(logit_p_shop))
    p_owner = pm.Deterministic("p_owner", pm.invlogit(logit_p_owner_mean))
    like = pm.Binomial(
        "like",
        n=data["num_customers"],
        p=p_shop,
        observed=data["num_favs"],
    )

In [None]:
with model:
    trace = pm.sample(2000)
    trace = az.from_pymc3(
        trace,
        coords={
            "p_shop_dim_0": data["shopname"],
            "logit_p_shop_transformed": data["shopname"],
            "logit_p_shop_dim_0": data["shopname"],
            "logit_p_owner_scale_dim_0": data["owner_idx"].sort_values().unique(),
            "p_owner_dim_0": data["owner_idx"].sort_values().unique(),
            "logit_p_owner_mean": data["owner_idx"].sort_values().unique(),
        },
    )

## Inspect traces

Let's now do a comparison of the traces.
As with the prior chapter on hierarchical models,
for simplicity's sake, we're going to look at only the owner $\mu$ and $p$,
rather than look at all of the shop's $\mu$ and $p$.

In [None]:
az.plot_trace(trace, var_names=["logit_p_owner_mean"]);

No more divergences! Also, the traces look more like hairy caterpillars again. Compare that to the traces from before:

In [None]:
az.plot_trace(trace_original, var_names=["logit_p_owner_mean"]);

It should be clear that the divergences are gone.

### What does this mean for me?

The long story cut short: Look for opportunities to reparametrize!

In particular, if you run into a situation like we have
where you have a Gaussian in your model,
if divergences show up during sampling,
this is an opportunity to try a non-centered parametrization.

## Further Reading

- [Stan User Guide on Reparametrizations](https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html)
- [PyMC3 docs on diagnosing biased inferences with divergences](https://docs.pymc.io/notebooks/Diagnosing_biased_Inference_with_Divergences.html)