In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py

import kalepy as kale

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})

In [None]:
import zcode.math as zmath
import zcode.plot as zplot

In [None]:
# import holodeck.sam
# reload(holodeck.sam)
# sam = holodeck.sam.BP_Semi_Analytic()

# freqs = 1/YR
# gwb = sam.gwb_sa(freqs)

# hc = np.sqrt(np.sum(gwb, axis=(1, 2, 3)))
# hc.shape, hc

# Discretize Population

In [None]:
import holodeck.sam
reload(holodeck.sam)
sam = holodeck.sam.BP_Semi_Analytic()
nbh = sam.dnbh()

In [None]:
freqs = utils.nyquist_freqs(20.0, 0.1, trim=[None, 5.0])

edges_fobs, num_mbhb_fobs, mbhb_hs = sam.num_mbhb(freqs/YR)
log_edges_fobs = [np.log10(edges_fobs[0]), edges_fobs[1], edges_fobs[2], np.log10(edges_fobs[3])]

num = num_mbhb_fobs.sum()
print(f"{num=:.4e}")

In [None]:
vals = num_mbhb_fobs[..., 0].sum(axis=-1)
plt.pcolormesh(sam.mbh1, sam.mrat, vals.T)
plt.show()

In [None]:
class Corner_Grid:
    
    def __init__(self, edges, data, labels=None):
        shape = [len(ee) for ee in edges]
        ndim = len(shape)
        
        fsize = np.clip(ndim * 5, 8, 24)
        fsize = [fsize, fsize * 0.75]
        fig, axes = plt.subplots(figsize=fsize, ncols=ndim, nrows=ndim, sharex='col')
        
        self._ndim = ndim
        self._shape = shape
        self.fig = fig
        self.axes = axes
        
        self._edges = edges
        self._data = data
        self._labels = labels

        self.setup()
        self.draw(edges, data)
        return
    
    @property
    def last(self):
        return self._ndim - 1        
    
    def setup(self):
        labels = self._labels
    
        def diag(jj, ax):
            ax.set_yscale('log')
            return

        def offdiag(ii, jj, ax):
            if jj > ii:
                ax.set_visible(False)
                return
            
            if labels is None:
                return
            
            if (ii == self.last):
                ax.set_xlabel(labels[jj])
                
            if (jj == 0) and (ii > 0):
                ax.set_ylabel(labels[ii])
            
            return
        
        self.loop(diag, offdiag, skip=False)
        return

    def loop(self, diag, offdiag, skip=True):
        axes = self.axes
        diag_list = []
        for jj, ax in enumerate(axes.diagonal()):
            rv = diag(jj, ax)
            diag_list.append(rv)

        offd_list = []
        for (ii, jj), ax in np.ndenumerate(axes):
            if skip and (jj >= ii):
                continue
            rv = offdiag(ii, jj, ax)
            offd_list.append(rv)
        
        return diag_list, offd_list
    
    def draw(self, edges, data):
        ndim = self._ndim

        def diag(jj, ax):
            xx = edges[jj]
            idx = np.arange(ndim).tolist()
            idx.pop(jj)
            vv = np.sum(data, axis=tuple(idx))
            return self._draw1d(ax, xx, vv)

        def offdiag(ii, jj, ax):
            idx = np.arange(ndim).tolist()
            for kk in [ii, jj]:
                idx.pop(kk)

            vv = np.sum(data, axis=tuple(idx))
            xx = edges[jj]
            yy = edges[ii]
                
            return self._draw2d(ax, [xx, yy], np.log10(vv))
        
        self.loop(diag, offdiag)
        return
    
    def _draw1d(self, ax, edges, hist, **kwargs):
        if len(edges) == len(hist) + 1:
            xx = np.hstack([[edges[jj], edges[jj+1]] for jj in range(len(edges)-1)])
            yy = np.hstack([[hh, hh] for hh in hist])
        elif len(edges) == len(hist):
            xx = edges
            yy = hist
        else:
            raise
            
        line, = ax.plot(xx, yy, **kwargs)
        return

    def _draw2d(self, ax, edges, hist, mask_below=None, **kwargs):
        if mask_below not in [False, None]:
            hist = np.ma.masked_less_equal(hist, mask_below)
        kwargs.setdefault('shading', 'auto')
        # NOTE: this avoids edge artifacts when alpha is not unity!
        kwargs.setdefault('edgecolors', [1.0, 1.0, 1.0, 0.0])
        kwargs.setdefault('linewidth', 0.01)
        # Plot
        rv = ax.pcolormesh(*edges, hist.T, **kwargs)
        return
    
