In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import kalepy as kale
from datetime import datetime
import tqdm

# NOTE: this is installed with `pip install matplotlib-label-lines`
from labellines import labelLine, labelLines

import holodeck as holo
import holodeck.librarian
from holodeck import plot, utils, cosmo
from holodeck.constants import YR, GYR, MSOL, PC
# from holodeck.sams import cyutils as sam_cyutils

In [None]:
SHAPE = 13
NREALS = 10
NLOUDEST = 5

Define a function that will run a given combination of a SAM and hardening model

In [None]:
def old_run_model(sam, hard, singles_flag=False, params_flag=False, gwb_flag=True, details_flag=False):
    """Run the given modeling, storing requested data
    """
    fobs_cents, fobs_edges = holo.librarian.get_freqs(None)
    fobs_orb_cents = fobs_cents / 2.0     # convert from GW to orbital frequencies
    fobs_orb_edges = fobs_edges / 2.0     # convert from GW to orbital frequencies

    data = dict(fobs=fobs_cents)

    redz_final, diff_num = sam_cyutils.dynamic_binary_number_at_fobs(
        fobs_orb_cents, sam, hard, cosmo
    )
    use_redz = redz_final
    edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
    number = sam_cyutils.integrate_differential_number_3dx1d(edges, diff_num)
    if details_flag:
        data['number'] = number
        data['redz_final'] = redz_final

    # calculate single sources and/or binary parameters
    if singles_flag or params_flag:
        nloudest = NLOUDEST if singles_flag else 1

        vals = holo.single_sources.ss_gws_redz(
            edges, use_redz, number, realize=NREALS,
            loudest=nloudest, params=params_flag,
        )
        if params_flag:
            hc_ss, hc_bg, sspar, bgpar = vals
            data['sspar'] = sspar
            data['bgpar'] = bgpar
        else:
            hc_ss, hc_bg = vals

        if singles_flag:
            data['hc_ss'] = hc_ss
            data['hc_bg'] = hc_bg

    if gwb_flag:
        gwb = holo.gravwaves._gws_from_number_grid_integrated_redz(edges, use_redz, number, NREALS)
        data['gwb'] = gwb

    return data

# Compare simulations

In [None]:
import holodeck.librarian.old_param_spaces
pspace = holo.librarian.old_param_spaces.PS_Uniform_07B_Rot(holo.log, nsamples=1, sam_shape=SHAPE, seed=None)

# get the parameter names from this library-space
param_names = pspace.param_names.copy() # use copy method so that we can reorder param_names later if we want.
num_pars = len(pspace.param_names)
print(f"{num_pars=} :: {param_names=}")

fiducial_pars = 0.5 * np.ones(num_pars)
fiducial_parameters = pspace._normalized_params(fiducial_pars)
fiducial_parameters['hard_time'] = 1.0
#print(fiducial_parameters)
print((1.0 - 0.1)/10.9)
print(pspace._normalized_params([0.08256880733944955, 0.5, 0.5, 0.5, 0.5, 0.5])['hard_time'])

In [None]:
import holodeck.librarian.old_param_spaces
import holodeck.librarian.libraries

# construct a param_space instance, note that `nsamples` and `seed` don't matter here for how we'll use this
pspace = holo.librarian.old_param_spaces.PS_Uniform_07B_Rot(holo.log, nsamples=1, sam_shape=SHAPE, seed=None)

# get the parameter names from this library-space
param_names = pspace.param_names.copy() # use copy method so that we can reorder param_names later if we want.
num_pars = len(pspace.param_names)
print(f"{num_pars=} :: {param_names=}")

fiducial_pars = 0.5 * np.ones(num_pars)
fiducial_pars[0] = 0.08256880733944955 # force to be fiducial hard_time = 1.0

fiducial_parameters = pspace._normalized_params(fiducial_pars)
fiducial_sam, fiducial_hard = pspace.model_for_params(fiducial_parameters, sam_shape=SHAPE) # <-- slow for some reason?
#fiducial_sam, fiducial_hard = pspace.model_for_normalized_params(fiducial_pars)
fiducial_model_data = holo.librarian.libraries.run_model(fiducial_sam, fiducial_hard, nreals=NREALS, gwb_flag=True, details_flag=True)  # Only run fiducial model once to save time!

alldata = []
allparams_list = {}
## Make a GW only model for comparison
fiducial_gwonly_hard = holo.hardening.Hard_GW()
fiducial_sam.ZERO_DYNAMIC_STALLED_SYSTEMS = False
fiducial_sam.ZERO_GMT_STALLED_SYSTEMS = True
fiducial_gwonly_model_data = holo.librarian.libraries.run_model(
    fiducial_sam, fiducial_gwonly_hard, nreals=NREALS, gwb_flag=True, details_flag=True
)
print(fiducial_gwonly_model_data.keys())

