In [None]:
# %load ../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import zcode.math as zmath
import zcode.plot as zplot

# Look at 1D distribution of Mass

## Choose stellar-masses, Calculate the GSMF, translate to MBH masses

In [None]:
redz = 1.0
NBINS = 5
# mstar_edges = np.logspace(11, 12, 10) * MSOL
mstar_edges = np.logspace(10, 12.5, NBINS+1) * MSOL

mmbulge = holo.host_relations.MMBulge_Standard()
gsmf = holo.sam.GSMF_Schechter()

phi_mstar = gsmf(mstar_edges, redz)    # [Mpc^-3]
mbh_edges = mmbulge.mbh_from_mstar(mstar_edges, False)
dmdm = mmbulge.dmstar_dmbh(mstar_edges)

phi_mbh = phi_mstar * dmdm * mbh_edges / mstar_edges

xx = utils.midpoints(mstar_edges / MSOL)
yy = utils.trapz(phi_mstar, xx=np.log10(mstar_edges), cumsum=False)

fig, ax = plot.figax()
ax.plot(xx, yy, 'b-')

tw = ax.twiny()
tw.set(xscale='log', xlabel='BH Mass [Msol]')
xx = utils.midpoints(mbh_edges / MSOL)
yy = utils.trapz(phi_mbh, xx=np.log10(mbh_edges), cumsum=False)
tw.plot(xx, yy, 'r--')

plt.show()

## Simplest Method: Add the effects of scatter to (re-)calculate MBH density

In [None]:
UNITS = MSOL

dist = sp.stats.norm(loc=0.0, scale=mmbulge._scatter_dex)

def weights_for_bin(edges, cents, dist, bin):

    # Get locations of the edges to the left of this bin (ll) and right of this bin (rr)
    ll = edges[:bin+1]
    rr = edges[bin+1:]
    # Get the center of this bin
    cc = cents[bin]
    # Find the distance from this bin (center) to edges of left and right bins
    ll = ll - cc   # negative (left)
    rr = rr - cc   # positive (right)
    # Calculate the distribution's CDF, centered at zero, at these different distances
    wl = dist.cdf(ll)
    wr = dist.cdf(rr)
    # Find the mass in each bin by differencing the CDF across bin edges
    wl = np.diff(wl)    # mass in each bin to the left
    wr = np.diff(wr)    # mass in each bin to the right
    # print(bin, ll.size, rr.size, wl.size, wr.size)

    # Find the mass that will remain in this bin, i.e. the mass between this bin's edges
    wc = dist.cdf([ll[-1], rr[0]])
    wc = np.diff(wc)
    # combine the left, center, right weights together
    weights = np.concatenate([wl, wc, wr])
    return weights

def simple_scatter_redistribute(edges, dist, dens):
    log_edges = np.log10(edges)
    log_cents = utils.midpoints(log_edges, log=False)
    dens_new = np.zeros_like(dens)
    for bin in range(log_cents.size):
        weights = weights_for_bin(log_edges, log_cents, dist, bin)
        for jj in range(log_cents.size):
            dens_new[jj] += phi_mbh[bin] * weights[jj]

    return dens_new

phi_new = simple_scatter_redistribute(mbh_edges/UNITS, dist, phi_mbh)

fig, ax = plot.figax()

xx = utils.midpoints(mbh_edges / MSOL)
yy = utils.trapz(phi_mbh, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, yy, 'k-', alpha=0.5, lw=2.0)

zz = utils.trapz(phi_new, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, zz, 'r--')

# yy = utils.trapz(phi_mbh, xx=np.log10(mbh_edges), cumsum=False)
yy = simple_scatter_redistribute(mbh_edges/UNITS, dist, yy)
ax.plot(xx, yy, 'b--')

plt.show()

## Try to be a bit smarter

In [None]:
UNITS = MSOL

def roll_rows(arr, r_tup):
    roll = np.asarray(r_tup)
    assert np.ndim(arr) == 2 and np.ndim(roll) == 1
    nrows, ncols = arr.shape
    assert roll.size == nrows
    arr_roll = arr[:, [*range(ncols), *range(ncols-1)]].copy()
    strd_0, strd_1 = arr_roll.strides
    result = np.lib.stride_tricks.as_strided(arr_roll, (nrows, ncols, ncols), (strd_0, strd_1, strd_1))
    return result[np.arange(nrows), (ncols - roll)%ncols]


def get_weights(log_edges, dist):
    num = log_edges.size
    # Get log-spacing between edges, this must be constant to work in this way!
    dx = np.diff(log_edges)
    assert np.allclose(dx, dx[0]), "This method only works if `log_edges` are uniformly spaced!"
    dx = dx[0]
    # The bin edges are at distance [dx/2, 1.5*dx, 2.5*dx, ...]
    dx = dx/2.0 + np.arange(num) * dx
    # Convert to both sides:  [..., -1.5*dx, -0.5dx, +0.5dx, +1.5dx, ...]
    dx = np.concatenate([-dx[::-1], dx])
    # Get the mass across each interval by differencing the CDF at each edge location
    dm = np.diff(dist.cdf(dx))
    return dm
    

