In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from cmdstanpy import CmdStanModel
from baynes.plotter import FitPlotter
from baynes.model_utils import get_stan_file
import pandas as pd
from scipy import stats
sns.set_style('ticks')

sns.set_context("notebook", font_scale=1.6)
plt.rc("axes.spines", top=False, right=False)
import cmdstanpy
import logging
cmdstanpy.utils.get_logger().setLevel(logging.ERROR)
import arviz as az

plt.style.use('rose-pine-dawn')

# Example 1: fit of a poissonian process
### Generate the data

### Compile and print the STAN model


In [None]:
stan_file = get_stan_file('poisson_SBC.stan')
model = CmdStanModel(stan_file=stan_file,
                     cpp_options={'STAN_THREADS': True, 'jN': 4})
print(model.code())

In [None]:
from baynes.analysis import multithreaded_run

In [None]:
def SBC_plot(data, n_bins, percs=[0.05, 0.95]):
    if isinstance(data, str):
        data = pd.read_json(data)
    elif isinstance(data, dict):
        data = pd.DataFrame.from_dict(data)
    hist = sns.displot(data.melt(value_name='rank'), bins=n_bins, kind='hist', x='rank', col='variable', col_wrap=2, alpha=1)
    pdf = stats.binom(len(data), 1/n_bins)
    hist.set_titles("")
    for ax in hist.axes.flatten():
        xlim = ax.get_xlim()
        ax.fill_between(ax.get_xlim(), pdf.ppf(percs[0]), pdf.ppf(percs[1]), color='grey', alpha=0.20, zorder=0)
        ax.axhline(pdf.median(), color='grey', alpha=0.50, zorder=0)
        ax.set_xlim(xlim)
        ax.grid(visible=False)
    return hist

def ECDF_plot(data, n_bins, percs=[0.05, 0.95]):
    if isinstance(data, str):
        data = pd.read_json(data)
    elif isinstance(data, dict):
        data = pd.DataFrame.from_dict(data)
    hist = sns.displot(data.melt(value_name='rank'), bins=n_bins, kind='hist', x='rank', col='variable', hue='variable', col_wrap=2, alpha=1)
    pdf = stats.binom(len(data), 1/n_bins)
    
    for ax in hist.axes.flatten():
        xlim = ax.get_xlim()
        ax.fill_between(ax.get_xlim(), pdf.ppf(percs[0]), pdf.ppf(percs[1]), color='green', alpha=0.20, zorder=0)
        ax.axhline(pdf.median(), color='green', alpha=0.50, zorder=0)
        ax.set_xlim(xlim)
        ax.grid(visible=False)
    return hist

In [None]:
def run_SBC(model):
    data = {'N': 500, 'alpha':5, 'beta':1, 'alpha_true':5, 'beta_true':1}
    fit = model.sample(data,
                         chains=1,
                         iter_warmup=200,
                         iter_sampling=398,
                         save_warmup=False,
                         show_progress=False,
                         thin=2)
    df = fit.draws_pd('lt_lambda')
    rank = np.sum(df.to_numpy())
    return rank
rank = multithreaded_run(run_SBC, [model]*400, 4)


In [None]:
fplot = FitPlotter(output_dir="/home/pietro/work/TESI/thesis/figures/ch2/poisson/", output_format='.pdf')
hist = SBC_plot({'lambda':rank}, 20)
fplot.new_figure('SBChist', hist.figure)
fplot.resize(10,6)


In [None]:
pdf = stats.uniform(0,200)

ax = fplot.new_figure('SBCecdf').subplots()
ax = az.plot_ecdf(rank, cdf = pdf.cdf, difference=True, ax=ax, plot_kwargs={}, fill_kwargs={'color': 'grey'})
ax.set_xlabel('rank')
ax.set_ylabel('ECDF difference')
ax.plot([0]*200, ls='--', color='black', alpha=0.4)
fplot.resize(8,6)
fplot.save_figures('all')

In [None]:

SBC_plot({'lambda':rank}, 20)
pdf = stats.uniform(0,200)
ax = az.plot_ecdf(rank, cdf = pdf.cdf, difference=True, plot_kwargs={}, fill_kwargs={'color': 'grey'})
ax.set_xlabel('rank')
ax.set_ylabel('ECDF difference')
ax.plot([0]*200, ls='--', color='black', alpha=0.4)


In [None]:
exp = len(rank)/20
counts, bins=np.histogram(rank, bins=20)
chi_test = sum([(a-exp)**2/exp for a in counts])

In [None]:
chi_test

In [None]:
stats.chi2.sf(chi_test, 20)

In [None]:
chi_test

In [None]:
pdf = stats.uniform(0,200)
ax = az.plot_ecdf(rank, cdf = pdf.cdf, difference=True, plot_kwargs={'color': 'darkred'}, fill_kwargs={'color': 'grey'})
ax.set_xlabel('rank')
ax.set_ylabel('ECDF difference')
ax.plot([0]*200, ls='--', color='black', alpha=0.4)

In [None]:
plt.hist(np.random.gamma(20,1, 10000), bins=50)