In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
# import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import utils, plot, cosmo
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})

log = holo.log
log.setLevel(logging.INFO)

# Quick Start

In [None]:
# ---- Construct S.A.M. with default settings ----
sam = holo.sam.Semi_Analytic_Model(mtot=[2.75e5*MSOL, 1.0e11*MSOL, 23], mrat=[0.02, 1.0, 25], redz=[0.0, 6.0, 31])

In [None]:
# ---- Choose frequencies, and calculate GWB ----
fobs = utils.nyquist_freqs(15.0*YR, 0.1*YR)
print(utils.stats(1/(fobs*YR)), "[yr]")

In [None]:
gwb = sam.gwb(fobs, realize=10)    # calculate 10 different realizations

In [None]:
# ---- Plot GWB ----

fig, ax = plt.subplots(figsize=[15, 8])
ax.set(xscale='log', xlabel='Frequency [1/yr]', yscale='log', ylabel='Characteristic Strain')
ax.grid(alpha=0.2)

xx = fobs * YR
yy = 1e-15 * np.power(xx, -2.0/3.0)
ax.plot(xx, yy, 'k--', alpha=0.25, lw=2.0)
ax.plot(xx, np.median(gwb, axis=-1), 'k-')
for pp in [50, 98]:
    percs = pp / 2
    percs = [50 - percs, 50 + percs]
    ax.fill_between(xx, *np.percentile(gwb, percs, axis=-1), alpha=0.25, color='b')

plt.show()

# build SAM component-by-component

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.sam.MMBulge_Simple()     # M-MBulge Relation            (MMB)

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge)

In [None]:
# ---- Calculate GWB Amplitude distribution at 1/yr ----
ayr = sam.gwb(1.0/YR, realize=100)

In [None]:
fig, ax = plt.subplots(figsize=[10, 6])
ax.set(xlabel=r'$\log_{10}(A_\mathrm{yr})$', ylabel='Probability Density')
ax.grid(alpha=0.2)
kale.dist1d(ayr, density=True, confidence=True)

plt.show()

## Plot GWB Amplitude Distribution vs. M-MBulge parameters

In [None]:
alpha_list = [0.75, 1.0, 1.25]
norm_list = np.logspace(7, 9, 11)
NREALS = 100
FREQ = 1.0 / YR   # [1/s]

dist_mmb = np.zeros((len(alpha_list), norm_list.size, NREALS))

for aa, alpha in enumerate(tqdm.tqdm(alpha_list)):
    for nn, norm in enumerate(tqdm.tqdm(norm_list, leave=False)):
        mmbulge = holo.sam.MMBulge_Simple(mass_norm=norm*MSOL, malpha=alpha)
        sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge)
        dist_mmb[aa, nn, :] = sam.gwb(FREQ, realize=NREALS)

In [None]:
fig, ax = plt.subplots(figsize=[6, 4])
ax.set(xscale='log', xlabel='M-MBulge Mass', yscale='log', ylabel=r'GWB $A_\mathrm{yr}$')
ax.grid(alpha=0.2)

dist = dist_mmb

for aa, dd in zip(alpha_list, dist):
    med = np.median(dd, axis=-1)
    cc, = ax.plot(norm_list, med, label=aa)
    cc = cc.get_color()
    ax.fill_between(norm_list, *np.percentile(dd, [25, 75], axis=-1), color=cc, alpha=0.25)
    
plt.legend(title='slope')
plt.show()

## Plot GWB Amplitude vs. Schecter Mass Parameter

In [None]:
mz = np.array([0.0, 1e10, 1e11])
m0 = np.logspace(10, 12, 3)
NREALS = 100
FREQ = 1.0 / YR   # [1/s]

dist_gsmf = np.zeros((len(mz), len(m0), NREALS))

for aa, _mz in enumerate(tqdm.tqdm(mz)):
    for nn, _m0 in enumerate(tqdm.tqdm(m0, leave=False)):
        gsmf = holo.sam.GSMF_Schechter(mref0=_m0 * MSOL, mrefz=_mz * MSOL)
        sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge)
        dist_gsmf[aa, nn, :] = sam.gwb(FREQ, realize=NREALS)

