In [1]:
### Notebook for P3D stuff!
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import jax

jax.config.update('jax_platform_name', 'gpu')
from jax.lib import xla_bridge
print("jax version",jax.__version__)
print(xla_bridge.get_backend().platform)

import numpy as np
import jax.numpy as jnp

%pylab inline

jax version 0.4.26
gpu
%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


In [2]:
!nvidia-smi


Mon Nov  3 14:34:21 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 2080 Ti      Off| 00000000:1A:00.0 Off |                  N/A |
| 16%   25C    P8                1W / 250W|   8467MiB / 11264MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 2080 Ti      Off| 00000000:1B:00.0 Off |  

  pid, fd = os.forkpty()


In [3]:
z= 2.2

bs =  200#box size in Mpc/h
nc =  150#number of pixels per side

mesh_shape = [nc,nc,nc]
box_size = [bs,bs,bs]

ptcl_grid_shape = (nc,) * 3
ptcl_spacing = bs/nc



In [4]:
#some fourier space routines
from helper_functions import *

kvec = rfftnfreq_2d(ptcl_grid_shape, ptcl_spacing)   # unchanged helper call
k = jnp.sqrt(sum(ki**2 for ki in kvec))              # (nc, nc, nc//2+1)
#this vector is needed for lots of fourier space transformation operations later on...
print(k.shape)

  k = jnp.sqrt(sum(ki**2 for ki in kvec))              # (nc, nc, nc//2+1)
2025-11-03 14:34:25.085000: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.8.93). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


(150, 150, 76)


In [5]:
np.max(k)

Array(4.0810485, dtype=float32)

In [6]:
#in this setup, LOS (z) is along the 0 axis, be careful when loading in data!

kz = jnp.ones(k.shape)*kvec[0]**2

kx = jnp.ones(k.shape)*(kvec[1]**2+kvec[2]**2)
#tiny deltas are added to avoid div by zero errors
kk = (kx+kz)+10**(-8)
kmu = jnp.sqrt((kz/(k**2+0.00001)))

k_par = jnp.broadcast_to(kvec[-1], k.shape)
mu = jnp.where(k > 0, jnp.abs(k_par) / k, 0.0)

In [7]:
#m_array[:,0][:50].max()
from jax import jit, checkpoint, custom_vjp
from jax.scipy.ndimage import map_coordinates

#loading in fiducial model
m_array = np.load("/gpfs02/work/diffusion/P3D/data/pkell_red_CROL200_bf50_zoom150.npy")
m_array = jnp.array(m_array[np.where(m_array[:,1]<=4)])
m_array = jnp.array(np.concatenate([m_array[np.where(m_array[:,1]==0)][:15,:],m_array[np.where(m_array[:,1]==2)][:15,:],m_array[np.where(m_array[:,1]==4)][:15,:]]))
kmax = m_array[-1,0]
print(m_array.shape,kmax)


(45, 3) 2.1768007


In [8]:
import jaxinterp2d
from scipy.special import legendre

l0 = np.where(m_array[:,1]==0)[0]
l2 = np.where(m_array[:,1]==2)[0]
l4 = np.where(m_array[:,1]==4)[0]

k0, P0_tab = m_array[l0,0], m_array[l0,2]
k2, P2_tab = m_array[l2,0], m_array[l2,2]
k4, P4_tab = m_array[l4,0], m_array[l4,2]

kflat = k.reshape(-1)
P0_grid = jnp.interp(kflat, k0, P0_tab, left=0.0, right=0.0).reshape(k.shape)
P2_grid = jnp.interp(kflat, k2, P2_tab, left=0.0, right=0.0).reshape(k.shape)
P4_grid = jnp.interp(kflat, k4, P4_tab, left=0.0, right=0.0).reshape(k.shape)

L0, L2, L4 = legendre(0)(mu), legendre(2)(mu), legendre(4)(mu)

