In [None]:
import numpy as np
from matplotlib import pyplot as plt
import Global21cmLSTM as Global21cmLSTM
from Global21cmLSTM.eval_21cmGEM import evaluate_21cmGEM
from Global21cmLSTM.eval_ARES import evaluate_ARES
from Global21cmLSTM.emulator_21cmGEM import error as err_21cmGEM
from Global21cmLSTM.emulator_ARES import error as err_ARES

# this code evaluates trained instances of 21cmLSTM on the parameter values of the 1,704 signals in the '21cmGEM' and ARES test sets
# and prints the accuracy metrics (mean, median, max relative and absolute rms error)
# and plots the network realizations and relative rms errors when evaluated on the respective test sets

emulator_21cmGEM = Global21cmLSTM.emulator_21cmGEM.Emulate() # initialize 21cmLSTM to emulate 21cmGEM data
emulator_21cmGEM.load_model()  # load pretrained instance of 21cmLSTM emulator. Use kwarg 'model_path' to load other models.
emulator_21cmGEM.emulator.summary()
print(emulator_21cmGEM.par_labels) # print input 21cmGEM physical parameters: 'f_star', 'V_c', 'f_X', 'tau', 'alpha', 'nu_min', 'R_mfp'
params_21cmGEM = emulator_21cmGEM.par_test.copy() # define test set parameter values (i.e., data), shape (N_signals,7) for 7 physical parameters
signals_21cmGEM_true = emulator_21cmGEM.signal_test.copy() # define test set brightness temperature values (i.e., labels)
predictor_21cmGEM = evaluate_21cmGEM()
signals_21cmGEM_emulated = predictor_21cmGEM(params_21cmGEM) # generate 21cmLSTM predictions (i.e. emulations) of the 21cmGEM signals

emulator_ARES = Global21cmLSTM.emulator_ARES.Emulate() # initialize 21cmLSTM to emulate ARES data
emulator_ARES.load_model()  # load pretrained instance of 21cmLSTM emulator. Use kwarg 'model_path' to load other models.
emulator_ARES.emulator.summary()
print(emulator_ARES.par_labels) # print input ARES physical parameters: 'c_X', 'f_esc', 'T_min', 'logN_HI', 'f_star0', 'M_p', 'gamma_lo','gamma_hi'
params_ARES = emulator_ARES.par_test.copy()  # define test set parameter values (i.e., data), shape (N_signals,8) for 8 physical parameters
signals_ARES_true = emulator_ARES.signal_test.copy() # define test set brightness temperature values (i.e., labels)
predictor_ARES = evaluate_ARES()
signals_ARES_emulated = predictor_ARES(params_ARES)  # generate 21cmLSTM predictions (i.e. emulations) of the ARES signals

# compute the relative and absolute rms errors for the provided representative instances of 21cmLSTM
# trained and tested separately on same 21cmGEM and ARES sets used and described in DJ+24
# the result for trained/tested on 21cmGEM should be consistent with the range found for 20 trials shown in Fig. 2 of DJ+24

rel_err_21cmGEM = err_21cmGEM(signals_21cmGEM_true, signals_21cmGEM_emulated)
rel_err_ARES = err_ARES(signals_ARES_true, signals_ARES_emulated)
abs_err_21cmGEM = err_21cmGEM(signals_21cmGEM_true, signals_21cmGEM_emulated, relative=False)
abs_err_ARES = err_ARES(signals_ARES_true, signals_ARES_emulated, relative=False)

mean_rel_err_ARES = np.mean(rel_err_ARES)
median_rel_err_ARES = np.median(rel_err_ARES)
max_rel_err_ARES = np.max(rel_err_ARES)
mean_abs_err_ARES = np.mean(abs_err_ARES)
median_abs_err_ARES = np.median(abs_err_ARES)
max_abs_err_ARES = np.max(abs_err_ARES)
print('Mean relative rms error for 21cmLSTM trained and tested on ARES:', mean_rel_err_ARES, '%')
print('Median relative rms error for 21cmLSTM trained and tested on ARES:', median_rel_err_ARES, '%')
print('Max relative rms error for 21cmLSTM trained and tested on ARES:', max_rel_err_ARES, '%')
print()
print('Mean absolute rms error for 21cmLSTM trained and tested on ARES:', mean_abs_err_ARES, 'mK')
print('Median absolute rms error for 21cmLSTM trained and tested on ARES:', median_abs_err_ARES, 'mK')
print('Max absolute rms error for 21cmLSTM trained and tested on ARES:', max_abs_err_ARES, 'mK')
print()

