In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm
import tqdm.notebook

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5, 'pcolor.shading': 'gouraud'})

In [None]:
import zcode.math as zmath
import zcode.plot as zplot

# Discretize Population

In [None]:
import holodeck.sam
reload(holodeck.sam)
sam = holodeck.sam.BP_Semi_Analytic()
nbh = sam.dnbh()

In [None]:
sam = holodeck.sam.BP_Semi_Analytic()
nbh = sam.dnbh()

sam_new = holo.sam.Semi_Analytic_Model()
nbh_new = sam_new.density

In [None]:
sam = holodeck.sam.BP_Semi_Analytic()
nbh = sam.dnbh()

sam_new = holo.sam.Semi_Analytic_Model()
nbh_new = sam_new.density

def compare(mt, mr, rz):
    zz = zmath.argnearest(sam.redz, rz)
    qq = zmath.argnearest(sam.mrat, mr)
    olds = nbh[:, qq, zz]
    m1 = (10.0 ** sam.mbh1)
    m2 = mr * m1
    # xx = m1 + m2
    # y1 = zmath.interp(mt, xx, olds)
    xx = m1
    y1 = zmath.interp(mt / (1 + mr), xx, olds)
    
    zz = zmath.argnearest(sam_new.redz, rz)
    qq = zmath.argnearest(sam_new.mrat, mr)
    news = nbh_new[:, qq, zz]
#     print(mt, sam_new.mtot)
    y2 = zmath.interp(mt, sam_new.mtot, news)
    return y1, y2

mt = np.logspace(5, 10, 100)

fig, ax = zplot.figax(yscale='log') # , xlim=[1e8, 3e11])

qlist = [0.05, 0.05, 0.5, 0.5, 1.0, 1.0]
zlist = [0.1, 1.0, 0.1, 1.0, 0.1, 1.0]
for qq, zz in zip(qlist, zlist):
    y1, y2 = compare(mt, qq, zz)
#     hh, = ax.plot(mt, y1, label=[qq, zz], ls='--', lw=3.0, alpha=0.5)
#     ax.plot(mt, y2, color=hh.get_color())

#     rr = (y2-y1)/y1
#     rr = np.fabs(rr)

    rr = y2/y1
    tt = zmath.argfirst(rr > 0)
    print(rr[tt])
    hh, = ax.plot(mt, rr, label=[qq, zz])

plt.legend()
plt.show()


In [None]:
idx = 1

fig, axes = zplot.figax(figsize=[15, 5], ncols=3, ylim=[1e-2, 1])
vals = [nbh[:, :, idx], nbh_new[:, :, idx]]
rat = np.zeros_like(vals[0])
idx = vals[0] > 0.0
rat[idx] = (vals[1][idx] - vals[0][idx]) / vals[0][idx]

mr = np.ones_like(rat) * sam_new.mrat[np.newaxis, :]
idx = (mr < 0.1)
print(zmath.stats_str(rat[idx]))

smap = zplot.smap(vals, scale='linear')
vals.append(rat)

levels = None
for ii, (ax, vv) in enumerate(zip(axes, vals)):
    if ii == 2:
        smap = zplot.smap([-1, 1], scale='linear')
        # smap = zplot.smap(vv, scale='linear')
        levels = None
        
    ax.pcolormesh(sam_new.mtot, sam_new.mrat, vv.T, cmap=smap.cmap, norm=smap.norm)
    qcs = ax.contour(sam_new.mtot, sam_new.mrat, vv.T, levels=levels)
    cbar = plt.colorbar(smap, orientation='horizontal', ax=ax)
    zplot.draw_colorbar_contours(cbar, qcs, smap=smap)
    levels = qcs.levels
    
plt.show()

In [None]:
fobs = utils.nyquist_freqs()
sam = holo.sam.Semi_Analytic_Model()
# edges, num, strain = sam.number_at_gw_fobs(fobs, limit_merger_time=True)
ff = zmath.argnearest(fobs*YR, 1.0)
hs = np.sqrt(np.sum(num[..., ff] * np.square(strain[..., ff])))
print(f"{num.sum():.4e}, {hs:.4e}")

In [None]:
sam = holo.sam.Semi_Analytic_Model()
gwb = sam.gwb(fobs, realize=100)

In [None]:
fig, ax = zplot.figax()
zplot.conf_fill(ax, fobs, gwb)

In [None]:
ratio = nbh / nbh_new
print(zmath.stats_str(ratio))

In [None]:
freqs = utils.nyquist_freqs(20.0, 0.1, trim=[None, 5.0])

