# Initialization

In [None]:
# %load ../init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import glob
import re

In [None]:
def figax(fobs=None, plamp=1.0e-15, **kwargs):
    plt.close('all')
    fig, axes = plot.figax(**kwargs)

    if plamp is not None:
        if fobs is None:
            xx = [0.05, 10.0]
        else:
            xx = utils.minmax(fobs) * YR

        yy = plamp * np.power(xx, -2/3)
        for ax in np.atleast_1d(axes):
            ax.plot(xx, yy, 'k--', alpha=0.25)

    return fig, axes 

In [None]:
# import re

# path = (
#     "/Users/lzkelley/programs/nanograv/holodeck/output/test_2022-06-27/"
# )

# path = Path(path)
# files = list(path.glob("*.npz"))
# print(f"Found {len(files)} files")
# pat = r'_p[0-9]{3}_r[0-9]{3}'

# for ff in tqdm.tqdm(files):
#     mat = re.search(pat, str(ff))
#     if mat is None:
#         continue
#     mat = [mm for mm in mat.group().split('_') if len(mm.strip()) > 0]
#     pp, rr = [mm[1:] for mm in mat]
#     fname_new = f"lib_sams__p{int(pp):06d}_r{int(rr):03d}.npz"
#     fname_new = path / fname_new
#     fname_old = path / ff
#     # print(f"\t{fname_old} ==> {fname_new}")
#     fname_old.rename(fname_new)
#     if not fname_new.is_file():
#         raise 

In [None]:
fig, ax = plot.figax()

ax.plot(*plot._get_hist_steps(data['fobs'], data['gwb']))

plt.show()

# Load Data

In [None]:
def merge_output_files(path_output, fname_merged):
    re_pattern = "sam_output_p([0-9]{3})_r([0-9]{3}).npz"

    path_output = os.path.abspath(path_output)
    print(f"{path_output=}")
    file_pattern = os.path.join(path_output, "*.*")
    # files = sorted(glob.glob(file_pattern))
    files = glob.glob(file_pattern)
    num_files = len(files)
    if num_files == 0:
        raise RuntimeError(f"No files found in '{path_output}'!")

    print(f"Found {num_files=}, {files[0]=}")
    files = [ff for ff in files if re.search(re_pattern, ff) is not None]
    num_files = len(files)
    if num_files == 0:
        raise RuntimeError(f"No files matching pattern '{re_pattern}'!")

    print(f"Found {num_files=} matching target pattern")

    data = np.load(files[0])
    fobs = data['fobs']

    num_pars = 10
    num_reals = 300
    gwf = np.zeros((num_pars, num_reals, fobs.size - 1))
    gff = np.zeros_like(gwf)
    gwb = np.zeros_like(gwf)
    mmbulge_norm = np.ones(num_pars) * np.nan

    if num_files != num_pars * num_reals:
        err = f"Found number of files '{num_files}' does not match {num_pars=} * {num_reals=} = {num_pars*num_reals}!"
        log.error(err)
        raise ValueError(err)

    for fil in tqdm.tqdm(files):
        groups = re.findall(re_pattern, fil)[0]
        pp, rr = [int(gg) for gg in groups]
        # print(fil, "===>", pp, rr)
        data = np.load(fil)

        assert np.all(fobs == data['fobs'])
        if np.isnan(mmbulge_norm[pp]):
            mmbulge_norm[pp] = data['mmbulge_norm']
            assert not np.isnan(mmbulge_norm[pp])
        else:
            assert mmbulge_norm[pp] == data['mmbulge_norm']

        gff[pp, rr, :] = data['gff']
        gwf[pp, rr, :] = data['gwf']
        gwb[pp, rr, :] = data['gwb']

    with h5py.File(fname_merged, 'w') as h5:
        h5.create_dataset('fobs', data=fobs)
        h5.create_dataset('gff', data=gff)
        h5.create_dataset('gwf', data=gwf)
        h5.create_dataset('gwb', data=gwb)
        h5.create_dataset('mmbulge_norm', data=mmbulge_norm)
        h5.attrs['num_pars'] = num_pars
        h5.attrs['num_reals'] = num_reals

    print(f"Saved to '{fname_merged}' size {utils.get_file_size(fname_merged)}")

    return

