In [None]:
import os
import ast
import matplotlib.pyplot as plt
import sys
sys.path.append(os.path.join(os.getcwd(), '../desi/'))
import catalog_definitions as cat
import numpy as np
import emcee
import corner
from IPython.display import display, Latex

# 10 Parameters associated with galaxy colors
# Multiple configurations could produce the same LHMR, fsat, etc.
# The actual values of these parameters is not of central interest;
# it's the implied LHMR, fsat, etc. that we really care about.
# Thus any degeneracies in these parameters are not a concern.

# A zeroth and first order polynomial in log L_gal for B_sat, which controls the satelite threshold
# Bsat,r = β_0,r + β_L,r(log L_gal − 9.5)
# Bsat,b = β_0,b + β_L,b(log L_gal − 9.5)
# Constrained from projected two-point clustering comparison for r,b seperately

# Weights for each galaxy luminosity, when abundance matching
# log w_cen,r = (ω_0,r / 2) (1 + erf[(log L_gal - ω_L,r) / σω,r)] ) 
# log w_cen,b = (ω_0,b / 2) (1 + erf[(log L_gal - ω_L,b) / σω,b)] ) 
# Constrained from Lsat,r/Lsat,b ratio and projected two-point clustering.

# A secondary, individual galaxy property can be introduced to affect the weight for abundance matching.
#  2 Parameters (one for each red and blue)
# w_χ,r = exp(χ/ω_χ,r)
# w_χ,b = exp(χ/ω_χ,b)
# Constrained from Lsat(χ|L_gal) data.

# Analyze MCMC using emcee backend

In [None]:
# Folder number to look in
run = 1
run_file = run

path = f'/mount/sirocco1/imw2293/GROUP_CAT/MCMC/mcmc_{run}/mcmc_{run_file}.dat'
#path = f'./mcmc_{run}/mcmc_{run_file}.dat'

reader = emcee.backends.HDFBackend(path, read_only=True)
samples = reader.get_chain()
ndim = reader.shape[1]
print(f'Number of steps: {samples.shape[0] * samples.shape[1]} (total); {samples.shape[0]} (per walker), ')
print(f'Number of walkers: {samples.shape[1]}')
print(f'Number of parameters: {ndim}')

try:
    tau = reader.get_autocorr_time()
    print(tau)
except:
    print("Not burnt in yet")

burn_number = 150 # TODO choose this by inspecting the chains above. wait for convergence in all parameters
thin_number = 1 # TODO not sure how or why to choose this
flat_samples = reader.get_chain(discard=burn_number, thin=thin_number, flat=True)

In [None]:
fig, axes = plt.subplots(ndim, figsize=(10, 2.5*ndim), sharex=True)
labels = ['$\\omega_{L,b}$', '$\\sigma_{\\omega,b}$', '$\\omega_{L,r}$', '$\\sigma_{\\omega,r}$', '$\\omega_{0,b}$', '$\\omega_{0,r}$', '$\\beta_{0,r}$', '$\\beta_{L,r}$', '$\\beta_{0,b}$', '$\\beta_{L,b}$']
#bad = 15 # this walker went off the deep end...
good_walkers = list(np.arange(samples.shape[1]))
#good_walkers.remove(bad)
for i in range(ndim):
    ax = axes[i]
    ax.plot(samples[:, good_walkers, i], alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])
    #ax.set_yscale('log')
    #ax.yaxis.set_label_coords(-0.1, 0.5)

axes[-1].set_xlabel("step number")

In [None]:
# The corner plot shows all 1D and 2D projections of the posterior probabilities of your parameters.
# This is useful because it quickly demonstrates all of the covariances between parameters. 
# Also, the way that you find the marginalized distribution for a parameter or set of parameters 
#   using the results of the MCMC chain is to project the samples into that plane and then make 
#   an N-dimensional histogram. 
# That means that the corner plot shows the marginalized distribution for each parameter independently 
#   in the histograms along the diagonal and then the marginalized two dimensional distributions 
#   in the other panels.
fig = corner.corner(
    flat_samples, labels=labels#, truths=[m_true, b_true, np.log(f_true)]
)

In [None]:
# Show best model (lowest chi squared)
idx = np.argmax(reader.get_log_prob(flat=True))
all_flat_samples = reader.get_chain(flat=True)
best_fit = all_flat_samples[idx]
# Print with labels, need latex formatting
for i in range(len(labels)):
    display(Latex(f'{labels[i]} = {best_fit[i]:.3f}'))

# Analyze MCMC using manual output file
But this is problematic in that it resets the model count to 1 everytime we startup after a crash

In [None]:
# My Chains and Distributions of Parameters from the out.i file
# Can get parameter info from the output file as well

folder_num = 9

models = []
chi = []
with open(f'mcmc_{folder_num}/out.{folder_num}', 'r') as file:
    model_num = -1
    for line in file:
        if line.startswith('MODEL'):
            model_num = int(line.split(' ', 2)[-1].strip())
        elif line.startswith('{\'zmin\':'):
            parameter_dict = ast.literal_eval(line)
            models.append(parameter_dict)
        elif line.startswith('CHI'):
            chi.append(float(line.split(' ', 2)[-1].strip()))

    assert len(models) == len(chi)

def get_parameter_values(parameter_name):
    return [model[parameter_name] for model in models]

exclusions = ['zmin', 'zmax', 'frac_area', 'fluxlim', 'color']

In [None]:
# My versions of chain plots and parameter distributions
for pname in models[0]:
    if pname in exclusions:
        continue
    values=get_parameter_values(pname)
    plt.plot(values, color="k")
    plt.xlabel('Iteration')
    plt.ylabel(f'Parameter Value')
    plt.title(f'Parameter Chain for {pname}')
    plt.show()

    plt.hist(values, 100, color="k", histtype="step")
    plt.xlabel(f"{pname}")
    plt.ylabel(f"$p({pname})$")
    plt.gca().set_yticks([])
    plt.show()

plt.plot(chi)
plt.xlabel('Iteration')
plt.ylabel('Chi Squared')
plt.title('Chi Squared Chain')
plt.show()

In [None]:
# Show best model
best_model = models[chi.index(min(chi))]
print(f'Best model is model {chi.index(min(chi))} with chi squared of {min(chi)}')

#compare each property of best model to cat.sdss_colors.GF_props
for key in best_model:
    if key in exclusions:
        continue
    print(f'{key.ljust(12)}:  {best_model[key]:.4} vs {cat.sdss_colors.GF_props[key]:.4} ({100 * (best_model[key] - cat.sdss_colors.GF_props[key]) / cat.sdss_colors.GF_props[key]:.2f}%)')


In [None]:
from IPython.display import display, Math

for i in range(ndim):
    mcmc = np.percentile(flat_samples[:, i], [16, 50, 84])
    q = np.diff(mcmc)
    txt = "{3} = {0:.3f}_{{-{1:.3f}}}^{{{2:.3f}}}"
    txt = txt.format(mcmc[1], q[0], q[1], labels[i])
    display(Math(txt))