In [1]:
import sys
import os
import numpy as np
import dill
import gzip
import logging


sys.path.append(os.path.abspath('../../surmise/emulationmethods'))
from AKSGP import Emulator

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        # logging.FileHandler('emulator_train.log', mode='w'),  # Log to file
        logging.StreamHandler()  # Log to console
    ]
)
logger = logging.getLogger(__name__)

# Load data

In [2]:
# Load training data
train_dir = 'simulation_data/Grad_Pb-Pb-2760GeV/train'
X = np.loadtxt(os.path.join(train_dir, 'X.txt'))
Ymean = np.loadtxt(os.path.join(train_dir, 'Ymean.txt'))
Ystd = np.loadtxt(os.path.join(train_dir, 'Ystd.txt'))

print(f"Arrays loaded from directory '{train_dir}'.")
print(f"Shapes of loaded arrays: {X.shape}, {Ymean.shape}, {Ystd.shape}")


# Load testing data
test_dir = 'simulation_data/Grad_Pb-Pb-2760GeV/test'

Xval = np.loadtxt(os.path.join(test_dir, 'X.txt'))
Ymeanval = np.loadtxt(os.path.join(test_dir, 'Ymean.txt'))
Ystdval = np.loadtxt(os.path.join(test_dir, 'Ystd.txt'))

print(f"Arrays loaded from directory '{test_dir}'.")
print(f"Shapes of loaded arrays: {Xval.shape}, {Ymeanval.shape}, {Ystdval.shape}")


Arrays loaded from directory 'simulation_data/Grad_Pb-Pb-2760GeV/train'.
Shapes of loaded arrays: (485, 17), (485, 110), (485, 110)
Arrays loaded from directory 'simulation_data/Grad_Pb-Pb-2760GeV/test'.
Shapes of loaded arrays: (93, 17), (93, 110), (93, 110)


# Train emulators

In [3]:
inpsize = 485; numobs= (0,4,7,34);

emu = Emulator(X=X[:inpsize], Y_mean=Ymean[:inpsize, numobs], Y_std=Ystd[:inpsize, numobs])
emu.fit(kernel='AKS', nrestarts=5, seed=13)


2024-08-23 18:07:35,363 - AKSGP - INFO - Automatic kernel selection opted. Best kernel for each output dimension will be selected from the list of kernels:
   ['Matern12', 'Matern32', 'Matern52', 'RBF', 'DotProduct+Matern12', 'DotProduct+Matern32', 'DotProduct+Matern52', 'DotProduct+RBF', 'DotProduct*Matern12', 'DotProduct*Matern32', 'DotProduct*Matern52', 'DotProduct*RBF', 'ExpSineSquared+Matern12', 'ExpSineSquared+Matern32', 'ExpSineSquared+Matern52', 'ExpSineSquared+RBF', 'ExpSineSquared*Matern12', 'ExpSineSquared*Matern32', 'ExpSineSquared*Matern52', 'ExpSineSquared*RBF', 'RationalQuadratic+Matern12', 'RationalQuadratic+Matern32', 'RationalQuadratic+Matern52', 'RationalQuadratic+RBF', 'RationalQuadratic*Matern12', 'RationalQuadratic*Matern32', 'RationalQuadratic*Matern52', 'RationalQuadratic*RBF']

2024-08-23 18:07:35,363 - AKSGP - INFO - Shape of training arrays: (436, 17), (436, 4), (436, 4)
2024-08-23 18:07:35,364 - AKSGP - INFO - Shape of pseudo_test arrays: (49, 17), (49, 4), 

In [4]:
test = 10

print("Predictions on training set ----------------->")
# Predict with standard deviation
means, std_devs = emu.predict(X[:test])

print("  Original means:\n", Ymean[:test, numobs],"\n")
print("  Predicted means:\n", means,"\n")

print("  Original standard deviations:\n", Ystd[:test, numobs],"\n")
print("  Predicted standard deviations:\n", std_devs,"\n")


print("Predictions on test set ----------------->")
# Predict with standard deviation
means, std_devs = emu.predict(Xval[:test])

print("  Original means:\n", Ymeanval[:test, numobs],"\n")
print("  Predicted means:\n", means,"\n")

print("  Original standard deviations:\n", Ystdval[:test, numobs],"\n")
print("  Predicted standard deviations:\n", std_devs,"\n")


Predictions on training set ----------------->
  Original means:
 [[1608.155095  402.4201     64.998956  370.298827]
 [1376.107143  398.155003   75.674266  371.891188]
 [1715.649143  495.888734   93.457043  462.45653 ]
 [1474.812429  439.662816   79.33832   407.198466]
 [1625.855467  504.200066  100.375664  467.423269]
 [1504.085667  379.364243   58.23383   346.128945]
 [1480.250667  443.19724    87.351097  413.624828]
 [1437.252714  427.646196   72.45078   393.168612]
 [1482.820619  364.887598   56.401625  338.102267]
 [1657.787762  484.448914   82.403924  450.87281 ]] 

  Predicted means:
 [[1611.99800312  403.38215891   65.23766955  372.79484851]
 [1374.35220555  397.71231605   75.38382912  371.49579182]
 [1710.04703896  495.25086574   92.99959395  461.49806039]
 [1474.69000123  439.46758525   79.28216365  406.12050034]
 [1626.14640097  504.15672467  100.3805551   467.48985595]
 [1505.13966517  379.09314104   58.29477314  345.54513277]
 [1480.41236758  443.30014296   87.41787078  41