In [None]:
import os
import ast
import matplotlib.pyplot as plt
import sys
sys.path.append(os.path.join(os.getcwd(), '../desi/'))
import numpy as np
import emcee
import corner
from IPython.display import display, Latex, Math
from copy import deepcopy

if './SelfCalGroupFinder/py/' not in sys.path:
    sys.path.append('./SelfCalGroupFinder/py/')
from pyutils import *
from groupcatalog import *
import plotting as pp
import catalog_definitions as cat

%load_ext autoreload
%autoreload 2

param_ranges = [(10,20),(1,7),(-2,30),(-5,20),(5,80),(-10,15),(-8,6),(4,80),(8,20),(-25,5)]

# 10 Parameters associated with galaxy colors

# A zeroth and first order polynomial in log L_gal for B_sat, which controls the satelite threshold
# Bsat,r = β_0,r + β_L,r(log L_gal − 9.5)
# Bsat,b = β_0,b + β_L,b(log L_gal − 9.5)
# Constrained from projected two-point clustering comparison for r,b seperately

# Weights for each galaxy luminosity, when abundance matching
# log w_cen,r = (ω_0,r / 2) (1 + erf[(log L_gal - ω_L,r) / σ_ω,r)] ) 
# log w_cen,b = (ω_0,b / 2) (1 + erf[(log L_gal - ω_L,b) / σ_ω,b)] ) 
# Constrained from Lsat,r/Lsat,b ratio and projected two-point clustering.

# A secondary, individual galaxy property can be introduced to affect the weight for abundance matching.
#  2 Parameters (one for each red and blue)
# w_χ,r = exp(χ/ω_χ,r)
# w_χ,b = exp(χ/ω_χ,b)
# Constrained from Lsat(χ|L_gal) data.

# SDSS
# Ideal Chi^2 Estimate = N_dof = N_data - N_params = 100 - 10 = 90

# BGS
# Ideal Chi^2 Estimate = N_dof = N_data - N_params = 15*2 + 30*6 + 20 - 10 = 220
# Chi squared per dof, divide by 220

labels = ['$\\omega_{L,b}$', '$\\sigma_{\\omega,b}$', '$\\omega_{L,r}$', '$\\sigma_{\\omega,r}$', '$\\omega_{0,b}$', '$\\omega_{0,r}$', '$\\beta_{0,r}$', '$\\beta_{L,r}$', '$\\beta_{0,b}$', '$\\beta_{L,b}$']



# Analyze Results

In [None]:
#gc = deepcopy(cat.bgs_sv3_hybrid_mcmc) # 43170, 18000
#gc = deepcopy(cat.bgs_y3_like_sv3_hybrid_mcmc_new)
#gc = deepcopy(cat.bgs_y1mini_hybrid_mcmc)

#gc = deepcopy(cat.bgs_y1_hybrid_mcmc)
#gc = deserialize(cat.bgs_y1_hybrid8_mcmc)
#gc = deserialize(cat.bgs_y1_hybrid8_v1_mcmc)
gc = deepcopy(cat.bgs_y1_hybrid8_v1_mcmc)
best = gc.load_best_params_across_runs(median_best=True)
print(best)

# 16.34605155  4.54724616 17.22001332  7.26752538 23.83699212  5.20258597 -3.85224961 17.38279755 10.71420271 -0.91545774]
# 12.30536655,  1.67078568, 11.82548149,  1.71729165,  8.40348227,  0.94817885, -3.24537345, 16.16321933, 10.58781695, 1.26489381

In [None]:
gc.GF_props['halomassfunc'] = HMF_T08_P18_FILE
gc.preprocess()
gc.run_group_finder(popmock=True, silent=False)
gc.calc_wp_for_mock()
gc.chisqr()
gc.postprocess()
gc.dump()

In [None]:
pp.proj_clustering_plot(gc)
pp.lsat_data_compare_plot(gc)

In [None]:
chains, logprob = combine_emcee_backends(gc.get_backends()[0])
dims = chains.shape[2]
chains_flat = chains.reshape(-1, dims)
logprob_flat = logprob.flatten()

In [None]:
pp.gfparams_plots(gc, chains_flat)
plt.show()

In [None]:
# Remove nans from both and ensure they are at same indexes
valid_indices = ~np.isnan(logprob_flat) & ~np.any(np.isnan(chains_flat), axis=1)
chains_flat = chains_flat[valid_indices]
logprob_flat = logprob_flat[valid_indices]  

print(f"Flat chains shape: {chains_flat.shape}")
print(f"Flat log probabilities shape: {logprob_flat.shape}")

In [None]:
# --- 1. Define the central 68% region for each parameter ---
bounds = []
for i in range(dims):
    lower, upper = np.percentile(chains_flat[:, i], [16, 84])
    bounds.append((lower, upper))

