In [None]:
# %load ../init.ipy
%reload_ext autoreload
%autoreload 2

# Builtin packages
from importlib import reload
import logging
import os
from pathlib import Path
import sys
import warnings

# standard secondary packages
import astropy as ap
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats
import tqdm.notebook as tqdm

# development packages
import kalepy as kale
import kalepy.utils
import kalepy.plot

# --- Holodeck ----
import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

# Load log and set logging level
log = holo.log
log.setLevel(logging.INFO)

In [None]:
import zcode
import zcode.math as zmath
import zcode.plot as zplot
import zcode.inout as zio

* Check how integration compares to centroid strains, are they consistent?

## Construct SAM and calculate SAM binaries Grid

In [None]:
gsmf = holo.sam.GSMF_Schechter()        # Galaxy Stellar-Mass Function (GSMF)
gpf = holo.sam.GPF_Power_Law()          # Galaxy Pair Fraction         (GPF)
gmt = holo.sam.GMT_Power_Law()          # Galaxy Merger Time           (GMT)
mmbulge = holo.host_relations.MMBulge_Standard()     # M-MBulge Relation            (MMB)
hard = holo.hardening.Hard_GW
shape = 40

sam = holo.sam.Semi_Analytic_Model(gsmf=gsmf, gpf=gpf, gmt=gmt, mmbulge=mmbulge, shape=shape)

In [None]:
# fobs = utils.nyquist_freqs(10.0*YR, 0.1*YR)
fobs = utils.nyquist_freqs(10.0*YR, 1.0*YR)
gwb_smooth = sam.gwb(fobs, realize=False)
edges, dnum = sam.dynamic_binary_number(hard, fobs=fobs)
number = holo.utils._integrate_grid_differential_number(edges, dnum, freq=True)

In [None]:
sample_threshold = 1e2
REALS = 27
REALS = True
# hc_grid = holo.sam._gws_from_number_grid(edges, dnum, number, realize=REALS, integrate=False)
hc_grid = holo.sam._gws_from_number_grid(edges, dnum, number, realize=False, integrate=False)
hc_test = holo.sam._gws_from_number_grid_integrated(edges, dnum, number, realize=REALS, integrate=False)
print(f"{hc_grid.shape=} {hc_test.shape=}")
if REALS not in [False, None, True]:
    # hc_grid = np.median(hc_grid, axis=-1)
    hc_test = np.median(hc_test, axis=-1)

edges_sample = [np.log10(edges[0]), edges[1], edges[2], np.log(edges[3])]
vals, weights = kale.sample_outliers(edges_sample, dnum, threshold=sample_threshold, mass=number)
vals[0] = 10.0 ** vals[0]
vals[3] = np.e ** vals[3]
hs, fo = holo.sam._strains_from_samples(vals)

# ---- Bin the sampled strains, get characteristic strain for each bin

hc_samp, *_ = sp.stats.binned_statistic_dd(vals.T, weights*(hs**2), statistic='sum', bins=edges)
cycles = 1.0 / np.diff(np.log(fobs))
hc_samp = hc_samp * cycles[np.newaxis, np.newaxis, np.newaxis, :]

wcount, *_ = sp.stats.binned_statistic_dd(vals.T, weights, statistic='sum', bins=edges)
print(f"{hc_grid.shape=} {hc_samp.shape=} {wcount.sum()=:.4e}    {hc_test.shape=}")

# ---- Plot Strains 

fig, axes = plot.figax(nrows=2)
plot._draw_plaw(axes[0], fobs, 1e-15, color='k', ls='--', alpha=0.5)
tw = axes[1].twinx()
tw.set(yscale='log')
prev = None

try:
    print(f"{hc_test.shape=}")
    axes[0].plot(fobs, np.sqrt(hc_test.sum(axis=(0, 1, 2))), 'k--', label='test')
except:
    pass

for hcv, nn, lab in zip([hc_grid, hc_samp, hc_test], [number, wcount, number], ['grid', 'samp', 'test']):
    print(lab, np.shape(hcv), utils.stats(hcv))
    yy = np.sqrt(hcv.sum(axis=(0, 1, 2)))
    try:
        xx, yy = plot._get_hist_steps(fobs, yy)
        cc, = axes[0].plot(xx, yy, label=lab, alpha=0.5)
        cc = cc.get_color()
    except:
        cc = None
    
    axes[1].plot(np.sqrt(hcv).flatten(), color=cc, alpha=0.5)
    # tw.plot(nn.flatten(), color=cc, alpha=0.5, ls='--')

    if prev is None:
        prev = yy[0]
    else:
        next = yy[0]
        diff = (next - prev) / prev
        print(f"{prev=:.4e}, {next=:.4e}, {diff=:.4e}")
        # prev = next