edges_fobs, num_mbhb_fobs, mbhb_hs = sam.num_mbhb(freqs/YR)
log_edges_fobs = [np.log10(edges_fobs[0]), edges_fobs[1], edges_fobs[2], np.log10(edges_fobs[3])]

num = num_mbhb_fobs.sum()
print(f"num={num:.4e}")

In [None]:
log_edges_fobs = [np.log10(edges_fobs[0]), edges_fobs[1], edges_fobs[2], np.log10(edges_fobs[3])]
outliers = kale.sample.Sample_Outliers(log_edges_fobs, num_mbhb_fobs, threshold=10.0)
# vals, weights = kale.sample_outliers(log_edges_fobs, num_mbhb_fobs, 10.0)
# print(vals.shape, weights.shape, zmath.argfirst(weights > 1.0))
nsamp = outliers._data_outs.sum()
print(f"nsamp={nsamp:.2e}")

In [None]:
NREALS = 10
nsamp = np.random.poisson(NREALS * nsamp)
print(f"nsamp={nsamp:.2e}")

In [None]:
vals = num_mbhb_fobs[..., 0].sum(axis=-1)
plt.pcolormesh(sam.mbh1, sam.mrat, vals.T)
plt.show()

In [None]:
class Corner_Grid:
    
    def __init__(self, edges, data, labels=None):
        shape = [len(ee) for ee in edges]
        ndim = len(shape)
        
        fsize = np.clip(ndim * 5, 8, 24)
        fsize = [fsize, fsize * 0.75]
        fig, axes = plt.subplots(figsize=fsize, ncols=ndim, nrows=ndim, sharex='col')
        
        self._ndim = ndim
        self._shape = shape
        self.fig = fig
        self.axes = axes
        
        self._edges = edges
        self._data = data
        self._labels = labels

        self.setup()
        self.draw(edges, data)
        return
    
    @property
    def last(self):
        return self._ndim - 1        
    
    def setup(self):
        labels = self._labels
    
        def diag(jj, ax):
            ax.set_yscale('log')
            return

        def offdiag(ii, jj, ax):
            if jj > ii:
                ax.set_visible(False)
                return
            
            if labels is None:
                return
            
            if (ii == self.last):
                ax.set_xlabel(labels[jj])
                
            if (jj == 0) and (ii > 0):
                ax.set_ylabel(labels[ii])
            
            return
        
        self.loop(diag, offdiag, skip=False)
        return

    def loop(self, diag, offdiag, skip=True):
        axes = self.axes
        diag_list = []
        for jj, ax in enumerate(axes.diagonal()):
            rv = diag(jj, ax)
            diag_list.append(rv)

        offd_list = []
        for (ii, jj), ax in np.ndenumerate(axes):
            if skip and (jj >= ii):
                continue
            rv = offdiag(ii, jj, ax)
            offd_list.append(rv)
        
        return diag_list, offd_list
    
    def draw(self, edges, data):
        ndim = self._ndim

        def diag(jj, ax):
            xx = edges[jj]
            idx = np.arange(ndim).tolist()
            idx.pop(jj)
            vv = np.sum(data, axis=tuple(idx))
            return self._draw1d(ax, xx, vv)

        def offdiag(ii, jj, ax):
            idx = np.arange(ndim).tolist()
            for kk in [ii, jj]:
                idx.pop(kk)

            vv = np.sum(data, axis=tuple(idx))
            xx = edges[jj]
            yy = edges[ii]
                
            return self._draw2d(ax, [xx, yy], np.log10(vv))
        
        self.loop(diag, offdiag)
        return
    
    def _draw1d(self, ax, edges, hist, **kwargs):
        if len(edges) == len(hist) + 1:
            xx = np.hstack([[edges[jj], edges[jj+1]] for jj in range(len(edges)-1)])
            yy = np.hstack([[hh, hh] for hh in hist])
        elif len(edges) == len(hist):
            xx = edges
            yy = hist
        else:
            raise
            
        line, = ax.plot(xx, yy, **kwargs)
        return

    def _draw2d(self, ax, edges, hist, mask_below=None, **kwargs):
        if mask_below not in [False, None]:
            hist = np.ma.masked_less_equal(hist, mask_below)
        kwargs.setdefault('shading', 'auto')
        # NOTE: this avoids edge artifacts when alpha is not unity!
        kwargs.setdefault('edgecolors', [1.0, 1.0, 1.0, 0.0])
        kwargs.setdefault('linewidth', 0.01)
        # Plot
        rv = ax.pcolormesh(*edges, hist.T, **kwargs)
        return
    
