In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py


from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim

# Setup

## Load vary param rv's

In [None]:
npz_hard_time = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/hard_time_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
                        allow_pickle=True)
print(npz_hard_time.files)

data = npz_hard_time['data']
params = npz_hard_time['params']
hard_name = npz_hard_time['hard_name']
shape = npz_hard_time['shape']
target_param = npz_hard_time['target_param']

# npz_hard_time.close()

## Get pspace info

In [None]:
pspace = holo.param_spaces.PS_Uniform_09A(holo.log, nsamples=1, sam_shape=shape, seed=None)
param_names = pspace.param_names
print(param_names)

In [None]:
math_param_names = np.array([r'$t_\mathrm{hard}$', r'$\Phi_0$', r'$M_\mathrm{char,0}$',
                            r'$\log \mu$', r'$\epsilon_\mu$', r'$\gamma_\mathrm{inner}'])
short_param_names = np.array(['t_hard', 'Phi0', 'Mchar0$',
                            'Mamp', 'scatter', 'gam_in'])

# General Plotting Functions

In [None]:
def draw_sample_text(fig, params, param_names, 
                     xx=0.1, yy=-0.025, fontsize=10, color='k'):
    text = ''
    for pp, name in enumerate(param_names):
        text = text+short_param_names[pp]+'=%.2e, ' % (params[name])
        # if pp == int(len(param_names)/2):
        #     text = text+'\n'
    fig.text(xx, yy, text, fontsize=fontsize, color=color, alpha=0.75,
             parse_math=True)

# def draw_bg_par_vs_freq(ax, xx, xx_ss=None, yy_ss=None, xx_bg=None, yy_bg=None, 
#                      color_ss='r', color_bg='k', colors=None,
#                      show_ss_medians=True, 
#                      fast, show_reals):
#     if show_ss_medians:
#         ax.

    # if show_reals:
    #     if fast:
    #         ax.scatter(xx_ss.flatten(), yy_ss.flatten(), marker='o', alpha=0.1, color=color_ss)
    #         ax.scatter(xx_bg.flatten(), yy_bg.flatten(), marker='x', alpha=0.1, color=color_bg)
        # else:
        #     for rr in range(nreals):
        #         for ll in range(nloudest):
        #             edgecolor='k' if ll==0 else None
        #             ax.scatter(xx, yy_ss[:,rr,ll], marker='o', s=15, alpha=0.1, color=colors[rr], edgecolor=edgecolor)
        #         ax.plot(xx, yy_bg[:,rr], marker='x', alpha=0.1, color=colors[rr])

# Plot mass and char strain

In [None]:
def draw_mass_hc_vs_freq(fig, axs, fobs_cents, hc_ss, hc_bg, sspar, bgpar,
                          color_ss='r', color_bg='k', ls_ss = ':', ls_bg = '-', lw_ss = 2, lw_bg = 2, 
                          fast=True, show_reals=True): 
    colors = cm.rainbow(np.linspace(0,1,len(hc_ss[0])))
    idx = [0,] # mass only
    shape = hc_ss.shape
    nfreqs, nreals, nloudest = shape[0], shape[1], shape[2]
    xx = fobs_cents*YR
    xx_ss = np.repeat(xx, nreals*nloudest).reshape(nfreqs, nreals, nloudest)
    xx_bg = np.repeat(xx, nreals).reshape(nfreqs, nreals)

    units = sings.par_units[idx]
    yy_ss = sspar[idx]*units[:,np.newaxis,np.newaxis,np.newaxis] # shape 1,F,R,L
    yy_ss = np.append(yy_ss, hc_ss).reshape(2, nfreqs, nreals, nloudest) # shape 2,F,R,L
    yy_bg = bgpar[idx]*units[:,np.newaxis,np.newaxis] # shape 1,F,R
    yy_bg = np.append(yy_bg, hc_bg).reshape(2, nfreqs, nreals) # shape 2,F,R

    for ii, ax in enumerate(axs[:,0]): # first column, bg only
        ax.plot(xx, np.median(yy_bg[ii], axis=-1), color=color_bg, linestyle=ls_bg, alpha=0.75)
        for pp in [50, 98]:
            conf = np.percentile(yy_bg[ii], [50-pp/2, 50+pp/2], axis=-1)
            ax.fill_between(xx, *conf, color=color_bg, alpha=0.25)
    
    for ii, ax in enumerate(axs[:,1]): # second columns, bg median and ss
        ax.plot(xx, np.median(yy_bg[ii], axis=-1), color=color_bg, linestyle=ls_bg, alpha=0.75)
        # all loudest, small and very transparent
        # ax.scatter(xx_ss.flatten(), yy_ss[ii].flatten(), color=color_ss, alpha=0.1, s=5)
        # single loudest of any realization
        ax.scatter(xx_ss[...,0].flatten(), yy_ss[ii,:,:,0].flatten(), color=color_ss,  alpha=0.1, s=15)
        # median loudest of all realizations, with errorbars
        ax.errorbar(xx, np.median(yy_ss[ii,:,:,0], axis=-1), yerr=np.std(yy_ss[ii,:,:,0], axis=-1),
                    color=color_ss, alpha=0.5, marker='o', markersize=5, capsize=5)
        ax.sharey(axs[ii,0])


    # # plot all pars and hs
    # for ii,ax in enumerate(axs.flatten()):
    #     draw_par_vs_freq(ax, xx, xx_ss, yy_ss[ii], xx_bg, yy_bg[ii], 
    #                     color_ss, color_bg, colors,
    #                     fast, show_reals, nreals, nloudest)
    
    return fig