axes[0].legend()
# plt.show()
zplot.set_lim(axes[1], 'y', lo=1e-20, at='exactly')
zplot.set_lim(tw, 'y', lo=0.1, at='exactly')
axes[1].set(xscale='linear')
plt.show()

In [None]:
diff = (hc_grid - hc_samp) / hc_grid
dmax = np.nan_to_num(diff)
diff_num = (number - wcount) / np.sqrt(number)
diff_num = np.nan_to_num(diff_num)
# diff_num = diff_num / np.sqrt(number)
print(utils.stats(diff_num))
dmax = dmax[(dmax != 0.0) & (dmax != 1.0) & (wcount > 20) & (np.fabs(diff_num) < 1.0)]
dmax = dmax.min()
idx = np.where(diff == dmax)
print(dmax, idx, number[idx], wcount[idx], hc_grid[idx], hc_samp[idx], diff_num[idx])
print()

ee = [[ee[ii+jj] for jj in range(2)] for ee, ii in zip(edges, idx)]
ee = np.asarray(ee).squeeze()

cut = [[bb+ii for ii in range(2)] for bb in idx]
cut = np.asarray(cut).squeeze()
# print(cut.shape, *cut, dnum.shape)
dn = dnum[np.ix_(*cut)]
nn = number[idx][np.newaxis, np.newaxis, np.newaxis]

# print(ee)
# print(dn)
# print(nn)

temp = holo.sam._gws_from_number_grid(ee, dn, nn, realize=False, integrate=True)
print(f"{temp**2=}")

hc_test = holo.sam._gws_from_number_grid_integrated(ee, dn, nn)
print(f"{hc_test**2=}")

In [None]:
REALS = 10000
ee_sample = [np.log10(ee[0]), ee[1], ee[2], np.log(ee[3])]
cyc = 1.0 / np.diff(ee_sample[-1])
temp_reals = np.zeros(REALS)
for rr in utils.tqdm(range(REALS)):
    # vv, ww = kale.sample_outliers(ee_sample, dn, threshold=sample_threshold, mass=nn)
    vv = kale.sample_grid(ee_sample, dn, mass=nn)
    vv[0] = 10.0 ** vv[0]
    vv[3] = np.e ** vv[3]
    hs, fo = holo.sam._strains_from_samples(vv)
    # temp = np.sum(ww * (hs**2) * cyc)
    temp = np.sum((hs**2) * cyc)
    temp_reals[rr] = temp
    # print(ww.size, utils.stats(ww))
    # print(hs)
    
print(temp, np.percentile(temp_reals, [25, 50, 75]))

In [None]:
breaker()

## Load slice of SAM grid, sample, compare GWB calculations

In [None]:
def get_strain(data):
    """data should be (4, N) for [mtot, mrat, redz, fobs]"""
    mc = utils.chirp_mass(*utils.m1m2_from_mtmr(data[0], data[1]))
    dc = cosmo.comoving_distance(data[2]).cgs.value
    fr = utils.frst_from_fobs(data[3], data[2])
    hs = utils.gw_strain_source(mc, dc, fr/2.0)
    return hs

