In [None]:

### This tutorial notebook shows how to use 21cmKAN to fit a global 21 cm signal and constrain physical parameters that describe it

### The first section of this notebook employs a trained instance of 21cmKAN in PyMultiNest parameter inference analyses
### to fit mock global 21 cm signals generated by the 21cmGEM model, with added observational noise.

### The second section prints and plots the results of the parameter inference analyses and reproduces Figures 5, C1, and C2
### of Dorigo Jones et al. 2025 (hereafter referred to as DJ+25; see Section 3.3 for further details)

import os
import corner
import random
import json
import numpy as np
import pymultinest
import matplotlib.patches as mpatches
import matplotlib.patheffects as pe
import matplotlib.lines as mlines
import Global21cmKAN as Global21cmKAN
from Global21cmKAN.evaluate import evaluate_on_21cmGEM
from scripts.wrapper import predict_21cmGEM
from matplotlib import gridspec, rc
from matplotlib import pyplot as plt
from pymultinest.solve import solve
from pylinex import GaussianLoglikelihood
try: os.mkdir('multinest_chains_21cmKAN_21cmGEM_7param')
except OSError: pass

# Create the 21cmKAN emulator instance and load network trained on the 21cmGEM set (to fit mock 21cmGEM signals)
# the network loaded here is the same one used in Section 3.3 of DJ+25 to perform Bayesian inference analyses
# use the kwarg "model_path" to load a different trained instance of 21cmKAN
emulator_21cmGEM = Global21cmKAN.emulate_21cmGEM.Emulate()
emulator_21cmGEM.load_model()
nu_list = emulator_21cmGEM.frequencies
z_list = emulator_21cmGEM.redshifts
params_list = emulator_21cmGEM.par_labels
print('21cmGEM input physical parameter names:', params_list) # see descriptions in Table 1 of DJ+25

# define arrays of 21cmGEM test set parameters and signals (unnormalized)
params_21cmGEM_test = emulator_21cmGEM.par_test.copy()
signals_21cmGEM_test = emulator_21cmGEM.signal_test.copy()

# define arrays of 21cmGEM parameter values for the signals fit in DJ+24 and DJ+25, originally randomly selected from the test set
par_21cmGEM_5mK = np.array(params_21cmGEM_test[9])
par_21cmGEM_10mK = np.array(params_21cmGEM_test[742])
par_21cmGEM_25mK = np.array(params_21cmGEM_test[998])
print('"true" physical parameter values of mock 21cmGEM signal being fit with 5 mK added noise:', par_21cmGEM_5mK)
print('"true" physical parameter values of mock 21cmGEM signal being fit with 10 mK added noise:', par_21cmGEM_10mK)
print('"true" physical parameter values of mock 21cmGEM signal being fit with 25 mK added noise:', par_21cmGEM_25mK)

# define arrays of 21cmGEM signal dT_b values for the signals fit in DJ+24 and DJ+25
T_21cmGEM_5mK = signals_21cmGEM_test[9]
T_21cmGEM_10mK = signals_21cmGEM_test[742]
T_21cmGEM_25mK = signals_21cmGEM_test[998]
# use trained instance of 21cmKAN to emulate each 21cmGEM signal fit in DJ+24 and DJ+25
T_KAN_5mK = emulator_21cmGEM.predict(par_21cmGEM_5mK)
T_KAN_10mK = emulator_21cmGEM.predict(par_21cmGEM_10mK)
T_KAN_25mK = emulator_21cmGEM.predict(par_21cmGEM_25mK)

# add Gaussian-distributed, white noise in mK to each true mock signal being fit
np.random.seed(2) # set random seed so we get the same error realization for each fitting analysis
random_array = np.random.normal(size=len(T_21cmGEM_5mK))
noise_level_5mK = 5
noise_level_10mK = 10
noise_level_25mK = 25
true_signal_no_noise_5mK = T_21cmGEM_5mK.copy()
true_signal_no_noise_10mK = T_21cmGEM_10mK.copy()
true_signal_no_noise_25mK = T_21cmGEM_25mK.copy()
signal_flat_error_5mK = noise_level_5mK*np.ones(len(T_21cmGEM_5mK))
signal_flat_error_10mK = noise_level_10mK*np.ones(len(T_21cmGEM_10mK))
signal_flat_error_25mK = noise_level_25mK*np.ones(len(T_21cmGEM_25mK))
gaussian_error_5mK = random_array*signal_flat_error_5mK
gaussian_error_10mK = random_array*signal_flat_error_10mK
gaussian_error_25mK = random_array*signal_flat_error_25mK
input_true_signal_5mK = true_signal_no_noise_5mK+gaussian_error_5mK
input_true_signal_10mK = true_signal_no_noise_10mK+gaussian_error_10mK
input_true_signal_25mK = true_signal_no_noise_25mK+gaussian_error_25mK

