In [None]:
# MOD per system
import sys
sys.path.append('C:/Users/emiga/OneDrive/Cal/GWs/code/holodeck') 


# %load ../init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys
import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sam
from holodeck import cosmo, utils, plot
from holodeck.constants import MSOL, PC, YR, MPC, GYR

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

log = holo.log
log.setLevel(logging.INFO)

# 1 Functions

## 1.1 Make Examples
1) Choose the frequency bins at which to calculate the GWB, same as in semi-analytic-models.ipynb
2) Build Semi-Analytic-Model with super simple parameters 
3) Get SAM edges and numbers as in sam.gwb()

In [None]:
def example(dur, cad, mtot, mrat, redz, print_test):
    ''' 
    1) Choose the frequency bins at which to calculate the GWB, same as in semi-analytic-models.ipynb
    2) Build Semi-Analytic-Model with super simple parameters 
    3) Get SAM edges and numbers as in sam.gwb()

    Parameters
    ----------oiuyuiuytd
    dur : scalar
        Duration of observation in secnods (multiply by YR)
    cad : scalar
        Cadence of observations in seconds (multiply by YR)
    mtot : (3,) list of scalars
        Min, max, and steps for total mass.
    mrat : (3,) list of scalars
        Min, max, and steps for mass ratio.
    redz : (3,) list of scalars
        Min, max, and steps for redshift.
    print_test :

    Returns
    -------
    edges : (4,) list of 1darrays
        A list containing the edges along each dimension.  The four dimensions correspond to
        total mass, mass ratio, redshift, and observer-frame orbital frequency.
        The length of each of the four arrays is M, Q, Z, F.
    number : (M-a, Q-1, Z-1, F-1) array
        The number of binaries in each bin of parameter space.  This is calculated by integrating
        `dnum` over each bin.
    fobs : (F-1) array
        observed frequency bin centers
    '''
    # 1) Choose the frequency bins at which to calculate the GWB, same as in semi-analytic-models.ipynb
    fobs = utils.nyquist_freqs(dur,cad)
    fobs_edges = utils.nyquist_freqs_edges(dur,cad)
    if(print_test):
        print(f"Number of frequency bins: {fobs.size-1}")
        print(f"  between [{fobs[0]*YR:.2f}, {fobs[-1]*YR:.2f}] 1/yr")
        print(f"          [{fobs[0]*1e9:.2f}, {fobs[-1]*1e9:.2f}] nHz")

    # 2) Build Semi-Analytic-Model with super simple parameters 
    if(mtot==None or mrat==None or redz==None):
        print('using default mtot, mrat, and redz')
        sam = holo.sam.Semi_Analytic_Model()
    else:
        sam = holo.sam.Semi_Analytic_Model(mtot, mrat, redz)
    if(print_test):
        print('edges:', sam.edges)
    # get observed orbital frequency bin edges and centers 
    # from observed GW frequency bin edges
    fobs_orb_edges = fobs_edges / 2.0 # f_orb = f_GW/2
    fobs_orb_cents = kale.utils.midpoints(fobs_edges) / 2.0

    # 3) Get SAM edges and numbers as in sam.gwb()
    # dynamic_binary_number
    # gets differential number of binaries per bin-vol per log freq interval
    edges, dnum = sam.dynamic_binary_number(holo.hardening.Hard_GW, fobs_orb=fobs_orb_cents)
    edges[-1] = fobs_orb_edges

    # integrate (multiply by bin volume) within each bin
    number = utils._integrate_grid_differential_number(edges, dnum, freq=False)
    number = number * np.diff(np.log(fobs_edges))

    return edges, number, fobs

Example 2

In [None]:
def example2(print_test = True, exname='Example 2'):
    ''' 
    Parameters
    ---------
    print_test : Bool
        Whether to print frequencies and edges


    Returns
    ---------
    edges : (M,Q,Z,F) array
    number : (M-a, Q-1, Z-1, F-1) array
    fobs : (F-1) array
        observed frequency bin centers
    '''
    
    dur = 5.0*YR/3.1557600
    cad = .5*YR/3.1557600
    
    mtot=(1.0e6*MSOL/1.988409870698051, 1.0e8*MSOL/1.988409870698051, 3)
    mrat=(1e-1, 1.0, 2)
    redz=(1e-3, 1.0, 4)
    
    edges, number, fobs = example(dur, cad, mtot, mrat, redz, print_test)
    return edges, number, fobs, exname

Example 3

In [None]:
def example3(print_test = True, exname = 'Example 3'):
    ''' 
    Parameters
    ---------
    print_test : Bool
        Whether to print frequencies and edges


    Returns
    ---------
    edges : (M,Q,Z,F) array
    number : (M-a, Q-1, Z-1, F-1) array
    fobs : (F-1) array
        observed frequency bin centers
    '''
    dur = 5.0*YR/3.1557600
    cad = .5*YR/3.1557600
    

    mtot=(1.0e6*MSOL/1.988409870698051, 4.0e9*MSOL, 25)
    mrat=(1e-1, 1.0, 25)
    redz=(1e-3, 10.0, 25)

    edges, number, fobs = example(dur, cad, mtot, mrat, redz, print_test)
    return edges, number, fobs, exname

Example 4

In [None]:
def example4(print_test = True, exname = 'Example 4'):
    ''' 
    Parameters
    ---------
    print_test : Bool
        Whether to print frequencies and edges


    Returns
    ---------
    edges : (M,Q,Z,F) array
    number : (M-a, Q-1, Z-1, F-1) array
    fobs : (F-1) array
        observed frequency bin centers
    '''
    dur = 5.0*YR/3.1557600
    cad = .2*YR/3.1557600

    mtot=(1.0e6*MSOL/1.988409870698051, (4.0e11*MSOL).astype(np.float64), 25)
    mrat=(1e-1, 1.0, 25)
    redz=(1e-3, 10.0, 25)
    
    edges, number, fobs = example(dur, cad, mtot, mrat, redz, print_test)
    return edges, number, fobs, exname

Example 5 (same as in semi-analytic-models.ipynb)

In [None]:
def example5(print_test = True, exname = 'Example 5'):
    ''' 
    Parameters
    ---------
    print_test : Bool
        Whether to print frequencies and edges


    Returns
    ---------
    edges : (M,Q,Z,F) array
    number : (M-a, Q-1, Z-1, F-1) array
    fobs : (F-1) array
        observed frequency bin centers
    '''
    dur = 10.0*YR
    cad = .2*YR

    # default mtot, mrat, redz
    
    edges, number, fobs = example(dur, cad, mtot=None, mrat=None, redz=None, 
                                  print_test=print_test)
    return edges, number, fobs, exname

## 1.2 SS Calculations
Contains: 
- ss_gws_by_loops
- gws_by_ndars (same purpose as gravwaves._gws_from_number_grid_integrated())
- ss_gws_by_ndars
- subtraction_from_number method (no longer needed)
Could add:
- get sspar from edges and ssidx (instead of calculating and returning sspar)
- get bgnum from number and ssidx (instead of ss methods returning bgnum)


ss_gws_by_loops \
rounds loops first