def scatter_redistribute(edges, dist, dens, axis=0):
    log_edges = np.log10(edges)
    num = log_edges.size
    # Get the fractional weights that this bin should be redistributed to
    # (2*N - 1,)  giving the bins all the way to the left and the right
    # e.g. [-N+1, -N+2, ..., -2, -1, 0, +1, +2, ..., +N-2, +N-1]
    weights = get_weights(log_edges, dist)

    # Duplicate the weights into each row of an (N, N) matrix
    # e.g. [[-N+1, -N+2, ..., -2, -1, 0, +1, +2, ..., +N-2, +N-1]
    #       [-N+1, -N+2, ..., -2, -1, 0, +1, +2, ..., +N-2, +N-1]
    #       [-N+1, -N+2, ..., -2, -1, 0, +1, +2, ..., +N-2, +N-1]
    #        ...
    weights = weights[np.newaxis, :] * np.ones((num, weights.size))
    # Need to "roll" each row of the matrix such that the central bin is at number index=row
    #    rolls backward by default, 
    roll = 1 - num + np.arange(num)
    # Roll each row
    # e.g. [[ 0, +1, +2, ..., +N-2, +N-1, -N+1, -N+2, ..., -2, -1]
    #       [-1,  0, +1, +2, ..., +N-2, +N-1, -N+1, -N+2, ..., -2]
    #       [-2, -1,  0, +1, +2, ..., +N-2, +N-1, -N+1, -N+2, ..., -3]
    #        ...
    weights = roll_rows(weights, roll)
    # Cutoff each row after N elements
    weights = weights[:, :num]

    # Perform the convolution
    dens = np.moveaxis(dens, axis, 0)
    dens_new = np.einsum("j...,jk...", dens, weights)
    dens_new = np.moveaxis(dens_new, 0, axis)
    dens = np.moveaxis(dens, 0, axis)
    return dens_new

print(mbh_edges.size, phi_mbh.shape)
phi_new = scatter_redistribute(mbh_edges/UNITS, dist, phi_mbh)

fig, ax = plot.figax()

xx = utils.midpoints(mbh_edges / MSOL)
yy = utils.trapz(phi_mbh, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, yy, 'k-', alpha=0.5, lw=2.0)

yy = utils.trapz(phi_new, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, yy, 'r--')

plt.show()

In [None]:
xx = mbh_edges/MSOL
y1 = phi_mbh.copy()
y2 = 10 * phi_mbh * (xx/xx.min()) ** -0.5

z1 = scatter_redistribute(mbh_edges/UNITS, dist, y1)
z2 = scatter_redistribute(mbh_edges/UNITS, dist, y2)

fig, ax = plot.figax()

xx = utils.midpoints(mbh_edges / MSOL)
a1 = utils.trapz(y1, xx=np.log10(mbh_edges), cumsum=False)
b1 = utils.trapz(z1, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, a1, 'b-', alpha=0.5, lw=2.0)
ax.plot(xx, b1, 'b--')

a2 = utils.trapz(y2, xx=np.log10(mbh_edges), cumsum=False)
b2 = utils.trapz(z2, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, a2, 'r-', alpha=0.5, lw=2.0)
ax.plot(xx, b2, 'r--')

yy = np.vstack([y1, y2]).T
print(yy.shape)
zz = scatter_redistribute(mbh_edges/UNITS, dist, yy)
for z in zz:
    z = utils.trapz(z, xx=np.log10(mbh_edges), cumsum=False)
    ax.plot(xx, z, ls=':', color='purple', lw=3.0, alpha=0.5)


plt.show()

# 2D with mtot and mrat

In [None]:
gsmf = holo.sam.GSMF_Schechter()
gpf = holo.sam.GPF_Power_Law()
gmt = holo.sam.GMT_Power_Law()

NBINS = 30
redz = 0.1
mtot = np.logspace(10, 12.5, NBINS+1) * MSOL
mrat = np.logspace(-2, 0, NBINS+2)

m1, m2 = utils.m1m2_from_mtmr(mtot[:, np.newaxis], mrat[np.newaxis, :])
dens = gsmf(m1, redz) * gpf(m1, mrat, redz)
print(dens.shape)

NSAMP = 1e6
samps = kale.sample_grid([mtot, mrat], dens, nsamp=NSAMP)
print(samps.shape)

hist, *_ = np.histogram2d(*samps, bins=(mtot, mrat), density=True)

fig, axes = plot.figax(ncols=4)
density_flag = True
axes[0].pcolormesh(mtot, mrat, dens.T)
axes[1].pcolormesh(mtot, mrat, hist.T)
axes[2].hist(samps[0], bins=mtot, density=density_flag)
axes[3].hist(samps[1], bins=mrat, density=density_flag)

plt.show()



## Demonstrate translating between   mt, mr <==> m1, m2   and then back to   m1, m2 <==> mt, mr

In [None]:
USE_LINEAR_INTERP_BACKUP = False
fig, ax = plot.figax()

smap = zplot.smap(dens, 'viridis', scale='log')
cc = smap.to_rgba(dens)

