In [1]:
## Data retrieval and setup

# Obtain data from Zenodo if not already downloaded (~10 GB)
# This will take some time!

import subprocess
subprocess.call('wget https://zenodo.org/record/7335961/files/active_learning_sims.tar.gz', shell=True)

# Strongly suggest using something like aria2 for speed-up

subprocess.call('aria2c --file-allocation=none -c -x 10 -s 10 \
https://zenodo.org/record/7335961/files/active_learning_sims.tar.gz', shell=True)

# Untar the data

subprocess.call('tar -zxf active_learning_sims.tar.gz', shell=True)

# Optional: once downloaded, convert spectra to hdf5 format
# This will *significantly* increase IO during training

import convert_dat_to_h5 as convh5

convh5.convert_dat_to_h5(path_to_sims_dir='knsc1_active_learning', \
                         path_to_h5_out='TP_wind2_spectra.h5')

In [None]:
## Random forest training

import spectra_interpolator as si

# Use the random forest interpolator
intp = si.intp(rf=True)

# Load the simulation parameters and spectra
# t_max = None implies time used during interpolation
# theta = 0 defined in degrees

intp.load_data('knsc1_active_learning/*spec*', \
               'TP_wind2_spectra.h5', \
               t_max=None, theta=30, short_wavs=True)

# Append time as the 5th input parameter
# Nominally, spectra have shape [N_sims, times, wavs, thetas]
# For free time + fixed angle, append_input_parameter(intp.times, 1)
# For free angle + fixed time, append_input_parameter(intp.angles, 2)

intp.append_input_parameter(intp.times, 1)

# Create verification test set

intp.create_test_set(size=5)

# Preprocess data for easier training

intp.preprocess()

# Train!

intp.train()

# Save the model

intp.save('rf_spec_intp.joblib')

# Evaluate the test set input parameters if evaluate() given no arguments
# If arguments provided, stores them under intp.prediction
# Predictions can be returned by setting argument ret_out=True 

intp.evaluate()

# Make plots of the test set to visually identify off-sample fitting
# Test set plots stored in intp_figures directory

intp.make_plots()