In [1]:
from functools import partial
import numpyro.distributions as dist
#import jax.numpy.fft as fft
import jax
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
import numpyro
from numpyro.handlers import condition
from pathlib import Path
import diffrax
import os,sys
import pickle
from diffrax import LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve, ConstantStepSize
from jax.scipy.ndimage import map_coordinates
from jax_cosmo.scipy.integrate import simps
from jaxpm.pm import growth_factor, growth_rate, pm_forces
from jaxpm.kernels import fftk, gradient_kernel, invlaplace_kernel, longrange_kernel
from jaxpm.painting import cic_paint, cic_paint_2d, cic_read, compensate_cic
from jaxpm.distributed import fft3d, ifft3d, normal_field
from collections import namedtuple

In [2]:
import diffrax

In [None]:

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".95"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"

jax.config.update("jax_enable_x64", False)


def lpt(cosmo, initial_conditions, positions, a, mesh_shape):
    """Computes first order LPT displacement"""
    initial_force = pm_forces(positions, mesh_shape, delta=initial_conditions).reshape(
        mesh_shape + [3]
    )
    a = jnp.atleast_1d(a)
    gf, cosmo = growth_factor(cosmo, a)
    gr, cosmo = growth_rate(cosmo, a)
    dx = gf.reshape([1, 1, -1, 1]) * initial_force
    p = (
        a**2
        * gr
        * jnp.sqrt(jc.background.Esqr(cosmo, a))
        * gf
    ).reshape([1, 1, -1, 1]) * initial_force
    return dx.reshape([-1, 3]), p.reshape([-1, 3])

def linear_field_pm(mesh_shape, box_size, pk, seed, sharding=None,field=None):
    """
    Generate initial conditions.
    """
    # Initialize a random field with one slice on each gpu
    #field = normal_field(mesh_shape, seed=seed, sharding=sharding)]

    field = fft3d(field)
    kvec = fftk(field)
    kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
                for i, kk in enumerate(kvec))**0.5
    pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
        box_size[0] * box_size[1] * box_size[2])

    field = field * (pkmesh)**0.5
    field = ifft3d(field)
    return field






def PM_numpyro_forward_model(
    linear_field,
    dens_model,
    cubegrid_size,
    cube_size,
):

    # Sample 2 cosmological parameters
    Omega_c = numpyro.sample("Omega_c", dist.Normal(0.26, 0.2))
    sigma8 = numpyro.sample("sigma8", dist.Normal(0.83, 1.0))

    # Set cosmology
    model_cosmo = jc.Cosmology(
        Omega_c=Omega_c,
        sigma8=sigma8,
        Omega_b=0.0492,
        Omega_k=0.0,
        h=0.67,
        n_s=0.96,
        w0=-1,
        wa=0.0,
    )
    Workspace = namedtuple('_workspace', ['background_radial_comoving_distance', 'background_growth_factor'])
    model_cosmo._workspace =Workspace(None, None)


    # Generate linear field
    def pk_fn(x):
        # Create a small function to generate the matter power spectrum
        k = jnp.logspace(-4, 1, 128)
        pk = jc.power.linear_matter_power(model_cosmo, k)
        return jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)
    mesh_shape = [cubegrid_size, cubegrid_size, cubegrid_size]
    field = numpyro.sample(
        "initial_conditions", dist.Normal(jnp.zeros(mesh_shape), jnp.ones(mesh_shape))
    )
    lin_field = linear_field_pm(
        [cubegrid_size, cubegrid_size, cubegrid_size],
        [cube_size, cube_size, cube_size],
        pk_fn, jax.random.PRNGKey(0), field=field
    )

    # Evolve the density field
    dens = dens_model(model_cosmo, lin_field)
    #print(dens.shape)
    # Sample noise
    ss = [cubegrid_size, cubegrid_size, cubegrid_size]
    obs = numpyro.sample("obs_dens",dist.Normal(dens, jnp.ones(ss)))

    numpyro.deterministic("dens",dens)

    return obs