path_output = "/Users/lzkelley/research/nanograv/holodeck/output"
fname_data = os.path.join(path_output, "data.hdf5")
if not os.path.exists(fname_data):
    merge_output_files(path_output, fname_data)

In [None]:
with h5py.File(fname_data, 'r') as h5:
    fobs = h5['fobs'][()]
    print(f"{fobs.size} frequencies, {h5.attrs['num_reals']} realizations, {h5.attrs['num_pars']} parameters")
    gff = h5['gff'][()]
    gwf = h5['gwf'][()]
    gwb = h5['gwb'][()]
    mmbulge_norm = h5['mmbulge_norm'][()]

print(f"{fobs.shape=}, {gwf.shape=}, {mmbulge_norm.shape=}")

In [None]:
np.log10(mmbulge_norm)

## Quick Look at Data Features

In [None]:
gff.shape, gwf.shape, gwb.shape

In [None]:
def draw_sampled_gwb(ax, fobs, gff, gwf, gwb):
    xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]
    col, = ax.plot(xx, gwb, ls='-', alpha=0.75, lw=0.75)
    col = col.get_color()

    idx = (gwf > gwb)
    xx = gff * YR   # [1/sec] ==> [1/yr]
    ax.scatter(xx[idx], gwf[idx], color=col, s=10, alpha=0.5)
    ax.scatter(xx[~idx], gwf[~idx], edgecolor=col, facecolor='none', s=10, alpha=0.5)
    return

fig, ax = figax()

PAR = 5
REAL = 0
draw_sampled_gwb(ax, fobs, gff[PAR, REAL, :], gwf[PAR, REAL, :], gwb[PAR, REAL, :])
plt.show()

In [None]:
fig, ax = figax()
xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

PAR = 6
for gw, lab in zip([gwf, gwb], ['fore', 'back']):
    med, *span = np.percentile(gw[PAR, :, :], [50, 25, 75], axis=0)
    col, = ax.plot(xx, med, label=lab)
    col = col.get_color()
    ax.fill_between(xx, *span, color=col, alpha=0.5)

plt.legend()
plt.show()


In [None]:
num_pars = gwb.shape[0]
print(f"{num_pars=}")

fig, axes = figax(figsize=[16, 6], ncols=2)
xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

for ax, gw, lab in zip(axes, [gwf, gwb], ['fore', 'back']):
    ax.set_title(lab)
    for pp in range(num_pars):
        med, *span = np.percentile(gw[pp, :, :], [50, 25, 75], axis=0)
        col, = ax.plot(xx, med, label=lab)
        col = col.get_color()
        ax.fill_between(xx, *span, color=col, alpha=0.5)

# plt.legend()
plt.show()


In [None]:
xf = np.linspace(-17.5, -10, 300)
xb = np.linspace(-16, -13, 300)
fig, axes = figax(ncols=2, plamp=None, scale='lin')

off = 0.2e15
colors = [mpl.cm.get_cmap('RdBu')(ii) for ii in np.linspace(0.1, 0.9, num_pars)]

FREQ = 0
for ax, xx, gw, lab in zip(axes, [xf, xb], [gwf, gwb], ['fore', 'back']):
    ax.set_title(lab)
    for pp in range(num_pars):
        zz = np.log10(gw[pp, :, FREQ])
        _, yy = kale.density(zz, xx, probability=True)
        col, = ax.plot(xx, yy, color=colors[pp])
        col = col.get_color()
        kale.carpet(zz, ax=ax, color=col)


plt.show()


In [None]:
FREQ = 100