def plot_mass_hc_vs_freq(
        data, params, hard_name, shape, target_param, 
        datcolor_ss = np.array(['limegreen', 'cornflowerblue', 'tomato']),
        datcolor_bg = np.array(['#003300', 'darkblue', 'darkred']),
        datlw = np.array([3,4,5]),
        dattext_yy = np.array([-0.02, -0.05, -0.08])):
    
    fobs_cents = data[0]['fobs_cents']
    fig, axs = holo.plot.figax(
        nrows=2, ncols=2, sharex=True, figsize=(10,6))

    idx = [0,]
    labels = np.append(sings.par_labels[idx], 
                        np.array([plot.LABEL_CHARACTERISTIC_STRAIN]))
        
    for ax in axs[-1]:
        ax.set_xlabel(holo.plot.LABEL_GW_FREQUENCY_YR)
    for ii,ax in enumerate(axs[:,0]):
        ax.set_ylabel(labels[ii])


    for ii, dat in enumerate(data):
        print(f'on dat {ii}')
        fig = draw_mass_hc_vs_freq(fig, axs, fobs_cents, dat['hc_ss'], dat['hc_bg'], 
                                    dat['sspar'], dat['bgpar'], 
                                    color_ss=datcolor_ss[ii], color_bg=datcolor_bg[ii],
                                    lw_bg = datlw[ii], lw_ss = datlw[ii]) 
        draw_sample_text(fig, params[ii], param_names, color=datcolor_bg[ii], 
                         yy=dattext_yy[ii], xx=0, fontsize=12)
    fig.suptitle("%s, %s, Varying '%s'" % (hard_name, str(shape), target_param))
    fig.tight_layout()

    return fig

# fig = draw_three_models(data = data_hard_time, params = params_hard_time,
#                         hard_name=hard_name, shape=sam.shape, target_param=target_param)

'hard_time'

In [None]:
npz = npz_hard_time
fig = plot_mass_hc_vs_freq(npz['data'], npz['params'], npz['hard_name'], npz['shape'], 
                           npz['target_param'],)

'gsmf_phi0'

In [None]:
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/gsmf_phi0_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
              allow_pickle=True)
fig = plot_mass_hc_vs_freq(npz['data'], npz['params'], npz['hard_name'], npz['shape'], 
                           npz['target_param'],)

'gsmf_mchar0_log10'

In [None]:
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/gsmf_mchar0_log10_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
              allow_pickle=True)
fig = plot_mass_hc_vs_freq(npz['data'], npz['params'], npz['hard_name'], npz['shape'], 
                           npz['target_param'],)

'mmb_mamp_log10'

In [None]:
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/mmb_mamp_log10_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
              allow_pickle=True)
fig = plot_mass_hc_vs_freq(npz['data'], npz['params'], npz['hard_name'], npz['shape'], 
                           npz['target_param'],)

'mmb_scatter_dex'

In [None]:
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/mmb_scatter_dex_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
              allow_pickle=True)
fig = plot_mass_hc_vs_freq(npz['data'], npz['params'], npz['hard_name'], npz['shape'], 
                           npz['target_param'],)

'hard_gamma_inner'

In [None]:
npz = np.load('/Users/emigardiner/GWs/holodeck/ecg-notebooks/parameter_investigation/anatomy_uniform09A_fullshape/hard_gamma_inner_p0.5_0.5_0.5_0.5_0.5_0.5_s91_81_101.npz',
              allow_pickle=True)
fig = plot_mass_hc_vs_freq(npz['data'], npz['params'], npz['hard_name'], npz['shape'], 
                           npz['target_param'],)