In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py

import kalepy as kale

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})

In [None]:
import holodeck.sam
reload(holodeck.sam)
sam = holodeck.sam.BP_Semi_Analytic()

freqs = 1/YR
gwb = sam.gwb_sa(freqs)

hc = np.sqrt(np.sum(gwb, axis=(1, 2, 3)))
hc.shape, hc

# Discretize Population

In [None]:
import holodeck.sam
reload(holodeck.sam)
sam = holodeck.sam.BP_Semi_Analytic()
edges = sam.edges
nbh = sam.dnbh()

In [None]:
# num_mbhb_fobs.shape
self.mchirp

In [None]:
freqs = utils.nyquist_freqs(20.0, 0.1, trim=[None, 5.0])
num_mbhb_fobs, hs = sam.num_mbhb(freqs/YR)
edges_freqs = sam.edges + [freqs,]
num = num_mbhb_fobs.sum()
print(f"{num=:.4e}")
num = np.random.poisson(num)
print(f"\t{num:.4e}")

In [None]:
vals = num_mbhb_fobs[..., 0].sum(axis=-1)
plt.pcolormesh(sam.mbh1, sam.mrat, vals.T)
plt.show()

In [None]:
np.less()

In [None]:
DOWN = 10.0
num_mbhb = kale.utils.midpoints(num_mbhb_fobs, axis=None)
portion = np.copy(num_mbhb)
# portion = kale.utils.midpoints(num_mbhb_fobs, axis=None)
# portion = np.copy(num_mbhb_fobs)

print(f"{portion.sum()=:.4e}")
# edges = edges_freqs
edges = [sam.mbh1, sam.mrat, sam.redz, freqs]
operators = [np.less, np.less, np.greater, None]
values = [
    [8.0, 7.0, 6.0],
    [0.1, 0.05],
    [2.0, 4.0, 6.0],
    []
]

for ii, (ee, op, val) in enumerate(zip(edges, operators, values)):
    portion = np.moveaxis(portion, ii, 0)
    mm = kale.utils.midpoints(ee)
    # mm = ee
    for vv in val:
        idx = op(mm, vv)
        portion[idx] = portion[idx] / DOWN

    portion = np.moveaxis(portion, 0, ii)

print(f"{portion.sum()=:.4e}")
down = num_mbhb/portion
down = np.nan_to_num(down)
print(utils.stats(down))

In [None]:
# sample = kale.sample_grid(edges_fobs, num_mbhb_fobs, num/1000)

nsamp = num_mbhb_fobs.sum()
print(f"{nsamp=:.8e}")
nsamp = np.random.poisson(nsamp) / nsamp
print(f"{1-nsamp=:.8e}")
nsamp = int(portion.sum() * nsamp)
print(f"{nsamp=:.8e}")
# sample, weights = kale.sample_grid_proportional(edges, num_mbhb_fobs, portion, nsamp)
sample, weights = kale.sample_grid_proportional(edges, num_mbhb, portion, nsamp)

In [None]:
sample[0]

In [None]:
ii = 0
bins = [sam.mbh1, sam.mrat]

fig, axes = plt.subplots(figsize=[20, 4], ncols=4)

ax = axes[ii]; ii += 1
vals = num_mbhb[..., 0].sum(axis=-1)
# vals = num_mbhb.sum(axis=(-1, -2))

# norm = mpl.colors.LogNorm(vals[vals > 0].min(), vals.max())
# norm = mpl.colors.Normalize(vals.min(), vals.max())
# print(norm.vmin, norm.vmax, norm)
# kw = dict(norm=norm)
kw = dict()

ax.pcolormesh(*bins, vals.T, **kw)


ax = axes[ii]; ii += 1
vals = portion[..., 0].sum(axis=-1)
# vals = portion.sum(axis=(-1, -2))
ax.pcolormesh(*bins, vals.T, **kw)


idx = (sample[-1] < edges[-1][1])

ax = axes[ii]; ii += 1
hist, *_ = np.histogram2d(sample[0][idx], sample[1][idx], bins=bins)
ax.pcolormesh(*bins, hist.T, **kw)


ax = axes[ii]; ii += 1
hist, *_ = np.histogram2d(sample[0][idx], sample[1][idx], bins=bins, weights=weights[idx])
ax.pcolormesh(*bins, hist.T, **kw)


plt.show()

In [None]:
ii = 0
bins = [sam.mbh1, sam.mrat]

fig, ax = plt.subplots(figsize=[20, 10])

vals = num_mbhb[..., 0].sum(axis=-1)

idx = (sample[-1] < edges[-1][1])
hist, *_ = np.histogram2d(sample[0][idx], sample[1][idx], bins=bins, weights=weights[idx])

data = np.ones_like(vals) * np.nan
idx = (hist > 0)
data[idx] = (vals[idx] / hist[idx]) - 1.0
print(utils.stats(data))

data = np.log10(np.fabs(data))

pcm = ax.pcolormesh(*bins, data.T)
plt.colorbar(pcm)

plt.show()

In [None]:
cut = tuple([slice(None), 20, 10, 10])
portion[cut] / num_mbhb_fobs[cut]

