In [None]:
import numpy as np
import Global21cmLSTM as Global21cmLSTM
from Global21cmLSTM.eval_21cmGEM import evaluate_21cmGEM
from Global21cmLSTM.eval_ARES import evaluate_ARES
from matplotlib import pyplot as plt

# this code evaluates a trained instance of 21cmLSTM on the parameter values of the 1,704 signals in the "21cmGEM" test set
# and prints the accuracy metrics (mean, median, max relative and absolute rms error)
# and plots the network realizations and relative rms errors when evaluated on the test set

emulator_21cmGEM = Global21cmLSTM.emulator_21cmGEM.Emulate() # initialize 21cmLSTM to emulate 21cmGEM data
emulator_21cmGEM.load_model()  # load pretrained instance of 21cmLSTM emulator. Use kwarg 'model_path' to load other models.
emulator_21cmGEM.emulator.summary()
print(emulator_21cmGEM.par_labels) # print the input 21cmGEM physical parameters
predictor_21cmGEM = evaluate_21cmGEM()

# define list of parameters, shape (N_signals,7) for 7 physical parameters
params_21cmGEM = emulator_21cmGEM.par_test.copy() # grab the test set parameter values for 'f_star', 'V_c', 'f_X', 'tau', 'alpha', 'nu_min', 'R_mfp'
signals_21cmGEM_emulated = predictor_21cmGEM(params_21cmGEM) # generate 21cmLSTM predictions (i.e. emulations) of the 21cmGEM global 21 cm signals
signals_21cmGEM_true = emulator_21cmGEM.signal_test.copy()

# emulator_ARES = Global21cmLSTM.emulator_ARES.Emulate() # initialize 21cmLSTM to emulate 21cmGEM data
# emulator_ARES.load_model()  # load pretrained instance of 21cmLSTM emulator. Use kwarg 'model_path' to load other models.
# emulator_ARES.emulator.summary()
# print(emulator_ARES.par_labels) # print the input 21cmGEM physical parameters
# predictor_ARES = evaluate_ARES()

# # define list of parameters, shape (N_signals,7) for 7 physical parameters
# params_ARES = emulator_ARES.par_test.copy() # grab the test set parameter values for 'f_star', 'V_c', 'f_X', 'tau', 'alpha', 'nu_min', 'R_mfp'
# signals_ARES_emulated = predictor_ARES(params_ARES)  # generate 21cmLSTM predictions (i.e. emulations) of the 21cmGEM global 21 cm signals
# signals_ARES_true = emulator_ARES.signal_test.copy()

frequencies = emulator_21cmGEM.frequencies  # array of frequencies (dz=0.1, length=451)

vr = 1420.405751
def freq(zs):
    return vr/(zs+1)

def redshift(v):
    return (vr/v)-1

fig, ax = plt.subplots(constrained_layout=True)
ax.minorticks_on()
ax.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax.tick_params(axis='both', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax.set_yticks([50, 0, -50,-100,-150,-200,-250,-300])
ax.set_yticklabels(['50', '0','-50','-100','-150','-200','-250','-300'], fontsize=20, fontname= 'Baskerville')
ax.set_xticks([40, 60, 80, 100, 120, 140, 160, 180, 200, 220])
ax.set_xticklabels(['40', '60', '80', '100', '120', '140', '160', '180', '200', '220'], fontsize=20, fontname= 'Baskerville')
ax.set_ylabel(r'$\delta T_b$ (mK)', fontsize=20, fontname= 'Baskerville')
ax.set_xlabel(r'$\nu$ (MHz)', fontsize=20, fontname= 'Baskerville')
secax = ax.secondary_xaxis('top', functions=(redshift, freq))
secax.tick_params(which='major', direction = 'out', width = 2, length = 10, labelsize=20)
secax.tick_params(which='minor', direction = 'out', width = 1, length = 5, labelsize=20)
secax.set_xlabel(r'$z$', fontsize=20, fontname= 'Baskerville')
secax.set_xticks([5, 10, 15, 20, 30, 50])
secax.set_xticklabels(['5', '10', '15', '20', '30', '50'], fontsize=15, fontname= 'Baskerville')
i=0
for i in range(np.shape(params_21cmGEM)[0]):
    ax.plot(frequencies, signals_21cmGEM_emulated[i], color='r', alpha=0.1)
    ax.plot(frequencies, signals_21cmGEM_true[i], color='k', alpha=0.1)
ax.set_ylim(-300,50)
ax.set_xlim(27.85,236.74)
plt.savefig('21cmLSTM_21cmGEM_test_realizations.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

# compute the relative rms error for this instance of 21cmLSTM evaluated on the same 21cmGEM test set used in the paper
# the results should fall within the range found for 21cmLSTM when trained and tested on the 21cmGEM set for 20 trials (Fig. 2 in paper)
#signals_emulated = emulator_21cmGEM.predict(params)
# rel_error = emulator_21cmGEM.test_error()
# abs_error = emulator_21cmGEM.test_error(relative=False)

# mean_rel_err = np.mean(rel_error)
# median_rel_err = np.median(rel_error)
# max_rel_err = np.max(rel_error)
# mean_abs_err = np.mean(abs_error)
# print('Mean relative rms error:', mean_rel_err, '%')
# print('Median relative rms error:', median_rel_err, '%')
# print('Max relative rms error:', max_rel_err, '%')
# print('Mean absolute rms error:', mean_abs_err, 'mK')

# # Histogram of relative rms errors for this instance of 21cmLSTM, should look similar to top right panel of Fig. 1 in the paper
# fig, ax = plt.subplots(constrained_layout=True)
# ax.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
# ax.set_yticks([0, 50, 100, 150, 200, 250, 300])
# ax.set_yticklabels(['0', '50', '100', '150', '200', '250', '300'], fontsize=20, fontname= 'Baskerville')
# ax.set_xticks([0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75])
# ax.set_xticklabels(['0', '0.25', '0.50', '0.75', '1.00', '1.25', '1.50', '1.75'], fontsize=20, fontname= 'Baskerville')
# ax.hist(rel_error, bins=20, color='blue')
# ax.axvline(median_rel_err, linestyle='solid', label='median (21cmLSTM)', color='orange')
# ax.axvline(mean_rel_err, linestyle='solid', label='mean (21cmLSTM)', color='gray')
# ax.legend(fontsize=20)#, fontname= 'Baskerville')
# plt.xlabel('Error (%)', fontsize=20, fontname= 'Baskerville')
# plt.ylabel('Counts', fontsize=20, fontname= 'Baskerville')
# plt.xlim(0,1.9)
# plt.savefig('21cmLSTM_21cmGEM_rel_err_github_run1.png', dpi = 300, bbox_inches='tight', facecolor='w')
# plt.show()