# Corner_Grid(log_edges_fobs, num_mbhb_fobs, labels=['log10(M)', 'q', 'z', 'log10(f)'])
# plt.show()

In [None]:
# Corner_Grid(log_edges_fobs, mbhb_hs, labels=['log10(M)', 'q', 'z', 'log10(f)'])
# plt.show()

# Sample full population

In [None]:
# nsamp = num_mbhb_fobs.sum()
# print(f"dirst sum = {nsamp:.8e}")
# nsamp = np.random.poisson(nsamp)
# print(f"poisson   = {nsamp:.8e}")

# nsamp = 1e8
# samples = kale.sample_grid(log_edges_fobs, num_mbhb_fobs, nsamp)

## Down-Sample Grid

Modify the true distribution to downweight by a continuous function of chirp-mass and redshift.

In [None]:
outliers = kale.sample.Sample_Outliers(log_edges_fobs, num_mbhb_fobs, threshold=10.0)
nsamp, _vals, _weights = outliers.sample()

In [None]:
print(f"Fraction of samples that are 'outliers': {zmath.frac_str(_weights == 1.0)}")

In [None]:
CUT_BELOW_MASS_SEC = 1e6

vals = np.copy(_vals)
weights = np.copy(_weights)

if CUT_BELOW_MASS_SEC is not None:
    # cut_below = np.log10(CUT_BELOW_MASS_SEC)
    cut_below = CUT_BELOW_MASS_SEC    
    bads = ((10.0 ** vals[0]) * vals[1] < cut_below)
    print(f"Cutting {zmath.frac_str(bads)}")
    
    vals = vals.T[~bads].T
    weights = weights[~bads]
    print(f"vals.shape={vals.shape}")

In [None]:
corner = Corner_Grid(log_edges_fobs, num_mbhb_fobs, labels=['log10(M)', 'q', 'z', 'log10(f)'])
SKIP = None
SKIP = 3

def draw_samples_in_out(corner, in_only, out_only, color):

    sep = np.argmax(weights > 1.0)
    
    def take(*args):
        if out_only:
            args = [xx[:sep] for xx in args]
        elif in_only:
            args = [xx[sep:] for xx in args]

        if SKIP is not None:
            args = [xx[::SKIP] for xx in args]
        return args

    def diag(jj, ax):
        bins = log_edges_fobs[jj]
        xx, yy = take(vals[jj], weights)
        ax.hist(xx, weights=yy, bins=bins, color=color, histtype='step', zorder=100, lw=2.0, alpha=0.5)
        return

    def offdiag(ii, jj, ax):
        xx, yy = take(vals[jj], vals[ii])
        ax.scatter(xx, yy, facecolor=color, edgecolor='none', zorder=100, s=5, alpha=0.25)
        return

    corner.loop(diag, offdiag)
    return

draw_samples_in_out(corner, True, False, 'r')
draw_samples_in_out(corner, False, True, 'b')

plt.show()

In [None]:
corner = Corner_Grid(log_edges_fobs, outliers._data_ins, labels=['log10(M)', 'q', 'z', 'log10(f)'])
draw_samples_in_out(corner, True, False, 'r')

In [None]:
corner = Corner_Grid(log_edges_fobs, outliers._data_outs, labels=['log10(M)', 'q', 'z', 'log10(f)'])
draw_samples_in_out(corner, False, True, 'r')

In [None]:
mc = np.copy(vals[:2, :])
mc[0] = 10.0 ** mc[0]
mc[1] = mc[1] * mc[0]
print("\t", utils.stats(mc))
mc = utils.chirp_mass(*mc)
print(utils.stats(mc))
dl = vals[2, :]
frst = (10.0 ** vals[3]) * (1.0 + dl)
dl = cosmo.luminosity_distance(dl).cgs.value
print(utils.stats(dl))
print(utils.stats(frst))
hs = utils.gw_strain_source(mc * MSOL, dc, frst)
print(hs.shape, utils.stats(hs))

In [None]:
num_freq = freqs.size - 1
gwb = np.zeros(num_freq)
gwf = np.zeros_like(gwb)
ffr = np.zeros_like(gwb)

for ii in range(num_freq):
    lo = log_edges_fobs[-1][ii]
    hi = log_edges_fobs[-1][ii+1]
    fr_bin = vals[-1, :]
    idx = (lo < fr_bin) & (fr_bin < hi)
    hs_bin = hs[idx]
    fr_bin = fr_bin[idx]
    ww_bin = weights[idx]
    pp = (ww_bin > 1.0)
    ww_bin[pp] = np.random.poisson(ww_bin[pp])
    
    floc = np.argmax(hs_bin)
    ffr[ii] = fr_bin[floc]
    gwf[ii] = hs_bin[floc]
    gwb[ii] = np.sum(ww_bin * np.square(hs_bin))
    gwb[ii] = gwb[ii] - gwf[ii]**2
    