# define Gaussian log-likelihood (using pylinex class) that is evaluated during PyMultiNest analyses
# uses provided wrapper to call the 21cmKAN network trained on 21cmGEM; initialize it with any params
# the network loaded here is the same one used in Section 3.3 of DJ+25 to perform Bayesian inference analyses
# use the kwarg "model_path" to load a different trained instance of 21cmKAN
model_signal = predict_21cmGEM(par_21cmGEM_5mK)
data_signal_5mK = input_true_signal_5mK
data_signal_10mK = input_true_signal_10mK
data_signal_25mK = input_true_signal_25mK
error_signal_5mK = signal_flat_error_5mK
error_signal_10mK = signal_flat_error_10mK
error_signal_25mK = signal_flat_error_25mK
gaussian_loglikelihood_signal_5mK = GaussianLoglikelihood(data_signal_5mK, error_signal_5mK, model_signal)
gaussian_loglikelihood_signal_10mK = GaussianLoglikelihood(data_signal_10mK, error_signal_10mK, model_signal)
gaussian_loglikelihood_signal_25mK = GaussianLoglikelihood(data_signal_25mK, error_signal_25mK, model_signal)

# define unit hypercube (i.e., transformed parameter space) that multinest samples from, where the prior distribution is uniform
def priors(cube):
    prior_array = np.array([np.log10(5e-1)+4,2-np.log10(4.2),9,0.16,0.5,2.9,40])*cube + np.array([-4,np.log10(4.2),-6,0.04,1,0.1,10])
    prior_array[0] = 10**prior_array[0]
    prior_array[1] = 10**prior_array[1]
    prior_array[2] = 10**prior_array[2]
    return prior_array

# run PyMultiNest analyses with hyperparameters described in Section 3.3.1 of DJ+25
n_params = len(params_list)
prefix_5mK = "multinest_chains_21cmKAN_21cmGEM_7param/KAN_21cmGEM_7param_nlive1200_tolerance0.1_sampeff0.8_5mKnoise-"
prefix_10mK = "multinest_chains_21cmKAN_21cmGEM_7param/KAN_21cmGEM_7param_nlive1200_tolerance0.1_sampeff0.8_10mKnoise-"
prefix_25mK = "multinest_chains_21cmKAN_21cmGEM_7param/KAN_21cmGEM_7param_nlive1200_tolerance0.1_sampeff0.8_25mKnoise-"
kwargs = {'n_live_points':1200, 'evidence_tolerance':0.1,'sampling_efficiency':0.8, 'n_iter_before_update':10}

result_5mK = solve(LogLikelihood=gaussian_loglikelihood_signal_5mK, Prior=priors, n_dims=n_params,\
                   outputfiles_basename=prefix_5mK, verbose=True, **kwargs)

result_10mK = solve(LogLikelihood=gaussian_loglikelihood_signal_10mK, Prior=priors, n_dims=n_params,\
                    outputfiles_basename=prefix_10mK, verbose=True, **kwargs)

result_25mK = solve(LogLikelihood=gaussian_loglikelihood_signal_25mK, Prior=priors, n_dims=n_params,\
                    outputfiles_basename=prefix_25mK, verbose=True, **kwargs)

#########################################################################################################################
#########################################################################################################################
# code to plot posterior signal realizations for the mock signal fits performed above; reproduces Figure 5 in DJ+25

# define predictor; the network loaded here is the same one used in Section 3.3 of DJ+25 to perform Bayesian inference analyses
# use the kwarg "model_path" to load a different trained instance of 21cmKAN
predictor = evaluate_on_21cmGEM()

print()
print('final evidence for 5 mK noise fit: %(logZ).1f +- %(logZerr).1f' % result_5mK)
print('final evidence for 10 mK noise fit: %(logZ).1f +- %(logZerr).1f' % result_10mK)
print('final evidence for 25 mK noise fit: %(logZ).1f +- %(logZerr).1f' % result_25mK)
print()
print('parameter posterior mean values for 5 mK noise fit:')
for name, col in zip(params_list, result_5mK['samples'].transpose()):
    print('%15s : %.3f +- %.3f' % (name, col.mean(), col.std()))
print()
print('parameter posterior mean values for 10 mK noise fit:')
for name, col in zip(params_list, result_10mK['samples'].transpose()):
    print('%15s : %.3f +- %.3f' % (name, col.mean(), col.std()))
print()
print('parameter posterior mean values for 25 mK noise fit:')
for name, col in zip(params_list, result_25mK['samples'].transpose()):
    print('%15s : %.3f +- %.3f' % (name, col.mean(), col.std()))

# store the parameter names, to make marginal plots by running $ python multinest_marginals.py chains/test-
with open('%sparams.json' % prefix_5mK, 'w') as f1:
    json.dump(params_list, f1, indent=2)