# --- 2. Filter the chains to find points where all parameters are within their 68% CI ---
# Start with a mask of all True
within_bounds_mask = np.full(chains_flat.shape[0], True)
for i in range(dims):
    # Update the mask, keeping only points within the bounds for the current parameter
    within_bounds_mask &= (chains_flat[:, i] >= bounds[i][0]) & (chains_flat[:, i] <= bounds[i][1])

# Create the subset of chains and log probabilities
subset_chains = chains_flat[within_bounds_mask]
subset_logprob = logprob_flat[within_bounds_mask]

print(f"Found {len(subset_chains)} of {len(chains_flat)} within the central 68% region of all parameters.")

# --- 3. Find the best-fit parameters from this subset ---
if len(subset_chains) > 0:
    best_fit_idx_in_subset = np.argmax(subset_logprob)
    best_fit_params = subset_chains[best_fit_idx_in_subset]
else:
    print("Warning: No points found in the central 68% region. Using the absolute best-fit instead.")
    best_fit_idx = np.argmax(logprob_flat)
    best_fit_params = chains_flat[best_fit_idx]


# --- 4. Generate LaTeX table rows ---
# The median and errors are still calculated from the full distribution
print("\\hline")
for i in range(dims):
    # Calculate the 16th, 50th (median), and 84th percentiles from the full chain
    mcmc = np.percentile(chains_flat[:, i], [16, 50, 84])
    
    median = mcmc[1]
    lower_err = median - mcmc[0]
    upper_err = mcmc[2] - median

    low_val = median - lower_err
    high_val = median + upper_err
    
    # Get the best-fit value found from the subset
    best_val = best_fit_params[i]
    
    param_label = labels[i]
    
    #latex_row = f"{param_label.ljust(20)} & ${median:.2f}_{{-{lower_err:.2f}}}^{{+{upper_err:.2f}}}$ & ${best_val:.2f}$ \\\\"
    latex_row = f"{param_label.ljust(20)}   & {best_val:.2f}    &  [{low_val:.2f}, {high_val:.2f}]  \\\\"
    print(latex_row)

print("\\hline")

In [None]:
# Show the best N unique models (lowest chi squared) from both the full chain and the subset
N = 5  # Number of best models to display

# --- 1. Show best models from the ENTIRE chain ---
print("--- Top N Overall Models (from all chains) ---")
best_indices_all = np.argsort(logprob_flat)[::-1]

shown = 0
seen = set()
for idx in best_indices_all:
    # Convert to tuple for hashable comparison (rounded to avoid float precision issues)
    fit_tuple = tuple(np.round(chains_flat[idx], 5))
    if fit_tuple in seen:
        continue
    seen.add(fit_tuple)
    
    best_fit = chains_flat[idx]
    chi_squared = -2 * logprob_flat[idx]
    
    with np.printoptions(precision=5, suppress=True, formatter={'all': lambda x: f"{x:.3f},"}, linewidth=500):
        print(f"Model {shown+1} (chi_sq={chi_squared:.2f}): {best_fit}")

    shown += 1
    if shown >= N:
        break

# --- 2. Show best models from the RESTRICTED 68% SUBSET ---
print("\n--- Top N Models from 68% Central Region ---")
if len(subset_chains) > 0:
    best_indices_subset = np.argsort(subset_logprob)[::-1]

    shown = 0
    seen = set()
    for idx in best_indices_subset:
        # Convert to tuple for hashable comparison
        fit_tuple = tuple(np.round(subset_chains[idx], 5))
        if fit_tuple in seen:
            continue
        seen.add(fit_tuple)
        
        best_fit = subset_chains[idx]
        chi_squared = -2 * subset_logprob[idx]

        with np.printoptions(precision=5, suppress=True, formatter={'all': lambda x: f"{x:.3f},"}, linewidth=500):
            print(f"Model {shown+1} (chi_sq={chi_squared:.2f}): {best_fit}")

        shown += 1
        if shown >= N:
            break
else:
    print("No models found in the subset.")

In [None]:
for reader in gc.get_backends()[0]:
    if isinstance(reader, emcee.backends.backend.Backend):
        samples = reader.get_chain()
        print(f'Number of steps: {samples.shape[0] * samples.shape[1]} (total); {samples.shape[0]} (per walker), ')
        print(f'Number of walkers: {samples.shape[1]}')
        print(f'Number of parameters: {dims}')

        try:
            tau = reader.get_autocorr_time()
            print(tau)
        except:
            print("Not burnt in yet")

        # Print off the current walker positions in a nice arrays
        # One line per walker in order
        PRINT_WALKERS = False
        if PRINT_WALKERS:
            with np.printoptions(precision=5, suppress=True, linewidth=500,  formatter={'all': lambda x: f"{x:.3f},"}):
                current = samples[-1]
                chisqr = -2 * reader.get_log_prob(flat=False)[-1]
                median_chisqr = np.median(chisqr)
                good = np.where(chisqr < median_chisqr)
                print(np.array2string(current[good]))
                print(chisqr[good])

        burn_number = 0 # TODO choose this by inspecting the chains above. wait for convergence in all parameters
        thin_number = 1
        flat_samples = reader.get_chain(discard=burn_number, thin=thin_number, flat=True)
        flat_samples.shape