def slice_func(edges, dnum, number, zbin, fbin, sample_threshold):
    np.random.seed(12345)
    edges = np.array(edges).copy()
    edges[-2] = edges[-2][zbin:zbin+2]
    edges[-1] = edges[-1][fbin:fbin+2]
    for ii in [2, 3]:
        print(f"{ii=}, {edges[ii]=}")
        
    fobs = edges[-1]
    dnum = dnum[..., zbin:zbin+2, fbin:fbin+2]
    number = number[..., zbin:zbin+1, fbin:fbin+1]
    # print(f"{number.shape=}, {number.sum()=:.4e}, {utils.stats(number)=}")

    # ---- Grid Calculation

    # find weighted bin centers
    coms = np.meshgrid(*edges, indexing='ij')
    # get unweighted centers
    cent = kale.utils.midpoints(dnum, log=False, axis=(0, 1, 2, 3))
    # get weighted centers for each dimension
    for ii, cc in enumerate(coms):
        coms[ii] = kale.utils.midpoints(dnum * cc, log=False, axis=(0, 1, 2, 3)) / cent

    # calculate GW strain at bin centroids
    hs_grid = get_strain(coms)

    dlogf = np.diff(np.log(fobs))
    dlogf = dlogf[np.newaxis, np.newaxis, np.newaxis, :]
    cycles = 1.0 / dlogf

    hs_grid = np.nan_to_num(hs_grid)
    # (M',Q',Z',F) ==> (F,)
    gwb_grid = np.sqrt(np.sum(number*cycles*np.square(hs_grid), axis=(0, 1, 2)))[0]
    
    # ---- Sampled Calculation
    
    edges_sample = [np.log10(edges[0]), edges[1], edges[2], np.log(edges[3])]

    print(f"{number.sum()=:.8e}")
    print(f"{utils.stats(dnum.squeeze().flatten())=}")
    print("---- slice_func() :: sample_outliers() ----\n")
    vals, weights = kale.sample_outliers(
        edges_sample, np.log10(dnum), sample_threshold, mass=number,
    )
    print("\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^")
    # print(f"{weights.shape=}, {weights.sum()=:.4e}, {utils.stats(weights)=}")

    vals[0] = 10.0 ** vals[0]
    vals[3] = np.e ** vals[3]
    
    hs_samp = get_strain(vals)

    for ii in [2, 3]:
        # print(f"{ii=}, {edges[ii]=}")
        # print(f"\t{utils.stats(vals[ii], prec=4)}")
        assert np.all((edges[ii][0] <= vals[ii]) & (vals[ii] <= edges[ii][-1]))

    # cycles = 0.5 * np.sum(fextr) / np.diff(fextr)[0]
    assert len(fobs) == 2
    cycles = 1.0 / np.diff(np.log(fobs))[0]
    gwb_sample = np.sqrt(np.sum(weights * cycles * (hs_samp ** 2)))
        
    return gwb_grid, gwb_sample, hs_grid, hs_samp, number, dnum, vals, weights, coms

fbin = 0
zbin = zmath.argnearest(edges[2], 0.1)
# zbin = 30
# print(f"{fbin=}, {zbin=} :: {YR*edges[-1][[fbin, fbin+1]]}, {edges[-2][[zbin, zbin+1]]}")

# SAMPLE_THRESHOLD = -1
SAMPLE_THRESHOLD = 1e2
gwb_grid, gwb_sample, hs_grid, hs_samp, numcut, dncut, vals, weights, coms = slice_func(
    edges, dnum, number, zbin, fbin, sample_threshold=SAMPLE_THRESHOLD
)

err = (gwb_sample - gwb_grid) / gwb_grid
print(f"{gwb_grid=:.4e}, {gwb_sample=:.4e}, {err=:.4e}")

In [None]:
vv = vals[:2]
dist = hs_grid.squeeze()**2
hist, *ee, idx = sp.stats.binned_statistic_2d(
    *vv, hs_samp**2,
    bins=(edges[0], edges[1]), statistic='sum', expand_binnumbers=True
)

In [None]:
SEL_BIN = 700
print(f"{SEL_BIN=}")

whist, *ee, idx = sp.stats.binned_statistic_2d(
    *vv, weights,
    bins=(edges[0], edges[1]), statistic='sum', expand_binnumbers=True
)

# cs_dist = np.cumsum(dist.flatten())
# cs_hist = np.cumsum(hist.flatten())
cs_dist = dist.flatten()
cs_hist = hist.flatten()
err = (cs_hist - cs_dist) / cs_dist
# err = np.fabs(err)

fig, ax = plot.figax(xlim=[600, 1e3])
ax.plot(cs_dist, label='dist')
ax.plot(cs_hist, label='hist')

ax.axvline(SEL_BIN, color='r', ls='--', alpha=0.25)
print(f"{cs_dist[SEL_BIN]=} {cs_hist[SEL_BIN]=}")
print(f"{np.sqrt(cs_dist[SEL_BIN])=} {np.sqrt(cs_hist[SEL_BIN])=}")

ax.legend()
# tw.legend()
plt.show()

In [None]:
bin = np.unravel_index(SEL_BIN, hist.shape)
num = numcut.squeeze()[bin]
print(f"{SEL_BIN=} ==> {bin=}, {num=:.8e}")
cut = [[bb+ii for ii in range(2)] for bb in bin]
dn = dncut[np.ix_(*cut)]
# print(f"{dn=}")

print("edges = ")
for ii in range(2):
    print(edges[ii][bin[ii]], edges[ii][bin[ii]+1])

idx = (edges[0][bin[0]] < vals[0]) & (vals[0] < edges[0][bin[0]+1])
idx = idx & (edges[1][bin[1]] < vals[1]) & (vals[1] < edges[1][bin[1]+1])
print("vals = ")
print(np.count_nonzero(idx), utils.stats(weights[idx]))
xx = vals[0][idx].copy()
yy = vals[1][idx].copy()

