In [None]:
from pathlib import Path

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import h5py

import kalepy as kale

import holodeck as holo
from holodeck.constants import MSOL, PC, YR, GYR, SPLC, EDDT

# 15yr Population Posteriors

Load chains from 15yr Binary Astrophysics analysis to get population parameter posteriors

In [None]:
path_data = Path("./data/astroprior_hdall").resolve()
print(path_data)
assert path_data.is_dir()
fname_pars = path_data.joinpath("pars.txt")
fname_chains = path_data.joinpath("chain_1.0.txt")
print(fname_pars)
print(fname_chains)
assert fname_chains.is_file() and fname_pars.is_file()

In [None]:
chain_pars = np.loadtxt(fname_pars, dtype=str)
chains = np.loadtxt(fname_chains)
npars = len(chain_pars)

# Get maximum likelihood parameters (estimate using KDE)
mlpars = {}
fig, axes = plt.subplots(figsize=[10, 1.5*npars], nrows=npars)
plt.subplots_adjust(hspace=0.75)
for ii, ax in enumerate(axes):
    ax.set(xlabel=chain_pars[ii])
    vals = chains[:, ii]
    extr = holo.utils.minmax(vals)
    xx, yy = kale.density(vals, reflect=extr)
    kale.dist1d(chains[:, ii], ax=ax, density=True, carpet=1000)
    idx = np.argmax(yy)
    xmax = xx[idx]
    ax.axvline(xmax, color='firebrick')
    mlpars[chain_pars[ii]] = xmax

plt.show()


In [None]:
print(f"Maximum Likelihood binary population parameters:")
for kk, vv in mlpars.items():
    print(f"\t{kk:>20s}: {vv:+.2e}")

## Generate Population with ML Parameters

Construct model

In [None]:
# Choose the appropriate Parameter Space (from 15yr astro analysis)
pspace = holo.param_spaces.PS_Uniform_09B
# Load SAM and hardening model for desired parameters
sam, hard = pspace.model_for_params(mlpars)

In [None]:
# compare a couple of the parameters to make sure things look right
print(hard._target_time/GYR, mlpars['hard_time'])
assert np.isclose(hard._target_time/GYR, mlpars['hard_time'])
print(sam._gsmf._phi0, mlpars['gsmf_phi0'])
assert np.isclose(sam._gsmf._phi0, mlpars['gsmf_phi0'])

Calculate number of binaries in a target frequency (period) range.
Takes about 1 minute

In [None]:
# Choose range of orbital periods of interest
tvals = [10.0, 0.1]   # yr
NFBINS = 10
print(f"Considering orbital periods between {tvals} yrs, {NFBINS} bins")
# convert to frequencies
fobs_orb_edges = 1 / np.array(tvals)   # 1/yr
# construct bins
fobs_orb_edges = np.logspace(*np.log10(fobs_orb_edges/YR), NFBINS+1)  # 1/sec
fobs_orb_cents = holo.utils.midpoints(fobs_orb_edges)
fobs = fobs_orb_cents

# calculate (differential) number of binaries
redz_final, diff_num = holo.sams.cyutils.dynamic_binary_number_at_fobs(
    fobs_orb_cents, sam, hard, holo.cosmo
)
# integrate to find total number of binaries in each bin
edges = [sam.mtot, sam.mrat, sam.redz, fobs_orb_edges]
number = holo.sams.cyutils.integrate_differential_number_3dx1d(edges, diff_num)
print(f"Loaded {number.sum():.1e} binaries across frequency range")

In [None]:
temp = number.sum(axis=(0, 1, 2))
fig, ax = plt.subplots()

xx = fobs_orb_cents*YR
yy = temp/np.diff(fobs_orb_edges*YR)
ax.plot(xx, yy)

ax.set(xscale='log', yscale='log', ylabel='Number of Binaries ($dN/df$)', xlabel='Orbial Frequency [1/yr]')
tw = holo.plot._twin_hz(ax)
tw.set_xlabel('orbital frequency [nHz]')
plt.show()

## Variability Models and Observability Cuts

Choose which bins of SAM population are 'observable'

In [None]:
fedd = 0.1   # eddington fraction
bcorr = 0.1  # bolometric correction, bolometric ==> optical

# LSST V-band sensitivity [erg/s/cm^2/Hz]
#    see: https://ui.adsabs.harvard.edu/abs/2019MNRAS.485.1579K/abstract
flux_sens_lsst = 3.0e-30
vband_wlen = 551.0e-7   # [cm]

# get V-band frequency
vband_freq = SPLC/(vband_wlen)   # [Hz]

# get bin-center values for population
mtot = holo.utils.midpoints(sam.mtot)
mrat = holo.utils.midpoints(sam.mrat)
redz = holo.utils.midpoints(sam.redz)
# convert redshift to luminosity-distance
dlum = holo.cosmo.z_to_dlum(redz)

