# Import packages
Use kernel "ABM_env" -- see README.

This file processes the results from running batch1d_posterior_estimation.py. The results are included in the data repository -- see README.txt.

In [None]:
from __future__ import division
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import math
import time
from tqdm import tqdm
import cProfile
import pickle
from scipy.stats import gaussian_kde
from scipy import interpolate
import pandas as pd
from math import sqrt
import scipy.stats as stats
from matplotlib.pyplot import imshow
import matplotlib.colors as mcolors
from multiprocessing import Pool
import multiprocessing as mp
import time
import math
import matplotlib.ticker as ticker

from ABM import SEIR_multiple_pops
from amcmc import ammcmc
import os
from CalibrationMethod1_2D_methods import *

## Load results


In [None]:
#Load in mobility and jumping probability values:
parameter_matrix = pd.read_csv('./Data/Test Data/Two-parameter case/variable_parameter_values_Two-Pop-NEW-COMBINED-TEST.csv', index_col=0).to_numpy()
mobilities = parameter_matrix[:,0]
jumping_probs = parameter_matrix[:,1]
random_seeds = parameter_matrix[:,2]
num_of_samples = mobilities.shape[0]
    

## Visualize results:
Set sample_id = 1500 to reproduce Fig. 10.


In [None]:
#--------------Specify batch posterior estimation folder, nburn, sample_id-----------------
MCMC_res_folder = './Data/MCMC_results/Two-Pop/'
nburn = 0
sample_id = 1500

#------------- Load in mobility and jumping probability values ------------:
parameter_matrix = pd.read_csv('./Data/Test Data/Two-parameter case/variable_parameter_values_Two-Pop-NEW-COMBINED-TEST.csv', index_col=0).to_numpy() 
mobilities = parameter_matrix[:,0]
jumping_probs = parameter_matrix[:,1]
random_seeds = parameter_matrix[:,2]
num_of_samples = mobilities.shape[0]

completed_samples = []
incomplete_samples = []
within_95_conf_int_samples = np.zeros(1666)
within_50_conf_int_samples = np.zeros(1666)

    
print('-------------------------')
print('-------Sample '+str(sample_id)+'-------')
print('-------------------------')

actual_mob = mobilities[sample_id]
actual_jp = jumping_probs[sample_id]


# -------- set up colormaps -----------
color_1 = (237/255,248/255,251/255)
color_2 = (35/255,139/255,69/255)
colors = [color_1, color_2]
my_colormap = mcolors.LinearSegmentedColormap.from_list("CustomColormap", colors)

