In [None]:
import baccoemu
import chainconsumer
import dynesty
import gc
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from multiprocessing import Pool, cpu_count

import emcee
import os
import pandas as pd
import time

import sys
sys.path.append('/dipc/kstoreyf/muchisimocks/scripts')
#import sbi_tools
import plot_utils
#import scripts
# from scripts import sbi_tools
#from scripts import plot_utils
import generate_emuPks as genP

from momentnetworks import demo

%load_ext autoreload
%autoreload 2

In [None]:
plot_dir = '../plots/plots_2024-02-19'
save_plots = True

tag_fit = '_cosmolib'

In [None]:
ncpu = cpu_count()
print("{0} CPUs".format(ncpu))
os.environ["OMP_NUM_THREADS"] = "1"

In [None]:
%matplotlib inline
mpl.pyplot.style.use('default')
mpl.pyplot.close('all')

font, rcnew = plot_utils.matplotlib_default_config()
mpl.rc('font', **font)
mpl.pyplot.rcParams.update(rcnew)
mpl.pyplot.style.use('tableau-colorblind10')
%config InlineBackend.figure_format = 'retina'
#N_threads = sbi_tools.set_N_threads(6)

mpl.rcParams['xtick.labelsize'] = 16 
mpl.rcParams['ytick.labelsize'] = 16 

### Load data

In [None]:
param_names_all = ['omega_cold', 'sigma_8', 'h', 'omega_baryon', 'n_s', 'seed']
param_names = ['omega_cold', 'sigma_8', 'h']
param_names_fixed = [name for name in param_names_all if name not in param_names]
idxs_param_names = [param_names_all.index(name) for name in param_names]

#tag_pk = '_b1only'
tag_pk = '_b0000'
dir_pks = f'../data/pks_cosmolib/pks{tag_pk}'

n_lib = 500
dir_mocks = '../data/cosmolib'
theta = []
Pk = []
gaussian_error_pk = []
param_dict_fixed = {}
for idx_LH in range(n_lib):
    fn_fields = f'{dir_mocks}/LH{idx_LH}/Eulerian_fields_lr_{idx_LH}.npy'
    fn_params = f'{dir_mocks}/LH{idx_LH}/cosmo_{idx_LH}.txt'
    fn_pk = f'{dir_pks}/pk_{idx_LH}.npy'
    
    pk_obj = np.load(fn_pk, allow_pickle=True).item()
    Pk.append(pk_obj['pk'])
    gaussian_error_pk.append(pk_obj['pk_gaussian_error'])
    
    param_vals = np.loadtxt(fn_params)
    if idx_LH==0:
        for name in param_names_fixed:
            param_dict_fixed[name] = param_vals[param_names_all.index(name)]
    theta.append(param_vals[idxs_param_names])

Pk = np.array(Pk)
theta = np.array(theta)
gaussian_error_pk = np.array(gaussian_error_pk)

kk = pk_obj['k'] # all ks should be same so just grab one

In [None]:
print(Pk.shape, theta.shape, gaussian_error_pk.shape)

In [None]:
param_dict_fixed

In [None]:
# i_oob = theta[:,param_names.index('sigma_8')] < 0.73
# print(np.sum(i_oob), len(i_oob))
# print(np.min(theta[:,param_names.index('sigma_8')]), np.max(theta[:,param_names.index('sigma_8')] ))

# # if wan't to run on a subset, edit n_samples here (max=1000 right now)
# n_samples = n_lib
# Pk = Pk[~i_oob][:n_samples]
# theta = theta[~i_oob][:n_samples]

n_samples = n_lib
Pk = Pk[:n_samples]
theta = theta[:n_samples]
gaussian_error_pk = gaussian_error_pk[:n_samples]

In [None]:
print(Pk.shape, theta.shape)

In [None]:
n_tot = theta.shape[0]
n_params = theta.shape[1]

Plot P(k) data:

In [None]:
fig, ax = mpl.pyplot.subplots(figsize=(6, 4.5))
for iLH in range(n_tot):
    ax.loglog(kk, Pk[iLH])

