<a href="https://colab.research.google.com/github/flyaflya/persuasive/blob/main/demoNotebooks/diagnosticsWalkthrough.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install matplotlib numpyro daft --upgrade

In [None]:
import xarray as xr
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
import arviz as az
import pandas as pd
url = "https://raw.githubusercontent.com/flyaflya/persuasive/main/tickets.csv"
ticketsDF = pd.read_csv(url, parse_dates = ["date"])
wedTicketsDF = (
    ticketsDF
    .assign(dayOfWeek = lambda DF: DF.date.dt.day_name())
    .where(lambda DF: DF.dayOfWeek == "Wednesday")
    .dropna()
    .groupby("date")
    .agg(numTickets = ('daily_tickets', 'sum'))
)

In [None]:
#@title Tickets DAG
import matplotlib.pyplot as plt
import pandas as pd
from functools import partial, partialmethod
import daft   ### %pip install -U git+https://github.com/daft-dev/daft.git
from numpy.random import default_rng
import numpy as np

class dag(daft.PGM):
    def __init__(self, *args, **kwargs):
        daft.PGM.__init__(self, *args, **kwargs)
    
    obsNode = partialmethod(daft.PGM.add_node, scale = 1.3, aspect = 2.4, fontsize = 10, plot_params = {'facecolor': 'cadetblue'})
    decNode = partialmethod(daft.PGM.add_node, aspect = 2.2, fontsize = 10, shape = "rectangle", plot_params = {'facecolor': 'thistle'})
    detNode = partialmethod(daft.PGM.add_node, aspect = 5.4, fontsize = 9.25, alternate = True, plot_params = {'facecolor': 'aliceblue'})
    latNode = partialmethod(daft.PGM.add_node, scale = 1.3, aspect = 2.4, fontsize = 10, plot_params = {'facecolor': 'aliceblue'})
    detNodeBig = partialmethod(daft.PGM.add_node, scale = 1.6, aspect = 2.25, fontsize = 10, alternate = True, plot_params = {'facecolor': 'aliceblue'})
    latNodeBig = partialmethod(daft.PGM.add_node, scale = 1.6, aspect = 2.2, fontsize = 10, plot_params = {'facecolor': 'aliceblue'})
    
pgm = dag(dpi = 300, alternate_style="outer")
pgm.obsNode("k","Daily # of Tickets Issued\n"+r"$k \sim Poisson(\lambda)$",1,1, aspect = 3, scale = 1.8)
pgm.latNode("mu","Avg. # of Daily Tickets\n"+r"$\lambda \sim Uniform(3000,7000)$",1,2.3, aspect = 3, scale = 1.8)
pgm.add_edge("mu","k")
pgm.add_plate([-0.5, 0.0, 3.0, 1.6], label = "Observation:\n" + r"$i = 1, 2, \ldots, 105$", 
              label_offset = (2,2), rect_params = dict({"fill": False, "linestyle": "dashed", "edgecolor": "black"}))
pgm.show(dpi=150)

In [None]:
## get array of tickey values as numpy array for numpyro
wedTicketsDS = xr.Dataset.from_dataframe(wedTicketsDF)
wedTickets = wedTicketsDS.numTickets.to_numpy()

## define the graphical/statistical model as a Python function
def ticketsModel(k):
    ## NOTE LAMBDA IS RESERVED WORD IN PYTHON... MUST USE MODIFIED NAME
    lambdaParam = numpyro.sample('lambdaParam', dist.Uniform(low = 3000, high = 7000))
    
    with numpyro.plate('observation', len(k)):
        k = numpyro.sample('k', dist.Poisson(rate = lambdaParam), obs = k)

# ## computationally get posterior distribution
mcmc = MCMC(NUTS(ticketsModel), num_warmup=1000, num_samples=4000) 
rng_key = random.PRNGKey(seed = 111) ## so you and I get same results
mcmc.run(rng_key, k=wedTickets) ## get representative sample of posterior
drawsDS = az.from_numpyro(mcmc).posterior ## get posterior samples into xarray
az.plot_posterior(drawsDS)

### Posterior Predictive Check


In [None]:
(
    drawsDS
    .lambdaParam
    .to_numpy()
    .flatten()
)

In [None]:
from numpy.random import default_rng
rng = default_rng(seed = 111) 
# drawsDS.lambdaParam is 2-d, i.e. chain and draws are coordinates
# we use flatten to get a 1-d numpy array for input to rng.choice()
lambdaPost = rng.choice(a = drawsDS.lambdaParam.to_numpy().flatten(), size = 1)
lambdaPost.item()  ## show just single value

In [None]:
simulatedData = rng.poisson(lam = lambdaPost, size = len(wedTickets))
simulatedData

## compare density estimates of simulated data to observed data:

In [None]:
#| Simulated data is not capturing the variance seen in the observed data.
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from numpy import linspace

fig, ax = plt.subplots(figsize=(8, 4), 
                        layout='constrained')