# calculate luminosity of binaries based on Eddington fraction and bolometric correction
lum = EDDT * mtot * fedd * bcorr    # [erg/s]

# calculate flux at observer
# TODO: should really divide by the width of the V-band
flux_tot = lum[:, np.newaxis] / (4*np.pi*dlum[np.newaxis, :]**2) / vband_freq
# get the flux of the secondary, assume that it is what's needed
flux_sec = flux_tot[:, np.newaxis, :] * (mrat / (1.0 + mrat))[np.newaxis, :, np.newaxis]

# select "observable" systems
obs_flag = (flux_sec > flux_sens_lsst)
num_obs = np.sum(obs_flag[..., np.newaxis]*number)
num_all = np.sum(number)
frac_obs = num_obs / num_all
print(f"observable: {num_obs:.2e}/{num_all:.2e} = {frac_obs:.2e}")

# Detectability (Test) Data from Caitlin 

In [None]:
fname_data = Path("./data/export_for_gwb_test.txt")
fname_data = fname_data.absolute().resolve()
print(fname_data, fname_data.exists())
with open(fname_data, 'r') as input:
    det_header = None
    for ii, line in enumerate(input.readlines()):
        line = line.strip()
        print(line)
        if det_header is None:
            if not line.startswith('#'):
                raise ValueError(
                    "First line of file should have stared with a comment including header "
                    "information about the columns!"
                )
            det_header = line.strip(' #').split("  ")
            det_header = [head.strip() for head in det_header]
            print(len(det_header), det_header)

        if ii > 3:
            break

det_data = np.loadtxt(fname_data)
print(f"{det_data.shape=}")

injected = (det_data[:, 11] > 0)
print("injected: ", holo.utils.frac_str(injected))
detected = (det_data[:, 12] > 0)
print("detected: ", holo.utils.frac_str(detected))
print(" both   : ", holo.utils.frac_str(detected & injected))

## Look at the data

In [None]:
indices = [0, 2, 4, 6, 8]

num = len(indices)
fig, axes = plt.subplots(figsize=[20, num*5], ncols=4, nrows=num, sharex='row')

for ii, axrow in enumerate(axes):
    idx = indices[ii]

    xx = det_data[:, idx]
    yy = det_data[:, idx+1]

    for ax in axrow:
        ax.set(xlabel=det_header[idx], ylabel=det_header[idx+1])

    ax = axrow[0]
    if ii == 0:
        ax.set(title='all')
    ax.scatter(xx, yy, alpha=0.2, s=5)

    ax = axrow[1]
    if ii == 0:
        ax.set(title='injected')
    ax.scatter(xx, yy, alpha=0.2, s=14)
    ax.scatter(xx[injected], yy[injected], alpha=0.75, marker='x', s=8, lw=0.5)

    ax = axrow[2]
    if ii == 0:
        ax.set(title='detected')
    ax.scatter(xx, yy, alpha=0.2, s=14)
    ax.scatter(xx[detected], yy[detected], alpha=0.75, marker='x', s=8, lw=0.5)

    ax = axrow[3]
    if ii == 0:
        ax.set(title='both')
    ax.scatter(xx, yy, alpha=0.2, s=14)
    ax.scatter(xx[detected & injected], yy[detected & injected], alpha=0.75, marker='x', s=8, lw=0.5)

plt.show()

In [None]:
indices = [0, 2, 4, 6, 8]

num = len(indices)
fig, axes = plt.subplots(figsize=[20, num*5], ncols=2, nrows=num, sharex='row')
kwargs = dict(alpha=0.5, lw=2.0, density=True, histtype='step')

for ii, axrow in enumerate(axes):
    idx = indices[ii]
    bins = 20

    for jj, ax in enumerate(axrow):
        ax.set(yscale='log', xlabel=det_header[idx+jj], ylabel='Number')
        vals = det_data[:, idx+jj]
        hist, bins, patches = ax.hist(vals, bins=bins, label='all', **kwargs)
        hist, bins, patches = ax.hist(vals[injected], bins=bins, label='injected', **kwargs)
        hist, bins, patches = ax.hist(vals[detected], bins=bins, label='detected', **kwargs)
        hist, bins, patches = ax.hist(vals[injected & detected], bins=bins, label='both', **kwargs)

        ax.legend()

plt.show()

## Calculate a 'detectability' metric

In [None]:
NAMPS = 7
NPERS = 9

def get_idx(key, header):
    for ii, hh in enumerate(header):
        if key.lower() in hh.lower():
            return ii
    else:
        return None

# grab amplitude and period data
amp_idx = get_idx('amp_in', det_header)
period_idx = get_idx('period_in', det_header)
det_amps = det_data[:, amp_idx]
# convert periods from [day] to [yr]
det_pers = det_data[:, period_idx]*24*60*60/YR

