# Comparing NumPy and JAX Backend Performance for H0LogLikelihood

This notebook demonstrates how to select the numerical backend (NumPy or JAX) for the `H0LogLikelihood` function in the `gwsiren` pipeline and provides a basic comparison of their performance. For more comprehensive benchmarking, especially on GPU hardware, please use the `scripts/bench_jax.py` script.

## 1. Setup and Imports

In [None]:
import numpy as np
import time
import logging
import sys

# Configure basic logging to see messages from the backend selection
logging.basicConfig(stream=sys.stdout, level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

# Import necessary components from gwsiren
# Ensure gw-siren-pipeline is installed and accessible in your PYTHONPATH
# Typically, if you are in the notebook inside 'examples/', and the package is in '../gw-siren-pipeline'
# you might need to adjust sys.path or install the package in editable mode.
try:
    from gw_siren_pipeline.gwsiren.h0_mcmc_analyzer import (
        get_log_likelihood_h0,
        DEFAULT_SIGMA_V_PEC,
        DEFAULT_C_LIGHT,
        DEFAULT_OMEGA_M,
        DEFAULT_H0_PRIOR_MIN,
        DEFAULT_H0_PRIOR_MAX,
        DEFAULT_ALPHA_PRIOR_MIN,
        DEFAULT_ALPHA_PRIOR_MAX
    )
    # CONFIG is loaded when gwsiren.config is imported, which happens internally
    # from gw_siren_pipeline.gwsiren import CONFIG # To access CONFIG.backend if needed
except ImportError as e:
    print(f"Error importing gwsiren components: {e}\n" \
          "Please ensure 'gw-siren-pipeline' is installed (e.g., 'pip install -e .') " \
          "from the 'gw-siren-pipeline' directory and your PYTHONPATH is set correctly.")
    # You might need to add the project root to sys.path if running directly from examples
    # import os
    # module_path = os.path.abspath(os.path.join('..'))
    # if module_path not in sys.path:
    #     sys.path.append(module_path)
    # from gw_siren_pipeline.gwsiren.h0_mcmc_analyzer import get_log_likelihood_h0, ...

# JAX setup (optional, but good practice for consistency)
try:
    import jax
    jax.config.update("jax_enable_x64", True)
    JAX_AVAILABLE = True
    print(f"JAX available. Version: {jax.__version__}. Default device: {jax.default_backend()}, Devices: {jax.devices()}")
except ImportError:
    JAX_AVAILABLE = False
    print("JAX not available. JAX-related cells will be skipped or will show NumPy behavior.")

## 2. Data Generation

We'll generate a small mock dataset. For more extensive benchmarks, use `scripts/bench_jax.py`.

In [None]:
def generate_notebook_mock_data(num_gw_samples: int, num_hosts: int):
    """Generates simplified mock data for notebook demonstration."""
    rng = np.random.default_rng(seed=42)
    dL_gw_samples = rng.normal(loc=700, scale=70, size=num_gw_samples).astype(np.float64)
    z_values = rng.uniform(low=0.01, high=0.1, size=num_hosts).astype(np.float64)
    mass_proxy_values = rng.lognormal(mean=10, sigma=1, size=num_hosts).astype(np.float64)
    mass_proxy_values = np.maximum(mass_proxy_values, 1e-5)
    z_err_values = rng.uniform(low=0.001, high=0.002, size=num_hosts).astype(np.float64)
    return dL_gw_samples, z_values, mass_proxy_values, z_err_values

# Small dataset for quick notebook execution
NUM_GW_SAMPLES_NB = 100
NUM_HOSTS_NB = 1000 

dL_gw, z_hosts, mass_proxy, z_err_hosts = generate_notebook_mock_data(NUM_GW_SAMPLES_NB, NUM_HOSTS_NB)

print(f"Generated mock data: {dL_gw.shape[0]} GW samples, {z_hosts.shape[0]} host galaxies.")

## 3. Backend Selection and Likelihood Instantiation

The `get_log_likelihood_h0` factory function uses the `backend_preference` argument (`"auto"`, `"numpy"`, or `"jax"`) to determine which numerical backend to use. The `"auto"` mode prioritizes JAX if a GPU is available, otherwise it falls back to NumPy.

In [None]:
# Common parameters for the likelihood function
common_params = {
    "dL_gw_samples": dL_gw,
    "host_galaxies_z": z_hosts,
    "host_galaxies_mass_proxy": mass_proxy,
    "host_galaxies_z_err": z_err_hosts,
    "sigma_v": DEFAULT_SIGMA_V_PEC,
    "c_val": DEFAULT_C_LIGHT,
    "omega_m_val": DEFAULT_OMEGA_M,
    "h0_min": DEFAULT_H0_PRIOR_MIN,
    "h0_max": DEFAULT_H0_PRIOR_MAX,
    "alpha_min": DEFAULT_ALPHA_PRIOR_MIN,
    "alpha_max": DEFAULT_ALPHA_PRIOR_MAX
}

# Instantiate with NumPy backend
print("\n--- Instantiating NumPy backend ---")
log_L_numpy_instance = get_log_likelihood_h0(**common_params, backend_preference="numpy")
print(f"NumPy likelihood instance: {log_L_numpy_instance}")
print(f"NumPy backend selected: {log_L_numpy_instance.backend_name}, xp module: {log_L_numpy_instance.xp}")

# Instantiate with JAX backend (if available)
log_L_jax_instance = None
if JAX_AVAILABLE:
    print("\n--- Instantiating JAX backend ---")
    log_L_jax_instance = get_log_likelihood_h0(**common_params, backend_preference="jax")
    print(f"JAX likelihood instance: {log_L_jax_instance}")
    print(f"JAX backend selected: {log_L_jax_instance.backend_name}, xp module: {log_L_jax_instance.xp}")
    if log_L_jax_instance._jitted_likelihood_core:
        print("JAX core likelihood function has been JIT-compiled.")
    else:
        print("JAX core likelihood function was NOT JIT-compiled (e.g. JAX available but no GPU for 'auto', or JIT failed).")
else:
    print("\nSkipping JAX backend instantiation as JAX is not available.")

## 4. Basic Timing Comparison

Let's perform a simple timing comparison. Note that for JAX, the first call includes JIT compilation overhead, so warmup calls are important.

In [None]:
theta_sample = np.array([70.0, 0.0]) # [H0, alpha_g]
num_warmup_calls = 5
num_timed_calls = 20

def time_evaluations(likelihood_instance, theta, num_warmup, num_timed, backend_name):
    if likelihood_instance is None:
        print(f"Skipping timing for {backend_name} as instance is not available.")
        return None
    
    xp_module = likelihood_instance.xp
    theta_device = xp_module.asarray(theta) 
    
    print(f"\nWarmup for {backend_name} ({num_warmup} calls)...")
    for _ in range(num_warmup):
        res = likelihood_instance(theta_device)
        if hasattr(res, 'block_until_ready'): # For JAX
            res.block_until_ready()
            
    print(f"Timing {backend_name} ({num_timed} calls)...")
    start_time = time.perf_counter()
    for _ in range(num_timed):
        res = likelihood_instance(theta_device)
        if hasattr(res, 'block_until_ready'): # For JAX
            res.block_until_ready()
    end_time = time.perf_counter()
    
    total_time = end_time - start_time
    avg_time = total_time / num_timed
    print(f"{backend_name} average time per call: {avg_time:.6f} seconds")
    return avg_time

# Time NumPy
time_numpy = time_evaluations(log_L_numpy_instance, theta_sample, num_warmup_calls, num_timed_calls, "NumPy")

# Time JAX (if available)
time_jax = None
if JAX_AVAILABLE and log_L_jax_instance:
    time_jax = time_evaluations(log_L_jax_instance, theta_sample, num_warmup_calls, num_timed_calls, "JAX")

if time_numpy and time_jax:
    speedup = time_numpy / time_jax
    print(f"\nJAX speed-up vs NumPy: {speedup:.2f}x")
elif time_numpy and not JAX_AVAILABLE:
    print("\nJAX benchmark not run as JAX is unavailable.")

## 5. Alternative: Using `scripts/bench_jax.py`

For more robust and configurable benchmarking, you can use the dedicated script `scripts/bench_jax.py`. This script allows you to control the number of GW samples, host galaxies, evaluations, and warmup iterations. 

You can run it from your terminal (ensure you are in the repository root, or adjust the path to the script):

In [None]:
# Example: Running the benchmark script from the notebook (if in 'examples/' directory)
# Using ! to execute shell commands. Output will appear in the notebook.
!python ../gw-siren-pipeline/scripts/bench_jax.py --num_hosts=500 --num_gw_samples=50 --num_evals=10 --num_warmup=2

The output of `bench_jax.py` will provide average timings for NumPy and JAX (if available on your system), along with the JAX device used (CPU/GPU/TPU) and the speed-up factor.

**Note:** Significant speed-ups with JAX are typically observed with larger datasets (more host galaxies and GW samples) and when JAX can utilize GPU or TPU hardware. For small datasets on CPU, the overhead of JAX (including JIT compilation) might make it slower than or comparable to NumPy.