# plot density estimate, i.e. estimate of f(x)
az.plot_dist(wedTickets, ax = ax, color = "darkorchid", 
             kind = "kde", fill_kwargs={'alpha': 0.5})
az.plot_dist(simulatedData, ax = ax, color = "cadetblue", 
             kind = "kde", fill_kwargs={'alpha': 0.5})

ax.set_ylabel('Plausibility Measure: ' + r'$f(k)$')
ax.set_xlabel(r'Daily Issued Tickets  $(k)$')
ax.set_xticks(linspace(0,8000,9))

custom_lines = [Line2D([0], [0], color = "darkorchid", lw=4, alpha = 0.5),
                Line2D([0], [0], color = "cadetblue", lw=4, alpha = 0.5)]
ax.legend(custom_lines, 
          ['Density Estimate for Observed Data', 
           'Density Estimate for Simulated Data'], loc='upper left')

plt.show()

## Posterior Predictive Checks Using Arviz and NumPyro



In [None]:
#@title Cherry Tree Model - Posterior Predictive Check
import matplotlib.pyplot as plt
import pandas as pd
from functools import partial, partialmethod
import daft   ### %pip install -U git+https://github.com/daft-dev/daft.git
from numpy.random import default_rng
import numpy as np

class dag(daft.PGM):
    def __init__(self, *args, **kwargs):
        daft.PGM.__init__(self, *args, **kwargs)
    
    obsNode = partialmethod(daft.PGM.add_node, scale = 1.3, aspect = 2.4, fontsize = 10, plot_params = {'facecolor': 'cadetblue'})
    decNode = partialmethod(daft.PGM.add_node, aspect = 2.2, fontsize = 10, shape = "rectangle", plot_params = {'facecolor': 'thistle'})
    detNode = partialmethod(daft.PGM.add_node, aspect = 5.4, fontsize = 9.25, alternate = True, plot_params = {'facecolor': 'aliceblue'})
    latNode = partialmethod(daft.PGM.add_node, scale = 1.3, aspect = 2.4, fontsize = 10, plot_params = {'facecolor': 'aliceblue'})
    detNodeBig = partialmethod(daft.PGM.add_node, scale = 1.6, aspect = 2.25, fontsize = 10, alternate = True, plot_params = {'facecolor': 'aliceblue'})
    latNodeBig = partialmethod(daft.PGM.add_node, scale = 1.6, aspect = 2.2, fontsize = 10, plot_params = {'facecolor': 'aliceblue'})
    
pgm = dag(dpi = 300, alternate_style="outer")
pgm.obsNode("x","Tree Height \n"+r"$x \sim StudentT(\nu,\mu,\sigma)$",0,1, scale = 1.5, aspect = 4)
pgm.latNode("mu","Avg Cherry Tree Height\n"+r"$\mu \sim Normal(50,24.5)$",0,2.2, scale = 1.5, aspect = 4)
pgm.latNode("sigma","Std.Dev. of Observed Height\n"+r"$\sigma \sim Uniform(0,50)$",3.5,2.2, scale = 1.5, aspect = 4)
pgm.latNode("nu","Deg. of Freedom\n"+r"$\nu \sim Gamma(2,0.1)$",-3.5,2.2, scale = 1.5, aspect = 4)
pgm.add_edge("mu","x")
pgm.add_edge("sigma","x")
pgm.add_edge("nu","x")
pgm.add_plate([-1.6, 0.1, 3.2, 1.4], label = "Observation:\n" + r"$i = 1, 2, \ldots, 31$", 
              label_offset = (2,2), rect_params = dict({"fill": False, "linestyle": "dashed", "edgecolor": "black"}))
pgm.show(dpi=150)

We get the posterior as usual:

In [None]:
import pandas as pd
import xarray as xr
import numpy as np

#get data
url = "https://raw.githubusercontent.com/flyaflya/persuasive/main/trees.csv"
treeHeightData = pd.read_csv(url).Height.to_numpy()

## define the generative DAG as a Python function
## for posterior predictive checks, we introduce numObs argument
def cherryTreeModelT(x, numObs):
    nu = numpyro.sample('nu', dist.Gamma(concentration = 2, rate = 0.1))
    mu = numpyro.sample('mu', dist.Normal(loc = 50, scale = 24.5))
    sigma = numpyro.sample('sigma', dist.Uniform(low = 0, high = 50))

    with numpyro.plate('observation', numObs):
        x = numpyro.sample('x', dist.StudentT(df = nu,
                                            loc = mu,
                                            scale = sigma), obs=x)

# ## computationally get posterior distribution
mcmc = MCMC(NUTS(cherryTreeModelT), num_warmup=1000, num_samples=4000) 
rng_key = random.PRNGKey(seed = 111) ## so you and I get same results
mcmc.run(rng_key, x = treeHeightData, numObs = len(treeHeightData)) # get posterior
drawsDS = az.from_numpyro(mcmc).posterior ## get posterior samples into xarray

