# Model Inference
Infer from a cosmological model via MCMC samplers. 

In [1]:
import os; os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='1.' # NOTE: jax preallocates GPU (default 75%)
import matplotlib.pyplot as plt
import numpy as np
from jax import numpy as jnp, random as jr, jit, vmap, grad, debug, tree

from functools import partial
from getdist import plots
from numpyro import infer

%matplotlib inline
%load_ext autoreload
%autoreload 2

from montecosmo.model import FieldLevelModel, default_config
from montecosmo.utils import pdump, pload
from montecosmo.mcbench import sample_and_save

# import mlflow
# mlflow.set_tracking_uri(uri="http://127.0.0.1:8081")
# mlflow.set_experiment("infer")
!jupyter nbconvert --to script ./src/montecosmo/tests/infer_model.ipynb

2025-01-15 11:51:11.994101: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 11.5 which is older than the ptxas CUDA version (11.8.89). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
  pid, fd = os.forkpty()


[NbConvertApp] Converting notebook ./src/montecosmo/tests/infer_model.ipynb to script
[NbConvertApp] Writing 8604 bytes to src/montecosmo/tests/infer_model.py


## Config and fiduc

In [None]:
################## TO SET #######################
from montecosmo.script import from_id
# task_id = int(os.environ['SLURM_ARRAY_TASK_ID'])
task_id = 1130
print("SLURM_ARRAY_TASK_ID:", task_id)
model, mcmc_config, save_dir, save_path = from_id(task_id)

# import sys
# tempstdout, tempstderr = sys.stdout, sys.stderr
# sys.stdout = sys.stderr = open(save_path+'.out', 'a')
os.makedirs(save_dir, exist_ok=True)

SLURM_ARRAY_TASK_ID: 1130


In [4]:
print(model)
print(mcmc_config)
# model.render()

if not os.path.exists(save_dir+"truth.p"):
    # Predict and save fiducial
    truth = {'Omega_m': 0.31, 
            'sigma8': 0.81, 
            'b1': 1., 
            'b2':0., 
            'bs2':0., 
            'bn2': 0.}

    model.reset()
    truth = model.predict(samples=truth, hide_base=False, hide_samp=False, frombase=True)
    
    print(f"Saving model and truth at {save_dir}")
    model.save(save_dir)    
    pdump(truth, save_dir+"truth.p")
else:
    print(f"Loading truth from {save_dir}")
    truth = pload(save_dir+"truth.p")

model.condition({'obs': truth['obs']})
model.delta_obs = truth['obs'] - 1
model.block()
# model.render()