In [None]:
def ss_gws_by_loops(edges, number, realize=False, round=True, ss=True, 
                    sum=True, print_test = False):
       
    """ Inefficient way to calculate strain from numbered 
    grid integrated

    Parameters
    ----------
    edges : (4,) list of 1darrays
        A list containing the edges along each dimension.  The four dimensions correspond to
        total mass, mass ratio, redshift, and observer-frame orbital frequency.
        The length of each of the four arrays is M, Q, Z, F.
    number : (M-1, Q-1, Z-1, F-1) ndarray
        The number of binaries in each bin of parameter space.  This is calculated by integrating
        `dnum` over each bin.
    realize : bool or int,
        Specification of how to construct one or more discrete realizations.
        If a `bool` value, then whether or not to construct a realization.
        If a `int` value, then how many discrete realizations to construct.
    round : bool
        Specification of whether to discretize the sample if realize is False, 
        by rounding number of binaries in each bin to integers. 
    ss : bool 
        Whether or not to separate the loudest single source in each frequency bin.
    sum : bool
        Whether or not to sum the strain at a given frequency over all bins.
    print_test : bool
        Whether or not to print variable as they are calculated, for dev purposes.


    Returns
    -------
    hc_bg : ndarray
        Characteristic strain of the GWB.
        The shape depends on whether realize is an integer or not
        realize = True or False, sum = False: shape is (M-1, Q-1, Z-1, F-1)
        realize = True or False, sum = True: shape is (F-1)
        realize = R, sum = False: shape is  (M-1, Q-1, Z-1, F-1, R)
        realize = R, sum = True: shape is  (F-1, R)
    hc_ss : (F-1) array
        The characteristic strain of the loudest single source at each frequency.
    sspar : (F-1, 3) 2darray or None
        The parameters (M, q, and z) of the loudest single source at each frequency.
        None if ss = False. 
    ssidx : (F-1, 3) 2darray or None
        The indices (m_idx, q_idx, and z_idx) of the parameters of the loudest single
        source's bin, at each frequency.    
        None if ss = False. 
    maxhs : (F-1) array or None
        The maximum single source strain amplitude at each frequency.
        None if ss = False. 
    bgnum : (M-1, Q-1, Z-1, F-1) 
        The number of binaries in each bin after the loudest single source
        at each frequency is subtracted out.

    """
    if(print_test):
        print('INPUTS: edges:', len(edges), '\n', edges, 
        '\nINPUTS:number:', number.shape, '\n', number,'\n')

    # Frequency bin midpoints
    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)

    # All other bin midpoints
    mt = kale.utils.midpoints(edges[0]) #: total mass
    mr = kale.utils.midpoints(edges[1]) #: mass ratio
    rz = kale.utils.midpoints(edges[2]) #: redshift

    
    # GW background characteristic strain
    hc_bg = np.empty_like(number)

    # new number array
    if(round == True):
        bgnum = np.copy(np.floor(number).astype(np.int64))
        if(print_test):
            print('noninteger bgnum values:', bgnum[np.where(bgnum%1 !=0)])
    else:
        bgnum = np.copy(number)

    # for single sources, make a grid with shape
    # (f, 3)
    if(ss == True):
        # params of loudest bin with number>=1
        # shape (f,3) for 3 params
        sspar = np.empty((len(fc), 3)) 
        # param indices of loudest bin with number>=1
        # shape (f,3) for 3 params
        ssidx = np.empty((len(fc), 3)) 
        # max hs at each frequency
        maxhs = np.zeros(len(fc))
        # (max)  single source characteristic strain at each frequency
        hc_ss = np.zeros(len(fc))
    else: 
        sspar = None
        ssidx = None
        maxhs = None
        hc_ss = None

    # # not worrying about realization implementation yet
    # # for r realizations, make a grid with shape 
    # # m, q, z, f, r
    # if(utils.isinteger(realize)):
    #     newshape = hc_grid.shape + (realize,)
    #     if(print_test):
    #         print('newshape:', newshape)
    #     realized_grid = np.empty(newshape)


    # --------------- Single Sources ------------------
    # 0) Round or realize so numbers are all integers
    # 1) Identify the loudest (max hs) single source in a bin with N>0 
    # 2) Record the parameters, parameter indices, and strain
    #  of that single source
    # 3) Subtract 1 from the number in that source's bin, 
    # 4) Calculate single source characteristic strain (hc)
    # 5) Calculate the background with the new number 
    
    if(ss == True):
        for m_idx in range(len(mt)):
            for q_idx in range(len(mr)):
                cmass = holo.utils.chirp_mass_mtmr(mt[m_idx], mr[q_idx])
                for z_idx in range(len(rz)):
                    cdist = holo.cosmo.comoving_distance(rz[z_idx]).cgs.value
                    
                    # print M, q, z, M_c, d_c
                    if(print_test):
                        print('BIN mt=%.2e, mr=%.2e, rz=%.2e' %
                            (mt[m_idx], mr[q_idx], rz[z_idx]))
                        print('\t m_c = %.2e, d_c = %.2e' 
                            % (cmass, cdist))

                    # check if loudest source in any bin
                    for f_idx in range(len(fc)):
                        rfreq = holo.utils.frst_from_fobs(fc[f_idx], rz[z_idx])
                        # hs of a source in that bin
                        hs_mqzf = utils.gw_strain_source(cmass, cdist, rfreq)
                        
                        # 1) IF LOUDEST
                        # check if loudest hs at that 
                        # frequency and contains binaries
                        if(hs_mqzf>maxhs[f_idx] and 
                           bgnum[m_idx, q_idx, z_idx, f_idx]>0):
                            if(bgnum[m_idx, q_idx, z_idx, f_idx]<1):
                                print('number<1 used', bgnum[m_idx, q_idx, z_idx, f_idx])  #DELETE
                            # 2) If so, RECORD:
                            # parameters M, q, z
                            sspar[f_idx] = np.array([mt[m_idx], mr[q_idx],
                                                     rz[z_idx]])
                            # parameter indices
                            ssidx[f_idx] = np.array([m_idx, q_idx, z_idx])
                            # new max strain
                            maxhs[f_idx] = hs_mqzf


        # 3) SUBTRACT 1 
        # from bin with loudest source at each frequency
        # can do this using the index of loudest, ssidx
        # recall ssidx has shape [3, F]
        # and = [(m_idx,q_idx,z_idx), fc],
        for f_idx in range(len(fc)):
            bgnum[int(ssidx[f_idx,0]), int(ssidx[f_idx,1]), int(ssidx[f_idx,2]), 
                    f_idx] -= 1 

            # 4) CALCULATE 
            # single source characteristic strain
            hc_ss[f_idx] = np.sqrt(maxhs[f_idx]**2 * (fc[f_idx]/df[f_idx]))
        
        # CHECK no numbers should be <0 
        if(np.any(bgnum<0)): 
            error_index = np.where(bgnum<0)
            print('number<0 found at (M,q,z,f) =', error_index)         
        

    # 5)
    # ----------------- Calculate Background Strains --------------------
    # then we can go back in and calculate characteristic strains
    # NOTE: could make this faster by saving rfreq and hs values from above
    # instead of recalculating
    for m_idx in range(len(mt)):
        for q_idx in range(len(mr)):
            cmass = holo.utils.chirp_mass_mtmr(mt[m_idx], mr[q_idx])
            for z_idx in range(len(rz)):
                cdist = holo.cosmo.comoving_distance(rz[z_idx]).cgs.value
                for f_idx in range(len(fc)):
                    rfreq = holo.utils.frst_from_fobs(fc[f_idx], rz[z_idx])
                    hs_mqzf = utils.gw_strain_source(cmass, cdist, rfreq)
                    hc_dlnf = hs_mqzf**2 * (fc[f_idx]/df[f_idx])
                    if(realize == False):
                        hc_bg[m_idx, q_idx, z_idx, f_idx] = np.sqrt(hc_dlnf 
                                        * bgnum[m_idx, q_idx, z_idx, f_idx])
                    else: 
                        raise Exception('realize not implemented yet') 
                    # elif(realize == True):
                    #     hc_grid[m_idx, q_idx, z_idx, f_idx] = np.sqrt(hc_dlnf 
                    #                     *np.random.poisson(bgnum[m_idx, q_idx, z_idx, f_idx]))
                    # elif(utils.isinteger(realize)):
                    #     for r_idx in range(realize):
                    #         realized_grid[m_idx, q_idx, z_idx, f_idx, r_idx] = \
                    #             np.sqrt(hc_dlnf 
                    #                     *np.random.poisson(bgnum[m_idx, q_idx, z_idx, f_idx]))
                    #     hc_grid = realized_grid                
                    # else:
                    #     print("`realize` ({}) must be one of {{True, False, integer}}!"\
                    #         .format(realize))

                    if(print_test):
                        print('\tfr = %.2fnHz, h_s = %.2e, h_c^2/dlnf = %.2e' 
                            % (rfreq*10**9, hs_mqzf, hc_dlnf))
                        print('\t\tnumber: %.2e' % bgnum[m_idx, q_idx, z_idx, f_idx])
                        print('\t\thc = %.2e' % hc_bg[m_idx, q_idx, z_idx, f_idx])
                        if(ss == True):
                            print('\t\t loudest?', 
                                np.all((ssidx[f_idx] == np.array([m_idx, q_idx, z_idx]))))
    if(ss and print_test):
        print('----loudest bins:')
        for f_idx in range(len(fc)):
            print('\t M=%.2e, q=%.2e, z=%.2e, f=%.2e' 
                    % (sspar[f_idx,0], sspar[f_idx,1], sspar[f_idx,2], fc[f_idx]))
            # if we use this instead, we don't need sspar
            # print('\t M=%.2e, q=%.2e, z=%.2e, f=%.2e' 
            #       % (mt[ssidx[f_idx,0]], mr[ssidx[f_idx,1]], rz[ssidx[f_idx,2]]))

     
    if(sum):
        # sum over all bins at a given frequency and realization
        hc_bg = np.sqrt(np.sum(hc_bg**2, axis=(0, 1, 2)))

    
    return hc_bg, hc_ss, sspar, ssidx, maxhs, bgnum


gws_by_ndars (no ss)

