In [None]:
import Global21cmLSTM as Global21cmLSTM
import tensorflow as tf
from matplotlib import pyplot as plt

# this code trains an instance of 21cmLSTM using the "21cmGEM" training and validation sets (Section 2.2, DJ+24)
# using the quoted architecture (Section 2.1), preprocessing (Section 2.2), and training configuration (Section 2.3)
# and plots the training set loss and validation set loss for each epoch number
# and prints the mean relative rms error when evaluated on the parameter values of the 1,704 signals in the test set

emulator_21cmGEM = Global21cmLSTM.emulator_21cmGEM.Emulate() # initialize emulator for 21cmGEM data
emulator_21cmGEM.emulator.summary() # print architecture and characterisitcs of the network to train

loss_mse = 'mse' # using MSE loss function for training
learning_rate = 0.001 # using default for Keras Adam optimizer
Adam = tf.keras.optimizers.Adam(learning_rate)
emulator_21cmGEM.emulator.compile(optimizer=Adam, loss=loss_mse)

train_loss_ep10_1, val_loss_ep10_1 = emulator_21cmGEM.train(epochs=75, batch_size=10)
train_loss_ep1, val_loss_ep1 = emulator_21cmGEM.train(epochs=25, batch_size=1)
train_loss_ep10_2, val_loss_ep10_2 = emulator_21cmGEM.train(epochs=75, batch_size=10)

train_loss = train_loss_ep10_1+train_loss_ep1+train_loss_ep10_2
val_loss = val_loss_ep10_1+val_loss_ep1+val_loss_ep10_2

fig, ax = plt.subplots(constrained_layout=True)
ax.plot(train_loss, label='Training set')
ax.plot(val_loss, label='Validation set')
ax.legend()
ax.set_xlabel('Epoch number')
ax.set_ylabel(r'MSE loss (mK$^2$)')
ax.set_yscale('log')
plt.savefig('LSTM_21cmGEM_loss.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

rmse = emulator_21cmGEM.test_error()
print('the mean relative rms error of this trial of 21cmLSTM when trained and tested on the 21cmGEM set is:', rmse.mean(), '%')