In [None]:
# %load ../notebooks/init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC, NWTG
import holodeck.gravwaves
import holodeck.evolution
import holodeck.population

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import zcode.math as zmath

In [None]:
import zcode.plot as zplot

# Look at 1D distribution of Mass

## Choose stellar-masses, Calculate the GSMF, translate to MBH masses

In [None]:
redz = 1.0
NBINS = 5
# mstar_edges = np.logspace(11, 12, 10) * MSOL
mstar_edges = np.logspace(10, 12.5, NBINS+1) * MSOL

mmbulge = holo.relations.MMBulge_Standard()
gsmf = holo.sam.GSMF_Schechter()

phi_mstar = gsmf(mstar_edges, redz)    # [Mpc^-3]
mbh_edges = mmbulge.mbh_from_mstar(mstar_edges, False)
dmdm = mmbulge.dmstar_dmbh(mstar_edges)

phi_mbh = phi_mstar * dmdm * mbh_edges / mstar_edges

xx = utils.midpoints(mstar_edges / MSOL)
yy = utils.trapz(phi_mstar, xx=np.log10(mstar_edges), cumsum=False)

fig, ax = plot.figax()
ax.plot(xx, yy, 'b-')

tw = ax.twiny()
tw.set(xscale='log', xlabel='BH Mass [Msol]')
xx = utils.midpoints(mbh_edges / MSOL)
yy = utils.trapz(phi_mbh, xx=np.log10(mbh_edges), cumsum=False)
tw.plot(xx, yy, 'r-')

plt.show()

## Simplest Method: Add the effects of scatter to (re-)calculate MBH density

In [None]:
UNITS = MSOL

dist = sp.stats.norm(loc=0.0, scale=mmbulge._scatter_dex)
def weights_for_bin(edges, cents, dist, bin):

    # Get locations of the edges to the left of this bin (ll) and right of this bin (rr)
    ll = edges[:bin+1]
    rr = edges[bin+1:]
    # Get the center of this bin
    cc = cents[bin]
    # Find the distance from this bin (center) to edges of left and right bins
    ll = ll - cc   # negative (left)
    rr = rr - cc   # positive (right)
    # Calculate the distribution's CDF, centered at zero, at these different distances
    wl = dist.cdf(ll)
    wr = dist.cdf(rr)
    # Find the mass in each bin by differencing the CDF across bin edges
    wl = np.diff(wl)    # mass in each bin to the left
    wr = np.diff(wr)    # mass in each bin to the right
    # print(bin, ll.size, rr.size, wl.size, wr.size)

    # Find the mass that will remain in this bin, i.e. the mass between this bin's edges
    wc = dist.cdf([ll[-1], rr[0]])
    wc = np.diff(wc)
    # combine the left, center, right weights together
    weights = np.concatenate([wl, wc, wr])
    return weights

def scatter_redistribute(edges, dist, dens):
    log_edges = np.log10(edges)
    log_cents = utils.midpoints(log_edges, log=False)
    dens_new = np.zeros_like(dens)
    for bin in range(log_cents.size):
        weights = weights_for_bin(log_edges, log_cents, dist, bin)
        for jj in range(log_cents.size):
            dens_new[jj] += phi_mbh[bin] * weights[jj]

    return dens_new

phi_new = scatter_redistribute(mbh_edges/UNITS, dist, phi_mbh)

fig, ax = plot.figax()

xx = utils.midpoints(mbh_edges / MSOL)
yy = utils.trapz(phi_mbh, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, yy, 'k-', alpha=0.5, lw=2.0)

zz = utils.trapz(phi_new, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, zz, 'r--')

# yy = utils.trapz(phi_mbh, xx=np.log10(mbh_edges), cumsum=False)
yy = scatter_redistribute(mbh_edges/UNITS, dist, yy)
ax.plot(xx, yy, 'b--')

plt.show()

## Try to be a bit smarter

In [None]:
UNITS = MSOL

