In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import h5py
import matplotlib as mpl
from tqdm import tqdm

from holodeck import plot, detstats
import holodeck.single_sources as sings
from holodeck.constants import YR, MSOL, MPC
import holodeck as holo

import hasasia.sim as hsim
import os



In [None]:
SHAPE = None
NREALS = 500
# NREALS = 20
NFREQS = 40
NLOUDEST = 10

BUILD_ARRAYS = False
SAVEFIG = False
TOL=0.01
MAXBADS=5

NVARS = 21
# NVARS = 6

NPSRS = 40
NSKIES = 100
# NSKIES = 15

PARAM_NAMES = ['gsmf_phi0', 'gsmf_mchar0_log10',
               'mmb_mamp_log10', 'mmb_scatter_dex', 
               'hard_time', 'hard_gamma_inner']

In [None]:
def get_data(
        target, nvars=NVARS, nreals=NREALS, nskies=NSKIES, shape=SHAPE, red_gamma = None, red2white=None,
    path = '/Users/emigardiner/GWs/holodeck/output/anatomy_redz'     
):

    if path == '/Users/emigardiner/GWs/holodeck/output/anatomy_09B':
        load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_s{nskies}_shape{str(shape)}.npz' 
    else:
        load_data_from_file = path+f'/{target}_v{nvars}_r{nreals}_shape{str(shape)}/data_params.npz' 

    if os.path.exists(load_data_from_file) is False:
        err = f"load data file '{load_data_from_file}' does not exist, you need to construct it."
        raise Exception(err)
    # if os.path.exists(load_dets_from_file) is False:
    #     err = f"load dets file '{load_dets_from_file}' does not exist, you need to construct it."
    #     raise Exception(err)
    file = np.load(load_data_from_file, allow_pickle=True)
    data = file['data']
    params = file['params']
    file.close()

    # file = np.load(load_dets_from_file, allow_pickle=True)
    # dsdat = file['dsdat']
    # file.close()

    return data, params

# Plot hc_total for all variations

In [None]:
def plot_char_strains(target, data, params):

    colors = cm.rainbow(np.linspace(0,1,NVARS))
    xlabel = plot.LABEL_GW_FREQUENCY_YR
    ylabel = plot.LABEL_CHARACTERISTIC_STRAIN

    xx = data[0]['fobs_cents'] * YR
    parvars = []

    fig,axs = plot.figax_single(ncols=3, figsize=(12, 4), sharey=True, sharex=True)
    fig.text(0.5, 0.02, xlabel, va='center', ha='center')
    axs[0].set_ylabel(ylabel)

    lw_ss = np.ones(NLOUDEST)
    lw_ss[0] = 2

    ls_ss = np.repeat('--', NLOUDEST)
    ls_ss[0] = '-'

    for vv in range(NVARS):
        hc_ss = data[vv]['hc_ss']
        hc_bg = data[vv]['hc_bg']
        hc_tot = np.sqrt(hc_bg**2 + np.sum(hc_ss**2, axis=-1))
        parvars.append(params[vv][target])

        for ii, yy in enumerate([hc_tot, [], hc_bg]):
            axs[ii].plot(np.median(yy, axis=(-1)), color=colors[vv], alpha=0.5, lw=1)
            # axs[1].plot(np.median(hc_ss, axis=(-2,-1)), color=colors[vv], alpha=0.5, lw=0.5)
            # axs[2].plot(np.median(hc_bg, axis=(-1)), color=colors[vv], alpha=0.5, lw=0.5)
        
        # plot first loudest
        ii=0
        axs[1].plot(np.median(hc_ss[...,ii], axis=-1), color=colors[vv], alpha=0.5, lw=lw_ss[ii], ls=ls_ss[ii])

        # plot rest of loudest
        for ii in range(1, NLOUDEST):
            axs[1].plot(np.median(hc_ss[...,ii], axis=-1), color=colors[vv], alpha=0.25, lw=lw_ss[ii], ls=ls_ss[ii])
        
        
    fig.subplots_adjust(wspace=0)
    fig.text(0.365, 0.92, 'hc_tot', ha='left', va='top')
    fig.text(0.635, 0.92, 'hc_ss', ha='left', va='top')
    fig.text(0.9, 0.92, 'hc_bg', ha='left', va='top')

    norm = mpl.colors.Normalize(vmin=parvars[0], vmax=parvars[-1])
    cax = fig.add_axes([0.25,-0.04,0.5,0.02])
    # cmap = mpl.colors.LinearSegmentedColormap.from_list(parvars, colors)
    cb = mpl.colorbar.ColorbarBase(cax, cmap=mpl.cm.rainbow, norm=norm, orientation='horizontal', 
                                label=plot.PARAM_KEYS[target])
    # fig.suptitle(target)
    # im = axs[2].scatter([], [], [], cmap=cmap)
    # plt.colorbar(im, ax=axs[2], cmap=colors)
    return fig