In [None]:
fig, ax = plt.subplots(figsize=[16, 6])
ax.set(xscale='log', xlabel='Schechter Mass', yscale='log', ylabel=r'GWB $A_\mathrm{yr}$')
ax.grid(alpha=0.2)

for aa, dd in zip(mz, dist_gsmf):
    med = np.median(dd, axis=-1)
    lab = f"$10^{{{np.log10(aa):+.1f}}}$" if aa > 0.0 else f"${aa:+.1f}$"
    cc, = ax.plot(m0, med, label=lab)
    cc = cc.get_color()
    ax.fill_between(m0, *np.percentile(dd, [25, 75], axis=-1), color=cc, alpha=0.25)
    
plt.legend(title='$M_0(z)$')
plt.show()

# Discretize Population

In [None]:
sam = holo.sam.Semi_Analytic_Model()
fobs = utils.nyquist_freqs(20.0*YR, 0.1*YR)
# fobs = utils.nyquist_freqs(200.0*YR, 1.0*YR)

In [None]:
gff, gwf, gwb = holo.sam.sampled_gws_from_sam(
    sam, fobs, sample_threshold=10.0,
    cut_below_mass=1e7*MSOL, limit_merger_time=4*GYR
)

In [None]:
fig, ax = plot.figax()
xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

ax.plot(xx, gwb, 'k-')

idx = (gwf > gwb)
xx = gff * YR   # [1/sec] ==> [1/yr]
ax.scatter(xx[idx], gwf[idx], color='r', s=20, alpha=0.5)
ax.scatter(xx[~idx], gwf[~idx], edgecolor='r', facecolor='none', s=20, alpha=0.5)

plt.show()

In [None]:
vals, weights, edges, dens = holo.sam.sample_sam_with_hardening(
        sam, holo.evolution.Hard_GW, fobs=fobs,
        sample_threshold=10.0, cut_below_mass=1e7, limit_merger_time=4*GYR,
)

In [None]:
gwf_freqs, gwf, gwb = holo.sam._gws_from_samples(vals, weights, fobs)

fig, ax = plot.figax()

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

# ff = np.sqrt(xx / np.diff(fobs*YR))
ff = 1.0

ax.plot(xx, gwb*ff, 'k-')

idx = (gwf > gwb)
xx = gwf_freqs * YR   # [1/sec] ==> [1/yr]
ax.scatter(xx[idx], (gwf*ff)[idx], color='r', s=20, alpha=0.5)
ax.scatter(xx[~idx], (gwf*ff)[~idx], edgecolor='r', facecolor='none', s=20, alpha=0.5)

# ax.plot(fobs*YR, gwb_smooth, 'b--')
# ax.plot(fobs*YR, np.median(gwb_rough, axis=-1), 'b:')
# ax.fill_between(fobs*YR, *np.percentile(gwb_rough, [25, 75], axis=-1), color='b', alpha=0.25)

plt.show()

## Step by Step

In [None]:
edges, dnum = sam.number_from_hardening(holo.evolution.Hard_GW, fobs=fobs)

In [None]:
dens = np.copy(dnum)
log_edges = [np.log10(edges[0]), edges[1], edges[2], np.log(edges[3])]
cut_below_mass = 1e7 * MSOL
# cut_below_mass = None
if cut_below_mass is not None:
    m2 = edges[0][:, np.newaxis] * edges[1][np.newaxis, :]
    bads = (m2 < cut_below_mass)
    dens[bads] = 0.0

mass = holo.sam._integrate_differential_number(edges, dens, freq=True)

In [None]:
sample_threshold = 10.0
vals, weights = kale.sample_outliers(log_edges, dens, sample_threshold, mass=mass)
vals[0] = 10.0 ** vals[0]
vals[3] = np.e ** vals[3]