m1m2_on_mtmr_grid = (m1.flatten(), m2.flatten())

ax.scatter(*m1m2_on_mtmr_grid, c=cc.reshape(-1, 4), zorder=100, s=5, edgecolor='0.5', lw=0.5)
xx = np.array(sorted(m1.flatten()))
ax.plot(xx, xx)
ax.plot(xx, xx*1e-2)

REFINE = 4
grid_size = m1.shape[0]*REFINE
mextr = zmath.minmax([mtot[0]*mrat[0]/(1.0 + mrat[0]), mtot[-1]*(1.0 + mrat[0])/mrat[0]])
m1_grid = zmath.spacing(mextr, 'log', grid_size)
for gg in m1_grid:
    ax.axhline(gg, lw=0.25)
    ax.axvline(gg, lw=0.25)

m1m2_grid = np.meshgrid(m1_grid, m1_grid, indexing='ij')
m1m2_dens = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='cubic')
if USE_LINEAR_INTERP_BACKUP:
    bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
    temp = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='linear')
    print(zmath.frac_str(bads), zmath.frac_str(np.isnan(temp[bads])))
    m1m2_dens[bads] = temp[bads]

bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
print(zmath.frac_str(bads))
temp = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='nearest')
m1m2_dens[bads] = temp[bads]
bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
print(zmath.frac_str(bads))
ax.pcolormesh(m1_grid, m1_grid, m1m2_dens, alpha=0.8, cmap=smap.cmap, norm=smap.norm)
        
plt.show()

In [None]:
interp = sp.interpolate.RegularGridInterpolator((m1_grid, m1_grid), m1m2_dens)
ww = interp((m1.flatten(), m2.flatten()), method='linear').reshape(m1.shape)
ww.shape

In [None]:
fig, axes = plot.figax(ncols=3)

err = (ww - dens) / ww
print(zmath.stats_str(err))

smap = zplot.smap([dens, ww], cmap='viridis', scale='log')
err_map = zplot.ScalarMappable2D

values = [ww, dens, err]
maps = [smap, smap, None]
for ax, mm, vv in zip(axes, maps, values):
    kw = {} if mm is None else dict(cmap=smap.cmap, norm=smap.norm)
    pcm = ax.pcolormesh(mtot, mrat, vv.T, **kw)
    plt.colorbar(pcm, ax=ax, orientation='horizontal')

plt.show()

## Add Scatter in m1, m2 space and translate back

In [None]:
USE_LINEAR_INTERP_BACKUP = False
REFINE = 4

m1m2_on_mtmr_grid = (m1.flatten(), m2.flatten())
grid_size = m1.shape[0]*REFINE
mextr = zmath.minmax([mtot[0]*mrat[0]/(1.0 + mrat[0]), mtot[-1]*(1.0 + mrat[0])/mrat[0]])
m1_grid = zmath.spacing(mextr, 'log', grid_size)

m1m2_grid = np.meshgrid(m1_grid, m1_grid, indexing='ij')
m1m2_dens = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='cubic')
if USE_LINEAR_INTERP_BACKUP:
    bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
    temp = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='linear')
    print(zmath.frac_str(bads), zmath.frac_str(np.isnan(temp[bads])))
    m1m2_dens[bads] = temp[bads]

bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
print(zmath.frac_str(bads))
temp = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='nearest')
m1m2_dens[bads] = temp[bads]
bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
print(zmath.frac_str(bads))
if np.any(bads):
    raise

fig, axes = plot.figax(figsize=[10, 7], nrows=3, ncols=4)

smap = zplot.smap(m1m2_dens, cmap='viridis', scale='log')