In [None]:
fig, axes = plt.subplots(dims, figsize=(10, 2.5*dims), sharex=True)
for i in range(dims):
    ax = axes[i]
    ax.plot(chains[:, :, i], alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])

    # Set y-limits to cover 90% of the walkers locations
    min_val = np.nanpercentile(chains[:, :, i], 5)
    max_val = np.nanpercentile(chains[:, :, i], 95)
    ax.set_ylim(min_val, max_val)
    
    # label each walker number
    #for j in good_walkers:
    #    ax.text(10000, samples[10000, j, i], f'{j}', color='k', fontsize=6)

axes[-1].set_xlabel("step number")

In [None]:
# Select the nth walker
n = 0  # Replace with the desired walker index
p = 8
walker_samples = samples[:, n, :]
walker_chisqr = -2*reader.get_log_prob(discard=0, flat=False)[:, 0]

# Create a boolean array indicating whether parameter values were updated at each step
updated = np.any(np.diff(walker_samples, axis=0) != 0, axis=1)

# Add a False at the beginning since the first step has no previous step to compare
updated = np.insert(updated, 0, False)

with np.printoptions(threshold=np.inf, linewidth=np.inf, suppress=True, formatter={'float': '{:6.1f}'.format, 'bool': '{:6}'.format}):
    print(updated)
    print(walker_chisqr)
    print(walker_samples[:, p])
    # 

In [None]:
# The corner plot shows all 1D and 2D projections of the posterior probabilities of your parameters.
# This is useful because it quickly demonstrates all of the covariances between parameters. 
# Also, the way that you find the marginalized distribution for a parameter or set of parameters 
#   using the results of the MCMC chain is to project the samples into that plane and then make 
#   an N-dimensional histogram. 
# That means that the corner plot shows the marginalized distribution for each parameter independently 
#   in the histograms along the diagonal and then the marginalized two dimensional distributions 
#   in the other panels.
all_flat_samples = chains.reshape(-1, dims)

fig = corner.corner(all_flat_samples, labels=labels, range=param_ranges)

In [None]:
# Then print means of the posteriors
print("MEAN MODEL")
for i in range(10):
    mcmc = np.percentile(flat_samples[:, i], [16, 50, 84])
    q = np.diff(mcmc)
    txt = f"{labels[i]} = ${mcmc[1]:.3f}_{{-{q[0]:.3f}}}^{{{q[1]:.3f}}}$"
    display(Latex(txt))

# View Variances of fsat, LHMR from chains

In [None]:
# This will save off a .npy file with the array of fsat values
#save_from_log(PY_SRC_FOLDER + 'exec.out', overwrite=False)

# TODO From Blobs to error estimates
backends, folders = gc.get_backends()
save_from_backend(backends, overwrite=True)

In [None]:
fsat_std, fsatr_std, fsatb_std, fsat_mean, fsatr_mean, fsatb_mean = fsat_variance_from_saved()
#np.save(OUTPUT_FOLDER + 'std_fsat.npy', (fsat_std, fsatr_std, fsatb_std, fsat_mean, fsatr_mean, fsatb_mean))

plt.figure()
plt.errorbar(L_gal_bins, fsat_mean, yerr=fsat_std, fmt='.', color='k', label='All', capsize=3, alpha=0.7)
plt.errorbar(L_gal_bins, fsatr_mean, yerr=fsatr_std, fmt='.', color='r', label='Quiescent', capsize=3, alpha=0.7)
plt.errorbar(L_gal_bins, fsatb_mean, yerr=fsatb_std, fmt='.', color='b', label='Star-forming', capsize=3, alpha=0.7)
plt.xlabel('$L_{\mathrm{gal}}$')
plt.ylabel(r'$\langle f_{\mathrm{sat}} \rangle$')
plt.legend()
plt.xscale('log')
plt.xlim(1E7, 2E11)
plt.ylim(0.0, 1.0)
plt.tight_layout()
plt.show()

In [None]:
pp.LHMR_from_logs()

In [None]:
lsat_r_mean, lsat_r_std, lsat_b_mean, lsat_b_std = lsat_variance_from_saved()

data = np.loadtxt(LSAT_OBSERVATIONS_SDSS_FILE, skiprows=0, dtype='float')
pp.lsat_compare_plot(data, lsat_r_mean, lsat_b_mean, lsat_r_std, lsat_b_std)