In [None]:

### This tutorial notebooks shows how to train an instance of 21cmKAN (Dorigo Jones et al. 2025) on a physical model data set
### and how to print the trained network's accuracy metrics and relevant plots shown in DJ+25, including visualizing its learned activations

### The first section trains and tests 21cmKAN on the publicly-available 21cmGEM/21cmVAE data set (doi: 10.5281/zenodo.5084113) 
### and also trains and tests 21cmKAN on the publicly-available ARES data set (doi: 10.5281/zenodo.13840725),
### using the default architecture and training configurations explained in Section 2.3 and Figure 3 of DJ+25.
### The configs and model path can be changed using kwargs, described in the emulator GitHub scripts: emulate_21cmGEM.py or emulate_ARES.py
### The emulator can be trained on other physical models or data sets by editing the provided example emulator script: emulate_yourmodel.py

### The second section of this notebook plots results of the training and reproduces Figures B1, 7, 8, and 2 of DJ+25

import numpy as np
import torch
import os
import h5py
import Global21cmKAN as Global21cmKAN
import matplotlib as mpl
import matplotlib.cm as cm
from matplotlib import pyplot as plt
from matplotlib import rc, gridspec
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)

# create the 21cmKAN emulator instance and explicitly move the model to the correct device
emulator_21cmGEM = Global21cmKAN.emulate_21cmGEM.Emulate()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
emulator_21cmGEM.emulator = emulator_21cmGEM.emulator.to(device)
print('21cmGEM parameters:', emulator_21cmGEM.par_labels) # print the 21cmGEM physical model parameter names (see Table 1 of Dorigo Jones et al. 2025)

# train the model on the publicly-available 21cmGEM data set
# set the kwarg "model_path" to define where to save the model
# note that the default path overwrites the saved network downloaded from the GitHub repository
train_loss_21cmGEM, val_loss_21cmGEM = emulator_21cmGEM.train()
print(f"Training loss for each epoch: {train_loss_21cmGEM}")
print(f"Validation loss for each epoch: {val_loss_21cmGEM}")

# evaluate the loaded trained instance of 21cmKAN on the (unnormalized) 1,704-signal 21cmGEM test set
# print the mean and max relative rms error evaluated on the test set; set relative=False to print the absolute errors
# to reproduce Figure 4 of DJ+25 perform 20 trials, although note that standard stochasticity in the training algorithm may cause
# trials to fall outside of the published range, which is ~0.2% to ~0.32% mean error when training/testing on the 21cmGEM set
test_rel_RMSE_values_21cmGEM = emulator_21cmGEM.test_error()
test_mean_rel_RMSE_21cmGEM = np.mean(test_rel_RMSE_values_21cmGEM)
test_median_rel_RMSE_21cmGEM = np.median(test_rel_RMSE_values_21cmGEM)
test_max_rel_RMSE_21cmGEM = np.max(test_rel_RMSE_values_21cmGEM)
print(f"Test set mean relative emulation RMSE (%): {test_mean_rel_RMSE_21cmGEM}")
print(f"Test set median relative emulation RMSE (%): {test_median_rel_RMSE_21cmGEM}")
print(f"Test set max relative emulation RMSE (%): {test_max_rel_RMSE_21cmGEM}")

# now train and test 21cmKAN on the ARES data set
emulator_ARES = Global21cmKAN.emulate_ARES.Emulate()
emulator_ARES.emulator = emulator_ARES.emulator.to(device)
print('ARES parameters:', emulator_ARES.par_labels)
train_loss_ARES, val_loss_ARES = emulator_ARES.train()
test_rel_RMSE_values_ARES = emulator_ARES.test_error()
test_mean_rel_RMSE_ARES = np.mean(test_rel_RMSE_values_ARES)
test_median_rel_RMSE_ARES = np.median(test_rel_RMSE_values_ARES)
test_max_rel_RMSE_ARES = np.max(test_rel_RMSE_values_ARES)
print(f"Test set mean relative emulation RMSE (%): {test_mean_rel_RMSE_ARES}")
print(f"Test set median relative emulation RMSE (%): {test_median_rel_RMSE_ARES}")
print(f"Test set max relative emulation RMSE (%): {test_max_rel_RMSE_ARES}")

################################################################################################################################
################################################################################################################################
# plot the MSE loss for each training epoch of 21cmKAN evaluated on the normalized training and validation set signals
# to reproduce Figure B1 of DJ+25 perform 20 trials training 21cmKAN on the 21cmGEM and ARES sets and plot them below,
# along with the mean loss curves and secondary axis showing the approximate training time
rc('figure', figsize=(8.0, 6.0))
plt.rcParams['mathtext.fontset'] = 'cm'
mpl.rc('font',family='Baskerville')

epochs_21cmGEM = np.linspace(1, 400, 400)
epochs_ARES = np.linspace(1, 800, 800)