In [None]:
def plot_masses(target, data, params):

    colors = cm.rainbow(np.linspace(0,1,NVARS))
    xlabel = plot.LABEL_GW_FREQUENCY_YR
    ylabel = "Mass [$\mathrm{M}_\odot$]"

    xx = data[0]['fobs_cents'] * YR
    parvars = []

    fig,axs = plot.figax_single(ncols=2, figsize=(8, 4), sharey=True, sharex=True)
    fig.text(0.5, 0.02, xlabel, va='center', ha='center')
    axs[0].set_ylabel(ylabel)

    lw_ss = np.ones(NLOUDEST)
    lw_ss[0] = 2

    ls_ss = np.repeat('--', NLOUDEST)
    ls_ss[0] = '-'

    for vv in range(NVARS):
        mt_ss = data[vv]['sspar'][0]/MSOL
        mt_bg = data[vv]['bgpar'][0]/MSOL
        # mt_tot = np.sqrt(hc_bg**2 + np.sum(hc_ss**2, axis=-1))
        parvars.append(params[vv][target])

        axs[1].plot(np.median(mt_bg, axis=(-1)), color=colors[vv], alpha=0.5, lw=1)
            # axs[1].plot(np.median(hc_ss, axis=(-2,-1)), color=colors[vv], alpha=0.5, lw=0.5)
            # axs[2].plot(np.median(hc_bg, axis=(-1)), color=colors[vv], alpha=0.5, lw=0.5)
        
        # plot first loudest
        ii=0
        axs[0].plot(np.median(mt_ss[...,ii], axis=-1), color=colors[vv], alpha=0.5, lw=lw_ss[ii], ls=ls_ss[ii])

        # plot rest of loudest
        for ii in range(1, NLOUDEST):
            axs[0].plot(np.median(mt_ss[...,ii], axis=-1), color=colors[vv], alpha=0.25, lw=lw_ss[ii], ls=ls_ss[ii])
        
        
    fig.subplots_adjust(wspace=0)
    # fig.text(0.365, 0.92, 'hc_tot', ha='left', va='top')
    fig.text(0.48, 0.92, '$M_\mathrm{SS}$', ha='left', va='top', fontsize=14)
    fig.text(0.86, 0.92, '$\langle M \\rangle_\mathrm{BG}$', ha='left', va='top', fontsize=14)

    norm = mpl.colors.Normalize(vmin=parvars[0], vmax=parvars[-1])
    cax = fig.add_axes([0.25,-0.04,0.5,0.02])
    # cmap = mpl.colors.LinearSegmentedColormap.from_list(parvars, colors)
    cb = mpl.colorbar.ColorbarBase(cax, cmap=mpl.cm.rainbow, norm=norm, orientation='horizontal', 
                                label=plot.PARAM_KEYS[target])
    # fig.suptitle(target)
    # im = axs[2].scatter([], [], [], cmap=cmap)
    # plt.colorbar(im, ax=axs[2], cmap=colors)
    return fig


In [None]:
target='gsmf_phi0'
# data1, params1 = get_data(target)
fig = plot_char_strains(target, data1, params1)
fig = plot_masses(target, data1, params1)

In [None]:
target='gsmf_mchar0_log10'
# data2, params2 = get_data(target)
fig = plot_char_strains(target, data2, params2)
fig = plot_masses(target, data2, params2)

In [None]:
target='mmb_mamp_log10'
# data3, params3 = get_data(target)
fig = plot_char_strains(target, data3, params3)
fig = plot_masses(target, data3, params3)

In [None]:
target='mmb_scatter_dex'
data4, params4 = get_data(target)
fig = plot_char_strains(target, data4, params4)
fig = plot_masses(target, data4, params4)

In [None]:
target='hard_time'
data5, params5 = get_data(target)
fig = plot_char_strains(target, data5, params5)
fig = plot_masses(target, data5, params5)

In [None]:
target='hard_gamma_inner'
data6, params6 = get_data(target)
fig = plot_char_strains(target, data6, params6)
fig = plot_masses(target, data6, params6)