ax.set_xlabel(r'$k \,\, [h \,\, {\rm Mpc}^{-1}]$', fontsize=23)
ax.set_ylabel(r'$P(k) \,\, [h^{-3} \,\, {\rm Mpc}^3]$', fontsize=23)

mpl.pyplot.tight_layout()
mpl.pyplot.show()

In [None]:
#n_biasmodels = len(biases_vec)
n_biasmodels = 0
n_cosmos = n_params
print(n_biasmodels, n_cosmos)

Split into train-val-test

In [None]:
p_train, p_test = 0.8, 0.1
p_val = 1-p_train-p_test
train_split = int(theta.shape[0]*p_train)
test_split = int(theta.shape[0]*(1-p_test))
#train_val_split = int(n_biasmodels*round(theta.shape[0]*0.99/n_biasmodels))

theta_train = theta[:train_split]
theta_val = theta[train_split:test_split]
theta_test = theta[test_split:]
print(theta_train.shape, theta_val.shape, theta_test.shape)

In [None]:
# add noise 
# Pk_train = Pk[:train_split]
# err_1p = 0.01*np.mean(Pk_train, axis=0)
# rng = np.random.default_rng()
# Pk += rng.normal(loc=0, scale=err_1p, size=Pk.shape)

In [None]:
Pk_train = Pk[:train_split]
Pk_val = Pk[train_split:test_split]
Pk_test = Pk[test_split:]

mask = np.all(Pk_train>0, axis=0)
Pk_train = Pk_train[:,mask]
Pk_val = Pk_val[:,mask]
Pk_test = Pk_test[:,mask]
k = kk[mask]

gaussian_error_pk_train = gaussian_error_pk[:train_split][:,mask]
gaussian_error_pk_val = gaussian_error_pk[train_split:test_split][:,mask]
gaussian_error_pk_test = gaussian_error_pk[test_split:][:,mask]

In [None]:
n_dim = Pk_train.shape[1]
print(n_tot, n_params, n_dim)

In [None]:
fig, ax = mpl.pyplot.subplots(1,1, figsize=(7,5))
fontsize = 24
fontsize1 = 18

alpha = 1

tmp_Pk_plot = Pk_train
tmp_Pk_plot = tmp_Pk_plot[np.random.choice(tmp_Pk_plot.shape[0], tmp_Pk_plot.shape[0], replace=False)].T
ax.plot(np.log10(k), np.log10(tmp_Pk_plot), c='royalblue', alpha=alpha, lw=0.5, label='training set')

tmp_Pk_plot = Pk_test
tmp_Pk_plot = tmp_Pk_plot[np.random.choice(tmp_Pk_plot.shape[0], tmp_Pk_plot.shape[0], replace=False)].T
ax.plot(np.log10(k), np.log10(tmp_Pk_plot), c='k', alpha=alpha, lw=0.5, label='test set')
    
ax.set_xlabel(r'$k \,\, [h \,\, {\rm Mpc}^{-1}]$', fontsize=23)
ax.set_ylabel(r'$P(k) \,\, [h^{-3} \,\, {\rm Mpc}^3]$', fontsize=23)

mpl.pyplot.tight_layout()
mpl.pyplot.show()

In [None]:
dict_bounds = {}
for pp, param_name in enumerate(param_names):
    dict_bounds[param_name] = [np.min(theta[:,pp]), np.max(theta[:,pp])]

In [None]:
class Scaler:

    def __init__(self):
          pass
        
    def fit(self, x_train):
        self.x_train_min = np.min(x_train)
        self.x_train_max = np.max(x_train)
           
    def scale(self, x):
        log_x = np.log10(x)
        log_x_norm = (log_x - np.log10(self.x_train_min)) / (np.log10(self.x_train_max) - np.log10(self.x_train_min))
        return log_x_norm
    
    def unscale(self, x_scaled):
        x = x_scaled * (np.log10(self.x_train_max) - np.log10(self.x_train_min)) + np.log10(self.x_train_min)
        return 10**x  
    
    def scale_error(self, err, x):
        # need 1/np.log(10) factor bc working in base 10
        #print(err[:,0])
        dydx = 1./x * 1/np.log(10) * 1./(np.log10(self.x_train_max) - np.log10(self.x_train_min))
        #print(dydx[:,0])
        #print((dydx**2).shape)
        #print((err**2).shape)
        #print(np.multiply(dydx**2, err**2)[:,0])
        
        #print(dydx.shape, err.shape)
        mult = (dydx*err)
        #print(mult.shape)
        #print(mult[:,0])
        err_scaled = np.sqrt(np.multiply(dydx**2, err**2))
        #print(err_scaled[:,0])
        return err_scaled