In [None]:
def gws_by_ndars(edges, number, realize, round = True, sum = True, print_test = False):
       
    """ More efficient way to calculate strain from numbered 
    grid integrated

    Parameters
    ----------
    edges : (4,) list of 1darrays
        A list containing the edges along each dimension.  The four dimensions 
        correspond to total mass, mass ratio, redshift, and observer-frame orbital 
        frequency. The length of each of the four arrays is M, Q, Z, F.
    number : (M-1, Q-1, Z-1, F-1) ndarray
        The number of binaries in each bin of parameter space.  This is calculated 
        by integrating `dnum` over each bin.
    realize : bool or int
        Specification of how to construct one or more discrete realizations.
        If a `bool` value, then whether or not to construct a realization.
        If an `int` value, then how many discrete realizations to construct.
    round : bool
        Specification of whether to discretize the sample if realize is False, 
        by rounding number of binaries in each bin to integers. This has no impact 
        if realize is true.
        NOTE: should add a warning if round and realize are both True
    sum : bool
        Whether or not to sum the strain at a given frequency over all bins.
    print_test : bool
        Whether or not to print variable as they are calculated, for dev purposes.


    Returns
    -------
    hchar : ndarray
        Characteristic strain of the GWB.
        The shape depends on whether realize is an integer or not
        realize = True or False, sum = False: shape is (M-1, Q-1, Z-1, F-1)
        realize = True or False, sum = True: shape is (F-1)
        realize = R, sum = False: shape is  (M-1, Q-1, Z-1, F-1, R)
        realize = R, sum = True: shape is  (F-1, R)

    """

    if(print_test):
        print('INPUTS: edges:', len(edges), # '\n', edges, 
        '\nINPUTS:number:', number.shape, '\n', number,'\n')

    # Frequency bin midpoints
    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)

    # All other bin midpoints
    mt = kale.utils.midpoints(edges[0]) #: total mass
    mr = kale.utils.midpoints(edges[1]) #: mass ratio
    rz = kale.utils.midpoints(edges[2]) #: redshift


    # --- Chirp Masses ---
    # to get chirp mass in shape (M-1, Q-1) we need 
    # mt in shape (M-1, 1) 
    # mr in shape (1, Q-1)
    cmass = utils.chirp_mass_mtmr(mt[:,np.newaxis], mr[np.newaxis,:])
    if(print_test):
        print('cmass:', cmass.shape, '\n', cmass)

    # --- Comoving Distances ---
    # to get cdist in shape (Z-1) we need
    # rz in shape (Z-1)
    cdist = holo.cosmo.comoving_distance(rz).cgs.value
    if(print_test):
        print('cdist:', cdist.shape, '\n', cdist)

    # --- Rest Frame Frequencies ---
    # to get rest freqs in shape (Z-1, F-1) we need 
    # rz in shape (Z-1, 1) 
    # fc in shape (1, F-1)
    rfreq = holo.utils.frst_from_fobs(fc[np.newaxis,:], rz[:,np.newaxis])
    if(print_test):
        print('rfreq:', rfreq.shape, '\n', rfreq)

    # --- Source Strain Amplitude ---
    # to get hs amplitude in shape (M-1, Q-1, Z-1, F-1) we need
    # cmass in shape (M-1, Q-1, 1, 1) from (M-1, Q-1)
    # cdist in shape (1, 1, Z-1, 1) from (Z-1)
    # rfreq in shape (1, 1, Z-1, F-1) from (Z-1, F-1)
    hsamp = utils.gw_strain_source(cmass[:,:,np.newaxis,np.newaxis],
                                   cdist[np.newaxis,np.newaxis,:,np.newaxis],
                                   rfreq[np.newaxis,np.newaxis,:,:])
    if(print_test):
        print('hsamp', hsamp.shape, '\n', hsamp)

    # --- Characteristic Strain Squared ---
    # to get characteristic strain in shape (M-1, Q-1, Z-1, F-1) we need
    # hsamp in shape (M-1, Q-1, Z-1, F-1)
    # fc in shape (1, 1, 1, F-1)
    hchar = hsamp**2 * (fc[np.newaxis, np.newaxis, np.newaxis,:]
                        /df[np.newaxis, np.newaxis, np.newaxis,:])

    # Sample:
    if(realize == False):
        # without sampling, want strain in shape (M-1, Q-1, Z-1, F-1)
        if(round): 
            # discretize by rounding number down to nearest integer 
            hchar *= np.floor(number).astype(int) 
        else: 
            # keep non-integer values
            hchar *= number

    if(realize == True):
        # with a single sample, want strain in shape (M-1, Q-1, Z-1, F-1)
        hchar *= np.random.poisson(number)

    if(utils.isinteger(realize)):
        # with R realizations, 
        # to get strain in shape (M-1, Q-1, Z-1, F-1, R) we need
        # hchar in shape(M-1, Q-1, Z-1, F-1, 1)
        # Poisson sample in shape (1, 1, 1, 1, R)
        npois = np.random.poisson(number[...,np.newaxis], size = (number.shape + (realize,)))
        if(print_test):
            print('npois', npois.shape)
        hchar = hchar[...,np.newaxis] * npois


    if(print_test):
        print('hchar', hchar.shape, '\n', hchar)


    if(sum):
        # sum over all bins at a given frequency and realization
        hchar = np.sum(hchar, axis=(0, 1, 2))
        # NOTE I should check what big O time this is,  not sure
        if(print_test):
            print('hchar summed', hchar.shape, '\n', hchar)

    return np.sqrt(hchar)

subtraction from number method

ss_gws_by_ndars

In [None]:
def ss_gws_by_ndars(edges, number, realize, round = True, sum = True, 
                    ss = True, print_test = False):
       
    """ More efficient way to calculate strain from numbered 
    grid integrated


    Parameters
    ----------
    edges : (4,) list of 1darrays
        A list containing the edges along each dimension.  The four dimensions correspond to
        total mass, mass ratio, redshift, and observer-frame orbital frequency.
        The length of each of the four arrays is M, Q, Z, F.
    number : (M-1, Q-1, Z-1, F-1) ndarray of scalars
        The number of binaries in each bin of parameter space.  This is calculated by integrating
        `dnum` over each bin.
    realize : bool or int,
        Specification of how to construct one or more discrete realizations.
        If a `bool` value, then whether or not to construct a realization.
        If a `int` value, then how many discrete realizations to construct.
    round : bool
        Specification of whether to discretize the sample if realize is False, 
        by rounding number of binaries in each bin to integers. 
        Does nothing if realize is True.
    ss : bool 
        Whether or not to separate the loudest single source in each frequency bin.
    sum : bool
        Whether or not to sum the strain at a given frequency over all bins.
    print_test : bool
        Whether or not to print variable as they are calculated, for dev purposes.


    Returns
    -------
    hc_bg : ndarray
        Characteristic strain of the GWB.
        The shape depends on whether realize is an integer or not
        realize = True or False, sum = False: shape is (M-1, Q-1, Z-1, F-1)
        realize = True or False, sum = True: shape is (F-1)
        realize = R, sum = False: shape is  (M-1, Q-1, Z-1, F-1, R)
        realize = R, sum = True: shape is  (F-1, R)
    hc_ss : (F-1,) array of scalars
        The characteristic strain of the loudest single source at each frequency.
    ssidx : (F-1, 4) ndarray or None
        The indices (m_idx, q_idx, z_idx, f_idx) of the parameters of the loudest single
        source's bin, at each frequency such that 
        ssidx[i,0] = m_idx of the ith frequency
        ssidx[i,1] = q_idx of the ith frequency
        ssidx[i,2] = z_idx of the ith frequency
        ssidx[i,3] = f_idx of the ith frequency = i
        None if ss = False. 
    hsmax : (F-1) array of scalars or None
        The maximum single source strain amplitude at each frequency.
        None if ss = False. 
    bgnum : (M-1, Q-1, Z-1, F-1) ndarray
        The number of binaries in each bin after the loudest single source
        at each frequency is subtracted out.

        

    Potential BUG: In the unlikely scenario that there are two equal hsmaxes 
    (at same OR dif frequencies), ssidx calculation will go wrong
    Could avoid this by using argwhere for each f_idx column separately.
    Or TODO implement some kind of check to see if any argwheres return multiple 
    values for that hsmax and raises a warning/assertion error

    NOTE: Probably don't need to return so many things, it's just useful for testing.

    TODO: Calculate sspar
    TODO: Implement realizations
    TODO: Implement not summing, or remove option
    """

    if(print_test):
        print('INPUTS: edges:', len(edges), # '\n', edges, 
        '\nINPUTS:number:', number.shape, '\n', number,'\n')

    # Frequency bin midpoints
    foo = edges[-1]                   #: should be observer-frame orbital-frequencies
    df = np.diff(foo)                 #: frequency bin widths
    fc = kale.utils.midpoints(foo)    #: use frequency-bin centers for strain (more accurate!)

    # All other bin midpoints
    mt = kale.utils.midpoints(edges[0]) #: total mass
    mr = kale.utils.midpoints(edges[1]) #: mass ratio
    rz = kale.utils.midpoints(edges[2]) #: redshift


    # --- Chirp Masses ---
    # to get chirp mass in shape (M-1, Q-1) we need 
    # mt in shape (M-1, 1) 
    # mr in shape (1, Q-1)
    cmass = utils.chirp_mass_mtmr(mt[:,np.newaxis], mr[np.newaxis,:])
    if(print_test):
        print('cmass:', cmass.shape, '\n', cmass)

    # --- Comoving Distances ---
    # to get cdist in shape (Z-1) we need
    # rz in shape (Z-1)
    cdist = holo.cosmo.comoving_distance(rz).cgs.value
    if(print_test):
        print('cdist:', cdist.shape, '\n', cdist)

    # --- Rest Frame Frequencies ---
    # to get rest freqs in shape (Z-1, F-1) we need 
    # rz in shape (Z-1, 1) 
    # fc in shape (1, F-1)
    rfreq = holo.utils.frst_from_fobs(fc[np.newaxis,:], rz[:,np.newaxis])
    if(print_test):
        print('rfreq:', rfreq.shape, '\n', rfreq)

    # --- Source Strain Amplitude ---
    # to get hs amplitude in shape (M-1, Q-1, Z-1, F-1) we need
    # cmass in shape (M-1, Q-1, 1, 1) from (M-1, Q-1)
    # cdist in shape (1, 1, Z-1, 1) from (Z-1)
    # rfreq in shape (1, 1, Z-1, F-1) from (Z-1, F-1)
    hsamp = utils.gw_strain_source(cmass[:,:,np.newaxis,np.newaxis],
                                   cdist[np.newaxis,np.newaxis,:,np.newaxis],
                                   rfreq[np.newaxis,np.newaxis,:,:])
    if(print_test):
        print('hsamp', hsamp.shape, '\n', hsamp)


    ############################################################
    ########## HERE'S WHERE THINGS CHANGE FOR SS ###############
    ############################################################

    # --------------- Single Sources ------------------
    ##### 0) Round and/or realize so numbers are all integers
    if (round == True):
        bgnum = np.copy(np.floor(number).astype(np.int64))
        assert (np.all(bgnum%1 == 0)), 'non integer numbers found with round=True'
        assert (np.all(bgnum >= 0)), 'negative numbers found with round=True'
    else:
        bgnum = np.copy(number)
        if(ss==True):
            warnings.warn('Number grid used for single source calculation.')

    if(realize == True):
        bgnum = np.random.poisson(number)
        assert (np.all(bgnum%1 ==0)), 'nonzero numbers found with realize=True'
    print('bgnum stats after copy\n', holo.utils.stats(bgnum))

    #### 1) Identify the loudest (max hs) single source in a bin with N>0 
    hsamp[(bgnum==0)] = 0 #set hs=0 if number=0
    # NOTE don't need to use where function when they are the same shape because I can use boolean indexing
    # hsamp[bgnum==0] = 0



    # --- Single Source Strain Amplitude At Each Frequency ---
    # to get max strain in shape (F-1) we need
    # hsamp in shape (M-1, Q-1, Z-1, F-1), search over first 3 axes
    hsmax = np.amax(hsamp, axis=(0,1,2)) #find max hs at each frequency
    
    #### 2) Record the indices and strain of that single source

    # --- Indices of Loudest Bin ---
    # Shape (F-1, 4), looks like
    # [[m_idx,q_idx,z_idx,0],
    #  [m_idx,q_idx,z_idx,1],
    #   ........
    #  [m_idx,q_idx,z_idx,F-2]]
    # no longer actually need this, but might be useful
    ssidx = np.argwhere(hsamp==hsmax) 
    ssidx = ssidx[ssidx[:,-1].argsort()]



    ### 3) Subtract 1 from the number in that source's bin

    # --- Background Number ---
    # bgnum = subtract_from_number(bgnum, ssidx) # Find a better way to do this!
    if np.any( bgnum[(hsamp == hsmax)] <=0):
        raise Exception("bgnum <= found at hsmax")
    if np.any( hsamp[(hsamp == hsmax)] <=0):
        raise Exception("hsamp <=0 found at hsmax")
    if np.any(hsmax<=0):
        raise Exception("hsmax <=0 found")
    print('new tests passed! :(')
   
    print('bgnum stats:\n', holo.utils.stats(bgnum))
    print('bgnum[hsamp==hsmax] stats:\n', holo.utils.stats(bgnum[(hsamp == hsmax)]))
    bgnum[(hsamp == hsmax)]-=1 # better way found! ssidx may be unnecessary now, but still nice to have
    # NOTE keep an eye out for if hsmax is not found anywhere in hsasmp
    # could change to bgnum(np.where(hsamp==hsmax) & (bgnum >0))-=1
    print('\nafter subtraction')
    print('bgnum stats:\n', holo.utils.stats(bgnum))
    print('bgnum[hsamp==hsmax] stats:\n', holo.utils.stats(bgnum[(hsamp == hsmax)]))

    assert np.all(bgnum>=0), f"bgnum contains negative values at: {np.where(bgnum<0)}"
    # if(np.any(bgnum<0)):   # alternate way to check for this error, and give index of neg number
    #         error_index = *np.where(bgnum<0)
    #         print('number<0 found at [M's], [q's], [z's], [f's]) =', error_index)   

    

    ### 4) Calculate single source characteristic strain (hc)

    # --- Single Source Characteristic Strain ---
    # to get ss char strain in shape [F-1] need
    # fc in shape (F-1)
    # df in shape (F-1)
    hc_ss = np.sqrt(hsmax**2 * (fc/df))


    # --- Parameters of loudest source ---
    # NOTE: This would be useful to implement

    ### 5) Calculate the background with the new number 
 
    # --- Background Characteristic Strain Squared ---
    # to get characteristic strain in shape (M-1, Q-1, Z-1, F-1) we need
    # hsamp in shape (M-1, Q-1, Z-1, F-1)
    # fc in shape (1, 1, 1, F-1)
    hchar = hsamp**2 * (fc[np.newaxis, np.newaxis, np.newaxis,:]
                        /df[np.newaxis, np.newaxis, np.newaxis,:])   
    if (realize==False):
        hchar *= bgnum
    else:
        raise Exception('realize not implemented yet') 

    if(print_test):
        print('hchar', hchar.shape, '\n', hchar)

    if(sum):
        # sum over all bins at a given frequency and realization
        hchar = np.sum(hchar, axis=(0, 1, 2))
        # NOTE I should check what big O time this is,  not sure
        if(print_test):
            print('hchar summed', hchar.shape, '\n', hchar)
    else:
        raise Exception('without sum not implemented or tested yet')

    hc_bg = np.sqrt(hchar)

    return hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum


## 1.3 Test Functions
Consider changing all the np.all(x==y) to np.isclose(x,y)

In [None]:
def max_test(hsmax, hsamp): 
    # check hsmaxes are correct
    hsmax_hsamp_match = np.empty_like(hsmax)
    for f_idx in range(len(hsmax)):
        hsmax_hsamp_match[f_idx] = (np.max(hsamp[...,f_idx]) == hsmax[f_idx])
    assert np.all(hsmax_hsamp_match == True), "the max amplitudes in hsamp do not match those in hsmax"
    print('max_test passed')

def ssidx_test(hsmax, hsamp, ssidx, print_test):
    """ 
    Test ssidx in hsamp gives the same values as hsmax

    Parameters
    ----------
    hsmax : (F,) array of scalars
        Maximum strain amplitude of a single source at each frequency.
    hsamp : (M, Q, Z, F,) ndarray of scalar
        Strain amplitude of a source in each bin
    ssidx : (F-1, 4) ndarray 
        

    """
    # check ssidx are correct and in frequency order
    for i in range(len(hsmax)): #ith frequency
        m,q,z,f = ssidx[i]
        assert i==f, 'ssidx not in order of frequencies'
        if(print_test):
            print('max is at m,q,z,f = %d, %d, %d, %d and it = %.2e'
                  % (m, q, z, f, hsmax[i]))
        assert (hsamp[m,q,z,f] == hsmax[i]), f"The ssidx[{i}] does not give the hsmax[{i}]."
    print('ssidx test passes')

def number_test(num, bgnum, fobs, exname='', plot_test=False):
    ''' 
    Plots num - bgnum, where number is the ndarray of 
    integer number of sources in each bin, i.e. after 
    rounding or Poisson sampling

    Parameters
    ------------
    num : (M, Q, Z, F) array
        integer numbers in each bin, i.e. after rounding or
        Poisson sampling
    bgnum : (M, Q, Z, F) array
        number of background sources in each bin, 
        after single source subtraction
    fobs : (F) array
        frequencies of each F, for ax titles
    exname : String
        name of example
    plot_test : Bool
        whether or not to print values a


    Returns
    -----------
    None 
    
    '''   
    if np.all(num%1 == 0) != True: warnings.warn("num contains at least one non-integer value")
    difs = num - bgnum
    assert len(difs[np.where(difs>0)]) == len(difs[0,0,0,:]), "More than one bin per frequency found with a single source subtracted."

    if(plot_test):
        fig, ax = plt.subplots(1,len(fobs), figsize = (10,3), sharey=True)
        fig.suptitle('integer number - numbg for each bin, '+ exname)
        ax[0].set_ylabel('number - number_bg')
        bins = np.arange(0, num[...,0].size, 1)
        bins = np.reshape(bins, num[...,0].shape)
        # print(bins.shape)
        # print(num[...,0].shape)
        for f in range(len(fobs)):
            ax[f].scatter(bins, (num[...,f] - bgnum[...,f]))
            ax[f].set_title('$f_\mathrm{obs}$ = %dnHz' % (fobs[f]*10**9))
            ax[f].set_xlabel('bin')
        fig.tight_layout()
    print('number test passed')

    
    # TODO Assertion 

def compare_to_loops_test(edges, number, hc_bg, hc_ss, hsmax, ssidx, bgnum):
    hc_bg_loop, hc_ss_loop, sspar_loop, ssidx_loop, maxhs_loop, number_bg_loop \
      = ss_gws_by_loops(edges, number, realize=False, round=True, sum=True, ss=True, print_test=False)
    
    for i in range(len(ssidx)):
        assert np.all(ssidx[i, 0:3] == ssidx_loop[i,:]), \
            f"ssidx[{i}] by ndars does not match by loops"
    assert (np.all(bgnum == number_bg_loop)), "bgnum by ndars does not match by loops"
    assert (np.all(hc_ss == hc_ss_loop)), "hc_ss by ndars does not match by loops"    
    assert (np.all(hsmax == maxhs_loop)), "hsmax by ndars does not match by loops" 
    assert (np.all(hsmax == maxhs_loop)), "hsmax by ndars does not match by loops"
    assert (np.all(np.isclose(hc_bg, hc_bg_loop, atol=1e-20, rtol=1e-20))), \
        "hc_bg by ndars does not match by loops"
    print('compare to loops test passed')

def quadratic_sum_test(hc_bg, hc_ss, hc_tt, print_test):
    test = (hc_bg**2 + hc_ss**2)
    error = (test-hc_tt**2)/hc_tt**2
    assert np.all(np.isclose(hc_tt, test, atol=2e-15, rtol=1e-15)), \
        "quadratic sum of hc_bg and hc_ss does not match hc_tt"
    if(print_test):
        print('percent error between (hc_bg^2+hc_ss^2) and hc_tt^2:', error)
        print('differences between np.sqrt((hc_bg^2+hc_ss^2)) and hc_tt:', 
              np.sqrt(test) - hc_tt)
    print('quadratic sum test passed')

def run_example_tests(edges,number, fobs, exname='', print_test=False, 
                     loop_comparison = True):
    '''
    Call tests for some edges, number
    Paramaters
    ----------
    edges : (4,) list of 1D arrays
        Mass, ratio, redshift, and frequency edges of bins
    number : (M, Q, Z, F) ndarray of scalars
        Number of binaries in each bin
    fobs : (F,) array of scalars
        Observed frequency bin centers
    exname : String
        Name of example (used for number plots)

    Returns
    ------
    hsamp
    hsmax
    ssidx
    bgnum
    '''
    hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = ss_gws_by_ndars(edges, number, realize=False, round=True, ss=True, sum=True)
    max_test(hsmax, hsamp)

    ssidx_test(hsmax, hsamp, ssidx, print_test)

    rounded = np.floor(number).astype(np.int64)
    number_test(rounded, bgnum, fobs, exname, plot_test=print_test)
    
    if(loop_comparison): # optional because its faster without
        compare_to_loops_test(edges, number, hc_bg, hc_ss, hsmax, ssidx, bgnum)

    hc_tt = gws_by_ndars(edges, number, realize=False, round = True, sum=True)  
    quadratic_sum_test(hc_bg, hc_ss, hc_tt, print_test)

    return hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum

run tests

In [None]:
edges, number, fobs, exname = example2(print_test=False)
hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = \
    run_example_tests(edges, number, fobs, exname, print_test=False)

In [None]:
edges, number, fobs, exname = example3(print_test=False)
hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = \
    run_example_tests(edges, number, fobs, exname, print_test=False)

In [None]:
edges, number, fobs, exname = example4(print_test=False)
hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = \
    run_example_tests(edges, number, fobs, exname, print_test=False)

## 1.4 Random Other Functions

In [None]:
def subtract_from_number(bgnum, ssidx):
    ''' 
    Inefficient way to subtract 1 from the number of all bins 
    with a loudest single source
    
    Paramaters
    -------------
    bgnum : (M, Q, Z, F) array
        number of sources in each bin, before
        single sources have been subtracted
    ssidx : (F, 4) array 
        The indices (m_idx, q_idx, and z_idx) of the parameters of the 
        loudest single source's bin, at each frequency

    Returns
    -------------
    bgnum : (M, Q, Z, F) array
        number of background sources in each bin,
        after single sources have been subtracted
    '''

    for ff in range(len(ssidx)):
        m,q,z,f = ssidx[ff]
        print(m,q,z,f)
        bgnum[m,q,z,f] -=1
    return bgnum

def max_index_at_f(grid):
    """ Get the index of the maximum value for each frequency. 
    Frequency is the 4th dimension of the 4Darray, grid.
    NOTE: Find a more elegant way to do this with some fancy indexing! Or
    at least choose which option works best and just use that

    Parameters:
    grid : [M,Q,Z,F] array
        test grid

    Returns:
    mqz_f : (3, F) array
        Indices of max grid value at each frequency
        For F frequencies, it looks like 
        [[m1,m2,...,mF], [q1,q2,...,qF], [z1,z2,...,zF]] 
    f_mqz : (F, 3) array
        Indices of max grid value at each frequency
        For F frequencies it looks like
        [[m1,q1,z1,1], [m2,q2,z2,2], ..., [mF,qF,zF,F]]
    all_f : (4, F) array
        Indices of max grid value at each frequency
        For n frequencies, it looks like 
        [[m1,m2,...,mF], [q1,q2,...,qF], [z1,z2,...,zF], [1,2,...,F]] 
    f_all : (F, 4) array
        Indices of max grid value at each frequency
        For F frequencies it looks like
        [[m1,q1,z1,1], [m2,q2,z2,2], ..., [mF,qF,zF,F]]
    """
    mqz_f = np.empty((3,len(grid[0,0,0,:])))
    f_mqz = np.empty((len(grid[0,0,0,:]),3))
    all_f = np.empty((4, len(grid[0,0,0,:])))
    f_all = np.empty((len(grid[0,0,0,:]),4))

    for f_idx in range(len(grid[0,0,0,:])):
        m,q,z = np.unravel_index(np.argmax(grid[...,f_idx]), 
                                 grid[...,f_idx].shape)
        mqz_f[:,f_idx] = np.array([m,q,z])
        f_mqz[f_idx,:] = np.array([m,q,z])
        all_f[:, f_idx] = np.array([m,q,z,f_idx])
        f_all[f_idx,:] = np.array([m,q,z,f_idx])
    mqz_f = mqz_f.astype(int)
    f_mqz = f_mqz.astype(int)
    all_f = all_f.astype(int)
    f_all = f_all.astype(int)
    return mqz_f, f_mqz, all_f, f_all

def argwhere_at_f(grid):
    """ Get the index of the maximum value for each frequency. 
    Frequency is the 4th dimension of the 4Darray, grid.
    NOTE: This should match f_all

    Parameters:
    grid : [M,Q,Z,F] array
        test grid

    Returns:
    argwhere : (F, 4) array
        Indices of max grid value at each frequency
        For F frequencies it looks like
        [[m1,q1,z1,1], [m2,q2,z2,2], ..., [mF,qF,zF,F]]
    
    """
    maxes = np.max(grid, axis=(0,1,2))
    # print(maxes)
    argwhere = np.argwhere(grid==maxes)
    # print(argwhere)
    return argwhere


In [None]:
def test_argwhere(grid):
    # check argwhere find max indices correctly
    f_all = (max_index_at_f(grid))[-1]
    argwhere = argwhere_at_f(grid)
    assert (np.all(f_all == argwhere)), 'argwhere failing to find correct max indices'


## Scratch
- matching max values from hsmax, hsamp, and ssidx
- number-bgnum give 1 nonzero value for each frequency
- hc_bg^2 + hc_ss^2 = hc_tt^2

In [None]:
# np.set_printoptions(precision=2)
# hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = ss_gws_by_ndars(edges, number, realize=False, round=True, ss=True, sum=True)
# #### Test hsmax
# for i in range(len(hsmax)):
#     print(np.max(hsamp[...,i]) == hsmax[i])
# # print(np.max(hsamp[...,0]))
# # print(hsmax)
# # print(number>=1)

# ### Test ssidx
# print(ssidx)
# print(ssidx.shape)
# for f_idx in range(len(hsamp[0,0,0,:])):
#     print(f_idx)
#     print(hsamp[ssidx[f_idx,0], ssidx[f_idx,1], ssidx[f_idx,2], ssidx[f_idx,3]])
#     print(hsmax[f_idx])
#     print(np.max(hsamp[...,f_idx]))
#     # Those should all be equal, lets make this into asserts
# print(hsamp.shape)


In [None]:
# bgnum = np.floor(number).astype(int) 
# print(np.shares_memory(bgnum, number))

In [None]:
# arr = np.array([[[1,2], [3,4], [5,6], [1,2]], [[9000,20], [30,40], [50,60], [10,20]], [[100,200], [300,400], [500,600], [700,800]]])
# # arr, like hsamp, shape (3,4,2)
# print(arr.shape, np.max(arr[...,0]), np.max(arr[...,1]))
# maxes = np.amax(arr, axis=(0,1)) 
# # maxes, like hsmax, shape 2
# print(maxes.shape)
# print(arr[np.where(arr==maxes)])
# idx = np.array(np.where(arr==maxes))
# print('index', idx.shape, '\n', idx)
# rot = np.swapaxes(idx,0,1)
# print('rot', rot.shape, '\n', rot)
# for i in range(2):
#     print(arr[int(idx[0,i]), int(idx[1,i]), int(idx[2,i])])
# for i in range(2):
#     print(arr[int(rot[i,0]), int(rot[i,1]), int(rot[i,2])])

here's how to check

In [None]:
# hsmax_hsamp_match = np.empty_like(hsmax)
# hsmax_ssidx_match = np.empty_like(hsmax)
# for f_idx in range(len(hsmax)):
#     hsmax_hsamp_match[f_idx] = (np.max(hsamp[...,f_idx]) == hsmax[f_idx])
    
#     m,q,z,f = (ssidx[np.where(ssidx[:,3] == f_idx)])[0]
#     print(m,q,z,f)
#     hsmax_ssidx_match[f_idx] = (hsamp[m,q,z,f] == hsmax[f_idx])
# assert np.all(hsmax_hsamp_match == True), "the max amplitudes in hsamp do not match those in hsmax"
# assert np.all(hsmax_ssidx_match == True), "the max amplitudes in hsamp idenfitied with ssidx do not match those in hsmax"

In [None]:
# # to look up where the strain at some index (e.g. from ssidx):
# f=3
# print(ssidx[np.where(ssidx[:,3] == f)])

In [None]:
# print(ssidx[0,...])
# print(ssidx)
# print(ssidx[:,3])
# # print(number.shape)
# # print(hsmax)


In [None]:
# print(hsmax_hsamp_match)
# print(hsmax_ssidx_match)

In [None]:
# hsmaxes = np.max(hsamp, axis=(0,1,2))
# print(hsmaxes) # this orders them from largest to smallest
# argwhere = np.argwhere(hsamp==hsmaxes)
# revwhere = np.argwhere(hsmaxes==hsamp) # order makes no dif
# print(argwhere)
# print(revwhere)
# print(np.all(argwhere==revwhere))
# print(hsmaxes[0])
# print(np.max(hsmax))
# print(hsamp[13,23,16,0])

grid:

1x  |  xx  |  xx  |  xx  |  xx \
xx  |  9x  |  x3  |  xx  |  xx \
xx  |  xx  |  xx  |  x4  |  x5 

In [None]:
grid = np.array([[[[1,0,0,0,0],[0,9,0,0,0],[0,0,0,0,0]]],
                 [[[0,0,0,0,0],[0,0,3,0,0],[0,0,0,4,5]]]])
print(grid.shape)
mqz_f, f_mqz, all_f, f_all= max_index_at_f(grid)
print('mqz_f:\n',mqz_f)
print('\nf_mqz:\n', f_mqz)
print('\nall_f:\n', all_f)
print('\nf_all:\n', f_all)

In [None]:
print(grid.shape) #(2,1,3,5)
test=grid.reshape(6,5)
print(grid.shape) #(6,5)

testi = np.argmax(test, axis)

ci = np.argmax(cc, axis=0)
print(cc)
print(cc[ci])

# get max at each row when reshaped into (all, F)
# unravel index
# reshape back

#np.apply_along axis vs. apply over axis


In [None]:
# maxes = np.amax(grid, axis=(0,1,2))
# print(maxes)

# hsmax = np.amax(grid, axis=(0,1,2)) #find max hs at each frequency
    
#     #### 2) Record the indices and strain of that single source