In [9]:
N = float(nc); Lbox = float(bs)
k_nyq = jnp.pi * N / Lbox
k_tabmax = float(k0.max())
k_cut = jnp.minimum(k_nyq, k_tabmax)


def _anisotropic_P(P0g, P2g, P4g):
    P = jnp.clip(P0g*L0 + P2g*L2 + P4g*L4, a_min=0.0)
    return jnp.where(k <= k_cut, P, 0.0)

# Baseline Plin on grid from the loaded table
Plin_grid = _anisotropic_P(P0_grid, P2_grid, P4_grid)


In [10]:
# using 20/22 band powers
k_ind_optim_max = 15
ell_bins=3
k_bins=15
tff = m_array[:,2].reshape(ell_bins,k_bins)
theta_fid = m_array[:,2].reshape(ell_bins,k_bins)[:,:k_ind_optim_max]
# theta_fid = m_array[:k_bins,2].reshape(1,k_bins)[:,:k_ind_optim_max]


def power_b(theta):
    tff_theta = tff.at[:,:k_ind_optim_max].set(theta.reshape(ell_bins, k_ind_optim_max))
    P0g = jnp.interp(kflat, k0, tff_theta[0], left=0.0, right=0.0).reshape(k.shape)
    P2g = jnp.interp(kflat, k2, tff_theta[1], left=0.0, right=0.0).reshape(k.shape)
    P4g = jnp.interp(kflat, k4, tff_theta[2], left=0.0, right=0.0).reshape(k.shape)
    return _anisotropic_P(P0g, P2g, P4g)

In [11]:
# 4) nbodykit-consistent synthesis from P (fixes amplitude & removes conj()/transpose)
def synthesize_from_P(Plin, z_realspace):
    V = float(bs**3); Nf = float(nc)
    white_k = jnp.fft.rfftn(z_realspace) / (Nf**1.5)            # unit-variance complex
    delta_k = white_k * jnp.sqrt(Plin * (Nf**6) / V)            # target amplitude
    return jnp.fft.irfftn(delta_k)                               # real field
    
# # 5) nbodykit-like estimator (isotropic + multipoles) for quick checks
# def nbodykit_like_power(delta_x, nbins=30, assignment=None):
#     V = float(bs**3); Nf = float(nc); dx = bs/nc
#     # FFT and raw power
#     delta_k = jnp.fft.rfftn(delta_x)
#     Pk = (V / (Nf**6)) * (delta_k.real**2 + delta_k.imag**2)

#     # build broadcast k-components once (reuse kvec)
#     KX = jnp.broadcast_to(kvec[0], Pk.shape)
#     KY = jnp.broadcast_to(kvec[1], Pk.shape)
#     KZ = jnp.broadcast_to(kvec[2], Pk.shape)
#     kvals = jnp.sqrt(KX**2 + KY**2 + KZ**2).ravel()

#     # deconvolve mass assignment if you painted particles (CIC/TSC/NGP)
#     if assignment is not None:
#         if assignment.upper() == 'NGP': p = 1
#         elif assignment.upper() == 'CIC': p = 2
#         elif assignment.upper() == 'TSC': p = 3
#         else: p = 0
#         if p > 0:
#             def sinc_true(x): return jnp.sinc(x / jnp.pi)       # sin(x)/x
#             Wx = sinc_true(0.5*KX*dx)**p
#             Wy = sinc_true(0.5*KY*dx)**p
#             Wz = sinc_true(0.5*KZ*dx)**p
#             W2 = (Wx*Wy*Wz)**2
#             Pk = jnp.where(W2 > 0, Pk / W2, 0.0)

#     # bin up to per-axis Nyquist
#     kmax = k_nyq
#     m = (kvals <= kmax)
#     edges = jnp.linspace(0.0, kmax, nbins+1)
#     centers = 0.5*(edges[1:]+edges[:-1])
#     idx = jnp.digitize(kvals[m], edges) - 1
#     counts = jnp.bincount(idx, length=nbins)
#     Pvals = Pk.ravel()[m]
#     sums = jnp.bincount(idx, weights=Pvals, length=nbins)
#     P1d = jnp.where(counts>0, sums/counts, 0.0)

