In [1]:
import numpy as np
import struct
import matplotlib.pyplot as plt
import numpy as np

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import blackjax

from cosmopower_jax.cosmopower_jax import CosmoPowerJAX as CPJ


In [2]:
# Base path to the data files
base_path = 'data/plik_lite_v22_TT.clik/clik/lkl_0/_external/'

# ============================================================
# 1. Read TT Power Spectrum (first 215 lines)
# ============================================================
data = np.loadtxt(f'{base_path}cl_cmb_plik_v22.dat')
tt_data = data[0:215]  # Extract only TT spectrum

# Split into columns
tt_ell = tt_data[:, 0]      # Multipole moments
tt_cl = tt_data[:, 1]       # C_ℓ values
tt_err = tt_data[:, 2]      # Uncertainties

print(f"TT spectrum: {len(tt_ell)} data points")
print(f"Multipole range: ℓ = {int(tt_ell[0])} to {int(tt_ell[-1])}")

# ============================================================
# 2. Read TT Covariance Matrix (top-left 215×215 block)
# ============================================================
with open(f'{base_path}c_matrix_plik_v22.dat', 'rb') as f:
    # Skip 4-byte Fortran record marker
    marker1 = struct.unpack('<i', f.read(4))[0]
    
    # Read the full 613×613 covariance matrix
    cov_full = np.fromfile(f, dtype='<f8', count=613*613)
    cov_full = cov_full.reshape(613, 613)
    
    # Extract only the TT block (top-left 215×215)
    tt_cov = cov_full[0:215, 0:215]

print(f"TT covariance matrix shape: {tt_cov.shape}")
print(f"Diagonal range: {tt_cov.diagonal().min():.3e} to {tt_cov.diagonal().max():.3e}")

# ============================================================
# Optional: Verify the data
# ============================================================
print(f"\nFirst few data points:")
print(f"{'ℓ':>6} {'C_ℓ':>12} {'σ(C_ℓ)':>12} {'sqrt(Cov_ii)':>12}")
print("-" * 48)
for i in range(5):
    print(f"{int(tt_ell[i]):>6} {tt_cl[i]:>12.6e} {tt_err[i]:>12.6e} {np.sqrt(tt_cov[i,i]):>12.6e}")


TT spectrum: 215 data points
Multipole range: ℓ = 32 to 2492
TT covariance matrix shape: (215, 215)
Diagonal range: 3.013e-10 to 4.804e-01

First few data points:
     ℓ          C_ℓ       σ(C_ℓ) sqrt(Cov_ii)
------------------------------------------------
    32 6.527202e+00 6.931048e-01 6.931048e-01
    37 5.274407e+00 5.187308e-01 5.187308e-01
    42 4.964522e+00 4.083057e-01 4.083057e-01
    47 3.800712e+00 3.321527e-01 3.321527e-01
    52 3.030021e+00 2.766162e-01 2.766162e-01


In [3]:
# omega_b, omega_cdm, h, tau, n_s, ln10^10A_s
cosmo_params = np.array([0.025, 0.11, 0.68, 0.1, 0.97, 3.1])
emulator = CPJ(probe='cmb_tt')
emulator_predictions = emulator.predict(cosmo_params)
emulator_predictions

I0000 00:00:1764094824.468772 1915819 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Array([1.5674932e-10, 7.3293843e-11, 4.1095505e-11, ..., 1.1857567e-17,
       1.1811569e-17, 1.1766628e-17], dtype=float32)

In [4]:
# Verify JAX linalg is available
print(f"JAX version: {jax.__version__}")
print(f"jnp.linalg available: {hasattr(jnp, 'linalg')}")
print(f"jnp.linalg.inv available: {hasattr(jnp.linalg, 'inv')}")
print(f"jnp.linalg.slogdet available: {hasattr(jnp.linalg, 'slogdet')}")

JAX version: 0.4.16
jnp.linalg available: True
jnp.linalg.inv available: True
jnp.linalg.slogdet available: True


In [5]:
tt_ell.astype(int)