# Choose a 2D grid of bin-edges based on the detected amplitudes and periods
sel_flag = injected & detected
amp_edges = np.linspace(*holo.utils.minmax(det_amps[sel_flag]), NAMPS)
per_edges = np.linspace(*holo.utils.minmax(det_pers[sel_flag]), NPERS)
print(f"{amp_edges=}")
print(f"{per_edges=}")
bins = [amp_edges, per_edges]

# find the number of points in each bin
num_all, *_ = sp.stats.binned_statistic_2d(
    det_amps, det_pers, np.ones_like(det_amps), statistic='sum', bins=bins
)
# find the number of injected & detected points in each bin
num_det, *_ = sp.stats.binned_statistic_2d(
    det_amps, det_pers, sel_flag*np.ones_like(det_amps), statistic='sum', bins=bins
)
# The detection fraction is the number of injected & detected points divided by all points
# TODO: should denominator just be the number of injected points???  How to handle false-positives???
det_frac = num_det / num_all

fig, ax = plt.subplots(figsize=[10, 7])
ax.set(xlabel='amplitude [frac]', ylabel='period [yr]')
pcm = ax.pcolormesh(*bins, det_frac.T, shading='auto')
plt.colorbar(pcm, ax=ax, label='detection fraction')

plt.show()


# Calculate LSST Detections

We have a number of binaries `number` over a grid of total-mass, mass-ratio, redshift, and orbital frequency.  The shape is (M, Q, Z, F).  The bin-center values for each dimension are given in:
 * `mtot` (M,) [gram]
 * `mrat` (Q,) [-]
 * `redz` (Z,) [-]
 * `fobs` (F,) [1/sec]

There is a boolean grid of which bins are observable given in `obs_flag`, with the same shape as `number`.

We want to determine what fraction of binaries in each bin are detectable in LSST variability surveys.

We have a grid of detectability fractions for periodic variable AGN in `det_frac` with shape (A, P).  This is over a grid of variability amplitudes given by the array `amp_edges` shaped (A+1,), and orbital periods `per_edges` shaped (P+1,).  We need to map the simulated binaries to this parameter space.

In [None]:
# assume that the variability amplitude exactly equals the mass-ratio
bin_amp = mrat[np.newaxis, :, np.newaxis, np.newaxis] * np.ones_like(number)

# assume that the variability period is exactly the orbital period
_per = (1/fobs/YR)
bin_per = _per[np.newaxis, np.newaxis, np.newaxis, :] * np.ones_like(number)

# convert to 1D arrays, and select out the 'observable' binaries
bin_amp = bin_amp[obs_flag].flatten()
bin_per = bin_per[obs_flag].flatten()
bin_num = number[obs_flag].flatten()

We have the binaries in terms of the variable-detectability parameter space.  So we just need to find the detectability fraction (`det_frac`) for each binary grid-point now.

In [None]:
# returned indices `idx` give the bin number
amp_idx = np.digitize(bin_amp, amp_edges) - 1
per_idx = np.digitize(bin_per, per_edges) - 1
# idx values of -1 mean the value is below the lowest bin, values of B+1 (for B bins) are above the highest bin
# put amplitudes above the highest bin into the highest bin
nbins = amp_edges.size - 1
amp_idx[amp_idx >= nbins] = nbins - 1
# set amplitudes below smallest bin to be invalid, i.e. select only values above the lowest bin
sel_amp = (amp_idx >= 0)
# put periods below the lowest bin, into the lowest bin
per_idx[per_idx < 0] = 0
# set periods above highest bin to be invalid, i.e. select only values below the highest bin
nbins = per_edges.size - 1
print(per_edges.size, holo.utils.stats(per_idx))
sel_per = (per_idx < nbins)

# select valid entires
sel = sel_amp & sel_per
amp_idx = amp_idx[sel]
per_idx = per_idx[sel]
# grab the corresponding numbers of binaries in each of these 'selected' bins
sel_bin_num = bin_num[sel]
print(f"{holo.utils.stats(sel_bin_num)=}")

# convert from indices in each dimension, to an index for the flattened array
sel_dfracs = det_frac[(amp_idx, per_idx)]
print(f"{holo.utils.stats(sel_dfracs)=}")

# find the total number of detectable binaries
# multiply the number of binaries in each bin, by the detection fraction in that bin
num_det_bins = sel_bin_num * sel_dfracs
print(f"{num_det_bins.sum()=:.2e}")
num_all_bins = number.sum()
frac_det_bins = num_det_bins.sum() / num_all_bins
print(f"Total detection fraction: {frac_det_bins:.2e}")

# remind us the fraction of binaries that were 'observable'
print(f"Total 'observability' fraction: {frac_obs:.2e}")

frac_obs_det = num_det_bins.sum() / num_obs
print(f"Det frac of observable: {frac_obs_det:.2e}")