In [None]:
from __future__ import absolute_import, unicode_literals, print_function
import os
try: os.mkdir('multinest_chains_21cmKAN_21cmGEM_7param')
except OSError: pass

import Global21cmKAN as Global21cmKAN
import numpy as np
import random
import pymultinest
from pymultinest.solve import solve
from matplotlib import pyplot as plt
from matplotlib import rc
from pylinex import Sampler, BurnRule, NLFitter, GaussianLoglikelihood, GaussianModel
from perses.models.KANSignalModel21cmGEM import KANSignalModel
#import ares
#from ares.simulations import Global21cm
#from scipy import interpolate
#from distpy.distribution import DistributionSet, DistributionList, UniformDistribution
#from distpy import triangle_plot
#from pylinex.loglikelihood.GaussianLoglikelihoodJoint import GaussianLoglikelihoodJoint
# from perses.models import AresSignalModel, GlobalemuSignalModel
#from perses.models.GlobalemuSignalModelWithUVLF import GlobalemuSignalModelWithUVLF
#from perses.models.AresUVLFModel import AresUVLFModel

rc('figure', figsize=(8.0, 7.0))
plt.rcParams['mathtext.fontset'] = 'cm'

# evaluate a trained instance of 21cmKAN on the given parameter values
emulator = Global21cmKAN.emulator_21cmKAN_21cmGEM.Emulate() # initialize emulator with 21cmKAN data sets
emulator.load_model()  # load pretrained instance of 21cmKAN emulator. Use kwarg 'model_path' to load other models.
#emulator.emulator.summary()
params_list = emulator.par_labels # input physical parameters

# define arrays of test set parameters and signals
params_21cmGEM_test = emulator.par_test.copy() # grab parameter values: 'f_star', 'V_c', 'f_X', 'tau', 'alpha', 'nu_min', 'R_mfp'
signals_21cmGEM_test = emulator.signal_test.copy()

# random_index = np.random.choice(len(params_21cmGEM_test), 1, replace=False)
# par_21cmGEM = np.array(params_21cmGEM_test[random_index][0]) # parameters of mock signal to be fit (i.e., fiducial parameters)

# signals fit in DJ+24, randomly selected from the 21cmGEM test set
par_21cmGEM_5mK = np.array([0.15811388, 16.5, 0.1, 0.06260315, 1., 0.2, 40.]) # index 9
par_21cmGEM_10mK = np.array([5.000000e-02, 7.650000e+01, 1e-6, 6.917401e-02, 1.300000e+00, 3.000000e-01, 4.000000e+01]) # index 742
par_21cmGEM_25mK = np.array([1.1020955e-02, 4.1533981e+01, 6.4696016e-04, 7.6195940e-02, 1.0000000e+00, 2.0000000e+00, 4.5000000e+01]) # index 998

signal_id_5mK = 9
signal_id_10mK = 742
signal_id_25mK = 998

print(params_21cmGEM_test[signal_id_5mK])
print(params_21cmGEM_test[signal_id_10mK])
print(params_21cmGEM_test[signal_id_25mK])

par_21cmGEM = par_21cmGEM_5mK

print('parameters in 21cmGEM:', params_list)
print('fiducial parameter values for mock 21cmGEM signal being fit:', par_21cmGEM)
# T_21cmGEM = signals_21cmGEM_test[random_index][0] # mock signal to be fit (no noise added yet)
T_21cmGEM_5mK = signals_21cmGEM_test[signal_id_5mK] # mock signal to be fit (no noise added yet)
T_21cmGEM_10mK = signals_21cmGEM_test[signal_id_10mK] # mock signal to be fit (no noise added yet)
T_21cmGEM_25mK = signals_21cmGEM_test[signal_id_25mK] # mock signal to be fit (no noise added yet)
T_KAN_5mK = emulator.predict(par_21cmGEM_5mK)  # emulate the mock global 21 cm signal
T_KAN_10mK = emulator.predict(par_21cmGEM_10mK)  # emulate the mock global 21 cm signal
T_KAN_25mK = emulator.predict(par_21cmGEM_25mK)  # emulate the mock global 21 cm signal
nu_list = emulator.frequencies
z_list = emulator.redshifts

T_21cmGEM = T_21cmGEM_5mK
T_KAN = T_KAN_5mK

noise_level = 5 # standard flat noise in mK
true_signal_no_noise = T_21cmGEM.copy()
signal_flat_error = noise_level*np.ones(len(T_21cmGEM))
np.random.seed(2) # set random seed so we get the same error realization for each fit analysis
random_array = np.random.normal(size=len(T_21cmGEM))
gaussian_error = random_array*signal_flat_error
input_true_signal = true_signal_no_noise+gaussian_error

data_signal = input_true_signal
model_signal = KANSignalModel(par_21cmGEM) # initialize KAN signal model emulator, list of params doesn't matter
error_signal = signal_flat_error
gaussian_loglikelihood_signal = GaussianLoglikelihood(data_signal, error_signal, model_signal)
#input_params = np.array([3e-4, 4.2, 1e-6, 0.055, 1.0, 0.1, 10])

vr = 1420.405751
def freq(zs):
    return vr/(zs+1)

def redshift(v):
    return (vr/v)-1