uh, error is always the same when scaled.... by construction w log?!

In [None]:
25458.80421554*3.46969057e-06

In [None]:
12263.69676411*7.20289930e-06, 

In [None]:
scaler = Scaler()
scaler.fit(Pk_train)
scaler.scale_error(gaussian_error_pk_train[:2], Pk_train[:2])

In [None]:
scaler = Scaler()
scaler.fit(Pk_train)
Pk_train_scaled = scaler.scale(Pk_train)
Pk_val_scaled = scaler.scale(Pk_val)
Pk_test_scaled = scaler.scale(Pk_test)

gaussian_error_pk_train_scaled = scaler.scale_error(gaussian_error_pk_train, Pk_train)
gaussian_error_pk_val_scaled = scaler.scale_error(gaussian_error_pk_val, Pk_val)
gaussian_error_pk_test_scaled = scaler.scale_error(gaussian_error_pk_test, Pk_test)

In [None]:
print(np.min(Pk_train), np.max(Pk_train))
print(np.min(Pk_train_scaled), np.max(Pk_train_scaled))

print(np.min(Pk_test), np.max(Pk_test))
print(np.min(Pk_test_scaled), np.max(Pk_test_scaled))

In [None]:
print(Pk_train.shape)
print(theta_train.shape)
print(n_params)

### Set up and run Moment Network model

Following demos at https://github.com/NiallJeffrey/MomentNetworks/tree/master

In [None]:
model_instance = demo.simple_leaky(n_dim, n_params, learning_rate=1e-4) 
regression = model_instance.model() 

In [None]:
print(theta_train.shape, Pk_train.shape)
print(theta_val.shape, Pk_val.shape)

Train initial model (basic MLP), as usual, on labeled data

In [None]:
history = regression.fit(Pk_train_scaled, theta_train,
                         epochs=200, batch_size=32, shuffle=True,
                         validation_data=(Pk_val_scaled, theta_val))

In [None]:
#predicted_mean = regression.predict(np.atleast_2d(Pk_train_scaled)) # maybe should be train & val??

Get means and residuals

In [None]:
theta_train_pred = regression.predict(np.atleast_2d(Pk_train_scaled))
theta_val_pred = regression.predict(np.atleast_2d(Pk_val_scaled))

cov_dict = {}

training_var_unknown_mean = []
training_var_unknown_mean_val = []
count = 0
for i in range(n_params):
    for j in range(n_params):
        if j<i:
            cov_dict[(i,j)] = cov_dict[(j,i)]
            continue
        training_covariance = ((theta_train[:,i]-theta_train_pred[:,i])* \
                               (theta_train[:,j]-theta_train_pred[:,j]))
        training_var_unknown_mean.append(training_covariance)
        
        training_covariance_val = ((theta_val[:,i]-theta_val_pred[:,i])* \
                                   (theta_val[:,j]-theta_val_pred[:,j]))
        training_var_unknown_mean_val.append(training_covariance_val)
        
        cov_dict[(i,j)] = count
        count += 1
        
training_var_unknown_mean = np.array(training_var_unknown_mean).T
training_var_unknown_mean_val = np.array(training_var_unknown_mean_val).T

print(training_var_unknown_mean.shape)
print(training_var_unknown_mean[0])

n_covs = training_var_unknown_mean.shape[1]

Set up and train model on the residuals

In [None]:
model_instance = demo.simple_leaky(n_dim, n_covs, learning_rate=1e-3)
regression_var_unknown_mean = model_instance.model()

In [None]:
history = regression_var_unknown_mean.fit(Pk_train_scaled,
                                          training_var_unknown_mean,
                                          epochs=200, batch_size=32, shuffle=True,
                                          validation_data = (Pk_val_scaled,
                                                             training_var_unknown_mean_val))