In [None]:
num_mbhb = sam.num_mbhb()

In [None]:
freqs = utils.nyquist_freqs(20.0, 0.1, trim=[None, 5.0])
print(freqs.size, freqs)
frest = freqs[np.newaxis, :] * (1.0 + sam.redz[:, np.newaxis])
print(frest.shape)

m1 = sam.mbh1[:, np.newaxis, np.newaxis]   # (m1, redz, freq)
m1 = (10.0 ** m1) * MSOL
sepa = utils.kepler_sep_from_freq(m1, frest[np.newaxis, ...]/YR)
print(f"{sepa.shape=}, {utils.minmax(sepa/PC)=}")
m1 = m1[:, np.newaxis, :, :]   # (m1, z, f) ==> (m1, q, redz, freq)
m2 = sam.mbh2[:, :, np.newaxis, np.newaxis]  # m1, q ==> m1, q, f, z
m2 = (10.0 ** m2) * MSOL
time = utils.time_to_merge_at_sep(m1, m2, sepa[:, np.newaxis, :, :])
time[time < 0.0] = np.nan
print(f"{time.shape=}, {utils.minmax(time/GYR)=}")

fobs = 0.5 * (freqs[1:] + freqs[:-1])
dt = -np.diff(time, axis=-1)
dt = np.nan_to_num(dt)
print(f"{dt.shape=}, {utils.minmax(dt/GYR)=}")
tot_time = time[..., 0, np.newaxis]
print(f"{tot_time.shape=}, {utils.minmax(tot_time/GYR)=}")
mrat = sam.mrat[np.newaxis, :, np.newaxis, np.newaxis]
zz = sam.redz[np.newaxis, np.newaxis, :, np.newaxis]
mtime = sam.merger_time(m1, mrat, zz) * GYR
print(f"{mtime.shape=}, {utils.minmax(mtime/GYR)=}")
tot_time += mtime
tfrac = dt / tot_time
print(f"{tfrac.shape=}, {utils.minmax(tfrac)=}")
temp = utils.stats(tfrac.sum(axis=-1))
print(temp)

num_mbhb_fobs = num_mbhb[..., np.newaxis] * tfrac
edges_fobs = edges + [fobs]
edges_fobs[0] = np.log10(edges_fobs[0])

In [None]:
sam.mchirp.shape

In [None]:
num = num_mbhb_fobs.sum()
print(f"{num=:.4e}")
num = np.random.poisson(num)
print(f"\t{num:.4e}")

In [None]:
sample = kale.sample_grid(edges_fobs, num_mbhb_fobs, num/1000)
# sample[0, :] = np.power(10.0, sample[0, :])

In [None]:
for ss in sample:
    print(utils.minmax(ss), np.mean(ss))

In [None]:
corner = kale.Corner(sample)
corner.plot_data()
plt.show()

In [None]:
sam.mbh2.shape

In [None]:
import zcode.plot as zplot

# Examine convergence properties

In [None]:
FREQ = 1.0 / YR
MSTAR = [8.5, 13.0, 46]
MRAT = [0.02, 1.0, 50]
REDZ = [0.0, 6.0, 61]
# np.logspace(*MSTAR)

### mass

In [None]:
mstar_args = [
    [8.5, 13, 41],    
    [8.5, 13.5, 41],    
    [8.5, 14.0, 41],    
    [8.5, 13, 41],    
    [8.5, 13, 61],    
    [8.5, 13, 81],    
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, mstar in enumerate(mstar_args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=mstar, mrat=MRAT, redz=REDZ)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(mstar))
    
plt.legend()
plt.show()

### mass-ratio

In [None]:
args = [
    [0.02, 1.0, 40],
    [0.02, 1.0, 50],
    [0.02, 1.0, 60],
    [0.02, 1.0, 70],
    [0.02, 1.0, 80],
    [0.02, 1.0, 160],
    [0.02, 1.0, 320],
    [0.02, 1.0, 640],
    [0.02, 1.0, 160],
    [0.002, 1.0, 160],
    [0.0002, 1.0, 160],
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, arg in enumerate(args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=arg, redz=REDZ)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()

### redshift

In [None]:
fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii in range(10):
    arg = [0.0, 1.0 + ii, int((1.0 + ii)/0.05)]
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=MRAT, redz=arg)
    gwb = sam.gwb_sa(FREQ)
    hc = np.sqrt(np.sum(gwb))
    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()

In [None]:
args = [
    [0.0, 6.0, 40],
    [0.0, 6.0, 80],
    [0.0, 6.0, 100],
    [0.0, 6.0, 150],
    [0.0, 6.0, 200],
    [0.0, 6.0, 250],
    [0.0, 6.0, 300],
]

fig, ax = zplot.figax(scale='linear')

strain = []
labels = []
for ii, arg in enumerate(args):
    sam = holodeck.sam.BP_Semi_Analytic(mstar_pri=MSTAR, mrat=MRAT, redz=arg)
    
    gwb = sam.gwb_sa(FREQ)   # [:, :, :, 1:]
    hc = np.sqrt(np.sum(gwb))

    ax.plot(ii+1, hc, marker='o', label=str(arg))
    
plt.legend()
plt.show()