[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ksachdeva/rethinking-tensorflow-probability/blob/master/notebooks/13_models_with_memory.ipynb)

# Chapter 13 - Models with Memory


## Imports and utility functions


In [1]:
# Install packages that are not installed in colab
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

if IN_COLAB:
    %tensorflow_version 2.X
    
    !pip install watermark
    !pip install arviz
    !pip install tensorflow_probability==0.9.0

In [2]:
%load_ext watermark

In [3]:
from functools import partial

# Core
import numpy as np
import arviz as az
import pandas as pd
import xarray as xr
import tensorflow as tf
import tensorflow_probability as tfp

# visualization 
import matplotlib.pyplot as plt

# aliases
tfd = tfp.distributions
tfb = tfp.bijectors
Root = tfd.JointDistributionCoroutine.Root

In [4]:
%watermark -p numpy,tensorflow,tensorflow_probability,arviz,scipy,pandas

numpy 1.18.1
tensorflow 2.1.0
tensorflow_probability 0.9.0
arviz 0.6.1
scipy 1.4.1
pandas 0.25.3


In [5]:
# config of various plotting libraries
%config InlineBackend.figure_format = 'retina'
az.style.use('arviz-darkgrid')

In [6]:
USE_XLA = False

In [7]:
NUMBER_OF_CHAINS  = 2
NUMBER_OF_BURNIN  = 500
NUMBER_OF_SAMPLES = 500

def _trace_to_arviz(trace=None,
                   sample_stats=None,
                   observed_data=None,
                   prior_predictive=None,
                   posterior_predictive=None,
                   inplace=True):

    if trace is not None and isinstance(trace, dict):
        trace = {k: np.swapaxes(v.numpy(), 1, 0)
                 for k, v in trace.items()}
    if sample_stats is not None and isinstance(sample_stats, dict):
        sample_stats = {k: v.numpy().T for k, v in sample_stats.items()}
    if prior_predictive is not None and isinstance(prior_predictive, dict):
        prior_predictive = {k: v[np.newaxis]
                            for k, v in prior_predictive.items()}
    if posterior_predictive is not None and isinstance(posterior_predictive, dict):
        if isinstance(trace, az.InferenceData) and inplace == True:
            return trace + az.from_dict(posterior_predictive=posterior_predictive)
        else:
            trace = None

    return az.from_dict(
        posterior=trace,
        sample_stats=sample_stats,
        prior_predictive=prior_predictive,
        posterior_predictive=posterior_predictive,
        observed_data=observed_data,
    )

@tf.function(autograph=False, experimental_compile=USE_XLA)
def run_chain(init_state,
              bijectors, 
              step_size, 
              target_log_prob_fn, 
              num_samples=NUMBER_OF_SAMPLES,
              burnin=NUMBER_OF_BURNIN,
              ):    
    
    def _trace_fn_transitioned(_, pkr):
        return (
            pkr.inner_results.inner_results.target_log_prob,
            pkr.inner_results.inner_results.leapfrogs_taken,
            pkr.inner_results.inner_results.has_divergence,
            pkr.inner_results.inner_results.energy,
            pkr.inner_results.inner_results.log_accept_ratio
        )

    nuts_kernel = tfp.mcmc.NoUTurnSampler(
        target_log_prob_fn,
        step_size=step_size)  
    
    inner_kernel = tfp.mcmc.TransformedTransitionKernel(
        inner_kernel=nuts_kernel,
        bijector=bijectors)       
    
    kernel = tfp.mcmc.SimpleStepSizeAdaptation(
        inner_kernel=inner_kernel,
        target_accept_prob=.8,
        num_adaptation_steps=int(0.8*burnin),
        step_size_setter_fn=lambda pkr, new_step_size: pkr._replace(inner_results=pkr.inner_results._replace(step_size=new_step_size)),
        step_size_getter_fn=lambda pkr: pkr.inner_results.step_size,
        log_accept_prob_getter_fn=lambda pkr: pkr.inner_results.log_accept_ratio
    )    
    
    results, sampler_stat = tfp.mcmc.sample_chain(
        num_results=num_samples,
        num_burnin_steps=burnin,
        current_state=init_state,
        kernel=kernel,
        trace_fn=_trace_fn_transitioned)
    
    return results, sampler_stat