### Set up MCMC

In [None]:
param_names_2_emu_param_names = {'sigma_8': 'sigma8_cold',
                                 'omega_cold': 'omega_cold',
                                 'h': 'hubble',
                                 'n_s': 'ns'}

In [None]:
def setup_cosmo_emu():
    print("Setting up emulator cosmology")
    cosmo_params = {
        #'omega_cold'    :  Om,
        #'sigma8_cold'   :  sigma8, # if A_s is not specified
        'omega_baryon'  :  param_dict_fixed['omega_baryon'],
        'ns'            :  param_dict_fixed['n_s'],
        #'hubble'        :  hubble,
        'neutrino_mass' :  0.0,
        'w0'            : -1.0,
        'wa'            :  0.0,
        'expfactor'     :  1
    }
    return cosmo_params

In [None]:
#emu = baccoemu.Lbias_expansion(verbose=False)
emu = baccoemu.Lbias_expansion(nonlinear_emu_path='/dipc_storage/cosmosims/data_share/lbias_emulator/lbias_emulator2.0.0',
                                     nonlinear_emu_details='details.pickle',
                                     nonlinear_emu_field_name='NN_n',
                                     nonlinear_emu_read_rotation=False)
print(emu.emulator['nonlinear']['bounds'])
cosmo_params = setup_cosmo_emu()
bias_params = [1., 0., 0., 0.]

In [None]:
n_burn = 40
n_steps = 200 # 50000
n_walkers = 4 * n_params

##### Check emu

In [None]:
bias_params = [0.0, 0.0, 0.0, 0.0]

In [None]:
gaussian_error_pk[0]

In [None]:
nrows, ncols = 2, 1
fig, axarr = plt.subplots(nrows, ncols, figsize=(6,6), sharex=True, height_ratios=[2,1])
plt.subplots_adjust(hspace=0)

colors = ['red', 'orange', 'green', 'blue', 'purple']

for i in range(5):
    for pp in range(len(param_names)):
        emu_param_name = param_names_2_emu_param_names[param_names[pp]]
        cosmo_params[emu_param_name] = theta[i][pp]
    _, pk_model_unscaled, _ = emu.get_galaxy_real_pk(bias=bias_params, k=k, 
                                                        **cosmo_params)

    label_true, label_emu = None, None
    if i==0:
        label_true = 'measured from map2map bias field'
        label_emu = 'emulated at true theta'
    
    if i==0:
        axarr[0].errorbar(k, Pk_train[i], yerr=gaussian_error_pk_train[i], 
                          ls='--', marker='o', markersize=6, alpha=0.5, label=label_true, color=colors[i])
    else:
        axarr[0].plot(k, Pk_train[i], ls='--', marker='o', markersize=6, alpha=0.5, label=label_true, color=colors[i])
    axarr[0].plot(k, pk_model_unscaled, ls='-', alpha=0.5, label=label_emu, color=colors[i])
    
    axarr[1].plot(k, (pk_model_unscaled-Pk_train[i])/gaussian_error_pk_train[i], ls='-', alpha=0.5, color=colors[i])
    axarr[1].axhline(0, color='grey', lw=0.5)
    err_extra = 0.1*np.std(Pk_train, axis=0)
    axarr[1].fill_between(k, -err_extra/gaussian_error_pk_train[i], 
                              err_extra/gaussian_error_pk_train[i], color='grey', alpha=0.1)

plt.xscale('log')
axarr[0].set_yscale('log')
    
axarr[1].set_ylim(-5, 5)
    
axarr[0].legend(fontsize=12)
axarr[1].set_xlabel(r'$k \,\, [h \,\, {\rm Mpc}^{-1}]$', fontsize=18)
axarr[0].set_ylabel(r'$P(k) \,\, [h^{-3} \,\, {\rm Mpc}^3]$', fontsize=18)    
axarr[1].set_ylabel(r'$(P_\text{emu}-P_\text{m2m})/\sigma_\text{G,m2m}$', fontsize=18)


In [None]:
nrows, ncols = 2, 1
fig, axarr = plt.subplots(nrows, ncols, figsize=(6,6), sharex=True, height_ratios=[2,1])
plt.subplots_adjust(hspace=0)

