In [None]:
import os
import ast
import matplotlib.pyplot as plt
import sys
sys.path.append(os.path.join(os.getcwd(), '../desi/'))
import numpy as np
import emcee
import corner
from IPython.display import display, Latex, Math
from copy import deepcopy

if './SelfCalGroupFinder/py/' not in sys.path:
    sys.path.append('./SelfCalGroupFinder/py/')
from pyutils import *
from groupcatalog import *
import plotting as pp
import catalog_definitions as cat

%load_ext autoreload
%autoreload 2

param_ranges = [(10,20),(0,5),(-2,30),(-5,20),(0,35),(-10,15),(-6,6),(4,25),(8,25),(-25,5)]

# 10 Parameters associated with galaxy colors
# Multiple configurations could produce the same LHMR, fsat, etc.
# The actual values of these parameters is not of central interest;
# it's the implied LHMR, fsat, etc. that we really care about.
# Thus any degeneracies in these parameters are not a concern.

# A zeroth and first order polynomial in log L_gal for B_sat, which controls the satelite threshold
# Bsat,r = β_0,r + β_L,r(log L_gal − 9.5)
# Bsat,b = β_0,b + β_L,b(log L_gal − 9.5)
# Constrained from projected two-point clustering comparison for r,b seperately

# Weights for each galaxy luminosity, when abundance matching
# log w_cen,r = (ω_0,r / 2) (1 + erf[(log L_gal - ω_L,r) / σ_ω,r)] ) 
# log w_cen,b = (ω_0,b / 2) (1 + erf[(log L_gal - ω_L,b) / σ_ω,b)] ) 
# Constrained from Lsat,r/Lsat,b ratio and projected two-point clustering.

# A secondary, individual galaxy property can be introduced to affect the weight for abundance matching.
#  2 Parameters (one for each red and blue)
# w_χ,r = exp(χ/ω_χ,r)
# w_χ,b = exp(χ/ω_χ,b)
# Constrained from Lsat(χ|L_gal) data.

# SDSS
# Ideal Chi^2 Estimate = N_dof = N_data - N_params = 100 - 10 = 90

# BGS
# Ideal Chi^2 Estimate = N_dof = N_data - N_params = 30*6 + 20 - 10 = 190
# Chi squared per dof, divide by 200

# BGS Mini
# Ideal Chi^2 Estimate = N_dof = N_data - N_params = 20*3 + 10*3 + 20 - 10 = 100
# Chi squared per dof, divide by 110


#Job ID                    Name             User            Time Use S Queue
#------------------------- ---------------- --------------- -------- - -----
#15016.master.local         ian.optuna0      imw2293         415:05:3 R default        (TPE)
#15017.master.local         ian.optuna1      imw2293         70:22:47 R default        (Q MC)
#15018.master.local         ian.optuna2      imw2293         404:17:4 R default        (GP)
#15019.master.local         ian.emcee3       imw2293         72:43:09 R default        (emcee Stretch)

labels = ['$\\omega_{L,b}$', '$\\sigma_{\\omega,b}$', '$\\omega_{L,r}$', '$\\sigma_{\\omega,r}$', '$\\omega_{0,b}$', '$\\omega_{0,r}$', '$\\beta_{0,r}$', '$\\beta_{L,r}$', '$\\beta_{0,b}$', '$\\beta_{L,b}$']



# Analyze Results

In [None]:
#gc = deepcopy(cat.bgs_sv3_hybrid_mcmc) # 43170, 18000
#gc = deepcopy(cat.bgs_y3_like_sv3_hybrid_mcmc_new)
#gc = deepcopy(cat.bgs_y1mini_hybrid_mcmc)

#gc = deepcopy(cat.bgs_y1_hybrid_mcmc)
gc = deserialize(cat.bgs_y1_hybrid8_mcmc)
best = gc.load_best_params_across_runs()
print(best)