def draw_2d_and_errors(axes, vals):
    pcm = axes[0].pcolormesh(m1_grid, m1_grid, vals, cmap=smap.cmap, norm=smap.norm)
    plt.colorbar(pcm, ax=axes[0], orientation='horizontal')
    ss = (vals - m1m2_dens)
    idx = (ss != 0.0)
    ss[idx] = ss[idx] / m1m2_dens[idx]
    kw = {}
    if np.any(ss != 0.0):
        _smap = zplot.smap(ss, cmap='viridis', scale='log')
        kw = dict(cmap=_smap.cmap, norm=_smap.norm)
    pcm = axes[1].pcolormesh(m1_grid, m1_grid, ss, **kw)
    plt.colorbar(pcm, ax=axes[1], orientation='horizontal')
    
    ax = axes[2]
    nn = 3
    xx = m1_grid/MSOL
    for jj in range(nn):
        ii = jj * ((m1_grid.size-1)//(nn-1))
        cc, = axes[2].plot(xx, vals[ii, :], ls='-', alpha=0.5)
        cc = cc.get_color()
        axes[2].plot(xx, m1m2_dens[ii, :], ls='--', alpha=0.5, color=cc)

        axes[3].plot(xx, vals[:, ii], ls='-', alpha=0.5, color=cc)
        axes[3].plot(xx, m1m2_dens[:, ii], ls='--', alpha=0.5, color=cc)
    
    for ax in axes[2:]:
        ax.set(ylim=[1e-7, 1e-3])
    
    return

m1m2_dens_new = m1m2_dens.copy()
draw_2d_and_errors(axes[0], m1m2_dens_new)

m1m2_dens_new = scatter_redistribute(m1_grid, dist, m1m2_dens_new, axis=0)
draw_2d_and_errors(axes[1], m1m2_dens_new)

m1m2_dens_new = scatter_redistribute(m1_grid, dist, m1m2_dens_new, axis=1)
draw_2d_and_errors(axes[2], m1m2_dens_new)


interp = sp.interpolate.RegularGridInterpolator((m1_grid, m1_grid), m1m2_dens_new)
mtmr_dens_new = interp(m1m2_on_mtmr_grid, method='linear').reshape(m1.shape)
mtmr_dens_new.shape

In [None]:
fig, axes = plot.figax(ncols=3)
err = (mtmr_dens_new - dens) / dens

smap = None
for ii, (ax, vals) in enumerate(zip(axes, [dens, mtmr_dens_new, err])):
    if ii != 1:
        smap = zplot.smap(vals, scale='log')
    pcm = ax.pcolormesh(mtot, mrat, vals.T, cmap=smap.cmap, norm=smap.norm)
    plt.colorbar(pcm, ax=ax, orientation='horizontal')
    
plt.show()


## Manually add scatter to the discrete sampled binaries, calculate histogram and compare to density

In [None]:
_vals = utils.m1m2_from_mtmr(*samps)
vals = 10.0 ** (np.log10(_vals) + np.random.normal(0.0, dist.std(), size=_vals.shape))
vals = utils.mtmr_from_m1m2(*vals)

check, *_ = np.histogram2d(*vals, bins=(mtot, mrat), density=True)
fig, axes = plot.figax(ncols=3)

err = (check - hist) / hist

smap = None
for ii, (ax, vals) in enumerate(zip(axes, [hist, check, err])):
    if ii != 1:
        smap = zplot.smap(vals, scale='log')
    pcm = ax.pcolormesh(mtot, mrat, vals.T, cmap=smap.cmap, norm=smap.norm)
    plt.colorbar(pcm, ax=ax, orientation='horizontal')
    
plt.show()



# 3D (with redshift)

In [None]:
def add_scatter_to_masses(mtot, mrat, dens, dist, refine=4, linear_interp_backup=False):
    assert np.shape(dens) == (mtot.size, mrat.size)

    # Get the primary and secondary masses corresponding to these total-mass and mass-ratios
    m1, m2 = utils.m1m2_from_mtmr(mtot[:, np.newaxis], mrat[np.newaxis, :])
    m1m2_on_mtmr_grid = (m1.flatten(), m2.flatten())

    # Construct a symmetric rectilinear grid in (m1, m2) space
    grid_size = m1.shape[0] * refine
    # make sure the extrema will fully span the required domain
    mextr = zmath.minmax([mtot[0]*mrat[0]/(1.0 + mrat[0]), mtot[-1]*(1.0 + mrat[0])/mrat[0]])
    mgrid = zmath.spacing(mextr, 'log', grid_size)
    m1m2_grid = np.meshgrid(mgrid, mgrid, indexing='ij')

    # Interpolate from irregular m1m2 space (based on mtmr space), into regular m1m2 grid
    m1m2_dens = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='cubic')
    # Fill in problematic values with first-order interpolant
    if linear_interp_backup:
        bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
        temp = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='linear')
        print(zmath.frac_str(bads), zmath.frac_str(np.isnan(temp[bads])))
        m1m2_dens[bads] = temp[bads]

    # Fill in problematic values with zeroth-order interpolant
    bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
    print(zmath.frac_str(bads))
    if np.any(bads):
        temp = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='nearest')
        m1m2_dens[bads] = temp[bads]
        bads = np.isnan(m1m2_dens) | (m1m2_dens <= 0.0)
        print(zmath.frac_str(bads))
        if np.any(bads):
            raise

    # Introduce scatter along both the 0th (primary) and 1th (secondary) axes
    m1m2_dens = scatter_redistribute(mgrid, dist, m1m2_dens, axis=0)
    m1m2_dens = scatter_redistribute(mgrid, dist, m1m2_dens, axis=1)

    # Interpolate result back to mtmr grid
    interp = sp.interpolate.RegularGridInterpolator((mgrid, mgrid), m1m2_dens)
    m1m2_dens = interp(m1m2_on_mtmr_grid, method='linear').reshape(m1.shape)
    return m1m2_dens


mmbulge = holo.host_relations.MMBulge_Standard()
dist = sp.stats.norm(loc=0.0, scale=mmbulge._scatter_dex)

gsmf = holo.sam.GSMF_Schechter()
gpf = holo.sam.GPF_Power_Law()
gmt = holo.sam.GMT_Power_Law()

NBINS = 6
mtot = np.logspace(10, 12.5, NBINS+1) * MSOL
mrat = np.logspace(-2, 0, NBINS+2)
redz = np.linspace(0.0, 4.0, NBINS-1)