colors = ['red', 'orange', 'green', 'blue', 'purple']

for i in range(5):
    for pp in range(len(param_names)):
        emu_param_name = param_names_2_emu_param_names[param_names[pp]]
        cosmo_params[emu_param_name] = theta[i][pp]
    _, pk_model_unscaled, _ = emu.get_galaxy_real_pk(bias=bias_params, k=k, 
                                                        **cosmo_params)
    pk_model_scaled = scaler.scale(pk_model_unscaled)
    label_true, label_emu = None, None
    if i==0:
        label_true = 'measured from map2map bias field'
        label_emu = 'emulated at true theta'
    
    if i==0:
        axarr[0].errorbar(k, Pk_train_scaled[i], yerr=gaussian_error_pk_train_scaled[i], 
                          ls='--', marker='o', markersize=6, alpha=0.5, label=label_true, color=colors[i])
    else:
        axarr[0].plot(k, Pk_train_scaled[i], ls='--', marker='o', markersize=6, alpha=0.5, label=label_true, color=colors[i])
    axarr[0].plot(k, pk_model_scaled, ls='-', alpha=0.5, label=label_emu, color=colors[i])
    
    axarr[1].plot(k, (pk_model_scaled-Pk_train_scaled[i])/gaussian_error_pk_train_scaled[i], ls='-', alpha=0.5, color=colors[i])
    axarr[1].axhline(0, color='grey', lw=0.5)
    
    err_extra = 0.1*np.std(Pk_train_scaled, axis=0)
    axarr[1].fill_between(k, -err_extra/gaussian_error_pk_train_scaled[i], 
                              err_extra/gaussian_error_pk_train_scaled[i], color='grey', alpha=0.1)

plt.xscale('log')
#axarr[0].set_yscale('log')
    
axarr[1].set_ylim(-5, 5)
    
axarr[0].legend(fontsize=12)
axarr[1].set_xlabel(r'$k \,\, [h \,\, {\rm Mpc}^{-1}]$', fontsize=18)
axarr[0].set_ylabel(r'$P(k) \,\, [h^{-3} \,\, {\rm Mpc}^3]$', fontsize=18)    
axarr[1].set_ylabel(r'$(P_\text{emu}-P_\text{m2m})/\sigma_\text{G,m2m}$', fontsize=18)


In [None]:
global pk_data, cov_inv

def log_prior(theta):
    for pp in range(len(param_names)):
       if (theta[pp] < dict_bounds[param_names[pp]][0]) or (theta[pp] >= dict_bounds[param_names[pp]][1]):
           return -np.inf
    return 0.0

def log_likelihood(theta):
    for pp in range(len(param_names)):
        emu_param_name = param_names_2_emu_param_names[param_names[pp]]
        cosmo_params[emu_param_name] = theta[pp]
    _, pk_model_unscaled, _ = emu.get_galaxy_real_pk(bias=bias_params, k=k, 
                                                **cosmo_params)
    pk_model = scaler.scale(pk_model_unscaled)
    diff = pk_data-pk_model
    # print(theta)
    # print(cosmo_params)
    # print(pk_data)
    # print(pk_model)
    # print(cov_inv[0,0], cov_inv[1,1], cov_inv[2,2])
    # print(-0.5*np.dot(diff,np.dot(cov_inv,diff)))
    
    # print()
    return -0.5*np.dot(diff,np.dot(cov_inv,diff))

def log_posterior(theta):
    lp = log_prior(theta)
    if not np.isfinite(lp):
        return -np.inf
    return lp + log_likelihood(theta)

### Test on a model pulled directly from the training set (NOT held-out data) 

In [None]:

#idx_train_check = rng.choice(np.arange(len(theta_train)))
idx_train_check = 17

print(idx_train_check)
theta_train_check = np.array([theta_train[idx_train_check]])
print(theta_train_check)
#Pk_train_check = np.array([Pk_train[idx_train_check]])
Pk_train_scaled_check = np.array([Pk_train_scaled[idx_train_check]])

predicted_mean_obs = regression.predict(np.atleast_2d(Pk_train_scaled_check))
predicted_var_obs = (regression_var_unknown_mean.predict(np.atleast_2d(Pk_train_scaled_check))[0])