#-------------------------------------------------
#-----------Load and plot MCMC results------------
#-------------------------------------------------
try:
    #Load MCMC results:
    MCMC_file = MCMC_res_folder+'AMCMC_sample_ind_'+str(sample_id)+'.pickle'
    data_file = open(MCMC_file, "rb")
    sol = pickle.load(data_file)
    data_file.close()

    # -------- calculate marginal CI boundaries -----------

    #Calculate marginal 95% CI for mobility
    upper_95_CI_mob = np.percentile(sol['chain'][nburn:,0], 97.5)
    lower_95_CI_mob = np.percentile(sol['chain'][nburn:,0], 2.5)

    #Calculate marginal 95% CI for jumping prob
    upper_95_CI_jp = np.percentile(sol['chain'][nburn:,1], 97.5)
    lower_95_CI_jp = np.percentile(sol['chain'][nburn:,1], 2.5)

    #Calculate marginal 95% CI for mobility
    upper_50_CI_mob = np.percentile(sol['chain'][nburn:,0], 75)
    lower_50_CI_mob = np.percentile(sol['chain'][nburn:,0], 25)

    #Calculate marginal 95% CI for jumping prob
    upper_50_CI_jp = np.percentile(sol['chain'][nburn:,1], 75)
    lower_50_CI_jp = np.percentile(sol['chain'][nburn:,1], 25)

    #Calculate 50% and 95% confidence interval bounds:
    Z, inside_95, inside_hollow_95, Z_renormalized_95, Z_marked_95, x_flat, y_flat = MCMC_KDE(sol, nburn = nburn, target = 0.95)
    #     within_confidence_int = check_if_within_confidence_int([actual_mob, actual_jp], x_flat, y_flat, inside)
    Z, inside_50, inside_hollow_50, Z_renormalized_50, Z_marked_50, x_flat, y_flat = MCMC_KDE(sol, nburn = nburn, target = 0.5)

    #----------------Plot scatter plot of chain------------------
    plt.figure(dpi = 500, figsize = (3.5,3.5))
    plt.scatter(sol['chain'][nburn:,1], sol['chain'][nburn:,0], s=1, color = (102/255,194/255,164/255), label = 'MCMC chain')
    plt.ylim(0.005,0.025)
    plt.xlim(0,0.001)
    plt.scatter(actual_jp, actual_mob, color = 'blue', marker = '*', label = 'True parameter')
    plt.ylabel('Mobility')
    plt.xlabel('Jumping probability')
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    plt.legend()
    #     plt.savefig('figs_final/2d_MCMC_results/chain-sample'+str(sample_id)+'.png',bbox_inches='tight')
    plt.show()

    #----------------Posterior plot with 95 and 50% conf intervals (WITH colorbar) ------------------
    colors = [(0.9,0,0,c) for c in np.linspace(0,1,100)]
    cmapred = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=2)
    colors = [(0.9,0.6,0.6,c) for c in np.linspace(0,1,100)]
    cmapblue = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=2)

    fig, ax = plt.subplots(dpi = 300)
    ax_main = plt.subplot()
    pcm = ax_main.pcolormesh(y_flat, x_flat, Z.T, cmap=my_colormap)
    ax_main.pcolormesh(y_flat, x_flat, inside_hollow_95.T,shading='auto',cmap=cmapblue,label='95% CI')
    ax_main.pcolormesh(y_flat, x_flat, inside_hollow_50.T,shading='auto',cmap=cmapred,label='50% CI')
    plt.xlabel('Jumping probability')
    plt.ylabel('Mobility')
    plt.scatter(actual_jp, actual_mob, c='blue', label = 'True parameter', marker ='*')
    d = []
    plt.plot(d,d,linewidth=2,color = (0.9,0,0), label='95% CI')
    plt.plot(d,d,linewidth=2,color = (0.9,0.6,0.6), label='50% CI')
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    #     plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    plt.legend()

    ax_cbar = plt.subplot()
    cax = make_square_axes_with_colorbar(ax=ax_cbar, size=0.15, pad=0.1)
    cbar = fig.colorbar(pcm, cax=cax)
    cbar.set_label('Probability density')

    #     cbar.set_ticks([])  # Remove tick labels from the colorbar

    # plt.savefig('figs_final/2d_MCMC_results/posterior-with-colorbar-sample'+str(sample_id)+'.png',bbox_inches='tight')

    plt.show()

    #----------------Posterior plot with no confidence intervals (WITH colorbar) ------------------
    colors = [(0.9,0,0,c) for c in np.linspace(0,1,100)]
    cmapred = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=2)
    colors = [(0.9,0.6,0.6,c) for c in np.linspace(0,1,100)]
    cmapblue = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=2)

    fig, ax = plt.subplots(dpi = 500, figsize = (3.5,3.5))
    ax_main = plt.subplot()
    pcm = ax_main.pcolormesh(y_flat, x_flat, Z.T, cmap=my_colormap)
    plt.xlabel('Jumping probability')
    plt.ylabel('Mobility')
    plt.ylim(0.005,0.025)
    plt.xlim(0,0.001)
    plt.scatter(actual_jp, actual_mob, c='blue', label = 'True parameter', marker ='*')
    d = []
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    #     plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

    #     plt.legend()

    exp_value = 5 #MANUALLY SET THE EXP VALUE FOR SCI NOTATION ON COLORBAR

    def fmt(x, pos):
        return str(x/(10**exp_value))

    ax_cbar = plt.subplot()
    # cax = make_square_axes_with_colorbar(ax=ax_cbar, size=0.15, pad=0.1)
    cbar = fig.colorbar(pcm, cax=cax, format=ticker.FuncFormatter(fmt))
    cbar.set_label('Probability density')

    # # Adding exponent value at the top
    # exp_value = int(np.log10(ticks[-1]))
    cbar.ax.text(1.05, 1, f"1e{exp_value}", transform=cbar.ax.transAxes, va='bottom', ha='left')
    #     cbar.set_ticks([])  # Remove tick labels from the colorbar

    #     plt.savefig('figs_final/2d_MCMC_results/posterior-no-CI_sample'+str(sample_id)+'.png',bbox_inches='tight')

    plt.show()

    # ----------------- Marginal plots --------------------
    mob_marginal = np.trapz(Z, y_flat, axis = 0)
    jp_marginal = np.trapz(Z, x_flat, axis = 1)

    print('mob marginal integral', np.trapz(mob_marginal, x_flat))
    print('jp marginal integral', np.trapz(jp_marginal, y_flat))

    #jumping probability marginal:
    plt.figure(figsize = (3.5, 1), dpi = 500)
    plt.plot(y_flat, jp_marginal, linewidth=2, color = color_2)
    plt.xlim(0,0.001)
    plt.axvline(actual_jp, linewidth=2, color = "blue", linestyle = "--") #true param
    plt.axvline(upper_95_CI_jp, linewidth=2, color = (0.9,0,0))
    plt.axvline(lower_95_CI_jp, linewidth=2, color = (0.9,0,0))
    plt.axvline(upper_50_CI_jp, linewidth=2, color = (0.9,0.6,0.6))
    plt.axvline(lower_50_CI_jp, linewidth=2, color = (0.9,0.6,0.6))
    #     plt.savefig('figs_final/2d_MCMC_results/JP_marginal_posterior-sample'+str(sample_id)+'.png',bbox_inches='tight')
    plt.show()

    #mobility marginal:
    plt.figure(figsize = (3.5, 1), dpi = 500)
    plt.scatter([], [], c='blue', marker="*", label='True parameter')
    plt.axvline(actual_mob, linewidth=2, color = "blue", label='True parameter', linestyle = "--") #true param
    plt.plot(x_flat, mob_marginal, linewidth=2, color = color_2, label = 'Marginal posterior')
    plt.xlim(0.025,0.005)
    plt.axvline(upper_95_CI_mob, linewidth=2, color = (0.9,0,0), label='Marginal 95% CI')
    plt.axvline(lower_95_CI_mob, linewidth=2, color = (0.9,0,0))
    plt.axvline(upper_50_CI_mob, linewidth=2, color = (0.9,0.6,0.6), label='Marginal 50% CI')
    plt.axvline(lower_50_CI_mob, linewidth=2, color = (0.9,0.6,0.6))
    plt.xticks(np.linspace(0.0050,0.025,9))
    plt.legend(bbox_to_anchor=(1.3, 0.95))
    #     plt.savefig('figs_final/2d_MCMC_results/MOB_marginal_posterior-sample'+str(sample_id)+'.png',bbox_inches='tight')

    plt.show()

    # ----------------- Range checking and print outs ------------------

    interp_95=scipy.interpolate.RegularGridInterpolator((y_flat,x_flat), inside_95, method='nearest')
    interp_50=scipy.interpolate.RegularGridInterpolator((y_flat,x_flat), inside_50, method='nearest')

    within_95_conf_int_samples[sample_id] = interp_95((actual_jp, actual_mob))
    within_50_conf_int_samples[sample_id] = interp_50((actual_jp, actual_mob))

    print('95% confidence interval:', within_95_conf_int_samples[sample_id])
    print('50% confidence interval:', within_50_conf_int_samples[sample_id])