m1, m2 = utils.m1m2_from_mtmr(mtot[:, np.newaxis], mrat[np.newaxis, :])
print(mtot.shape, mrat.shape, m1.shape, m2.shape, redz.shape)
dens = (
    gsmf(m1[:, :, np.newaxis], redz[np.newaxis, np.newaxis, :]) * 
    gpf(m1[:, :, np.newaxis], mrat[np.newaxis, :, np.newaxis], redz[np.newaxis, np.newaxis, :])
)

prime = np.zeros_like(dens)
for ii, rz in enumerate(redz):
    prime[:, :, ii] = add_scatter_to_masses(mtot, mrat, dens[:, :, ii], dist)

print(zmath.stats_str(dens))
print(zmath.stats_str(prime))

In [None]:
NN = 3
# mtot = np.logspace(10, 12.5, NN) * MSOL
# mrat = np.logspace(-2, 0, NN)
# redz = np.linspace(0.0, 4.0, 3)
mtot = np.logspace(10, 11, NN) * MSOL
mrat = np.logspace(-1, 0, NN)
redz = np.linspace(0.0, 1.0, 3)
print(f"{mtot=}")
print(f"{mrat=}")
print(f"{redz=}")

np.random.seed(1234)
dens = np.random.uniform(0.0, 100.0, size=(mtot.size, mrat.size, redz.size))
print(f"{dens=}")

# Get the primary and secondary masses corresponding to these total-mass and mass-ratios
m1, m2 = utils.m1m2_from_mtmr(mtot[:, np.newaxis, np.newaxis], mrat[np.newaxis, :, np.newaxis])

rz = np.ones_like(m1) * redz[np.newaxis, np.newaxis, :]
m1, m2 = [mm * np.ones_like(rz) for mm in [m1, m2]]

# print(f"{m1.shape=}, {rz.shape=}")
# m1m2_on_mtmr_grid = (m1.reshape(-1, redz.size), m2.reshape(-1, redz.size), rz.reshape(-1, redz.size))
m1m2_on_mtmr_grid = tuple([mm.flatten() for mm in [m1, m2, rz]])
# print(f"{m1.shape=}, {rz.shape=}, {np.shape(m1m2_on_mtmr_grid)=}")

# Construct a symmetric rectilinear grid in (m1, m2) space
RR = 2
grid_size = m1.shape[0] * RR
# make sure the extrema will fully span the required domain
# mextr = zmath.minmax([mtot[0]*mrat[0]/(1.0 + mrat[0]), mtot[-1]*(1.0 + mrat[0])/mrat[0]])
mextr = zmath.minmax(m1, prev=zmath.minmax(m2))
print(f"{mextr=}")
mgrid = zmath.spacing(mextr, 'log', grid_size)
m1m2_grid = np.meshgrid(mgrid, mgrid, redz, indexing='ij')
# print(f"{np.shape(m1m2_grid)=}")

temp = np.copy(m1m2_on_mtmr_grid)
temp[0] = np.log10(temp[0])
temp[1] = np.log10(temp[1])
print(f"{np.shape(temp)=}")
print(f"{dens.flatten()=}")
print(temp[0])
print(temp[1])
print(temp[2])

temp_grid = np.copy(m1m2_grid)
temp_grid[0] = np.log10(temp_grid[0])
temp_grid[1] = np.log10(temp_grid[1])
print(temp.shape, temp_grid.shape)
print(temp_grid[0][:, 0, 0])
print(temp_grid[1][0, :, 0])
print(temp_grid[2][0, 0, :])

# Interpolate from irregular m1m2 space (based on mtmr space), into regular m1m2 grid
# test = sp.interpolate.griddata(tuple(temp), dens.flatten(), tuple(m1m2_grid), method='linear', rescale=False)
interp = sp.interpolate.LinearNDInterpolator(tuple(temp), dens.flatten(), rescale=True)
# interp = sp.interpolate.NearestNDInterpolator(temp.T, dens.flatten(), rescale=True)
test = interp(*m1m2_grid)
print(np.all(np.isnan(test)))
print(test)

In [None]:
def interp_dim(mtot, mrat, dens):
    # Get the primary and secondary masses corresponding to these total-mass and mass-ratios
    m1, m2 = utils.m1m2_from_mtmr(mtot[:, np.newaxis], mrat[np.newaxis, :])
    m1m2_on_mtmr_grid = tuple([mm.flatten() for mm in [m1, m2]])
    # print(f"{m1.shape=}, {np.shape(m1m2_on_mtmr_grid)=}")
    
    # Construct a symmetric rectilinear grid in (m1, m2) space
    grid_size = m1.shape[0] * 2
    # make sure the extrema will fully span the required domain
    mextr = zmath.minmax([mtot[0]*mrat[0]/(1.0 + mrat[0]), mtot[-1]*(1.0 + mrat[0])/mrat[0]])
    mgrid = zmath.spacing(mextr, 'log', grid_size)
    m1m2_grid = np.meshgrid(mgrid, mgrid, indexing='ij')
    # print(f"{np.shape(m1m2_grid)=}")


    # Interpolate from irregular m1m2 space (based on mtmr space), into regular m1m2 grid
    test = sp.interpolate.griddata(m1m2_on_mtmr_grid, dens.flatten(), tuple(m1m2_grid), method='linear', rescale=False)
    print(test)
    return test