# Make a straw man 2/3 power-law normalized to GW only model at lowest freq
powerlawnorm = np.median(fiducial_gwonly_model_data['gwb'][0])
powerlawlowfreq = fiducial_gwonly_model_data['fobs_cents'][0]

# Not as elegant, but easier to find the min and max values on creation rather than after
mingwb = np.min(fiducial_gwonly_model_data['gwb'])
maxgwb = np.max(fiducial_gwonly_model_data['gwb'])
for jj, target_param in enumerate(param_names):
    # choose default parameters to be half-way across the range provided by the library
    pars = np.copy(fiducial_pars)
    print(f"Varying {target_param}, holding others constant at fiducial values.")


    # Choose a parameter that we're going to vary
    if target_param == 'hard_time':
        params_list = np.linspace(0.0, 1.0, 5)
        params_list[1] = 0.08256880733944955
    else:
        params_list = np.linspace(0.0, 1.0, 5)   # the values we'll check
    allparams_list[target_param] = params_list
    if not fiducial_pars[jj] in params_list:
        print(f"****\n\nBy not having {fiducial_pars[jj]} in your range of parameter values, you won't have a single model to compare over everything ({params_list=})\n\n****")
    param_idx = param_names.index(target_param)

    data = []
    for ii, par in enumerate(params_list):
        pars[param_idx] = par
        parameters = pspace._normalized_params(pars)
        print(f"{ii=}, {pars=}, {parameters=}, {fiducial_parameters=}")
        if parameters == fiducial_parameters:
            # Use stored fiducial run
            _data = fiducial_model_data
        else:
            # construct `sam` and `hard` instances based on these parameters

            # sam, hard = pspace.model_for_normalized_params(pars)
            sam, hard = pspace.model_for_params(parameters, sam_shape=SHAPE) # <-- slow for some reason?
            # run this model, retrieving binary parameters and the GWB
            _data = holo.librarian.libraries.run_model(sam, hard, nreals=NREALS, gwb_flag=True, details_flag=True)
        data.append(_data)
        mingwb = np.min([mingwb, np.min(_data['gwb'])])
        maxgwb = np.max([mingwb, np.max(_data['gwb'])])
    alldata.append(data)

In [None]:
dt = datetime.now()
timestamp = dt.__str__().replace(' ', '_') # yes this is bad form. No I don't care right now.
dataoutfname = f"gwb_anatomy_alldata_{pspace.__class__.__name__}_n{NREALS}_{timestamp}.npy"
# with open(dataoutfname, 'wb') as f:
#     np.save(f, alldata)

## Plot GWBs

If you use a parameter space with a new parameter, you will need to update `par_to_symbol` to give the tex symbol for plotting; and you will want to add the parameter to the appropriate thematic parameter list (`galaxy_params`, `blackhole_params`, `time_params`) or create a new one.

In [None]:
# Set up some plotting preferences
par_to_symbol = plot.PARAM_KEYS #{'gsmf_phi0' : r"$\log(\Phi_0 / \mathrm{Mpc^{-3}\ dex^{-1}})$", 'gsmf_mchar0_log10' : r"$\log{(M_0 / M_{\odot})}$", 'mmb_mamp_log10' : r"$\alpha_{\mathrm{MM}}$", 'mmb_scatter_dex' : r"$\epsilon_{\mathrm{MM}}$", 'hard_time' : r"$\tau / \mathrm{Gyr}$", 'hard_time_compare' : r"$\tau / \tau_{GW}$", 'hard_gamma_rot': r"$\Delta\gamma_{\mathrm{hard}}$", 'hard_gamma_outer': r"$\Delta\gamma_{\mathrm{hard,out}}$"}
# _G_alaxies are _G_reen, _B_lack holes are _B_lood orange, _T_ime is _T_otally purple
galaxy_params = ['gsmf_phi0', 'gsmf_mchar0_log10']
blackhole_params = ['mmb_mamp_log10', 'mmb_scatter_dex']
time_params = ['hard_time', 'hard_time_compare', 'hard_gamma_outer', 'hard_gamma_rot']
galaxy_colormap = plt.cm.Greens
blackhole_colormap = plt.cm.Oranges
time_colormap = plt.cm.Purples
default_colormap = plt.cm.cividis_r
fiducial_color = 'k'

# Don't let the y-ranges get out of hand.
mingwb = np.max([10.0**(-16.5), mingwb])
maxgwb = np.min([10.0**(-13.0), maxgwb])


