In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import warnings; warnings.simplefilter('ignore')
import sys
sys.path.insert(0, '/Users/jsmonzon/Research/SatGen/mcmc/src/')
import jsm_halopull
import jsm_SHMR
import jsm_mcmc
import jsm_stats
import jsm_models
import seaborn as sns
from scipy import interpolate

In [3]:
plt.style.use('../paper/paper.mplstyle')
double_textwidth = 7.0 #inches
single_textwidth = 3.5 #inches

### calling in 10k merger tree realizations

In [4]:
from scipy.stats import ks_2samp

def lnL_PNsat(data, model):
    lnL = np.sum(np.log(model.PNsat[data.Nsat_perhost]))
    if np.isinf(lnL):
        #print("index error in Pnsat")
        return -np.inf
    else:
        return lnL

def lnL_KS_max(data, model):
    try:
        clean_max_split = list(map(model.max_split.__getitem__, data.model_mask)) # this might yield an index error!
        p_vals = np.array(list(map(lambda x, y: ks_2samp(x, y)[1], data.clean_max_split, clean_max_split)))
        return np.sum(np.log(p_vals))
    except IndexError:
        #print("this model is not preferable!")
        return -np.inf

In [5]:
# def measure_lnLvar_D(fid_theta=[10.5, 2.0, 0,0, 0.0, 0.2, 0.0], min_mass=6.5):

#     lnLs = []
        
#     for SAGA_ind in range(100):
#         class_i = jsm_models.SAMPLE_SAGA_MODELS(fid_theta, meta_path="../../../data/MW-analog/meta_data_psi4/", extra_path="../../../data/MW-analog/meta_data_psi3/", SAGA_ind=SAGA_ind, verbose=False)
#         Dstat_i = jsm_stats.SatStats_D(class_i.lgMs_data, min_mass, max_N=500)
#         Mstat_i = jsm_stats.SatStats_M(class_i.lgMs_model, min_mass, max_N=500)

#         lnL_i = lnL_PNsat(Dstat_i, Mstat_i) + lnL_KS_max(Dstat_i, Mstat_i)
#         lnLs.append(lnL_i)

#     lnLs  = np.array(lnLs)
#     lnLs_clean = lnLs[~np.isinf(lnLs)]
#     return lnLs_clean.std(ddof=1)

In [6]:
#dvar_S0 = measure_lnL_dvar("../../../data/MW-analog/meta_data_psi3/") #7.072
#dvar_S15 = measure_lnL_dvar("../../../data/cross_host/lognorm_015_psi3/") #8.229
#dvar_S30 = measure_lnL_dvar("../../../data/cross_host/lognorm_030_psi3/") #12.350

### now the variance in the models with respect to Nhost

In [7]:
def measure_lnLvar_M(Nhost_per_model, SAGA_ind, fid_theta=[10.5, 2.0, 0,0, 0.0, 0.2, 0.0], min_mass=6.5):
    
    #print("selecting the", SAGA_ind, "SAGA index")    
    class_i = jsm_models.SAMPLE_SAGA_MODELS(fid_theta, meta_path="../../../data/MW-analog/meta_data_psi4/", extra_path="../../../data/MW-analog/meta_data_psi3/", SAGA_ind=SAGA_ind, verbose=False)
    Dstat_i = jsm_stats.SatStats_D(class_i.lgMs_data, min_mass, max_N=500)

    Nhost_extra = class_i.lgMs_model.shape[0] % Nhost_per_model
    if Nhost_extra == 0:
        N_models = int(class_i.lgMs_model.shape[0] / Nhost_per_model)
        class_i.lgMs_model = class_i.lgMs_model.reshape([N_models, Nhost_per_model, class_i.lgMs_model.shape[1]])
    else:
        class_i.lgMs_model = np.delete(class_i.lgMs_model, np.arange(Nhost_extra), axis=0)
        N_models = int(class_i.lgMs_model.shape[0] / Nhost_per_model)
        class_i.lgMs_model = class_i.lgMs_model.reshape([N_models, Nhost_per_model,  class_i.lgMs_model.shape[1]])

    # print("When Nhost = ", Nhost_per_model, ",there are", Nhost_extra, "extra trees. That leaves", N_models, "model realizations")
    # print(class_i.lgMs_model.shape)

    lnLs = []
    for model in class_i.lgMs_model:
        Mstat_i = jsm_stats.SatStats_M(model, min_mass, max_N=500)
        lnL_i = lnL_PNsat(Dstat_i, Mstat_i) + lnL_KS_max(Dstat_i, Mstat_i)
        lnLs.append(lnL_i)

    lnLs = np.array(lnLs)
    inf_mask = np.isinf(lnLs)
    Ndrops = np.sum(inf_mask)/lnLs.shape[0]
    lnLs_clean = lnLs[~inf_mask]
    return lnLs_clean.std(ddof=1), Ndrops

In [8]:
# Nhost = np.logspace(1.1,3, 25).astype(int)

# Nsaga = 100

# var_mat = np.full((Nsaga, Nhost.shape[0]), np.nan)
# drop_mat = np.full((Nsaga, Nhost.shape[0]), np.nan)

# for i,index in enumerate(range(Nsaga)):
#     for j,Nmod in enumerate(Nhost):
#         var_ij, drop_ij = measure_lnLvar_M(Nhost_per_model=Nmod, SAGA_ind=index)
#         var_mat[i,j] = var_ij
#         drop_mat[i,j] = drop_ij

# np.save("../../mcmc/inference_tests/convergence/data_saves/drop_S0.npy", drop_mat)
# np.save("../../mcmc/inference_tests/convergence/data_saves/var_S0.npy", var_mat)

In [9]:
# fig, axes = plt.subplots(2, 1, sharex=True, sharey=False, figsize=(double_textwidth, double_textwidth), gridspec_kw={'height_ratios': [3, 1]})
# axes[0].axhline(dvar_S0, ls="--", color="green", label="$\sigma_{\\vec{SS}} = \sqrt{ \sigma_{\\vec{D}}^2 + \sigma_{4}^2} $")
# axes[0].plot(np.log10(Nhost), np.nanmedian(var_mat, axis=0), color="k", label="$< \sigma_{\\vec{M}} >$")

# axes[1].scatter(np.log10(Nhost),1-np.median(drop_mat, axis=0), color="k", marker="+")
# axes[1].set_xlabel("$\log \hat{N}_{\mathrm{host}}$", fontsize=15)

# plt.tight_layout()
# #plt.savefig("../../../paper_1/figures/aux/convergence.pdf", bbox_inches="tight")
# plt.show()