In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cosmoLSTM as cosmoLSTM

emulator = cosmoLSTM.emulator.Emulate() # initialize emulator with 21cmGEM data sets
emulator.load_model()  # load pretrained instance of 21cmLSTM emulator. Use kwarg 'model_path' to load other models.
emulator.emulator.summary()
print(emulator.par_labels) # print the input physical parameters

# define list of parameters. can also be shape (N_signals,7)
params = emulator.par_test.copy() # grab the test set parameters: "fstar", "Vc", "fx", "tau", "alpha", "nu_min", "Rmfp"
signals_emulated = emulator.predict(params)  # emulate the global 21 cm signals
signals_true = emulator.signal_test.copy()

frequencies = emulator.frequencies  # array of frequencies (dz=0.1, length=451)
plt.figure()
for i in range(np.shape(params)[0]):
    plt.plot(frequencies, signals_emulated[i], color='r', alpha=0.1)
    plt.plot(frequencies, signals_true[i], color='k', alpha=0.1)
plt.xlabel(r'$\nu$ (MHz)')
plt.ylabel(r'$\delta T_b$ (mK)')
plt.savefig('21cmLSTM_21cmGEM_run1.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

# compute the relative rms error for this instance of 21cmLSTM evaluated on the same 21cmGEM test set used in the paper
# the results should fall within the range found for 21cmLSTM when trained and tested on the 21cmGEM set for 20 trials (Fig. 2 in paper)
rel_error = emulator.test_error()
abs_error = emulator.test_error(relative=False)
abs_error_nu_50_100 = emulator.test_error(relative=False, nu_low=50, nu_high=100)
print('Mean relative rms error:', np.mean(rel_error), '%')
print('Mean absolute rms error:', np.mean(abs_error), 'mK')
print('Mean absolute rms error for 50-100 MHz:', np.mean(abs_error_nu_50_100), 'mK')

# Histogram of relative rms errors for this instance of 21cmLSTM, should look similar to top right panel of Fig. 1 in the paper
plt.hist(rel_error, bins=50)
plt.show()