# # --- Indices of Loudest Bin ---
# # Shape [F-1, 4], looks like
# # [[m_idx,q_idx,z_idx,0],
# #  [m_idx,q_idx,z_idx,1],
# #   ........
# #  [m_idx,q_idx,z_idx,F-2]]
# ssidx = np.argwhere(grid==hsmax) # NOTE: 
# for s in range(len(ssidx)):
#     m,q,z,f = ssidx[s]
#     print(m,q,z,f, 'grid max:',grid[m,q,z,f])
# print(ssidx)

In [None]:
# am = (np.argmax(hsamp[3,3,:,4]))
# print(hsamp[3,3,:,4])
# print(hsamp[3,3,am,4])

In [None]:
# mqz_f, f_mqz, all_f, f_all= max_index_at_f(grid)
# argwhere = argwhere_at_f(grid)
# print(argwhere)
# assert (np.all(f_all == argwhere)), 'argwhere failing to find maxes'

In [None]:
# for i in range(len(grid[0,0,0,:])):
#     m,q,z,f = argwhere[i]
#     print(m,q,z,f, 'max:', grid[m,q,z,f])

In [None]:
# grid = np.array([[[[1,0,0,0,0],[0,2,0,0,0],[0,0,0,0,0]]],
#                  [[[0,0,0,0,0],[0,0,3,0,0],[0,0,0,4,5]]]])
# maxes = np.max(grid, axis=(0,1,2))
# # print('grid:', grid.shape, '\n', grid)
# # print('maxes', maxes.shape, '\n', maxes)
# where = np.where(grid==maxes)
# print('where: len=', len(where), 'with shapes=', where[0].shape, where[1].shape, where[2].shape,
#       where[3].shape, '\n', where)
# print('max by where:', grid[np.where(maxes==grid)])
# print()
# argmax = np.argmax(grid)
# m,q,z,f = np.unravel_index(np.argmax(grid), grid.shape)
# print('m:',m, '\nq:',q, '\nz:',z, '\nf:',f)
# print('argmaxes:', argmax.shape, '\n', argmax)

In [None]:
# idxs = np.argwhere(arr==maxes)
# print('idxs', idxs.shape,'\n', idxs)
# rots = np.rot90(idxs)
# print('rots', rots.shape, '\n', rots)
# print(idxs[0])
# for i in [0,1]:
#     print(arr[idxs[i,0], idxs[i,1], idxs[i,2]])
# for i in [0,1]:
#     # print(arr[rots[0,i], rots[1,i], rots[2,i]])
#     print(arr[rots[2,i], rots[1,i], rots[0,i]])
# # the rot situation is trickier, just use
# # idxs = np.argwhere(arr==maxes)
# # returns (2,3) array
# # or (f, 4) array?


another sample grid for hsamp

In [None]:
# # --- Single Source Strain Amplitude At Each Frequency ---
# hsamp = np.copy(grid)
# # print(hsamp)

# hsmax = np.amax(hsamp, axis=(0,1,2)) #find max hs at each frequency
# print(hsmax)
# #### 2) Record the indices and strain of that single source

# # --- Indices of Loudest Bin ---
# # Shape [F-1, 4], looks like
# # [[m_idx,q_idx,z_idx,0],
# #  [m_idx,q_idx,z_idx,1],
# #   ........
# #  [m_idx,q_idx,z_idx,F-2]]
# ssidx = np.argwhere(hsamp==hsmax) 
# print(ssidx.shape)
# print(ssidx)
# for s in range(len(ssidx)):
#     print(s)
#     m,q,z,f = ssidx[s]
#     print(m,q,z,f, hsamp[m,q,z,f])

# # This works perfectly for my test grid, so why the actual
# # is it not working fo the real hsamps
# # I AM LOSING MY MIND

number

In [None]:
# print(number.shape)
# print('all freq, m10 q10 and z10,', number[10,10,10])
# print('all 2 ms, 1q, 3zs, all 5 freqs:\n', number[5:7,10:11,10:13])

In [None]:
# # --------------- Single Sources ------------------
#     ##### 0) Round and/or realize so numbers are all integers
# # (round == True):
# bgnum = np.copy(np.floor(number).astype(int))
# assert (np.all(bgnum%1 ==0)), 'nonzero numbers found with round=True'
#     # else:
#     #     bgnum = np.copy(number)

# # if(realize == True):
# bgnum = np.random.poisson(number)
# assert (np.all(bgnum%1 ==0)), 'nonzero numbers found with realize=True'


number subtraction

In [None]:
# def subtract_from_number(bgnum, ssidx):
#     for ff in range(len(ssidx)):
#         m,q,z,f = ssidx[ff]
#         print(m,q,z,f)
#         bgnum[m,q,z,f] -=1
#     return bgnum


# # better subtraction method?
# def better_sub_from_number(number, hsamp, hsmax):
#     bgnum = np.copy(number)
#     bgnum[np.where(hsamp == hsmax)]-=1
#     return bgnum

In [None]:
# print(ssidx[0].shape)
# print(rounded[ssidx[0]]) 

In [None]:
# edges, number, fobs, exname = example2(False)
# hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = run_example_test(edges, number, fobs, exname)
# rounded = np.floor(number).astype(int)
# subnum = subtract_from_number(rounded, ssidx)
# print(np.all(subnum == bgnum))

# rounded = np.floor(number).astype(int)
# subnum2 = better_sub_from_number(rounded, hsamp, hsmax)
# print(np.all(subnum2==subnum))

# rounded = np.floor(number).astype(int)
# subnum3 = sub2(rounded, ssidx)
# print(np.all(subnum3==subnum))

In [None]:
# def number_test(num, bgnum, fobs, exname=''):
#     ''' 
#     Plots num - bgnum, where number is the ndarray of 
#     integer number of sources in each bin, i.e. after 
#     rounding or Poisson sampling

#     Parameters
#     ------------
#     num : (M, Q, Z, F) array
#         integer numbers in each bin, i.e. after rounding or
#         Poisson sampling
#     bgnum : (M, Q, Z, F) array
#         number of background sources in each bin, 
#         after single source subtraction
#     fobs : (F) array
#         frequencies of each F, for ax titles
#     exname : 

#     Returns
#     -----------
#     None 

#     TODO: Add an assertion that there be one nonzero value 
#     in the difference array at each frequency
    
#     '''
        
#     assert np.all(num%1 == 0), "use integer array for num"
#     fig, ax = plt.subplots(1,len(fobs), figsize = (10,3), sharey=True)
#     fig.suptitle('integer number - numbg for each bin, '+ exname)
#     ax[0].set_ylabel('number - number_bg')
#     bins = np.arange(0, num[...,0].size, 1)
#     bins = np.reshape(bins, num[...,0].shape)
#     # print(bins.shape)
#     # print(num[...,0].shape)
#     for f in range(len(fobs)):
#         ax[f].scatter(bins, (num[...,f] - bgnum[...,f]))
#         ax[f].set_title('$f_\mathrm{obs}$ = %dnHz' % (fobs[f]*10**9))
#         ax[f].set_xlabel('bin')
#     fig.tight_layout()

In [None]:
# number_test(rounded, bgnum, fobs=fobs, exname= exname)

In [None]:
# print(ssidx.shape)
# print(hsamp.shape)
# numbg = np.floor(number).astype(int)
# rounded = np.floor(number).astype(int)
# # print(numbg%1 == 0)
# for ff in range(len(ssidx)):
#     m,q,z,f = ssidx[ff]
#     print(m,q,z,f)
#     numbg[m,q,z,f] -=1

# # plot number vs anything at 1 frequency 
# freqs_in_nHz = fobs*10**9
# # plot number - single sources vs anything at 1 frequency,
# # should only vary in one cell
# fig, ax = plt.subplots(1,len(fobs), figsize = (10,3), sharey=True)
# fig.suptitle('rounded number - numbg for each bin, '+ exname)
# ax[0].set_ylabel('number - number_bg')
# bins = np.arange(0, number[...,0].size, 1)
# bins = np.reshape(bins, number[...,0].shape)
# print(bins.shape)
# print(number[...,0].shape)
# for f in range(len(fobs)):
#     ax[f].scatter(bins, (rounded[...,f] - numbg[...,f]))
#     ax[f].set_title('$f_\mathrm{obs}$ = %dnHz' % freqs_in_nHz[f])
#     ax[f].set_xlabel('bin')
# fig.tight_layout()

    

In [None]:
# difs = rounded-numbg
# print(len(difs[0,0,0,:]))
# # print(difs[:,:,:,0])
# print(np.where(difs==0))
# # difs[0,0,0,0] = 2
# print(difs[np.where(difs>0)])
# print(len(difs[np.where(difs>0)]))
# assert len(difs[np.where(difs>0)]) == len(difs[0,0,0,:]), "More than one bin per frequency found with a single source subtracted."

sort hsmax with example

In [None]:
# edges, number, fobs, exname = example2(False)

In [None]:
# hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = ss_gws_by_ndars(edges, number, realize=False, round=True, ss=True, sum=True)

In [None]:
# run_example_test(edges, number, fobs, exname)

In [None]:
# print(hsamp)

# hsmaxtest = np.amax(hsamp, axis=(0,1,2)) #find max hs at each frequency
# assert np.all(hsmaxtest == hsmax), 'hsmaxtest problem'
# print(hsmax)

# ssidxtest = np.argwhere(hsamp==hsmaxtest) 
# assert np.all(ssidxtest == ssidx), 'ssidx problem'
# print(ssidx.shape)
# print(ssidx)
# for s in range(len(ssidx)):
#     print(s)
#     m,q,z,f = ssidx[s]
#     print(m,q,z,f, hsamp[m,q,z,f])

# # This works perfectly for my test grid, so why the actual
# # is it not working for the real hsamps
# # I AM LOSING MY MIND