# ---- plot COM of bin and bin-edges

cc = np.array(coms)[:2].squeeze()
zz = [cc[ii][bin] for ii in range(2)]

fig, ax = plot.figax()

for ii, tt in enumerate(zz):
    assert (edges[ii][bin[ii]] < tt) & (tt < edges[ii][bin[ii]+1])
    ax.axvline(edges[0][bin[0]+ii], color='r', ls='--', alpha=0.25)
    ax.axhline(edges[1][bin[1]+ii], color='r', ls='--', alpha=0.25)

# print(zz)
ax.scatter(*zz, marker='x')

# ---- Plot sampled points in bin, and their average

cc = ax.scatter(xx, yy, marker='.')
xave = np.mean(xx)
yave = np.mean(yy)
ax.scatter(xave, yave, marker='+', color=cc.get_facecolor(), s=100, lw=1.0)

# ---- strains

# get strain from COM
temp = np.array(coms).squeeze()
temp = np.moveaxis(temp, 0, -1)[bin]
hs_grid = get_strain(temp)
hs_grid = np.sqrt(num * hs_grid**2)

# get strain from samples
temp = [vv[idx] for vv in vals]
temp[-2] = np.ones_like(temp[0]) * coms[-2].flatten()[0]
temp[-1] = np.ones_like(temp[0]) * coms[-1].flatten()[0]
# print(f"{temp=}")
hs_samp = get_strain(temp)

# get strain from average of samples
hs_ave = [xave, yave, temp[2][0], temp[3][0]]
hs_ave = get_strain(hs_ave)
hs_ave = np.sqrt(num * hs_ave**2)

fig, ax = plot.figax()
ax.axhline(hs_grid, ls='--', color='k')
ax.axhline(hs_ave, ls='--', color='r')
ax.plot(np.sqrt(np.cumsum(np.sort(hs_samp)**2)))

hs_samp = np.sqrt(np.sum(hs_samp**2))

print(f"grid strain = {hs_grid**2:.8e}")
print(f"samp strain = {hs_samp**2:.8e}")
print(f"ave  strain = {hs_ave**2:.8e}")


plt.show()


In [None]:
breaker()

## Compare distribution of samples to grid (2D slice)

### Weights

In [None]:
fig, axes = plot.figax(figsize=[12, 4], ncols=3, grid=False)

xx, yy = np.meshgrid(edges[0], edges[1], indexing='ij')
hist, *_ = np.histogram2d(vals[0], vals[1], bins=(edges[0], edges[1]), weights=weights)
dist = numcut.squeeze()

extr = zmath.minmax(hist, prev=zmath.minmax(numcut), limit=[0.1/hist.size, None])
# smap = plot.smap(extr, log=True)
smap = plot.smap(extr, log=True, midpoint=1.0, cmap='bwr')

ax = axes[0]
pcm = ax.pcolormesh(xx, yy, dist, cmap=smap.cmap, norm=smap.norm)
plt.colorbar(pcm, ax=ax)

ax = axes[1]
pcm = ax.pcolormesh(xx, yy, hist, cmap=smap.cmap, norm=smap.norm)
plt.colorbar(pcm, ax=ax)

ax = axes[2]
diff = (hist - dist) / dist
diff = np.nan_to_num(diff)
smap = plot.smap(diff, log=False, midpoint=0.0, cmap='bwr')
print(f"{utils.stats(diff)=}")

pcm = ax.pcolormesh(xx, yy, diff, cmap=smap.cmap, norm=smap.norm)
plt.colorbar(pcm, ax=ax)

plt.show()

### Strains

In [None]:
# hs_grid.shape, hs_samp.shape
vv = vals[:2]
dist = hs_grid.squeeze()**2
hist, *ee, idx = sp.stats.binned_statistic_2d(
    *vv, hs_samp**2,
    bins=(edges[0], edges[1]), statistic='sum', expand_binnumbers=True
)
grid = np.meshgrid(edges[0], edges[1], indexing='ij')

fig, axes = plot.figax(figsize=[12, 4], ncols=3, grid=False)

extr = zmath.minmax(dist, prev=zmath.minmax(hist, filter='>'), filter='>')
smap = plot.smap(extr, log=True)

ax = axes[0]
pcm = ax.pcolormesh(*grid, dist, cmap=smap.cmap, norm=smap.norm)
plt.colorbar(pcm, ax=ax)