xf = np.linspace(-18, -13.5, 300)
xb = np.linspace(-18, -14.5, 300)
fig, axes = figax(ncols=2, plamp=None, scale='lin')

off = 0.2e15
colors = [mpl.cm.get_cmap('RdBu')(ii) for ii in np.linspace(0.1, 0.9, num_pars)]

for ax, xx, gw, lab in zip(axes, [xf, xb], [gwf, gwb], ['fore', 'back']):
    ax.set_title(lab)
    for pp in range(num_pars)[::-1]:
        zz = np.log10(gw[pp, :, FREQ])
        zz = zz[np.isfinite(zz)]
        _, yy = kale.density(zz, xx, probability=True)
        col, = ax.plot(xx, yy, color=colors[pp])
        col = col.get_color()
        kale.carpet(zz, ax=ax, color=col)


plt.show()


# Gaussian Processes

In [None]:
# import scipy.signal as ssig
import scipy.signal
# import scipy.interpolate as interp

# import scipy.linalg as sl
# import scipy.special as ss
# import scipy.constants as sc
# import scipy.misc as scmisc
# import scipy.integrate as si

import george
import george.kernels as kernels
import emcee

## Setup data

In [None]:
## NOTE - Only need to train GP on number of frequencies in PTA analysis !
NFREQ = 3

freqs = kale.utils.midpoints(fobs[:NFREQ+1]) * YR

# (P, R, F)
gwb_spectra = gwb[:, :, :NFREQ] ** 2
print(utils.stats(gwb_spectra))

# Find all of the zeros and set them to be h_c = 1e-20
# low_ind = np.where(gwb_spectra < 1e-40)
# gwb_spectra[low_ind] = 1e-40

# Find std over realizations
# (P[arams], F[reqs])
err = np.std(np.log10(gwb_spectra), axis=1)
# Find mean over realizations
# (P[arams], F[reqs])
mean = np.log10(np.mean(gwb_spectra, axis=1))

# Smooth Mean Spectra over frequencies
## NOTE FOR LUKE - HOW MUCH SMOOTHING DO WE WANT TO DO ?
# print(mean.shape)
# smooth_mean = sp.signal.savgol_filter(mean, 7, 3, axis=-1)
smooth_mean = mean.copy()

if np.any(np.isnan(err)):
    print('Got a NAN issue')
    raise

In [None]:
PAR = 0
fig, ax = figax(plamp=None, scale='lin')

ax.plot(freqs, np.log10(gwb[PAR, :, :NFREQ].T**2), color='C0', alpha=0.3, zorder=0)

ax.plot(freqs, mean[PAR], color='C1', label='Mean')
ax.plot(freqs, smooth_mean[PAR], color='C3', label='Smoothed Mean')
ax.fill_between(freqs, (mean[PAR]-err[PAR]), (mean[PAR]+err[PAR]), color='C1', alpha=0.5)

plt.legend()
ax.set_xlabel(r'GW Frequency [yr$^{-1}$]')
ax.set_ylabel(r'$h_{c}^{2}$')
plt.show()

## Train GP

In [None]:
# Define a GP class containing the kernel parameter priors and a log-likelihood

class GP:
    
    def __init__(self, x, y, yerr=None):
        
        self.x = x
        self.y = y
        self.yerr = yerr
        
        # The number of GP parameters is one more than the number of spectra parameters.
        # self.pmax = np.array([20.0, 20.0, 20.0, 20.0, 20.0, 20.0]) # sampling ranges
        # self.pmin = np.array([-20.0, -20.0, -20.0, -20.0, -20.0, -20.0])
        self.pmin = np.array([-20.0, -20.0])
        self.pmax = np.array([20.0, 20.0]) # sampling ranges

        self.emcee_flatchain = None
        self.emcee_flatlnprob = None
        self.emcee_kernel_map = None
    
    def lnprior(self, p):
        logp = 0.0
    
        if np.all(p <= self.pmax) and np.all(p >= self.pmin):
            logp = np.log(1.0 / (self.pmax - self.pmin))
            logp = np.sum(logp)
        else:
            logp = -np.inf

        return logp

    def lnlike(self, p):

        # Update the kernel and compute the lnlikelihood.
        a, tau = np.exp(p[0]), np.exp(p[1:])
        
        try:
            gp = george.GP(a * kernels.ExpSquaredKernel(tau, ndim=len(tau)))
            #gp = george.GP(a * kernels.Matern32Kernel(tau))
            gp.compute(self.x, self.yerr)
            lnlike = gp.lnlikelihood(self.y, quiet=True)
        except np.linalg.LinAlgError:
            lnlike = -np.inf
        
        return lnlike
    
    def lnprob(self, p):
        return self.lnprior(p) + self.lnlike(p)