except:
    print('Issue with MCMC')

In [None]:

#------------------------------------------------------------------
#-----------Load and plot brute force posterior results------------
#------------------------------------------------------------------
file_brute = './Data/Brute_force_posterior_estimation/Two-Pop/IND'+str(sample_id)+'LOG_PROBS.txt'  #Combined first and second batch of MCMC

brute_posterior = pd.read_csv(file_brute, header=None, sep=" ").to_numpy()
x_flat = np.unique(brute_posterior[:,0])
y_flat = np.unique(brute_posterior[:,1])[:-1]
brute_posterior_truncate = brute_posterior[:y_flat.shape[0]*x_flat.shape[0]]
Z = np.exp((brute_posterior_truncate[:,2]).reshape((y_flat.shape[0], x_flat.shape[0])))

#----------------Plot with pcolormesh: likelihood (WITH colorbar)------------------
fig, ax = plt.subplots(dpi = 500,  figsize = (4.25,4.25))

ax_main = plt.subplot()
pcm = ax_main.pcolormesh(y_flat, x_flat, Z.T, cmap=my_colormap)
ax_main.set_xlabel('Jumping probability')
ax_main.set_ylabel('Mobility')
plt.ylim(0.005,0.025)
plt.xlim(0,0.001)
ax_main.scatter(jumping_probs[sample_id], mobilities[sample_id], c='blue', marker="*", label='True parameter')
ax_main.legend()
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
# plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