mean_rel_err_21cmGEM = np.mean(rel_err_21cmGEM)
median_rel_err_21cmGEM = np.median(rel_err_21cmGEM)
max_rel_err_21cmGEM = np.max(rel_err_21cmGEM)
mean_abs_err_21cmGEM = np.mean(abs_err_21cmGEM)
median_abs_err_21cmGEM = np.median(abs_err_21cmGEM)
max_abs_err_21cmGEM = np.max(abs_err_21cmGEM)
print('Mean relative rms error for 21cmLSTM trained and tested on 21cmGEM:', mean_rel_err_21cmGEM, '%')
print('Median relative rms error for 21cmLSTM trained and tested on 21cmGEM:', median_rel_err_21cmGEM, '%')
print('Max relative rms error for 21cmLSTM trained and tested on 21cmGEM:', max_rel_err_21cmGEM, '%')
print()
print('Mean absolute rms error for 21cmLSTM trained and tested on 21cmGEM:', mean_abs_err_21cmGEM, 'mK')
print('Median absolute rms error for 21cmLSTM trained and tested on 21cmGEM:', median_abs_err_21cmGEM, 'mK')
print('Max absolute rms error for 21cmLSTM trained and tested on 21cmGEM:', max_abs_err_21cmGEM, 'mK')
print()

# note that the test set error of 21cmLSTM can also be computed directly with the test_error() function within the Emulate class,
# although this function calls predict() to compute the error rather than using the eval_21cmGEM.py or eval_ARES.py scripts as done above,
# which are optimized for speed by importing lookup tables for the training sets mins and maxs to avoid preprocessing them for each signal
test_err_21cmGEM = emulator_21cmGEM.test_error()
test_err_ARES = emulator_ARES.test_error()
print(np.mean(test_err_21cmGEM))
print(np.mean(test_err_ARES))

# Plot the true signals (in black) and emulated realizations (in red) for 21cmLSTM trained and tested on the 21cmGEM and ARES sets

frequencies_21cmGEM = emulator_21cmGEM.frequencies  # array of frequencies for signals in 21cmGEM set (z=5-50, dz=0.1, n_nu=451)
frequencies_ARES = emulator_ARES.frequencies  # array of frequencies for signals in ARES set (z=5.1-49.9, dz=0.1, n_nu=449)

vr = 1420.405751
def freq(zs):
    return vr/(zs+1)

def redshift(v):
    return (vr/v)-1

