In [1]:
import inspect
import math
import os
import warnings

#import arviz as az

import jax.numpy as jnp
from jax import lax, ops, random

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.diagnostics import print_summary
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, init_to_value, Predictive
from numpyro.infer.autoguide import AutoLaplaceApproximation

import arviz

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
from scipy.stats import beta
%matplotlib inline

# 4 Die rolls

Let's consider multiple manufactered die of different colors, with thee following properties:

1. A particular die has it's own probabilities of rolling a 1,2,3,...,6. i.e. they are not "fair" dice.
2. Manufacturing processes for the different dice lead to different distributions of the probabilities for that die

That is, while our red dice have a manufactering process that -- on average -- produces fair dice, any particular red die will likely be a littlee unfair.

In [2]:
green_avg_fair = np.random.dirichlet(alpha = [1, 1, 1, 1, 1, 1], size=3)
blue_biased_die = np.random.dirichlet(alpha=[1, 1, 5, 1, 1, 1], size=5)
red_really_fair = np.random.dirichlet(alpha=[10,10,10,10,10,10], size=3)

In [3]:
red_really_fair

array([[0.14520733, 0.1472884 , 0.16182545, 0.15836011, 0.18349667,
        0.20382205],
       [0.11543241, 0.18757482, 0.11850072, 0.17620032, 0.20093241,
        0.20135933],
       [0.19985665, 0.19275562, 0.11016164, 0.19987729, 0.15544782,
        0.14190098]])

In [4]:
def create_dice(color_name, p_values, n_rolls):
    assert len(p_values)==len(n_rolls)
    df = pd.DataFrame({
        'outcomes': [np.random.multinomial(n=n_r,pvals=p_v) 
                     for n_r, p_v in zip(n_rolls, p_values)],
        'rolls': n_rolls,
        'color': color_name
    })
    for result in [1,2,3,4,5,6]:
        df[f'n_{result}'] = df['outcomes'].apply(lambda x: x[result-1])
        
    return df

In [38]:
_dfs = [
    create_dice('red', red_really_fair, [300, 400, 800]),
    create_dice('green', green_avg_fair, [40, 200, 90]),
    create_dice('blue', blue_biased_die, [40, 60, 200, 90, 100])
]

df_experiment = pd.concat(_dfs).reset_index(drop=True)

In [39]:
df_experiment

Unnamed: 0,outcomes,rolls,color,n_1,n_2,n_3,n_4,n_5,n_6
0,"[45, 44, 50, 50, 57, 54]",300,red,45,44,50,50,57,54
1,"[42, 67, 51, 76, 79, 85]",400,red,42,67,51,76,79,85
2,"[161, 165, 79, 168, 113, 114]",800,red,161,165,79,168,113,114
3,"[0, 3, 5, 12, 6, 14]",40,green,0,3,5,12,6,14
4,"[11, 50, 96, 9, 13, 21]",200,green,11,50,96,9,13,21
5,"[35, 7, 34, 7, 2, 5]",90,green,35,7,34,7,2,5
6,"[4, 11, 16, 1, 3, 5]",40,blue,4,11,16,1,3,5
7,"[1, 1, 44, 0, 12, 2]",60,blue,1,1,44,0,12,2
8,"[0, 43, 89, 6, 8, 54]",200,blue,0,43,89,6,8,54
9,"[19, 6, 19, 20, 11, 15]",90,blue,19,6,19,20,11,15


```
p = 1/(1+np.exp(-z))
1/p = 1+exp(-z)
(1-p)/p = exp(-z)
z = log(p/1-p)
```

In [40]:
np.log((1/6)/(5/6))

-1.6094379124341005

In [41]:
def simple_color_based_pooling(X, color_index, n_sixes=None):
    """
    """
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    z_color = numpyro.sample("z_color", dist.Normal(-1.61, 0.3), sample_shape=(3,))
    
    with numpyro.plate('die', len(X)):
        z_die = numpyro.sample("z_die", dist.Normal(0, 0.1))

        z = z_color[color_index] + z_die
        p = numpyro.deterministic("p_6", 1/(1+jnp.exp(-z)))
        return numpyro.sample("sixes", dist.Binomial(total_count=X.rolls.values, probs=p), obs=n_sixes)

In [42]:
df_experiment['color'] = df_experiment.color.astype("category")

In [47]:
chain_run = MCMC(NUTS(simple_color_based_pooling), num_warmup=2000, num_samples=2500, num_chains=1)

In [48]:
chain_run.run(rng_key=random.PRNGKey(0), X=df_experiment, color_index=df_experiment.color.cat.codes.values,
             n_sixes=df_experiment.n_6.values)

