<a href="https://colab.research.google.com/github/flyaflya/persuasive/blob/main/demoNotebooks/happyWalkthrough.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install matplotlib numpyro daft --upgrade

In [None]:
import xarray as xr
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
import arviz as az
import pandas as pd
url = "https://raw.githubusercontent.com/flyaflya/persuasive/main/happy.csv"
happyDF = pd.read_csv(url)
happyDF

In [None]:
## transforms for plotting
plotDF = (
    happyDF
    .assign(x_trans = np.log10(happyDF.GDPperCapita) - np.mean(np.log10(happyDF.GDPperCapita)))
)
plotDF

In [None]:
## see initial data and transformed data
import matplotlib.pyplot as plt

fig, (initAx,transAx) = plt.subplots(ncols = 2, figsize=(8, 4), 
                        layout='constrained')
## continue below

In [None]:
## plot the initial data and the transformed data
plt.style.use("seaborn-whitegrid") ##place at beginning
fig, (initAx,transAx) = plt.subplots(ncols = 2, figsize=(8, 4), 
                        layout='constrained')

## initial data is non-linear
initAx.scatter(plotDF.GDPperCapita, plotDF.lifeSatisfaction)
initAx.set_xlabel("GDP Per Capita")
initAx.set_ylabel("Life Satisfaction")

## trasnformed data is linear
transAx.scatter(plotDF.x_trans, plotDF.lifeSatisfaction)
transAx.set_xlabel("Transformed GDP Data")
transAx.set_ylabel("Life Satisfaction")

In [None]:
## fit numpyro model with transformed data
#get data in format numpyro likes
transGDP = plotDF.x_trans.to_numpy()
lifeSatis = plotDF.lifeSatisfaction.to_numpy()

## define the generative DAG as a Python function
## for posterior predictive checks, we introduce numObs argument
def happyModel(xval, yval):
    alpha = numpyro.sample('alpha', dist.Normal(5,2))
    beta = numpyro.sample('beta', dist.Uniform(low = 0, high = 5))
    mu = numpyro.deterministic("mu", alpha + beta * xval )
    sigma = numpyro.sample("sigma", dist.Gamma(2,1))
    y = numpyro.sample("y", dist.Normal(mu,sigma), obs = yval)



In [None]:
# an okay way visualize the model
numpyro.render_model(happyModel, model_args=(transGDP,lifeSatis), render_distributions=True)

In [None]:
#@title happy DAG
import matplotlib.pyplot as plt
import pandas as pd
from functools import partial, partialmethod
import daft   ### %pip install -U git+https://github.com/daft-dev/daft.git
from numpy.random import default_rng
import numpy as np

class dag(daft.PGM):
    def __init__(self, *args, **kwargs):
        daft.PGM.__init__(self, *args, **kwargs)
    
    obsNode = partialmethod(daft.PGM.add_node, scale = 1.3, aspect = 2.4, fontsize = 10, plot_params = {'facecolor': 'cadetblue'})
    decNode = partialmethod(daft.PGM.add_node, aspect = 2.2, fontsize = 10, shape = "rectangle", plot_params = {'facecolor': 'thistle'})
    detNode = partialmethod(daft.PGM.add_node, aspect = 5.4, fontsize = 9.25, alternate = True, plot_params = {'facecolor': 'aliceblue'})
    latNode = partialmethod(daft.PGM.add_node, scale = 1.3, aspect = 2.4, fontsize = 10, plot_params = {'facecolor': 'aliceblue'})
    detNodeBig = partialmethod(daft.PGM.add_node, scale = 1.6, aspect = 2.25, fontsize = 10, alternate = True, plot_params = {'facecolor': 'aliceblue'})
    latNodeBig = partialmethod(daft.PGM.add_node, scale = 1.6, aspect = 2.2, fontsize = 10, plot_params = {'facecolor': 'aliceblue'})
    
pgm = dag(dpi = 300, alternate_style="outer")
pgm.obsNode("k","Daily # of Tickets Issued\n"+r"$k \sim Poisson(\lambda)$",1,1, aspect = 3, scale = 1.8)
pgm.latNode("mu","Avg. # of Daily Tickets\n"+r"$\lambda \sim Uniform(3000,7000)$",1,2.3, aspect = 3, scale = 1.8)
pgm.add_edge("mu","k")
pgm.add_plate([-0.5, 0.0, 3.0, 1.6], label = "Observation:\n" + r"$i = 1, 2, \ldots, 105$", 
              label_offset = (2,2), rect_params = dict({"fill": False, "linestyle": "dashed", "edgecolor": "black"}))
pgm.show(dpi=150)

In [None]:
# ## computationally get posterior distribution
mcmc = MCMC(NUTS(happyModel), num_warmup=1000, num_samples=4000) 
rng_key = random.PRNGKey(seed = 111) ## so you and I get same results
mcmc.run(rng_key, xval = transGDP, yval = lifeSatis) # get posterior
drawsDS = az.from_numpyro(mcmc).posterior ## get posterior samples into xarray

In [None]:
az.plot_posterior(drawsDS, var_names =  ["alpha","beta","sigma"])

In [None]:
postSamples = mcmc.get_samples()
postSamples

In [None]:
from numpyro.infer import Predictive
from jax import random
rng = random.PRNGKey(seed = 111)

## Predictive is a NumPyro class used to construct posterior distributions
predicitiveObject = Predictive(model = happyModel,
                               posterior_samples = postSamples)

In [None]:
## now make posterior predictions 
postPredData = predicitiveObject(rng_key, xval = transGDP, yval = None)
postPredData  ## for each of 4,000 draws, 
              ## get simulated observations of each country

In [None]:
dataForArvizPlotting = az.from_numpyro(
    posterior = mcmc,
    posterior_predictive=postPredData
)
az.plot_ppc(dataForArvizPlotting, num_pp_samples=20)

In [None]:
## plot back on original scale
## use sample mu value... for now, let's just look at 
## mu 4.0615063, 5.495851
## plot the initial data and the transformed data
plt.style.use("seaborn-whitegrid") ##place at beginning
fig, initAx = plt.subplots(figsize=(8, 4), 
                        layout='constrained')

## initial data is non-linear
initAx.scatter(plotDF.GDPperCapita, plotDF.lifeSatisfaction)
initAx.set_xlabel("GDP Per Capita")
initAx.set_ylabel("Life Satisfaction")

## come up with function to take in xvalues, transform
## them and spit out estimate lifesatisfaction

def yEstimate(alpha, beta, x):
    x_trans = np.log10(x) - np.mean(np.log10(happyDF.GDPperCapita))
    mu = alpha + beta * x_trans
    return mu

## make plot for first 5 candidate alpha and beta
alphaVals = np.array(postSamples["alpha"])[:5]
betaVals = np.array(postSamples["beta"])[:5]
xvals = np.linspace(0,120000,200)

for a,b in zip(alphaVals, betaVals):
    yvals = yEstimate(a,b,xvals)
    initAx.plot(xvals, yvals, linewidth = 0.8, color = "grey")