Corner_Grid(log_edges_fobs, num_mbhb_fobs, labels=['log10(M)', 'q', 'z', 'log10(f)'])
plt.show()

In [None]:
Corner_Grid(log_edges_fobs, mbhb_hs, labels=['log10(M)', 'q', 'z', 'log10(f)'])
plt.show()

# Sample full population

In [None]:
# nsamp = num_mbhb_fobs.sum()
# print(f"dirst sum = {nsamp:.8e}")
# nsamp = np.random.poisson(nsamp)
# print(f"poisson   = {nsamp:.8e}")

nsamp = 1e8
samples = kale.sample_grid(log_edges_fobs, num_mbhb_fobs, nsamp)

In [None]:
utils.stats(samples[0])

## Down-Sample Grid

Modify the true distribution to downweight by a continuous function of chirp-mass and redshift.

In [None]:
def downsample_continuous_1(sam, edges_fobs, num_mbhb_fobs, nsamp=1e6):

    _weight_mchirp = zmath.rescale(np.power(sam.mchirp, 5.0/3.0), log=True, clip=True)
    _weight_redz = zmath.rescale(1.0 - sam.redz, log=False, clip=True)
    weight = _weight_mchirp[:, :, np.newaxis, np.newaxis] * _weight_redz[np.newaxis, np.newaxis, :, np.newaxis]

    print(f"{utils.stats(weight)=}")

    weight = weight ** 2

    print(f"{utils.stats(weight)=}")

    weight = weight * np.ones_like(num_mbhb_fobs)
    
    ntot = num_mbhb_fobs.sum()
    weighted = num_mbhb_fobs * weight

    # portion = weight
    # portion = weighted / 30.0
    portion = weighted

    print(f"{ntot=:.8e}")
    poisson_frac = np.random.poisson(ntot) / ntot
    print(f"{poisson_frac-1=:.8e}")
    # nsamp = int(portion.sum() * poisson_frac)
    print(f"{nsamp=:.8e}")

    print("sampling...")
    sample, weights = kale.sample_grid_proportional(edges_fobs, num_mbhb_fobs, portion, nsamp, interpolate=True)
    weights *= (poisson_frac * ntot / weights.sum())
    print(f"{utils.stats(weights)=}")
    print(f"{ntot=:.4e}, {nsamp=:.4e}, {weights.sum()=:.4e}")
    
    return sample, weights

sample, weights = downsample_continuous_1(sam, edges_freqs, num_mbhb_fobs)
print(num_mbhb_fobs.sum(), weights.sum(), weights.sum()/num_mbhb_fobs.sum())

Modify the true distribution to downweight bins as step functions at particular values of each variable.  For example, decrease the weight of all bins with M < 1e8 by 10x, then further decrease the weight of bins with M < 1e7 by another 10x... etc.