#     # multipoles with LOS = rfft axis (last axis)
#     mu_grid = jnp.where(k > 0, jnp.abs(k_par)/k, 0.0)
#     L0g, L2g, L4g = legendre(0)(mu_grid), legendre(2)(mu_grid), legendre(4)(mu_grid)

#     def bin_weight(val):
#         v = val.ravel()[m]
#         s = jnp.bincount(idx, weights=v, length=nbins)
#         return jnp.where(counts>0, s/counts, 0.0)

#     P0 = bin_weight(Pk * (1.0 * L0g)) * 1.0     # (2ℓ+1)=1
#     P2 = bin_weight(Pk * (5.0 * L2g))
#     P4 = bin_weight(Pk * (9.0 * L4g))
#     return centers, P1d, P0, P2, P4

def nbodykit_like_power(delta_x, bs, nc, nbins=30, assignment=None, interlaced=False):
    V = float(bs**3); N = float(nc); dx = bs/nc
    # optional interlacing (real-space average of two half-cell shifts)
    if interlaced:
        # simple 3-axis half-cell shift average (approx)
        sh = (slice(None), slice(None), slice(None))
        # roll by 1 is a Δx shift; half-cell is not representable on the grid exactly,
        # but this 1-cell interlace still damps aliases; for exact half-cell use phase in k-space.
        delta_x = 0.5*(delta_x + 0.5*(jnp.roll(delta_x, 1, 0)+jnp.roll(delta_x, 1, 1)+jnp.roll(delta_x, 1, 2)))

    delta_k = jnp.fft.rfftn(delta_x)
    delta_k = delta_k.at[0,0,0].set(0.0)  # DC = 0
    Pk = (V / (N**6)) * (delta_k.real**2 + delta_k.imag**2)

    # k-grid (helper)
    kvec = rfftnfreq_2d((nc,nc,nc), dx)
    KX = jnp.broadcast_to(kvec[0], Pk.shape)
    KY = jnp.broadcast_to(kvec[1], Pk.shape)
    KZ = jnp.broadcast_to(kvec[2], Pk.shape)
    kvals = jnp.sqrt(KX**2 + KY**2 + KZ**2).ravel()

    # Assignment window deconv
    if assignment is not None:
        if assignment.upper()=='NGP': p=1
        elif assignment.upper()=='CIC': p=2
        elif assignment.upper()=='TSC': p=3
        else: p=0
        if p>0:
            def sinc_true(x): return jnp.sinc(x/jnp.pi)
            Wx = sinc_true(0.5*KX*dx)**p
            Wy = sinc_true(0.5*KY*dx)**p
            Wz = sinc_true(0.5*KZ*dx)**p
            W2 = (Wx*Wy*Wz)**2
            Pk = jnp.where(W2>0, Pk/W2, 0.0)

    kmax = jnp.pi * nc / bs  # per-axis Nyquist
    m = (kvals > 0) & (kvals <= kmax)

    edges = jnp.linspace(0.0, kmax, nbins+1)
    centers = 0.5*(edges[1:]+edges[:-1])
    idx = jnp.digitize(kvals[m], edges, right=True) - 1  # binning convention

    counts = jnp.bincount(idx, length=nbins)
    sums   = jnp.bincount(idx, weights=Pk.ravel()[m], length=nbins)
    Pk_1d  = jnp.where(counts>0, sums/counts, 0.0)
    return centers, Pk_1d


In [12]:
#muse componenet...
from functools import partial
import jax
import jax.numpy as jnp
from muse_inference.jax import JaxMuseProblem

In [13]:

#file name
prefix = "V1_DENSE_"
loc = "./configs/"

