# Model validation

In [55]:
!python -c "import jax; print(jax.default_backend(), jax.devices())"
import os; os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.99' # NOTE: jax preallocates GPU (default 75%)

import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
import jax.random as jr
from jax import jit, vmap, grad, debug, lax, flatten_util
from jax.tree_util import tree_map

import numpyro
from numpyro.handlers import seed, condition, trace
from functools import partial
from getdist import plots, MCSamples

%matplotlib inline
%load_ext autoreload 
%autoreload 2

import mlflow
mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")
mlflow.set_experiment("Model SBI")
from montecosmo.utils import pickle_dump, pickle_load, get_vlim, theme_switch
save_dir = os.path.expanduser("~/scratch/pickles/")

gpu [cuda(id=0)]
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Train

### Import

In [2]:
from montecosmo.models import pmrsd_model, prior_model, get_logp_fn, get_score_fn, get_simulator, get_pk_fn, get_param_fn, get_noise_fn
from montecosmo.models import print_config, condition_on_config_mean, default_config as config
# Build and render model
config.update(a_lpt=0.5, mesh_size=16*np.ones(3, dtype=int))
model = partial(pmrsd_model, **config)
config['lik_config']['obs_std'] = 0.1
print_config(model)
simulator = jit(vmap(get_simulator(model)))

# CONFIG
{'mesh_size': array([16, 16, 16]), 'box_size': array([640., 640., 640.]), 'a_lpt': 0.5, 'a_obs': 0.5, 'galaxy_density': 0.001, 'trace_reparam': False, 'trace_meshes': False, 'prior_config': {'Omega_c': ['{\\Omega}_c', 0.25, 0.1], 'sigma8': ['{\\sigma}_8', 0.831, 0.14], 'b1': ['{b}_1', 1, 0.5], 'b2': ['{b}_2', 0, 0.5], 'bs2': ['{b}_{s^2}', 0, 0.5], 'bn2': ['{b}_{\\nabla^2}', 0, 0.5]}, 'lik_config': {'obs_std': 0.1}}

# INFOS
cell_size:        [40.0, 40.0, 40.0] Mpc/h
delta_k:          0.00982 h/Mpc
k_nyquist:        0.07854 h/Mpc
mean_gxy_density: 64.000 gxy/cell



### Simulate

In [3]:
n_simus = 10_000
simus = simulator(rng_seed=jnp.arange(n_simus))
pickle_dump(simus, "simus.p")

  return lax_numpy.astype(arr, dtype)


In [89]:
from jax import eval_shape, tree_util, ShapeDtypeStruct
fiduc_params = get_simulator(condition_on_config_mean(model))(rng_seed=0)

shape_dtype_struct = eval_shape(lambda x:x, fiduc_params)


# feature_dict_fn = lambda x: tfds.features.Tensor(shape=jnp.shape(x), dtype=type(x))
# feature_dict = tree_util.tree_map(feature_dict_fn, fiduc_params)

feature_dict_fn = lambda x: tfds.features.Tensor(shape=x.shape, dtype=x.dtype)
feature_dict = tree_util.tree_map(feature_dict_fn, shape_dtype_struct)
feature_dict

{'Omega_c_': Tensor(shape=(), dtype=float32),
 'b1_': Tensor(shape=(), dtype=float32),
 'b2_': Tensor(shape=(), dtype=float32),
 'bn2_': Tensor(shape=(), dtype=float32),
 'bs2_': Tensor(shape=(), dtype=float32),
 'init_mesh_': Tensor(shape=(16, 16, 16), dtype=float32),
 'obs_mesh': Tensor(shape=(16, 16, 16), dtype=float32),
 'sigma8_': Tensor(shape=(), dtype=float32)}

In [103]:
import tensorflow_datasets as tfds
import montecosmo.mydataset
ds = tfds.load("mydataset")

# CONFIG
{'mesh_size': array([16, 16, 16]), 'box_size': array([640., 640., 640.]), 'a_lpt': 0.5, 'a_obs': 0.5, 'galaxy_density': 0.001, 'trace_reparam': False, 'trace_meshes': False, 'prior_config': {'Omega_c': ['{\\Omega}_c', 0.25, 0.1], 'sigma8': ['{\\sigma}_8', 0.831, 0.14], 'b1': ['{b}_1', 1, 0.5], 'b2': ['{b}_2', 0, 0.5], 'bs2': ['{b}_{s^2}', 0, 0.5], 'bn2': ['{b}_{\\nabla^2}', 0, 0.5]}, 'lik_config': {'obs_std': 0.1}}

# INFOS
cell_size:        [40.0, 40.0, 40.0] Mpc/h
delta_k:          0.00982 h/Mpc
k_nyquist:        0.07854 h/Mpc
mean_gxy_density: 64.000 gxy/cell



  return lax_numpy.astype(arr, dtype)


## Inference

In [None]:
# Get fiducial parameters
fiduc_params = get_simulator(condition_on_config_mean(partial(model, trace_reparam=True)))(rng_seed=0)

# Condition model on observables
obs_names = ['obs_mesh']
# obs_names = ['obs_mesh','Omega_c_','sigma8_','b1_','b2_','bs_','bnl_']
obs_params = {name: fiduc_params[name] for name in obs_names}
observed_model = condition(model, obs_params)

# Get and vectorize relevant functionals
logp_fn = get_logp_fn(observed_model)
score_fn = get_score_fn(observed_model)
pk_fn = get_pk_fn(**config)
param_fn = get_param_fn(**config)
pk_vfn = jit(vmap(pk_fn))
param_vfn = jit(vmap(param_fn))