# Overview
This notebook will walk you through how to use trained GPs. 


In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import pickle
import warnings
from pathlib import Path
from pprint import pp

import h5py
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import holodeck as holo
from holodeck import plot, utils
import holodeck.gps.plotting_utils
import holodeck.gps.sam_utils
import holodeck.gps.gp_utils
from holodeck.gps import plotting_utils as pu
from holodeck.gps import sam_utils as su
from holodeck.gps import gp_utils
from holodeck.gps import gp_utils as gu
from holodeck.gps.gp_utils import GaussProc


# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

In [None]:
sim_path = Path(
    "/Users/lzkelley/programs/nanograv/15yr_astro_libraries/"
    # "astro-02-gw"
    "broad-uniform-01"
)
sim_path.exists()

### First, we need to read in the library

In [None]:
spectra_file = sim_path.joinpath("sam_lib.hdf5")
print("/".join(spectra_file.parts[-3:]), spectra_file.exists())
spectra = h5py.File(spectra_file, "r")
param_names = spectra.attrs['param_names'].astype('str')
print(param_names)

### Loading the trained GPs
We'll load in the .PKL of trained GPs. Note that the `gaussproc` class was renamed to `GaussProc` to follow camel case standards for class names.

The GPs should be named programmatically based on the library's name, so we'll use that.

In [None]:
import pickle
import sys
from pathlib import Path
from holodeck.gps import gp_utils

sys.modules['gp_utils'] = gp_utils

gp_file = sim_path.joinpath(
    # "gp_trained/trained_gp_astro-02-gw_20230312_202655.pkl"
    # "gp_trained/trained_gp_astro-02-gw_20230312_204249.pkl"
    # "gp_trained/trained_gp_astro-02-gw_20230312_204349.pkl"
    "gp_trained/broad-uniform-01-kernel-fits_tinygp.pkl"
    
)

# NOTE: THIS IS NOT WORKING because the config files are copied when the run is done... by which point its been changed!
# test = Path("/" + "/".join(gp_file.parts[1:-1]))
# pat = "2023" + gp_file.name.split("_2023")[-1].strip(".pkl")
# ini_file = list(test.glob(f'*{pat}.ini'))
# assert len(ini_file) == 1
# gp_config = {}
# with open(ini_file[0], 'r') as data:
#     for line in data.readlines():
#         line = line.strip()
#         if len(line) == 0 or line.startswith('#') or line.startswith('['):
#             continue
#         print(line)
#         line = line.split(" = ")
#         gp_config[line[0]] = line[1]

with open(gp_file, "rb") as f:
    gp_george = pickle.load(f, )

## Examine Chains

In [None]:
gp_params = ['amp',] + param_names.tolist()
print(gp_params)

### Corner plot the samples

In [None]:
skip = 4
gp = gp_george[0]
samples = gp.emcee_flatchain
print(samples.shape)
nsamp = samples.shape[0]
print(nsamp/96)
nwalk = 96
nsamp = 2000
print(nwalk, nsamp, nwalk*nsamp, samples.shape[0]/nwalk/nsamp)
assert samples.shape[0] == nwalk * nsamp
samples = samples.reshape(nwalk, nsamp, -1)
print(samples.shape)
samples = samples[::skip, ::skip, :].reshape(-1, samples.shape[-1])
print(samples.shape)
# kale.corner(samples[:1000, :].T)
import corner
corner.corner(samples, labels=gp_params)
plt.show()

### Plot the progression of each parameter vs chain steps

In [None]:
skip = 4
gp = gp_george[0]
npars = len(gp_params)
fig, axes = plot.figax(figsize=[8, 3*npars], nrows=npars, scale='lin', sharex=True, xlabel='chain step')
zz = gp.emcee_flatlnprob[::skip]
smap = plot.smap(zz, cmap='Spectral')
colors = smap.to_rgba(zz)
for ii, ax in enumerate(axes):
    ax.set(ylabel=gp_params[ii])
    yy = gp.emcee_flatchain[::skip, ii]
    ax.scatter(1+np.arange(yy.size), yy, facecolor=colors, edgecolor='0.25', lw=0.1, alpha=0.25, s=5)
    
cax = fig.add_axes([0.95, 0.2, 0.02, 0.6])
plt.colorbar(smap, cax=cax, label='lnprob')
plt.show()

### Plot the parameters vs. lnprob