check = np.zeros_like(test)
for ii, rz in enumerate(redz):
    check[:, :, ii] = interp_dim(mtot, mrat, dens[:, :, ii])

In [None]:
np.all(test == check)

# Test on SAM model

In [None]:
def draw_gwb(ax, xx, gwb, nsamp=10, color=None, label=None):
    if color is None:
        color = ax._get_lines.get_next_color()

    mm, *conf = np.percentile(gwb, [50, 25, 75], axis=1)
    ax.plot(xx, mm, alpha=0.5, color=color, label=label)
    ax.fill_between(xx, *conf, color=color, alpha=0.15)

    if (nsamp is not None) and (nsamp > 0):
        nsamp_max = gwb.shape[1]
        idx = np.random.choice(nsamp_max, np.min([nsamp, nsamp_max]), replace=False)
        for ii in idx:
            ax.plot(xx, gwb[:, ii], color=color, alpha=0.25, lw=1.0, ls='-')
            
    return

In [None]:
SHAPE = 50
REALS = 100
scatter_list = [0.0, 0.3, 0.6]
fobs_edges = utils.nyquist_freqs_edges(20.0*YR, 0.2*YR)
fobs = utils.midpoints(fobs_edges)

gwbs = []
for scatter in scatter_list:
    mmbulge = holo.host_relations.MMBulge_Standard(scatter_dex=scatter)
    sam = holo.sam.Semi_Analytic_Model(shape=SHAPE, mmbulge=mmbulge)
    # hard = holo.hardening.Fixed_Time.from_sam(sam, temp)
    hard = holo.hardening.Hard_GW
    gwb = sam.gwb(fobs_edges, hard=hard, realize=REALS, zero_stalled=False, use_redz_after_hard=False)
    gwbs.append(gwb)    # calculate many different realizations

In [None]:
fig, ax = plot.figax(
    xlabel=plot.LABEL_GW_FREQUENCY_YR, ylabel=plot.LABEL_CHARACTERISTIC_STRAIN,
)
for gwb, lab in zip(gwbs[::-1], scatter_list[::-1]):
    lab = f"{lab:.2f}"
    draw_gwb(ax, fobs*YR, gwb, label=lab)
    
ax.legend(title='stdev [dex]')
plot._twin_hz(ax)

plt.show()
output_path = Path("~/scatter.png").expanduser()
fig.savefig(output_path, dpi=400)

# Examine SAM grid edges

In [None]:
reload(holo.sam)

In [None]:
SHAPE = (60, 60, 60)
mamp_log10_list = [8.5, 9.0, 9.5, 10.0]
scatter_list = [0.0, 0.3, 0.6, 0.9, 1.2]

# masses = []
# fobs_edges = utils.nyquist_freqs_edges(20.0*YR, 0.2*YR)
# fobs = utils.midpoints(fobs_edges)
# mamp_log10 = 8.5
# masses_lo = []
# mamp_log10 = 9.0
# masses_me = []
# mamp_log10 = 9.5
# masses_hi = []

masses_60_extend = np.zeros((len(mamp_log10_list), len(scatter_list)))

for ii, mamp_log10 in enumerate(mamp_log10_list):

    for jj, scatter in enumerate(scatter_list):
        mmbulge = holo.host_relations.MMBulge_Standard(mamp_log10=mamp_log10, scatter_dex=scatter)
        sam = holo.sam.Semi_Analytic_Model(shape=SHAPE, mmbulge=mmbulge)
        if ii == 0 and jj == 0:
            print(utils.stats(sam.mtot))

        dens = sam.static_binary_density
        mass = sam._integrated_binary_density(dens, sum=True)
        masses_60_extend[ii, jj] = mass


In [None]:
# masses = [masses_lo, masses_me, masses_hi]
labels = ['lo', 'me', 'hi']

fig, ax = plot.figax(xscale='lin')

colors = []
for mm, lab in zip(masses_60, mamp_log10_list):
    cc, = ax.plot(scatter_list[1:], np.diff(mm)/mm[0], label=lab)
    cc = cc.get_color()
    colors.append(cc)

for ii, (mm, lab) in enumerate(zip(masses_100, mamp_log10_list)):
    ax.plot(scatter_list[1:], np.diff(mm)/mm[0], label=lab, color=colors[ii], ls='--')

for ii, (mm, lab) in enumerate(zip(masses_60_extend, mamp_log10_list)):
    ax.plot(scatter_list[1:], np.diff(mm)/mm[0], label=lab, color=colors[ii], ls=':')

ax.legend()
plt.show()

# Check convergence against MTOT grid 

In [None]:
mamp_log10_list = [8.5, 10.5]
scatter_list = [0.5, 1.5]
SHAPE = 20