In [None]:
## Load in the spectra data!

# The "y" data are the means and errors for the spectra at each point in parameter space
yobs = smooth_mean.copy() #mean.copy()
yerr = err.copy()

## Find mean in each frequency bin (remove it before analyzing with the GP) ##
# This allows the GPs to oscillate around zero, where they are better behaved.
yobs_mean = np.mean(yobs, axis=0)
# MAKE SURE TO SAVE THESE VALUES - THE GP IS USELESS WITHOUT THEM !
# np.save('./Luke_Spectra_MEANS.npy', yobs_mean)

yobs -= yobs_mean[np.newaxis, :]

## The "x" data are the actual parameter values
xobs = np.zeros((num_pars, 1))
xobs[:, 0] = np.log10(mmbulge_norm)

#['eccs_mu', 'hard_gamma', 'MM2013_amp', 'MM2013_slope', 'tdelay']
# for ii in range(120):
#     xobs[ii,0] = spectra['eccs_mu'][ii]
#     xobs[ii,1] = spectra['hard_gamma'][ii]
#     xobs[ii,2] = spectra['MM2013_amp'][ii]
#     xobs[ii,3] = spectra['MM2013_slope'][ii]
#     xobs[ii,4] = spectra['tdelay'][ii]

In [None]:
# Instanciate a list of GP kernels and models [one for each frequency]
gp_george = []
k = []

for freq_ind in range(len(freqs)):
    gp_george.append(GP(xobs, yobs[:, freq_ind], yerr[:, freq_ind]))
    k.append(1.0 * kernels.ExpSquaredKernel([2.0], ndim=1))

num_kpars = len(k[freq_ind])
    
print(num_kpars)

In [None]:
import time
# Sample the posterior distribution of the kernel parameters 
# to find MAP value for each frequency. 

# THIS WILL TAKE A WHILE... (~ 1 min per frequency)

sampler = [0.0]*len(freqs)
for freq_ind in range(len(freqs)):
    t_start = time.time()
    
    # Set up the sampler.
    nwalkers, ndim = 36, num_kpars
    sampler[freq_ind] = emcee.EnsembleSampler(nwalkers, ndim, gp_george[freq_ind].lnprob)

    # Initialize the walkers.
    p0 = [np.log([1.0, 1.0]) + 1e-4 * np.random.randn(ndim) for i in range(nwalkers)]

    print(freq_ind, "Running burn-in")
    p0, lnp, _ = sampler[freq_ind].run_mcmc(p0, 750)
    sampler[freq_ind].reset()

    print(freq_ind, "Running second burn-in")
    p = p0[np.argmax(lnp)]
    p0 = [p + 1e-8 * np.random.randn(ndim) for i in range(nwalkers)]
    p0, _, _ = sampler[freq_ind].run_mcmc(p0, 750)
    sampler[freq_ind].reset()

    print(freq_ind, "Running production")
    p0, _, _ = sampler[freq_ind].run_mcmc(p0, 1500)
    
    print('Completed in {} min'.format((time.time()-t_start)/60.) , '\n')
    if freq_ind > 1:
        break

In [None]:
import corner

In [None]:
## Let's take a look at the posterior distribution of the 
# kernel parameters at a frequency [ind] of our choice.