def sample_from_posterior(jdc, 
                          observed_data, 
                          params, 
                          init_state, 
                          bijectors,
                          num_chains=NUMBER_OF_CHAINS,
                          num_samples=NUMBER_OF_SAMPLES, 
                          burnin=NUMBER_OF_BURNIN):
    
    target_log_prob_fn = lambda *x: jdc.log_prob(x + (observed_data,))    

    step_size = 0.1

    results, sample_stats = run_chain(init_state,
                                      bijectors,
                                      step_size=step_size,
                                      target_log_prob_fn=target_log_prob_fn,                                      
                                      num_samples=num_samples, 
                                      burnin=burnin)

    stat_names = ['lp', 'tree_size',
                  'diverging', 'energy', 'mean_tree_accept']
    sampler_stats = dict(zip(stat_names, sample_stats))
    posterior = dict(zip(params, results))

    return _trace_to_arviz(trace=posterior, sample_stats=sampler_stats)

In [8]:
# You could change base url to local dir or a remoate raw github content
_BASE_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/Experimental/data"

REEDFROGS_DATASET_PATH = f"{_BASE_URL}/reedfrogs.csv"

## Code 13.1


Reedfrogs dataset is about the tadpole mortality. The objective will be determine the `surv` out of an initial count, `density`.

Author explains that within each tank there are things that go unmeasured and these unmeasured factors create variation in survival across tanks.

These tanks are an example of **cluster** variable

He argues that both of the approaches - 
    * treat the tanks independetly i.e. each of them have their unique intecepts 
    * treat them togather 

have issues. 

for e.g.

    - unique intecepts will imply that we are not using information from other tanks.
    - all togather will have the problem ignoring varations in baseline survival 
    
A multilevel model, in which we simultaneously estimate both an intercept for each tank and the variation among tanks, is what we want !

This type of a model is called **Varying intercepts** model.
        


In [9]:
d = pd.read_csv(REEDFROGS_DATASET_PATH, sep=";")
d.head()

Unnamed: 0,density,pred,size,surv,propsurv
0,10,no,big,9,0.9
1,10,no,big,10,1.0
2,10,no,big,7,0.7
3,10,no,big,10,1.0
4,10,no,small,9,0.9


## Code 13.2

Our simple model. This will give us 48 different intercepts. This means that it does not use the information available between each tank

In [10]:
d["tank"] = np.arange(d.shape[0])
alpha_sample_shape = d["tank"].shape[0]

dat = dict(
    S=tf.cast(d.surv.values, dtype=tf.float32),
    N=tf.cast(d.density.values, dtype=tf.float32),
    tank=d.tank.values)


def model_13_1(tid, N):
    def _generator():      
      alpha = yield Root(tfd.Sample(tfd.Normal(loc=0., scale=1.5), sample_shape=alpha_sample_shape))
      p = tf.sigmoid(tf.squeeze(tf.gather(alpha, tid, axis=-1)))        
        
      S = yield tfd.Independent(tfd.Binomial(total_count=N, probs=p), reinterpreted_batch_ndims=1)             
      
    return tfd.JointDistributionCoroutine(_generator, validate_args=False)    
    
jdc_13_1 = model_13_1(dat["tank"], dat["N"])

In [11]:
NUM_CHAINS_FOR_13_1 = 2

init_state = [
    tf.zeros([NUM_CHAINS_FOR_13_1, alpha_sample_shape])
]

bijectors = [
    tfb.Identity()
]