In [None]:
gp = gp_george[0]
npars = len(gp_params)
fig, axes = plot.figax(figsize=[8, 3*npars], nrows=npars, scale='lin', xlabel='lnprob')
zz = gp.emcee_flatlnprob
smap = plot.smap([0.0, 1.0], cmap='Spectral')
colors = smap.to_rgba(np.linspace(0.0, 1.0, zz.size))
for ii, ax in enumerate(axes):
    ax.set(ylabel=gp_params[ii])
    xx = gp.emcee_flatchain[:, ii]
    ax.scatter(zz, xx, c=colors, s=3, alpha=0.2)

cax = fig.add_axes([0.95, 0.2, 0.02, 0.6])
plt.colorbar(smap, cax=cax)    
plt.show()

## Examine predictions

### Setting up GP predictions
Here we prepare the GPs for predictions.
It's possible that the older models have byte strings instead of strings as their dictionary keys. If so, copy the below code and run it immediately after this cell.
```python
for gp in gp_george:
    gp.par_dict = { key.decode('ascii'): gp.par_dict.get(key) for key in gp.par_dict.keys() }
```

In [None]:
gp = gu.set_up_predictions(gp_george)
breaker()

## Plot Each Parameter 1D

In [None]:
def get_binned_med_std_1d(xx, yy, nbins=10):
    assert np.ndim(yy) == 2
    assert np.shape(xx)[0] == np.shape(yy)[0]
    if np.ndim(xx) == 1:
        xx = xx[:, np.newaxis] * np.ones_like(yy)
    elif np.shape(xx) != np.shape(yy):
        raise ValueError
    
    bins = kale.utils.spacing(xx[:, 0], 'lin', nbins+1, stretch=0.1)

    med, *_ = sp.stats.binned_statistic(xx.flatten(), yy.flatten(), bins=bins, statistic='median')

    std, *_ = sp.stats.binned_statistic(xx.flatten(), yy.flatten(), bins=bins, statistic='std')

    ylo = med - std
    ylo = [[ylo[ii], ylo[ii]] for ii in range(ylo.size)]
    ylo = np.array(ylo).flatten()
    yhi = med + std
    edges, yhi = plot._get_hist_steps(bins, yhi)
    return bins, med, edges, ylo, yhi


def draw_spectra_strains_and_gp_strains_at_freq__vs_param(ax, freq_idx, par_idx, sample_params, gp_george, gps):
    test_ind = int(sample_params.shape[0] * TEST_FRAC)
    pars_all = sample_params[test_ind:, :]
    pars = pars_all[:, par_idx]
    gpg = gp_george[freq_idx]
    gp_trained = gps[freq_idx]

    gwb = spectra['gwb'][test_ind:, freq_idx, :]
    sort_idx = np.argsort(pars)

    xx = pars[sort_idx]
    yy = np.log10(gwb[sort_idx])
    xx_yy = xx[:, np.newaxis] * np.ones_like(yy)
    cc = ax.scatter(xx_yy, yy, alpha=0.2, s=4, zorder=2)
    cc = cc.get_facecolor()
    ax.scatter([], [], color=cc, label='library', alpha=1.0, s=15)

    bins, med, edges, ylo, yhi = get_binned_med_std_1d(xx_yy, yy)
    plot.draw_hist_steps(ax, bins, med, zorder=9, color='0.75', lw=5.0, alpha=0.5)
    plot.draw_hist_steps(ax, bins, med, zorder=10, color=cc, lw=2.0, alpha=1.0)
    ax.fill_between(edges, ylo, yhi, zorder=10, color=cc, alpha=0.2)

    # Move number of realizations to 1th dimension, (R, S) ==> (S, R) 
    hc = gp_trained.sample_conditional(gpg.y, pars_all, size=yy.shape[1]).T

    hc = gpg.mean_spectra + hc
    hc = np.log10(np.sqrt(10.0 ** hc))
    cc = ax.scatter(xx_yy, hc, alpha=0.15, s=10, zorder=2, marker='.', edgecolor='0.25', lw=0.1)
    cc = cc.get_facecolor()
    ax.scatter([], [], color=cc, label='GP samples', alpha=1.0, s=30)

    bins, med, edges, ylo, yhi = get_binned_med_std_1d(xx_yy, hc)
    plot.draw_hist_steps(ax, bins, med, zorder=9, color='0.75', lw=5.0, alpha=0.5)
    plot.draw_hist_steps(ax, bins, med, zorder=10, color=cc, lw=2.0, alpha=1.0)
    ax.fill_between(edges, ylo, yhi, zorder=10, color=cc, alpha=0.2)

    yextr = kale.utils.minmax(yy, stretch=1.0)
    ax.set(ylim=yextr)
    return

