## 0. Packages

In [None]:
# Numerical and Scientific Computing
import numpy as np
import pandas as pd
import scipy as sp
import math
from numba import jit


from scipy import special
from scipy.integrate import quad
from scipy import integrate
from scipy import stats
from scipy.interpolate import interp1d

# Plotting and Data Visualization
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
from matplotlib import gridspec
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, AutoMinorLocator)
import matplotlib as mpl
from matplotlib import ticker, cm
from matplotlib.colors import LinearSegmentedColormap
from tabulate import tabulate


from matplotlib import rc
from matplotlib.ticker import MaxNLocator


import pyregion

# File Handling
import os
import glob
import csv

# Astronomical

from astropy.cosmology import FlatLambdaCDM
from astropy.cosmology import Planck15
import astropy.units as u
from astropy.io import fits
from astropy.table import Table
from astroML.stats import binned_statistic_2d


# Other
from tqdm.notebook import tqdm


## 1. Functions

In [None]:
c0 = 3e5
H_0 = 70
Omega_l = 0.7
Omega_m = 0.3

lim_deltaz = 2

cosmo = FlatLambdaCDM(H0 = H_0, Om0 = Omega_m)

###########################################################################
############################ LSS functions ################################
###########################################################################

def M_lim(z):
    """Fitting function for the mass completeness limit Weaver et. al 2022"""
    return np.log10(-1.51e6 * (1+z) + 6.8e7 * (1+z)**2)

def M_lim_ks(z):
    """Fitting function for mass completeness limit (on K_s) Weaver et al 2022"""
    return np.log10(-3.55e8 * (1+z) + 2.7e8 * (1+z)**2)

def slice_width(z):
    """Calculate the width of each redshift slice of size _physical_width_ (Mpc h^-1)"""
    return physical_width * 100 / c0 * np.sqrt(Omega_m * (1+z)**3 + Omega_l)


def redshift_bins(zmin, zmax):
    """returns the slice centers and widths, given a physical length in (Mpc h^-1) """
    centers = []
    centers.append(zmin + 0.5 * slice_width(zmin))

    i = 0
    while (centers[i] + slice_width(centers[i]) < zmax ):
        centers.append(centers[i] + slice_width(centers[i]))
        i += 1

    centers = np.array(centers)

    "redshift edges"
    edges = np.zeros((len(centers), 2))

    for i in range(0, len(centers)):
        edges[i, 0] = centers[i] - slice_width(centers[i]) / 2
        edges[i, 1] = centers[i] + slice_width(centers[i]) / 2

    return (centers, edges)


def cartesian_from_polar(phi, theta):
    """ 
    phi, theta : float or numpy.array
        azimuthal and polar angle in radians.
    Returns
    -------
    nhat : numpy.array
        unit vector(s) in direction (phi, theta).
    """
    x = np.sin(theta) * np.cos(phi)
    y = np.sin(theta) * np.sin(phi)
    z = np.cos(theta)
    return np.array([x, y, z])

def cos_dist(alpha, delta, alpha0, delta0):
    """ gets all angles in [deg]"""
    phi = alpha * np.pi / 180
    theta = np.pi / 2 - delta * np.pi / 180
    phi0 = alpha0 * np.pi / 180
    theta0 = np.pi / 2 - delta0 * np.pi / 180
    
    x = cartesian_from_polar(phi, theta)
    x0 = cartesian_from_polar(phi0, theta0)
    cosdist = np.tensordot(x, x0, axes=[[0], [0]])
    return np.clip(cosdist, 0, 1)

def logsinh(x):
    if np.any(x < 0):
        raise ValueError("logsinh only valid for positive arguments")
    return x + np.log(1-np.exp(-2*x)) - np.log(2)

def Log_K(alpha, delta, alpha0, delta0, kappa):
    norm = -np.log(4 * np.pi / kappa) - logsinh(kappa)
    return norm + cos_dist(alpha, delta, alpha0, delta0) * kappa

def σ_k(X0, b, points):
    kappa = 1 / (b * np.pi / 180)**2
    X0_x = points[X0, 0]
    X0_y = points[X0, 1]
    rem = np.delete(points, X0, axis = 0)
    arr = rem[:, 2] * np.exp(Log_K(rem[:, 0], rem[:, 1], X0_x, X0_y, kappa))
    return np.sum(arr)
    
def LCV(b, points):
    N = len(points)
    arr1 = [np.log(σ_k(i, b, points)) for i in range(0, len(points))]
    return (1 / N) * np.sum(arr1)