trace_13_1 = sample_from_posterior(jdc_13_1,
                                   observed_data=dat["S"],
                                   params=['alpha'],
                                   num_chains=NUM_CHAINS_FOR_13_1,
                                   init_state=init_state,
                                   bijectors=bijectors
                                )

az.summary(trace_13_1, round_to=2, kind='all')

Unnamed: 0,mean,sd,hpd_3%,hpd_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
alpha[0],1.71,0.78,0.4,3.26,0.02,0.02,1050.83,826.48,1079.81,764.0,1.0
alpha[1],2.41,0.88,0.76,4.03,0.03,0.02,867.2,741.09,898.79,746.5,1.0
alpha[2],0.74,0.62,-0.36,1.89,0.02,0.02,1270.07,723.55,1274.5,555.86,1.0
alpha[3],2.41,0.89,0.78,4.01,0.03,0.02,882.62,751.56,911.28,683.74,1.0
alpha[4],1.69,0.78,0.16,3.08,0.02,0.02,1160.85,859.41,1239.01,622.55,1.0
alpha[5],1.72,0.78,0.29,3.07,0.03,0.02,880.86,605.13,1042.57,423.41,1.0
alpha[6],2.41,0.96,0.75,4.37,0.04,0.03,708.08,483.04,852.12,419.68,1.01
alpha[7],1.74,0.77,0.33,3.18,0.03,0.02,819.28,565.11,879.38,637.77,1.0
alpha[8],-0.36,0.64,-1.58,0.78,0.02,0.02,1583.55,641.83,1579.28,827.27,1.01
alpha[9],1.76,0.76,0.4,3.2,0.02,0.02,1191.04,772.96,1247.42,558.43,1.01


## Code 13.3

We now build a multilevel model, which adaptively pools information across tanks.

In order to do so, we must make the prior for the parameter **alpha** a function of some new parameters.

Prior itself has priors !


In [12]:
def model_13_2(tid, N):
    def _generator():     
      a_bar = yield Root(tfd.Sample(tfd.Normal(loc=0., scale=1.5), sample_shape=1))
      sigma = yield Root(tfd.Sample(tfd.Exponential(rate=1.), sample_shape=1))
      alpha = yield tfd.Sample(tfd.Normal(loc=a_bar, scale=sigma), sample_shape=alpha_sample_shape)
      p = tf.sigmoid(tf.squeeze(tf.gather(alpha, tid, axis=-1)))              
        
      S = yield tfd.Independent(tfd.Binomial(total_count=N, probs=p), reinterpreted_batch_ndims=1)             
      
    return tfd.JointDistributionCoroutine(_generator, validate_args=False)    
    
jdc_13_2 = model_13_2(dat["tank"], dat["N"])

In [13]:
NUM_CHAINS_FOR_13_2 = 2

init_state = [
    tf.zeros([NUM_CHAINS_FOR_13_2]),
    tf.ones([NUM_CHAINS_FOR_13_2]),
    tf.zeros([NUM_CHAINS_FOR_13_2, alpha_sample_shape])
]

In [14]:
bijectors = [
    tfb.Identity(),
    tfb.Exp(),
    tfb.Identity()
]

In [15]:
trace_13_2 = sample_from_posterior(jdc_13_2,
                                   observed_data=dat["S"],
                                   params=['a_bar', 'sigma', 'alpha'],
                                   num_chains=NUM_CHAINS_FOR_13_2, 
                                   init_state=init_state,
                                   bijectors=bijectors)

az.summary(trace_13_2, round_to=2, kind='all')