for ii, mamp_log10 in enumerate(mamp_log10_list):
    for jj, scatter in enumerate(scatter_list):
        mmbulge = holo.host_relations.MMBulge_Standard(mamp_log10=mamp_log10, scatter_dex=scatter)
        sam = holo.sam.Semi_Analytic_Model(shape=SHAPE, mmbulge=mmbulge)
        if ii == 0 and jj == 0:
            dens_bef = np.zeros((2, 2,) + sam.shape)
            dens_aft = np.zeros((2, 2,) + sam.shape)
        sam.static_binary_density
        dens_bef[ii, jj] = sam._dens_bef[...]
        dens_aft[ii, jj] = sam._dens_aft[...]


In [None]:
fig, axes = plt.subplots(figsize=[10, 10], ncols=3, nrows=4)
for (ii, jj), ax in np.ndenumerate(axes):
    ax.set(yscale='log', xscale='log')

xx = sam.mtot/MSOL
yy = sam.mrat

smap_mass = plot.smap([1e-20, 1e-1], log=True)
smap_ratio = plot.smap([1e-10, 10.0], log=True)

for ii, jj in np.ndindex(2, 2):
    kk = ii + 2*jj
    axrow = axes[kk, :]
    if kk == 0:
        for ax, lab in zip(axrow, ['bef', 'aft', 'bef/aft']):
            ax.set_title(lab, fontsize=8)
    
    axrow[0].set(ylabel=f"M={mamp_log10_list[ii]:.2f}   s={scatter_list[jj]:.2f}")
    
    mass_bef = sam._integrated_binary_density(dens_bef[ii, jj], sum=False).sum(axis=-1)
    mass = mass_bef
    print(utils.stats(mass[mass > 0.0]))
    pcm = axrow[0].pcolormesh(xx, yy, mass.T, cmap=smap_mass.cmap, norm=smap_mass.norm)
    plt.colorbar(pcm, ax=axrow[0])

    mass_aft = sam._integrated_binary_density(dens_aft[ii, jj], sum=False).sum(axis=-1)
    mass = mass_aft
    print(utils.minmax(mass[mass > 0.0]))
    pcm = axrow[1].pcolormesh(xx, yy, mass.T, cmap=smap_mass.cmap, norm=smap_mass.norm)
    plt.colorbar(pcm, ax=axrow[1])

    mass = mass_bef/mass_aft
    print(utils.minmax(mass[mass > 0.0]))
    pcm = axrow[2].pcolormesh(xx, yy, mass.T, cmap=smap_ratio.cmap, norm=smap_ratio.norm)
    plt.colorbar(pcm, ax=axrow[2])

plt.show()

In [None]:
mamp_log10 = 9.5
scatter_dex = 1.0
mmbulge = holo.host_relations.MMBulge_Standard(mamp_log10=mamp_log10, scatter_dex=scatter_dex)
hard = holo.hardening.Hard_GW()
nreals = 1000

fobs_edges = [1/(15*YR), 2/(15*YR)]
fobs_cents = utils.midpoints(fobs_edges)

# MTOT_LIST = [(1e6*MSOL, 1e11*MSOL, 61), (1e6*MSOL, 1e12*MSOL, 71), (1e6*MSOL, 1e13*MSOL, 81)]
# MTOT_LIST = [(1e6*MSOL, 1e12*MSOL, 61), (1e6*MSOL, 1e12*MSOL, 71), (1e6*MSOL, 1e12*MSOL, 81)]
# MTOT_LIST = [(1e6*MSOL, 1e12*MSOL, 61), (1e6*MSOL, 1e12*MSOL, 71), (1e6*MSOL, 1e12*MSOL, 81), (1e6*MSOL, 1e12*MSOL, 91), (1e6*MSOL, 1e12*MSOL, 101)]
# MTOT_LIST = [(1e6*MSOL, 1e12*MSOL, 60), (1e6*MSOL, 1e12*MSOL, 80), (1e6*MSOL, 1e12*MSOL, 100)]
# MTOT_LIST = [(1e6*MSOL, 1e13*MSOL, 60), (1e6*MSOL, 1e13*MSOL, 80), (1e6*MSOL, 1e13*MSOL, 100)]
MTOT_LIST = [(1e6*MSOL, 1e11*MSOL, 100), (1e6*MSOL, 1e12*MSOL, 100), (1e6*MSOL, 1e13*MSOL, 100), (1e6*MSOL, 1e14*MSOL, 100)]

gwbs = []
for MTOT in MTOT_LIST:
    sam = holo.sam.Semi_Analytic_Model(mtot=MTOT, mmbulge=mmbulge, redz=(1e-3, 10.0, 40), shape=30)
    gwb = sam.gwb(fobs_edges, hard=hard, realize=nreals, zero_stalled=False, use_redz_after_hard=False)
    gwbs.append(gwb)

In [None]:
fig, axes = plt.subplots(figsize=[10, 5], ncols=2)
fig.text(0.5, 0.98, f"{mamp_log10=}  {scatter_dex=}", fontsize=8, ha='center', va='top')