def full_field_model_pm(
    cube_size,
    cubegrid_size,
    a_init,
    a_center,
):
    @jax.jit
    def forward_model(cosmo, lin_field):
        # Create particles
        particles = jnp.stack(
            jnp.meshgrid(
                *[jnp.arange(s) for s in [cubegrid_size, cubegrid_size, cubegrid_size]]
            ),
            axis=-1,
        ).reshape([-1, 3])
        mesh_shape = [cubegrid_size, cubegrid_size, cubegrid_size]

        ''' cosmo = jc.Cosmology(
            Omega_c=cosmo.Omega_c,
            sigma8=cosmo.sigma8,
            Omega_b=cosmo.Omega_b,
            h=cosmo.h,
            n_s=cosmo.n_s,
            w0=cosmo.w0,
            Omega_k=0.0,
            wa=0.0,
        )
        '''

        # Temporary fix
        #cosmo._workspace = {}

        def paint_density(t, y, args):
            # Load positions and apply boundary conditions
            pos = y[0]
            #print(y.shape)
            x = pos[..., 1]
            y = pos[..., 2]
            z = pos[..., 0]

            # Paint all the particles that fall into slice
            ss=args[1]
            density = cic_paint(jnp.zeros(ss),jnp.c_[x, y, z])

            return density


        #@jax.jit
        def neural_nbody_ode(a, state, args):
            """
            state is a tuple (position, velocities)
            """
            pos, vel = state
            cosmoin = args[0]
            mesh_shape= args[1]
            forces = pm_forces(pos,
                               mesh_shape=mesh_shape,
                               paint_absolute_pos=True,
                               halo_size=0,
                               sharding=None) * 1.5 * cosmoin.Omega_m
    
            # Computes the update of position (drift)
            dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmoin, a))) * vel
    
            # Computes the update of velocity (kick)
            dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmoin, a))) * forces
    
            return jnp.stack([dpos, dvel], axis=0)

    

        eps, p = lpt(
            cosmo,
            lin_field,
            particles,
            a_init,
            [cubegrid_size, cubegrid_size, cubegrid_size],
        )
        args = [cosmo,[cubegrid_size, cubegrid_size, cubegrid_size]]
        term = ODETerm(neural_nbody_ode)
        solver = LeapfrogMidpoint()
        saveat = SaveAt(ts=a_center[::-1])
        stepsize_controller = ConstantStepSize()
        solution = diffeqsolve(
            term,
            solver,
            t0=a_init,
            t1=1.0,
            dt0=0.03,
            y0=jnp.stack([particles + eps, p], axis=0),
            args=args,
            saveat=saveat,
            adjoint=diffrax.ReversibleAdjoint(),
            #adjoint= diffrax.RecursiveCheckpointAdjoint(6),
            max_steps=32,
            stepsize_controller=stepsize_controller,
        )

        density = paint_density(0, solution.ys[-1], args)
        
        return density#[-1,:,:,:] # Just returning the last snapshot

    return forward_model
#####################################################################################################################

if __name__ == "__main__":
    key_number = 1

    # Set output directory
    dir_out = "./"
    Path(dir_out).mkdir(parents=True, exist_ok=True)

    # Sim configuration
    cube_size = 128  # Mpc/h
    cubegrid_size = 32  # npix
    a_init = 0.05

    # Initialize the density model
    dens_model = full_field_model_pm(
        cube_size=cube_size,
        cubegrid_size=cubegrid_size,
        a_init = a_init,
        a_center=np.linspace(a_init,1,41)[::-1], 
    )

    # Partially fix the forward model
    model = partial(
        PM_numpyro_forward_model,
        linear_field=linear_field_pm,
        dens_model=dens_model,
        cubegrid_size=cubegrid_size,
        cube_size=cube_size,
    )

    # Set keys
    keys = jax.random.PRNGKey(3)
    subkey = jax.random.split(keys, 200)
    key_data = subkey[key_number]
    key_run = subkey[key_number + 100]

    # Create a random realization of a map with fixed cosmology
    gen_model = condition(model, {"sigma8": 0.83, "Omega_c": 0.26})
    model_tracer = numpyro.handlers.trace(numpyro.handlers.seed(gen_model, key_data))
    model_trace = model_tracer.get_trace()
    #jax.debug.breakpoint()

    with open(dir_out+'model_trace.pkl', 'wb') as handle:
        pickle.dump(model_trace, handle, protocol=pickle.HIGHEST_PROTOCOL)


    # Set the starting point of the sampling (truth here)
    init_values = {
        "initial_conditions": model_trace["initial_conditions"]["value"],
        "sigma8": 0.83,
        "Omegac": 0.26,
    }

    # Condition the model (i.e. set target map)
    observed_model = condition(
        model,
        {'obs_dens': model_trace["obs_dens"]["value"]}
    )

    # Setup NUTS kernel
    nuts_kernel = numpyro.infer.NUTS(
        model=observed_model,
        init_strategy=numpyro.infer.init_to_value(values=init_values),
        max_tree_depth=5,
        step_size=0.01,
    )

    # Set up MCMC
    mcmc = numpyro.infer.MCMC(
        nuts_kernel,
        num_warmup=100,
        num_samples=50,
        thinning=1,
        num_chains=1,
        chain_method="vectorized",
        progress_bar=True,
    )

    # Start sampling
    mcmc.run(key_run)
    res = mcmc.get_samples()

    with open(dir_out+'sample.pkl', 'wb') as handle:
        pickle.dump(res, handle, protocol=pickle.HIGHEST_PROTOCOL)


  return lax_numpy.astype(self, dtype, copy=copy, device=device)
warmup:   4%|███                                                                           | 6/150 [00:45<09:59,  4.16s/it, 31 steps of size 1.04e-05. acc. prob=0.30]