In [None]:
idx_m = 16
idx_q = 48
idx_f = 20
truth = np.sum(mass[idx_m, idx_q, :, idx_f])
print(truth)

bnd_m = edges[0][idx_m:idx_m+2]
bnd_q = edges[1][idx_q:idx_q+2]
bnd_f = edges[3][idx_f:idx_f+2]
df = np.diff(bnd_f)
fdf = np.mean(bnd_f) / df
dlnf = np.diff(np.log(bnd_f))
# print(bnd_m, bnd_q, bnd_f, df, dlnf)

sel = (bnd_m[0] < vals[0]) & (vals[0] < bnd_m[1])
sel = sel & (bnd_q[0] < vals[1]) & (vals[1] < bnd_q[1])
sel = sel & (bnd_f[0] < vals[3]) & (vals[3] < bnd_f[1])
test = weights[sel].sum()
print(f"{test:.8e}, {truth:.8e}", truth/test, sp.stats.poisson.cdf(truth, test))

In [None]:
bins = (sam.mtot/MSOL, sam.mrat)

nums = mass.sum(axis=(-1, -2))

extr = [nums[nums > 0].min()/2, 2*nums.max()]
extr[0] = np.max([extr[0], 1e-2])
norm = mpl.colors.LogNorm(*extr)

fig, axes = plt.subplots(figsize=[18, 6], ncols=3)

for ax in axes:
    ax.set(xscale='log', xlabel='log10 Mass Primary', ylabel='Mass Ratio')

ax = axes[0]
ax.set(title='Continuous')
pcm = ax.pcolormesh(*bins, nums.T, norm=norm)
plt.colorbar(pcm, ax=ax, orientation='horizontal')


ax = axes[1]
ax.set(title='Discrete (sampled)')
# idx = (vals[3] < fobs[1])
idx = slice(None)
hist, *_ = np.histogram2d(vals[0, idx]/MSOL, vals[1, idx], bins=bins, weights=weights[idx])

idx = (weights == 1.0)
ax.scatter(vals[0, idx]/MSOL, vals[1, idx], color='r', alpha=0.2, s=50)

pcm = ax.pcolormesh(*bins, hist.T, norm=norm)
plt.colorbar(pcm, ax=ax, orientation='horizontal')


ax = axes[2]
ax.set(title='Error')
# test = kale.utils.midpoints(nums, log=True, axis=None)
# diff = np.zeros_like(test)
# idx = (test > 0.0)
# diff[idx] = (hist[idx] - test[idx]) / test[idx]
# diff[~idx] = hist[~idx]
diff = (hist - nums) / nums
print("hist = ", utils.stats(hist[np.isfinite(hist)]))
print("test = ", utils.stats(test[np.isfinite(test)]))
print("diff = ", utils.stats(diff), utils.minmax(diff))
print(np.argmax(diff))
extr = diff
extr = [diff.min(), 2.0]
# extr = [-0.5, 0.5]
smap = plot.smap(extr, midpoint=0.0)
pcm = ax.pcolormesh(*bins, diff.T, cmap=smap.cmap, norm=smap.norm)
plt.colorbar(pcm, ax=ax, orientation='horizontal')

plt.show()

In [None]:
idx = (weights == 1.0) & (vals[0] < 0.5e9*MSOL)
plt.scatter(*np.log10(vals.T[idx].T))

In [None]:
# vals, weights, ee, dn = holo.sam.sample_sam_with_hardening(
#     sam, holo.evolution.Hard_GW, fobs=fobs,
#     sample_threshold=10.0, cut_below_mass=1e6*MSOL, limit_merger_time=2.0*GYR
# )

## Calculate GWs

In [None]:
gwb_smooth = sam.gwb(fobs, realize=False)
gwb_rough = sam.gwb(fobs, realize=33)

In [None]:
# use_vals = vals
# use_weights = weights

use_vals = np.copy(vals)
use_weights = np.copy(weights)

# idx = (vals[2] > 0.2)
# use_vals = use_vals.T[idx].T
# use_weights = use_weights[idx]