ffr = (10.0 ** ffr) * YR
gwb = np.sqrt(gwb)

In [None]:
fig, ax = zplot.figax()

ax.scatter(ffr, gwf, color='r')

# xx = kale.utils.midpoints(freqs)
xx = zmath.midpoints(freqs, log=False)
ax.plot(xx, gwb, 'k-')
ax.plot(xx, np.sqrt(gwb**2 + gwf**2), 'b:', alpha=0.2)

amp = 3e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

plt.show()

In [None]:
DUR = 20.0
CAD = 0.1

fobs = utils.nyquist_freqs(DUR, CAD)
gwb_cont = holo.sam.gwb_continuous(sam, fobs/YR)
gwb_cont_tot = gwb_cont.sum(axis=(1, 2, 3))
gwb_cont_tot = np.sqrt(gwb_cont_tot)

In [None]:
gff, gwf, gwb = holo.sam.gwb_discrete(sam, fobs/YR)

In [None]:
xx = kale.utils.midpoints(fobs)

fig, ax = zplot.figax()
ax.plot(fobs, gwb_cont_tot)
ax.plot(xx, gwb)
ax.plot(gff*YR, gwf, 'ro', alpha=0.2)

plt.show()

## Calculate GW Realizations

In [None]:
# log_edges_fobs = [np.log10(edges_fobs[0]), edges_fobs[1], edges_fobs[2], np.log10(edges_fobs[3])]
# vals, weights = kale.sample_outliers(log_edges_fobs, num_mbhb_fobs, 10.0)

In [None]:
def mc_gws_from_sam(edges_fobs, num_mbhb_fobs, threshold=10.0, cut_below_mass=1e6):
    """
    
    Arguments
    ---------
    fobs : units are [1/yr]
    
    """

    log_edges_fobs = [np.log10(edges_fobs[0]), edges_fobs[1], edges_fobs[2], np.log10(edges_fobs[3])]

    if cut_below_mass is not None:
        m2 = edges_fobs[0][:, np.newaxis] * edges_fobs[1][np.newaxis, :]
        bads = (m2 < cut_below_mass)
        nbef = num_mbhb_fobs.sum()
        num_mbhb_fobs[bads] = 0.0
        naft = num_mbhb_fobs.sum()
    
    vals, weights = kale.sample_outliers(log_edges_fobs, num_mbhb_fobs, threshold)
    
    if cut_below_mass is not None:
        bads = ((10.0 ** vals[0]) * vals[1] < cut_below_mass)
        vals = vals.T[~bads].T
        weights = weights[~bads]
        
    vals[0] = 10.0 ** vals[0]
    vals[1] = vals[1] * vals[0]
    mc = utils.chirp_mass(vals[0], vals[1])
    dl = vals[2, :]
    frst = (10.0 ** vals[3]) * (1.0 + dl)
    dl = cosmo.luminosity_distance(dl).cgs.value
    hs = utils.gw_strain_source(mc * MSOL, dc, frst)
    fo = vals[-1, :]
    del vals
    
    gff, gwf, gwb = holo.sam.gws_from_sampled_strains(log_edges_fobs[-1], fo, hs, weights)
    gff = (10.0 ** gff) * YR

    return gff, gwf, gwb


NREALS = 10
DUR = 20.0
CAD = 0.1

fobs = utils.nyquist_freqs(DUR, CAD)
gwf_freqs = np.zeros((fobs.size - 1, NREALS))
gwf = np.zeros_like(gwf_freqs)
gwb = np.zeros_like(gwf_freqs)

print("num_mbhb")
edges_fobs, num_mbhb_fobs, _ = sam.num_mbhb(fobs/YR)

for rr in tqdm.notebook.tqdm(range(NREALS)):
    _gff, _gwf, _gwb = mc_gws_from_sam(edges_fobs, num_mbhb_fobs, threshold=10.0, cut_below_mass=1e6)
    gwf_freqs[:, rr] = _gff
    gwf[:, rr] = _gwf
    gwb[:, rr] = _gwb
    break

In [None]:
fig, ax = zplot.figax(figsize=[16, 10])

xx = zmath.midpoints(fobs, log=False)