with open('%sparams.json' % prefix_10mK, 'w') as f2:
    json.dump(params_list, f2, indent=2)

with open('%sparams.json' % prefix_25mK, 'w') as f3:
    json.dump(params_list, f3, indent=2)

nest_analyzer_5mK = pymultinest.Analyzer(n_params, outputfiles_basename=prefix_5mK)
nest_analyzer_10mK = pymultinest.Analyzer(n_params, outputfiles_basename=prefix_10mK)
nest_analyzer_25mK = pymultinest.Analyzer(n_params, outputfiles_basename=prefix_25mK)
data_5mK = nest_analyzer_5mK.get_data()
data_10mK = nest_analyzer_10mK.get_data()
data_25mK = nest_analyzer_25mK.get_data()
weights_5mK = data_5mK[:,0]
samples_5mK = data_5mK[:,2:]
weights_10mK = data_10mK[:,0]
samples_10mK = data_10mK[:,2:]
weights_25mK = data_25mK[:,0]
samples_25mK = data_25mK[:,2:]

np.savetxt('multinest_chains_21cmKAN_21cmGEM_7param/samples_5mK_21cmGEM.txt', samples_5mK)
np.savetxt('multinest_chains_21cmKAN_21cmGEM_7param/samples_10mK_21cmGEM.txt', samples_10mK)
np.savetxt('multinest_chains_21cmKAN_21cmGEM_7param/samples_25mK_21cmGEM.txt', samples_25mK)
samples_5mK = np.loadtxt('multinest_chains_21cmKAN_21cmGEM_7param/samples_5mK_21cmGEM.txt')
samples_10mK = np.loadtxt('multinest_chains_21cmKAN_21cmGEM_7param/samples_10mK_21cmGEM.txt')
samples_25mK = np.loadtxt('multinest_chains_21cmKAN_21cmGEM_7param/samples_25mK_21cmGEM.txt')

residual_fiducial_5mK = true_signal_no_noise_5mK - T_KAN_5mK
residual_fiducial_10mK = true_signal_no_noise_10mK - T_KAN_10mK
residual_fiducial_25mK = true_signal_no_noise_25mK - T_KAN_25mK
rmse_fiducial_5mK = np.sqrt(np.mean(np.abs(residual_fiducial_5mK**2)))
rmse_fiducial_10mK = np.sqrt(np.mean(np.abs(residual_fiducial_10mK**2)))
rmse_fiducial_25mK = np.sqrt(np.mean(np.abs(residual_fiducial_25mK**2)))
print('rms error between true signal true and 21cmKAN emulated realization for 5 mK noise:', rmse_fiducial_5mK, 'mK')
print('rms error between true signal true and 21cmKAN emulated realization for 10 mK noise:', rmse_fiducial_10mK, 'mK')
print('rms error between true signal true and 21cmKAN emulated realization for 25 mK noise:', rmse_fiducial_25mK, 'mK')

vr = 1420.405751
def freq(zs):
    return vr/(zs+1)

def redshift(v):
    return (vr/v)-1

plt.rcParams['mathtext.fontset'] = 'cm'
fig = plt.figure()
fig.set_figheight(8)
fig.set_figwidth(24)
spec = gridspec.GridSpec(ncols=3, nrows=2, width_ratios=[1,1,1], wspace=0, hspace=0, height_ratios=[3, 1])
ax1 = fig.add_subplot(spec[0])
ax2 = fig.add_subplot(spec[1])
ax3 = fig.add_subplot(spec[2])
ax4 = fig.add_subplot(spec[3])
ax5 = fig.add_subplot(spec[4])
ax6 = fig.add_subplot(spec[5])