In [None]:
gc.preprocess()
gc.run_group_finder(popmock=True, silent=False)
gc.calc_wp_for_mock()
gc.chisqr()
gc.postprocess()
gc.dump()

In [None]:
pp.proj_clustering_plot(gc)
pp.lsat_data_compare_plot(gc)

In [None]:
#TODO fsat with saved variance

In [None]:
# TODO LHMR with saved variance

In [None]:
pp.LHMR_withscatter(gc)

In [None]:
pp.single_plots(gc)

In [None]:
pp.hod_plot(gc)

In [None]:
def combine_emcee_backends(backends):
    """
    Combine multiple emcee backends into a single set of chains.

    Parameters
    ----------
    backends : list of emcee.backends.backend.Backend
        List of emcee backends to combine.

    Returns
    -------
    combined_samples : np.ndarray
        Combined chains of shape (nsteps_total, nwalkers, ndim)
    """
    chains = [b.get_chain() for b in backends]
    shapes = [c.shape for c in chains]
    print(f"Shapes: {shapes}")
    to_drop = np.full((len(chains),), False)
    walkers = 0
    dims = chains[0].shape[2] 

    longest_steps = max(shape[0] for shape in shapes)
    print(f"Longest chain has {longest_steps} steps.")
    # Pad shorter chains with NaNs to match the longest chain length
    for i in range(len(chains)):
        if shapes[i][0] < longest_steps:
            pad_length = longest_steps - shapes[i][0]
            if pad_length > 0:
                if pad_length > shapes[i][0]:
                    print(f"Chain {i} is too short ({shapes[i][0]} steps), dropping it.")
                    to_drop[i] = True
                else:
                    print(f"Padding chain {i} with {pad_length} NaN steps to match the longest chain length.")
                chains[i] = np.pad(chains[i], ((0, pad_length), (0, 0), (0, 0)), mode='constant', constant_values=np.nan)

    for i in range(len(chains)):
        if not to_drop[i]:
            walkers += shapes[i][1]

    combined = np.full((longest_steps, walkers, dims), np.nan)
    print(f"Combined shape will be: {combined.shape}")

    # Fill the combined array with the chains, skipping those marked for dropping
    walker_index = 0
    for i in range(len(chains)):
        if not to_drop[i]:
            nwalkers = chains[i].shape[1]
            combined[:, walker_index:walker_index + nwalkers, :] = chains[i]
            walker_index += nwalkers
    
    return combined
    

In [None]:
chains = combine_emcee_backends(gc.get_backends()[0])

In [None]:
for reader in gc.get_backends()[0]:
    if isinstance(reader, emcee.backends.backend.Backend):
        samples = reader.get_chain()
        ndim = reader.shape[1]
        print(f'Number of steps: {samples.shape[0] * samples.shape[1]} (total); {samples.shape[0]} (per walker), ')
        print(f'Number of walkers: {samples.shape[1]}')
        print(f'Number of parameters: {ndim}')

        try:
            tau = reader.get_autocorr_time()
            print(tau)
        except:
            print("Not burnt in yet")

        # Print off the current walker positions in a nice arrays
        # One line per walker in order
        PRINT_WALKERS = False
        if PRINT_WALKERS:
            with np.printoptions(precision=5, suppress=True, linewidth=500,  formatter={'all': lambda x: f"{x:.3f},"}):
                current = samples[-1]
                chisqr = -2 * reader.get_log_prob(flat=False)[-1]
                median_chisqr = np.median(chisqr)
                good = np.where(chisqr < median_chisqr)
                print(np.array2string(current[good]))
                print(chisqr[good])

        burn_number = 0 # TODO choose this by inspecting the chains above. wait for convergence in all parameters
        thin_number = 1
        flat_samples = reader.get_chain(discard=burn_number, thin=thin_number, flat=True)
        flat_samples.shape