Now for a posterior predictive check.  Here, we reverse the process from inference to simulation as can be visualized using this generative DAG.

In [None]:
#@title Post Pred Check is Opposite of Inference
import matplotlib.pyplot as plt
import pandas as pd
from functools import partial, partialmethod
import daft   ### %pip install -U git+https://github.com/daft-dev/daft.git
from numpy.random import default_rng
import numpy as np

class dag(daft.PGM):
    def __init__(self, *args, **kwargs):
        daft.PGM.__init__(self, *args, **kwargs)
    
    obsNode = partialmethod(daft.PGM.add_node, scale = 1.3, aspect = 2.4, fontsize = 10, plot_params = {'facecolor': 'aliceblue'})
    decNode = partialmethod(daft.PGM.add_node, aspect = 2.2, fontsize = 10, shape = "rectangle", plot_params = {'facecolor': 'thistle'})
    detNode = partialmethod(daft.PGM.add_node, aspect = 5.4, fontsize = 9.25, alternate = True, plot_params = {'facecolor': 'aliceblue'})
    latNode = partialmethod(daft.PGM.add_node, scale = 1.3, aspect = 2.4, fontsize = 10, plot_params = {'facecolor': 'cadetblue'})
    detNodeBig = partialmethod(daft.PGM.add_node, scale = 1.6, aspect = 2.25, fontsize = 10, alternate = True, plot_params = {'facecolor': 'aliceblue'})
    latNodeBig = partialmethod(daft.PGM.add_node, scale = 1.6, aspect = 2.2, fontsize = 10, plot_params = {'facecolor': 'aliceblue'})
    
pgm = dag(dpi = 300, alternate_style="outer")
pgm.obsNode("x","Tree Height \n"+r"$x \sim StudentT(\nu,\mu,\sigma)$",0,1, scale = 1.5, aspect = 4)
pgm.latNode("mu","Avg Cherry Tree Height\n"+r"$\mu$",0,2.2, scale = 1.5, aspect = 4)
pgm.latNode("sigma","Std.Dev. of Observed Height\n"+r"$\sigma$",3.5,2.2, scale = 1.5, aspect = 4)
pgm.latNode("nu","Deg. of Freedom\n"+r"$\nu$",-3.5,2.2, scale = 1.5, aspect = 4)
pgm.add_edge("mu","x")
pgm.add_edge("sigma","x")
pgm.add_edge("nu","x")
pgm.add_plate([-1.6, 0.1, 3.2, 1.4], label = "Simulated Observation:\n" + r"$i = 1, 2, \ldots, 31$", 
              label_offset = (2,2), rect_params = dict({"fill": False, "linestyle": "dotted", "edgecolor": "darkorchid"}))
pgm.add_plate([-5.1, 0, 10.2, 2.8], label = "Posterior Draw:\n" + r"$draw = 1, 2, \ldots, 4000$", 
              label_offset = (2,2), rect_params = dict({"fill": False, "linestyle": "dashed", "edgecolor": "cadetblue"}))
pgm.show(dpi=150)

In [None]:
### get samples of the darker nodes
postSamples = mcmc.get_samples()
postSamples

In [None]:
from numpyro.infer import Predictive
from jax import random
rng = random.PRNGKey(seed = 111)

## Predictive is a NumPyro class used to construct posterior distributions
cherryTreePredictiveObject = Predictive(model = cherryTreeModelT, 
                                        posterior_samples = postSamples)

Once that object is created, you can treat this new object named `cherryTreePredictiveObject` like a function whose arguments are 1) `rng_key`: a `jax.random.PRNGKey`  random key used to draw samples, and 2) `args`: which is all the arguments required for the model (e.g. `cherryTreeModelT`) to work.  This function will return a dict of samples from the predictive distribution.  By default, only sample sites not contained in `posterior_samples` are returned.

In [None]:
## now make posterior predictions - note len(treeHeightData) = 31
## use None for data so that it gets simulated from posterior draws
postPredData = cherryTreePredictiveObject(rng_key, x = None, numObs = 31)
postPredData  ## for each of 4,000 draws, 
              ## get 31 simulated observations of cherry tree height

And then, we compare data simulated from the 4,000 random posterior draws to the observed data.  _Actually, to aid clarity, we will just use simulated data from several draws, say 20, to get a picture of how various posterior densities compare to the estimated density of the observed data._  By creating multiple simulated datasets, we can see how much the data distributions vary among plausible posterior values.  Observed data is subject to lots of randomness, so we just want to ensure that the observed randomness falls within the realm of our plausible narratives.  The below code creates an `arviz` object and subsequently automates the plotting of the posterior predictive check that we are interested in.

In [None]:
## use arviz plotting capabilities because making these by hand is HARD

dataForArvizPlotting = az.from_numpyro(
    posterior = mcmc,
    posterior_predictive=postPredData
)
az.plot_ppc(dataForArvizPlotting, num_pp_samples=20)