ax = axes[1]
pcm = ax.pcolormesh(*grid, hist, cmap=smap.cmap, norm=smap.norm)
plt.colorbar(pcm, ax=ax)

ax = axes[2]
diff = (hist - dist) / dist
diff = np.nan_to_num(diff)
smap = plot.smap(diff, log=False, midpoint=0.0, cmap='bwr')
print(f"{utils.stats(diff)=}")

pcm = ax.pcolormesh(*grid, diff, cmap=smap.cmap, norm=smap.norm)
plt.colorbar(pcm, ax=ax)

plt.show()

In [None]:
whist, *ee, idx = sp.stats.binned_statistic_2d(
    *vv, weights,
    bins=(edges[0], edges[1]), statistic='sum', expand_binnumbers=True
)

# cs_dist = np.cumsum(dist.flatten())
# cs_hist = np.cumsum(hist.flatten())
cs_dist = dist.flatten()
cs_hist = hist.flatten()
err = (cs_hist - cs_dist) / cs_dist
# err = np.fabs(err)

fig, ax = plot.figax(xlim=[600, 1e3])
ax.plot(cs_dist, label='dist')
ax.plot(cs_hist, label='hist')

'''
tw = ax.twinx()

tw.set(yscale='log')
tw.set(ylim=[1e-1, 1e4])
# tw.set(ylim=[-1.2, 1.2])
# tw.plot(err, 'k--', label='err')

tw.plot(numcut.flatten(), ls='--')
tw.plot(whist.flatten(), ls='--')
# tw.axhline(36, color='r', ls='--', alpha=0.25)
'''

ax.axvline(700, color='r', ls='--', alpha=0.25)
print(f"{cs_dist[700]=} {cs_hist[700]=}")
print(f"{np.sqrt(cs_dist[700])=} {np.sqrt(cs_hist[700])=}")

ax.legend()
# tw.legend()
plt.show()

In [None]:
idx = (17, 20)
hist[idx], dist[idx]

In [None]:
bin = 700
bin = np.unravel_index(700, hist.shape)
num = numcut.squeeze()[bin]
print(f"{bin=}, {num=:.8e}")
cut = [[bb+ii for ii in range(2)] for bb in bin]
dn = dncut[np.ix_(*cut)]
# print(f"{dn=}")

for ii in range(2):
    print(edges[ii][bin[ii]], edges[ii][bin[ii]+1])

idx = (edges[0][bin[0]] < vals[0]) & (vals[0] < edges[0][bin[0]+1])
idx = idx & (edges[1][bin[1]] < vals[1]) & (vals[1] < edges[1][bin[1]+1])
print(np.count_nonzero(idx), utils.stats(weights[idx]))
xx = vals[0][idx].copy()
yy = vals[1][idx].copy()

# ---- plot COM of bin and bin-edges

cc = np.array(coms)[:2].squeeze()
zz = [cc[ii][bin] for ii in range(2)]

fig, ax = plot.figax()

for ii, tt in enumerate(zz):
    assert (edges[ii][bin[ii]] < tt) & (tt < edges[ii][bin[ii]+1])
    ax.axvline(edges[0][bin[0]+ii], color='r', ls='--', alpha=0.25)
    ax.axhline(edges[1][bin[1]+ii], color='r', ls='--', alpha=0.25)

# print(zz)
ax.scatter(*zz, marker='x')

# ---- Plot sampled points in bin, and their average

cc = ax.scatter(xx, yy, marker='.')
xave = np.mean(xx)
yave = np.mean(yy)
ax.scatter(xave, yave, marker='+', color=cc.get_facecolor(), s=100, lw=1.0)

# ---- strains

# get strain from COM
temp = np.array(coms).squeeze()
temp = np.moveaxis(temp, 0, -1)[bin]
hs_grid = get_strain(temp)
hs_grid = np.sqrt(num * hs_grid**2)

# get strain from samples
temp = [vv[idx] for vv in vals]
temp[-2] = np.ones_like(temp[0]) * coms[-2].flatten()[0]
temp[-1] = np.ones_like(temp[0]) * coms[-1].flatten()[0]
# print(f"{temp=}")
hs_samp = get_strain(temp)

# get strain from average of samples
hs_ave = [xave, yave, temp[2][0], temp[3][0]]
hs_ave = get_strain(hs_ave)
hs_ave = np.sqrt(num * hs_ave**2)

fig, ax = plot.figax()
ax.axhline(hs_grid, ls='--', color='k')
ax.axhline(hs_ave, ls='--', color='r')
ax.plot(np.sqrt(np.cumsum(np.sort(hs_samp)**2)))

