# Model annealing
Infer a cosmological model via Continuous Tempering Langevin. 

In [1]:
!hostname
!python -c "import jax; print(jax.default_backend(), jax.devices())"
# !nvidia-smi
# numpyro.set_platform("gpu")

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.33' # NOTE: jax preallocates GPU (default 75%)

import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import random, jit, vmap, grad

import numpyro
from numpyro.handlers import seed, condition, trace
from functools import partial

%matplotlib inline
%load_ext autoreload 
%autoreload 2

import mlflow
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")
mlflow.set_experiment("Continuous Tempering Langevin");
# mlflow.end_run()
# mlflow.start_run(run_name="Zee")
# mlflow.log_params({"ho":2, "ha":np.array([2,3])})
# mlflow.log_metrics({"ho":2, "ha":3}, step=1)

feynmangpu04.cluster.local


gpu [cuda(id=0)]


# Import and simulate fiducial

In [119]:
def get_simulator(model, cond_params):
    """
    Return a simulator that samples from a model conditioned on some parameter.
    """
    def sample_model(model, cond_params, rng_seed=0):
        cond_model = condition(model, cond_params) # NOTE: Only condition on random sites
        cond_trace = trace(seed(cond_model, rng_seed=rng_seed)).get_trace()
        params = {name: cond_trace[name]['value'] for name in cond_trace.keys()}
        return params

    vsample_model = vmap(partial(sample_model, model, cond_params))

    @partial(jit, static_argnames=('batch_size'))
    def simulator(batch_size, rng_key=random.PRNGKey(0)):
        keys = random.split(rng_key, batch_size)
        return vsample_model(keys)

    return simulator

In [132]:
jnp.array([[]])

Array([], dtype=float32)

In [154]:
def get_truc(c,d):
    def trac(c, d, a=1,b={}):
        return a+len(b)

    vtruc = vmap(partial(trac, c,d), in_axes=(0, None))
    vvtruc = vmap(vtruc, in_axes=(None,0))
    return vvtruc

vvtruc = get_truc(0,1)
vvtruc(jnp.array([1,3]),{'hey':jnp.array([2,3])})
vvtruc(jnp.array([1,3]),jnp.array([]))

# vvvtruc = vmap(vvtruc, in_axis=(None,0))

TypeError: len() of unsized object

In [155]:
def get_simulator(model, cond_params):
    def sample_model(model, cond_params, rng_seed=0, model_kwargs={}):
        return rng_seed+len(model_kwargs)

    vsample_model = vmap(partial(sample_model, model, cond_params), in_axes=(0,None))
    vvsample_model = vmap(vsample_model, in_axes=(None,0))

    # vvsample_model = vmap(partial(sample_model, model, cond_params), in_axes=(0,0))

    return vvsample_model

fiducial_simulator = get_simulator(model, cond_params)
# fiducial_simulator(jnp.array([0,1]),{'None':jnp.array([2,3])})
# fiducial_simulator(jnp.array([]),{})
# fiducial_simulator(jnp.array([1]),{'None':jnp.array([2,3])})
# fiducial_simulator(jnp.array([0,1]),jnp.array([[]]))
fiducial_simulator(jnp.array([0,1]),jnp.array([[]]))

Array([[0, 1]], dtype=int32)

In [169]:
def get_simulator(model, cond_params):
    """
    Return a simulator that samples from a model conditioned on some parameters.
    """
    def sample_model(model, cond_params, rng_seed=0, model_kwargs={}):
        if len(model_kwargs)==0:
            model_kwargs = {}
        cond_model = condition(model, cond_params) # NOTE: Only condition on random sites
        cond_trace = trace(seed(cond_model, rng_seed=rng_seed)).get_trace(**model_kwargs)
        params = {name: cond_trace[name]['value'] for name in cond_trace.keys()}
        return params

    vsample_model = vmap(partial(sample_model, model, cond_params), in_axes=(None,0))
    vvsample_model = vmap(vsample_model, in_axes=(0,None))

    @partial(jit, static_argnames=('batch_size'))
    def simulator(batch_size, rng_key=random.PRNGKey(0), model_kwargs={}):
        """
        Sample batches from model.
        Batch size would be left-most dimension, and model arguments length the second left-most.
        """
        if len(model_kwargs)==0:
            model_kwargs = jnp.array([[]]) # because jnp.array([{}]) is not valid
        keys = random.split(rng_key, batch_size)
        return vvsample_model(keys, model_kwargs)

    return simulator