rc('figure', figsize=(7.0, 5.0))
fig, ax = plt.subplots(constrained_layout=True)
ax.minorticks_on()
ax.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax.tick_params(axis='both', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax.set_yticks([50, 0,-50,-100,-150,-200,-250,-300,-350,-400])
ax.set_yticklabels(['50', '0','-50','-100','-150','-200','-250','-300','-350','-400'], fontsize=15, fontname= 'Baskerville')
ax.set_xticks([40, 60, 80, 100, 120, 140, 160, 180, 200, 220])
ax.set_xticklabels(['40', '60', '80', '100', '120', '140', '160', '180', '200', '220'], fontsize=15, fontname= 'Baskerville')
ax.set_ylabel(r'$\delta T_b$ (mK)', fontsize=20, fontname= 'Baskerville')
ax.set_xlabel(r'$\nu$ (MHz)', fontsize=20, fontname= 'Baskerville')
secax = ax.secondary_xaxis('top', functions=(redshift, freq))
secax.tick_params(which='major', direction = 'out', width = 2, length = 10, labelsize=15)
secax.tick_params(which='minor', direction = 'out', width = 1, length = 5, labelsize=15)
secax.set_xlabel(r'$z$', fontsize=20, fontname= 'Baskerville')
secax.set_xticks([5, 10, 15, 20, 30, 50])
secax.set_xticklabels(['5', '10', '15', '20', '30', '50'], fontsize=15, fontname= 'Baskerville')
plt.plot(nu_list, input_true_signal, color='k', alpha=1)
plt.plot(nu_list, T_KAN, color='r', alpha=1, linestyle='--')
plt.plot(nu_list, T_21cmGEM, color='k', alpha=1, linestyle='--')
ax.set_ylim(-300,50)
ax.set_xlim(27.85,236.74)
#plt.savefig('21cmGEM_mock_signal_5mknoise_DJ+24.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

##########
# par_ARES = [39.415, 0.2, 4.0, 21.0, -1.301, 11.447158, 0.49, -0.61]
# par_list = ['pop_rad_yield{1}', 'pop_fesc{0}', 'pop_Tmin{0}',  'pop_logN{1}', 'pq_func_par0[0]{0}','pq_func_par1[0]{0}','pq_func_par2[0]{0}','pq_func_par3[0]{0}']
# print('ARES parameter names:', par_list)
# print('fiducial parameter values:', par_ARES)
# parameters_dictionary = ares.util.ParameterBundle('mirocha2017:base')
# updates = {'pop_rad_yield{1}': 10**par_ARES[0], 'pop_fesc{0}': par_ARES[1], 'pop_Tmin{0}': 10**par_ARES[2],
#               'pop_logN{1}': par_ARES[3], 'pq_func_par0[0]{0}': 10**par_ARES[4], 'pq_func_par1[0]{0}': 10**par_ARES[5],
#               'pq_func_par2[0]{0}': par_ARES[6],'pq_func_par3[0]{0}': par_ARES[7]}
# parameters_dictionary.update(updates)
# sim = ares.simulations.Global21cm(**parameters_dictionary, verbose=False, progress_bar=False)
# tau = sim.medium.field.solver.tau
# tau_instance = sim.medium.field.solver.tau_solver
# hmf_instance = sim.pops[0].halos
# sps_instance = sim.pops[0].src
# updates_2 = {'tau_instance': tau_instance, 'hmf_instance': hmf_instance, 'pop_src_instance{0}' : sps_instance}
# parameters_dictionary.update(updates_2)
# sim = ares.simulations.Global21cm(**parameters_dictionary, verbose=False, progress_bar=False)
# sim.run()

# b15 = ares.util.read_lit('bouwens2015')
# b15_Phi = np.array(b15.data['lf'][5.9]['phi'])
# b15_error = np.array(b15.data['lf'][5.9]['err'])
# b15_M_UV = np.array(b15.data['lf'][5.9]['M'])
# magnitude_list = b15_M_UV
# ARES_UVLF = sim.pops[0].LuminosityFunction(5.9, magnitude_list)[1]

# N_UVLF = len(magnitude_list)
# data_UVLF = ARES_UVLF
# model_signal = KANSignalModel(file_path_to_nn='models/emulator.h5')
# model_UVLF = AresUVLFModel(freq(redshifts), magnitude_list, params_list)
# UVLF_noise_level = 2
# error_UVLF = [x*UVLF_noise_level for x in b15_error]
# gaussian_loglikelihood_UVLF = GaussianLoglikelihood(data_UVLF, error_UVLF, model_UVLF)
##########

def priors(cube):
    prior_array = np.array([np.log10(5e-1)+4,2-np.log10(4.2),9,0.16,0.5,2.9,40])*cube + np.array([-4,np.log10(4.2),-6,0.04,1,0.1,10])
    prior_array[0] = 10**prior_array[0]
    prior_array[1] = 10**prior_array[1]
    prior_array[2] = 10**prior_array[2]
    return prior_array

# number of dimensions to fit in nested sampling analysis
n_params = len(params_list)
prefix = "/tutorials/multinest_chains_21cmKAN_21cmGEM_7param/KAN_21cmGEM_7param_nlive1200_tolerance0.1_sampeff0.8_DJ24signal_5mKnoise_default2_trial11-"
# run MultiNest
kwargs = {'n_live_points':1200, 'evidence_tolerance':0.1,'sampling_efficiency':0.8, 'n_iter_before_update':10}
result = solve(LogLikelihood=gaussian_loglikelihood_signal, Prior=priors, n_dims=n_params, outputfiles_basename=prefix, verbose=True, **kwargs)

print()
print('evidence: %(logZ).1f +- %(logZerr).1f' % result)
print()
print('parameter values:')
for name, col in zip(params_list, result['samples'].transpose()):
    print('%15s : %.3f +- %.3f' % (name, col.mean(), col.std()))

# make marginal plots by running:
# $ python multinest_marginals.py chains/test-
# For that, we need to store the parameter names:
import json
with open('%sparams.json' % prefix, 'w') as f:
    json.dump(params_list, f, indent=2)