#------ new colorbar code (scientific notation on tick marks) =========
import matplotlib.ticker as ticker

exp_value = -4 #MANUALLY SET THE EXP VALUE FOR SCI NOTATION ON COLORBAR

def fmt(x, pos):
    return str(round(x/(10**exp_value), 2))

ax_cbar = plt.subplot()
cax = make_square_axes_with_colorbar(ax=ax_cbar, size=0.15, pad=0.1)
cbar = fig.colorbar(pcm, cax=cax, format=ticker.FuncFormatter(fmt))
cbar.set_label('Likelihood')

# # Adding exponent value at the top
# exp_value = int(np.log10(ticks[-1]))
cbar.ax.text(1.05, 1, f"1e{exp_value}", transform=cbar.ax.transAxes, va='bottom', ha='left')

# Adjust layout to prevent overlap of labels and colorbar
plt.tight_layout()
# plt.savefig('figs_final/2d_MCMC_results/likelihood-with-colorbar-sample'+str(sample_id)+'.png',bbox_inches='tight')

plt.show()

#----------------Plot with pcolormesh: posterior (WITH colorbar)------------------

#normalize to get posterior from likelihood:
integral = np.trapz(Z, y_flat, axis = 0)
integral = np.trapz(integral, x_flat, axis = 0)
print(integral)
Z = Z/integral
# Z[10,:] = 100000 #testing visualization direction

fig, ax = plt.subplots(dpi = 500,  figsize = (4.25,4.25))

ax_main = plt.subplot()
pcm = ax_main.pcolormesh(y_flat, x_flat, Z.T, cmap=my_colormap)
ax_main.set_xlabel('Jumping probability')
ax_main.set_ylabel('Mobility')
plt.ylim(0.005,0.025)
plt.xlim(0,0.001)
ax_main.scatter(jumping_probs[sample_id], mobilities[sample_id], c='blue', marker="*", label='True parameter')
ax_main.legend()
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
# plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

#------ new colorbar code (scientific notation on tick marks) =========
import matplotlib.ticker as ticker

exp_value = 5 #MANUALLY SET THE EXP VALUE FOR SCI NOTATION ON COLORBAR

def fmt(x, pos):
    return str(round(x/(10**exp_value), 2))

ax_cbar = plt.subplot()
# cax = make_square_axes_with_colorbar(ax=ax_cbar, size=0.15, pad=0.1)
cbar = fig.colorbar(pcm, cax=cax, format=ticker.FuncFormatter(fmt))
cbar.set_label('Probability density')

# # Adding exponent value at the top
# exp_value = int(np.log10(ticks[-1]))
cbar.ax.text(1.05, 1, f"1e{exp_value}", transform=cbar.ax.transAxes, va='bottom', ha='left')

