# Modelling to Infer Void Parameters

### Current features:
- infers non-parametric luminsosity function
    - build mock with sam non-parametric function, or with Schechter
- building non-parametric density function
    - currently sampling splines from uniform dist, then exponentiating to get $\rho(r)$ points
    - these points are then used to calculate shell masses which are then normalised
    - it should be noted that currently the model only constrains the $\rho(r)$ nodes up to a multiplicative constant, with rescaling happening manually
- selection correction included, with fine integration grid
    - introduce variable grid density to improve speed
- infers latent params, mag and comoving dist, for each galaxy observed

In [None]:
import numpy as np
import math
import pandas as pd
from scipy import integrate, optimize
import scipy as sp
import numpyro
import numpyro.distributions as dist
from numpyro.infer.initialization import init_to_median, init_to_value
from jax import random, jit
from jax import numpy as jnp
from jax import vmap
from jaxopt import Bisection
# import arviz as az
from scipy.special import gammaincinv
from jax.scipy.special import gamma, gammaincc, gammainc, gammaln
from scipy.interpolate import interp1d
import pickle
import sys
from jax import numpy as jnp
from quadax import cumulative_simpson
from jax.scipy.stats.norm import logcdf as norm_logcdf

### For use on cluster

In [None]:
from mpi4py import MPI

# Set up MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()   # This process's ID
size = comm.Get_size()   # Total processes

### Functions for non-parametric luminosity

### Functions

In [None]:
# comoving distance from redshift
def redshift_to_dist(z, c, omega_m):
    
    """
    Takes inputs: redshift, c, H0, and matter density param
    Outputs: comoving distance
    """
    
    return (c / 100) * (z - 0.75 * omega_m * z**2)


# redshift from comoving distance
def dc_to_redshift(D_c, c, omega_m):
    
    """
    Takes inputs: comoving distance (in Mpc/h), c, and matter density param
    Outputs: redshift
    """
    
    D = (D_c * 100) / c
    return (1 - jnp.sqrt(1 - 3 * omega_m * D)) / (1.5 * omega_m)


# interpolation for non-parametric functions
def f_of_r(r, f_pts):
    """
    Interpolates f(r) for given spline points (f_pts)
    """
            
    # clip r to ensure within min/max of r_points
    r_clipped = jnp.clip(r, r_points[0], r_points[-1])
    return jnp.interp(r_clipped, r_points, f_pts)

### Selection correction functions

In [None]:
def dc2mu(dc):
    """Convert distance in Mpc to distance modulus."""
    z = dc_to_redshift(dc, c, omega_m)
    return 5 * jnp.log10(dc * (1 + z)) + 25

def mu2r(mu):
    """Convert distance modulus to distance in Mpc."""
    return 10 ** ((mu - 25) / 5)


def simpson2d(f_val, x_grid, y_grid):
    """Evaluate a 2D integral using Simpson's rule."""
    inner = cumulative_simpson(f_val, x=y_grid, axis=1, initial=0.0)
    outer = cumulative_simpson(inner, x=x_grid, axis=0, initial=0.0)
    return outer[-1, -1]


# def log_pdf_LF(M, alpha, M_star, M_abs_Sun):
#     """Simple Gaussian-like luminosity function. Replace with Schechter or other."""
#     norm = gamma(1 - alpha) * jnp.exp(-0.4*jnp.log(10)*(M_star - M_abs_Sun))
#     Ln_L_by_Lstar = -0.4*jnp.log(10)*(M - M_star)
#     log_pdf_L = -alpha * Ln_L_by_Lstar - jnp.exp(Ln_L_by_Lstar) - jnp.log(norm) #Wrt L
#     log_pdf_M = log_pdf_L + jnp.log(0.4*jnp.log(10)) - 0.4*jnp.log(10)*(M - M_abs_Sun) #Wrt M
#     return log_pdf_M



def log_integrand_p_det(M, dc, M_star, alpha, r_vals, pdf_vals, mlim, m_err, M_abs_Sun, log_norm_L):
    """Logarithmic integrand for the detection probability."""
    
    app_mag = M + dc2mu(dc)
    x = (mlim - app_mag) / m_err
    
    #Radial distribution.
    
    inv_pdf = jnp.interp(dc, r_vals, pdf_vals)
    
    ln_prior_dc = jnp.log(inv_pdf)
    
    return norm_logcdf(x) + ln_prior_dc + log_pdf_LF_trunc(M, alpha, M_star, M_abs_Sun, log_norm_L)

### Functions for non-parametric density

In [None]:
# -------------------------
# Analytic integral of r^2 * linear function
# -------------------------
def int_r2_linear(A, B, a, b):
    # integral of r^2 * (A + B r) dr
    return A * (b**3 - a**3) / 3.0 + B * (b**4 - a**4) / 4.0

# -------------------------
# Compute mass integral over each shell from node values
# -------------------------
def compute_shell_masses(r_nodes, edges, rho_nodes):
    n_shells = edges.shape[0] - 1
    n_nodes = r_nodes.shape[0]
    M = jnp.zeros((n_shells, n_nodes))

    for i in range(n_shells):
        a, b = edges[i], edges[i + 1]

        for j in range(n_nodes):
            # left half of linear hat
            if j > 0:
                s0, s1 = r_nodes[j-1], r_nodes[j]
                lo = jnp.maximum(a, s0)
                hi = jnp.minimum(b, s1)
                cond = hi > lo
                denom = s1 - s0 + 1e-12
                A = -s0 / denom
                B = 1.0 / denom
                M = M.at[i,j].set(M[i,j] + jnp.where(cond, int_r2_linear(A, B, lo, hi), 0.0))

            # right half of linear hat
            if j < n_nodes - 1:
                s0, s1 = r_nodes[j], r_nodes[j+1]
                lo = jnp.maximum(a, s0)
                hi = jnp.minimum(b, s1)
                cond = hi > lo
                denom = s1 - s0 + 1e-12
                A = s1 / denom
                B = -1.0 / denom
                M = M.at[i,j].set(M[i,j] + jnp.where(cond, int_r2_linear(A, B, lo, hi), 0.0))

    # compute shell masses
    shell_masses = M @ rho_nodes
    return shell_masses