print(predicted_var_obs)
print(predicted_var_obs.shape)
moment_network_param_cov = np.empty((n_params, n_params))

for i in range(n_params):
    for j in range(n_params):
        moment_network_param_cov[i,j] = predicted_var_obs[cov_dict[(i,j)]]
print(moment_network_param_cov)

In [None]:
moment_network_samples = np.array(np.random.multivariate_normal(predicted_mean_obs[0],moment_network_param_cov,int(1e6)),dtype=np.float32)
gc.collect()

#### Dynesty MCMC

In [None]:
def prior_transform(u):

    u_transformed = []
    for pp in range(len(param_names)):
        width = dict_bounds[param_names[pp]][1] - dict_bounds[param_names[pp]][0]
        min_bound = dict_bounds[param_names[pp]][0]
        
        u_t = width*u[pp] + min_bound
        u_transformed.append(u_t)           

    return np.array(u_transformed)

In [None]:
dict_bounds

In [None]:
err_gaussian = gaussian_error_pk_train_scaled[idx_train_check]
err_extra = 0.1*np.std(Pk_train_scaled, axis=0)
err = np.sqrt(err_gaussian**2 + err_extra**2)

cov_inv = np.diag(1/err**2)
pk_data = Pk_train_scaled[idx_train_check]

n_threads = 8

with dynesty.pool.Pool(n_threads, log_likelihood, prior_transform) as pool:

    sampler = dynesty.NestedSampler(pool.loglike, pool.prior_transform, n_params, 
                                    nlive=10, bound='single')
    sampler.run_nested(dlogz=0.01)



In [None]:
results = sampler.results
samples_dynesty = results.samples_equal()
print(samples_dynesty.shape)

In [None]:
samples_dynesty

In [None]:
from dynesty import plotting as dyplot

fig, axes = plt.subplots(n_params, n_params, figsize=(3, 3))
axes = axes.reshape((n_params, n_params)) 
fg, ax = dyplot.cornerplot(results, color='dodgerblue', #truths=np.zeros(n_params),
                           truth_color='black', show_titles=True,
                           quantiles=None, max_n_ticks=3,
                           fig=(fig, axes))

#### MCMC for comparison

In [None]:
rng = np.random.default_rng(seed=42)
theta_0 = np.array([[rng.uniform(low=dict_bounds[param_name][0],high=dict_bounds[param_name][1]) 
            for param_name in param_names] for _ in range(n_walkers)])

In [None]:

pk_data = Pk_train_scaled[idx_train_check]

n_threads = 8
start = time.time()
if n_threads>1:
    with Pool(processes=n_threads) as pool:
        sampler_emcee = emcee.EnsembleSampler(n_walkers, n_params, log_posterior, pool=pool,
                                    #args=(Pk_train_scaled[idx_train_check],cov_inv)
                                    )
        _ = sampler_emcee.run_mcmc(theta_0, n_steps, progress=True) 
else:
    sampler_emcee = emcee.EnsembleSampler(n_walkers, n_params, log_posterior,
                                #args=(Pk_train_scaled[idx_train_check],cov_inv)
                                )
    _ = sampler_emcee.run_mcmc(theta_0, n_steps, progress=True) 
end = time.time()

print(f"Time: {end-start} s ({(end-start)/60} min)")


In [None]:
samples_emcee = sampler_emcee.get_chain(discard=n_burn, flat=True,thin=1)
gc.collect()

### Plot contours

In [None]:
param_label_dict = {'omega_cold': r'$\Omega_\mathrm{m}$',
                'sigma8_cold': r'$\sigma_{8}$',
                'sigma_8': r'$\sigma_{8}$',
                'hubble': r'$h$',
                'h': r'$h$',
                'ns': r'$n_\mathrm{s}$',
                'n_s': r'$n_\mathrm{s}$',
                'omega_baryon': r'$\Omega_\mathrm{b}$',}
param_labels = [param_label_dict[param_name] for param_name in param_names]
extents = [dict_bounds[param_name] for param_name in param_names]

In [None]:
c = chainconsumer.ChainConsumer()

