# Model annealing
Infer a cosmological model via Continuous Tempering Langevin. 

In [1]:
!hostname
!python -c "import jax; print(jax.default_backend(), jax.devices())"
# !nvidia-smi
# numpyro.set_platform("gpu")

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.33' # NOTE: jax preallocates GPU (default 75%)

import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import random, jit, vmap, grad

import numpyro
from numpyro.handlers import seed, condition, trace
from functools import partial

%matplotlib inline
%load_ext autoreload 
%autoreload 2

import mlflow
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")
mlflow.set_experiment("Continuous Tempering Langevin");
# mlflow.end_run()
# mlflow.start_run(run_name="Zee")
# mlflow.log_params({"ho":2, "ha":np.array([2,3])})
# mlflow.log_metrics({"ho":2, "ha":3}, step=1)

feynmangpu04.cluster.local


gpu [cuda(id=0)]


In [6]:
run
mlflow.start_run()
run = mlflow.last_active_run()
run.info

<ActiveRun: >

# Import and simulate fiducial

In [36]:
from montecosmo.models import pmrsd_model, model_config
from montecosmo.utils import get_simulator, get_logp_fn, get_score_fn
model_config['scale_factor_lpt'] = 0.5
model_config['scale_factor_obs'] = 0.5
print(f"{model_config=}")

model = partial(pmrsd_model, **model_config)

# Cosmological parameters
cosmo_names = ['Omega_c', 'sigma8']
cosmo_labels = [r'\Omega_c', r'\sigma_8']
cond_params = {var_name+'_base': 0. for var_name in cosmo_names}

fiducial_simulator = get_simulator(condition(model, cond_params))
fiducial_params = fiducial_simulator(rng_seed=0)
fiducial_cosmo_params = {name: fiducial_params[name] for name in cosmo_names}

# Condition model
obs_names = ['obs_mesh'] # NOTE: Only condition on random sites
obs_params = {name: fiducial_params[name] for name in obs_names}
observed_model = condition(model, obs_params)
logp_fn = get_logp_fn(observed_model)
score_fn = get_score_fn(observed_model)

# Parameters to initialize samplers on
init_names = ['Omega_c', 'sigma8', 'init_mesh', 'b1', 'b2', 'bs', 'bnl'] # NOTE: Only init on random sites
init_params = {name+'_base': fiducial_params[name+'_base'] for name in init_names}

model_config={'mesh_size': array([64, 64, 64]), 'box_size': array([640, 640, 640]), 'scale_factor_lpt': 0.5, 'scale_factor_obs': 0.5, 'galaxy_density': 0.001, 'trace_reparam': True, 'trace_deterministic': False}


  return lax_numpy.astype(arr, dtype)


In [37]:
logp_fn(fiducial_params)

Array(-744184.75, dtype=float32)

In [38]:
fiducial_params['obs_mesh'] = jnp.zeros((64,64,64))
logp_fn(fiducial_params)

Array(-862459.9, dtype=float32)

In [28]:
fiducial_params['init_mesh'] = jnp.zeros((64,64,64))
logp_fn(fiducial_params)

Array(-862459.9, dtype=float32)

In [30]:
fiducial_params['init_mesh_base'] = jnp.zeros((64,64,64))
logp_fn(fiducial_params)

Array(-612866.44, dtype=float32)

In [4]:
from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree, ReversibleHeun
from jax.tree_util import tree_map, tree_flatten
from jax.flatten_util import ravel_pytree
from jax import  eval_shape

t0, t1 = 5, 0.
noise = lambda t: t/t0
drift = lambda t, y, args: tree_map(lambda x: -0.5 * x, score_fn(y, model_kwargs={'noise':noise(t)}))
diffusion = lambda t, y, args: tree_map(lambda x: jnp.ones_like(x), y)
solver = Euler()
ts = jnp.linspace(t0,t1,100)
saveat = SaveAt(ts=ts)

sample_shape_struct = eval_shape(lambda x:{name: x[name][0] for name in x.keys()}, samples)

@jit
@vmap
def get_samples(y, seed):
  brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-4, shape=sample_shape_struct, key=seed)
  terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
  return diffeqsolve(terms, solver, t0, t1, dt0=-0.001, y0=y, max_steps=10_000, saveat=saveat).ys