# Overview
This notebook will walk you through how to use trained GPs. 


In [None]:
import pickle
from pathlib import Path
from pprint import pp

import gp_utils as gu
import h5py
import holodeck as holo
import matplotlib.pyplot as plt
import plotting_utils as pu
import sam_utils as su
from gp_utils import GaussProc

%load_ext autoreload
%autoreload 2

### First, we need to read in the library

In [None]:
spectra_file = Path(
    "./spec_libraries/hard04b_n1000_g100_s40_r50_f40/sam-lib_hard04b_2023-01-23_01_n1000_g100_s40_r50_f40.hdf5"
)


spectra = h5py.File(spectra_file, "r")

### Loading the trained GPs
We'll load in the .PKL of trained GPs. Note that the `gaussproc` class was renamed to `GaussProc` to follow camel case standards for class names.

The GPs should be named programmatically based on the library's name, so we'll use that.

In [None]:
gaussproc = GaussProc  # For backwards compatibility before change to camel-case

gp_file = "trained_gp_" + spectra_file.parent.name + ".pkl"
with open(spectra_file.parent / gp_file, "rb") as f:
    gp_george = pickle.load(f)

### Setting up GP predictions
Here we prepare the GPs for predictions.
It's possible that the older models have byte strings instead of strings as their dictionary keys. If so, copy the below code and run it immediately after this cell.
```python
for gp in gp_george:
    gp.par_dict = { key.decode('ascii'): gp.par_dict.get(key) for key in gp.par_dict.keys() }
```

In [None]:
gp = gu.set_up_predictions(spectra, gp_george)

### Choosing what to hold constant
In the following cell, a `mean_pars` dictionary is created. This contains the mean value of the parameter over its allowed range. This values in this dictionary tell the plotting routines what constant values to use. Feel free to construct your own with different values. Each `gp_george` has a `gp_george.par_dict` that contains each parameter and its allowed range. 

Another example constant dictionary you could create is one of the minimum values
```python
min_pars = {key:gp_george[0].par_dict[key]['min'] for key in gp_george[0].par_dict.keys()}
```

In [None]:
mean_pars = gu.mean_par_dict(gp_george)

pp(mean_pars)

In [None]:
pp(gp_george[0].par_dict)

### Plotting individual parameters
The following cell will plot GWBs while varying the parameter of interest with other parameters held constant. You may choose whether to calculate smoothed-mean GWBs from a SAM to overlay. The avaiable SAM configurations are in `sam_utils.py`



In [None]:
sam_model = su.Hard04()

In [None]:
sam_model.param_names

In [None]:
pu.plot_individual_parameter(
    gp_george,
    gp,
    mean_pars,
    "hard_rchar",
    spectra,
    find_sam_mean=True,
    model=sam_model.sam_for_params,
    plot_dir="plots",
    nreal=10,
    num_points=5
)

### Getting back numerical values
`plotting_utils.plot_individual_parameter()` can optionally return numerical results

In [None]:
?pu.plot_individual_parameter

### Plotting all parameters
The following cell will plot GWBs for each paramter, shading the regions in between the extrema. Once again, the values held constant are specified by `mean_pars`, but you can supply your own.

In [None]:
pu.plot_parameter_variances(
    gp_george, gp, mean_pars, spectra, alpha=0.65, plot_dir="plots"
)

### Plotting prediction over the data from the library
In the following cell, you can plot the GP's prediction on top of all of the realizations for a given parameter combination from the training data. If you reserved a training set, this would be a good place to choose an index that lies within the training set.

In [None]:
index = 300
pu.plot_over_realizations(index, spectra, gp_george, gp)

### Drawing from the emcee chain
Below, you'll see an example of drawing $h_\rm{c}(f)$ samples from the emcee chain. 

In [None]:
# Parameters from above plot
use_pars = [-4.95E-01, -6.01E-01, 2.13E+00, 1.57E+00, -2.03E+00, 8.48E+00]

# To use mean_pars, see below
#hc = gu.sample_hc_from_gp(gp_george, gp, list(mean_pars.values()), 100)


# I'm using the parameters from above because the spectra for the mean parameters are rather simple
hc = gu.sample_hc_from_gp(gp_george, gp, use_pars, 100)

In [None]:
?gu.sample_hc_from_gp

In [None]:
freqs = spectra["fobs"][: hc.shape[1]]
for i in range(hc.shape[0]):
    plt.loglog(freqs, hc[i, :], color="#4682b4", alpha=0.3)

plt.xlabel("Observed GW Frequency [Hz]")
plt.ylabel(r"$h_{c} (f)$")