fig, ax = plt.subplots(constrained_layout=True)
plt.xlabel('epoch number', size=20, fontname= 'Baskerville')
plt.ylabel(r'MSE loss (mK$^2$)', size=20, fontname= 'Baskerville')
ax.plot(epochs_21cmGEM, train_loss_21cmGEM, color='r', alpha=1.0, label=r'training loss, trained on ${\tt 21cmGEM}$')
ax.plot(epochs_21cmGEM, val_loss_21cmGEM, color='k', alpha=1.0, label=r'validation loss, trained on ${\tt 21cmGEM}$')
ax.plot(epochs_ARES, train_loss_ARES, color='orange', alpha=1.0, label=r'training loss, trained on ${\tt ARES}$')
ax.plot(epochs_ARES, val_loss_ARES, color='purple', alpha=1.0, label=r'validation loss, trained on ${\tt ARES}$')
ax.tick_params(which='major', direction = 'out', width = 2, length = 10, labelsize=20)
ax.set_xticks([0, 100, 200, 300, 400, 500, 600, 700, 800])
ax.set_yticks([1e-6, 1e-5, 1e-4, 1e-3, 1e-2])
ax.set_xticklabels(['0', '100', '200', '300', '400', '500', '600', '700', '800'], fontsize=20, fontname= 'Baskerville')
ax.set_yticklabels([r'10^{-6}', r'10^{-5}', r'10^{-4}', r'10^{-3}', r'10^{-2}'], fontsize=20, fontname= 'Baskerville')
ax.set_yscale('log')
ax.set_xlim(0,800)
ax.set_ylim(7e-7,2e-2)
plt.legend(loc='upper right', fontsize=17)
plt.savefig('21cmKAN_default_21cmGEM_ARES_train_val_loss.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

plt.cla()
plt.clf()

################################################################################################################################
# plot the activation functions in the first hidden layer for each input physical parameter, learned by a trial of 21cmKAN
# trained on the 21cmGEM set and evaluated at the parameter values of the 24,562 training set signals
# reproduces Figure 7 of DJ+25

# load saved model; if you changed the model_path when training above, use that path here
PATH = f"{os.environ.get('AUX_DIR', os.environ.get('HOME'))}/.Global21cmKAN/"
model_save_path_21cmGEM = PATH+"models/emulator_21cmGEM.pth"
model_21cmGEM = torch.load(model_save_path_21cmGEM, weights_only=False, map_location=device)
model_21cmGEM.eval() # shows the structure of the loaded KAN model

pytorch_total_params = sum(p.numel() for p in model_21cmGEM.parameters())
print('total number of parameters in 21cmKAN model:', pytorch_total_params)
pytorch_trainable_params = sum(p.numel() for p in model_21cmGEM.parameters() if p.requires_grad)
print('total number of trainable parameters in 21cmKAN model:', pytorch_trainable_params)

# extract the first hidden layer of the trained KAN to visualize its learned activation functions
kan_layer1 = model_21cmGEM.layers[0]  # Extract first hidden layer
print('PyTorch and KAN attributes:', dir(kan_layer1))

proc_params_train_21cmGEM_np = emulator_21cmGEM.par_train.copy()
proc_params_train_21cmGEM = torch.from_numpy(proc_params_train_21cmGEM_np)
proc_params_train_21cmGEM_np = 0
proc_params_train_21cmGEM = proc_params_train_21cmGEM.to(device)

print("Input tensor shape:", proc_params_train_21cmGEM.shape) # shape: [24562, 7] = [num_samples, num_features]

# evaluate the B-spline basis functions at each input feature, for the 21cmGEM training set parameter values
with torch.no_grad():
    spline_basis_values1 = kan_layer1.b_splines(proc_params_train_21cmGEM) # shape: [24562, 7, 10] = [num_samples, num_features, num_basis_functions_per_feature]
# extract the learned weights for each B-spline basis function in the first hidden layer, for all features and nodes
spline_weights1 = kan_layer1.spline_weight # shape: [44, 7, 10] = [num_nodes, num_features, num_basis_functions_per_feature]

# Compute the learned B-spline outputs by weighting the basis functions
# for each parameter: extract the evaluated basis functions for all samples, extract the spline weights for all nodes
# finally, use np.einsum to compute the learned B-spline outputs per sample and node
num_features = proc_params_train_21cmGEM.shape[1] # 7 input features
num_basis_functions = spline_basis_values1.shape[2] # 10 basis functions per feature
basis_vals1 = []
basis_weights1 = []
learned_splines1 = []
for i in range(num_features):
    basis_vals1.append(spline_basis_values1[:, i, :].detach().cpu().numpy()) # Shape: [num_features, num_samples, num_basis_functions_per_feature]
    basis_weights1.append(spline_weights1[:, i, :].detach().cpu().numpy()) # Shape: [num_features, num_nodes, num_basis_functions_per_feature]
    learned_splines1.append(np.einsum("sb,nb->sn", basis_vals1[i], basis_weights1[i]))  # Shape: [num_features, num_samples, num_nodes]

# compute activation function outputs by taking the weighted sum over the basis functions
# i.e., take dot product of b_splines(x) and spline_weight along the 'number of basis functions per feature' dimension
y_full1 = torch.einsum("nsb,osb->nos", spline_basis_values1, spline_weights1)  # Shape: [num_samples, num_nodes, num_features]
# sum across all nodes to get aggregated activation functions
y_values1 = y_full1.sum(dim=1) # Shape: [num_samples, num_features]

num_nodes = learned_splines1[i].shape[1]
max_basis_coeffs_par0 = []
max_basis_coeffs_par1 = []
max_basis_coeffs_par2 = []
max_basis_coeffs_par3 = []
max_basis_coeffs_par4 = []
max_basis_coeffs_par5 = []
max_basis_coeffs_par6 = []
for j in range(num_nodes):
    max_basis_coeffs_par0.append(np.max(np.abs(basis_weights1[0][j,:])))
    max_basis_coeffs_par1.append(np.max(np.abs(basis_weights1[1][j,:])))
    max_basis_coeffs_par2.append(np.max(np.abs(basis_weights1[2][j,:])))
    max_basis_coeffs_par3.append(np.max(np.abs(basis_weights1[3][j,:])))
    max_basis_coeffs_par4.append(np.max(np.abs(basis_weights1[4][j,:])))
    max_basis_coeffs_par5.append(np.max(np.abs(basis_weights1[5][j,:])))
    max_basis_coeffs_par6.append(np.max(np.abs(basis_weights1[6][j,:])))
max_basis_coeffs_allparams_allnodes = [max_basis_coeffs_par0, max_basis_coeffs_par1, max_basis_coeffs_par2, max_basis_coeffs_par3,\
                                       max_basis_coeffs_par4, max_basis_coeffs_par5, max_basis_coeffs_par6]

with h5py.File(PATH + 'dataset_21cmGEM.h5', "r") as f:
    print("Keys: %s" % f.keys())
    par_train = np.asarray(f['par_train'])[()]
f.close()

unproc_f_s_train = par_train[:,0].copy() # f_*, star formation efficiency
unproc_V_c_train = par_train[:,1].copy() # V_c, minimum circular velocity of star-forming halos 
unproc_f_X_train = par_train[:,2].copy() # f_X, X-ray efficiency of sources
unproc_f_s_train = np.log10(unproc_f_s_train)
unproc_V_c_train = np.log10(unproc_V_c_train)
unproc_f_X_train[unproc_f_X_train == 0] = 1e-6 # for f_X, set zero values to 1e-6 before taking log_10
unproc_f_X_train = np.log10(unproc_f_X_train)
parameters_log_train = np.empty(par_train.shape)
parameters_log_train[:,0] = unproc_f_s_train
parameters_log_train[:,1] = unproc_V_c_train
parameters_log_train[:,2] = unproc_f_X_train
parameters_log_train[:,3:] = par_train[:,3:].copy()

plot_limits_21cmGEM = [[-4,np.log10(5e-1)], [np.log10(4.2),2], [-6,3], [0.04,0.2], [1,1.5], [0.1,3.0], [10,50]] # size : [num_features, 2]
param_labels_21cmGEM = [r'$\log_{10} f_*$', r'$\log_{10} V_c$', r'$\log_{10} f_X$', r'$\tau$', r'$\alpha$',\
                        r'$\nu_{\rm min}$', r'$R_{\rm mfp}$']
cmap_nodes = cm.get_cmap('plasma', num_nodes)
cmap_basis_functions = cm.get_cmap('rainbow', num_basis_functions)
fig, axes = plt.subplots(nrows=3, ncols=num_features, figsize=(4*num_features, 10), sharex='col', sharey='row')
plt.subplots_adjust(wspace=-0.3, hspace=-0.3)
for i in range(num_features):
    label = param_labels_21cmGEM[i]
    feature_vals = parameters_log_train[:, i]
    y_np1 = y_values1[:, i].detach().cpu().numpy()
    for node in range(num_nodes):
        color_nodes = cmap_nodes(max_basis_coeffs_allparams_allnodes[i][node]) # Get a unique color for each node
        axes[1, i].scatter(feature_vals, learned_splines1[i][:, node], label=f"Node {node+1}", color=color_nodes, s=5, alpha=0.8)
    for j in range(num_basis_functions):
        color_basis_functions = cmap_basis_functions(j)
        axes[2, i].scatter(feature_vals, basis_vals1[i][:, j], label=f"Basis {j+1}", color=color_basis_functions, s=5, alpha=1)
    axes[0, i].scatter(feature_vals, y_np1, label=f"Spline {i}", alpha=0.5, s=5, color='k')
    axes[0, i].xaxis.set_minor_locator(AutoMinorLocator())
    axes[0, i].tick_params(axis='both', which='major', labelsize=25, width=3, length=8)
    axes[0, i].tick_params(axis='x', which='minor', labelsize=15, width=3, length=4)
    axes[0, i].set_ylim([-1.5,1.5])
    axes[0, i].set_xlim(plot_limits_21cmGEM[i])
    axes[1, i].xaxis.set_minor_locator(AutoMinorLocator())
    axes[1, i].tick_params(axis='both', which='major', labelsize=25, width=3, length=8)
    axes[1, i].tick_params(axis='x', which='minor', labelsize=15, width=3, length=4)
    axes[1, i].set_ylim([-0.6,0.6])
    axes[1, i].set_xlim(plot_limits_21cmGEM[i])
    axes[2, i].xaxis.set_minor_locator(AutoMinorLocator())
    axes[2, i].tick_params(axis='both', which='major', labelsize=25, width=3, length=8)
    axes[2, i].tick_params(axis='x', which='minor', labelsize=15, width=3, length=4)
    axes[2, i].set_xlabel(f"{label}", fontsize=32)
    axes[2, i].set_ylim([0,0.7])
    axes[2, i].set_xlim(plot_limits_21cmGEM[i])
    if i == 0:
        axes[0, i].set_ylabel(r'$\sum{\phi(x)}$', fontsize=32, rotation=0, labelpad=35)
        axes[1, i].set_ylabel(r'$\phi(x)$', fontsize=32, rotation=0, labelpad=35)
        axes[2, i].set_ylabel(r'$B(x)$', fontsize=32, rotation=0, labelpad=35)

plt.tight_layout()
plt.savefig('21cmKAN_default_21cmGEM_activations_basis_functions_layer1.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

plt.cla()
plt.clf()

################################################################################################################################
# plot the activation functions in the first hidden layer for each input physical parameter, learned by a trial of 21cmKAN
# trained on the ARES set and evaluated at the parameter values of the 23,896 training set signals
# reproduces Figure 8 of DJ+25

# load saved model; if you changed the model_path when training above, use that path here
model_save_path_ARES = PATH+"models/emulator_ARES.pth"
model_ARES = torch.load(model_save_path_ARES, weights_only=False, map_location=device)
model_ARES.eval() # shows the structure of the loaded KAN model

pytorch_total_params = sum(p.numel() for p in model_ARES.parameters())
print('total number of parameters in 21cmKAN model:', pytorch_total_params)
pytorch_trainable_params = sum(p.numel() for p in model_ARES.parameters() if p.requires_grad)
print('total number of trainable parameters in 21cmKAN model:', pytorch_trainable_params)

# extract the first hidden layer of the trained KAN to visualize its learned activation functions
kan_layer1 = model_ARES.layers[0]  # Extract first hidden layer
print('PyTorch and KAN attributes:', dir(kan_layer1))

proc_params_train_ARES_np = emulator_ARES.par_train.copy()
proc_params_train_ARES = torch.from_numpy(proc_params_train_ARES_np)
proc_params_train_ARES_np = 0
proc_params_train_ARES = proc_params_train_ARES.to(device)

print("Input tensor shape:", proc_params_train_ARES.shape) # shape: [23896, 8] = [num_samples, num_features]

# evaluate the B-spline basis functions at each input feature, for the ARES training set parameter values
with torch.no_grad():
    spline_basis_values1 = kan_layer1.b_splines(proc_params_train_ARES) # shape: [23896, 8, 10] = [num_samples, num_features, num_basis_functions_per_feature]
# extract the learned weights for each B-spline basis function in the first hidden layer, for all features and nodes
spline_weights1 = kan_layer1.spline_weight # shape: [44, 8, 10] = [num_nodes, num_features, num_basis_functions_per_feature]

# Compute the learned B-spline outputs by weighting the basis functions
# for each parameter: extract the evaluated basis functions for all samples, extract the spline weights for all nodes
# finally, use np.einsum to compute the learned B-spline outputs per sample and node
num_features = proc_params_train_ARES.shape[1] # 8 input features
num_basis_functions = spline_basis_values1.shape[2] # 10 basis functions per feature
basis_vals1 = []
basis_weights1 = []
learned_splines1 = []
for i in range(num_features):
    basis_vals1.append(spline_basis_values1[:, i, :].detach().cpu().numpy()) # Shape: [num_features, num_samples, num_basis_functions_per_feature]
    basis_weights1.append(spline_weights1[:, i, :].detach().cpu().numpy()) # Shape: [num_features, num_nodes, num_basis_functions_per_feature]
    learned_splines1.append(np.einsum("sb,nb->sn", basis_vals1[i], basis_weights1[i]))  # Shape: [num_features, num_samples, num_nodes]

# compute activation function outputs by taking the weighted sum over the basis functions
# i.e., take dot product of b_splines(x) and spline_weight along the 'number of basis functions per feature' dimension
y_full1 = torch.einsum("nsb,osb->nos", spline_basis_values1, spline_weights1)  # Shape: [num_samples, num_nodes, num_features]
# sum across all nodes to get aggregated activation functions
y_values1 = y_full1.sum(dim=1) # Shape: [num_samples, num_features]

num_nodes = learned_splines1[i].shape[1]
max_basis_coeffs_par0 = []
max_basis_coeffs_par1 = []
max_basis_coeffs_par2 = []
max_basis_coeffs_par3 = []
max_basis_coeffs_par4 = []
max_basis_coeffs_par5 = []
max_basis_coeffs_par6 = []
max_basis_coeffs_par7 = []
for j in range(num_nodes):
    max_basis_coeffs_par0.append(np.max(np.abs(basis_weights1[0][j,:])))
    max_basis_coeffs_par1.append(np.max(np.abs(basis_weights1[1][j,:])))
    max_basis_coeffs_par2.append(np.max(np.abs(basis_weights1[2][j,:])))
    max_basis_coeffs_par3.append(np.max(np.abs(basis_weights1[3][j,:])))
    max_basis_coeffs_par4.append(np.max(np.abs(basis_weights1[4][j,:])))
    max_basis_coeffs_par5.append(np.max(np.abs(basis_weights1[5][j,:])))
    max_basis_coeffs_par6.append(np.max(np.abs(basis_weights1[6][j,:])))
    max_basis_coeffs_par7.append(np.max(np.abs(basis_weights1[7][j,:])))
max_basis_coeffs_allparams_allnodes = [max_basis_coeffs_par0, max_basis_coeffs_par1, max_basis_coeffs_par2, max_basis_coeffs_par3,\
                                       max_basis_coeffs_par4, max_basis_coeffs_par5, max_basis_coeffs_par6, max_basis_coeffs_par7]

with h5py.File(PATH + 'dataset_ARES.h5', "r") as f:
    print("Keys: %s" % f.keys())
    par_train = np.asarray(f['train_data'])[()]
f.close()

unproc_c_X_train = par_train[:,0].copy() # c_X, normalization of X-ray luminosity-SFR relation
unproc_T_min_train = par_train[:,2].copy() # T_min, minimum temperature of star-forming halos
unproc_f_s_train = par_train[:,4].copy() # f_*,0, peak star formation efficiency 
unproc_M_p_train = par_train[:,5].copy() # M_p, dark matter halo mass at f_*,0
unproc_c_X_train = np.log10(unproc_c_X_train)
unproc_T_min_train = np.log10(unproc_T_min_train)
unproc_f_s_train = np.log10(unproc_f_s_train)
unproc_M_p_train = np.log10(unproc_M_p_train)
parameters_log_train = np.empty(par_train.shape)
parameters_log_train[:,0] = unproc_c_X_train
parameters_log_train[:,1] = par_train[:,1].copy()
parameters_log_train[:,2] = unproc_T_min_train
parameters_log_train[:,3] = par_train[:,3].copy()
parameters_log_train[:,4] = unproc_f_s_train
parameters_log_train[:,5] = unproc_M_p_train
parameters_log_train[:,6] = par_train[:,6].copy()
parameters_log_train[:,7] = par_train[:,7].copy()

plot_limits_ARES = [[36,44], [0,1], [np.log10(3e2), np.log10(5e5)], [18,23], [-5,0], [8,15], [0,2], [-4,0]] # size : [num_features, 2]
param_labels_ARES = [r'$\log_{10}{c_X}$', r'$f_{\rm esc}$', r'$\log_{10}{T_{\rm min}}$', r'$\log_{10}{N_{H I}}$',\
                     r'$\log_{10}{f_{\star, 0}}$',r'$\log_{10}{M_p}$',r'$\gamma_{lo}$', r'$\gamma_{hi}$']
cmap_nodes = cm.get_cmap('plasma', num_nodes)
cmap_basis_functions = cm.get_cmap('rainbow', num_basis_functions)
fig, axes = plt.subplots(nrows=3, ncols=num_features, figsize=(4*num_features, 10), sharex='col', sharey='row')
plt.subplots_adjust(wspace=-0.3, hspace=-0.3)
for i in range(num_features):
    label = param_labels_ARES[i]
    feature_vals = parameters_log_train[:, i]
    y_np1 = y_values1[:, i].detach().cpu().numpy()
    for node in range(num_nodes):
        color_nodes = cmap_nodes(max_basis_coeffs_allparams_allnodes[i][node]) # Get a unique color for each node
        axes[1, i].scatter(feature_vals, learned_splines1[i][:, node], label=f"Node {node+1}", color=color_nodes, s=5, alpha=0.8)
    for j in range(num_basis_functions):
        color_basis_functions = cmap_basis_functions(j)
        axes[2, i].scatter(feature_vals, basis_vals1[i][:, j], label=f"Basis {j+1}", color=color_basis_functions, s=5, alpha=1)
    axes[0, i].scatter(feature_vals, y_np1, label=f"Spline {i}", alpha=0.5, s=5, color='k')
    axes[0, i].xaxis.set_minor_locator(AutoMinorLocator())
    axes[0, i].tick_params(axis='both', which='major', labelsize=25, width=3, length=8)
    axes[0, i].tick_params(axis='x', which='minor', labelsize=15, width=3, length=4)
    axes[0, i].set_ylim([-1.5,1.5])
    axes[0, i].set_xlim(plot_limits_ARES[i])
    axes[1, i].xaxis.set_minor_locator(AutoMinorLocator())
    axes[1, i].tick_params(axis='both', which='major', labelsize=25, width=3, length=8)
    axes[1, i].tick_params(axis='x', which='minor', labelsize=15, width=3, length=4)
    axes[1, i].set_ylim([-0.6,0.6])
    axes[1, i].set_xlim(plot_limits_ARES[i])
    axes[2, i].xaxis.set_minor_locator(AutoMinorLocator())
    axes[2, i].tick_params(axis='both', which='major', labelsize=25, width=3, length=8)
    axes[2, i].tick_params(axis='x', which='minor', labelsize=15, width=3, length=4)
    axes[2, i].set_xlabel(f"{label}", fontsize=32)
    axes[2, i].set_ylim([0,0.7])
    axes[2, i].set_xlim(plot_limits_ARES[i])
    if i == 0:
        axes[0, i].set_ylabel(r'$\sum{\phi(x)}$', fontsize=32, rotation=0, labelpad=35)
        axes[1, i].set_ylabel(r'$\phi(x)$', fontsize=32, rotation=0, labelpad=35)
        axes[2, i].set_ylabel(r'$B(x)$', fontsize=32, rotation=0, labelpad=35)

plt.tight_layout()
plt.savefig('21cmKAN_default_ARES_activations_basis_functions_layer1.png', dpi = 300, bbox_inches='tight', facecolor='w')
plt.show()

plt.cla()
plt.clf()

################################################################################################################################
# plot examples of the learned components and signal emulation of 21cmKAN when trained on the ARES data set
# reproduces Figure 2 of DJ+25

# define array of ARES parameter values for the signal fit in DJ+23 and DJ+25
par_ARES_signal = [2.6e39, 0.2, 1e4, 21., 0.05, 2.8e11, 0.49, -0.61]
par_ARES_signal_log = [np.log10(par_ARES_signal[0]), par_ARES_signal[1], np.log10(par_ARES_signal[2]), par_ARES_signal[3], np.log10(par_ARES_signal[4]), np.log10(par_ARES_signal[5]), par_ARES_signal[6], par_ARES_signal[7]]
highlight_param_idx = 2 # choose which parameter to plot/highlight (using T_min as an example)
highlight_value = par_ARES_signal_log[highlight_param_idx]

# Find the closest training sample to our highlight value for the chosen parameter
param_diffs = np.abs(parameters_log_train[:, highlight_param_idx] - highlight_value)
closest_sample_idx = np.argmin(param_diffs)
feature_vals = parameters_log_train[:, highlight_param_idx]
highlight_sample_value = feature_vals[closest_sample_idx]
parameter_labels = [r'$\log_{10}{c_X}$', r'$f_{\rm esc}$', r'$\log_{10}{T_{\rm min}}$', r'$\log_{10}{N_{H I}}$',\
                    r'$\log_{10}{f_{\star, 0}}$',r'$\log_{10}{M_p}$',r'$\gamma_{lo}$', r'$\gamma_{hi}$']

vr = 1420.405751
def freq(zs):
    return vr/(zs+1)

def redshift(v):
    return (vr/v)-1

z_list = np.linspace(5.1, 49.9, 449)
nu_list = freq(z_list)
emulator_ARES.load_model()
T_KAN = emulator_ARES.predict(par_ARES_signal)
# the ARES signal fit is not in the test or training set, so providing the dT_b values here so we don't need to install ARES and run it
T_ARES = np.array([3.31441337e-03,2.18885571e-03,9.69876261e-04,-3.47769970e-04,-1.76936551e-03,-3.30031724e-03,-4.94616668e-03,\
                   -6.71236658e-03,-8.60443097e-03,-1.06277973e-02,-1.27877286e-02,-1.50894416e-02,-1.75377514e-02,-2.01372673e-02,\
                   -2.28920539e-02,-2.58057436e-02,-2.88814461e-02,-3.21216216e-02,-3.55276550e-02,-3.91003115e-02,-8.44328476e-01,\
                   -3.69581203e+00,-6.88031946e+00,-1.03933132e+01,-1.42155707e+01,-1.83334602e+01,-2.27340764e+01,-2.74042637e+01,\
                   -3.23284893e+01,-3.74902613e+01,-4.28714231e+01,-4.84519742e+01,-5.42107445e+01,-6.01247774e+01,-6.61698689e+01,\
                   -7.23211509e+01,-7.85528531e+01,-8.48384646e+01,-9.11513004e+01,-9.74647478e+01,-1.03752457e+02,-1.09988724e+02,\
                   -1.16148600e+02,-1.22207993e+02,-1.28144721e+02,-1.33937733e+02,-1.39567461e+02,-1.45016355e+02,-1.50269156e+02,\
                   -1.55311658e+02,-1.60132725e+02,-1.64721397e+02,-1.69070815e+02,-1.73174817e+02,-1.77029333e+02,-1.80630686e+02,\
                   -1.83978490e+02,-1.87073623e+02,-1.89916909e+02,-1.92510594e+02,-1.94859433e+02,-1.96967595e+02,-1.98840348e+02,\
                   -2.00483914e+02,-2.01904101e+02,-2.03106410e+02,-2.04100919e+02,-2.04891063e+02,-2.05487554e+02,-2.05895372e+02,\
                   -2.06121930e+02,-2.06175300e+02,-2.06062211e+02,-2.05789295e+02,-2.05363287e+02,-2.04791447e+02,-2.04078226e+02,\
                   -2.03230221e+02,-2.02254708e+02,-2.01155044e+02,-1.99937926e+02,-1.98607600e+02,-1.97169935e+02,-1.95627171e+02,\
                   -1.93985915e+02,-1.92249613e+02,-1.90421987e+02,-1.88506639e+02,-1.86508002e+02,-1.84428638e+02,-1.82271965e+02,\
                   -1.80041335e+02,-1.77740027e+02,-1.75371329e+02,-1.72938443e+02,-1.70443046e+02,-1.67888481e+02,-1.65278987e+02,\
                   -1.62616169e+02,-1.59904067e+02,-1.57144185e+02,-1.54340042e+02,-1.51493409e+02,-1.48610951e+02,-1.45692184e+02,\
                   -1.42740961e+02,-1.39760930e+02,-1.36756688e+02,-1.33728768e+02,-1.30681681e+02,-1.27621361e+02,-1.24546615e+02,\
                   -1.21465222e+02,-1.18377934e+02,-1.15288868e+02,-1.12203149e+02,-1.09121496e+02,-1.06051430e+02,-1.02992271e+02,\
                   -9.99516790e+01,-9.69290333e+01,-9.39317288e+01,-9.09594781e+01,-8.80179124e+01,-8.51114508e+01,-8.22398208e+01,\
                   -7.94078096e+01,-7.66174443e+01,-7.38737758e+01,-7.11759347e+01,-6.85276680e+01,-6.59336001e+01,-6.33900067e+01,\
                   -6.09061534e+01,-5.84759605e+01,-5.61085968e+01,-5.37976678e+01,-5.15540860e+01,-4.93673869e+01,-4.72478491e+01,\
                   -4.51857012e+01,-4.31936354e+01,-4.12616718e+01,-3.93905215e+01,-3.75888366e+01,-3.58449059e+01,-3.41654676e+01,\
                   -3.25480086e+01,-3.09887603e+01,-2.94939720e+01,-2.80542477e+01,-2.66756569e+01,-2.53531963e+01,-2.40829107e+01,\
                   -2.28727648e+01,-2.17103464e+01,-2.06034419e+01,-1.95451017e+01,-1.85339409e+01,-1.75774180e+01,-1.66654310e+01,\
                   -1.58007083e+01,-1.49770002e+01,-1.41948406e+01,-1.34534020e+01,-1.27479483e+01,-1.20827886e+01,-1.14493467e+01,\
                   -1.08534802e+01,-1.02871036e+01,-9.75407566e+00,-9.24948703e+00,-8.77361792e+00,-8.32585261e+00,-7.90197137e+00,\
                   -7.50584885e+00,-7.13062459e+00,-6.78002282e+00,-6.44933577e+00,-6.14055552e+00,-5.84934774e+00,-5.57860438e+00,\
                   -5.32383422e+00,-5.08637031e+00,-4.86428648e+00,-4.65791762e+00,-4.46484096e+00,-4.28563440e+00,-4.11899180e+00,\
                   -3.96493747e+00,-3.82161615e+00,-3.68946053e+00,-3.56695801e+00,-3.45502219e+00,-3.35168533e+00,-3.25723441e+00,\
                   -3.17063075e+00,-3.09190848e+00,-3.02046593e+00,-2.95621369e+00,-2.89787484e+00,-2.84605151e+00,-2.79983594e+00,\
                   -2.75939252e+00,-2.72355183e+00,-2.69307938e+00,-2.66706612e+00,-2.64528739e+00,-2.62770427e+00,-2.61884966e+00,\
                   -2.61969814e+00,-2.62420632e+00,-2.63203005e+00,-2.64308846e+00,-2.65727427e+00,-2.67431703e+00,-2.69390212e+00,\
                   -2.71633386e+00,-2.74116587e+00,-2.76815384e+00,-2.79745038e+00,-2.82889134e+00,-2.86230884e+00,-2.89754350e+00,\
                   -2.93461471e+00,-2.97345383e+00,-3.01384831e+00,-3.05590316e+00,-3.09940664e+00,-3.14427957e+00,-3.19061284e+00,\
                   -3.23819785e+00,-3.28698934e+00,-3.33705621e+00,-3.38823912e+00,-3.44055267e+00,-3.49396020e+00,-3.54834631e+00,\
                   -3.60370309e+00,-3.66006137e+00,-3.71732039e+00,-3.77549010e+00,-3.83451913e+00,-3.89442160e+00,-3.95508506e+00,\
                   -4.01659572e+00,-4.07882134e+00,-4.14185842e+00,-4.20556584e+00,-4.27005306e+00,-4.33520373e+00,-4.40104118e+00,\
                   -4.47032868e+00,-4.54222593e+00,-4.61481212e+00,-4.68807223e+00,-4.76197333e+00,-4.83654242e+00,-4.91177537e+00,\
                   -4.98762992e+00,-5.06410452e+00,-5.14117887e+00,-5.21889178e+00,-5.29719856e+00,-5.37609948e+00,-5.45560930e+00,\
                   -5.53568163e+00,-5.61632465e+00,-5.69754808e+00,-5.77935810e+00,-5.86169882e+00,-5.94462838e+00,-6.02808514e+00,\
                   -6.11208572e+00,-6.19666103e+00,-6.28174936e+00,-6.36739091e+00,-6.45354983e+00,-6.54023681e+00,-6.62745942e+00,\
                   -6.71521202e+00,-6.80347471e+00,-6.89226414e+00,-6.98155726e+00,-7.07136018e+00,-7.16166504e+00,-7.25247050e+00,\
                   -7.34380431e+00,-7.43561776e+00,-7.52791446e+00,-7.62043966e+00,-7.71301403e+00,-7.80606522e+00,-7.89959273e+00,\
                   -7.99360004e+00,-8.08810046e+00,-8.18305443e+00,-8.27847608e+00,-8.37437891e+00,-8.47073211e+00,-8.56754597e+00,\
                   -8.66483414e+00,-8.76257282e+00,-8.86076332e+00,-8.95940256e+00,-9.05848630e+00,-9.15801387e+00,-9.25798828e+00,\
                   -9.35839643e+00,-9.45924010e+00,-9.56052356e+00,-9.66224933e+00,-9.76439046e+00,-9.86695242e+00,-9.96993184e+00,\
                   -1.00733387e+01,-1.01771646e+01,-1.02814023e+01,-1.03860444e+01,-1.04910928e+01,-1.05965487e+01,-1.07024137e+01,\
                   -1.08086651e+01,-1.09153068e+01,-1.10223474e+01,-1.11263201e+01,-1.12271461e+01,-1.13283198e+01,-1.14298452e+01,\
                   -1.15317150e+01,-1.16339279e+01,-1.17364752e+01,-1.18393681e+01,-1.19425950e+01,-1.20461527e+01,-1.21500393e+01,\
                   -1.22542476e+01,-1.23587778e+01,-1.24636267e+01,-1.25687950e+01,-1.26742721e+01,-1.27800584e+01,-1.28861506e+01,\
                   -1.29925461e+01,-1.30992404e+01,-1.32062307e+01,-1.33135139e+01,-1.34210899e+01,-1.35289489e+01,-1.36370904e+01,\
                   -1.37455097e+01,-1.38542102e+01,-1.39631817e+01,-1.40724218e+01,-1.41819268e+01,-1.42916891e+01,-1.44017205e+01,\
                   -1.45120005e+01,-1.46225292e+01,-1.47333066e+01,-1.48443366e+01,-1.49555981e+01,-1.50671122e+01,-1.51788462e+01,\
                   -1.52908103e+01,-1.54030141e+01,-1.55154224e+01,-1.56280772e+01,-1.57409217e+01,-1.58540016e+01,-1.59672889e+01,\
                   -1.60807877e+01,-1.61944888e+01,-1.63083700e+01,-1.64224898e+01,-1.65368979e+01,-1.66515909e+01,-1.67662838e+01,\
                   -1.68809768e+01,-1.69960572e+01,-1.71112424e+01,-1.72268233e+01,-1.73424042e+01,-1.74579850e+01,-1.75738588e+01,\
                   -1.76898315e+01,-1.78060100e+01,-1.79206405e+01,-1.80324657e+01,-1.81399194e+01,-1.82457112e+01,-1.83516081e+01,\
                   -1.84576486e+01,-1.85637806e+01,-1.86699859e+01,-1.87763253e+01,-1.88827593e+01,-1.89892853e+01,-1.90958964e+01,\
                   -1.92025764e+01,-1.93093666e+01,-1.94162455e+01,-1.95232030e+01,-1.96302362e+01,-1.97373405e+01,-1.98445014e+01,\
                   -1.99517417e+01,-2.00590499e+01,-2.01664181e+01,-2.02738429e+01,-2.03813208e+01,-2.04888396e+01,-2.05964114e+01,\
                   -2.07040317e+01,-2.08116926e+01,-2.09193909e+01,-2.10271224e+01,-2.11348677e+01,-2.12426552e+01,-2.13504710e+01,\
                   -2.14583079e+01,-2.15661626e+01,-2.16740305e+01,-2.17818796e+01,-2.18897633e+01,-2.19976567e+01,-2.21055517e+01,\
                   -2.22134451e+01,-2.23213337e+01,-2.24292129e+01,-2.25370609e+01,-2.26449158e+01,-2.27527530e+01,-2.28605695e+01,\
                   -2.29683620e+01,-2.30761272e+01,-2.31838594e+01,-2.32915444e+01,-2.33992086e+01,-2.35068329e+01,-2.36144142e+01,\
                   -2.37219489e+01,-2.38294264e+01,-2.39368120e+01,-2.40436858e+01,-2.41468583e+01,-2.42441648e+01,-2.43414169e+01, -2.44386121e+01])

plt.rcParams['mathtext.fontset'] = 'cm'
mpl.rc('font', family='Baskerville')
fig = plt.figure(figsize=(26, 6))
spec = gridspec.GridSpec(ncols=3, nrows=1, width_ratios=[1,1,1], wspace=0.35, hspace=0.5)

# Left column - Basis functions for the highlighted parameter
ax1 = fig.add_subplot(spec[0])
for j in range(num_basis_functions):
    ax1.scatter(feature_vals, basis_vals1[highlight_param_idx][:, j], label=f"Basis {j+1}", s=5, alpha=0.8, color='k')
# Highlight the example value
for j in range(num_basis_functions):
    highlight_basis_value = basis_vals1[highlight_param_idx][closest_sample_idx, j]
    ax1.scatter(highlight_sample_value, highlight_basis_value, s=150, color='red', edgecolor='black', linewidth=2, zorder=10)
ax1.set_xlabel(parameter_labels[highlight_param_idx], fontsize=32)
ax1.set_ylabel(r'$B(x)$', fontsize=32, rotation=0, labelpad=45)
ax1.set_title('B-spline basis functions (unweighted)', fontsize=32, pad=50)
ax1.set_xlim(plot_limits_ARES[highlight_param_idx])
ax1.set_ylim([0, 0.7])
ax1.tick_params(axis='both', which='major', labelsize=30, width=2, length=10)
ax1.xaxis.set_minor_locator(AutoMinorLocator())
ax1.tick_params(axis='x', which='minor', width=2, length=5)

# Middle column - Per-edge activation functions for the highlighted parameter
ax2 = fig.add_subplot(spec[1])
for node in range(num_nodes):
    ax2.scatter(feature_vals, learned_splines1[highlight_param_idx][:, node], label=f"Node {node+1}", s=5, color='k', alpha=0.8)
# Highlight the example value
for node in range(num_nodes):
    highlight_activation_value = learned_splines1[highlight_param_idx][closest_sample_idx, node]
    ax2.scatter(highlight_sample_value, highlight_activation_value, s=150, color='red', edgecolor='black', linewidth=2, zorder=10)
ax2.set_xlabel(parameter_labels[highlight_param_idx], fontsize=32)
ax2.set_ylabel(r'$\phi(x)$', fontsize=32, rotation=0, labelpad=20)
ax2.set_title('per-edge activation functions', fontsize=32, pad=50)
ax2.set_xlim(plot_limits_ARES[highlight_param_idx])
ax2.set_ylim([-0.7, 0.7])
ax2.tick_params(axis='both', which='major', labelsize=30, width=2, length=10)
ax2.xaxis.set_minor_locator(AutoMinorLocator())
ax2.tick_params(axis='x', which='minor', width=2, length=5)

# Bottom row: Right column - Output signal emulation
ax3 = fig.add_subplot(spec[2])
ax3.plot(nu_list, T_KAN, color='blue', linewidth=8, alpha=0.4)
ax3.plot(nu_list, T_ARES, color='k', linestyle='dotted', linewidth=3)
ax3.set_xlabel(r'$\nu$ (MHz)', fontsize=32)
ax3.set_ylabel(r'$\delta T_b$ (mK)', fontsize=32)
ax3.set_title(r'signal emulation: $m_{\rm 21cmKAN}(\theta_0)$', fontsize=32, pad=5)
ax3.set_xlim(27.85, 236.74)
ax3.set_ylim(-250, 25)
ax3.set_yticks([0, -100, -200])
ax3.set_yticklabels(['0', '-100', '-200'], fontsize=30)
ax3.tick_params(axis='both', which='major', labelsize=30, width=2, length=10)
ax3.minorticks_on()
ax3.tick_params(axis='both', which='minor', width=2, length=5)

secax3 = ax3.secondary_xaxis('top', functions=(redshift, freq))
secax3.set_xlabel(r'$z$', fontsize=32)
secax3.set_xticks([5, 10, 15, 25, 50])
secax3.set_xticklabels(['5', '10', '15', '25', '50'], fontsize=30)
secax3.tick_params(which='major', direction='out', width=2, length=10, labelsize=30)
param_text = r"Physical Parameters ($\theta_0$):"
for i, (label, value) in enumerate(zip(parameter_labels, par_ARES_signal_log)):
    if i == 0:
        param_text += "\n"
        param_text += f"   {label} = {value:.2f}\n"
    else:
        param_text += f"   {label} = {value:.2f}\n"

ax3.text(0.595, 0.605, param_text, transform=ax3.transAxes, fontsize=15.5, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
plt.tight_layout()
plt.savefig('21cmKAN_combined_visualization.jpg', dpi=300, bbox_inches='tight', facecolor='w')
plt.show()

print('absolute RMSE between true and emulated signal:', np.sqrt(np.mean((T_KAN-T_ARES)**2)), 'mK')