def σ_k_gaussian(X0, b, points):
    X0_x = points[X0, 0]
    X0_y = points[X0, 1]
    rem = np.delete(points, X0, axis = 0)

    Cosdists = cos_dist(rem[:, 0], rem[:, 1], X0_x, X0_y)
    arr = rem[:, 2] * norm.pdf(np.arccos(Cosdists[:]), loc = 0, scale = b * np.pi / 180)
    return np.sum(arr)

def σ(alpha, delta, b_i, points):
    kappa = 1 / (b_i * np.pi / 180)**2
    arr2 = points[:, 2] * np.exp(Log_K(points[:, 0], points[:, 1], alpha, delta, kappa))
    return np.sum(arr2)

def Adaptive_b(b, points):
    g_i = np.array([np.log(points[i, 4] * σ(points[i, 0], points[i, 1], b, points)) for i in range(0, len(points))])
    log_g = 1 / len(points) * np.sum(g_i)
    b_i = np.array([(b * (points[i, 4] * σ(points[i, 0], points[i, 1], b, points) / np.exp(log_g))** -0.5) for i in tqdm(range(0, len(points)))])
    return b_i

def divider_NUV(rj):
    return (3*rj+1)

In [None]:
def setup(work_path='.'):
    '''
    Set up all of the necessary directories
    '''
    for subdir in ('inputs', 'outputs', 'bin', 
                   'outputs/plots', 'outputs/weights', 'outputs/density'):
        path = os.path.join(work_path, subdir)
        if not os.path.exists(path):
            os.makedirs(path)
            print(f'Built directory: {os.path.abspath(path)}')
    
    outputs_dir = os.path.join(work_path, 'outputs')
    plots_dir = os.path.join(work_path, 'outputs', 'plots')
    inputs_dir = os.path.join(work_path, 'inputs')
    weight_dir = os.path.join(work_path, 'outputs', 'weights')
    density_dir = os.path.join(work_path, 'outputs', 'density') 
    return outputs_dir, plots_dir, inputs_dir, weight_dir, density_dir

In [None]:
cat_dir = "where you want to set up the catalog directories"

outputs_dir, plots_dir, inputs_dir, weights_dir, density_dir = setup(work_path=cat_dir)

## 2 Preparing Data

In [None]:
z_min, z_max = 0.4, 9.5

physical_width = 35 # h^-1 Mpc

slice_centers, z_edges = redshift_bins(z_min, z_max)

z_width = z_edges[:, 1] - z_edges[:, 0]

In [None]:
Data = "path to your data file"

## 3. Weights

#### Assuming Gaussian Photo z PDF
weight of galaxy "g" in slice "s"
$w[\text{g}][\text{s}]=\int_{b}^{c} \frac{1}{\sqrt{2 \pi} \sigma}\exp^{\frac{-(x- \mu)^{2}}{2\sigma^{2}}}dx$
    
$w[\text{g}][\text{s}]=\frac{1}{2}[erf(a(c-\mu))-erf(a(b-\mu))]$
    
$a=\frac{1}{\sqrt{2}\sigma}$ 

zPDFs are assumed to be Gaussian with mean: zPDF and std= (zPDF_u68 - zPDF_l68) / 2

weights[i][j]: weight of ith galaxy in the jth slice <br>
first column is Ids, then weights in n columns (redshift slices)

In [None]:
threshold = 0.05 # threshold for low weight cutoff

## 3.1 Flagged Data

In [None]:
# --- Step 1: Prepare Gaussian parameters ---
mu = Data['zPDF'].to_numpy()
delta_z = (Data['zPDF_u68'] - Data['zPDF_l68']).to_numpy()
sigma = delta_z / 2
z_edges_array = np.array(z_edges)
n_slices = len(z_edges_array)

# --- Step 2: Compute Gaussian weights---
weights_block = 0.5 * (
    special.erf((z_edges_array[:, 1] - mu[:, None]) / (np.sqrt(2) * sigma[:, None])) -
    special.erf((z_edges_array[:, 0] - mu[:, None]) / (np.sqrt(2) * sigma[:, None]))
)

# --- Step 3: Normalize weights ---
row_sums = np.sum(weights_block, axis=1, keepdims=True)
row_sums[row_sums == 0] = 1  # prevent divide-by-zero
weights_block = weights_block / row_sums

# --- Step 4: Add IDs to form full weights matrix ---
weights = np.zeros((len(Data), n_slices + 1))
weights[:, 0] = Data['id'].to_numpy()
weights[:, 1:] = weights_block

# --- Step 5: Apply threshold to find significant contributions ---
mask = weights_block > threshold
ind = np.column_stack(np.nonzero(mask))  # shape: (n_selected, 2)
ind = ind[np.argsort(ind[:, 0])]  # sort by galaxy index

# --- Step 6: Group galaxies by slice ---
gals_bin = [np.unique(ind[ind[:, 1] == i, 0]) for i in range(n_slices)]