sample: 100%|██████████| 4500/4500 [00:07<00:00, 609.83it/s, 7 steps of size 5.14e-01. acc. prob=0.90]


In [49]:
chain_run.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
     sigma      0.99      1.03      0.67      0.00      2.34   4369.14      1.00
z_color[0]     -1.59      0.12     -1.58     -1.80     -1.41   3649.96      1.00
z_color[1]     -1.87      0.16     -1.87     -2.15     -1.64   3795.27      1.00
z_color[2]     -1.57      0.09     -1.57     -1.72     -1.43   2609.39      1.00
  z_die[0]      0.02      0.09      0.02     -0.13      0.17   3891.82      1.00
  z_die[1]      0.10      0.09      0.10     -0.04      0.25   3530.68      1.00
  z_die[2]     -0.11      0.08     -0.11     -0.25      0.03   2839.78      1.00
  z_die[3]      0.08      0.10      0.08     -0.07      0.24   4838.60      1.00
  z_die[4]     -0.05      0.10     -0.04     -0.20      0.12   4660.18      1.00
  z_die[5]     -0.06      0.09     -0.06     -0.21      0.09   5890.96      1.00
  z_die[6]     -0.02      0.10     -0.02     -0.19      0.13   4429.56      1.00
  z_die[7]     -0.08      0

In [50]:
chain_run.get_samples()['p_6'].mean(axis=0)

DeviceArray([0.17463836, 0.18692078, 0.15655674, 0.14407547, 0.12884797,
             0.12707298, 0.16853389, 0.16045403, 0.19370647, 0.17036732,
             0.16358174], dtype=float32)

In [51]:
np.array(chain_run.get_samples()['p_6'].mean(axis=0))

array([0.17463836, 0.18692078, 0.15655674, 0.14407547, 0.12884797,
       0.12707298, 0.16853389, 0.16045403, 0.19370647, 0.17036732,
       0.16358174], dtype=float32)

In [52]:
df_experiment['p_6_raw'] = df_experiment['n_6'] / df_experiment['rolls']
df_experiment['p_6_shrink'] = np.array(chain_run.get_samples()['p_6'].mean(axis=0))

In [53]:
df_experiment

Unnamed: 0,outcomes,rolls,color,n_1,n_2,n_3,n_4,n_5,n_6,p_6_raw,p_6_shrink
0,"[45, 44, 50, 50, 57, 54]",300,red,45,44,50,50,57,54,0.18,0.174638
1,"[42, 67, 51, 76, 79, 85]",400,red,42,67,51,76,79,85,0.2125,0.186921
2,"[161, 165, 79, 168, 113, 114]",800,red,161,165,79,168,113,114,0.1425,0.156557
3,"[0, 3, 5, 12, 6, 14]",40,green,0,3,5,12,6,14,0.35,0.144075
4,"[11, 50, 96, 9, 13, 21]",200,green,11,50,96,9,13,21,0.105,0.128848
5,"[35, 7, 34, 7, 2, 5]",90,green,35,7,34,7,2,5,0.055556,0.127073
6,"[4, 11, 16, 1, 3, 5]",40,blue,4,11,16,1,3,5,0.125,0.168534
7,"[1, 1, 44, 0, 12, 2]",60,blue,1,1,44,0,12,2,0.033333,0.160454
8,"[0, 43, 89, 6, 8, 54]",200,blue,0,43,89,6,8,54,0.27,0.193706
9,"[19, 6, 19, 20, 11, 15]",90,blue,19,6,19,20,11,15,0.166667,0.170367


What do we think the population effects of the different colors are?

Red:

In [17]:
z_colors = (
    chain_run.get_samples()['bias'].mean(axis=0) 
    + chain_run.get_samples()['z_color'].mean(axis=0)
)

In [18]:
1/(1+jnp.exp(-z_colors))

DeviceArray([0.18245903, 0.22007586, 0.19692445], dtype=float32)

In [19]:
df_experiment.color

0       red
1       red
2       red
3     green
4     green
5     green
6      blue
7      blue
8      blue
9      blue
10     blue
Name: color, dtype: category
Categories (3, object): ['blue', 'green', 'red']

In [56]:
chain_run.get_samples()['z_die'].mean(axis=0)

DeviceArray([ 0.01707763,  0.1008127 , -0.11268052,  0.0807372 ,
             -0.04657645, -0.06401747, -0.01786875, -0.07664743,
              0.15381113, -0.00475951, -0.0529045 ], dtype=float32)