fig, ax = plt.subplots(constrained_layout=True)
ax.minorticks_on()
ax.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax.tick_params(axis='both', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax.set_yticks([50, 0, -50,-100,-150,-200,-250,-300])
ax.set_yticklabels(['50', '0','-50','-100','-150','-200','-250','-300'], fontsize=20, fontname= 'Baskerville')
ax.set_xticks([40, 60, 80, 100, 120, 140, 160, 180, 200, 220])
ax.set_xticklabels(['40', '60', '80', '100', '120', '140', '160', '180', '200', '220'], fontsize=20, fontname= 'Baskerville')
ax.set_ylabel(r'$\delta T_b$ (mK)', fontsize=20, fontname= 'Baskerville')
ax.set_xlabel(r'$\nu$ (MHz)', fontsize=20, fontname= 'Baskerville')
secax = ax.secondary_xaxis('top', functions=(redshift, freq))
secax.tick_params(which='major', direction = 'out', width = 2, length = 10, labelsize=20)
secax.tick_params(which='minor', direction = 'out', width = 1, length = 5, labelsize=20)
secax.set_xlabel(r'$z$', fontsize=20, fontname= 'Baskerville')
secax.set_xticks([5, 10, 15, 20, 30, 50])
secax.set_xticklabels(['5', '10', '15', '20', '30', '50'], fontsize=15, fontname= 'Baskerville')
i=0
for i in range(np.shape(params_21cmGEM)[0]):
    ax.plot(frequencies_21cmGEM, signals_21cmGEM_emulated[i], color='r', alpha=0.1)
    ax.plot(frequencies_21cmGEM, signals_21cmGEM_true[i], color='k', alpha=0.1)
ax.set_ylim(-300,50)
ax.set_xlim(27.85,236.74)
plt.savefig('21cmLSTM_21cmGEM_test_realizations.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

plt.cla()
plt.clf()

fig, ax = plt.subplots(constrained_layout=True)
ax.minorticks_on()
ax.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax.tick_params(axis='both', which='minor', direction = 'out', width = 2, length = 5, labelsize=20)
ax.set_yticks([50, 0, -50,-100,-150,-200,-250,-300])
ax.set_yticklabels(['50', '0','-50','-100','-150','-200','-250','-300'], fontsize=20, fontname= 'Baskerville')
ax.set_xticks([40, 60, 80, 100, 120, 140, 160, 180, 200, 220])
ax.set_xticklabels(['40', '60', '80', '100', '120', '140', '160', '180', '200', '220'], fontsize=20, fontname= 'Baskerville')
ax.set_ylabel(r'$\delta T_b$ (mK)', fontsize=20, fontname= 'Baskerville')
ax.set_xlabel(r'$\nu$ (MHz)', fontsize=20, fontname= 'Baskerville')
secax = ax.secondary_xaxis('top', functions=(redshift, freq))
secax.tick_params(which='major', direction = 'out', width = 2, length = 10, labelsize=20)
secax.tick_params(which='minor', direction = 'out', width = 1, length = 5, labelsize=20)
secax.set_xlabel(r'$z$', fontsize=20, fontname= 'Baskerville')
secax.set_xticks([5, 10, 15, 20, 30, 50])
secax.set_xticklabels(['5', '10', '15', '20', '30', '50'], fontsize=15, fontname= 'Baskerville')
i=0
for i in range(np.shape(params_ARES)[0]):
    ax.plot(frequencies_ARES, signals_ARES_emulated[i], color='r', alpha=0.1)
    ax.plot(frequencies_ARES, signals_ARES_true[i], color='k', alpha=0.1)
ax.set_ylim(-300,50)
ax.set_xlim(27.85,236.74)
plt.savefig('21cmLSTM_ARES_test_realizations.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

plt.cla()
plt.clf()

# Plot histograms of test set relative rms errors for representative instances of 21cmLSTM trained on 21cmGEM and ARES sets
# should look similar to right panel of Fig. 1 in DJ+24

fig, ax = plt.subplots(constrained_layout=True)
ax.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax.set_yticks([0, 50, 100, 150, 200, 250, 300])
ax.set_yticklabels(['0', '50', '100', '150', '200', '250', '300'], fontsize=20, fontname= 'Baskerville')
ax.set_xticks([0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75])
ax.set_xticklabels(['0', '0.25', '0.50', '0.75', '1.00', '1.25', '1.50', '1.75'], fontsize=20, fontname= 'Baskerville')
ax.hist(err_21cmGEM, bins=20, color='blue')
ax.axvline(median_rel_err_21cmGEM, linestyle='solid', label='median (21cmLSTM)', color='orange')
ax.axvline(mean_rel_err_21cmGEM, linestyle='solid', label='mean (21cmLSTM)', color='gray')
ax.legend(fontsize=20)
plt.xlabel('Error (%)', fontsize=20, fontname= 'Baskerville')
plt.ylabel('Counts', fontsize=20, fontname= 'Baskerville')
plt.xlim(0,1.9)
plt.savefig('21cmLSTM_21cmGEM_test_rel_err.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

plt.cla()
plt.clf()

fig, ax = plt.subplots(constrained_layout=True)
ax.tick_params(axis='both', which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax.set_yticks([0, 50, 100, 150, 200, 250, 300])
ax.set_yticklabels(['0', '50', '100', '150', '200', '250', '300'], fontsize=20, fontname= 'Baskerville')
ax.set_xticks([0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75])
ax.set_xticklabels(['0', '0.25', '0.50', '0.75', '1.00', '1.25', '1.50', '1.75'], fontsize=20, fontname= 'Baskerville')
ax.hist(err_ARES, bins=20, color='blue')
ax.axvline(median_rel_err_ARES, linestyle='solid', label='median (21cmLSTM)', color='orange')
ax.axvline(mean_rel_err_ARES, linestyle='solid', label='mean (21cmLSTM)', color='gray')
ax.legend(fontsize=20)
plt.xlabel('Error (%)', fontsize=20, fontname= 'Baskerville')
plt.ylabel('Counts', fontsize=20, fontname= 'Baskerville')
plt.xlim(0,1.9)
plt.savefig('21cmLSTM_ARES_test_rel_err.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()