## Reorder things for maximum aesthetics:
if pspace.__class__.__name__ in ["PS_Uniform_07A", "PS_Uniform_07B", "PS_Uniform_07C", "PS_Uniform_07D", "PS_Uniform_07E", "PS_Uniform_07F", "PS_Uniform_07G"]:
    neworder = [1, 3, 0, 2, 4]
    param_names[:] = [param_names[i] for i in neworder]
    alldata[:] = [alldata[i] for i in neworder]
if pspace.__class__.__name__ in ["PS_Uniform_08A", "PS_Uniform_07B_Rot"]:
    neworder = [1, 3, 0, 2, 4, 5]
    param_names[:] = [param_names[i] for i in neworder]
    alldata[:] = [alldata[i] for i in neworder]


In [None]:
print(param_names, num_pars)
par_to_symbol = plot.PARAM_KEYS

In [None]:
fracs = [0.68]    # fractions of spectra over which to draw contours
nsamp = None                   # number of individual spectra examples to plot

ncols = np.min([3, num_pars])
nrows = int(np.ceil(num_pars / 3))

fig, axsarr = plot.figax(ncols=ncols, nrows=nrows, grid=False, ylim=(mingwb, maxgwb), sharex='all', dpi=500)
fig.subplots_adjust(hspace=0, wspace=0)
axs = axsarr.flatten()

for jj, (target_param, data) in enumerate(zip(param_names, alldata)):
    # create a single-panel figure in publication style
    num_points = len(data)
    ax = axs[jj]
    #fig, ax = plot.figax_single(xlabel=plot.LABEL_GW_FREQUENCY_NHZ, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN, grid=False, ylim=(mingwb, maxgwb))
    if target_param in galaxy_params:
        colors = galaxy_colormap(np.linspace(0.5, 1, num=num_points))
    elif target_param in blackhole_params:
        colors = blackhole_colormap(np.linspace(0.5, 1, num=num_points))
    elif target_param in time_params:
        colors = time_colormap(np.linspace(0.7, 1, num=num_points))
    else:
        holo.log.warning(f"using default color for {target_param}. Better to assign it to parameter category for thematic consistency.")
        colors = default_colormap(np.linspace(0.5, 1, num=num_points))

    # draw the GWB
    handles = []
    for ii, dd in enumerate(data):
        color=colors[ii]
        if ii % 2 == 0:
            drawfracs = fracs
        else:
            drawfracs = 0
        label = f"${pspace._normalized_params(allparams_list[target_param][ii])[target_param]:.2f}$"
        hh = plot.draw_gwb(ax, dd['fobs_cents']*1e9, dd['gwb'], nsamp=nsamp, fracs=drawfracs, color=color, plot={'label':label})
        handles.append(hh)
    curxlims = ax.get_xlim()
    if target_param in ['gsmf_phi0', 'gsmf_mchar0_log10', 'mmb_mamp_log10', 'mmb_scatter_dex']:
        labellinesxvals = (curxlims[1], curxlims[0])
    else:
        labellinesxvals = curxlims

    labelLines(ax.get_lines(), xvals=labellinesxvals, zorder=2.5, shrink_factor=0, align=False)
    plot.draw_gwb(ax, fiducial_gwonly_model_data['fobs_cents']*1e9, fiducial_gwonly_model_data['gwb'], nsamp=None, fracs=0, color=fiducial_color, plot={'linestyle':'dashdot'})
    ax.plot(fiducial_gwonly_model_data['fobs_cents']*1e9, powerlawnorm * (fiducial_gwonly_model_data['fobs_cents']/powerlawlowfreq)**(-2./3.), linestyle=':', color=fiducial_color)
    plot.draw_gwb(ax, fiducial_model_data['fobs_cents']*1e9, fiducial_model_data['gwb'], nsamp=None, fracs=0, color=fiducial_color, plot={'linestyle':'dashed'})
    ax.text(0.97, 0.9, par_to_symbol.setdefault(target_param, f"update par_to_symbol for {target_param}"), transform=ax.transAxes, horizontalalignment='right')
    # draw [1/yr] label on the x2-axis
    plot._twin_yr(ax, label=False)
    #ax.legend(handles, params_list, title=f"{target_param} $\in [0.0, 1.0]$")

# Get rid of extraneous subplots
for ax in plt.gcf().axes:
    try:
        ax.label_outer()
    except:
        pass
if axs.size - num_pars > 0:
    for ax in axs[-(axs.size - num_pars):]:
        fig.delaxes(ax)
    for i in [-(axs.size - num_pars)]:
        axs[i-ncols].xaxis.set_tick_params(which="major", labelbottom=True)