hs_samp = np.sqrt(np.sum(hs_samp**2))

print(f"grid strain = {hs_grid:.8e}")
print(f"samp strain = {hs_samp:.8e}")
print(f"ave  strain = {hs_ave:.8e}")


plt.show()


In [None]:
breaker()

## Single Grid-Cell Test Case

In [None]:
use_dn

In [None]:
ee = [
    [1.4511731181666856e+41, 2.0151177932682267e+41],
    [0.51, 0.5345000000000001],
]

MULT = 10.0

use_dn = dn[:, :, 0, 0].copy()
vv, ww = kale.sample_outliers(ee, use_dn, 100*MULT, mass=MULT*num[np.newaxis, np.newaxis])
print(f"Loaded {ww.size} outliers (mult={MULT:.2f}), ww={utils.stats(ww)}")

fig, ax = plot.figax()
for ii, plotfunc in enumerate([ax.axvline, ax.axhline]):
    for jj in range(2):
        plotfunc(ee[ii][jj], color='r', ls='--', alpha=0.25)
                
cc = ax.scatter(*vv, marker='.', alpha=0.5)
ave = [np.mean(vv[ii]) for ii in range(2)]
ax.scatter(*ave, marker='o', s=100, facecolor=cc.get_facecolor(), edgecolor='r', zorder=100, lw=2.0)

# cent = kale.utils.centroids(np.meshgrid(*ee, indexing='ij'), dn[:, :, 0, 0])
cent = kale.utils.centroids(ee, dn[:, :, 0, 0])
cc = ax.scatter(*cent, marker='x', color='r', s=100)
        
plt.show()
        

In [None]:
bin = 700
bin = np.unravel_index(700, hist.shape)
num = numcut.squeeze()[bin]
print(f"{bin=}, {num=:.8e}")
cut = [[bb+ii for ii in range(2)] for bb in bin]
dn = dncut[np.ix_(*cut)]
print(f"{dn=}")

for ii in range(2):
    print(edges[ii][bin[ii]], edges[ii][bin[ii]+1])

idx = (edges[0][bin[0]] < vals[0]) & (vals[0] < edges[0][bin[0]+1])
idx = idx & (edges[1][bin[1]] < vals[1]) & (vals[1] < edges[1][bin[1]+1])
print(np.count_nonzero(idx), utils.stats(weights[idx]))
xx = vals[0][idx].copy()
yy = vals[1][idx].copy()

cc = np.array(coms)[:2].squeeze()
zz = [cc[ii][bin] for ii in range(2)]

fig, ax = plot.figax()


for ii, tt in enumerate(zz):
    assert (edges[ii][bin[ii]] < tt) & (tt < edges[ii][bin[ii]+1])

    ax.axvline(edges[0][bin[0]+ii], color='r', ls='--', alpha=0.25)
    ax.axhline(edges[1][bin[1]+ii], color='r', ls='--', alpha=0.25)

print(zz)

ax.scatter(*zz, marker='x')
cc = ax.scatter(xx, yy, marker='.')
xave = np.mean(xx)
yave = np.mean(yy)
ax.scatter(xave, yave, marker='+', color=cc.get_facecolor(), s=100, lw=1.0)

# strains
temp = np.array(coms).squeeze()
temp = np.moveaxis(temp, 0, -1)[bin]
hs_grid = get_strain(temp)

temp = [vv[idx] for vv in vals]
temp[-2] = np.ones_like(temp[0]) * coms[-2].flatten()[0]
temp[-1] = np.ones_like(temp[0]) * coms[-1].flatten()[0]
print(f"{temp=}")
hs_samp = get_strain(temp)
hs_grid = np.sqrt(num * hs_grid**2)

hs_ave = [xave, yave, temp[2][0], temp[3][0]]
hs_ave = get_strain(hs_ave)
hs_ave = np.sqrt(num * hs_ave**2)
print(f"{hs_ave=}")

fig, ax = plot.figax()
ax.axhline(hs_grid, ls='--', color='k')
ax.axhline(hs_ave, ls='--', color='r')
ax.plot(np.sqrt(np.cumsum(np.sort(hs_samp)**2)))

hs_samp = np.sqrt(np.sum(hs_samp**2))

print(f"grid strain = {hs_grid:.8e}")
print(f"samp strain = {hs_samp:.8e}")
print(f"ave  strain = {hs_ave:.8e}")


plt.show()


In [None]:
a = np.random.uniform(0.0, 1.0, (2, 3, 4))
i = np.argmin(a, axis=0)[np.newaxis, ...]
print(a, i)
print(a.shape, i.shape)
print(np.take_along_axis(a, i, 0))
# print(a[i])