c.add_chain(chainconsumer.Chain(
            samples=pd.DataFrame(moment_network_samples, columns=param_names),
            name='Moment Network', color='blue')
            )

# checked that this gives the same as direct, once remove burn-in
# chain_emcee = chainconsumer.Chain.from_emcee(sampler_emcee, param_names, discard=n_burn,
#                                              name="MCMC (emcee)", color="red")
# c.add_chain(chain_emcee)

# c.add_chain(chainconsumer.Chain(
#             samples=pd.DataFrame(samples_emcee, columns=param_names),
#             name='MCMC (emcee)', color='purple', ls='--',
#             smooth=1, bins=10)
#             )

c.add_chain(chainconsumer.Chain(
            samples=pd.DataFrame(samples_dynesty, columns=param_names),
            name='MCMC (Dynesty)', color='green', 
            smooth=1, bins=5)
            )

c.set_plot_config(
    chainconsumer.PlotConfig(
        flip=True,
        labels=param_label_dict,
        contour_label_font_size=12,
        #extents=dict_bounds,
    )
)

#c.set_override(chainconsumer.ChainConfig(smooth=1, bins=10))
#c.set_override(chainconsumer.ChainConfig(smooth=1, bins=10))

# c.configure(kde=[1.,None],sigmas = [1,2],
#             contour_label_font_size = 11,
#             label_font_size = 16, shade = False) 

truth_loc = dict(zip(param_names, theta_train_check[0]))
c.add_truth(chainconsumer.Truth(location=truth_loc))

fig = c.plotter.plot(figsize = (5,4) )

if save_plots:
    plt.savefig(f'{plot_dir}/contours_traincheck{idx_train_check}{tag_fit}.png')

### Test on a model from the test set (held-out data)

In [None]:
idx_test = 0
predicted_mean_obs_test = regression.predict(np.atleast_2d(Pk_test_scaled[idx_test]))
predicted_var_obs_test = (regression_var_unknown_mean.predict(np.atleast_2d(Pk_test_scaled[idx_test]))[0])

moment_network_param_cov_test = np.empty((n_params, n_params))

for i in range(n_params):
    for j in range(n_params):
        moment_network_param_cov_test[i,j] = predicted_var_obs_test[cov_dict[(i,j)]]
print(moment_network_param_cov_test)

In [None]:
moment_network_samples_test = np.array(np.random.multivariate_normal(predicted_mean_obs_test[0],
                                  moment_network_param_cov_test,int(1e6)),dtype=np.float32)
gc.collect()

In [None]:
pk_data = Pk_test_scaled[idx_test]

n_threads = 8

with dynesty.pool.Pool(n_threads, log_likelihood, prior_transform) as pool:

    sampler_test = dynesty.NestedSampler(pool.loglike, pool.prior_transform, n_params, 
                                    nlive=20, bound='single')
    sampler_test.run_nested(dlogz=0.01)

In [None]:
results_test = sampler_test.results
samples_dynesty_test = results_test.samples_equal()
print(samples_dynesty_test.shape)

In [None]:
c = chainconsumer.ChainConsumer()

c.add_chain(chainconsumer.Chain(
            samples=pd.DataFrame(moment_network_samples_test, columns=param_names),
            name='Moment Network', color='blue')
            )

# c.add_chain(chainconsumer.Chain(
#             samples=pd.DataFrame(samples_emcee, columns=param_names),
#             name='MCMC (emcee)', color='purple', ls='--',
#             smooth=1, bins=10)
#             )

c.add_chain(chainconsumer.Chain(
            samples=pd.DataFrame(samples_dynesty_test, columns=param_names),
            name='MCMC (Dynesty)', color='green', 
            smooth=2, bins=5)
            )

c.set_plot_config(
    chainconsumer.PlotConfig(
        flip=True,
        labels=param_label_dict,
        contour_label_font_size=12,
        #extents=dict_bounds,
    )
)

truth_loc = dict(zip(param_names, theta_test[idx_test]))
c.add_truth(chainconsumer.Truth(location=truth_loc))

fig = c.plotter.plot(figsize = (5,4) )
if save_plots:
    plt.savefig(f'{plot_dir}/contours_test{idx_test}{tag_fit}.png')