freq_idx = 0

# par_idx = 1
# fig, ax = plot.figax(scale='lin')
# ax.set(xlabel=param_names[par_idx], ylabel='$\log_{10}(h_c)$')
# draw_spectra_strains_and_gp_strains_at_freq__vs_param(ax, freq_idx, par_idx, sample_params, gp_george, gps)

npars = len(param_names)
fig, axes = plot.figax(figsize=[8, 3*npars], nrows=npars, scale='lin', hspace=0.35)
for par_idx, ax in enumerate(axes):
    ax.set(xlabel=param_names[par_idx], ylabel='$\log_{10}(h_c)$')
    draw_spectra_strains_and_gp_strains_at_freq__vs_param(ax, freq_idx, par_idx, sample_params, gp_george, gps)
    if par_idx == 0:
        ax.legend()

plt.show()


### Choosing what to hold constant
In the following cell, a `mean_pars` dictionary is created. This contains the mean value of the parameter over its allowed range. This values in this dictionary tell the plotting routines what constant values to use. Feel free to construct your own with different values. Each `gp_george` has a `gp_george.par_dict` that contains each parameter and its allowed range. 

Another example constant dictionary you could create is one of the minimum values
```python
min_pars = {key:gp_george[0].par_dict[key]['min'] for key in gp_george[0].par_dict.keys()}
```

In [None]:
mean_pars = gu.mean_par_dict(gp_george)
pp(mean_pars)

In [None]:
pp(gp_george[0].par_dict)

### Plotting individual parameters
The following cell will plot GWBs while varying the parameter of interest with other parameters held constant. You may choose whether to calculate smoothed-mean GWBs from a SAM to overlay. The avaiable SAM configurations are in `sam_utils.py`



In [None]:
# sam_model = su.Hard04()
sam_model = holo.param_spaces.PS_Broad_Uniform_01

In [None]:
sam_model.param_names

In [None]:
pu.plot_individual_parameter(
    gp_george,
    gp,
    mean_pars,
    "hard_time",
    spectra,
    find_sam_mean=True,
    model=sam_model.model_for_params,
    plot_dir="/Users/lzkelley/Programs/nanograv/holodeck/output/gps",
    nreal=10,
    num_points=5
)

### Getting back numerical values
`plotting_utils.plot_individual_parameter()` can optionally return numerical results

In [None]:
?pu.plot_individual_parameter

### Plotting all parameters
The following cell will plot GWBs for each paramter, shading the regions in between the extrema. Once again, the values held constant are specified by `mean_pars`, but you can supply your own.

In [None]:
pu.plot_parameter_variances(
    gp_george, gp, mean_pars, spectra, alpha=0.65, plot_dir="plots"
)

### Plotting prediction over the data from the library
In the following cell, you can plot the GP's prediction on top of all of the realizations for a given parameter combination from the training data. If you reserved a training set, this would be a good place to choose an index that lies within the training set.

In [None]:
index = 300
pu.plot_over_realizations(index, spectra, gp_george, gp)

### Drawing from the emcee chain
Below, you'll see an example of drawing $h_\rm{c}(f)$ samples from the emcee chain. 

In [None]:
# Parameters from above plot
use_pars = [-4.95E-01, -6.01E-01, 2.13E+00, 1.57E+00, -2.03E+00, 8.48E+00]

# To use mean_pars, see below
#hc = gu.sample_hc_from_gp(gp_george, gp, list(mean_pars.values()), 100)


# I'm using the parameters from above because the spectra for the mean parameters are rather simple
hc = gu.sample_hc_from_gp(gp_george, gp, use_pars, 100)

In [None]:
?gu.sample_hc_from_gp

In [None]:
freqs = spectra["fobs"][: hc.shape[1]]
for i in range(hc.shape[0]):
    plt.loglog(freqs, hc[i, :], color="#4682b4", alpha=0.3)

plt.xlabel("Observed GW Frequency [Hz]")
plt.ylabel(r"$h_{c} (f)$")