gwf_freqs, gwf, gwb = holo.sam._gws_from_samples(use_vals, use_weights, fobs)

In [None]:
fig, ax = plot.figax()

xx = kale.utils.midpoints(fobs) * YR   # [1/sec] ==> [1/yr]

amp = 10e-16
yy = amp * np.power(xx, -2/3)
ax.plot(xx, yy, 'k--', alpha=0.25)

ff = np.sqrt(xx / np.diff(fobs*YR))
# ff = 1.0

ax.plot(xx, gwb*ff, 'k-')

idx = (gwf > gwb)
xx = gwf_freqs * YR   # [1/sec] ==> [1/yr]
ax.scatter(xx[idx], (gwf*ff)[idx], color='r', s=20, alpha=0.5)
ax.scatter(xx[~idx], (gwf*ff)[~idx], edgecolor='r', facecolor='none', s=20, alpha=0.5)

ax.plot(fobs*YR, gwb_smooth, 'b--')
ax.plot(fobs*YR, np.median(gwb_rough, axis=-1), 'b:')
ax.fill_between(fobs*YR, *np.percentile(gwb_rough, [25, 75], axis=-1), color='b', alpha=0.25)

plt.show()

In [None]:
breaker()

## Use realistic hardening rate to sample population

In [None]:
gsmf = holo.sam.GSMF_Schechter()
gpf = holo.sam.GPF_Power_Law()
gmt = holo.sam.GMT_Power_Law()
mmbulge = holo.sam.MMBulge_Simple()

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge)

In [None]:
hard = holo.evolution.Fixed_Time.from_sam(sam, 2.0*GYR)

In [None]:
fobs = np.logspace(-2, 2, 20)
edges_fobs, num_fobs, strain_fobs = sam.number_from_hardening(hard, fobs=fobs)

In [None]:
utils.stats(num_fobs), num_fobs.shape, sam.shape

### Number vs frequency, plot GWB

In [None]:
fobs = np.logspace(-2, 2, 20)
edges_fobs, number_fobs, strain_fobs = sam.number_from_hardening(hard, fobs=fobs)
# edges_fobs, number_fobs, strain_fobs = sam.number_from_hardening(holo.evolution.Hard_GW, fobs=fobs)

In [None]:
NREALS = 30
fig, ax = plot.figax(ylabel='Characteristic Strain', xlabel='Frequency [1/yr]', ylim=[3e-18, 2e-15])

# Draw directly calculate GWB (smooth) spectrum
yy = number_fobs
zz = strain_fobs**2 * yy
zz = zz.sum(axis=(0, 1, 2))
zz = np.sqrt(zz)
ax.plot(fobs, zz)

# Draw poisson variations
shp = yy.shape + (NREALS,)
zz = strain_fobs[..., np.newaxis]**2 * np.random.poisson(yy[..., np.newaxis], size=shp)
zz = zz.sum(axis=(0, 1, 2))
zz = np.sqrt(zz)
med, *conf = utils.quantiles(zz, [0.5, 0.25, 0.75], axis=1).T
ax.plot(fobs, med, 'k--')
ax.fill_between(fobs, *conf, color='k', alpha=0.2)

# Draw analytic Estimate
xx = 1.0
tt = sam.gwb(xx, realize=False)
aa, mm, bb = np.percentile(tt, [25, 50, 75])
print(f"mm={mm:.2e}")
ax.plot([xx, xx], [aa, bb], 'r-', lw=2.0, alpha=0.5)
ax.scatter(xx, mm, color='r', alpha=0.5)

plt.show()

### Number vs. Separation - sample

In [None]:
sepa = np.logspace(-6, 4, 50)
edges, number, strain = sam.number_from_hardening(hard, sepa=sepa)

In [None]:
fig, ax = plot.figax()
yy = number
dlna = np.diff(np.log(edges[-1]))[0]
yy = yy * dlna
yy = yy.sum(axis=(0, 1, 2))
ax.plot(sepa, yy)
plt.show()