In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})

# Discretize Population

In [None]:
sam = holodeck.sam.BP_Semi_Analytic()
nbh = sam.dnbh()

In [None]:
freqs = utils.nyquist_freqs(20.0, 0.1, trim=[None, 5.0])

edges_fobs, num_mbhb_fobs, mbhb_hs = sam.num_mbhb(freqs/YR)
log_edges_fobs = [np.log10(edges_fobs[0]), edges_fobs[1], edges_fobs[2], np.log10(edges_fobs[3])]

num = num_mbhb_fobs.sum()
print(f"{num=:.4e}")

In [None]:
fig, ax = plt.subplots(figsize=[10, 5])
ax.set(xlabel='log10 Mass Primary', ylabel='Mass Ratio')
vals = num_mbhb_fobs[..., 0].sum(axis=-1)
plt.pcolormesh(sam.mbh1, sam.mrat, vals.T)
plt.show()

In [None]:
class Corner_Grid:
    
    def __init__(self, edges, data, labels=None):
        shape = [len(ee) for ee in edges]
        ndim = len(shape)
        
        fsize = np.clip(ndim * 5, 8, 24)
        fsize = [fsize, fsize * 0.75]
        fig, axes = plt.subplots(figsize=fsize, ncols=ndim, nrows=ndim, sharex='col')
        
        self._ndim = ndim
        self._shape = shape
        self.fig = fig
        self.axes = axes
        
        self._edges = edges
        self._data = data
        self._labels = labels

        self.setup()
        self.draw(edges, data)
        return
    
    @property
    def last(self):
        return self._ndim - 1        
    
    def setup(self):
        labels = self._labels
    
        def diag(jj, ax):
            ax.set_yscale('log')
            return

        def offdiag(ii, jj, ax):
            if jj > ii:
                ax.set_visible(False)
                return
            
            if labels is None:
                return
            
            if (ii == self.last):
                ax.set_xlabel(labels[jj])
                
            if (jj == 0) and (ii > 0):
                ax.set_ylabel(labels[ii])
            
            return
        
        self.loop(diag, offdiag, skip=False)
        return

    def loop(self, diag, offdiag, skip=True):
        axes = self.axes
        diag_list = []
        for jj, ax in enumerate(axes.diagonal()):
            rv = diag(jj, ax)
            diag_list.append(rv)

        offd_list = []
        for (ii, jj), ax in np.ndenumerate(axes):
            if skip and (jj >= ii):
                continue
            rv = offdiag(ii, jj, ax)
            offd_list.append(rv)
        
        return diag_list, offd_list
    
    def draw(self, edges, data):
        ndim = self._ndim

        def diag(jj, ax):
            xx = edges[jj]
            idx = np.arange(ndim).tolist()
            idx.pop(jj)
            vv = np.sum(data, axis=tuple(idx))
            return self._draw1d(ax, xx, vv)

        def offdiag(ii, jj, ax):
            idx = np.arange(ndim).tolist()
            for kk in [ii, jj]:
                idx.pop(kk)

            vv = np.sum(data, axis=tuple(idx))
            xx = edges[jj]
            yy = edges[ii]
                
            return self._draw2d(ax, [xx, yy], np.log10(vv))
        
        self.loop(diag, offdiag)
        return
    
    def _draw1d(self, ax, edges, hist, **kwargs):
        if len(edges) == len(hist) + 1:
            xx = np.hstack([[edges[jj], edges[jj+1]] for jj in range(len(edges)-1)])
            yy = np.hstack([[hh, hh] for hh in hist])
        elif len(edges) == len(hist):
            xx = edges
            yy = hist
        else:
            raise
            
        line, = ax.plot(xx, yy, **kwargs)
        return

    def _draw2d(self, ax, edges, hist, mask_below=None, **kwargs):
        if mask_below not in [False, None]:
            hist = np.ma.masked_less_equal(hist, mask_below)
        kwargs.setdefault('shading', 'auto')
        # NOTE: this avoids edge artifacts when alpha is not unity!
        kwargs.setdefault('edgecolors', [1.0, 1.0, 1.0, 0.0])
        kwargs.setdefault('linewidth', 0.01)
        # Plot
        rv = ax.pcolormesh(*edges, hist.T, **kwargs)
        return
    
# Corner_Grid(log_edges_fobs, num_mbhb_fobs, labels=['log10(M)', 'q', 'z', 'log10(f)'])
# plt.show()

## Down-Sample Grid

In [None]:
outliers = kale.sample.Sample_Outliers(log_edges_fobs, num_mbhb_fobs, threshold=10.0)
nsamp, _vals, _weights = outliers.sample()

In [None]:
# CUT_BELOW_MASS_SEC = 1e6
CUT_BELOW_MASS_SEC = None

vals = np.copy(_vals)
weights = np.copy(_weights)

if CUT_BELOW_MASS_SEC is not None:
    print(f"{vals.shape=}")
    cut_below = CUT_BELOW_MASS_SEC    
    bads = ((10.0 ** vals[0]) * vals[1] < cut_below)
    
    vals = vals.T[~bads].T
    weights = weights[~bads]
    print(f"{vals.shape=}")

In [None]:
bins = (sam.mbh1, sam.mrat)
nums = num_mbhb_fobs[..., 0].sum(axis=-1)

fig, axes = plt.subplots(figsize=[15, 6], ncols=2)

for ax in axes:
    ax.set(xlabel='log10 Mass Primary', ylabel='Mass Ratio')

ax = axes[0]
ax.pcolormesh(*bins, nums.T)


ax = axes[1]
idx = (vals[3] < np.log10(freqs[1]/YR))
hist2d, *_ = np.histogram2d(vals[0, idx], vals[1, idx], bins=bins, weights=weights[idx])
ax.pcolormesh(*bins, hist2d.T)

plt.show()


In [None]:
def mc_gws_from_samples(vals, weights, threshold=10.0, cut_below_mass=1e6):
    """
    
    Arguments
    ---------
    fobs : units are [1/yr]
    
    """

    if cut_below_mass is not None:
        bads = ((10.0 ** vals[0]) * vals[1] < cut_below_mass)
        vals = vals.T[~bads].T
        weights = weights[~bads]
        
    vals[0] = 10.0 ** vals[0]
    vals[1] = vals[1] * vals[0]
    mc = utils.chirp_mass(vals[0], vals[1])
    dl = vals[2, :]
    frst = (10.0 ** vals[3]) * (1.0 + dl)
    dl = cosmo.luminosity_distance(dl).cgs.value
    hs = utils.gw_strain_source(mc * MSOL, dl, frst)
    fo = vals[-1, :]
    del vals
    
    gff, gwf, gwb = holo.sam.gws_from_sampled_strains(log_edges_fobs[-1], fo, hs, weights)
    gff = (10.0 ** gff) * YR

    return gff, gwf, gwb

gwf_freqs, gwf, gwb = mc_gws_from_samples(vals, weights, threshold=10.0, cut_below_mass=1e6)

In [None]:
fig, ax = plt.subplots()
ax.set(xscale='log', yscale='log')

ax.scatter(gwf_freqs, gwf, color='r')

xx = kale.utils.midpoints(freqs)
ax.plot(xx, gwb, 'k-')
# ax.plot(xx, np.sqrt(gwb**2 + gwf**2), 'b:', alpha=0.2)

amp = 3e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

plt.show()