### Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['OPENBLAS_NUM_THREADS'] = '16'  
from os.path import join
import numpy as np
from tqdm import tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt

from omegaconf import OmegaConf
import h5py

### Load training and test data

In [None]:
def get_cosmo(source_path):
    cfg = OmegaConf.load(join(source_path, 'config.yaml'))
    return np.array(cfg.nbody.cosmo)

def get_halo_Pk(source_path):
    diag_file = join(source_path, 'diag', 'halos.h5')
    if not os.path.exists(diag_file):
        return None, None, None
    with h5py.File(diag_file, 'r') as f:
        a = list(f.keys())[-1]
        k = f[a]['Pk_k'][:]
        Pk = f[a]['Pk'][:]
    cosmo = get_cosmo(source_path)
    return k, Pk, cosmo

In [None]:
# Load pre-processed data
suite = 'quijote'
sim = 'nbody'
L, N = 1000, 128

from cmass.utils import get_source_path

suite_path = get_source_path(wdir, suite, sim, L, N, 0)[:-2]

ktrain, Pktrain, cosmotrain = [], [], []
#for lhid in tqdm(os.listdir(suite_path)):
for lhid in tqdm(['0', '1', '2', '3']):
    k, Pk, cosmo = get_halo_Pk(join(suite_path, lhid))
    ktrain.append(k)
    Pktrain.append(Pk)
    cosmotrain.append(cosmo)

ktrain, Pktrain, cosmotrain = map(lambda x: [i for i in x if i is not None], [ktrain, Pktrain, cosmotrain])
ktrain, Pktrain, cosmotrain = map(np.stack, [ktrain, Pktrain, cosmotrain])
print(ktrain.shape, Pktrain.shape, cosmotrain.shape)

# Split into training and testing
# FIXME: Change testing set to be a subset of the total data
ktest, Pktest, cosmotest = map(lambda x: [i for i in x if i is not None], [ktrain, Pktrain, cosmotrain])
ktest, Pktest, cosmotest = map(np.stack, [ktest, Pktest, cosmotest])
print(ktest.shape, Pktest.shape, cosmotest.shape)

In [None]:
l = 0

f, ax = plt.subplots()
cmap = mpl.cm.get_cmap('viridis')
Om = cosmotrain[:,0]
for k, Pk, cosmo in zip(ktrain, Pktrain[...,0], Om):
    omnorm = (Om-Om.min())/(Om.max()-Om.min())
    ax.loglog(k, Pk, lw=1, alpha=0.1)
ax.loglog(ktest[l], Pktest[l,...,0], 'k')

### Train

In [None]:
import ili
from ili.dataloaders import NumpyLoader
from ili.inference import InferenceRunner
from ili.validation.metrics import PosteriorCoverage, PlotSinglePosterior
from ili.embedding import FCN
device='cpu'

def run_inference(x, theta):
    loader = NumpyLoader(x=x, theta=theta)

    # define a prior
    prior = ili.utils.Uniform(
        low=theta.min(axis=0),
        high=theta.max(axis=0),
        device=device)

    embedding = FCN(n_hidden=[64, 32, 16], act_fn='ReLU')

    # instantiate your neural networks to be used as an ensemble
    nets = [
        ili.utils.load_nde_lampe(
            model='maf', hidden_features=50, num_transforms=5,
            embedding_net=embedding),
    ]

    # define training arguments
    train_args = {
        'training_batch_size': 32,
        'learning_rate': 1e-4
    }

    # initialize the trainer
    runner = InferenceRunner.load(
        backend='lampe',
        engine='NPE',
        prior=prior,
        nets=nets,
        device=device,
        train_args=train_args
    )

    # train the model
    posterior_ensemble, summaries = runner(loader=loader)

    return posterior_ensemble, summaries

In [None]:
# train at different resolutions
kmaxs = [0.2] # [0.1, 0.2, 0.4, 0.6]

# impute with mean
def impute(arr):
    # Compute the mean of each row, ignoring nan values
    col_mean = np.nanmean(arr, axis=1)
    
    # Find the indices where values are nan
    inds = np.where(np.isnan(arr))
    
    # Replace nan values with the mean of the respective column
    arr[inds] = np.take(col_mean, inds[0])
    return arr

In [None]:
posteriors, summaries = {}, {}
for kmax in kmaxs:
    print(f'Training for kmax={kmax}')

    # focus on the monopole
    x = Pktrain[:, :, 0]

    # cut on k
    mask = ktrain[0] < kmax
    x = x[:, mask]
    theta = cosmotrain

    # log
    x = np.log10(x)

    # impute
    x = impute(x)

    # train
    _p, _s = run_inference(x, theta) # FIXME: figure out why it breaks here
    posteriors[kmax] = _p
    summaries[kmax] = _s


### Plot

In [None]:
# metric = PlotSinglePosterior(
#     num_samples=1000, sample_method='direct',
#     labels=['Omega_m', 'Omega_b', 'h', 'n_s', 'sigma8']
# )
# fig = None
# for i, kmax in enumerate(kmaxs):
#     print(kmax)
#     xobs = Pktest[:, :, 0]
#     mask = ktest[0] < kmax
#     xobs = xobs[:, mask]
#     xobs = np.log10(xobs)
#     xobs = impute(xobs)
#     _p = posteriors[kmax]
#     fig = metric(
#         posterior=_p,
#         x_obs = xobs[0], # , theta_fid = thetaobs[0],
#         grid=fig,
#         name=f'kmax={kmax}'
#     )

In [None]:
kmax = 0.2
xobs = Pktest[:, :, 0]
mask = ktest[0] < kmax
xobs = xobs[:, mask]
xobs = np.log10(xobs)
xobs = impute(xobs)
xobs[0]

In [None]:
metric = PlotSinglePosterior(
    num_samples=1000, sample_method='direct',
    labels=['Omega_m', 'Omega_b', 'h', 'n_s', 'sigma8']
)
fig = metric(
    posterior=_p,
    x_obs = xobs[0], theta_fid = cosmotest[0],
    grid=None,
    name=f'kmax={kmax}'
)

In [None]:
metric = PosteriorCoverage(
    num_samples=1000, sample_method='direct', 
    labels=['Omega_m', 'Omega_b', 'h', 'n_s', 'sigma8'],
    plot_list = ["coverage", "histogram", "predictions", "tarp", "logprob"],
    out_dir=None
)
metric(posterior=_p, x=xobs, theta=cosmotest);