In [None]:
# # now modify example:
# print(hsamp[0,0,2,1]) # initially 4.83e-20, then set to larger number
# hsamp[0,0,2,1] = 9e-18
# print(hsamp[0,0,2,1])


In [None]:
# # print(hsamp)

# hsmaxtest = np.amax(hsamp, axis=(0,1,2)) #find max hs at each frequency
# # print(hsmaxtest)

# ssidxtest = np.argwhere(hsamp==hsmaxtest) 
# # assert np.all(ssidxtest == ssidx), 'ssidx problem'
# print(ssidxtest.shape)
# # print(ssidxtest)
# for s in range(len(ssidx)):
#     print(s)
#     m,q,z,f = ssidxtest[s]
#     print(m,q,z,f, hsamp[m,q,z,f])

# # EUREKA! np.argwhere sorted by max, not np.amax

In [None]:
# ssidx_sorted = ssidxtest[ssidxtest[:,-1].argsort()]
# print(ssidx_sorted)
# for s in range(len(ssidx_sorted)):
#     print(s)
#     m,q,z,f = ssidx_sorted[s]
#     print(m,q,z,f, hsamp[m,q,z,f], hsmax[s])


In [None]:
# hsmaxes = np.max(hsamp, axis=(0,1,2))
# print(hsmaxes) # this orders them from largest to smallest
# argwhere = np.argwhere(hsamp==hsmaxes)
# for a in range(len(argwhere)):
#     print(argwhere[a], hsmaxes[a], hsmax[a])
#     print(np.max(hsamp[...,argwhere[a]]))
# # print(argwhere)
# print(hsmax[0])
# print(np.max(hsmax))
# print(hsamp[13,23,16,0])


In [None]:
# sortedargs = argwhere[argwhere[:,-1].argsort()]
# print(sortedargs)
# maxsorted = np.empty_like(hsmax)
# for f_idx in range(len(hsmax)): 
#     m,q,z,f = sortedargs[f_idx]
#     maxsorted[f_idx] = hsamp[m,q,z,f] 
#     print(m,q,z,f, maxsorted[f_idx])
# # sortedargwhere = np.argwhere(hsamp == sortedmaxes)
# # print(sortedargwhere)

In [None]:
# print(argwhere[argwhere[:,-1].argsort()])

In [None]:
# for f_idx in range(len(hsmax)):
#     print(f_idx)
#     # print(hsamp[ssidx[f_idx,0], ssidx[f_idx,1], ssidx[f_idx,2], ssidx[f_idx,3]])
#     print(hsmax[f_idx])
#     print(sorted_hsmax)
#     # print(np.max(hsamp[...,f_idx]))

In [None]:
# np.set_printoptions(precision=2)
# #### Test hsmax
# # for i in range(len(hsmax)):
# #     print(np.max(hsamp[...,i]) == hsmax[i])
# # print(np.max(hsamp[...,0]))
# # print(hsmax)
# # print(number>=1)

# ### Test ssidx
# # print(ssidx)
# print(ssidx.shape)
# for f_idx in range(len(hsamp[0,0,0,:])):
#     print(f_idx)
#     print(hsamp[ssidx[f_idx,0], ssidx[f_idx,1], ssidx[f_idx,2], ssidx[f_idx,3]])
#     print(hsmax[f_idx])
#     print(np.max(hsamp[...,f_idx]))
#     # Those should all be equal, lets make this into asserts
# print(hsamp.shape)
# # BUG: ssidx does not give the correct maxes for the more complicated exampleS (3 and 4)

compare to loops

In [None]:
# edges, number, fobs, exname = example4(print_test=False)
# hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = run_example_tests(edges, number, fobs, exname, print_test=False)

In [None]:
# hc_bg_loop, hc_ss_loop, sspar_loop, ssidx_loop, maxhs_loop, number_bg_loop \
#       = ss_gws_by_loops(edges, number, realize=False, round=True, sum=True, ss=True)

In [None]:
# print(hsmax==maxhs_loop)

In [None]:
# print(np.all(bgnum == number_bg_loop))
# # print(np.all(ssidx == ssidx_loop))

In [None]:
# print(ssidx_loop.shape)
# print(ssidx.shape)
# for i in range(len(ssidx)):
#     print(ssidx[i, 0:3] == ssidx_loop[i])

In [None]:
# # hc_ss working!!

# print(np.isclose(hc_ss, hc_ss_loop, atol=10e-35))
# print(np.all(hc_ss == hc_ss_loop))
# print(hc_ss - hc_ss_loop)
# print(hc_ss)
# print(hc_ss_loop)

In [None]:
# assert (np.all(hsmax == maxhs_loop)), "hsmax by ndars does not match by loops"
# assert (np.all(ssidx == ssidx_loop)), "ssidx by ndars does not match by loops"
# assert (np.all(bgnum == number_bg_loop)), "bgnum by ndars does not match by loops"
# assert (np.all(hc_ss == hc_ss_loop)), "hc_ss by ndars does not match by loops"    


working on hc_ss

In [None]:
# edges, number, fobs, exname = example4(print_test=False)
# hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = run_example_tests(edges, number, fobs, exname, print_test=False)


warnings

In [None]:
# # WHY DOESN'T WARNING SHOW UP?
# # warnings probably disabled
# num = np.array([0,1,2,3.5,4])
# print( np.any(num[np.where(num%1 !=0)]))
# def test(num):
#     if np.any(num[np.where(num%1 !=0)]): 
#         print("should warn")
#         warnings.warn('noninteger val found !') #('noninteger bgnum values:', bgnum[np.where(bgnum%1 !=0)]))
# test(num)

quadratic sum tests

In [None]:
# hc_bg, hc_ss, hsamp, ssidx, hsmax, bgnum = ss_gws_by_ndars(edges, number, realize=False, round=True)
# hc_tt = gws_by_ndars(edges, number, realize=False, round = True, sum=True)  

In [None]:
# test = (hc_bg**2 + hc_ss**2)
# error = (test-hc_tt**2)/hc_tt**2
# # np.isclose(hc_tt, )
# print(error)
# print(hc_tt - np.sqrt(hc_ss**2 + hc_bg**2))
# print(hc_ss)
# print(np.isclose(hc_tt, test, atol=2e-15, rtol=1e-15))

In [None]:
# hc_bg_loop, hc_ss_loop, sspar_loop, ssidx_loop, maxhs_loop, number_bg_loop \
#   = ss_gws_by_loops(edges, number, realize=False, round=True, sum=True, ss=True, print_test=False)

bg tests

In [None]:
# np.set_printoptions(precision=2)
# print('isequal:', hc_bg == hc_bg_loop)
# print('isclose:', np.isclose(hc_bg, hc_bg_loop, atol=1e-30, rtol=1e-30))

# diff = hc_bg - hc_bg_loop
# error = diff/hc_bg_loop
# # np.isclose(hc_tt, )
# print('diff:', diff)
# print('percent diff:', error)
# print('hc_bg:', hc_bg)
# print('hc_bg_loop:', hc_bg_loop)

# 2 Super Simple Example 

## 2.1 Build Model and Calculate Strains
Build Model

In [None]:
edges, number, fobs, exname = example2(print_test=True)

Calculate Strains

In [None]:
# Get GW Total Strain using gws_by_ndars()
hc_tt = gws_by_ndars(edges, number, realize=False, round = True, sum=True)   
 
# Get BG and SS Strain using ss_gws_by_ndars()
# has hsamp but not sspar
hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar\
    = ss_gws_by_ndars(edges, number, realize=False)

# Get BG and SS Strain using ss_gws_by_loops()
# has sspar, but not hsamp
hc_bg_loop, hc_ss_loop, sspar_loop, ssidx_loop, maxhs_loop, bgnum_loop\
    = ss_gws_by_loops(edges, number, realize=False)

## 2.2 Tests

In [None]:
# test
hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar\
    = run_example_tests(edges, number, fobs, print_test=True)

## 2.3 Plots

### BG, SS, and TT Strain

In [None]:
# quick plot
fig, ax = plot.figax(xlabel='Frequency $f_\mathrm{obs}$ [1/yr]', 
                    ylabel='Characteristic Strain $h_c$', figsize=[10,4.5])
ax.set_title(exname)
xx = fobs * YR

# plot a reference, pure power-law  strain spectrum:   h_c(f) = 1e-15 * (f * yr) ^ -2/3
yy = 1e-15 * np.power(xx, -2.0/3.0)
ax.plot(xx, yy, 'k--', alpha=0.25, lw=2.0, label = 'pure power law')

# total char strain (no sources subtracted)
ax.plot(xx, hc_tt, color='black', marker = 'x', lw=3, 
        ls = 'dotted', alpha=1, label='total, by loops')


# by loops:
# gwb
ax.plot(xx, hc_bg_loop, color='b', marker = 'o', lw=3,
        ls = 'dotted', alpha=.5, label='background, loops')
# loudest source per bin
ax.scatter(xx, hc_ss_loop, color='b', marker = 'o', s=100,
           edgecolor='k', alpha=.5, label='single source, loops')

# by ndars:
# gwb
ax.plot(xx, hc_bg_ndar, color='g', marker = 'o', lw=3, 
        ls = 'dotted', alpha=.5, label='background, ndars')
# loudest source per bin
ax.scatter(xx, hc_ss_ndar, color='g', marker = 'o', 
           edgecolor='k', alpha=.5, label='single source, ndars')


# ax.plot(xx, hc_bg+hc_ss, color='r', label='bg + ss', alpha=.5)
ax.plot(xx, np.sqrt(hc_bg_ndar**2+hc_ss_ndar**2), color='r', 
        label=r'$sqrt (bg^2 + ss^2)$, ndars', alpha=.5)

legend_gwb = ax.legend(bbox_to_anchor=(.95,.95), 
                       bbox_transform=fig.transFigure, loc='upper right')