In [None]:
def downsample_steps_1(edges_freqs, num_mbhb_fobs, nsamp=1e6, down=10.0):
    portion = np.copy(num_mbhb_fobs)

    print(f"{portion.sum()=:.4e}")
    operators = [np.less, np.less, np.greater, None]
    values = [
        [1e8, 1e7, 1e6],
        [0.1, 0.05],
        [2.0, 4.0, 5.0],
        []
    ]

    for ii, (ee, op, val) in enumerate(zip(edges_freqs, operators, values)):
        portion = np.moveaxis(portion, ii, 0)
        mm = ee
        for vv in val:
            idx = op(mm, vv)
            portion[idx] = portion[idx] / down

        portion = np.moveaxis(portion, 0, ii)

    print(f"{portion.sum()=:.4e}")
    down = num_mbhb_fobs/portion
    down = np.nan_to_num(down)
    print(f"{utils.stats(down)=}")
        
    portion *= num_mbhb_fobs
    portion /= portion.sum()

    ntot = num_mbhb_fobs.sum()
    print(f"{ntot=:.8e}")
    # samp_frac = np.random.poisson(ntot) / ntot
    samp_frac = 1.0
    print(f"{samp_frac-1=:.8e}")
    print(f"{nsamp=:.8e}")

    sample, weights = kale.sample_grid_proportional(edges_freqs, num_mbhb_fobs, portion, nsamp)
    weights *= (ntot * samp_frac / weights.sum())

    print(f"{ntot=:.4e}, {samp_frac=:.4e}, {weights.sum()=:.4e}")
    print(f"{utils.stats(weights)=}")
    
    return sample, weights

sample, weights = downsample_steps_1(edges_freqs, num_mbhb_fobs)

In [None]:
idx = np.argsort(num_mbhb_fobs.flatten())
aa = num_mbhb_fobs.flatten()[idx]
bb = np.cumsum(aa)

In [None]:

def f0(num_mbhb):
    portion = np.copy(num_mbhb)
    idx = (portion > 10.0)
    portion[idx] = 10.0
    return portion


def f1(num_mbhb):
    portion = np.copy(num_mbhb)
    idx = (portion > 10.0)
    portion[idx] = np.sqrt(portion[idx])
    return portion


def f2(num_mbhb):
    portion = np.copy(num_mbhb)
    idx = (portion > 3.0)
    portion[idx] = np.sqrt(portion[idx])
    return portion


def f3(num_mbhb):
    portion = np.copy(num_mbhb)
    idx = (portion > 3.0)
    portion[idx] = 3.0
    return portion


def f4(num_mbhb):
    portion = np.copy(num_mbhb)
    idx = (portion > 10.0)
    portion[idx] = 1.0
    return portion


fig, ax = zplot.figax(xlim=[1e-4, 1e6], ylim=[1e0, 1e10])

funcs = [f0, f1, f2, f3, f4]
plt.plot(aa, bb, 'k--')

for ii, ff in enumerate(funcs):
    num = ff(num_mbhb_fobs).flatten()
    xx = np.sort(num)
    yy = np.cumsum(xx)
    hh, = ax.plot(xx, yy, label=ii)
    ymax = yy[-1]
    print(f"{ii=}, {ymax=:.2e}")
    ax.axhline(ymax, color=hh.get_color(), ls='--')

ax.legend()
plt.show()

In [None]:
fig, ax = zplot.figax()

ww = np.ones_like(samples[0]) * num_mbhb_fobs.sum()/samples.shape[-1]
print(ww[0])
ax.hist(samples[0], weights=ww, bins=log_edges_fobs[0], color='r', histtype='step')
# ax.hist(np.log10(vals[0]), weights=weights, bins=log_edges_fobs[0], color='b', histtype='step')
ax.hist(vals[0], weights=weights, bins=log_edges_fobs[0], color='b', histtype='step')

plt.show()

In [None]:
sampler_outlier = kale.sample.Outlier(log_edges_fobs, num_mbhb_fobs, threshold=10.0)
nsamp, _vals, _weights = sampler_outlier.sample()

In [None]:
CUT_BELOW_MASS_SEC = 1e6

vals = np.copy(_vals)
weights = np.copy(_weights)

