In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from cmdstanpy import CmdStanModel
from baynes.plotter import FitPlotter
from baynes.model_utils import get_stan_file
from baynes.analysis import *
from baynes.probability import hdi
import pandas as pd
from scipy import stats
import cmdstanpy
import logging
cmdstanpy.utils.get_logger().setLevel(logging.ERROR)
sns.set_style('ticks')
#sns.set_palette('colorblind')
sns.set_context("notebook", font_scale=1.6)
plt.rc("axes.spines", top=False, right=False)

# Example 1: fit of a poissonian process
### Generate the data

In [None]:
N=1000
lambda_true = 6.3
events = np.random.poisson(lambda_true, N)

data_mean = np.mean(events)
data_sd = np.std(events)
print('- mean of data: ', data_mean)
print('- sd of data: ', data_sd)
bins = np.arange(min(events), max(events))
sns.histplot(events, bins=bins)
sns.despine()

### Compile and print the STAN model


In [None]:
stan_file = get_stan_file('poisson.stan')
model = CmdStanModel(stan_file=stan_file,
                     cpp_options={'STAN_THREADS': True, 'jN': 4})
print(model.code())

In [None]:
df=pd.DataFrame()
prior_var=10
data = {'N': len(events), 'y': events, 'alpha': 5, 'beta': 1, 'prior': 0}

plot = FitPlotter(fig_scale=10)
post=[]
mus= np.linspace(1, 51, 10)

Ns=[50, 100, 200, 500]
for N in Ns:
    events = np.random.poisson(lambda_true, N)
    data = {'N': len(events), 'y': events, 'alpha': 5, 'beta': 1, 'prior': 0}
    dN = pd.DataFrame()
    for mu in mus:
        data['beta'] = mu/prior_var
        data['alpha'] = mu**2/prior_var
    
        fit = model.sample(data,
                           chains=1,
                           iter_warmup=200,
                           iter_sampling=300,
                           save_warmup=False,
                           show_progress=False)
       # plot.add_fit(fit, fit_title= str(mu))
       # post.append(fit.stan_variable('lambda'))
        
        dfit = fit.draws_pd(['lambda'])
      #  print(hdi(dfit.to_numpy().flatten()))
        dfit['prior mean'] = mu
       # print(dfit)
        dN = pd.concat([dN, dfit])
    dN['N'] = N
    df = pd.concat([df, dN])

#plot.get_fit_titles()
#plot.ridgeplot(fit_titles='all')

In [None]:
ax = plot.new_figure('sens').subplots()
sns.lineplot(df, x='prior mean', y='lambda', hue='N', style='N', errorbar=hdi, ax=ax)
ax.legend(bbox_to_anchor=(1.1, 0.7), facecolor='white', edgecolor='white', title='N data')
#ax.set_ylim(6.2, 6.4)

In [None]:
plot.get_current_figure()

In [None]:
plot.ridgeplot(fit_titles=plot.get_fit_titles('mu'), height=0.7)

In [None]:
plot.draws_df(fit_titles=plot.get_fit_titles('ad'), parameters='lambda').melt()

In [None]:
sns.lineplot(plot.draws_df(fit_titles='all', parameters=['lambda']), x='fit', y='lambda')

In [None]:
plot.draws_df(fit_titles='all', parameters=['lambda'])

In [None]:
sns.lineplot({'x': mus*len(post), 'y': post}, y='y')

In [None]:
np.array(post).shape

In [None]:
plt.hist(np.random.gamma(100, 1, 10000), bins=50)

In [None]:
data = {'N': len(events), 'y': events, 'alpha': 5, 'beta': 1, 'prior': 0}

plot = FitPlotter()
for mu in [1, 10, 20, 40, 60, 80, 100]:
    data['alpha'] = mu
    fit = model.sample(data,
                       chains=4,
                       iter_warmup=200,
                       iter_sampling=300,
                       save_warmup=False,
                       show_progress=False)
    plot.add_fit(fit, fit_title= "$\\alpha=$" + str(mu))
plot.get_fit_titles()
plot.ridgeplot(fit_titles=plot.get_fit_titles('alpha'))

In [None]:
list(plot.fits.keys())

In [None]:
N=100
events = np.random.poisson(lambda_true, N)
data = {'N': len(events), 'y': events, 'alpha': 3, 'beta': 0.1, 'prior': 1}
sampler_kwargs={
    'chains': 4,
    'iter_warmup': 500,
    'iter_sampling': 2000,
    'save_warmup': True, 
}
fplot = FitPlotter(fig_scale=7)
fit = standard_analysis(model, data, fplot, sampler_kwargs=sampler_kwargs, data_key='y', rep_key='y_rep', lines=True, legend=False, n_bins=len(bins))

fit_prior = fplot.fits['fit_prior']
joint = sns.jointplot({'posterior': fit.stan_variable('lambda'), 'prior': fit_prior.stan_variable('lambda')}, x='prior', y='posterior', kind='reg', scatter_kws={'s':0.1}, robust=True)
fig = fplot.new_figure('static', joint.figure)

In [None]:
joint = sns.jointplot(fit.draws_pd(['lambda', 'sd_y_rep']), x='lambda', y='sd_y_rep', kind='reg', scatter_kws={'s':0.1})
joint = sns.jointplot(fit.draws_pd(['lambda', 'mean_y_rep']), x='lambda', y='mean_y_rep', kind='reg', scatter_kws={'s':0.1})

fig = fplot.new_figure('static', joint.figure)

In [None]:
fplot.pair_grid(['lambda', 'sd_y_rep', 'mean_y_rep'])

In [None]:
fplot.draws_df(fit_titles='all', parameters=['lambda', 'mean_y_rep'])

In [None]:
fit_prior = fplot.fits['fit_prior']
joint = sns.jointplot({'posterior': fit.stan_variable('lambda'), 'prior': fit_prior.stan_variable('lambda')}, x='prior', y='posterior', xlim=[0,15], kind='reg', scatter_kws={'s':0.1})
fig = fplot.new_figure('static', joint.figure)