dist = sp.stats.norm(loc=0.0, scale=mmbulge._scatter_dex)
def weights_for_bin(edges, cents, dist, bin):
    # Get locations of the edges to the left of this bin (ll) and right of this bin (rr)
    ll = edges[:bin+1]
    rr = edges[bin+1:]
    # Get the center of this bin
    cc = cents[bin]
    # Find the distance from this bin (center) to edges of left and right bins
    ll = ll - cc   # negative (left)
    rr = rr - cc   # positive (right)
    # Calculate the distribution's CDF, centered at zero, at these different distances
    wl = dist.cdf(ll)
    wr = dist.cdf(rr)
    # Find the mass in each bin by differencing the CDF across bin edges
    wl = np.diff(wl)    # mass in each bin to the left
    wr = np.diff(wr)    # mass in each bin to the right
    # print(bin, ll.size, rr.size, wl.size, wr.size)

    # Find the mass that will remain in this bin, i.e. the mass between this bin's edges
    wc = dist.cdf([ll[-1], rr[0]])
    wc = np.diff(wc)
    # combine the left, center, right weights together
    weights = np.concatenate([wl, wc, wr])
    return weights


def custom_roll(arr, r_tup):
    m = np.asarray(r_tup)
    arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() #need `copy`
    strd_0, strd_1 = arr_roll.strides
    n = arr.shape[1]
    result = np.lib.stride_tricks.as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))

    return result[np.arange(arr.shape[0]), (n-m)%n]


def get_weights(log_edges, dist):
    dx = np.diff(log_edges)
    nbins = dx.size
    assert np.allclose(dx, dx[0])
    # cents = log_edges[:-1] + dx/2.0
    dx = dx[0]
    dx = dx/2.0 + np.arange(nbins) * dx
    dx = np.concatenate([-dx[::-1], dx])
    dm = np.diff(dist.cdf(dx))
    return dm
    

def scatter_redistribute(edges, dist, dens):
    log_edges = np.log10(edges)
    log_cents = utils.midpoints(log_edges, log=False)
    dens_new = np.zeros_like(dens)

    check = get_weights(log_edges, dist)
    # print(check)
    roll = -(check.size-1)//2 + np.arange(log_cents.size)
    # print(f"{roll=}")
    check = custom_roll(check[np.newaxis, :] * np.ones((roll.size, check.size)), roll)
    check = check[:, :log_cents.size]

    print(f"{edges.shape=}, {dens.shape=}")

    for bin in range(log_cents.size):
        weights = weights_for_bin(log_edges, log_cents, dist, bin)
        # print(f"\n{bin}")
        # print(zmath.str_array(weights, sides=None))
        # print(zmath.str_array(check[bin], sides=None))

        assert np.allclose(check[bin], weights)
        
        for jj in range(log_cents.size):
            dens_new[jj] += phi_mbh[bin] * weights[jj]
            # break
            
    print(f"{phi_mbh.shape=} {check.shape=}")
    test = np.einsum("j,jk", phi_mbh, check)
    assert np.allclose(test, dens_new)

    return dens_new

phi_new = scatter_redistribute(mbh_edges/UNITS, dist, phi_mbh)
# raise

fig, ax = plot.figax()

xx = utils.midpoints(mbh_edges / MSOL)
yy = utils.trapz(phi_mbh, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, yy, 'k-', alpha=0.5, lw=2.0)

yy = utils.trapz(phi_new, xx=np.log10(mbh_edges), cumsum=False)
ax.plot(xx, yy, 'r--')

plt.show()

# 2D with mtot and mrat

In [None]:
gsmf = holo.sam.GSMF_Schechter()
gpf = holo.sam.GPF_Power_Law()
gmt = holo.sam.GMT_Power_Law()

NBINS = 100
redz = 0.1
mtot = np.logspace(10, 12.5, NBINS+1) * MSOL
mrat = np.logspace(-2, 0, NBINS+2)

m1, m2 = utils.m1m2_from_mtmr(mtot[:, np.newaxis], mrat[np.newaxis, :])
dens = gsmf(m1, redz) * gpf(m1, mrat, redz)
print(dens.shape)

NSAMP = 1e5
samps = kale.sample_grid([mtot, mrat], dens, nsamp=NSAMP)
print(samps.shape)

fig, axes = plot.figax(ncols=3)
density_flag = False
axes[0].hist2d(*samps, bins=[mtot, mrat], density=density_flag)
axes[1].hist(samps[0], bins=mtot, density=density_flag)
axes[2].hist(samps[1], bins=mrat, density=density_flag)

plt.show()



## Demonstrate translating between   mt, mr <==> m1, m2   and then back to   m1, m2 <==> mt, mr

In [None]:
fig, ax = plot.figax()

smap = zplot.smap(dens, 'viridis', scale='log')
cc = smap.to_rgba(dens)