if CUT_BELOW_MASS_SEC is not None:
    # cut_below = np.log10(CUT_BELOW_MASS_SEC)
    cut_below = CUT_BELOW_MASS_SEC    
    bads = ((10.0 ** vals[0]) * vals[1] < cut_below)
    print(f"Cutting {zmath.frac_str(bads)}")
    
    vals = vals.T[~bads].T
    weights = weights[~bads]
    print(weights.size, vals.shape)

In [None]:
corner = Corner_Grid(log_edges_fobs, num_mbhb_fobs, labels=['log10(M)', 'q', 'z', 'log10(f)'])

def draw_samples_in_out(corner, in_only, out_only, color):

    sep = np.argmax(weights > 1.0)
    
    def take(*args):
        if out_only:
            args = [xx[:sep] for xx in args]
        elif in_only:
            args = [xx[sep:] for xx in args]
        return args

    def diag(jj, ax):
        bins = log_edges_fobs[jj]
        xx, yy = take(vals[jj], weights)
        ax.hist(xx, weights=yy, bins=bins, color=color, histtype='step', zorder=100, lw=2.0, alpha=0.5)
        return

    def offdiag(ii, jj, ax):
        xx, yy = take(vals[jj], vals[ii])
        ax.scatter(xx, yy, facecolor=color, edgecolor='none', zorder=100, s=5, alpha=0.25)
        return

    corner.loop(diag, offdiag)
    return

draw_samples_in_out(corner, True, False, 'r')
draw_samples_in_out(corner, False, True, 'b')

plt.show()

In [None]:
corner = Corner_Grid(log_edges_fobs, sampler_outlier._data_ins, labels=['log10(M)', 'q', 'z', 'log10(f)'])
draw_samples_in_out(corner, True, False, 'r')

In [None]:
corner = Corner_Grid(log_edges_fobs, sampler_outlier._data_outs, labels=['log10(M)', 'q', 'z', 'log10(f)'])
draw_samples_in_out(corner, False, True, 'r')

In [None]:
mc = np.copy(vals[:2, :])
mc[0] = 10.0 ** mc[0]
mc[1] = mc[1] * mc[0]
print("\t", utils.stats(mc))
mc = utils.chirp_mass(*mc)
print(utils.stats(mc))
dl = vals[2, :]
frst = (10.0 ** vals[3]) * (1.0 + dl)
dl = cosmo.luminosity_distance(dl).cgs.value
print(utils.stats(dl))
print(utils.stats(frst))
hs = utils.gw_strain_source(mc * MSOL, dl, frst)
print(hs.shape, utils.stats(hs))

In [None]:
num_freq = freqs.size - 1
gwb = np.zeros(num_freq)
gwf = np.zeros_like(gwb)
ffr = np.zeros_like(gwb)

for ii in range(num_freq):
    lo = log_edges_fobs[-1][ii]
    hi = log_edges_fobs[-1][ii+1]
    fr_bin = vals[-1, :]
    idx = (lo < fr_bin) & (fr_bin < hi)
    hs_bin = hs[idx]
    fr_bin = fr_bin[idx]
    ww_bin = weights[idx]
    
    floc = np.argmax(hs_bin)
    ffr[ii] = fr_bin[floc]
    gwf[ii] = hs_bin[floc]
    gwb[ii] = np.sum(ww_bin * np.square(hs_bin))
    gwb[ii] = gwb[ii] - gwf[ii]**2
    
ffr = (10.0 ** ffr) * YR
gwb = np.sqrt(gwb)

In [None]:
fig, ax = zplot.figax()

ax.scatter(ffr, gwf, color='r')

xx = kale.utils.midpoints(freqs)
ax.plot(xx, gwb, 'k-')
ax.plot(xx, np.sqrt(gwb**2 + gwf**2), 'b:', alpha=0.2)

amp = 3e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

plt.show()

In [None]:
ii = 0
bins = [sam.mbh1, sam.mrat]

fig, axes = plt.subplots(figsize=[20, 4], ncols=4)