# ax.set_ylim(1e-16, 3e-15)
fig.tight_layout()

# 3 Medium Simple Example

## 3.1 Build Model and Calculate Strains
Build Model

In [None]:
edges, number, fobs, exname = example3(print_test=True)

Calculate Strains

In [None]:
# Get GW Total Strain using gws_by_ndars()
hc_tt = gws_by_ndars(edges, number, realize=False, round = True, sum=True)   
 
# Get BG and SS Strain using ss_gws_by_ndars()
# has hsamp but not sspar
hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar\
    = ss_gws_by_ndars(edges, number, realize=False)

# Get BG and SS Strain using ss_gws_by_loops()
# has sspar, but not hsamp
hc_bg_loop, hc_ss_loop, sspar_loop, ssidx_loop, maxhs_loop, bgnum_loop\
    = ss_gws_by_loops(edges, number, realize=False)

## 3.2 Tests

In [None]:
# test
hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar\
    = run_example_tests(edges, number, fobs, print_test=True)

## 3.3 Plots

### BG, SS, and TT Strain

In [None]:
# quick plot
fig, ax = plot.figax(xlabel='Frequency $f_\mathrm{obs}$ [1/yr]', 
                    ylabel='Characteristic Strain $h_c$', figsize=[10,4.5])
ax.set_title(exname)
xx = fobs * YR

# plot a reference, pure power-law  strain spectrum:   h_c(f) = 1e-15 * (f * yr) ^ -2/3
yy = 1e-15 * np.power(xx, -2.0/3.0)
ax.plot(xx, yy, 'k--', alpha=0.25, lw=2.0, label = 'pure power law')

# total char strain (no sources subtracted)
ax.plot(xx, hc_tt, color='black', marker = 'x', lw=3, 
        ls = 'dotted', alpha=1, label='total, by loops')


# by loops:
# gwb
ax.plot(xx, hc_bg_loop, color='b', marker = 'o', lw=3,
        ls = 'dotted', alpha=.5, label='background, loops')
# loudest source per bin
ax.scatter(xx, hc_ss_loop, color='b', marker = 'o', s=100,
           edgecolor='k', alpha=.5, label='single source, loops')

# by ndars:
# gwb
ax.plot(xx, hc_bg_ndar, color='g', marker = 'o', lw=3, 
        ls = 'dotted', alpha=.5, label='background, ndars')
# loudest source per bin
ax.scatter(xx, hc_ss_ndar, color='g', marker = 'o', 
           edgecolor='k', alpha=.5, label='single source, ndars')


# ax.plot(xx, hc_bg+hc_ss, color='r', label='bg + ss', alpha=.5)
ax.plot(xx, np.sqrt(hc_bg_ndar**2+hc_ss_ndar**2), color='r', 
        label=r'$sqrt (bg^2 + ss^2)$, ndars', alpha=.5)

legend_gwb = ax.legend(bbox_to_anchor=(.95,.95), 
                       bbox_transform=fig.transFigure, loc='upper right')


# ax.set_ylim(1e-16, 3e-15)
fig.tight_layout()

# 4 Complex Example - Casting 64bit


## 4.1 Build Model and Calculate Strains
Build Model

In [None]:
edges, number, fobs, exname = example4(print_test=True)

Calculate Strains

In [None]:
# Get GW Total Strain using gws_by_ndars()
hc_tt = gws_by_ndars(edges, number, realize=False, round = True, sum=True)   
 
# Get BG and SS Strain using ss_gws_by_ndars()
# has hsamp but not sspar
hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar\
    = ss_gws_by_ndars(edges, number, realize=False)

# Get BG and SS Strain using ss_gws_by_loops()
# has sspar, but not hsamp
hc_bg_loop, hc_ss_loop, sspar_loop, ssidx_loop, maxhs_loop, bgnum_loop\
    = ss_gws_by_loops(edges, number, realize=False)

## 4.2 Tests

In [None]:
# test
hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar\
    = run_example_tests(edges, number, fobs, print_test=True)

## 4.3 Plots

### BG, SS, and TT Strain

In [None]:
# quick plot
fig, ax = plot.figax(xlabel='Frequency $f_\mathrm{obs}$ [1/yr]', 
                    ylabel='Characteristic Strain $h_c$', figsize=[10,4.5])
ax.set_title(exname)
xx = fobs * YR

# plot a reference, pure power-law  strain spectrum:   h_c(f) = 1e-15 * (f * yr) ^ -2/3
yy = 1e-15 * np.power(xx, -2.0/3.0)
ax.plot(xx, yy, 'k--', alpha=0.25, lw=2.0, label = 'pure power law')

# total char strain (no sources subtracted)
ax.plot(xx, hc_tt, color='black', marker = 'x', lw=3, 
        ls = 'dotted', alpha=1, label='total, by loops')


# by loops:
# gwb
ax.plot(xx, hc_bg_loop, color='b', marker = 'o', lw=3,
        ls = 'dotted', alpha=.5, label='background, loops')
# loudest source per bin
ax.scatter(xx, hc_ss_loop, color='b', marker = 'o', s=100,
           edgecolor='k', alpha=.5, label='single source, loops')

# by ndars:
# gwb
ax.plot(xx, hc_bg_ndar, color='g', marker = 'o', lw=3, 
        ls = 'dotted', alpha=.5, label='background, ndars')
# loudest source per bin
ax.scatter(xx, hc_ss_ndar, color='g', marker = 'o', 
           edgecolor='k', alpha=.5, label='single source, ndars')


# ax.plot(xx, hc_bg+hc_ss, color='r', label='bg + ss', alpha=.5)
ax.plot(xx, np.sqrt(hc_bg_ndar**2+hc_ss_ndar**2), color='r', 
        label=r'$sqrt (bg^2 + ss^2)$, ndars', alpha=.5)

legend_gwb = ax.legend(bbox_to_anchor=(.95,.95), 
                       bbox_transform=fig.transFigure, loc='upper right')


# ax.set_ylim(1e-16, 3e-15)
fig.tight_layout()

# 5 SAM Default Example - Casting 64bit


## 5.1 Build Model and Calculate Strains
Build Model

In [None]:
edges, number, fobs, exname = example5(print_test=True)

In [None]:
print(holo.utils.stats(number))

In [None]:
print(holo.utils.stats(np.floor(number).astype(int)))

Calculate Strains

In [None]:
# Get GW Total Strain using gws_by_ndars()
hc_tt = gws_by_ndars(edges, number, realize=False, round = True, sum=True)   
 
# Get BG and SS Strain using ss_gws_by_ndars()
# has hsamp but not sspar
hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar\
    = ss_gws_by_ndars(edges, number, realize=False)
hc_bg_ndar5, hc_ss_ndar5, hsamp_ndar5, ssidx_ndar5, maxhs_ndar5, bgnum_ndar5\
    = hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar

# Get BG and SS Strain using ss_gws_by_loops()
# has sspar, but not hsamp
hc_bg_loop, hc_ss_loop, sspar_loop, ssidx_loop, maxhs_loop, bgnum_loop\
    = ss_gws_by_loops(edges, number, realize=False)
hc_bg_loop5, hc_ss_loop5, sspar_loop5, ssidx_loop5, maxhs_loop5, bgnum_loop5\
    = hc_bg_loop, hc_ss_loop, sspar_loop, ssidx_loop, maxhs_loop, bgnum_loop

## 5.2 Tests

In [None]:
# test
hc_bg_ndar, hc_ss_ndar, hsamp_ndar, ssidx_ndar, maxhs_ndar, bgnum_ndar\
    = run_example_tests(edges, number, fobs, print_test=True)

## 5.3 Plots

### BG, SS, and TT Strain

In [None]:
# quick plot
fig, ax = plot.figax(xlabel='Frequency $f_\mathrm{obs}$ [1/yr]', 
                    ylabel='Characteristic Strain $h_c$', figsize=[10,4.5])
ax.set_title(exname)
xx = fobs * YR

# plot a reference, pure power-law  strain spectrum:   h_c(f) = 1e-15 * (f * yr) ^ -2/3
yy = 1e-15 * np.power(xx, -2.0/3.0)
ax.plot(xx, yy, 'k--', alpha=0.25, lw=2.0, label = 'pure power law')

# total char strain (no sources subtracted)
ax.plot(xx, hc_tt, color='black', marker = 'x', lw=3, 
        ls = 'dotted', alpha=1, label='total, by loops')


# by loops:
# gwb
ax.plot(xx, hc_bg_loop, color='b', marker = 'o', lw=3,
        ls = 'dotted', alpha=.5, label='background, loops')
# loudest source per bin
ax.scatter(xx, hc_ss_loop, color='b', marker = 'o', s=100,
           edgecolor='k', alpha=.5, label='single source, loops')

# by ndars:
# gwb
ax.plot(xx, hc_bg_ndar, color='g', marker = 'o', lw=3, 
        ls = 'dotted', alpha=.5, label='background, ndars')
# loudest source per bin
ax.scatter(xx, hc_ss_ndar, color='g', marker = 'o', 
           edgecolor='k', alpha=.5, label='single source, ndars')


ax.plot(xx, np.sqrt(hc_bg_ndar**2+hc_ss_ndar**2), color='r', 
        label=r'$sqrt (bg^2 + ss^2)$, ndars', alpha=.5)
# ax.plot(xx, np.sqrt(hc_bg_loop**2+hc_ss_loop**2), color='r', 
#         label=r'$sqrt (bg^2 + ss^2)$, loops', alpha=.5)

legend_gwb = ax.legend(bbox_to_anchor=(.95,.95), 
                       bbox_transform=fig.transFigure, loc='upper right')


# ax.set_ylim(1e-16, 3e-15)
fig.tight_layout()