ax.scatter(m1.flatten(), m2.flatten(), c=cc.reshape(-1, 4), zorder=100, s=5, edgecolor='0.5', lw=0.5)
xx = np.array(sorted(m1.flatten()))
ax.plot(xx, xx)
ax.plot(xx, xx*1e-2)

REFINE = 4
g1 = zmath.minmax([mtot[0]*mrat[0]/(1.0 + mrat[0]), mtot[-1]*(1.0 + mrat[0])/mrat[0]])
# g1 = zmath.minmax([mtot[0]*mrat[0]/(1.0 + mrat[0]), mtot[-1]])
g1 = zmath.spacing(g1, 'log', m1.shape[0]*REFINE)
g2 = zmath.spacing(g1, 'log', m1.shape[0]*REFINE)
for gg in g1:
    ax.axhline(gg, lw=0.25)
    ax.axvline(gg, lw=0.25)

gg = np.meshgrid(g1, g2, indexing='ij')
temp = sp.interpolate.griddata((m1.flatten(), m2.flatten()), dens.flatten(), tuple(gg), method='linear')
# zz = temp
zz = sp.interpolate.griddata((m1.flatten(), m2.flatten()), dens.flatten(), tuple(gg), method='cubic')
bads = np.isnan(zz) | (zz <= 0.0)
print(zmath.frac_str(bads), zmath.frac_str(np.isnan(temp[bads])))
zz[bads] = temp[bads]
bads = np.isnan(zz) | (zz <= 0.0)
print(zmath.frac_str(bads))
temp = sp.interpolate.griddata((m1.flatten(), m2.flatten()), dens.flatten(), tuple(gg), method='nearest')
zz[bads] = temp[bads]
bads = np.isnan(zz) | (zz <= 0.0)
print(zmath.frac_str(bads))
ax.pcolormesh(g1, g2, zz, alpha=0.8, cmap=smap.cmap, norm=smap.norm)
    
interp = sp.interpolate.RegularGridInterpolator((g1, g2), zz)
    
plt.show()

In [None]:
ww = interp((m1.flatten(), m2.flatten()), method='linear').reshape(m1.shape)
ww.shape


In [None]:
fig, axes = plot.figax(ncols=3)

err = (ww - dens) / ww
print(zmath.stats_str(err))

smap = zplot.smap([dens, ww], cmap='viridis', scale='log')
err_map = zplot.ScalarMappable2D

values = [ww, dens, err]
maps = [smap, smap, None]
for ax, mm, vv in zip(axes, maps, values):
    kw = {} if mm is None else dict(cmap=smap.cmap, norm=smap.norm)
    pcm = ax.pcolormesh(mtot, mrat, vv.T, **kw)
    plt.colorbar(pcm, ax=ax, orientation='horizontal')

plt.show()

## Add Scatter in m1, m2 space and translate back

In [None]:

REFINE = 10
# Define grid in m1, m2 that will fully span the space
gg = zmath.minmax([mtot[0]*mrat[0]/(1.0 + mrat[0]), mtot[-1]*(1.0 + mrat[0])/mrat[0]])
g1 = zmath.spacing(gg, 'log', m1.shape[0]*REFINE)
g2 = zmath.spacing(gg, 'log', m1.shape[0]*REFINE)
gg = np.meshgrid(g1, g2, indexing='ij')

# ---- Interpolate from mt,mr space into m1,m2 space ----
# Use a cubic interpolator where possible
zz = sp.interpolate.griddata((m1.flatten(), m2.flatten()), dens.flatten(), tuple(gg), method='cubic')
# Fix problematic locations with linear (still only works inside domain)
temp = sp.interpolate.griddata((m1.flatten(), m2.flatten()), dens.flatten(), tuple(gg), method='linear')
bads = np.isnan(zz) | (zz <= 0.0)
zz[bads] = temp[bads]
# Fix problematic locations with nearest  (works _outside_ of domain also)
temp = sp.interpolate.griddata((m1.flatten(), m2.flatten()), dens.flatten(), tuple(gg), method='nearest')
bads = np.isnan(zz) | (zz <= 0.0)
zz[bads] = temp[bads]

# Interpolate back from m1,m2 space into mt,mr space
interp = sp.interpolate.RegularGridInterpolator((g1, g2), zz)
ww = interp((m1.flatten(), m2.flatten()), method='linear').reshape(m1.shape)