ax = axes[ii]; ii += 1
vals = num_mbhb[..., 0].sum(axis=-1)
# vals = num_mbhb.sum(axis=(-1, -2))

# norm = mpl.colors.LogNorm(vals[vals > 0].min(), vals.max())
# norm = mpl.colors.Normalize(vals.min(), vals.max())
# print(norm.vmin, norm.vmax, norm)
# kw = dict(norm=norm)
kw = dict()

ax.pcolormesh(*bins, vals.T, **kw)


ax = axes[ii]; ii += 1
vals = portion[..., 0].sum(axis=-1)
# vals = portion.sum(axis=(-1, -2))
ax.pcolormesh(*bins, vals.T, **kw)


idx = (sample[-1] < edges[-1][1])

ax = axes[ii]; ii += 1
hist, *_ = np.histogram2d(sample[0][idx], sample[1][idx], bins=bins)
ax.pcolormesh(*bins, hist.T, **kw)


ax = axes[ii]; ii += 1
hist, *_ = np.histogram2d(sample[0][idx], sample[1][idx], bins=bins, weights=weights[idx])
ax.pcolormesh(*bins, hist.T, **kw)


plt.show()

In [None]:
ii = 0
bins = [sam.mbh1, sam.mrat]

fig, ax = plt.subplots(figsize=[20, 10])

vals = num_mbhb[..., 0].sum(axis=-1)

idx = (sample[-1] < edges[-1][1])
hist, *_ = np.histogram2d(sample[0][idx], sample[1][idx], bins=bins, weights=weights[idx])

data = np.ones_like(vals) * np.nan
idx = (hist > 0)
data[idx] = (vals[idx] / hist[idx]) - 1.0
print(utils.stats(data))

data = np.log10(np.fabs(data))

pcm = ax.pcolormesh(*bins, data.T)
plt.colorbar(pcm)

plt.show()

In [None]:
cut = tuple([slice(None), 20, 10, 10])
portion[cut] / num_mbhb[cut]

# Examine convergence properties

In [None]:
FREQ = 1.0 / YR
MSTAR = [8.5, 13.0, 46]
MRAT = [0.02, 1.0, 50]
REDZ = [0.0, 6.0, 61]
# np.logspace(*MSTAR)

### mass

In [None]:
mstar_args = [
    [8.5, 13, 41],    
    [8.5, 13.5, 41],    
    [8.5, 14.0, 41],    
    [8.5, 13, 41],    
    [8.5, 13, 61],    
    [8.5, 13, 81],    
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, mstar in enumerate(mstar_args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=mstar, mrat=MRAT, redz=REDZ)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(mstar))
    
plt.legend()
plt.show()

### mass-ratio

In [None]:
args = [
    [0.02, 1.0, 40],
    [0.02, 1.0, 50],
    [0.02, 1.0, 60],
    [0.02, 1.0, 70],
    [0.02, 1.0, 80],
    [0.02, 1.0, 160],
    [0.02, 1.0, 320],
    [0.02, 1.0, 640],
    [0.02, 1.0, 160],
    [0.002, 1.0, 160],
    [0.0002, 1.0, 160],
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, arg in enumerate(args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=arg, redz=REDZ)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()

### redshift

In [None]:
fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii in range(10):
    arg = [0.0, 1.0 + ii, int((1.0 + ii)/0.05)]
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=MRAT, redz=arg)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()

In [None]:
args = [
    [0.0, 6.0, 40],
    [0.0, 6.0, 80],
    [0.0, 6.0, 100],
    [0.0, 6.0, 150],
    [0.0, 6.0, 200],
    [0.0, 6.0, 250],
    [0.0, 6.0, 300],
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, arg in enumerate(args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=MRAT, redz=arg)
    
    gwb = sam.gwb_sa(FREQ)   # [:, :, :, 1:]
    hc = np.sqrt(np.sum(gwb))

    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()