Unnamed: 0,mean,sd,hpd_3%,hpd_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
a_bar,1.35,0.26,0.85,1.82,0.01,0.01,710.72,665.53,712.18,645.75,1.0
sigma,1.62,0.22,1.25,2.04,0.01,0.01,559.7,510.06,602.24,582.51,1.0
alpha[0],2.1,0.85,0.52,3.59,0.03,0.02,720.3,626.48,751.06,563.79,1.0
alpha[1],3.12,1.13,1.21,5.36,0.06,0.05,341.96,312.4,376.84,351.67,1.0
alpha[2],0.97,0.65,-0.26,2.21,0.02,0.02,1248.37,814.48,1305.25,553.27,1.0
alpha[3],2.97,1.09,1.14,5.29,0.05,0.04,461.82,408.22,505.5,362.35,1.0
alpha[4],2.18,0.86,0.64,3.7,0.03,0.03,628.08,542.48,640.25,430.55,1.0
alpha[5],2.19,0.89,0.62,3.76,0.03,0.03,691.42,574.6,757.57,463.03,1.0
alpha[6],2.99,1.05,1.14,4.88,0.06,0.04,358.46,330.04,384.29,434.78,1.01
alpha[7],2.15,0.86,0.63,3.74,0.03,0.02,683.44,598.7,735.04,570.25,1.0


## Code 13.4

In [16]:
# we must compute the likelhood before using arviz to do comparison
def compute_and_store_log_likelihood_for_model_13_1(num_chains):

    def log_like_13_1(a):
        p = tf.sigmoid(tf.squeeze(tf.gather(a, dat["tank"], axis=-1)))        
        return tfd.Binomial(total_count=dat["N"], probs=p).log_prob(dat["S"])

    log_likelihood_13_1 = []

    for i in range(num_chains):
        alpha = trace_13_1.posterior["alpha"].values[i]
        log_likelihood = np.array(list(map(log_like_13_1, alpha)))
        log_likelihood_13_1.append(log_likelihood)

    log_likelihood_13_1 = np.array(log_likelihood_13_1)

    # we need to insert this in the sampler_stats
    sample_stats_13_1 = trace_13_1.sample_stats

    coords = [sample_stats_13_1.coords['chain'], sample_stats_13_1.coords['draw'], np.arange(48)]

    sample_stats_13_1["log_likelihood"] = xr.DataArray(
        log_likelihood_13_1, 
        coords=coords,  
        dims=['chain', 'draw', 'log_likelihood_dim_0'])
    
compute_and_store_log_likelihood_for_model_13_1(num_chains=2)

In [17]:
def compute_and_store_log_likelihood_for_model_13_2(num_chains):

    def log_like_13_2(a):
        p = tf.sigmoid(tf.squeeze(tf.gather(a, dat["tank"], axis=-1)))        
        return tfd.Binomial(total_count=dat["N"], probs=p).log_prob(dat["S"])

    log_likelihood_13_2 = []

    for i in range(num_chains):
        alpha = trace_13_2.posterior["alpha"].values[i]
        log_likelihood = np.array(list(map(log_like_13_2, alpha)))
        log_likelihood_13_2.append(log_likelihood)

    log_likelihood_13_2 = np.array(log_likelihood_13_2)

    # we need to insert this in the sampler_stats
    sample_stats_13_2 = trace_13_2.sample_stats

    coords = [sample_stats_13_2.coords['chain'], sample_stats_13_2.coords['draw'], np.arange(48)]

    sample_stats_13_2["log_likelihood"] = xr.DataArray(
        log_likelihood_13_2, 
        coords=coords,  
        dims=['chain', 'draw', 'log_likelihood_dim_0'])
    
compute_and_store_log_likelihood_for_model_13_2(num_chains=2)

In [18]:
az.compare({"m13.1": trace_13_1, "m13.2": trace_13_2})

See http://arxiv.org/abs/1507.04544 for details
  "For one or more samples the posterior variance of the log predictive "


Unnamed: 0,rank,waic,p_waic,d_waic,weight,se,dse,warning,waic_scale
m13.2,0,200.349,20.9238,0.0,0.997519,4.30547,0.0,True,deviance
m13.1,1,215.937,26.3835,15.588,0.00248141,7.0354,3.97046,True,deviance