In [None]:
fig, axes = plt.subplots(ndim, figsize=(10, 2.5*ndim), sharex=True)
for i in range(ndim):
    ax = axes[i]
    ax.plot(chains[:, :, i], alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])

    # Set y-limits to cover 90% of the walkers locations
    min_val = np.nanpercentile(chains[:, :, i], 5)
    max_val = np.nanpercentile(chains[:, :, i], 95)
    ax.set_ylim(min_val, max_val)
    
    # label each walker number
    #for j in good_walkers:
    #    ax.text(10000, samples[10000, j, i], f'{j}', color='k', fontsize=6)

axes[-1].set_xlabel("step number")

In [None]:
# Select the nth walker
n = 0  # Replace with the desired walker index
p = 8
walker_samples = samples[:, n, :]
walker_chisqr = -2*reader.get_log_prob(discard=0, flat=False)[:, 0]

# Create a boolean array indicating whether parameter values were updated at each step
updated = np.any(np.diff(walker_samples, axis=0) != 0, axis=1)

# Add a False at the beginning since the first step has no previous step to compare
updated = np.insert(updated, 0, False)

with np.printoptions(threshold=np.inf, linewidth=np.inf, suppress=True, formatter={'float': '{:6.1f}'.format, 'bool': '{:6}'.format}):
    print(updated)
    print(walker_chisqr)
    print(walker_samples[:, p])
    # 

In [None]:
# The corner plot shows all 1D and 2D projections of the posterior probabilities of your parameters.
# This is useful because it quickly demonstrates all of the covariances between parameters. 
# Also, the way that you find the marginalized distribution for a parameter or set of parameters 
#   using the results of the MCMC chain is to project the samples into that plane and then make 
#   an N-dimensional histogram. 
# That means that the corner plot shows the marginalized distribution for each parameter independently 
#   in the histograms along the diagonal and then the marginalized two dimensional distributions 
#   in the other panels.
all_flat_samples = chains.reshape(-1, ndim)

#fig = corner.corner(all_flat_samples, labels=labels, range=param_ranges)

In [None]:
# Show the best N unique models (lowest chi squared)
N = 3  # Number of best models to display
log_probs = reader.get_log_prob(flat=True)
all_flat_samples = reader.get_chain(flat=True)

# Get indices of the top models with the highest log probabilities
best_indices = np.argsort(log_probs)[::-1]

shown = 0
seen = set()
for idx in best_indices:
    # Convert to tuple for hashable comparison (rounded to avoid float precision issues)
    fit_tuple = tuple(np.round(all_flat_samples[idx], 5))
    if fit_tuple in seen:
        continue
    seen.add(fit_tuple)
    best_fit = all_flat_samples[idx]
    chi_squared = -2 * log_probs[idx]
    print(f"BEST MODEL {shown+1} (chi={chi_squared:.3f})")

    with np.printoptions(precision=5, suppress=True, formatter={'all': lambda x: f"{x:.3f},"}, linewidth=500):
        print(best_fit)
        pp.plot_parameters(best_fit)
    #for i in range(len(labels)):
    #    display(Latex(f'{labels[i]} = {best_fit[i]:.3f}'))
    #print()
    
    shown += 1
    if shown >= N:
        break


In [None]:
# Then print means of the posteriors
print("MEAN MODEL")
for i in range(ndim):
    mcmc = np.percentile(flat_samples[:, i], [16, 50, 84])
    q = np.diff(mcmc)
    txt = f"{labels[i]} = ${mcmc[1]:.3f}_{{-{q[0]:.3f}}}^{{{q[1]:.3f}}}$"
    display(Latex(txt))

# View Variances of fsat, LHMR from chains

In [None]:
# This will save off a .npy file with the array of fsat values
#save_from_log(PY_SRC_FOLDER + 'exec.out', overwrite=False)

# TODO From Blobs to error estimates
backends, folders = gc.get_backends()
save_from_backend(backends, overwrite=True)