# --- Step 7: Re-normalize using only selected slice weights ---
W = weights_block.copy()
selected_weights = W[ind[:, 0], ind[:, 1]]
unique_ids, inverse_idx = np.unique(ind[:, 0], return_inverse=True)
sums = np.bincount(inverse_idx, weights=selected_weights)
sums[sums == 0] = 1
W[ind[:, 0], ind[:, 1]] = selected_weights / sums[inverse_idx]

# Compute count and median in redshift slices
count_in_zslice = np.array([
    np.sum((Data['zPDF'] > z_edges[i, 0]) & (Data['zPDF'] < z_edges[i, 1])) 
    for i in range(len(slice_centers))
])

delta_z_median = np.array([
    np.median((Data.loc[(Data['zPDF'] > z_edges[i, 0]) & 
                        (Data['zPDF'] < z_edges[i, 1]), 'zPDF_u68'] - 
               Data.loc[(Data['zPDF'] > z_edges[i, 0]) & 
                        (Data['zPDF'] < z_edges[i, 1]), 'zPDF_l68']) / 2)
    for i in range(len(slice_centers))
])

# Compute normalized dz / (1 + z) median
normalized_delta_z_median = np.array([
    np.median((Data.loc[(Data['zPDF'] > z_edges[i, 0]) & (Data['zPDF'] < z_edges[i, 1]), 'zPDF_u68'] - 
               Data.loc[(Data['zPDF'] > z_edges[i, 0]) & (Data['zPDF'] < z_edges[i, 1]), 'zPDF_l68']) / 
              (2 * (1 + Data.loc[(Data['zPDF'] > z_edges[i, 0]) & (Data['zPDF'] < z_edges[i, 1]), 'zPDF'])))
    if np.any((Data['zPDF'] > z_edges[i, 0]) & (Data['zPDF'] < z_edges[i, 1])) else np.nan
    for i in range(len(slice_centers))
])

In [None]:
# Plot with updated values
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(10, 8), gridspec_kw={'height_ratios': [1, 0.8]})

# Top bar plot dark olive
ax[0].bar(slice_centers, count_in_zslice, width=z_width, color='indigo')

ax[0].set_xlim(z_min, z_max)
ax[0].set_xticks([])
ax[0].set_ylabel('Number of galaxies', fontsize=16)
ax[0].yaxis.set_major_locator(MaxNLocator(nbins=5))
ax[0].tick_params(axis='y', labelsize=12)
ax[0].set_yscale('log')


# Bottom hexbin plot
hb = ax[1].hexbin(Data['zPDF'], sigma / (1 + Data['zPDF']), gridsize=20, cmap='viridis', bins='log')
step = 4
ax[1].plot(slice_centers[::step], normalized_delta_z_median[::step],
           color='red', linestyle='dashed', linewidth=2)

ax[1].set_xlabel('z', fontsize=16)
ax[1].set_ylabel(r'$\sigma_z / (1 + z)$', fontsize=16)
ax[1].set_xlim(z_min, z_max)
ax[1].xaxis.set_minor_locator(AutoMinorLocator())

# Add colorbar
cb = fig.colorbar(hb, ax=ax[1], orientation='horizontal', pad=0.3, fraction=0.09)
cb.set_label('counts', fontsize=16)
cb.ax.tick_params(labelsize=12)

plt.tick_params(axis='both', which='major', labelsize=12)

# Adjust spacing and save the plot
plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.4, hspace=0.1)
plt.savefig(os.path.join(plots_dir, f'histogram_th{threshold}_lengh{physical_width}_flagged.png'),
            format='png', dpi=600, bbox_inches='tight')
plt.savefig(os.path.join(plots_dir, f'histogram_th{threshold}_lengh{physical_width}_flagged.pdf'), bbox_inches='tight')
plt.show()

In [None]:
np.save(os.path.join(weights_dir, f'weights_unthresholded_normalized_thresh{threshold}_lengh{physical_width}.npy'), weights)
np.save(os.path.join(weights_dir, f'weightsBlock_unthresholded_normalized_thresh{threshold}_lengh{physical_width}.npy'), weights_block)
np.save(os.path.join(weights_dir, f'weightsBlock_thresh{threshold}_normalized_lengh{physical_width}.npy'), W)
np.save(os.path.join(weights_dir, f'normalized_delta_z_median_thresh{threshold}_lengh{physical_width}.npy'), normalized_delta_z_median)
np.save(os.path.join(weights_dir, f'delta_z_median_thresh{threshold}_lengh{physical_width}.npy'), delta_z_median)
np.save(os.path.join(weights_dir, f'count_in_zslice_thresh{threshold}_lengh{physical_width}.npy'), count_in_zslice)