ind = 1

fig = corner.corner(sampler[ind].flatchain, bins=50)
plt.show()

In [None]:
## Populate the GP class with the details of the kernel 
## MAP values for each frequency.

for ii in range(len(freqs)):
    gp_george[ii].chain = None 
    gp_george[ii].lnprob = None 
    
    gp_george[ii].kernel_map = sampler[ii].flatchain[np.argmax(sampler[ii].flatlnprobability)] 
    #print(ii, gp_george[ii].kernel_map)
    
    # add-in mean yobs (freq) values
    gp_george[ii].mean_spectra = yobs_mean[ii]
    if ii > 1:
        break

In [None]:
## Set-up GP predictions ##
# If you are running this part of the code separately from the section above, 
# you will need to re-define the GP class from above for this step to work!

gp = []
# GP_freqs = np.arange(1.,31.) / (20*365.25*86400.) 

# for ii in range(len(GP_freqs)):
for ii in range(3):
    gp_kparams = np.exp(gp_george[ii].kernel_map)

    gp.append(george.GP(gp_kparams[0] * \
            george.kernels.ExpSquaredKernel(gp_kparams[1:],ndim=len(gp_kparams[1:])) ) )

    gp[ii].compute(gp_george[ii].x, gp_george[ii].yerr)

In [None]:
mmbulge_norm

In [None]:
## Make a realization from the GP ##
PAR = 3

#  A reminder of the spectra parameters:
# ['eccs_mu', 'hard_gamma', 'MM2013_amp', 'MM2013_slope', 'tdelay']
# env_param = np.array([5.6249, -0.0807,  8.8394,  1.284 ,  5.9822])
# env_param = np.array([41.0])
env_param = np.log10(np.array([mmbulge_norm[PAR]]))

# rho_pred = np.zeros((len(GP_freqs), 2))
rho_pred = np.zeros((3, 2))
for ii, freq in enumerate(freqs):
    mu_pred, cov_pred = gp[ii].predict(gp_george[ii].y, [env_param])
    if np.diag(cov_pred) < 0.0:
        rho_pred[ii, 0], rho_pred[ii, 1] = mu_pred, 1e-5 * mu_pred
        print(bad)
    else:
        rho_pred[ii, 0], rho_pred[ii, 1] = mu_pred, np.sqrt(np.diag(cov_pred))

    if ii > 1:
        break

## transforming from zero-mean unit-variance variable to rho
rho = np.array([gp_george[ii].mean_spectra for ii in range(len(freqs))]) + rho_pred[:, 0]

hc = np.sqrt(10**rho)

In [None]:
## Making a plot ##

# the raw spectra #
# for ii in range(100):
plt.loglog(freqs, gwb[PAR, :, :NFREQ].T, color='C0', alpha=0.2, zorder=0)

# plt.loglog(spectra['freqs'][:30]/(365.25*86400.), spectra['gwb'][3,:30,ii], color='C0', alpha=0.2, zorder=0, label='Original Spectra')

# the smoothed mean #
plt.loglog(freqs, np.sqrt(10**smooth_mean[PAR, :NFREQ]), color='C1', ls='--', label='Smoothed Mean', lw=2)

# the GP realization #
plt.semilogx(freqs, hc, color='C3', lw=2.5, label='GP')
plt.fill_between(freqs, np.sqrt(10**(rho + rho_pred[:, 1])), np.sqrt(10**(rho - rho_pred[:, 1])), color='C3', alpha=0.5)

plt.xlabel('Observed GW Frequency [yr$^{-1}$]')
# plt.xlim(1e-9,7e-8)
plt.ylabel(r'$h_{c} (f)$')
# plt.ylim(1e-16, 1e-13)

plt.legend(loc=3)
#plt.savefig('./TrainedGP.pdf', bbox_inches='tight', dpi=500)

In [None]:
gwb.shape, freqs.shape, hc.shape