In [None]:
fsat_std, fsatr_std, fsatb_std, fsat_mean, fsatr_mean, fsatb_mean = fsat_variance_from_saved()
#np.save(OUTPUT_FOLDER + 'std_fsat.npy', (fsat_std, fsatr_std, fsatb_std, fsat_mean, fsatr_mean, fsatb_mean))

plt.figure()
plt.errorbar(L_gal_bins, fsat_mean, yerr=fsat_std, fmt='.', color='k', label='All', capsize=3, alpha=0.7)
plt.errorbar(L_gal_bins, fsatr_mean, yerr=fsatr_std, fmt='.', color='r', label='Quiescent', capsize=3, alpha=0.7)
plt.errorbar(L_gal_bins, fsatb_mean, yerr=fsatb_std, fmt='.', color='b', label='Star-forming', capsize=3, alpha=0.7)
plt.xlabel('$L_{\mathrm{gal}}$')
plt.ylabel(r'$\langle f_{\mathrm{sat}} \rangle$')
plt.legend()
plt.xscale('log')
plt.xlim(1E7, 2E11)
plt.ylim(0.0, 1.0)
plt.tight_layout()
plt.show()

In [None]:
lhmr_r_mean, lhmr_r_std, lhmr_r_scatter_mean, lhmr_r_scatter_std, lhmr_b_mean, lhmr_b_std, lhmr_b_scatter_mean, lhmr_b_scatter_std, lhmr_all_mean, lhmr_all_std, lhmr_all_scatter_mean, lhmr_all_scatter_std = lhmr_variance_from_saved()

plt.figure()
plt.errorbar(Mhalo_bins, lhmr_all_mean, yerr=lhmr_all_std, fmt='.', color='k', label='All', capsize=3, alpha=0.7)
plt.errorbar(Mhalo_bins, lhmr_b_mean, yerr=lhmr_b_std, fmt='.', color='b', label='Star-forming', capsize=3, alpha=0.7)
plt.errorbar(Mhalo_bins, lhmr_r_mean, yerr=lhmr_r_std, fmt='.', color='r', label='Quiescent', capsize=3, alpha=0.7)
plt.xlabel('$log_{10}(M_{halo}~[M_\\odot]$')
plt.ylabel(r'$\langle L_{\mathrm{cen}} \rangle$')
plt.title("Mean Central Luminosity vs. Halo Mass")
plt.legend()
plt.xscale('log')
plt.yscale('log')
plt.xlim(1E10, 1E15)
plt.ylim(1E7, 5E11)
plt.tight_layout()
plt.show()

plt.figure()
plt.errorbar(Mhalo_bins, lhmr_all_scatter_mean, yerr=lhmr_all_scatter_std, fmt='.', color='k', label='All', capsize=3, alpha=0.7)
plt.errorbar(Mhalo_bins, lhmr_b_scatter_mean, yerr=lhmr_b_scatter_std, fmt='.', color='b', label='Star-forming', capsize=3, alpha=0.7)
plt.errorbar(Mhalo_bins, lhmr_r_scatter_mean, yerr=lhmr_r_scatter_std, fmt='.', color='r', label='Quiescent', capsize=3, alpha=0.7)
plt.xlabel('$log_{10}(M_{halo}~[M_\\odot]$')
plt.ylabel(r'$\sigma_{{\mathrm{log}}(L_{\mathrm{cen}})}~$[dex]')
plt.title("Central Luminosity Scatter vs. Halo Mass")
plt.legend()
plt.xscale('log')
plt.xlim(1E10, 1E15)
plt.tight_layout()
plt.show()

In [None]:
lsat_r_mean, lsat_r_std, lsat_b_mean, lsat_b_std = lsat_variance_from_saved()

data = np.loadtxt(LSAT_OBSERVATIONS_SDSS_FILE, skiprows=0, dtype='float')
pp.lsat_compare_plot(data, lsat_r_mean, lsat_b_mean, lsat_r_std, lsat_b_std)