# Tutorial: using the `picasso` trained predictors

This notebook shows how one can use the trained models to make predictions of gas thermodynamics from halo properties.

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from astropy.table import Table
from astropy.cosmology import FlatLambdaCDM

from picasso import predictors
from picasso.utils.plots import NFW

import seaborn as sns
sns.set_style("darkgrid")
sns.set_theme("notebook")

We will use the `minimal_576` trained model, which takes as input halo mass and concentration:

In [2]:
predictor = predictors.minimal_576
print(predictor.input_names)

['log M200', 'c200']


## Predicting gas model parameters

First, we want to compute predictions for the model parameter vector, $\vartheta_{\rm gas}$.
To do so, we simply need the vector of scalar halo properties $\vartheta_{\rm halo}$.
We'll use some pre-stored data (containing four halos from the simulations presented in Kéruzoré+24) and write the input vector:

In [3]:
halos = Table.read("../data/halos.hdf5")
logM200c = jnp.log10(halos["M200c"])
c200c = jnp.array(halos["c200c"])
theta_halo = jnp.array([logM200c, c200c]).T
print(theta_halo.shape)

(4, 2)


2024-08-12 17:44:59.590725: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


We can then use the `predictor.predict_model_parameters()` function to predict $\vartheta_{\rm gas}$. For a single halo:

In [4]:
theta_gas_0 = predictor.predict_model_parameters(theta_halo[0])
print(theta_gas_0)

[3.2217349e+03 1.9191902e+02 1.1346726e+00 0.0000000e+00 3.5946820e-07
 1.1811828e-02 2.1121843e-01 1.6479002e+00]


The `predictor.predict_model_parameters()` function can also be used for several halos at a time:

In [5]:
theta_gas = predictor.predict_model_parameters(theta_halo)
print(theta_gas)
print(theta_gas.shape)

[[3.2217329e+03 1.9191902e+02 1.1346726e+00 0.0000000e+00 3.5946820e-07
  1.1811828e-02 2.1121846e-01 1.6479002e+00]
 [1.1516166e+03 6.4102837e+01 1.1361028e+00 0.0000000e+00 2.6019131e-07
  4.2736135e-02 3.2087338e-01 1.0452865e+00]
 [1.1092018e+03 6.1245892e+01 1.1370699e+00 0.0000000e+00 1.8584416e-07
  3.9753512e-02 3.3282009e-01 1.0353638e+00]
 [8.4998853e+02 4.8176407e+01 1.1426771e+00 0.0000000e+00 1.3825040e-07
  4.1372985e-02 3.5185230e-01 9.5307529e-01]]
(4, 8)


It can also be just-in-time compiled:

In [6]:
predict_jit = jax.jit(predictor.predict_model_parameters)
print("Not jitted:")
%timeit _ = predictor.predict_model_parameters(theta_halo)
print("jitted:")
%timeit _ = predict_jit(theta_halo)

Not jit-ted:
27 ms ± 179 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
jit-ted:
64.7 μs ± 25.1 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Predicting gas thermodynamics

With a prediction for $\vartheta_{\rm gas}$, we can use `picasso.polytrop` and `picasso.nonthermal` to predict gas thermodynamics (see example gallery).
`PicassoPredictor` objects also offers a wrapper function that predicts all thermodynamic properties directly from an input vector $\vartheta_{\rm halo}$ and a potential distribution.
Assuming the halos above are NFW, we can predict their potential profiles:

In [30]:
r_R500c = jnp.logspace(-1, 0.5, 51)

phi = []
for i in range(4):
    nfw_i = NFW(halos["M200c"][i], halos["c200c"][i], "200c", z=0.0, cosmo=FlatLambdaCDM(70.0, 0.3))
    phi_i = nfw_i.potential(r_R500c * halos["R500c"][i])
    phi.append(phi_i - nfw_i.potential(1e-6))

Then, we can make predictions of gas thermodynamics for one halo:

In [33]:
rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c, r_R500c / 2)

Or for all halos at the same time (this function uses `jax.vmap` to vectorize the predictions):

In [40]:
r_R500c_v = jnp.outer(jnp.ones(4), r_R500c)
phi_v = jnp.array(phi)
rho_g, P_tot, P_th, f_nt = predictor.predict_gas_model(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)


In [41]:
predict_jit = jax.jit(predictor.predict_gas_model)

print("1 halo, not jitted:")
%timeit _ = predictor.predict_gas_model(theta_halo[0], phi[0], r_R500c, r_R500c / 2)
print("1 halo, jitted:")
%timeit _ = predict_jit(theta_halo[0], phi[0], r_R500c, r_R500c / 2)

print("4 halo, not jitted:")
%timeit _ = predictor.predict_gas_model(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)
print("4 halo, jitted:")
%timeit _ = predict_jit(theta_halo, phi_v, r_R500c_v, r_R500c_v / 2)

1 halo, not jitted:
30.3 ms ± 119 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1 halo, jitted:
519 μs ± 91.8 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
4 halo, not jitted:
48.7 ms ± 708 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
4 halo, jitted:
168 μs ± 64.2 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