# modeled DLA, anisotropy
# SNR becoming smalller now (higher noise level)

naa = np.load(loc+prefix+"naa.npy")
kernel = np.load(loc+prefix+"kernel.npy")
skewers_skn = np.load(loc+prefix+"skewers_skn.npy")
skewers_dla = np.load(loc+prefix+"skewers_dla.npy")
skewers_fin = np.load(loc+prefix+"skewers_fin.npy")

@jit
def cic_readout_jit_jnc(mesh,naa,kernel,bs=False):
    #"highly optimized" CIC, need to preprocess lots of things... don't diff output coords
    meshvals = mesh.flatten()[naa].reshape(-1,8).T#mesh[tuple(neighboor_coords[0,:,:].T.tolist())]
    weightedvals = meshvals.T* kernel[0]
    values = np.sum(weightedvals, axis=-1)
    
    return values



In [14]:
# lin_modes_sim = np.load("./tau_mesh_red_CRO512.npy")
from scipy.ndimage import gaussian_filter
import scipy

flux_sim_delta = np.load("/gpfs02/work/diffusion/P3D/data/pkell_red_CROL200_bf50_zoom150.npy")

map_lya_sim = cic_readout_jit_jnc(flux_sim_delta,naa,kernel)
key = jax.random.PRNGKey(100)
keys = jax.random.split(key, 2)
map_lya_sim += jnp.sqrt(skewers_skn)*jax.random.normal(keys[1], (kernel.shape[1],))

# np.save("./field_mock/map_lya_sim_bf50_ell024.npy",map_lya_sim)
# map_lya_sim = np.load("./field_mock/map_lya_sim_bf50_ell0.npy")

In [15]:
tf_cut_flat = theta_fid.flatten()

noise_level = 1.0

def gen_map_lya(theta,z):
    modes = z[:nc**3].reshape((nc,nc,nc))    
    Plin = power_b(theta)
    # conv_field = jnp.fft.rfftn(modes).conj()*Plin**(1/2)
    # lin_modes_real = jnp.fft.irfftn(conv_field).T[:,:,:]
    lin_modes_real = synthesize_from_P(Plin_grid, modes)

    lya_values = cic_readout_jit_jnc(lin_modes_real,naa,kernel)
    return lya_values

class Jax3DMuseProblem_flat(JaxMuseProblem):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    #@jax.jit
    def sample_x_z(self, key, θ):
        keys = jax.random.split(key, 2)
        z = jax.random.normal(keys[0], (nc*nc*nc,))
        noise = jnp.sqrt(skewers_skn)*jax.random.normal(keys[1], (kernel.shape[1],))
        x = gen_map_lya(θ,z) 
        x_hat = x + noise #1*jax.random.normal(keys[1], (32**3,)).reshape((32,32,32))
        return (x_hat, z)

  #  @jax.jit
    def logLike(self, x, z, θ):
        return -(jnp.sum((x-gen_map_lya(θ,z))**2/((jnp.sqrt(skewers_skn))**2))+ jnp.sum(z**2.0))    

  #  @jax.jit
    def logPrior(self, θ):
        return -jnp.sum(((θ-jnp.array(tf_cut_flat)*1.2)**2 / (2*(tf_cut_flat*0.4)**2)))
    


In [16]:
prob = Jax3DMuseProblem_flat(implicit_diff=True,jit=True)
key = jax.random.PRNGKey(100)
(x, z) = prob.sample_x_z(key, tf_cut_flat)
prob.set_x(map_lya_sim)




In [17]:
modes = z[:nc**3].reshape((nc,nc,nc))    
Plin = power_b(tf_cut_flat)
# conv_field = jnp.fft.rfftn(modes).conj()*Plin**(1/2)
# lin_modes_real = jnp.fft.irfftn(conv_field).T[:,:,:]
lin_modes_real = synthesize_from_P(Plin_grid, modes)

np.save("example_gen_field",lin_modes_real)