# Adjust layout to prevent overlap of labels and colorbar
plt.tight_layout()
# plt.savefig('figs_final/2d_MCMC_results/grid-sampled-posterior-with-colorbar-sample'+str(sample_id)+'.png',bbox_inches='tight')

plt.show()

## Calculate ESS values

In [None]:
# Check ESS values (first pass):
MCMC_results_folder = "./Data/MCMC_results/Two-Pop"
nburn = 25000
end = 75000

def check_ESS(sample_id):
    MCMC_file = MCMC_results_folder+"/AMCMC_sample_ind_"+str(sample_id)+".pickle"
    data_file = open(MCMC_file, "rb")
    sol = pickle.load(data_file)
    data_file.close()
    
    n = end-nburn
    auto_corr = compute_group_auto_corr(sol['chain'][nburn:end], 200)
    mob_ESS, mob_ESS_went_to_zero = compute_effective_sample_size(n,auto_corr[:,0])
    jp_ESS, jp_ESS_went_to_zero = compute_effective_sample_size(n,auto_corr[:,1])
    
    print('*')
    
    return [mob_ESS, mob_ESS_went_to_zero, jp_ESS, jp_ESS_went_to_zero]
    
total = 1666
result = []
for i in range(0,total):
    result.append(check_ESS(i))

# num_workers = mp.cpu_count()
# print(num_workers)
# start_time = time.perf_counter()
# with Pool(num_workers) as pool:
#     result = pool.map(check_ESS, range(0,total))
# finish_time = time.perf_counter()
# print("Program finished in {} seconds - using multiprocessing".format(finish_time-start_time))
# print("---")

In [None]:
user_input = input("Do you want to save the data? Type 'yes' or 'no': ")
if user_input.lower() == 'yes': #So that I don't overwrite the data with zeros
    print('Saving...')
    with open(MCMC_results_folder+'/ESS.pickle', 'wb') as handle:
        pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
    print('DID NOT SAVE')

In [None]:
# Load ESS results
data_file = open(MCMC_results_folder+'/ESS.pickle', "rb") # Combined batch of MCMC
result = pickle.load(data_file)
data_file.close()

In [None]:
#Update ESS results for samples that didn't have auto-correlation go to zero on the first pass:

updated_result = []

for i in tqdm(range(1666)):
    if result[i][1] == True and result[i][3]==True:
        updated_result.append(result[i])
    else:
        MCMC_file = MCMC_results_folder+"/AMCMC_sample_ind_"+str(i)+".pickle"
        data_file = open(MCMC_file, "rb")
        sol = pickle.load(data_file)
        data_file.close()
        
        n = end-nburn
        auto_corr = compute_group_auto_corr(sol['chain'][nburn:end], 3000)
        mob_ESS, mob_ESS_went_to_zero = compute_effective_sample_size(n,auto_corr[:,0])
        jp_ESS, jp_ESS_went_to_zero = compute_effective_sample_size(n,auto_corr[:,1])
        updated_result.append([mob_ESS, mob_ESS_went_to_zero, jp_ESS, jp_ESS_went_to_zero])