fig.supxlabel(plot.LABEL_GW_FREQUENCY_NHZ)
# fig.supxlabel(plot.LABEL_GW_FREQUENCY_YR, y=0.99) # <-- matplotlib! Why does this erase the previous supxlabel?
fig.suptitle(plot.LABEL_GW_FREQUENCY_YR)
fig.supylabel(plot.LABEL_CHARACTERISTIC_STRAIN)
plt.show()
# plt.savefig(f"gwb_anatomy_{pspace.__class__.__name__}.png")

In [None]:
ncols = np.min([3, num_pars])
nrows = int(np.ceil(num_pars / 3))
logs = [True, True, False, False]
units = [MSOL, 1.0, 1.0, 1.0]
edges = [sam.mtot, sam.mrat, sam.redz, sam.redz]
fiducial_edges = [fiducial_sam.mtot, fiducial_sam.mrat, fiducial_sam.redz, fiducial_sam.redz]
xlabels = [r'Total mass [$M_{\odot}$]', r'Mass ratio', r'1 + $z_{i}$', r'1 + $z_{f}$']

for ii in range(4):
    fig, axsarr = plot.figax(ncols=ncols, nrows=nrows, grid=False, sharex='all', ylim=[1.0e-2, 1.0e10])
    fig.subplots_adjust(hspace=0, wspace=0)
    axs = axsarr.flatten()
    for jj, (target_param, data) in enumerate(zip(param_names, alldata)):
        if target_param in galaxy_params:
            colors = galaxy_colormap(np.linspace(0.5, 1, num=num_points))
        elif target_param in blackhole_params:
            colors = blackhole_colormap(np.linspace(0.5, 1, num=num_points))
        elif target_param in time_params:
            colors = time_colormap(np.linspace(0.7, 1, num=num_points))
        else:
            holo.log.warning(f"using default color for {target_param}. Better to assign it to parameter category for thematic consistency.")
            colors = default_colormap(np.linspace(0.5, 1, num=num_points))

        # find median
        for kk, dd in enumerate(data):
            print(kk, dd.keys())
            pp = dd['bin_params'][ii]
            xx = edges[ii] / units[ii]
            xxx = holo.utils.midpoints(edges[ii]) / units[ii] if logs[ii] else 1.0 + (holo.utils.midpoints(edges[ii]))
            # xxx = holo.utils.midpoints(sam.edges[ii]) / units[ii]
            #plot.draw_hist_steps(axs[ii], xx, pp, alpha=0.1, color='xkcd:blood orange')
            axs[jj].plot(xxx, pp, alpha=0.7, color=colors[kk])
        axs[jj].text(0.03, 0.1, par_to_symbol.setdefault(target_param, f"update par_to_symbol for {target_param}"), transform=axs[jj].transAxes, horizontalalignment='left')

        xxx = holo.utils.midpoints(fiducial_edges[ii]) / units[ii] if logs[ii] else 1.0 + (holo.utils.midpoints(fiducial_edges[ii]))
        axs[jj].plot(xxx, fiducial_gwonly_model_data['bin_params'][ii], alpha=0.7, color=fiducial_color, linestyle='dashdot')
        axs[jj].plot(xxx, fiducial_model_data['bin_params'][ii], alpha=0.7, color=fiducial_color, linestyle='dashed')

        #axs[jj].plot(plot.draw_gwb(ax, fiducial_gwonly_model_data['fobs']*1e9, fiducial_gwonly_model_data['gwb'], nsamp=None, fracs=0, color=fiducial_color, plot={'linestyle':'dashdot'})
        #ax.plot(fiducial_gwonly_model_data['fobs']*1e9, powerlawnorm * (fiducial_gwonly_model_data['fobs']/powerlawlowfreq)**(-2./3.), linestyle=':', color=fiducial_color)
        #plot.draw_gwb(ax, fiducial_model_data['fobs']*1e9, fiducial_model_data['gwb'], nsamp=None, fracs=0, color=fiducial_color, plot={'linestyle':'dashed'})

    for ax in plt.gcf().axes:
        try:
            ax.label_outer()
        except:
            pass
    if axs.size - num_pars > 0:
        for ax in axs[-(axs.size - num_pars):]:
            fig.delaxes(ax)
        for i in [-(axs.size - num_pars)]:
            axs[i-ncols].xaxis.set_tick_params(which="major", labelbottom=True)
    fig.supxlabel(xlabels[ii])
    fig.supylabel(r'Number per logarithmic bin')
    plt.show()