# -------------------------
# Build normalized PDF and CDF on grid
# -------------------------
def build_pdf_cdf(r_nodes, edges, rho_nodes, r_grid):
    # compute shell masses for exact normalization
    shell_masses = compute_shell_masses(r_nodes, edges, rho_nodes)
    total_mass = jnp.sum(shell_masses)
    shell_mass_norm = shell_masses / total_mass

    # linear interpolation of rho
    rho_grid = jnp.interp(r_grid, r_nodes, rho_nodes)

    # PDF = r^2 * rho(r) / total_mass
    pdf_grid = r_grid**2 * rho_grid / total_mass
    # print('pdf norm', jnp.trapezoid(pdf_grid, r_grid))

    # CDF via cumulative trapezoid
    dr = r_grid[1] - r_grid[0]
    cdf_grid = jnp.cumsum(pdf_grid) * dr
    # print('cdf final val = ', cdf_grid[-1])
    cdf_grid = cdf_grid / cdf_grid[-1]

    return pdf_grid, cdf_grid, shell_mass_norm, rho_grid, total_mass

# -------------------------
# Inverse CDF sampling
# -------------------------
def inv_cdf_sample(r_grid, cdf_grid, u):
    """
    u: uniform random values [0,1]
    returns r samples
    """
    return jnp.interp(u, cdf_grid, r_grid)

### Model for HMC

In [None]:
def model(app_mag_obs, m_err, z_obs, z_err):
    

    # Priors
    dc_max_MCMC = dc_max_ground
    alpha = numpyro.sample("alpha", dist.Uniform(.5, 1.2))
    M_star = numpyro.sample("M_star", dist.Uniform(-23.0, -21.0))
    # concentration = jnp.ones(n_splines)
    # shell_mass = numpyro.sample("shell_mass", dist.Dirichlet(concentration))
    
    # # convert sampled shell masses to rho(r) spline points
    # rho_pts = shell_mass / shell_weights
    # numpyro.deterministic("f_pts", rho_pts)

    M_abs_Sun = M_abs_Sun_ground

    log_norm_L = schechter_log_norm_L(alpha, M_star, M_abs_Sun)

    sigma_f = 1.0   # std dev on log(rho)

    # log rho at spline knots
    f = numpyro.sample("f_pts", dist.Normal(0.0, sigma_f).expand([n_splines]))

    rho_pts = jnp.exp(f)
    
    # computing volume pdf for sampled spline points
    r_vals = jnp.linspace(0, dc_max_ground, 10000)
    
    pdf_grid, cdf_grid, shell_mass, rho_grid, total_mass = build_pdf_cdf(r_points, edges, rho_pts, r_vals)

    
    numpyro.deterministic("rho_pts", rho_pts)
    numpyro.deterministic("shell_mass", shell_mass)

    
    # calculating phi_star for each point in chain
    phi_star = N_obs / (gamma(1 - alpha) * abs(M_star))
    numpyro.deterministic("phi_star", phi_star)
    
    # mag selection correction
    #M_star = M_abs_Sun - 2.5*jnp.log10(L_star)
    M_min = M_BRIGHT
    M_max = M_FAINT
    M_grid = jnp.linspace(M_min, M_max, 1001)


    dc_min = 1e-5
    dc_max = dc_max_MCMC

    dc_grid = jnp.linspace(dc_min, dc_max, 1001)

    mlim = 18
    m_err = 0.05
    
    X, Y = jnp.meshgrid(M_grid, dc_grid, indexing='ij')
    log_integrand = log_integrand_p_det(X, Y, M_star, alpha, r_vals, pdf_grid, mlim, m_err, M_abs_Sun, log_norm_L)


    # This is p(S = 1 | Lambda) from the Overleaf notation.
    p_det = simpson2d(jnp.exp(log_integrand), M_grid, dc_grid)
    numpyro.factor("selection_effect", -N_obs * jnp.log(p_det))

    
    # plate over galaxies
    with numpyro.plate("data", N_obs):
        
        # sample absolute magnitude and comoving distance
        M_true = numpyro.sample("M_true", dist.Uniform(M_BRIGHT, M_FAINT))
        d_c_true = numpyro.sample("d_c_true", dist.Uniform(0, dc_max_MCMC))
        

        # apply volume prior
        inv_pdf = jnp.interp(d_c_true, r_vals, pdf_grid)
        log_prior = jnp.log(inv_pdf)
        
        numpyro.factor("volume_prior", log_prior)
        
        # compute redshift from comoving dist
        z_true = dc_to_redshift(d_c_true, c, omega_m)
        
        # likelihood term for redshift
        numpyro.sample("z_likelihood", dist.Normal(z_true, z_err), obs=z_obs)
        
        # compute luminosity distance
        d_L = (1 + z_true) * d_c_true
        
        # predict app mag from abs mag and lum dist
        m_model = M_true + 5 * jnp.log10(d_L) + 25 # magnitude in units related to h
        
        
        # apply Schechter prior on M_true
        numpyro.factor("M_true_prior", log_pdf_LF_trunc(M_true, alpha, M_star, M_abs_Sun, log_norm_L))
        
        # likelihood term for apparent magnitude
        numpyro.sample("m_obs", dist.Normal(m_model, m_err), obs=app_mag_obs)