In [None]:
# Save ESS results
user_input = input("Do you want to save the data? Type 'yes' or 'no': ")
if user_input.lower() == 'yes': #So that I don't overwrite the data with zeros
    print('Saving...')
    with open(MCMC_results_folder+'/refined_ESS.pickle', 'wb') as handle:
        pickle.dump(updated_result, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
    print('DID NOT SAVE')

In [None]:
# Load ESS results
MCMC_results_folder = "./Data/MCMC_results/Two-Pop"
data_file = open(MCMC_results_folder+'/refined_ESS.pickle', "rb") # Combined batch of MCMC
updated_result = pickle.load(data_file)
data_file.close()

ESS_values = updated_result

## Simulation based calibration

### SBC, with rank approximation for thinned chains shorter than L
Reproduces Fig. 11.

In [None]:
#--------------Specify batch posterior estimation folder-----------------
MCMC_res_folder = "./Data/MCMC_Results/Two-Pop/" #Combined first and second batch of MCMC
nburn = 25000
end = 75000

#--------------Pull results from folders---------------
brute_force_folder = "./Data/Brute_force_posterior_estimation/Two-Pop/" #Combined first and second batch of MCMC
# # Second batch of MCMC

L = 50
ranks_mob = []
ranks_jp = []
ESS_too_small = []
ESS_min_value = []
thinned_chain_length = []

for sample_id in tqdm(range(1666)):
    
#     print('-------------------------')
#     print('-------Sample '+str(sample_id)+'-------')
#     print('-------------------------')
    
    #---------- Check MCMC results -----------
    
    MCMC_file = MCMC_results_folder+"/AMCMC_sample_ind_"+str(sample_id)+".pickle"
    data_file = open(MCMC_file, "rb")
    sol = pickle.load(data_file)
    data_file.close()
    
    mob_ESS = updated_result[sample_id][0]
    jp_ESS = updated_result[sample_id][2]
    ESS = min(mob_ESS, jp_ESS) #work with the smaller effective sample size
    ESS_min_value.append(ESS)
    
    # Thin and truncate the chain based on effective sample size and chosen L
    chain_thin_trunc = sol['chain'][nburn:end:int((end-nburn)/ESS),:]
    chain_thin_trunc = chain_thin_trunc[:L, :]
    
    mob_value = mobilities[sample_id]
    jp_value = jumping_probs[sample_id]
    
    if chain_thin_trunc.shape[0] < L:
        quantile_mob = np.searchsorted(np.sort(chain_thin_trunc[:,0].squeeze()),mob_value)/(chain_thin_trunc.shape[0])
        closest_rank_mob = round(quantile_mob*(L))
        quantile_jp = np.searchsorted(np.sort(chain_thin_trunc[:,1].squeeze()),jp_value)/(chain_thin_trunc.shape[0])
        closest_rank_jp = round(quantile_jp*(L))
        ESS_too_small.append(sample_id)
        ranks_mob.append(closest_rank_mob)
        ranks_jp.append(closest_rank_jp)
        print(quantile_jp)
        print('shape',np.unique(chain_thin_trunc).shape[0])
    else:
        ranks_mob.append(np.searchsorted(np.sort(chain_thin_trunc[:,0].squeeze()),mob_value))
        ranks_jp.append(np.searchsorted(np.sort(chain_thin_trunc[:,1].squeeze()),jp_value))
        
    thinned_chain_length.append(chain_thin_trunc.shape[0])

In [None]:
plt.figure(dpi = 200, figsize = (3.5,1.75))
bins = np.arange(0,L+2)
bins = np.arange(0,L+2,1)
bins[-1] = L+1
counts, bins = np.histogram(ranks_mob, bins=bins, range=None, density=None, weights=None)
plt.stairs(counts, bins-0.5, fill = True)
plt.xlabel('Mobility rank')
plt.ylabel('Frequency')

plt.figure(dpi = 200, figsize = (3.5,1.75))
bins = np.arange(0,L+2)
bins = np.arange(0,L+2,1)
bins[-1] = L+1
counts, bins = np.histogram(ranks_jp, bins=bins, range=None, density=None, weights=None)
plt.stairs(counts, bins-0.5, fill = True)
plt.xlabel('Jumping probability rank ')
plt.ylabel('Frequency')

### SBC, with no rank approximation, L = 8
Reproduces Fig. S1 in S1 Appendix.

In [None]:
#--------------Specify batch posterior estimation folder-----------------
MCMC_res_folder = "./Data/MCMC_Results/Two-Pop/" #Combined first and second batch of MCMC
nburn = 25000
end = 75000

#--------------Pull results from folders---------------
brute_force_folder = "./Data/Brute_force_posterior_estimation/Two-Pop/" #Combined first and second batch of MCMC
# # Second batch of MCMC

L = 8
ranks_mob = []
ranks_jp = []
ESS_too_small = []
ESS_min_value = []
thinned_chain_length = []

for sample_id in tqdm(range(1666)):
    
#     print('-------------------------')
#     print('-------Sample '+str(sample_id)+'-------')
#     print('-------------------------')
    
    #---------- Check MCMC results -----------
    
    MCMC_file = MCMC_results_folder+"/AMCMC_sample_ind_"+str(sample_id)+".pickle"
    data_file = open(MCMC_file, "rb")
    sol = pickle.load(data_file)
    data_file.close()
    
    mob_ESS = updated_result[sample_id][0]
    jp_ESS = updated_result[sample_id][2]
    ESS = min(mob_ESS, jp_ESS) #work with the smaller effective sample size
    ESS_min_value.append(ESS)
    
    # Thin and truncate the chain based on effective sample size and chosen L
    chain_thin_trunc = sol['chain'][nburn:end:int((end-nburn)/ESS),:]
    chain_thin_trunc = chain_thin_trunc[:L, :]
    
    mob_value = mobilities[sample_id]
    jp_value = jumping_probs[sample_id]
    
    if chain_thin_trunc.shape[0] < L:
        quantile_mob = np.searchsorted(np.sort(chain_thin_trunc[:,0].squeeze()),mob_value)/(chain_thin_trunc.shape[0])
        closest_rank_mob = round(quantile_mob*(L))
        quantile_jp = np.searchsorted(np.sort(chain_thin_trunc[:,1].squeeze()),jp_value)/(chain_thin_trunc.shape[0])
        closest_rank_jp = round(quantile_jp*(L))
        ESS_too_small.append(sample_id)
        ranks_mob.append(closest_rank_mob)
        ranks_jp.append(closest_rank_jp)
        print(quantile_jp)
        print('shape',np.unique(chain_thin_trunc).shape[0])
    else:
        ranks_mob.append(np.searchsorted(np.sort(chain_thin_trunc[:,0].squeeze()),mob_value))
        ranks_jp.append(np.searchsorted(np.sort(chain_thin_trunc[:,1].squeeze()),jp_value))
        
    thinned_chain_length.append(chain_thin_trunc.shape[0])

In [None]:
# Code snippet for chi-squared test

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

vals = ranks_mob 
n_bins = 9
counts, bins = np.histogram(vals, bins=n_bins, range=None, density=None, weights=None)
# plt.stairs(counts, bins, fill = True)
print('p-value:', scipy.stats.chisquare(counts).pvalue)
plt.show()

vals = ranks_jp 
n_bins = 9
counts, bins = np.histogram(vals, bins=n_bins, range=None, density=None, weights=None)
# plt.stairs(counts, bins, fill = True)
print('p-value:', scipy.stats.chisquare(counts).pvalue)

In [None]:
# plt.figure(dpi = 200, figsize = (3.5,1.75))
plt.figure(dpi = 500, figsize = (3.5,2.5))
bins = np.arange(0,L+2)
bins = np.arange(0,L+2,1)
bins[-1] = L+1
counts, bins = np.histogram(ranks_mob, bins=bins, range=None, density=None, weights=None)
plt.stairs(counts, bins-0.5, fill = True)
plt.xlabel('Mobility rank')
plt.ylabel('Frequency')
# plt.savefig('MCMC_2_param_case_SBC_mobility_L8.png', bbox_inches='tight')

# plt.figure(dpi = 200, figsize = (3.5,1.75))
plt.figure(dpi = 500, figsize = (3.5,2.5))
bins = np.arange(0,L+2)
bins = np.arange(0,L+2,1)
bins[-1] = L+1
counts, bins = np.histogram(ranks_jp, bins=bins, range=None, density=None, weights=None)
plt.stairs(counts, bins-0.5, fill = True)
plt.xlabel('Jumping probability rank ')
plt.ylabel('Frequency')
# plt.savefig('MCMC_2_param_case_SBC_jumping_prob_L8.png', bbox_inches='tight')