array([  32,   37,   42,   47,   52,   57,   62,   67,   72,   77,   82,
         87,   92,   97,  104,  113,  122,  131,  140,  149,  158,  167,
        176,  185,  194,  203,  212,  221,  230,  239,  248,  257,  266,
        275,  284,  293,  302,  311,  320,  329,  338,  347,  356,  365,
        374,  383,  392,  401,  410,  419,  428,  437,  446,  455,  464,
        473,  482,  491,  500,  509,  518,  527,  536,  545,  554,  563,
        572,  581,  590,  599,  608,  617,  626,  635,  644,  653,  662,
        671,  680,  689,  698,  707,  716,  725,  734,  743,  752,  761,
        770,  779,  788,  797,  806,  815,  824,  833,  842,  851,  860,
        869,  878,  887,  896,  905,  914,  923,  932,  941,  950,  959,
        968,  977,  986,  995, 1004, 1013, 1022, 1031, 1040, 1049, 1058,
       1067, 1076, 1085, 1094, 1103, 1112, 1121, 1130, 1139, 1148, 1157,
       1166, 1175, 1184, 1193, 1202, 1211, 1220, 1229, 1238, 1247, 1256,
       1265, 1274, 1283, 1292, 1301, 1310, 1319, 13

In [6]:
# Find which emulator modes match tt_ell multipoles
# First, check the shapes to understand what we're working with
print(f"emulator.modes shape: {emulator.modes.shape}")
print(f"tt_ell shape: {tt_ell.shape}")
print(f"emulator_predictions shape: {emulator_predictions.shape}")

# Convert to numpy arrays for comparison
emulator_modes_np = np.array(emulator.modes)
tt_ell_int = tt_ell.astype(int)

# Create a boolean mask for which emulator modes are in tt_ell
mask = np.isin(emulator_modes_np, tt_ell_int)
print(f"\nNumber of matching modes: {mask.sum()} out of {len(emulator_modes_np)}")

# Index the predictions using the mask
predicted_tt_cl = emulator_predictions[mask]
print(f"predicted_tt_cl shape: {predicted_tt_cl.shape}")
predicted_tt_cl

emulator.modes shape: (2507,)
tt_ell shape: (215,)
emulator_predictions shape: (2507,)

Number of matching modes: 215 out of 2507
predicted_tt_cl shape: (215,)


Array([8.68605860e-13, 7.06134532e-13, 5.94297615e-13, 5.13066423e-13,
       4.52240811e-13, 4.05064247e-13, 3.67656372e-13, 3.37460582e-13,
       3.12741749e-13, 2.92248078e-13, 2.75035203e-13, 2.60546332e-13,
       2.47940340e-13, 2.37095201e-13, 2.23922457e-13, 2.09723759e-13,
       1.97832243e-13, 1.87486582e-13, 1.78028382e-13, 1.69116376e-13,
       1.60613818e-13, 1.52305780e-13, 1.44078719e-13, 1.35841303e-13,
       1.27644152e-13, 1.19394268e-13, 1.11140589e-13, 1.02960019e-13,
       9.48929385e-14, 8.69922962e-14, 7.92757515e-14, 7.17955019e-14,
       6.46449311e-14, 5.78427551e-14, 5.14443936e-14, 4.54831214e-14,
       3.99786032e-14, 3.49342950e-14, 3.03850064e-14, 2.63152215e-14,
       2.27160972e-14, 1.95896867e-14, 1.69014650e-14, 1.46389265e-14,
       1.27658165e-14, 1.12473744e-14, 1.00465739e-14, 9.13311581e-15,
       8.45837784e-15, 7.99240638e-15, 7.69386623e-15, 7.52764956e-15,
       7.46276938e-15, 7.46880195e-15, 7.51387596e-15, 7.58019949e-15,
      

In [7]:
# Precompute indices for efficient JAX-based indexing
# Precompute indices once - this maps tt_ell to positions in emulator.modes
emulator_modes_np = np.array(emulator.modes)
tt_ell_int = tt_ell.astype(int)

# For each multipole in tt_ell, find its index in emulator.modes
indices = np.array([np.where(emulator_modes_np == ell)[0][0] 
                    for ell in tt_ell_int])

# Convert to JAX array (immutable, can be captured in JIT)
indices_jax = jnp.array(indices)

print(f"Precomputed {len(indices_jax)} indices for data selection")
print(f"First few indices: {indices_jax[:5]}")
print(f"First few multipoles: {tt_ell_int[:5]}")
print(f"Corresponding emulator modes: {emulator_modes_np[indices[:5]]}")

Precomputed 215 indices for data selection
First few indices: [30 35 40 45 50]
First few multipoles: [32 37 42 47 52]
Corresponding emulator modes: [32 37 42 47 52]


In [None]:
@jax.jit
def log_likelihood(params, observed_cl, cov_matrix):
    emulator_predictions = emulator.predict(params)

    # Extract theory predictions at observed multipoles using precomputed indices
    predicted_cl = emulator_predictions[indices_jax]
    
    # Compute the Gaussian log likelihood
    delta = observed_cl - predicted_cl
    inv_cov = jnp.linalg.inv(cov_matrix)
    chi2 = delta.T @ inv_cov @ delta
    log_det_cov = jnp.linalg.slogdet(cov_matrix)[1]
    n = len(observed_cl)
    log_like = -0.5 * (chi2 + log_det_cov + n * jnp.log(2 * jnp.pi))
    return log_like


In [9]:
@jax.jit
def log_prior(params):
    """
    Log prior for cosmological parameters with uniform distributions.
    
    Parameters (in order):
    - omega_b: Physical baryon density (ωb = Ωb*h^2)
    - omega_cdm: Physical cold dark matter density (ωc = Ωc*h^2)
    - h: Reduced Hubble constant (H0 = 100h km/s/Mpc)
    - tau: Optical depth to reionization
    - n_s: Scalar spectral index
    - ln10^10A_s: Log of primordial curvature perturbation amplitude
    
    Returns:
        log(prior probability) = 0 if within bounds, -inf otherwise
    """
    omega_b, omega_cdm, h, tau, n_s, ln10_10_A_s = params
    
    # Bounds for each parameter with physical justifications:
    
    # omega_b: 0.005 - 0.1
    # Big Bang Nucleosynthesis (BBN) constrains baryon density to ~0.02-0.025
    # Allow broader range for exploration, but baryon density must be small positive value
    omega_b_min, omega_b_max = 0.005, 0.1
    
    # omega_cdm: 0.001 - 0.99
    # CDM density must be positive and cannot exceed unity (would mean Ω_cdm > 1)
    # Planck measurements suggest ~0.10-0.14, but allow wider range
    omega_cdm_min, omega_cdm_max = 0.001, 0.99
    
    # h: 0.4 - 1.0
    # Corresponds to H0 = 40-100 km/s/Mpc
    # Current measurements range from ~67-74 km/s/Mpc (h ~ 0.67-0.74)
    # Allow broader range to encompass all reasonable values
    h_min, h_max = 0.4, 1.0
    
    # tau: 0.01 - 0.8
    # Optical depth must be positive (reionization must have occurred)
    # Planck 2018 gives tau ~ 0.054 ± 0.007
    # Upper limit prevents unphysically high reionization optical depth
    tau_min, tau_max = 0.01, 0.8
    
    # n_s: 0.8 - 1.2
    # Spectral index around 1.0 means scale-invariant Harrison-Zel'dovich spectrum
    # Planck finds n_s ~ 0.965, inflation models predict slight tilt
    # Allow range that encompasses both red (n_s < 1) and blue (n_s > 1) tilts
    n_s_min, n_s_max = 0.8, 1.2
    
    # ln10^10A_s: 1.61 - 4.0
    # This is ln(10^10 * A_s) where A_s is primordial amplitude
    # Planck gives ln(10^10 A_s) ~ 3.04 ± 0.01
    # exp(1.61) ≈ 5 and exp(4.0) ≈ 55 give reasonable range for 10^10*A_s
    ln10_10_A_s_min, ln10_10_A_s_max = 1.61, 4.0
    
    # Check if all parameters are within bounds
    # For uniform prior: log(P) = 0 if in bounds, -inf if outside
    in_bounds = (
        (omega_b >= omega_b_min) & (omega_b <= omega_b_max) &
        (omega_cdm >= omega_cdm_min) & (omega_cdm <= omega_cdm_max) &
        (h >= h_min) & (h <= h_max) &
        (tau >= tau_min) & (tau <= tau_max) &
        (n_s >= n_s_min) & (n_s <= n_s_max) &
        (ln10_10_A_s >= ln10_10_A_s_min) & (ln10_10_A_s <= ln10_10_A_s_max)
    )
    
    # Return 0.0 if in bounds, -inf otherwise (using jax.numpy)
    return jnp.where(in_bounds, 0.0, -jnp.inf)

In [19]:
@jax.jit
def log_posterior(params):
    return log_prior(params) + log_likelihood(params, tt_cl, tt_cov)

In [20]:
# Diagonal inverse mass matrix - one value per parameter
# Parameters: omega_b, omega_cdm, h, tau, n_s, ln10^10A_s
# These are starting values that will be adapted during sampling
inv_mass_matrix = np.array([0.01, 0.01, 0.01, 0.01, 0.01, 0.01])
step_size = 1e-3

hmc = blackjax.nuts(log_posterior, step_size, inv_mass_matrix) # type: ignore

In [21]:
initial_position = cosmo_params.copy()
initial_state = hmc.init(initial_position) # type: ignore

In [22]:
hmc_kernel = jax.jit(hmc.step) # type: ignore

In [23]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [24]:
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 100)

In [25]:
states.position[-1]
# omega_b, omega_cdm, h, tau, n_s, ln10^10A_s

# - omega_b: Physical baryon density (ωb = Ωb*h^2)
# - omega_cdm: Physical cold dark matter density (ωc = Ωc*h^2)
# - h: Reduced Hubble constant (H0 = 100h km/s/Mpc)
# - tau: Optical depth to reionization
# - n_s: Scalar spectral index
# - ln10^10A_s: Log of primordial curvature perturbation amplitude


Array([0.07862987, 0.08653661, 0.73372173, 0.02583685, 0.8399762 ,
       2.9799304 ], dtype=float32)