ax1.fill_between(nu_list, true_signal_no_noise_5mK-signal_flat_error_5mK, true_signal_no_noise_5mK+signal_flat_error_5mK, facecolor='b', edgecolor="none", alpha=0.3)
ax1.set_yticks([50, 0,-50,-100,-150,-200,-250,-300])
ax1.set_yticklabels(['50', '0','-50','-100','-150','-200','-250',''], fontsize=20, fontname= 'DejaVu Sans')
ax1.minorticks_on()
ax1.tick_params(axis='y', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax1.tick_params(axis='y', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax1.set_ylabel(r'$\delta T_b$ (mK)', fontsize=25, fontname= 'DejaVu Sans')
ax1.set_ylim(-300,50)
ax1.set_xlim(27.85,236.74)
ax1.set_title(r'$\sigma_{21}=$ 5 mK',fontsize=25)
secax1 = ax1.secondary_xaxis('top', functions=(redshift, freq))
secax1.set_xlabel(r'$z$', fontsize=25, fontname= 'DejaVu Sans')
secax1.set_xticks([5, 10, 15, 20, 30, 50])
secax1.set_xticklabels(['5', '10', '15', '20', '30', '50'], fontsize=20, fontname= 'DejaVu Sans')
secax1.tick_params(which='major', direction = 'out', width = 2, length = 10, labelsize=20)

ax2.fill_between(nu_list, true_signal_no_noise_10mK-signal_flat_error_10mK, true_signal_no_noise_10mK+signal_flat_error_10mK, facecolor='b', edgecolor="none", alpha=0.3)
ax2.set_yticks([50, 0,-50,-100,-150,-200,-250,-300])
ax2.set_yticklabels(['','','','','','','',''], fontsize=18, fontname= 'DejaVu Sans')
ax2.minorticks_on()
ax2.tick_params(axis='y', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax2.tick_params(axis='y', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax2.set_ylim(-300,50)
ax2.set_xlim(27.85,236.74)
ax2.set_title(r'$\sigma_{21}=$ 10 mK',fontsize=25)
secax2 = ax2.secondary_xaxis('top', functions=(redshift, freq))
secax2.set_xlabel(r'$z$', fontsize=25, fontname= 'DejaVu Sans')
secax2.set_xticks([5, 10, 15, 20, 30, 50])
secax2.set_xticklabels(['5', '10', '15', '20', '30', ''], fontsize=20, fontname= 'DejaVu Sans')
secax2.tick_params(which='major', direction = 'out', width = 2, length = 10, labelsize=20)

ax3.fill_between(nu_list, true_signal_no_noise_25mK-signal_flat_error_25mK, true_signal_no_noise_25mK+signal_flat_error_25mK, facecolor='b', edgecolor="none", alpha=0.3)
ax3.set_yticks([50, 0,-50,-100,-150,-200,-250,-300])
ax3.set_yticklabels(['','','','','','','',''], fontsize=18, fontname= 'DejaVu Sans')
ax3.minorticks_on()
ax3.tick_params(axis='y', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax3.tick_params(axis='y', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax3.set_ylim(-300,50)
ax3.set_xlim(27.85,236.74)
ax3.set_title(r'$\sigma_{21}=$ 25 mK',fontsize=25)
secax3 = ax3.secondary_xaxis('top', functions=(redshift, freq))
secax3.set_xlabel(r'$z$', fontsize=25, fontname= 'DejaVu Sans')
secax3.set_xticks([5, 10, 15, 20, 30, 50])
secax3.set_xticklabels(['5', '10', '15', '20', '30', ''], fontsize=20, fontname= 'DejaVu Sans')
secax3.tick_params(which='major', direction = 'out', width = 2, length = 10, labelsize=20)

ax4.set_ylabel(r'residual (mK)', fontsize=25, fontname= 'DejaVu Sans')
ax4.set_yticks([-50,-25,0,25,50])
ax4.set_yticklabels(['-50', '-25', '0', '25','50'], fontsize=20, fontname= 'DejaVu Sans')
ax4.tick_params(axis='y', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax4.set_ylim(-50,50)
ax4.set_xlim(27.85,236.74)
ax4.minorticks_on()
ax4.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax4.tick_params(axis='both', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax4.set_xticks([40, 60, 80, 100, 120, 140, 160, 180, 200, 220])
ax4.set_xticklabels(['40', '60', '80', '100', '120', '140', '160', '180', '200', '220'], fontsize=20, fontname= 'DejaVu Sans')
ax4.set_xlabel(r'$\nu$ (MHz)', fontsize=25, fontname= 'DejaVu Sans')
ax4.fill_between(nu_list, -5, 5, facecolor='b', edgecolor="none", alpha=0.3)
residual_posteriors_list_5mK = []
rmse_posteriors_list_5mK = []
frac_err_list_5mK = []
i=0
for i in range(np.shape(samples_5mK)[0]):
    input_params_5mK = samples_5mK[i]
    signal_5mK = predictor(input_params_5mK)
    input_signal_5mK = signal_5mK.copy()
    residual_posteriors_5mK = true_signal_no_noise_5mK - input_signal_5mK
    residual_posteriors_list_5mK.append(residual_posteriors_5mK)
    rmse_posteriors_val_5mK = np.sqrt(np.mean((residual_posteriors_list_5mK[i])**2))
    rmse_posteriors_list_5mK.append(rmse_posteriors_val_5mK)
    frac_err_5mK = rmse_posteriors_val_5mK/np.max(np.abs(true_signal_no_noise_5mK))
    frac_err_list_5mK.append(frac_err_5mK*100)

frac_err_5mK_1sigma = np.percentile(frac_err_list_5mK, 68)
rmse_5mK_1sigma = np.percentile(rmse_posteriors_list_5mK, 68)
print('1sigma relative rms error out of all the posterior samples for the 5 mK fit:', frac_err_5mK_1sigma)
print('1sigma absolute rms error out of all the posterior samples for the 5 mK fit:', rmse_5mK_1sigma)
print('mean relative rms error out of all the posterior samples for the 5 mK fit:', np.mean(frac_err_list_5mK))
print('mean absolute rms error out of all the posterior samples for the 5 mK fit:', np.mean(rmse_posteriors_list_5mK))
print('median relative rms error out of all the posterior samples for the 5 mK fit:', np.median(frac_err_list_5mK))
print('median absolute rms error out of all the posterior samples for the 5 mK fit:', np.median(rmse_posteriors_list_5mK))
frac_err_5mK_1sigma_where = np.argwhere(frac_err_list_5mK<frac_err_5mK_1sigma)
frac_err_5mK_1sigma_where = frac_err_5mK_1sigma_where[:,0]
print('number of posterior samples (out of all for the 5 mK fit) with the 68% best relative errors:', np.shape(frac_err_5mK_1sigma_where))
samples_5mK_1sigma = samples_5mK[frac_err_5mK_1sigma_where]
residual_posteriors_5mK_1sigma = np.array(residual_posteriors_list_5mK)[frac_err_5mK_1sigma_where]
print(np.shape(samples_5mK_1sigma))
print(np.shape(residual_posteriors_5mK_1sigma))
i=0
for i in range(np.shape(samples_5mK_1sigma)[0]):
    input_params_5mK_1sigma = samples_5mK_1sigma[i]
    signal_5mK_1sigma = predictor(input_params_5mK_1sigma)
    ax1.plot(nu_list, signal_5mK_1sigma, color='r', alpha=0.01)
    ax4.plot(nu_list, residual_posteriors_5mK_1sigma[i], color='r', alpha=0.01)

ax5.set_yticks([-50,-25,0,25,50])
ax5.set_yticklabels(['', '', '', '', ''], fontsize=20, fontname= 'DejaVu Sans')
ax5.tick_params(axis='y', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax5.set_ylim(-50,50)
ax5.set_xlim(27.85,236.74)
ax5.minorticks_on()
ax5.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax5.tick_params(axis='both', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax5.set_xticks([40, 60, 80, 100, 120, 140, 160, 180, 200, 220])
ax5.set_xticklabels(['40', '60', '80', '100', '120', '140', '160', '180', '200', '220'], fontsize=20, fontname= 'DejaVu Sans')
ax5.set_xlabel(r'$\nu$ (MHz)', fontsize=25, fontname= 'DejaVu Sans')
ax5.fill_between(nu_list, -10, 10, facecolor='b', edgecolor="none", alpha=0.3)
residual_posteriors_list_10mK = []
rmse_posteriors_list_10mK = []
frac_err_list_10mK = []
i=0
for i in range(np.shape(samples_10mK)[0]):
    input_params_10mK = samples_10mK[i] #samples_5mK[-100:][i]
    signal_10mK = predictor(input_params_10mK)
    input_signal_10mK = signal_10mK.copy()
    residual_posteriors_10mK = true_signal_no_noise_10mK - input_signal_10mK
    residual_posteriors_list_10mK.append(residual_posteriors_10mK)
    rmse_posteriors_val_10mK = np.sqrt(np.mean((residual_posteriors_list_10mK[i])**2))
    rmse_posteriors_list_10mK.append(rmse_posteriors_val_10mK)
    frac_err_10mK = rmse_posteriors_val_10mK/np.max(np.abs(true_signal_no_noise_10mK))
    frac_err_list_10mK.append(frac_err_10mK*100)

frac_err_10mK_1sigma = np.percentile(frac_err_list_10mK, 68)
rmse_10mK_1sigma = np.percentile(rmse_posteriors_list_10mK, 68)
print('1sigma relative rms error out of all the posterior samples for the 10 mK fit:', frac_err_10mK_1sigma)
print('1sigma absolute rms error out of all the posterior samples for the 10 mK fit:', rmse_10mK_1sigma)
print('mean relative rms error out of all the posterior samples for the 10 mK fit:', np.mean(frac_err_list_10mK))
print('mean absolute rms error out of all the posterior samples for the 10 mK fit:', np.mean(rmse_posteriors_list_10mK))
print('median relative rms error out of all the posterior samples for the 10 mK fit:', np.median(frac_err_list_10mK))
print('median absolute rms error out of all the posterior samples for the 10 mK fit:', np.median(rmse_posteriors_list_10mK))
frac_err_10mK_1sigma_where = np.argwhere(frac_err_list_10mK<frac_err_10mK_1sigma)
frac_err_10mK_1sigma_where = frac_err_10mK_1sigma_where[:,0]
print('number of posterior samples (out of all for the 10 mK fit) with the 68% best relative errors:', np.shape(frac_err_10mK_1sigma_where))
samples_10mK_1sigma = samples_10mK[frac_err_10mK_1sigma_where]
residual_posteriors_10mK_1sigma = np.array(residual_posteriors_list_10mK)[frac_err_10mK_1sigma_where]
print(np.shape(samples_10mK_1sigma))

i=0
for i in range(np.shape(samples_10mK_1sigma)[0]):
    input_params_10mK_1sigma = samples_10mK_1sigma[i]
    signal_10mK_1sigma = predictor(input_params_10mK_1sigma)
    ax2.plot(nu_list, signal_10mK_1sigma, color='r', alpha=0.01)
    ax5.plot(nu_list, residual_posteriors_10mK_1sigma[i], color='r', alpha=0.01)

ax6.set_yticks([-50,-25,0,25,50])
ax6.set_yticklabels(['', '', '', '', ''], fontsize=20, fontname= 'DejaVu Sans')
ax6.tick_params(axis='y', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax6.set_ylim(-50,50)
ax6.set_xlim(27.85,236.74)
ax6.minorticks_on()
ax6.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax6.tick_params(axis='both', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax6.set_xticks([40, 60, 80, 100, 120, 140, 160, 180, 200, 220])
ax6.set_xticklabels(['40', '60', '80', '100', '120', '140', '160', '180', '200', '220'], fontsize=20, fontname= 'DejaVu Sans')
ax6.set_xlabel(r'$\nu$ (MHz)', fontsize=25, fontname= 'DejaVu Sans')
ax6.fill_between(nu_list, -25, 25, facecolor='b', edgecolor="none", alpha=0.3)
residual_posteriors_list_25mK = []
rmse_posteriors_list_25mK = []
frac_err_list_25mK = []
i=0
for i in range(np.shape(samples_25mK)[0]):
    input_params_25mK = samples_25mK[i] #samples_5mK[-100:][i]
    signal_25mK = predictor(input_params_25mK)
    input_signal_25mK = signal_25mK.copy()
    residual_posteriors_25mK = true_signal_no_noise_25mK - input_signal_25mK
    residual_posteriors_list_25mK.append(residual_posteriors_25mK)
    rmse_posteriors_val_25mK = np.sqrt(np.mean((residual_posteriors_list_25mK[i])**2))
    rmse_posteriors_list_25mK.append(rmse_posteriors_val_25mK)
    frac_err_25mK = rmse_posteriors_val_25mK/np.max(np.abs(true_signal_no_noise_25mK))
    frac_err_list_25mK.append(frac_err_25mK*100)

frac_err_25mK_1sigma = np.percentile(frac_err_list_25mK, 68)
rmse_25mK_1sigma = np.percentile(rmse_posteriors_list_25mK, 68)
print('1sigma relative rms error out of all the posterior samples for the 25 mK fit:', frac_err_25mK_1sigma)
print('1sigma absolute rms error out of all the posterior samples for the 25 mK fit:', rmse_25mK_1sigma)
print('mean relative rms error out of all the posterior samples for the 25 mK fit:', np.mean(frac_err_list_25mK))
print('mean absolute rms error out of all the posterior samples for the 25 mK fit:', np.mean(rmse_posteriors_list_25mK))
print('median relative rms error out of all the posterior samples for the 25 mK fit:', np.median(frac_err_list_25mK))
print('median absolute rms error out of all the posterior samples for the 25 mK fit:', np.median(rmse_posteriors_list_25mK))
frac_err_25mK_1sigma_where = np.argwhere(frac_err_list_25mK<frac_err_25mK_1sigma)
frac_err_25mK_1sigma_where = frac_err_25mK_1sigma_where[:,0]
print('number of posterior samples (out of all for the 25 mK fit) with the 68% best relative errors:', np.shape(frac_err_25mK_1sigma_where))
samples_25mK_1sigma = samples_25mK[frac_err_25mK_1sigma_where]
residual_posteriors_25mK_1sigma = np.array(residual_posteriors_list_25mK)[frac_err_25mK_1sigma_where]
print(np.shape(samples_25mK_1sigma))
i=0
for i in range(np.shape(samples_25mK_1sigma)[0]):
    input_params_25mK_1sigma = samples_25mK_1sigma[i]
    signal_25mK_1sigma = predictor(input_params_25mK_1sigma)
    ax3.plot(nu_list, signal_25mK_1sigma, color='r', alpha=0.01)
    ax6.plot(nu_list, residual_posteriors_25mK_1sigma[i], color='r', alpha=0.01)

def Extract(lst,x):
        return list(list(zip(*lst))[x])

mean_residual_posteriors_list_5mK = []
i=0
for i in range(451):
    y_5mK = Extract(residual_posteriors_list_5mK,i)
    mean_residual_posteriors_5mK = np.mean(y_5mK)
    mean_residual_posteriors_list_5mK.append(mean_residual_posteriors_5mK)

ax4.plot(nu_list, mean_residual_posteriors_list_5mK, color='k', lw=2, alpha=1, zorder=20000)
ax4.plot(nu_list, residual_fiducial_5mK, color='k',linestyle='--', lw=2, alpha=1, zorder=20000)
ax1.plot(nu_list, true_signal_no_noise_5mK, color='b', lw=2)

mean_residual_posteriors_list_10mK = []
i=0
for i in range(451):
    y_10mK = Extract(residual_posteriors_list_10mK,i)
    mean_residual_posteriors_10mK = np.mean(y_10mK)
    mean_residual_posteriors_list_10mK.append(mean_residual_posteriors_10mK)

ax5.plot(nu_list, mean_residual_posteriors_list_10mK, color='k', lw=2, alpha=1, zorder=20000)
ax5.plot(nu_list, residual_fiducial_10mK, color='k',linestyle='--', lw=2, alpha=1, zorder=20000)
ax2.plot(nu_list, true_signal_no_noise_10mK, color='b', lw=2)

mean_residual_posteriors_list_25mK = []
i=0
for i in range(451):
    y_25mK = Extract(residual_posteriors_list_25mK,i)
    mean_residual_posteriors_25mK = np.mean(y_25mK)
    mean_residual_posteriors_list_25mK.append(mean_residual_posteriors_25mK)

ax6.plot(nu_list, mean_residual_posteriors_list_25mK, color='k', lw=2, alpha=1, zorder=20000)
ax6.plot(nu_list, residual_fiducial_25mK, color='k',linestyle='--', lw=2, alpha=1, zorder=20000)
ax3.plot(nu_list, true_signal_no_noise_25mK, color='b', lw=2)

red_line_alpha = mlines.Line2D([], [], color='red', marker='s', alpha=0.2, linestyle='none',
                          markersize=12, label=r'emulated ${\tt 21cmKAN}$ $1\sigma$ posteriors')
blue_square = mlines.Line2D([], [], color='b', marker='s', alpha=0.3, linestyle='None',
                          markersize=12, label=r'noise level, $\sigma_{21}$')
black_line_dashed = mlines.Line2D([], [], color='k', linestyle='--', linewidth=2, label=r'true signal emulator realization (bottom)')
black_line = mlines.Line2D([], [], color='k', linewidth=2, label=r'mean posterior (bottom)')
blue_line = mlines.Line2D([], [], color='b', linewidth=2, label=r'true signal model realization (top)')
plt.legend(handles=[blue_line, blue_square, red_line_alpha, black_line, black_line_dashed],bbox_to_anchor=(-1.1,1.74,0.1,0.1), fontsize="16.9")
plt.savefig('21cmKAN_21cmGEM_7param_nlive1200_25mk_10mk_5mk_posterior_realizations.jpg', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

plt.cla()
plt.clf()

print('mean absolute RMSE between fiducial signal and ALL posteriors for 5mK noise:', np.mean(rmse_posteriors_list_5mK), 'mK')
print('mean relative RMSE between fiducial signal and ALL posteriors for 5mK noise:', np.mean(frac_err_list_5mK), '%')

print('mean absolute RMSE between fiducial signal and ALL posteriors for 10mK noise:', np.mean(rmse_posteriors_list_10mK), 'mK')
print('mean relative RMSE between fiducial signal and ALL posteriors for 10mK noise:', np.mean(frac_err_list_10mK), '%')

print('mean absolute RMSE between fiducial signal and ALL posteriors for 25mK noise:', np.mean(rmse_posteriors_list_25mK), 'mK')
print('mean relative RMSE between fiducial signal and ALL posteriors for 25mK noise:', np.mean(frac_err_list_25mK), '%')

print('1sigma absolute rms error out of all the posterior samples for the 5 mK fit:', rmse_5mK_1sigma)
print('1sigma absolute rms error out of all the posterior samples for the 10 mK fit:', rmse_10mK_1sigma)
print('1sigma absolute rms error out of all the posterior samples for the 25 mK fit:', rmse_25mK_1sigma)

########################################################################################################################################
# code to plot the full 1D and 2D marginalized posterior distributions for the 25 mK and 5 mK noise mock signal fits performed above
# reproduces Figures C1 and C2 in DJ+25

reference_value_mean_5mK = [np.log10(par_21cmGEM_5mK[0]), np.log10(par_21cmGEM_5mK[1]), np.log10(par_21cmGEM_5mK[2]), par_21cmGEM_5mK[3], par_21cmGEM_5mK[4], par_21cmGEM_5mK[5], par_21cmGEM_5mK[6]]
reference_value_mean_10mK = [np.log10(par_21cmGEM_10mK[0]), np.log10(par_21cmGEM_10mK[1]), np.log10(par_21cmGEM_10mK[2]), par_21cmGEM_10mK[3], par_21cmGEM_10mK[4], par_21cmGEM_10mK[5], par_21cmGEM_10mK[6]]
reference_value_mean_25mK = [np.log10(par_21cmGEM_25mK[0]), np.log10(par_21cmGEM_25mK[1]), np.log10(par_21cmGEM_25mK[2]), par_21cmGEM_25mK[3], par_21cmGEM_25mK[4], par_21cmGEM_25mK[5], par_21cmGEM_25mK[6]]

factor = 2.0
lbdim = 0.5 * factor
trdim = 0.2 * factor
whspace = 0.05
plotdim = factor * n_params + factor * (n_params - 1.0) * whspace
dim = lbdim + plotdim + trdim
# define plot limits
plot_limits = [[-4,np.log10(5e-1)], [np.log10(4.2),2], [-6,3], [0.04,0.2], [1,1.5], [0.1,3.0], [10,50]] # size : [nparams, 2]
param_labels = [r'$\log_{10} f_*$', r'$\log_{10} V_c$', r'$\log_{10} f_X$', r'$\tau$', r'$\alpha$', r'$\nu_{\rm min}$', r'$R_{\rm mfp}$']

# reproduce Figure C1 in DJ+25
samples_25mK[:,0] = [np.log10(x) for x in samples_25mK[:,0]]
samples_25mK[:,1] = [np.log10(x) for x in samples_25mK[:,1]]
samples_25mK[:,2] = [np.log10(x) for x in samples_25mK[:,2]]
best_fit_25mK = nest_analyzer_25mK.get_best_fit()['parameters']
print('input param values:', reference_value_mean_25mK)
figure, axs = plt.subplots(n_params, n_params, figsize=(dim,dim))
corner.corner(samples_25mK, smooth=2, smooth1d=2, range=plot_limits, bins=100, alpha = 1.0, fig=figure, plot_datapoints=False,\
              weights=weights_25mK, label = r'$\sigma_{21}$ = 25 mK', hist_kwargs={'alpha':1.0, 'linewidth':3},\
              hist2d_kwargs={'alpha':1.0, 'linewidth':3}, data_kwargs={'alpha':1.0, 'linewidth':3}, plot_density = True,\
              plot_contours = True, linewidth = 3, color = 'red', quantiles=[], levels=np.array([0.9545]), labels=param_labels,\
              show_titles=False, title_kwargs={"fontsize": 30}, label_kwargs={"fontsize": 30}, truths=reference_value_mean_25mK)
plt.suptitle(r'$\sigma_{21}=$ 25 mK',fontsize=28)
plt.savefig('21cmKAN_21cmGEM_7param_nlive1200_smooth2_100bins_25mk_corner.jpg', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()
plt.cla()
plt.clf()
print('***best-fit parameter values for 25 mK fit***')
j=0
for j in np.arange(n_params):
    print(parameter_labels[j], '= %.4f' % best_fit_25mK[j])

# reproduce Figure C2 in DJ+25
samples_5mK[:,0] = [np.log10(x) for x in samples_5mK[:,0]]
samples_5mK[:,1] = [np.log10(x) for x in samples_5mK[:,1]]
samples_5mK[:,2] = [np.log10(x) for x in samples_5mK[:,2]]
best_fit_5mK = nest_analyzer_5mK.get_best_fit()['parameters']
print('input param values:', reference_value_mean_5mK)
figure, axs = plt.subplots(n_params, n_params, figsize=(dim,dim))
corner.corner(samples_5mK, smooth=2, smooth1d=2, range=plot_limits, bins=100, alpha = 1.0, fig=figure, plot_datapoints=False,\
              weights=weights_5mK, label = r'$\sigma_{21}$ = 5 mK', hist_kwargs={'alpha':1.0, 'linewidth':3},\
              hist2d_kwargs={'alpha':1.0, 'linewidth':3}, data_kwargs={'alpha':1.0, 'linewidth':3}, plot_density = True,\
              plot_contours = True, linewidth = 3, color = 'red', quantiles=[], levels=np.array([0.9545]), labels=param_labels,\
              show_titles=False, title_kwargs={"fontsize": 30}, label_kwargs={"fontsize": 30}, truths=reference_value_mean_5mK)
plt.suptitle(r'$\sigma_{21}=$ 5 mK',fontsize=28)
plt.savefig('21cmKAN_21cmGEM_7param_nlive1200_smooth2_100bins_5mk_corner.jpg', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()
plt.cla()
plt.clf()
print('***best-fit parameter values for 5 mK fit***')
j=0
for j in np.arange(n_params):
    print(parameter_labels[j], '= %.4f' % best_fit_5mK[j])