In [None]:
def get_coms_sample(xx, yy, num=1e4):
    mass = kale.utils.trapz_dens_to_mass(yy, xx)
    mass = mass * num / mass.sum()
    vv = kale.sample_grid(xx, yy, mass=mass)
    coms = [sp.stats.binned_statistic_dd(vv.T, vv[ii], statistic='mean', bins=xx)[0] for ii in range(len(vv))]
    return np.asarray(coms), vv


def get_coms_really(edges, yy):
    yy = np.asarray(yy)

    # shape of vertices ('corners') of each bin
    shp_corners = [2,] * yy.ndim
    # shape of bins
    shp_bins = [sh - 1 for sh in yy.shape]

    # ---- Get the y-values (densities) for each corner, for each bin

    # for a 2D grid, `zz[0, 0, :, :]` would be the lower-left,
    # while `zz[1, 0, :, :]` would be the lower-right
    zz = np.zeros(shp_corners + shp_bins)
    # iterate over all permutations of corners
    #     get a tuple specifying left/right edge for each dimension, e.g.
    #     (0, 1, 0) would be (left, right, left) for 3D
    for idx in np.ndindex(tuple(shp_corners)):
        cut = []
        # for each dimension, get a slicing object to get the left or right edges along that dim
        for dd, ii in enumerate(idx):
            # ii=0 ==> s=':-1'   ii=1 ==> s='1:'
            s = slice(ii, yy.shape[dd] - (ii+1)%2)
            cut.append(s)

        # for this corner (`idx`) select the y-values (densities) at that corner
        zz[idx] = yy[tuple(cut)]

    # ---- Calculate the centers of mass in each dimension
        
    coms = np.zeros([yy.ndim,] + shp_bins)
    for ii in range(yy.ndim):
        # sum over both corners, for each dimension *except* for `ii`
        jj = np.arange(yy.ndim).tolist()
        jj.pop(ii)
        # y1 is the left  corner along this dimension, marginalized (summed) over all other dims
        # y2 is the right corner along this dimension
        y1, y2 = np.sum(zz, axis=tuple(jj))

        # bin width in this dimension, for each bin
        dx = np.diff(edges[ii])
        # make `dx` broadcastable to the same shape as bins (i.e. `shp_bins`)
        cut = [np.newaxis for dd in range(yy.ndim-1)]
        cut.insert(ii, slice(None))
        cut = tuple(cut)
        _dx = dx[cut]

        xstack = [edges[ii][:-1], edges[ii][1:]]
        xstack = [np.asarray(xs)[cut] for xs in xstack]
        xstack = np.asarray(xstack)
        ystack = [y1, y2]
        # we need to know which direction each triangle is facing, find the index of the min y-value
        #     0 is left, 1 is right
        idx_min = np.argmin(ystack, axis=0)[np.newaxis, ...]

        # get the min and max y-values; doesn't matter if left or right for these
        y1, y2 = np.min(ystack, axis=0), np.max(ystack, axis=0)

        # ---- Calculate center of mass for trapezoid 
        
        # - We have marginalized over all dimensions except for this one, so we can consider the 1D
        #   case that looks like this:
        #
        #       /| y2
        #      / |
        #     /  |
        #    |---| y1
        #    |   |
        #    |___|
        #
        # - We will calculate the COM for the rectangle and the triangle separately, and then get
        #   the weighted COM between the two, where the weights are given by the areas
        # - `a1` and `x1` will be the area (i.e. mass) and x-COM for the rectangle.
        #   The x-COM is just the midpoint, because the y-values are the same
        # - `a2` and `x2` will be the area and x-COM for the triangle
        #   NOTE: for the triangle, it's direction matters.  For each bin, `idx_min` tells the
        #         direction: 0 means increasing (left-to-right), and 1 means decreasing.
        a1 = _dx * y1
        a2 = 0.5 * _dx * (y2 - y1)
        x1 = np.mean(xstack, axis=0)
        # get the x-value for the low y-value
        xlo = np.take_along_axis(xstack, idx_min, 0)[0]
        # make `dx` for each bin positive or negative, depending on the orientation of the triangle
        x2 = xlo + (2.0/3.0)*_dx*(1 - 2*idx_min.squeeze())
        coms[ii] = (x1 * a1 + x2 * a2) / (a1 + a2)

    return coms