# CONFIG
{'a_lpt': 0.5,
 'a_obs': 0.5,
 'box_shape': array([80., 80., 80.]),
 'gxy_density': 0.001,
 'latents': {'Omega_m': {'group': 'cosmo',
                         'high': 1.0,
                         'label': '{\\Omega}_m',
                         'loc': 0.3111,
                         'low': 0.05,
                         'scale': 0.2},
             'b1': {'group': 'bias',
                    'label': '{b}_1',
                    'loc': 1.0,
                    'scale': 0.5},
             'b2': {'group': 'bias',
                    'label': '{b}_2',
                    'loc': 0.0,
                    'scale': 2.0},
             'bn2': {'group': 'bias',
                     'label': '{b}_{\\nabla^2}',
                     'loc': 0.0,
                     'scale': 2.0},
             'bs2': {'group': 'bias',
                     'label': '{b}_{s^2}',
                     'loc': 0.0,
                     'scale': 2.0},
             'init_mesh': {'group': 'init',
                  

## Run

### NUTS, HMC

In [9]:
def get_mcmc(model, config):
    n_samples = config['n_samples']
    n_chains = config['n_chains']
    max_tree_depth = config['max_tree_depth']
    target_accept_prob = config['target_accept_prob']
    name = config['sampler']
    
    if name == "NUTS":
        kernel = infer.NUTS(
            model=model,
            # init_strategy=numpyro.infer.init_to_value(values=fiduc_params)
            step_size=1e-5, 
            max_tree_depth=max_tree_depth,
            target_accept_prob=target_accept_prob,)
        
    elif name == "HMC":
        kernel = infer.HMC(
            model=model,
            # init_strategy=numpyro.infer.init_to_value(values=fiduc_params),
            step_size=1e-5, 
            # Rule of thumb (2**max_tree_depth-1)*step_size_NUTS/(2 to 4), compare with default 2pi.
            trajectory_length=1023 * 1e-3 / 4, 
            target_accept_prob=target_accept_prob,)

    mcmc = infer.MCMC(
        sampler=kernel,
        num_warmup=n_samples,
        num_samples=n_samples, # for each run
        num_chains=n_chains,
        chain_method="vectorized",
        progress_bar=True,)
    
    return mcmc

# print("mean_acc_prob:", last_state.mean_accept_prob, "\nss:", last_state.adapt_state.step_size)
# invmm = list(last_state.adapt_state.inverse_mass_matrix.values())[0][0]
# invmm.min(),invmm.max(),invmm.mean(),invmm.std()

# Init params
# init_model = model.copy()
# init_model.partial(temp=1e-2)
# init_params_ = init_model.predict(samples=n_chains)

In [10]:
continue_run = False
if continue_run:
    model.reset()
    model.condition({'obs': truth['obs']})
    model.block()
    mcmc = get_mcmc(model.model, mcmc_config)

    last_state = pload(save_path + "_last_state.p")
    mcmc.num_warmup = 0
    mcmc.post_warmup_state = last_state
    init_params_ = None
else:
    model.reset()
    model.condition({'obs': truth['obs']} | model.prior_loc, frombase=True)
    model.block()
    mcmc = get_mcmc(model.model, mcmc_config)
    
    print("Init params")
    init_params_ = jit(vmap(model.init_model))(jr.split(jr.key(43), mcmc_config['n_chains']))
    init_mesh_ = {k: init_params_[k] for k in ['init_mesh_']} # NOTE: !!!!!!!
    mcmc = sample_and_save(mcmc, save_path+'_init', 0, 0, extra_fields=['num_steps'], init_params=init_mesh_)
    
    print("mean_acc_prob:", mcmc.last_state.mean_accept_prob, "\nss:", mcmc.last_state.adapt_state.step_size)
    init_params_ |= mcmc.last_state.z
    print(init_params_.keys())

    model.reset()
    model.condition({'obs': truth['obs']})
    model.block()
    mcmc = get_mcmc(model.model, mcmc_config)

Init params


  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)


run 0/0 (warmup)


  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
warmup: 100%|██████████| 64/64 [01:04<00:00,  1.01s/it]


mean_acc_prob: [0.68993956 0.6840822  0.6904469  0.6878417  0.6914676  0.6852821
 0.6864878  0.6881593 ] 
ss: [0.24592769 0.1429676  0.16511315 0.12081773 0.184664   0.12074462
 0.15272841 0.14145015]
dict_keys(['Omega_m_', 'b1_', 'b2_', 'bn2_', 'bs2_', 'init_mesh_', 'sigma8_'])


In [9]:
mcmc_runned = sample_and_save(mcmc, save_path, 0, mcmc_config['n_runs'], extra_fields=['num_steps'], init_params=init_params_)

run 0/2 (warmup)


  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
warmup: 100%|██████████| 5/5 [00:34<00:00,  6.93s/it]


run 1/2


sample: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s]


run 2/2


sample: 100%|██████████| 5/5 [00:00<00:00, 11.81it/s]


In [None]:
# model.reset()
# model.condition({'obs': truth['obs']})
# model.block()
# mcmc = get_mcmc(model.model, mcmc_config)
# init_params_ = {k+'_': jnp.broadcast_to(truth[k+'_'], (mcmc_config['n_chains'], *jnp.shape(truth[k+'_']))) for k in ['Omega_m','sigma8','b1','b2','bs2','bn2','init_mesh']}

# mcmc_runned = sample_and_save(mcmc, mcmc_config['n_runs'], save_path, extra_fields=['num_steps'], init_params=init_params_)

In [None]:
from montecosmo.samplers import NUTSwG_init, get_NUTSwG_run

n_samples, n_runs, n_chains = 1024, 2, 8
save_path = save_dir + f"NUTSGibbs_ns{n_samples:d}_x_nc{n_chains}"

step_fn, init_fn, parameters, init_state_fn = NUTSwG_init(model.logpdf)
warmup_fn = jit(vmap(get_NUTSwG_run(model.logpdf, step_fn, init_fn, parameters, n_samples, warmup=True)))
key = jr.key(42)
last_state = jit(vmap(init_state_fn))(init_params_)
# last_state = pload(save_dir+"NUTSGibbs/HMCGibbs_ns256_x_nc8_laststate32.p")


(last_state, parameters), samples, infos = warmup_fn(jr.split(jr.key(43), n_chains), last_state)
print(parameters,'\n=======\n')
jnp.savez(save_path+f"_{0}.npz", **samples | infos)
pdump(last_state, save_path+f"_last_state.p")

run_fn = jit(vmap(get_NUTSwG_run(model.logpdf, step_fn, init_fn, parameters, n_samples)))
i_shift = 0
for i_run in range(i_shift+1, i_shift+n_runs+1):
    print(f"run {i_run}/{n_runs}")
    key, run_key = jr.split(key, 2)
    # last_state, samples, infos = run_fn(jr.split(run_key, n_chains), last_state)
    last_state, samples, infos = run_fn(jr.split(run_key, n_chains), last_state, parameters=parameters)
    jnp.savez(save_path+f"_{i_run}.npz", **samples | infos)
    pdump(last_state, save_path+f"_last_state.p")


  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)


Running window adaptation


  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)


Running window adaptation


  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)


 |████████████████████████████████████████| 100.00% [1024/1024 00:00<?]
 |████████████████████████████████████████| 100.00% [1024/1024 00:00<?]
{'mesh_': {'inverse_mass_matrix': Array([[0.91613275, 0.43298027, 0.58366376, ..., 1.3958235 , 1.4039643 ,
        1.4155319 ],
       [1.0042535 , 0.35944837, 0.5298701 , ..., 1.4606954 , 1.4201225 ,
        1.5059968 ],
       [0.8509194 , 0.40631598, 0.5384197 , ..., 1.5246091 , 1.3115985 ,
        1.6526046 ],
       ...,
       [1.0255859 , 0.3929412 , 0.54638743, ..., 1.4913795 , 1.7418612 ,
        1.3790766 ],
       [0.9393368 , 0.31887838, 0.5684601 , ..., 1.3481212 , 1.3279188 ,
        1.3940841 ],
       [1.0796475 , 0.3491805 , 0.5457695 , ..., 1.6703726 , 1.460268  ,
        1.3966581 ]], dtype=float32), 'step_size': Array([0.06817338, 0.08579013, 0.0819969 , 0.07307532, 0.06652877,
       0.07553854, 0.06986616, 0.08936442], dtype=float32, weak_type=True)}, 'rest_': {'inverse_mass_matrix': Array([[0.12300394, 0.01222097, 0.00059

  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)


 |████████████████████████████████████████| 100.00% [1024/1024 00:00<?]
run 2/2
 |----------------------------------------| 0.00% [0/1024 00:00<?]