In [170]:
from montecosmo.models import pmrsd_model, model_config
model_config['scale_factor_lpt'] = 0.5
model_config['scale_factor_obs'] = 0.5
print(f"{model_config=}")

model = partial(pmrsd_model, **model_config)

# Cosmological parameters
cosmo_names = ['Omega_c', 'sigma8']
cosmo_labels = [r'\Omega_c', r'\sigma_8']
cond_params = {var_name+"_base": 0. for var_name in cosmo_names}

fiducial_simulator = get_simulator(model, cond_params)
fiducial_params = fiducial_simulator(batch_size=2)
# fiducial_cosmo_params = {name: fiducial_params[name] for name in cosmo_names}
# fiducial_simulator(jnp.array([0,1]),{'noise':jnp.array([1,2])})
# fiducial_simulator(jnp.array([0,1]), jnp.array([[]]))

model_config={'mesh_size': array([64, 64, 64]), 'box_size': array([640, 640, 640]), 'scale_factor_lpt': 0.5, 'scale_factor_obs': 0.5, 'galaxy_density': 0.001, 'trace_reparam': True, 'trace_deterministic': False}


  return lax_numpy.astype(arr, dtype)


In [171]:
fiducial_simulator(3, model_kwargs={'noise':jnp.array([1,2])})


{'Omega_c': Array([[0.25, 0.25],
        [0.25, 0.25],
        [0.25, 0.25]], dtype=float32, weak_type=True),
 'Omega_c_base': Array([[0., 0.],
        [0., 0.],
        [0., 0.]], dtype=float32, weak_type=True),
 'b1': Array([[0.70895493, 0.70895493],
        [0.5103309 , 0.5103309 ],
        [0.8100258 , 0.8100258 ]], dtype=float32),
 'b2': Array([[-1.3717531, -1.3717531],
        [ 3.1380093,  3.1380093],
        [-9.062385 , -9.062385 ]], dtype=float32),
 'bnl': Array([[ 2.7419336,  2.7419336],
        [-3.9158497, -3.9158497],
        [-1.2736499, -1.2736499]], dtype=float32),
 'bs': Array([[ 1.6050004 ,  1.6050004 ],
        [-0.00954373, -0.00954373],
        [-4.8938384 , -4.8938384 ]], dtype=float32),
 'init_mesh': Array([[[[[-5.88610411e-01,  3.31083894e-01, -9.87633586e-01, ...,
             1.26688778e-01, -4.94568437e-01, -9.24193203e-01],
           [ 5.86835921e-01,  5.92601418e-01, -4.61189687e-01, ...,
             4.00951803e-01, -1.14705932e+00, -1.20590031e-01],
   

In [55]:
# Condition model
obs_names = ['obs_mesh'] # NOTE: Only condition on random sites
obs_params = {name: fiducial_params[name] for name in obs_names}
observed_model = condition(model, obs_params)  

# Initialize sampler
init_names = ['Omega_c_base', 'sigma8_base', 'init_mesh_base'] # NOTE: Only init on random sites
init_params = {name: fiducial_params[name] for name in init_names}

In [None]:
from numpyro.infer.util import log_density

def get_logp()
def logp_fn(mixt_val, sigma, model_kwargs, temper_prior=True):
    logp = log_density(model=mixture_model, 
                    model_args=(), 
                    model_kwargs=model_kwargs, 
                    params={'mixt':mixt_val})[0]