# np.random.seed(1)
# shape = (4, 5)
# shape = (3, 3)
shape = (3, 4)
yy = np.random.uniform(0.0, 10.0, shape)
xx = [sorted(np.random.uniform(0.0, 1.0, sh)) for sh in yy.shape]

coms_test = get_coms_really(xx, yy).squeeze()
coms_sample, vv = get_coms_sample(xx, yy)

coms_test = np.reshape(coms_test, (2, -1))
coms_sample = np.reshape(coms_sample, (2, -1))

fig, ax = plot.figax(scale='lin')
ax.scatter(*vv, s=5, alpha=0.25)
ax.scatter(*coms_test, s=100, marker='+', color='r', alpha=0.35)
ax.scatter(*coms_sample, s=200, marker='x', color='r', alpha=0.35)
for ii, (_ee, line) in enumerate(zip(xx, [ax.axvline, ax.axhline])):
    for ee in _ee:
        line(ee, color='k', ls=':', alpha=0.5)

plt.show()

In [None]:
dens = [
    [0.0, 0.0, 0.0],
    [1.0, 2.0, 3.0],
]
ee = [
    [0.0, 1.0],
    [3.0, 4.0, 5.0],
]

dens = np.asarray(dens)

def get_coms(edges, dens):
    shp = [sh - 1 for sh in np.shape(dens)]
    ndim = len(shp)
    coms = np.zeros([ndim,] + shp)

    for ii in range(ndim):
        jj = list(np.arange(ndim))
        jj.pop(ii)
        xx = edges[ii]
        yy = np.sum(dens, axis=tuple(jj))
        xstack = np.vstack([xx[:-1], xx[1:]])
        ystack = np.vstack([yy[:-1], yy[1:]])
        a1 = np.diff(xx) * np.min(yy, axis=0)
        a2 = 0.5 * np.diff(xx) * np.diff(yy)
        x1 = np.mean(xstack, axis=0)
        x2 = np.min(xstack, axis=0) + (2.0/3.0)*np.diff(xx)

        coms[ii] = (x1 * a1 + x2 * a2) / (a1 + a2)

    return coms

cent = get_coms(ee, dens)

In [None]:
dens = [
    [0.0, 0.0],
    [1.0, 2.0],
]
ee = [
    [0.0, 1.0],
    [3.0, 4.0],
]
# dens = [
#     [0.0, 0.0, 0.0],
#     [1.0, 2.0, 3.0],
# ]
# ee = [
#     [0.0, 1.0],
#     [3.0, 4.0, 5.0],
# ]
num = 1e4

thresh = np.inf
# mass = np.atleast_2d(num)
mass = kale.utils.trapz_dens_to_mass(dens, ee)
mass = mass * num / mass.sum()
# print(f"{mass=}")

vv, ww = kale.sample_outliers(ee, dens, thresh, mass=mass)
print(np.mean(vv, axis=1))
corner, _ = kale.corner(vv, edges=ee, kwcorner=dict(origin='bl'), dist2d=dict(contour=False, hist=False, median=False))

# cent = [dens * np.moveaxis(np.array(ee[ii])[:, np.newaxis], 1, (ii + 1) % 2) for ii in range(2)]
# cent = [np.sum(cent[ii], axis=ii) / np.sum(dens, axis=ii) for ii in range(2)]
# cent = [np.mean(cent[ii]) for ii in range(2)]
cent = np.meshgrid(*ee, indexing='ij')
cent = [np.average(cc, weights=dens) for cc in cent]
print(f"{cent=}")
cent = np.meshgrid(*ee, indexing='ij')
print(f"{cent[0]=}")
cent = [np.sum(cc*dens) / np.sum(dens) for cc in cent]
print(f"{cent=}")

axes = corner.axes
ax = axes[1, 0]
kw = dict(color='r', alpha=0.5)
for ii, (cc, line) in enumerate(zip(cent, [ax.axvline, ax.axhline])):
    line(cc, **kw)
    line(np.mean(vv[ii]), ls=(0, [2, 4]), lw=2.0, **kw)

plt.show()

In [None]:
# xx = np.meshgrid(*ee, indexing='ij')[0]
# np.sum(dens, axis=1)
# np.diff

xx = [0.0, 1.0]
yy = [10.0, 20.0]
a1 = np.diff(xx) * np.min(yy)
a2 = 0.5 * np.diff(xx) * np.diff(yy)
print(a1, a2)
x1 = np.mean(xx)
x2 = np.min(xx) + (2.0/3.0)*np.diff(xx)
print(x1, x2)
xave = (x1 * a1 + x2 * a2) / (a1 + a2)
print(xave)

np.meshgrid(xx, yy)