In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import warnings; warnings.simplefilter('ignore')
import sys
sys.path.insert(0, '/Users/jsmonzon/Research/SatGen/mcmc/src/')
import jsm_halopull
import jsm_SHMR
import jsm_mcmc
import jsm_stats
import jsm_models
import seaborn as sns

In [3]:
plt.style.use('../paper/paper.mplstyle')
double_textwidth = 7.0 #inches
single_textwidth = 3.5 #inches

### calling in 10k merger tree realizations

In [4]:
s0_mat = np.load("../../../data/MW-analog/meta_data_psi3/models.npz")

s15_mat = np.load("../../../data/cross_host/lognorm_015_psi3/models.npz")

s30_mat = np.load("../../../data/cross_host/lognorm_030_psi3/models.npz")

In [5]:
test_array = np.load("../../../data/MW-analog/meta_data_psi3/models.npz")["mass"]
print(test_array.shape)
test_array = np.delete(test_array, 0, axis=0)
print(test_array.shape)

(100, 100, 1749)
(99, 100, 1749)


In [6]:
hmm = test_array.reshape([990, 10, 1749])

In [7]:
hmm.shape

(990, 10, 1749)

### first the variance in the data

In [8]:
from scipy.stats import ks_2samp

def lnL_PNsat(data, model):
    lnL = np.sum(np.log(model.PNsat[data.Nsat_perhost]))
    if np.isinf(lnL):
        #print("index error in Pnsat")
        return -np.inf
    else:
        return lnL
    
def lnL_PNsat_test(data, model):
    return model.PNsat[data.Nsat_perhost]

def lnL_KS_max(data, model):
    try:
        clean_max_split = list(map(model.max_split.__getitem__, data.model_mask)) # this might yield an index error!
        p_vals = np.array(list(map(lambda x, y: ks_2samp(x, y)[1], data.clean_max_split, clean_max_split)))
        return np.sum(np.log(p_vals))
    except IndexError:
        #print("this model is not preferable!")
        return -np.inf

In [9]:
def measure_lnL_dvar(meta_path):

    fid_theta = [10.5, 2.0, 0,0,0.0, 0]
    min_mass = 6.5

    lnLs = []

    for SAGA_ind in range(100):
        class_i = jsm_models.SAMPLE_SAGA_MODELS(fid_theta, meta_path, SAGA_ind)
        Dstat_i = jsm_stats.SatStats_D(class_i.lgMs_data, min_mass=min_mass, max_N=500)

        lgMs_model = jsm_SHMR.general_new(fid_theta, class_i.lgMh_models, 0, 1)
        Mstat_i = jsm_stats.SatStats_M(lgMs_model, min_mass, max_N=500)

        lnL_Pnsat_i = lnL_PNsat(Dstat_i, Mstat_i)
        lnL_KS_tot_i = lnL_KS_max(Dstat_i, Mstat_i)

        lnL_i = lnL_Pnsat_i + lnL_KS_tot_i
        lnLs.append(lnL_i)

    lnLs  = np.array(lnLs)
    lnLs_clean = lnLs[~np.isinf(lnLs)]

    sns.kdeplot(lnLs_clean)
    plt.xlabel("ln L")
    plt.title(f"mean: {lnLs_clean.mean():.3f}, width: {lnLs_clean.std():.3f}")
    plt.show()

    return lnLs_clean.std()

In [10]:
#dvar = measure_lnL_dvar("../../../data/MW-analog/meta_data_psi3/")

dvar= 7.1

### now the variance in the models Nhost = 10, 30, 100, 300, 1000

In [11]:
def measure_lnL_mvar(meta_path, Nreal):

    fid_theta = [10.5, 2.0, 0, 0, 0.0, 0]
    min_mass = 6.5

    lnLs = []

    master = jsm_models.SAMPLE_SAGA_MODELS(fid_theta, meta_path, SAGA_ind=0, Nreal_per_model=Nreal)
    Dstat = jsm_stats.SatStats_D(master.lgMs_data, min_mass=min_mass, max_N=500)

    for model_ind in range(master.N_model_realizations):
        lgMs_model = jsm_SHMR.general_new(fid_theta, master.lgMh_models[model_ind], 0, 1)
        Mstat_i = jsm_stats.SatStats_M(lgMs_model, min_mass, max_N=500)

        lnL_Pnsat_i = lnL_PNsat(Dstat, Mstat_i)
        lnL_KS_tot_i = lnL_KS_max(Dstat, Mstat_i)

        lnL_i = lnL_Pnsat_i + lnL_KS_tot_i
        lnLs.append(lnL_i)

    lnLs  = np.array(lnLs)
    lnLs_clean = lnLs[~np.isinf(lnLs)]

    # sns.kdeplot(lnLs_clean)
    # plt.xlabel("ln L")
    # plt.title(f"mean: {lnLs_clean.mean():.3f}, width: {lnLs_clean.std():.3f}")
    # plt.show()

    return lnLs_clean.std()

In [16]:
Nhost = np.logspace(1,3).astype(int)

In [17]:
Mvar = []
for N in Nhost:
    Mvar.append(measure_lnL_mvar("../../../data/MW-analog/meta_data_psi3/", N))

selecting the 0  SAGA sample
converting the subhalos to satellites and creating the mock data instance
there are 990 unique model realizations, each made up of 10 merger trees
there are 0 extra merger trees, deleting these unused trees
selecting the 0  SAGA sample
converting the subhalos to satellites and creating the mock data instance
there are 990 unique model realizations, each made up of 10 merger trees
there are 0 extra merger trees, deleting these unused trees
selecting the 0  SAGA sample
converting the subhalos to satellites and creating the mock data instance
there are 825 unique model realizations, each made up of 12 merger trees
there are 0 extra merger trees, deleting these unused trees
selecting the 0  SAGA sample
converting the subhalos to satellites and creating the mock data instance
there are 761 unique model realizations, each made up of 13 merger trees
there are 7 extra merger trees, deleting these unused trees
selecting the 0  SAGA sample
converting the subhalos to 

In [None]:
plt.scatter(Nhost, Mvar, marker="+")
plt.axhline(dvar, ls="--", color="k")
plt.axvline(1000, ls=":", color="grey")
plt.axvline(100, ls=":", color="grey")
plt.axvline(10, ls=":", color="grey")
plt.xscale("log")
plt.xlabel("$\hat{N}_{\mathrm{host}}$")
plt.ylabel("$\sigma_{lnL}$")
plt.show()