nums = [np.log10(mt[-2]/MSOL) for mt in MTOT_LIST]
funcs = [np.median, lambda xx: np.std(np.log10(xx))]
for ii, ax in enumerate(axes):
    vals = [funcs[ii](hc) for hc in gwbs]
    # ax.plot(nums, vals)
    if ii == 0:
        err = np.sqrt(nreals)
    else:
        err = 2*np.sqrt(nreals - 1)
    ax.errorbar(nums, vals, yerr=vals/err)
    
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=[10, 5], ncols=2)
fig.text(0.5, 0.98, f"{mamp_log10=}  {scatter_dex=}", fontsize=8, ha='center', va='top')

nums = [np.log10(mt[-2]/MSOL) for mt in MTOT_LIST]
funcs = [np.median, lambda xx: np.std(np.log10(xx))]
for ii, ax in enumerate(axes):
    vals = [funcs[ii](hc) for hc in gwbs]
    # ax.plot(nums, vals)
    if ii == 0:
        err = np.sqrt(nreals)
    else:
        err = 2*np.sqrt(nreals - 1)
    ax.errorbar(nums, vals, yerr=vals/err)
    
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=[10, 5], ncols=2)
fig.text(0.5, 0.98, f"{mamp_log10=}  {scatter_dex=}", fontsize=8, ha='center', va='top')

nums = [mt[-1] for mt in MTOT_LIST]
funcs = [np.median, lambda xx: np.std(np.log10(xx))]
for ii, ax in enumerate(axes):
    vals = [funcs[ii](hc) for hc in gwbs]
    # ax.plot(nums, vals)
    if ii == 0:
        err = np.sqrt(nreals)
    else:
        err = 2*np.sqrt(nreals - 1)
    ax.errorbar(nums, vals, yerr=vals/err)
    
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=[10, 5], ncols=2)
fig.text(0.5, 0.98, f"{mamp_log10=}  {scatter_dex=}", fontsize=8, ha='center', va='top')

nums = [mt[-1] for mt in MTOT_LIST]
funcs = [np.median, lambda xx: np.std(np.log10(xx))]
for ii, ax in enumerate(axes):
    vals = [funcs[ii](hc) for hc in gwbs]
    # ax.plot(nums, vals)
    if ii == 0:
        err = np.sqrt(nreals)
    else:
        err = 2*np.sqrt(nreals - 1)
    ax.errorbar(nums, vals, yerr=vals/err)
    
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=[10, 5], ncols=2)
fig.text(0.5, 0.98, f"{mamp_log10=}  {scatter_dex=}", fontsize=8, ha='center', va='top')

nums = [mt[-1] for mt in MTOT_LIST]
funcs = [np.median, lambda xx: np.std(np.log10(xx))]
for ii, ax in enumerate(axes):
    vals = [funcs[ii](hc) for hc in gwbs]
    # ax.plot(nums, vals)
    if ii == 0:
        err = np.sqrt(nreals)
    else:
        err = 2*np.sqrt(nreals - 1)
    ax.errorbar(nums, vals, yerr=vals/err)
    
plt.show()

## New cython calculations

In [None]:
mamp_log10 = 9.5
scatter_dex = 1.0
mmbulge = holo.host_relations.MMBulge_Standard(mamp_log10=mamp_log10, scatter_dex=scatter_dex)
hard = holo.hardening.Hard_GW()
nreals = 1000

fobs_edges = [1/(15*YR), 2/(15*YR)]
fobs_cents = utils.midpoints(fobs_edges)

# MTOT_LIST = [(1e6*MSOL, 1e11*MSOL, 61), (1e6*MSOL, 1e12*MSOL, 71), (1e6*MSOL, 1e13*MSOL, 81)]
# MTOT_LIST = [(1e6*MSOL, 1e12*MSOL, 61), (1e6*MSOL, 1e12*MSOL, 71), (1e6*MSOL, 1e12*MSOL, 81)]
# MTOT_LIST = [(1e6*MSOL, 1e12*MSOL, 61), (1e6*MSOL, 1e12*MSOL, 71), (1e6*MSOL, 1e12*MSOL, 81), (1e6*MSOL, 1e12*MSOL, 91), (1e6*MSOL, 1e12*MSOL, 101)]
# MTOT_LIST = [(1e6*MSOL, 1e12*MSOL, 60), (1e6*MSOL, 1e12*MSOL, 80), (1e6*MSOL, 1e12*MSOL, 100)]
# MTOT_LIST = [(1e6*MSOL, 1e13*MSOL, 60), (1e6*MSOL, 1e13*MSOL, 80), (1e6*MSOL, 1e13*MSOL, 100)]
MTOT_LIST = [(1e6*MSOL, 1e11*MSOL, 100), (1e6*MSOL, 1e12*MSOL, 100), (1e6*MSOL, 1e13*MSOL, 100), (1e6*MSOL, 1e14*MSOL, 100)]

gwbs = []
for MTOT in MTOT_LIST:
    sam = holo.sam.Semi_Analytic_Model(mtot=MTOT, mmbulge=mmbulge, redz=(1e-3, 10.0, 40), shape=30)
    gwb = sam.gwb(fobs_edges, hard=hard, realize=nreals, zero_stalled=False, use_redz_after_hard=False)
    gwbs.append(gwb)