for rr in range(NREALS):
    hh, = ax.plot(xx, gwb[:, rr], lw=0.75, alpha=0.35)
    col = hh.get_color()
    idx = (gwf[:, rr] > gwb[:, rr] * 0.5)
    ax.scatter(gwf_freqs[idx, rr], gwf[idx, rr], alpha=0.25, color=col, s=5)
    
# med, conf, ci = zmath.confidence_intervals(gwf, percs=[0.5, 0.9, 0.98], axis=-1, return_ci=True)
# ax.plot(xx, med, 'r-', alpha=0.25)
# conf = conf.squeeze()
# for ii in range(ci.size):
#     ax.fill_between(xx, *conf[:, ii, :].T, color='r', alpha=0.15)

# med, conf, ci = zmath.confidence_intervals(gwf, percs=0.9, axis=-1, return_ci=True)
# ax.plot(xx, conf[:, 0, -1], 'r-', alpha=0.25)

med, conf = zmath.confidence_intervals(gwb, percs=0.5, axis=-1)
conf = conf.squeeze()
ax.plot(xx, med, 'b-')
ax.fill_between(xx, *conf.T, color='b', alpha=0.25)

amp_loc = 0.1
amp = zmath.interp(amp_loc, xx, med)
yy = amp * np.power(xx/amp_loc, -2.0/3.0)
ax.plot(xx, yy, 'k--', lw=2.5, alpha=0.5)

ax.set(ylim=[3e-18, 1e-14], xlim=zmath.minmax(xx, log_stretch=0.02))
plt.show()

In [None]:
log_edges_fobs = [np.log10(edges_fobs[0]), edges_fobs[1], edges_fobs[2], np.log10(edges_fobs[3])]
vals, weights = kale.sample_outliers(log_edges_fobs, num_mbhb_fobs, 10.0)

In [None]:
print([bb.size for bb in bins])
[bb.size for bb in log_edges_fobs]

In [None]:
ii = 0
bins = [sam.mbh1, sam.mrat]

fig, ax = plt.subplots(figsize=[20, 10])

nums = num_mbhb_fobs[..., 0].sum(axis=-1)
nums = kale.utils.midpoints(nums, axis=None)

idx = (vals[-1] < log_edges_fobs[-1][1])
hist, *_ = np.histogram2d(vals[0][idx], vals[1][idx], bins=bins, weights=weights[idx])
print(nums.shape, hist.shape)

data = np.ones_like(nums) * np.nan
idx = (hist > 0)
data[idx] = (nums[idx] / hist[idx]) - 1.0
print(utils.stats(data))

data = np.log10(np.fabs(data))

pcm = ax.pcolormesh(*bins, data.T)
plt.colorbar(pcm)

plt.show()

# Examine convergence properties

In [None]:
FREQ = 1.0 / YR
MSTAR = [8.5, 13.0, 46]
MRAT = [0.02, 1.0, 50]
REDZ = [0.0, 6.0, 61]
# np.logspace(*MSTAR)

### mass

In [None]:
mstar_args = [
    [8.5, 13, 41],    
    [8.5, 13.5, 41],    
    [8.5, 14.0, 41],    
    [8.5, 13, 41],    
    [8.5, 13, 61],    
    [8.5, 13, 81],    
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, mstar in enumerate(mstar_args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=mstar, mrat=MRAT, redz=REDZ)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(mstar))
    
plt.legend()
plt.show()

### mass-ratio

In [None]:
args = [
    [0.02, 1.0, 40],
    [0.02, 1.0, 50],
    [0.02, 1.0, 60],
    [0.02, 1.0, 70],
    [0.02, 1.0, 80],
    [0.02, 1.0, 160],
    [0.02, 1.0, 320],
    [0.02, 1.0, 640],
    [0.02, 1.0, 160],
    [0.002, 1.0, 160],
    [0.0002, 1.0, 160],
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, arg in enumerate(args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=arg, redz=REDZ)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()

### redshift

In [None]:
fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii in range(10):
    arg = [0.0, 1.0 + ii, int((1.0 + ii)/0.05)]
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=MRAT, redz=arg)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()

In [None]:
args = [
    [0.0, 6.0, 40],
    [0.0, 6.0, 80],
    [0.0, 6.0, 100],
    [0.0, 6.0, 150],
    [0.0, 6.0, 200],
    [0.0, 6.0, 250],
    [0.0, 6.0, 300],
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, arg in enumerate(args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=MRAT, redz=arg)
    
    gwb = sam.gwb_sa(FREQ)   # [:, :, :, 1:]
    hc = np.sqrt(np.sum(gwb))

    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()