In [None]:
# %load ./init.ipy
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys

import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sams
import holodeck.gravwaves
from holodeck import cosmo, utils, plot, discrete, sams, relations
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode

log = holo.log
log.setLevel(logging.INFO)

In [None]:
def plot_bin_pop(pop):
    mt, mr = utils.mtmr_from_m1m2(pop.mass)
    redz = cosmo.a_to_z(pop.scafa)
    data = [mt/MSOL, mr, pop.sepa/PC, 1+redz]
    data = [np.log10(dd) for dd in data]
    reflect = [None, [None, 0], None, [0, None]]
    labels = [r'M/M_\odot', 'q', r'a/\mathrm{{pc}}', '1+z']
    labels = [r'${{\log_{{10}}}} \left({}\right)$'.format(ll) for ll in labels]

    if pop.eccen is not None:
        data.append(pop.eccen)
        reflect.append([0.0, 1.0])
        labels.append('e')

    kde = kale.KDE(data, reflect=reflect)
    corner = kale.Corner(kde, labels=labels, figsize=[8, 8])
    corner.plot_data(kde)
    return corner

def compare_bin_pops(dpops, labels=None, var='mass', colors=None, lws=None, 
                     density=True, hist=False, confidence=False):
    
    assert isinstance(dpops, list), '`dpops` must be a list of discrete populations'
    assert var in ['mass','mrat','sepa','redz'], "`var` must be 'mass','mrat','sepa', or 'redz'."
    if (labels is not None):
        if (len(labels) != len(dpops)) and (not isinstance(labels,str)):
            print("Warning: `labels` must be a str or a list of length len(dpops). Setting to None.")
            labels = None
    if (colors is not None):
        if (len(colors) != len(dpops)):
            print("Warning: `colors` must be a list of length len(dpops). Setting to None.")
            colors = None
    if (lws is not None):
        if (len(lws) != len(dpops)):
            print("Warning: `lws` must be a list of length len(dpops). Setting to None.")
            lws = None
   
    fig, ax = plt.subplots(figsize=[10, 5])
    ax.set(ylabel='Probability Density')
    ax.grid(alpha=0.01)
    
    if (var == 'mass') or (var == 'mrat'):
        for i,dp in enumerate(dpops):
            mt, mr = utils.mtmr_from_m1m2(dp.pop.mass)
            data = np.log10(mt/MSOL) if var=='mass' else mr
            #print(data.min(),data.max())
            xlab = r'$\log_{10}(M_{tot})$' if var=='mass' else r'$q$'
            ax.set(xlabel=xlab)
            if colors is None:
                kale.dist1d(data, density=density, hist=hist, confidence=confidence, 
                            label=dp.lbl, color=dp.color, lw=dp.lw)
            else: 
                kale.dist1d(data, density=density, hist=hist, confidence=confidence, 
                            label=labels[i], color=colors[i], lw=lws[i])

    elif var == 'sepa':
        for i,dp in enumerate(dpops):
            data = np.log10(dp.pop.sepa/PC)
            #print(data.min(),data.max())
            ax.set(xlabel=r'$\log_{10}(a) [pc]$')
            if colors is None:
                kale.dist1d(data, density=density, hist=hist, confidence=confidence, 
                            label=dp.lbl, color=dp.color, lw=dp.lw)
            else: 
                kale.dist1d(data, density=density, hist=hist, confidence=confidence, 
                            label=labels[i], color=colors[i], lw=lws[i])

    else:
        for i,dp in enumerate(dpops):
            redz = cosmo.a_to_z(dp.pop.scafa)
            data = 1+redz
            #print(data.min(),data.max())
            ax.set(xlabel=r'$\log_{10}(1+z)$')
            if colors is None:
                kale.dist1d(data, density=density, hist=hist, confidence=confidence, 
                            label=dp.lbl, color=dp.color, lw=dp.lw)
            else: 
                kale.dist1d(data, density=density, hist=hist, confidence=confidence, 
 
                            label=labels[i], color=colors[i], lw=lws[i])

def compare_bhmfs(dpops):
    
    assert isinstance(dpops, list), '`dpops` must be a list of discrete populations'

    fig, ax = plt.subplots(figsize=[10, 5])
    ax.set(ylabel='Probability Density')
    ax.grid(alpha=0.01)
    ax.set(yscale='log')
    
    for i,dp in enumerate(dpops):
        mt, mr = utils.mtmr_from_m1m2(dp.pop.mass)
        data = np.log10(mt/MSOL)
        #mbinsize = 0.5
        #mbins = np.arange(4.75, 11.25, mbinsize)
        mhist, mbin_edges = np.histogram(data, range=(5.25,11.25), bins=12) #bins=mbins)
        mbinsize = mbin_edges[1]-mbin_edges[0]
        mbins = mbin_edges[:-1]+mbinsize
        
        xlab = r'$\log_{10}(M_{tot})$' 
        ax.set(xlabel=xlab)
        ax.set(ylabel=r'BHMF [$(\log_{10} M)^{-1} Mpc^{-3}$]')
        #print(mhist.size, mbin_edges.size) #, mbins.size)
        #print(mbin_edges)
        #print(mhist)
        box_vol_mpc = dp.evo._sample_volume / (1.0e6*PC)**3
        print(box_vol_mpc)
        plt.plot(mbins, mhist/mbinsize/box_vol_mpc, color=dp.color, lw=dp.lw, label=dp.lbl)

def compare_gsmfs(dpops):
    
    assert isinstance(dpops, list), '`dpops` must be a list of binary populations'

    fig, ax = plt.subplots(figsize=[10, 5])
    ax.set(ylabel='Probability Density')
    ax.grid(alpha=0.01)
    ax.set(yscale='log')
    
    for i,dp in enumerate(dpops):
        #mt, mr = utils.mtmr_from_m1m2(dp.pop.mass)
        data = np.log10(dp.pop.mbulge/MSOL)
        #print(data)
        #mbinsize = 0.5
        #mbins = np.arange(4.75, 11.25, mbinsize)
        box_vol_mpc = dp.evo._sample_volume / (1.0e6*PC)**3

        mhist, mbin_edges = np.histogram(data, range=(7.25,14.25), bins=20) #bins=mbins)
        mbinsize = mbin_edges[1]-mbin_edges[0]
        mbins = mbin_edges[:-1]+mbinsize
        
        mhist_z0to1, mbin_edges_z0to1 = np.histogram(data[dp.pop.redz<1.0], range=(7.25,14.25), 
                                                     bins=mbin_edges)
        mhist_zgt4, mbin_edges_zgt4 = np.histogram(data[dp.pop.redz>4.0], range=(7.25,14.25), 
                                                   bins=mbin_edges)
        #print(f"{dp.lbl} m histograms:")
        #print(mhist.shape, mhist_z0to1.shape, mhist_zgt4.shape, mbins.shape)
        #print(mhist/mbinsize/np.log(10)/box_vol_mpc, mhist_z0to1/mbinsize/np.log(10)/box_vol_mpc, 
        #      mhist_zgt4/mbinsize/np.log(10)/box_vol_mpc)
        #print(mhist, mhist_z0to1, mhist_zgt4, mbins, mbinsize)
        xlab = r'$\log_{10}(M_{*})$' 
        ax.set(xlabel=xlab)
        ax.set(ylabel=r'GSMF [$(dex^{-1} Mpc^{-3}$]')
        #print(mhist.size, mbin_edges.size) #, mbins.size)
        #print(mbin_edges)
        #print(mhist)
        #plt.plot(mbins, mhist_z0to1/mbinsize/np.log(10)/box_vol_mpc, color=colors[i], 
        #         lw=np.maximum(lws[i]-1,0.5), ls='--') #,label=labels[i]+' (merging gals,z<1)')
        #plt.plot(mbins, mhist_zgt4/mbinsize/np.log(10)/box_vol_mpc, color=colors[i], 
        #         lw=np.maximum(lws[i]-1,0.5), ls=':') #, label=labels[i]+' (merging gals,z>4)')
        plt.plot(mbins, mhist/mbinsize/np.log(10)/box_vol_mpc, color=dp.color, 
                 lw=dp.lw,label=dp.lbl+' (merging gals, all z)')

def compare_gal_merger_rate(dpops, colors=None, lws=None, labels=None, 
                            sam_compare=None, gpf_flag=None, mcut=None, qcut=None):

    fig = plt.figure(figsize=(12,3))
    ax1 = fig.add_subplot(131) 
    ax1.set_yscale('log')
    ax1.set_xlabel(r'$\log_{10}(m_{*1})$')
    #ax1.set_ylabel(r'$\frac{d\eta_{gal-gal}}{d\log_{10}m_{*1}}$ [$(\log_{10}m_{*1})^{-1}$ Mpc$^{-3}$]')
    ax2 = fig.add_subplot(132) 
    ax2.set_yscale('log')
    ax2.set_xscale('log')
    ax2.set_xlabel(r'$q_*$')
    #ax2.set_ylabel(r'$d\eta_{gal-gal}$ / $dq_{*1}$ [Mpc$^{-3}$]')
    ax3 = fig.add_subplot(133) 
    ax3.set_xscale('log')
    ax3.set_yscale('log')
    ax3.set_xlabel('1+z')
    ax3.set_xlim(1,6)
    #ax3.set_ylabel(r'$d\eta_{gal-gal}$ / $dz$ [Mpc$^{-3}$]')

    if sam_compare is not None:
        assert gpf_flag is not None, "gpf_flag must be defined if sam_compare is not None."
        
        if (not isinstance(sam_compare,list)) and (not isinstance(sam_compare,tuple)):
            sam_compare = [sam_compare]
        if (not isinstance(gpf_flag,list)) and (not isinstance(gpf_flag,tuple)):
            gpf_flag = [gpf_flag]

        ls_arr = ['-.','-']
        lw_arr = [2,3]  
        color_arr = ['m','k']

        for i,s in enumerate(sam_compare):
    
            lbl = 'gpf' if gpf_flag[i] else 'gmr'
            print(f"{i=}, {gpf_flag[i]=}, {lbl=}")

            mstar_pri, mstar_rat, mstar_tot, redz = s.mass_stellar() 
            #mstar = mstar_pri
            mstar = mstar_tot
            mstar_arr = mstar[:,-1,0]
            mstar_rat_arr = mstar_rat[0,:,0]
            redz_arr = redz[0,0,:]
            if mcut is not None:
                ix = np.where(mstar_arr>=mcut)[0][0]
                #print(ix)
                #print(np.log10(mstar_arr/MSOL))
                #print(np.log10(mcut/MSOL))
                mstar_arr = mstar_arr[ix:]
                mstar = mstar[ix:,:,:]
                mstar_rat = mstar_rat[ix:,:,:]
                redz = redz[ix:,:,:]
                print(f"imposed mass cut >{mcut/MSOL} msun. "
                      f"new mass array shape: {mstar_arr.shape}")
            if qcut is not None:
                ixq = np.where(mstar_rat_arr >= qcut)[0][0]
                mstar_rat_arr = mstar_rat_arr[ixq:]
                mstar = mstar[:,ixq:,:]
                mstar_rat = mstar_rat[:,ixq:,:]
                redz = redz[:,ixq:,:]
                print(f"imposed mass ratio cut >{qcut/MSOL} msun. "
                      f"new mass ratio array shape: {mstar_rat_arr.shape}")
                
            gal_merger_rate = s._gal_merger_rate(mstar, mstar_rat, redz) * GYR # convert from s^-1 to Gyr^-1   
                
            #if not hasattr(s, '_gal_mrg_rate'): bhd = s.static_binary_density
            ###gal_merger_rate_old = s._gal_mrg_rate * GYR # convert from s^-1 to Gyr^-1
            #gal_merger_rate_old = s._gmr(mstar, mstar_rat, redz) * GYR # convert from s^-1 to Gyr^-1
            #if mcut is not None:
            #    gal_merger_rate_old = gal_merger_rate_old[ix:,:,:]
            
            #gal_merger_rate_new = s._gal_merger_rate(mstar, mstar_rat, redz) * GYR # convert from s^-1 to Gyr^-1
            #if mcut is not None:
            #    gal_merger_rate_new = gal_merger_rate_new[ix:,:,:]

            #if s._gmr is None:
            #    zprime, gmt_time = s._gmt.zprime(mstar, mstar_rat, redz)
            #    # `gmt` returns [sec]  `gpf` is dimensionless,  so this is [1/sec]
            #    gal_merger_rate = s._gpf(mstar, mstar_rat, redz) / gmt_time * GYR # convert from s^-1 to Gyr^-1
            #else:
            #    gal_merger_rate = s._gmr(mstar, mstar_rat, redz) * GYR # convert from s^-1 to Gyr^-1

            #for j,gal_merger_rate in enumerate([gal_merger_rate_old,gal_merger_rate_new]): 
                
            print(f'sam gal merg rate shape: {gal_merger_rate.shape}')
            print(gal_merger_rate.min(),gal_merger_rate.max())
            ###dens = self._gsmf(mass_gsmf, redz) * gal_merger_rate * cosmo.dtdz(redz)
        
            integ = utils.trapz(gal_merger_rate, np.log10(mstar_arr/MSOL), axis=0, cumsum=False) # int over primary mass with mrat=1
            print(f"integ over mstar_arr: {integ.shape}")
            sam_gal_mrg_rate_vs_z = utils.trapz(integ, mstar_rat_arr, axis=1, cumsum=False) # int over mstar_rat (same for all mstar_pri)
            print(f"integ over mstar_rat_arr: {sam_gal_mrg_rate_vs_z.shape}")
            sam_gal_mrg_rate_vs_z = sam_gal_mrg_rate_vs_z.sum(axis=0)
            print(f"sum over mstar_arr: {sam_gal_mrg_rate_vs_z.shape}")
            sam_gal_mrg_rate_vs_z = sam_gal_mrg_rate_vs_z.sum(axis=0)
            print(f"sum over mstar_rat_arr: {sam_gal_mrg_rate_vs_z.shape}")
            #print(f"total sam_gal_mrg_rate_vs_z={sam_gal_mrg_rate_vs_z}")
                
            integ = utils.trapz(gal_merger_rate, mstar_rat_arr, axis=1, cumsum=False)
            sam_gal_mrg_rate_vs_m = utils.trapz(integ, redz_arr, axis=2, cumsum=False)
            sam_gal_mrg_rate_vs_m = sam_gal_mrg_rate_vs_m.sum(axis=1)
            sam_gal_mrg_rate_vs_m = sam_gal_mrg_rate_vs_m.sum(axis=1)
            #print(f"sam_gal_mrg_rate_vs_m={sam_gal_mrg_rate_vs_m}")

            integ = utils.trapz(gal_merger_rate, np.log10(mstar_arr/MSOL), axis=0, cumsum=False)
            sam_gal_mrg_rate_vs_q = utils.trapz(integ, redz_arr, axis=2, cumsum=False)
            sam_gal_mrg_rate_vs_q = sam_gal_mrg_rate_vs_q.sum(axis=0)
            sam_gal_mrg_rate_vs_q = sam_gal_mrg_rate_vs_q.sum(axis=1)
            #print(f"sam_gal_mrg_rate_vs_q={sam_gal_mrg_rate_vs_q}")

            #ax2.plot(np.log10(mstar_pri_arr/MSOL), sam_gal_mrg_rate_vs_m, ls=ls_arr[i],lw=lw_arr[i],
            ax1.plot(np.log10(mstar_arr/MSOL), sam_gal_mrg_rate_vs_m, ls=ls_arr[i],lw=lw_arr[i],
                     color=color_arr[i],label=lbl)
            ax2.plot(mstar_rat_arr, sam_gal_mrg_rate_vs_q, ls=ls_arr[i],lw=lw_arr[i],
                     color=color_arr[i],label=lbl)
            ax3.plot(1+redz_arr, sam_gal_mrg_rate_vs_z, ls=ls_arr[i],lw=lw_arr[i],
                     color=color_arr[i],label=lbl)
        ax1.legend()
        ax2.legend()
        ax3.legend()
        
def compare_gal_merger_dens(dpops, colors=None, lws=None, labels=None, 
                            sam_compare=None, gpf_flag=None, mcut=None, qcut=None):

    fig = plt.figure(figsize=(14,5))
    ax1 = fig.add_subplot(131) 
    ax1.set_xlabel(r'$\log_{10}(m_{*1})$')
    ax1.set_ylabel(r'$\frac{d\eta_{gal-gal}}{d\log_{10}m_{*1}}$ [Mpc$^{-3}$]')

    ax2 = fig.add_subplot(132) 
    ax2.set_xlabel(r'$q_*$')
    ax2.set_ylabel(r'$d\eta_{gal-gal}$ / $dq_{*1}$ [Mpc$^{-3}$]')
    ax2.set_xscale('log')

    ax3 = fig.add_subplot(133) 
    ax3.set_xlabel('redshift')
    ax3.set_ylabel(r'$d\eta_{gal-gal}$ / $dz$ [Mpc$^{-3}$]')

    ls_arr = ['-.','-']
    lw_arr = [2,3]  
    color_arr = ['m','k']

    if sam_compare is not None:
        assert gpf_flag is not None, "gpf_flag must be defined if sam_compare is not None."
        
        if (not isinstance(sam_compare,list)) and (not isinstance(sam_compare,tuple)):
            sam_compare = [sam_compare]
        if (not isinstance(gpf_flag,list)) and (not isinstance(gpf_flag,tuple)):
            gpf_flag = [gpf_flag]
            
        for i,s in enumerate(sam_compare):

            #print(f"\n\n *** {i=}, {gpf_flag=} *** \n\n")
            lbl = 'gpf' if gpf_flag[i] else 'gmr'
            print(f"{i=}, {gpf_flag[i]=}, {lbl=}")
            
            mstar_pri, mstar_rat, mstar_tot, redz = s.mass_stellar() 
            #mstar = mstar_pri
            mstar = mstar_tot
            mstar_arr = mstar[:,-1,0]
            mstar_rat_arr = mstar_rat[0,:,0]
            redz_arr = redz[0,0,:]
            print(f"{mstar.shape=}, {mstar_rat.shape=}, {redz.shape=}")
            print(f"{mstar_arr.shape=}, {mstar_rat_arr.shape=}, {redz_arr.shape=}")

                
            if mcut is not None:
                ixm = np.where(mstar_arr>=mcut)[0][0]
                print(f"mass cut index ={ixm}")
                #print(np.log10(mstar_arr/MSOL))
                #print(np.log10(mcut/MSOL))
                mstar_arr = mstar_arr[ixm:]
                mstar = mstar[ixm:,:,:]
                mstar_rat = mstar_rat[ixm:,:,:]
                redz = redz[ixm:,:,:]
                print(f"imposed mass cut >{mcut/MSOL:.4g} msun. "
                      f"new mass array shape: {mstar_arr.shape}")
                #print(f"{mstar.shape=}, {mstar_rat.shape=}, {redz.shape=}")
                #print(f"{mstar_arr.shape=}, {mstar_rat_arr.shape=}, {redz_arr.shape=}")

            if qcut is not None:
                ixq = np.where(mstar_rat_arr >= qcut)[0][0]
                mstar_rat_arr = mstar_rat_arr[ixq:]
                mstar = mstar[:,ixq:,:]
                mstar_rat = mstar_rat[:,ixq:,:]
                redz = redz[:,ixq:,:]
                print(f"imposed mass ratio cut >{qcut} "
                      f"new mass ratio array shape: {mstar_rat_arr.shape}")
            
            print(f"{mstar.shape=}, {mstar_rat.shape=}, {redz.shape=}")
            print(f"{mstar_arr.shape=}, {mstar_rat_arr.shape=}, {redz_arr.shape=}")
                

            #if gpf_flag[i] == 0:
            #    mstar_bins = mstar_arr[1:]-mstar_arr[:-1]
            #    mstar_bins = np.append(mstar_bins, mstar_bins[-1])
            #    mstar_rat_bins = mstar_rat_arr[1:]-mstar_rat_arr[:-1]
            #    mstar_rat_bins = np.append(mstar_rat_bins, mstar_rat_bins[-1])
            #    redz_bins = redz_arr[1:]-redz_arr[:-1]
            #    redz_bins = np.append(redz_bins, redz_bins[-1])
            #    print(f"mstar_bins.shape={mstar_bins.shape}, mstar_rat_bins.shape={mstar_rat_bins.shape}, "
            #          "redz_bins.shape={redz_bins.shape}")
            #print(mstar_pri.shape, mstar_rat.shape, mstar_tot.shape, redz.shape)
            print(f"{mstar.shape=}, {mstar_rat.shape=}, {redz.shape=}")
            #print("dens_gal_gal and mstar shapes:")

            #sam_gal_merger_dens = sam._ndens_gal(mstar_pri, mstar_rat, redz) ## only works if gpf is defined
            #sam_gal_merger_dens = sam._dens_gal_gal
            #for j,nd in enumerate([sam._ndens_gal(mstar_pri, mstar_rat, redz), sam._dens_gal_gal]):

            nd = s._ndens_gal(mstar, mstar_rat, redz) ## dnt need cuts bc passed arrays to it that already had
            
            #print(f"{nd.shape=}")
            print(f"{nd.shape=},{nd.min()=},{nd.max()=}")

            ## my_nd is so cancelled. it sucks.
            #my_nd = s._dens_gal_gal
            #print(f"{my_nd.shape=}")
            #print(f"{my_nd.shape=},{my_nd.min()=},{my_nd.max()=}")

            #if mcut is not None:
            #    my_nd = my_nd[ixm:,:,:]
            #    print(f"after mcut: {my_nd.shape=}")
            
            #if qcut is not None:
            #    my_nd = my_nd[:,ixq:,:]
            #    print(f"after qcut: {my_nd.shape=}")

            #print(f"{my_nd.shape=},{my_nd.min()=},{my_nd.max()=}")
            #print(f"{new_nd.shape=},{new_nd.min()=},{new_nd.max()=}")

            #nd_arr = [my_nd, new_nd]

            #nd_arr = [new_nd]

            #if gpf_flag[i] == 0:
            #    print(f"my_nd.shape:{my_nd.shape}, new_nd.shape:{new_nd.shape}")
            #    for k in range(my_nd.shape[0]):
            #        my_nd[k,:,:] = my_nd[k,:,:] / np.log10(mstar_bins[k]/MSOL)  #/ redz_bins
            #        new_nd[k,:,:] = new_nd[k,:,:] / np.log10(mstar_bins[k]/MSOL)  #/ redz_bins
            #    for k in range(my_nd.shape[2]):
            #        my_nd[:,:,k] = my_nd[:,:,k] / redz_bins[k]
            #        new_nd[:,:,k] = new_nd[:,:,k] / redz_bins[k]

            

            #else:
            #    nd_arr = [s._ndens_gal(mstar_pri, mstar_rat, redz)]
            #ls_arr = ['-',':']
            #lw_arr = [3,2]  
            ##alph_arr = [0.5,1]
            #color_arr = ['cyan','orange']
            #lbl_arr = ['my nd','new nd']
            #lbl_arr = ['gpf', 'gmr']
            #lbl_add = 'gpf' if i==0 else 'gmr'
            
            #for j,nd in enumerate(nd_arr):
                
            #print(nd.shape, mstar_pri.shape, mstar_rat.shape, mstar_tot.shape, redz.shape)
            #print(f"{nd.shape=}, {mstar.shape=}, {mstar_rat.shape=}, {redz.shape=}")
            #print(f"{mstar[:,-1,0].shape=}, {mstar_rat[0,:,0].shape=}")
            #print(f"{redz[0,0,:].shape=}")
            #mstar_arr = mstar_pri[:,-1,0]
            #mstar_rat_arr = mstar_rat[0,:,0]
            #redz_arr = redz[0,0,:]
        
            #print('my_nd:') if j==0 else print('new_nd')
                        
            #print(f"{nd.shape=}, {mstar_arr.shape=}")
            integ = utils.trapz(nd, np.log10(mstar_arr/MSOL), axis=0, cumsum=False) # int over primary mass with mrat=1
            print(f"integ over mstar_arr: {integ.shape}")
            sam_dngalgal_dz = utils.trapz(integ, mstar_rat_arr, axis=1, cumsum=False) # int over mstar_rat (same for all mstar_pri)
            print(f"integ over mstar_rat_arr: {sam_dngalgal_dz.shape}")
            sam_dngalgal_dz = sam_dngalgal_dz.sum(axis=0)
            print(f"sum over mstar_arr: {sam_dngalgal_dz.shape}")
            sam_dngalgal_dz = sam_dngalgal_dz.sum(axis=0)
            print(f"sum over mstar_rat_arr: {sam_dngalgal_dz.shape}")
            print(f"total sam_dngalgal (from dz): {sam_dngalgal_dz.sum()}")

            integ = utils.trapz(nd, mstar_rat_arr, axis=1, cumsum=False)
            sam_dngalgal_dmstar = utils.trapz(integ, redz_arr, axis=2, cumsum=False)
            sam_dngalgal_dmstar = sam_dngalgal_dmstar.sum(axis=1)
            sam_dngalgal_dmstar = sam_dngalgal_dmstar.sum(axis=1)
            print(f"total sam_dngalgal (from dmstar): {sam_dngalgal_dmstar.sum()}")

            integ = utils.trapz(nd, np.log10(mstar_arr/MSOL), axis=0, cumsum=False)
            sam_dngalgal_dmstar_rat = utils.trapz(integ, redz_arr, axis=2, cumsum=False)
            sam_dngalgal_dmstar_rat = sam_dngalgal_dmstar_rat.sum(axis=0)
            sam_dngalgal_dmstar_rat = sam_dngalgal_dmstar_rat.sum(axis=1)
            print(f"total sam_dngalgal (from dmstar_rat): {sam_dngalgal_dmstar_rat.sum()}")

            ax1.plot(np.log10(mstar_arr/MSOL), sam_dngalgal_dmstar,lw=lw_arr[i],
                     ls=ls_arr[i], color=color_arr[i],label=lbl)
            ax2.plot(mstar_rat_arr, sam_dngalgal_dmstar_rat,lw=lw_arr[i],
                     ls=ls_arr[i], color=color_arr[i],label=lbl)
            ax3.plot(redz_arr, sam_dngalgal_dz,lw=lw_arr[i],
                     ls=ls_arr[i], color=color_arr[i],label=lbl)

    for i,dp in enumerate(dpops):
        box_vol_mpc = dp.evo._sample_volume / (1.0e6*PC)**3

        mstar = dp.pop.mbulge
        mstar1 = np.max(mstar,axis=1)
        mstar2 = np.min(mstar,axis=1)
        mstar_tot = mstar1 + mstar2
        qstar = mstar2 / mstar1
        
        redz = 1.0/dp.pop.scafa - 1
        dngalgal_tot = redz.size / box_vol_mpc # total galaxy merger number density in cMpc^-3
        print(f'dngalgal_tot = {dngalgal_tot}')
        z_hist,z_bin_edges = np.histogram(redz, bins=20)
        z_binsize = z_bin_edges[1:] - z_bin_edges[:-1]
        dngalgal_dz = z_hist / z_binsize / box_vol_mpc
        
        #lgm1_hist, lgm1_bin_edges = np.histogram(np.log10(mstar1/MSOL), bins=20)
        lgm1_hist, lgm1_bin_edges = np.histogram(np.log10(mstar_tot/MSOL), bins=20)
        lgm1_binsize = lgm1_bin_edges[1:] - lgm1_bin_edges[:-1]
        dngalgal_dlog10m1 = lgm1_hist / lgm1_binsize / box_vol_mpc

        qstar_hist, qstar_bin_edges = np.histogram(qstar, bins=20)
        qstar_binsize = qstar_bin_edges[1:] - qstar_bin_edges[:-1]
        dngalgal_dqstar = qstar_hist / qstar_binsize / box_vol_mpc
        print(f'dngalgal_tot = {dngalgal_tot/qstar_binsize[0]/lgm1_binsize[0]/z_binsize[0]}')

        ax1.set_ylim(1.0e-6,0.5)
        ax1.set_yscale('log')
        ax2.set_yscale('log')
        ax3.set_yscale('log')
        ax1.plot(lgm1_bin_edges[:-1], dngalgal_dlog10m1, label=dp.lbl)
        ax2.plot(qstar_bin_edges[:-1], dngalgal_dqstar, label=dp.lbl)
        ax3.plot(z_bin_edges[:-1], dngalgal_dz, label=dp.lbl)

    ax1.legend()
    ax2.legend()
    ax3.legend()
    fig.subplots_adjust(wspace=0.5)
                        
def plot_mbh_scaling_relations(pop, fname=None, color='r', compare_pops=None, ncols=1, nrows=1,
                               xlim=None, ylim=None):
    units = r"$[\log_{10}(M/M_\odot)]$"

    if compare_pops == None:
        fig, axes = plt.subplots(figsize=[12, 6])
        pops = [pop]
    else:
        fig, axes = plt.subplots(figsize=[12,5], ncols=ncols, nrows=nrows)
        pops = [pop] + compare_pops
        if len(pops) > ncols*nrows:
            raise ValueError(f"len(pops)>ncols*nrows. ({len(pops)} > {ncols}*{nrows}).")

    i = 0
    for idx, ax in np.ndenumerate(axes):
        if i > len(pops): break

        print(f"idx: {idx}, i: {i}")
        ax.set(xlabel=f'Stellar Mass {units}', ylabel=f'BH Mass {units}')
        if xlim is not None: ax.set(xlim=xlim)
        if ylim is not None: ax.set(ylim=ylim)

        #   ====    Plot McConnell+Ma-2013 Data    ====
        handles = []
        names = []
        if fname is not None:
            hh = _draw_MM2013_data(ax, fname)
            handles.append(hh)
            names.append('McConnell+Ma')

        #   ====    Plot MBH Merger Data    ====
        hh, nn = _draw_pop_masses(ax, pops[i], color, nplot=1e6)
        handles = handles + hh
        names = names + nn
        ax.legend(handles, names)
        
        i += 1

    return fig

def _draw_MM2013_data(ax):
    data = holo.observations.load_mcconnell_ma_2013()
    data = {kk: data[kk] if kk == 'name' else np.log10(data[kk]) for kk in data.keys()}
    key = 'mbulge'
    mass = data['mass']
    yy = mass[:, 1]
    yerr = np.array([yy - mass[:, 0], mass[:, 2] - yy])
    vals = data[key]
    if np.ndim(vals) == 1:
        xx = vals
        xerr = None
    elif vals.shape[1] == 2:
        xx = vals[:, 0]
        xerr = vals[:, 1]
    elif vals.shape[1] == 3:
        xx = vals[:, 1]
        xerr = np.array([xx-vals[:, 0], vals[:, 2]-xx])
    else:
        raise ValueError()

    idx = (xx > 0.0) & (yy > 0.0)
    if xerr is not None:
        xerr = xerr[:, idx]
    ax.errorbar(xx[idx], yy[idx], xerr=xerr, yerr=yerr[:, idx], fmt='none', zorder=10)
    handle = ax.scatter(xx[idx], yy[idx], zorder=10)
    ax.set(ylabel='MBH Mass', xlabel=key)

    return handle

def _draw_pop_masses(ax, pop, color='r', nplot=3e3):
    print(pop.mbulge.shape, pop.mass.shape)
    #xx = pop.mbulge.flatten() / MSOL
    xx = pop.mbulge[:,:2].flatten() / MSOL
    yy_list = [pop.mass]
    names = ['new']
    if hasattr(pop, '_mass'):
        yy_list.append(pop._mass)
        names.append('old')

    colors = [color, '0.5']
    handles = []
    if xx.size > nplot:
        cut = np.random.choice(xx.size, int(nplot), replace=False)
        print("Plotting {:.1e}/{:.1e} data-points".format(nplot, xx.size))
    else:
        cut = slice(None)

    for ii, yy in enumerate(yy_list):
        yy = yy.flatten() / MSOL
        data = np.log10([xx[cut], yy[cut]])
        kale.plot.dist2d(
            data, ax=ax, color=colors[ii], hist=False, contour=True,
            median=True, mask_dense=True,
        )
        hh, = plt.plot([], [], color=colors[ii])
        handles.append(hh)

    return handles, names

def plot_evo(evo, freqs=None, sepa=None, ax=None, label=None, color=None, **kwargs):
    if (freqs is None) and (sepa is None):
        err = "Either `freqs` or `sepa` must be provided!"
        log.exception(err)
        raise ValueError(err)

    if freqs is not None:
        data = evo.at('fobs', freqs)
        xx = freqs * YR
        xlabel = 'GW Frequency [1/yr]'
    else:
        data = evo.at('sepa', sepa)
        xx = sepa / PC
        xlabel = 'Binary Separation [pc]'

    if ax is None:
        fig, ax = plot.figax(xlabel=xlabel)
    else:
        fig = ax.get_figure()

    def _draw_vals_conf(ax, xx, vals, color=color, label=label):
        if color is None:
            color = ax._get_lines.get_next_color()
        if label is not None:
            ax.set_ylabel(label, color=color)
            ax.tick_params(axis='y', which='both', colors=color)
        # vals = np.percentile(vals, [25, 50, 75], axis=0) / units
        vals = utils.quantiles(vals, [0.25, 0.50, 0.75], axis=0).T
        h1 = ax.fill_between(xx, vals[0], vals[-1], alpha=0.2, color=color)
        h2, = ax.plot(xx, vals[1], alpha=0.5, lw=2.0, color=color)
        return (h1, h2)

    # handles = []
    # labels = []

    name = 'Hardening Time [yr]'
    vals = np.fabs(data['sepa'] / data['dadt']) / YR
    _draw_vals_conf(ax, xx, vals, label=name)
    # handles.append(hh)
    # labels.append(name)

    # name = 'eccen'
    # tw = ax.twinx()
    # hh, nn = _draw_vals_conf(tw, freqs*YR, name, 'green')
    # if hh is not None:
    #     handles.append(hh)
    #     labels.append(nn)

    # ax.legend(handles, labels)
    return ax

In [None]:
class Discrete:
    
    def __init__(self, freqs, freqs_edges, attrs=(None,'k',1.0), lbl=None, fixed_sepa=None, 
                 tau=2.0*YR, nreals=500, mod_mmbulge=False, rescale_mbulge=False):

        self.attrs = attrs
        self.freqs = freqs
        self.freqs_edges = freqs_edges
        self.lbl = lbl
        self.fname = self.attrs[0]
        self.color = self.attrs[1]
        self.lw = self.attrs[2]
        self.fixed_sepa = fixed_sepa
        self.tau = tau
        self.nreals = nreals
        self.mod_mmbulge = mod_mmbulge
        
        print(f"\nCreating Discrete_Pop class instance '{self.lbl}' with tau={self.tau}, fixed_sepa={self.fixed_sepa}")
        print(f" fname={self.fname}")
        self.pop = discrete.population.Pop_Illustris(fname=self.fname, fixed_sepa=self.fixed_sepa)

        # apply modifiers if requested
        if self.mod_mmbulge == True:
            self.mmbulge = holo.relations.MMBulge_KH2013()
            self.mod_KH2013 = discrete.population.PM_Mass_Reset(self.mmbulge, scatter=True, 
                                                            rescale_mbulge=rescale_mbulge)
            self.pop.modify(self.mod_KH2013)

        # create a fixed-total-time hardening mechanism
        print(f"modeling fixed-total-time hardening...")
        self.fixed = holo.hardening.Fixed_Time_2PL.from_pop(self.pop, self.tau)

        # Create an evolution instance using population and hardening mechanism
        print(f"creating evolution instance and evolving it...")
        self.evo = discrete.evolution.Evolution(self.pop, self.fixed)
        # evolve binary population
        self.evo.evolve()
        print("vol:",self.evo._sample_volume)

        ## create GWB
        self.gwb = holo.gravwaves.GW_Discrete(self.evo, self.freqs, nreals=self.nreals)
        self.gwb.emit()

    def get_amplitudes_at_freqs(self, select_freqs=None):
        if (select_freqs is not None):
            print("sorry this function sucks, you cannot select freqs yet. choosing 1/yr, 1/3yr, 1/10yr.")
        
        # ---- find frequency bins closest to 1/yr, 1/(3yr), 1/(10yr)
        self.idx_ayr = np.where(np.abs(self.freqs-1/YR)==np.abs(self.freqs-1/YR).min())[0]
        self.idx_a3yr = np.where(np.abs(self.freqs-1/(3*YR))==np.abs(self.freqs-1/(3*YR)).min())[0]
        self.idx_a10yr = np.where(np.abs(self.freqs-1/(10*YR))==np.abs(self.freqs-1/(10*YR)).min())[0]
        print(self.idx_ayr,self.idx_a3yr,self.idx_a10yr)

        self.ayr = self.gwb.back[self.idx_ayr,:].flatten()
        self.a3yr = self.gwb.back[self.idx_a3yr,:].flatten()
        self.a10yr = self.gwb.back[self.idx_a10yr,:].flatten()




## Create Discrete Populations, each including a simple binary-evolution model

In [None]:
def create_dpops(tau=1.0, fsa=1.0e4, inclIll=True, inclOldIll=False, inclT50=True,
                 inclT300=True, inclRescale=False):
    
    # ---- Set the fixed binary lifetime
    print(f"Setting inspiral timescale tau = {tau} Gyr.")
    tau = tau * GYR
    
    # ---- Define the GWB frequencies
    freqs, freqs_edges = utils.pta_freqs()

    # ---- Initialize return variables
    all_dpops = []
    tng_dpops = []

    # ---- (Optionally) set the fixed initial binary separation & initialize fsa return vars
    if fsa is not None:
        print(f"Setting fixed init binary sep = {fsa} pc.")
        fsa = fsa * PC
        all_fsa_dpops = []
        tng_fsa_dpops = []
        
    # ---- Define dpop attributes: (filename, plot color, plot linewidth)
    dpop_attrs = {
        'TNG50-1' : ('galaxy-mergers_TNG50-1_gas-800_dm-800_star-800_bh-001.hdf5', 'r', 3.5),
        'TNG50-2' : ('galaxy-mergers_TNG50-2_gas-100_dm-100_star-100_bh-001.hdf5', 'orange', 2.5),
        'TNG50-3' : ('galaxy-mergers_TNG50-3_gas-012_dm-012_star-012_bh-001.hdf5', 'y', 1.5),
        'oldIll' : (None, 'brown', 2.5),
        'newIll' : ('galaxy-mergers_Illustris-1_gas-100_dm-100_star-100_bh-001.hdf5', 'g', 2.5),
        'TNG100-1' : ('galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-001.hdf5', 'b', 2.5),
        'TNG100-2' : ('galaxy-mergers_TNG100-2_gas-012_dm-012_star-012_bh-001.hdf5', 'c', 1.5),
        'TNG300-1' : ('galaxy-mergers_TNG300-1_gas-012_dm-012_star-012_bh-001.hdf5', 'm', 1.5)
    }
    
    # ---- Loop thru dict and create dpops
    for l in dpop_attrs.keys():
        if ('Ill' in l) and (not inclIll): 
            continue
        if (l == 'oldIll') and (not inclOldIll):
            continue
        if ('TNG50' in l) and (not inclT50): 
            continue
        if ('TNG300' in l) and (not inclT300): 
                continue

        dp = Discrete(freqs, freqs_edges, lbl=l, tau=tau, fixed_sepa=None, attrs=dpop_attrs[l])

        all_dpops = all_dpops + [dp]
        if 'Ill' not in l: 
            tng_dpops = tng_dpops + [dp]
        
        if fsa is not None:

            dp_fsa = Discrete(freqs, freqs_edges, lbl='fsa'+l, tau=tau, 
                              fixed_sepa=fsa, attrs=dpop_attrs[l])

            all_fsa_dpops = all_fsa_dpops + [dp_fsa]
            if 'Ill' not in l: 
                tng_fsa_dpops = tng_fsa_dpops + [dp_fsa]
            
            if ('TNG300' in l) and (inclT300) and (inclRescale):
                rescale_dp_fsa = Discrete(freqs, freqs_edges, lbl='rescale_fsa'+l,
                                          tau=tau, fixed_sepa=fsa, mod_mmbulge=True, 
                                          rescale_mbulge=True, attrs=dpop_attrs[l])
                tng_fsa_dpops = tng_fsa_dpops + [rescale_dp_fsa]

        print(f"{l} dpop_attrs: {dpop_attrs[l][0]} {dpop_attrs[l][1]} {dpop_attrs[l][2]}")

    
    if fsa is not None:

        return all_dpops, tng_dpops, all_fsa_dpops, tng_fsa_dpops

    else:
        
        return all_dpops, tng_dpops

In [None]:
all_dpops, tng_dpops, all_fsa_dpops, tng_fsa_dpops = create_dpops()

In [None]:
tng100_dpops = [d for d in tng_dpops if '100' in d.lbl]
print('tng100_dpops:')
for d in tng100_dpops: print(d.lbl)
tng50_dpops = [d for d in tng_dpops if '50' in d.lbl]
print('tng50_dpops:')
for d in tng50_dpops: print(d.lbl)
tng_hires_dpops = [d for d in tng_dpops if '-1' in d.lbl]
print('tng_hires_dpops:')
for d in tng_hires_dpops: print(d.lbl)
all_hires_dpops = [d for d in all_dpops if ('-1' in d.lbl or d.lbl=='newIll')]
print('all_hires_dpops:')
for d in all_hires_dpops: print(d.lbl)

In [None]:
print("coalescing fractions:")
print("\nsim data:")

#for dpop in all_sim_dpops:
for dpop in tng_dpops:
    frac = np.where(dpop.evo.coal)[0].size / dpop.evo.sepa[:,0].size 
    print(f"{dpop.lbl}: coalescing frac = {frac:.4g}")

print("\nfixed ainit and rescaled mmbulge:")
#for dpop in all_fsa_dpops:
for dpop in tng_fsa_dpops:
    frac = np.where(dpop.evo.coal)[0].size / dpop.evo.sepa[:,0].size 
    print(f"{dpop.lbl}: coalescing frac = {frac:.4g}")


In [None]:
print("Number and fraction of binaries with log(Mtot)>9 and q>0.1:\n")
#for d in all_sim_dpops:
for d in tng_dpops:
    mt, mr = utils.mtmr_from_m1m2(d.pop.mass)
    lgmt = np.log10(mt/MSOL)
    print(f"{d.lbl}: {lgmt[(lgmt>9)&(mr>0.1)].size}, {lgmt[(lgmt>9)&(mr>0.1)].size/lgmt.size:.4g}")
    #print(lgmt[lgmt>9].size/lgmt.size)

##for pop in fsa_pop_list:
##for pop in fsa_pop_list+[rescale_fsa_pop_tng300_1]:
print('')
#for d in all_fsa_dpops:
for d in tng_fsa_dpops:
    mt, mr = utils.mtmr_from_m1m2(d.pop.mass)
    lgmt = np.log10(mt/MSOL)
    print(f"{d.lbl}: {lgmt[(lgmt>9)&(mr>0.1)].size}, {lgmt[(lgmt>9)&(mr>0.1)].size/lgmt.size:.4g}")
#    #print(lgmt[lgmt>9].size, lgmt[lgmt>9].size/lgmt.size)
#    #print(lgmt[lgmt>9].size/lgmt.size)

#pop_tng300_1[pop_tng300_1.mass.sum(axis=1)
#pop_tng300_1[pop_tng300_1.mass[:2].sum

In [None]:
for v in ['mass','mrat','sepa','redz']:
    compare_bin_pops(tng_dpops, var=v)
    plt.legend()
    plt.show()

In [None]:
for v in ['mass','mrat','redz']:
    compare_bin_pops(tng_fsa_dpops, var=v)
    #compare_bin_pops([dpop_fsa_tng100_1.pop, dpop_fsa_tng100_2.pop], colors=colors, lws=lws,
    #                 labels=[dpop_fsa_tng100_1.lbl, dpop_fsa_tng100_2.lbl], var=v)
    #compare_bin_pops([d.pop for d in all_fsa_dpops], colors=colors, lws=lws,
    #                 labels=[d.lbl for d in all_fsa_dpops], var=v)
    plt.legend()
    plt.show()


In [None]:
compare_bhmfs(tng_dpops)
plt.legend()
plt.show()

In [None]:
compare_bhmfs(tng_fsa_dpops)
plt.legend()
plt.show()

In [None]:
compare_gsmfs(tng_dpops)
plt.legend()
plt.show()

In [None]:
compare_gsmfs(tng_fsa_dpops)
plt.legend()
plt.show()

# Create SAM for comparison

In [None]:
print("creating sam")
#sam = sams.Semi_Analytic_Model()
sam = sams.Semi_Analytic_Model(gpf = sams.GPF_Power_Law())
#print(f"{sam._dens_gal_gal.min()=},{sam._dens_gal_gal.max()=}")
print("calculating hardening")
hard = holo.hardening.Fixed_Time_2PL_SAM(sam, tng_dpops[0].tau, sepa_init=1.0e4*PC)
#print(f"{sam._dens_gal_gal.min()=},{sam._dens_gal_gal.max()=}")
print("creating gwb")
#gwb_sam = sam.new_gwb(freqs_edges, hard, realize=500)    # calculate many different realizations
gwb_sam = sam.gwb_new(tng_dpops[0].freqs_edges, hard, realize=100)
print(f"{sam._dens_gal_gal.min()=},{sam._dens_gal_gal.max()=}")


In [None]:
print(f"{sam._dens_gal_gal.min()=},{sam._dens_gal_gal.max()=}")


In [None]:
print("creating no-gpf sam")
sam_no_gpf = sams.Semi_Analytic_Model()
hard_no_gpf = holo.hardening.Fixed_Time_2PL_SAM(sam_no_gpf, tng_dpops[0].tau, sepa_init=1.0e4*PC)
print("creating no-gpf gwb")
gwb_sam_no_gpf = sam.gwb_new(tng_dpops[0].freqs_edges, hard_no_gpf, realize=100)

In [None]:
import illustris_python as il
def calc_gsmf_from_snap(basePath,snapNum):
    subs = il.groupcat.loadSubhalos(basePath,snapNum,fields=['SubhaloLenType','SubhaloMassType',
                                                             'SubhaloMassInHalfRadType','SubhaloMassInRadType'])
    mstar = subs['SubhaloMassType'][:,4]*1.0e10/0.704
    Ngas = subs['SubhaloLenType'][:,0]
    Ndm = subs['SubhaloLenType'][:,1]
    Nstar = subs['SubhaloLenType'][:,4]
    Nbh = subs['SubhaloLenType'][:,5]
    
    mstar = mstar[(Ngas>100)&(Ndm>100)&(Nstar>100)&(Nbh>0)]
    lgmstar = np.log10(mstar)
    #print(lgmstar.min(),lgmstar.max())
    mstar_hist, bin_edges = np.histogram(lgmstar, bins=20)
    binsize = bin_edges[1:] - bin_edges[:-1]
    box_vol_mpc = 1357213.6324803103 # TNG100
    gsmf = mstar_hist / binsize / np.log(10) / box_vol_mpc # dex^-1 Mpc^-3
    #print(gsmf)
    #print(bin_edges)
    return gsmf, bin_edges

In [None]:
#def schech_yourself(mstar, phi0, phiz, mchar, alpha0, alphaz):
#    """
#    Before you wrech yourself.
#    
#    Still havent figured out why the values from the sam are lower than Chen et al. Units?? Sad.
#    """
#    phi = #self._phi_func(redz)
#    mchar = #self._mchar_func(redz)
#    alpha = #self._alpha_func(redz)
#    xx = mstar / mchar
#    # [Chen2019]_ Eq.8
#    sf = np.log(10.0) * phi * np.power(xx, 1.0 + alpha) * np.exp(-xx)
#    return sf

In [None]:
tng100_1_basePath = '/orange/lblecha/IllustrisTNG/Runs/TNG100-1/output/'
gsmf_tng100_1_snap99, bin_edges_tng100_1_snap99 = calc_gsmf_from_snap(tng100_1_basePath,99) # z=0
gsmf_tng100_1_snap17, bin_edges_tng100_1_snap17 = calc_gsmf_from_snap(tng100_1_basePath,17) # z=5
gsmf_tng100_1_snap33, bin_edges_tng100_1_snap33 = calc_gsmf_from_snap(tng100_1_basePath,33) # z=2
gsmf_tng100_1_snap67, bin_edges_tng100_1_snap67 = calc_gsmf_from_snap(tng100_1_basePath,67) # z=0.5
#gsmf_tng100_1_snap72, bin_edges_tng100_1_snap72 = calc_gsmf_from_snap(tng100_1_basePath,72) # z=0.4

In [None]:
tng50_1_basePath = '/orange/lblecha/IllustrisTNG/Runs/TNG50-1/output/'
gsmf_tng50_1_snap99, bin_edges_tng50_1_snap99 = calc_gsmf_from_snap(tng50_1_basePath,99) # z=0
gsmf_tng50_1_snap17, bin_edges_tng50_1_snap17 = calc_gsmf_from_snap(tng50_1_basePath,17) # z=5
gsmf_tng50_1_snap33, bin_edges_tng50_1_snap33 = calc_gsmf_from_snap(tng50_1_basePath,33) # z=2
gsmf_tng50_1_snap67, bin_edges_tng50_1_snap67 = calc_gsmf_from_snap(tng50_1_basePath,67) # z=0.5
#gsmf_tng50_1_snap72, bin_edges_tng50_1_snap72 = calc_gsmf_from_snap(tng50_1_basePath,72) # z=0.4

In [None]:
print(sam.mtot.shape, sam.mrat.shape, hard._sepa_init, hard._norm.shape)
#mass_stellar() returns (mstar_pri, mstar_rat, mstar_tot, redz)
mstar_pri,mstar_rat,mstar_tot,redz = sam.mass_stellar() 
print(mstar_pri.shape)
print(mstar_pri.min()/MSOL, mstar_pri.max()/MSOL)
#sam_dadt_vals = hard.dadt(sam.mtot, sam.mrat, np.repeat(hard._sepa_init,sam.mtot.size))
#print("mtot [msun]:",sam.mtot/MSOL)
mstar_arr = 10**np.arange(8.0,12.1,0.1) * MSOL
print(mstar_arr.min()/MSOL, mstar_arr.max()/MSOL)

gsmf_func_z0 = sam._gsmf(mstar_arr, 0.0) / np.log(10)  # units of dex^-1 Mpc^-3
gsmf_func_z0pt4 = sam._gsmf(mstar_arr, 0.4) / np.log(10)  # units of dex^-1 Mpc^-3
gsmf_func_z0pt5 = sam._gsmf(mstar_arr, 0.5) / np.log(10)  # units of dex^-1 Mpc^-3
gsmf_func_z3 = sam._gsmf(mstar_arr, 3.0) / np.log(10)  # units of dex^-1 Mpc^-3
gsmf_func_z5 = sam._gsmf(mstar_arr, 5.0) / np.log(10)  # units of dex^-1 Mpc^-3
#print(sam._gsmf(mstar_arr, 0.0))
print(f"phi0={sam._gsmf._phi0}, phiz={sam._gsmf._phiz}")
print(f"log10(mchar0/msun)={np.log10(sam._gsmf._mchar0/MSOL):g}, log10(mcharz/msun)={np.log10(sam._gsmf._mcharz/MSOL):g}")
print(f"alpha0={sam._gsmf._alpha0}, alphaz={sam._gsmf._alphaz}")
plt.xscale('log')
plt.yscale('log')
plt.xlim(5.0e7,2.0e12)
plt.ylim(1.0e-6,0.01)
plt.plot(mstar_arr/MSOL, gsmf_func_z0, lw=3)
#plt.plot(mstar_arr/MSOL, gsmf_func_z0pt4, lw=3)
plt.plot(mstar_arr/MSOL, gsmf_func_z0pt5, lw=3)
plt.plot(mstar_arr/MSOL, gsmf_func_z5, lw=3)
plt.plot(mstar_pri.flatten()/MSOL, sam._gsmf(mstar_pri.flatten(), 0.0)/ np.log(10) ,'.',lw=0,ms=0.1,alpha=0.2)
plt.plot(mstar_pri.flatten()/MSOL, sam._gsmf(mstar_pri.flatten(), 5.0)/ np.log(10) ,'.',lw=0,ms=0.1,alpha=0.2)
phi_check_z0 = np.power(10.0, sam._gsmf._phi0 + sam._gsmf._phiz * 0.0)
m0_check = sam._gsmf._mchar0 / MSOL
plt.plot([1.0e8,1.0e12], [phi_check_z0, phi_check_z0])
plt.plot([m0_check,m0_check],[1.0e-6,1.0e-2],)

In [None]:
#compare_gsmfs([d for d in tng_dpops], colors=colors, lws=lws,
#              labels=[d.lbl for d in tng_dpops])
#compare_gsmfs([d for d in [dpop_tng100_1,dpop_tng100_2]], colors=colors, lws=lws,
#              labels=[d.lbl for d in [dpop_tng100_1,dpop_tng100_2]])
compare_gsmfs([tng100_dpops[0]])
box_vol_mpc = tng100_dpops[0].evo._sample_volume / (1.0e6*PC)**3
print("box_vol_mpc = ", box_vol_mpc)
plt.xlim(7.8,12.2)
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0,'k',lw=2)
#plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0pt4)
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0pt5,'g',lw=2,label='SAM (z=0.5)')
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z5,'m',lw=2,label='SAM (z=5)')
plt.plot(bin_edges_tng100_1_snap99[:-1],gsmf_tng100_1_snap99,'k--')
#plt.plot(bin_edges_snap72[:-1],gsmf_snap72)
plt.plot(bin_edges_tng100_1_snap67[:-1],gsmf_tng100_1_snap67,'g--',label='TNG100-1 (all gals, z=0.5)')
plt.plot(bin_edges_tng100_1_snap17[:-1],gsmf_tng100_1_snap17,'m--',label='TNG100-1 (all gals, z=5)')
plt.legend()
plt.show()

In [None]:
compare_gsmfs([tng50_dpops[0]])
box_vol_mpc = tng50_dpops[0].evo._sample_volume / (1.0e6*PC)**3
print("box_vol_mpc = ", box_vol_mpc)
plt.xlim(7.8,12.2)
#plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0,'k',lw=2)
#plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0pt4)
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0pt5,'g',lw=2,label='SAM (z=0.5)')
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z5,'m',lw=2,label='SAM (z=5)')
#plt.plot(bin_edges_snap99[:-1],gsmf_snap99,'k--')
#plt.plot(bin_edges_snap72[:-1],gsmf_snap72)
plt.plot(bin_edges_tng100_1_snap67[:-1],gsmf_tng100_1_snap67,'g--',label='TNG50-1 (all gals, z=0.5)')
plt.plot(bin_edges_tng100_1_snap17[:-1],gsmf_tng100_1_snap17,'m--',label='TNG50-1 (all gals, z=5)')
plt.legend()
plt.show()

In [None]:
compare_gsmfs(tng_hires_dpops[:2])
#box_vol_mpc = dpop_tng50_1.evo._sample_volume / (1.0e6*PC)**3
#print("box_vol_mpc = ", box_vol_mpc)
plt.xlim(7.8,12.2)
#plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0,'k',lw=2)
#plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0pt4)
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0pt5,'g',lw=2,label='SAM (z=0.5)')
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z3,'m',lw=2,label='SAM (z=3)')
#plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z5,'m',lw=2,label='SAM (z=5)')
#plt.plot(bin_edges_snap99[:-1],gsmf_snap99,'k--')
#plt.plot(bin_edges_snap72[:-1],gsmf_snap72)
plt.plot(bin_edges_tng50_1_snap67[:-1],gsmf_tng50_1_snap67,'g--',label='TNG50-1 (all gals, z=0.5)')
#plt.plot(bin_edges_tng50_1_snap33[:-1],gsmf_tng50_1_snap33,'m:',label='TNG50-1 (all gals, z=3)')
##plt.plot(bin_edges_tng50_1_snap17[:-1],gsmf_tng50_1_snap17,'m:',label='TNG50-1 (all gals, z=5)')
plt.plot(bin_edges_tng100_1_snap67[:-1],gsmf_tng100_1_snap67,'g--',label='TNG100-1 (all gals, z=0.5)')
#plt.plot(bin_edges_tng100_1_snap33[:-1],gsmf_tng100_1_snap33,'m:',label='TNG100-1 (all gals, z=3)')
##plt.plot(bin_edges_tng100_1_snap17[:-1],gsmf_tng100_1_snap17,'m:',label='TNG100-1 (all gals, z=5)')
plt.legend()
plt.show()

In [None]:
def rg15_merger_rate(Mtot, mu, z):
    M0 = 2.0e11 * MSOL
    A0 = 10.0**-2.2287 #± 0.0045
    eta = 2.4644 #± 0.0128
    alpha0 = 0.2241 #± 0.0038
    alpha1 = -1.1759 #± 0.0316
    beta0 = -1.2595 #± 0.0026
    beta1 = 0.0611 #± 0.0021
    gamma = -0.0477 #± 0.0013
    delta0 = 0.7668 #± 0.0202
    delta1 = -0.4695 #± 0.0440

    Az = A0 * (1 + z)**eta
    alphaz = alpha0 * (1 + z)**alpha1
    betaz = beta0 * (1 + z)**beta1
    deltaz = delta0 * (1 + z)**delta1
    
    gmr = Az * (Mtot / (1.0e10*MSOL))**alphaz * (1 + (Mtot/M0)**deltaz) 
    gmr = gmr * mu**(betaz + gamma*np.log10(Mtot / (1.0e10*MSOL)))
    
    return gmr

In [None]:
mstar_pri, mstar_rat, mstar_tot, redz = sam.mass_stellar() 
rg15_gmr = rg15_merger_rate(mstar_tot, mstar_rat, redz)
def get_index(arr,val):
    diff = np.abs(arr-val)
    index =  np.where(diff==diff.min())[0][0]
    return index

mu0pt25_ix = get_index(mstar_rat[0,:,0],0.25)
mu0pt01_ix = get_index(mstar_rat[0,:,0],0.01)
m9_ix = get_index(mstar_tot[:,0,0],1.0e9*MSOL)
m10_ix = get_index(mstar_tot[:,0,0],1.0e10*MSOL)
m11_ix = get_index(mstar_tot[:,0,0],1.0e11*MSOL)
z0pt1_ix = get_index(redz[0,0,:],0.1)
print(z0pt1_ix, mu0pt25_ix, mu0pt01_ix)
#print(mstar_rat[-1,:,0])
#print(redz[0,0,:])
#print(mstar_tot[:,0,0]/MSOL)
#print(mstar_tot[:,-1,0]/MSOL)
print(rg15_gmr.shape)
print(mstar_tot.shape)
print(sam.mrat)
print(mstar_rat[0,:,0])

sam_gmr = sam_no_gpf._gal_merger_rate(mstar_tot, mstar_rat, redz) * GYR # convert from s^-1 to Gyr^-1
print(f"sam_gmr shape: {sam_gmr.shape}")

fig = plt.figure(figsize=(12,5))
ax1 = fig.add_subplot(131)
ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_xlabel('stellar mass')
ax1.set_ylabel('Galaxy merger rate [Gyr^-1]')
ax1.plot(mstar_tot[:,0,z0pt1_ix]/MSOL, rg15_gmr[:,0,z0pt1_ix], label=f'mu={mstar_rat[0,0,0]:.4g}')
ax1.plot(mstar_tot[:,mu0pt01_ix,z0pt1_ix]/MSOL, rg15_gmr[:,mu0pt01_ix,z0pt1_ix],
         label=f'mu={mstar_rat[0,mu0pt01_ix,0]:.4g}')
ax1.plot(mstar_tot[:,mu0pt25_ix,z0pt1_ix]/MSOL, rg15_gmr[:,mu0pt25_ix,z0pt1_ix],
         label=f'mu={mstar_rat[0,mu0pt25_ix,0]:.4g}')

ax1.plot(mstar_tot[:,0,z0pt1_ix]/MSOL, sam_gmr[:,0,z0pt1_ix],'--')
ax1.plot(mstar_tot[:,mu0pt01_ix,z0pt1_ix]/MSOL, sam_gmr[:,mu0pt01_ix,z0pt1_ix],'--')
ax1.plot(mstar_tot[:,mu0pt25_ix,z0pt1_ix]/MSOL, sam_gmr[:,mu0pt25_ix,z0pt1_ix],'--')

ax1.legend()

ax2 = fig.add_subplot(132)
ax2.set_xscale('log')
ax2.set_yscale('log')
ax2.set_xlabel('stellar mass ratio')
ax2.set_ylabel('Galaxy merger rate [Gyr^-1]')
ax2.plot(mstar_rat[0,:,z0pt1_ix],rg15_gmr[0,:,z0pt1_ix],label=f'M={mstar_tot[0,0,0]/MSOL:.4g}')
ax2.plot(mstar_rat[m9_ix,:,z0pt1_ix],rg15_gmr[m9_ix,:,z0pt1_ix],
         label=f'M={mstar_tot[m9_ix,0,0]/MSOL:.4g}')
ax2.plot(mstar_rat[m10_ix,:,z0pt1_ix], rg15_gmr[m10_ix,:,z0pt1_ix],
         label=f'M={mstar_tot[m10_ix,0,0]/MSOL:.4g}')
ax2.plot(mstar_rat[m11_ix,:,z0pt1_ix], rg15_gmr[m11_ix,:,z0pt1_ix],
         label=f'M={mstar_tot[m11_ix,0,0]/MSOL:.4g}')

ax2.plot(mstar_rat[0,:,z0pt1_ix], sam_gmr[0,:,z0pt1_ix],'--')
ax2.plot(mstar_rat[m9_ix,:,z0pt1_ix], sam_gmr[m9_ix,:,z0pt1_ix],'--')
ax2.plot(mstar_rat[m10_ix,:,z0pt1_ix], sam_gmr[m10_ix,:,z0pt1_ix],'--')
ax2.plot(mstar_rat[m11_ix,:,z0pt1_ix], sam_gmr[m11_ix,:,z0pt1_ix],'--')

ax2.legend()

print("shapes:")
print(redz[m9_ix,mu0pt01_ix,:].shape)
print(rg15_gmr[m9_ix,mu0pt01_ix:,:].shape)
print(mstar_rat[m9_ix,mu0pt01_ix:,:].shape)
print(utils.trapz(rg15_gmr[m9_ix,mu0pt01_ix:,:],mstar_rat[0,mu0pt01_ix:,0],axis=0,cumsum=False).shape)
print(utils.trapz(rg15_gmr[m9_ix,mu0pt01_ix:,:],mstar_rat[0,mu0pt01_ix:,0],axis=0,cumsum=False).sum(axis=0).shape)
#rg15_gmr_cum_mu = 
#sam_dngalgal_dz = utils.trapz(integ, mstar_rat_arr, axis=1, cumsum=False)                
#sam_dngalgal_dz = sam_dngalgal_dz.sum(axis=0)
ax3 = fig.add_subplot(133)
ax3.set_xlim(1,10)
ax3.set_xscale('log')
ax3.set_yscale('log')
ax3.set_xlabel('1+z')
ax3.set_ylabel('Galaxy merger rate [Gyr^-1]')
ax3.plot(1+redz[m11_ix,mu0pt01_ix,:],
         utils.trapz(rg15_gmr[m11_ix,mu0pt01_ix:,:],mstar_rat[0,mu0pt01_ix:,0],axis=0,cumsum=False).sum(axis=0),
         label=f'M={mstar_tot[m11_ix,mu0pt01_ix,0]/MSOL:.4g},mu>={mstar_rat[m11_ix,mu0pt01_ix,0]:.4g}')
ax3.plot(1+redz[m11_ix,mu0pt25_ix,:],
         utils.trapz(rg15_gmr[m11_ix,mu0pt25_ix:,:],mstar_rat[0,mu0pt25_ix:,0],axis=0,cumsum=False).sum(axis=0),
         label=f'M={mstar_tot[m11_ix,mu0pt25_ix,0]/MSOL:.4g},mu>={mstar_rat[m11_ix,mu0pt25_ix,0]:.4g}')
ax3.plot(1+redz[m10_ix,mu0pt25_ix,:],
         utils.trapz(rg15_gmr[m10_ix,mu0pt25_ix:,:],mstar_rat[0,mu0pt25_ix:,0],axis=0,cumsum=False).sum(axis=0),
         label=f'M={mstar_tot[m10_ix,mu0pt25_ix,0]/MSOL:.4g},mu>={mstar_rat[m10_ix,mu0pt25_ix,0]:.4g}')
ax3.plot(1+redz[m9_ix,mu0pt25_ix,:],
         utils.trapz(rg15_gmr[m9_ix,mu0pt25_ix:,:],mstar_rat[0,mu0pt25_ix:,0],axis=0,cumsum=False).sum(axis=0),
         label=f'M={mstar_tot[m9_ix,mu0pt25_ix,0]/MSOL:.4g},mu>={mstar_rat[m9_ix,mu0pt25_ix,0]:.4g}')

ax3.plot(1+redz[m11_ix,mu0pt01_ix,:],
         utils.trapz(sam_gmr[m11_ix,mu0pt01_ix:,:],
                     mstar_rat[0,mu0pt01_ix:,0],axis=0,cumsum=False).sum(axis=0),'--')
ax3.plot(1+redz[m11_ix,mu0pt25_ix,:],
         utils.trapz(sam_gmr[m11_ix,mu0pt25_ix:,:],
                     mstar_rat[0,mu0pt25_ix:,0],axis=0,cumsum=False).sum(axis=0),'--')
ax3.plot(1+redz[m10_ix,mu0pt25_ix,:],
         utils.trapz(sam_gmr[m10_ix,mu0pt25_ix:,:],
                     mstar_rat[0,mu0pt25_ix:,0],axis=0,cumsum=False).sum(axis=0),'--')
ax3.plot(1+redz[m9_ix,mu0pt25_ix,:],
         utils.trapz(sam_gmr[m9_ix,mu0pt25_ix:,:],
                     mstar_rat[0,mu0pt25_ix:,0],axis=0,cumsum=False).sum(axis=0),'--')

#ax3.plot(1+redz[m10_idx,mu0pt01_idx,:],rg15_gmr[m10_idx,mu0pt01_idx,:],
#         label=f'M={mstar_tot[m10_idx,mu0pt01_idx,0]/MSOL:.4g},mu>={mstar_rat[m10_idx,mu0pt01_idx,0]:.4g}')
#ax3.plot(1+redz[m11_idx,mu0pt01_idx,:],rg15_gmr[m11_idx,mu0pt01_idx,:],
#         label=f'M={mstar_tot[m11_idx,mu0pt01_idx,0]/MSOL:.4g},mu>={mstar_rat[m11_idx,mu0pt01_idx,0]:.4g}')
#ax3.plot(redz[m9_idx,:,z0pt1_idx],rg15_gmr[m9_idx,:,z0pt1_idx],
#         label=f'M={mstar_tot[m9_idx,0,z0pt1_idx]/MSOL:.4g}')
#ax2.plot(mstar_rat[m10_idx,:,z0pt1_idx], rg15_gmr[m10_idx,:,z0pt1_idx],
#         label=f'M={mstar_tot[m10_idx,0,z0pt1_idx]/MSOL:.4g}')
#ax2.plot(mstar_rat[m11_idx,:,z0pt1_idx], rg15_gmr[m11_idx,:,z0pt1_idx],
#         label=f'M={mstar_tot[m11_idx,0,z0pt1_idx]/MSOL:.4g}')
#ax1.plot(mstar_tot[:,-1,0]/MSOL,rg15_gmr[:,-1,0],label=f'mu={mstar_rat[0,-1,0]:.4g}')
ax3.legend()

In [None]:
#compare_gal_merger_dens([d for d in [dpop_tng100_1,dpop_tng100_2]],sam_compare=sam)
compare_gal_merger_dens(all_hires_dpops, sam_compare=[sam,sam_no_gpf],gpf_flag=[1,0],  
                        mcut=1.0e10*MSOL, qcut=0.1)

In [None]:
#compare_gal_merger_rate([d for d in [dpop_tng100_1,dpop_tng100_2]],sam_compare=sam)
compare_gal_merger_rate(tng_dpops, sam_compare=[sam,sam_no_gpf], gpf_flag=[1,0],mcut=1.0e10*MSOL,qcut=0.1)

In [None]:
compare_gal_merger_rate(tng_dpops, sam_compare=[sam,sam_no_gpf],mcut=None)

In [None]:
mstar_pri, mstar_rat, mstar_tot, redz = sam.mass_stellar() 
#print(mstar_pri.shape, mstar_rat.shape, mstar_tot.shape, redz.shape)
#print("dens_gal_gal and mstar shapes:")

#sam_no_gpf._ndens_gal(mstar_pri, mstar_rat, redz)


In [None]:
sepa_arr = np.logspace(1.0e-6*PC, hard._sepa_init, 101)
mt, mr, sepa, norm = np.broadcast_arrays(
            sam.mtot[:, np.newaxis, np.newaxis],
            sam.mrat[np.newaxis, :, np.newaxis],
            sepa_arr[np.newaxis, np.newaxis, :],
            hard._norm[:, :, np.newaxis],
        )
print(mt.shape, mr.shape, sepa.shape)
dadt_sam = hard.dadt(mt, mr, sepa, norm)
print(dadt_sam.shape)

### Plot amplitudes at 1/yr, 1/3yr, and 1/10yr

In [None]:
for d in tng_dpops + tng_fsa_dpops: #all_sim_dpops + all_fsa_dpops:
    d.get_amplitudes_at_freqs()

## this is sketch. only works if freqs is same for all dpops and for sam
ayr_sam = gwb_sam[d.idx_ayr,:].flatten()
a3yr_sam = gwb_sam[d.idx_a3yr,:].flatten()
a10yr_sam = gwb_sam[d.idx_a10yr,:].flatten()

ayr_sam_no_gpf = gwb_sam_no_gpf[d.idx_ayr,:].flatten()
a3yr_sam_no_gpf = gwb_sam_no_gpf[d.idx_a3yr,:].flatten()
a10yr_sam_no_gpf = gwb_sam_no_gpf[d.idx_a10yr,:].flatten()


In [None]:
fig, ax = plt.subplots(figsize=[12, 6])
ax.set(xlabel=r'$\log_{10}(A_\mathrm{yr})$', ylabel='Probability Density')
ax.grid(alpha=0.2)

kale.dist1d(np.log10(ayr_sam), density=True, hist=False, confidence=True, carpet=False, 
            lw=4, color='k', label='SAM')
kale.dist1d(np.log10(ayr_sam_no_gpf), density=True, hist=False, confidence=True, carpet=False, 
            lw=4, color='orchid', label='SAM (no GPF)')
for i in np.arange(len(tng_dpops)):
    kale.dist1d(np.log10(tng_dpops[i].ayr), density=True, hist=False, confidence=False, carpet=False, 
                label=tng_dpops[i].lbl, lw=lws[i], color=colors[i],alpha=0.5)
    kale.dist1d(np.log10(tng_fsa_dpops[i].ayr), density=True, hist=False, confidence=False, carpet=False, 
                label=tng_fsa_dpops[i].lbl, lw=lws[i], color=colors[i],ls='--')

plt.title("Comparison of GWB amplitudes at 1/yr")
plt.legend()
fig.savefig(f'compare_dpops_tau{tau/(1e9*YR):.1f}_A1yr.png')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=[12, 6])
ax.set(xlabel=r'$\log_{10}(A_\mathrm{3yr})$', ylabel='Probability Density')
ax.grid(alpha=0.2)

kale.dist1d(np.log10(a3yr_sam), density=True, hist=False, confidence=True, carpet=False, 
            lw=4, color='k', label='SAM')
kale.dist1d(np.log10(a3yr_sam_no_gpf), density=True, hist=False, confidence=True, carpet=False, 
            lw=4, color='orchid', label='SAM (no GPF)')
for i in np.arange(len(tng_dpops)):
    kale.dist1d(np.log10(tng_dpops[i].a3yr), density=True, hist=False, confidence=False, carpet=False, 
                label=tng_dpops[i].lbl, lw=lws[i], color=colors[i],alpha=0.5)
    kale.dist1d(np.log10(tng_fsa_dpops[i].a3yr), density=True, hist=False, confidence=False, carpet=False, 
                label=tng_fsa_dpops[i].lbl, lw=lws[i], color=colors[i],ls='--')

plt.title("Comparison of GWB amplitudes at 1/3yr")
plt.legend()
fig.savefig(f'compare_dpops_tau{tau/(1e9*YR):.1f}_A0.33yr.png')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=[12, 6])
ax.set(xlabel=r'$\log_{10}(A_\mathrm{10yr})$', ylabel='Probability Density')
ax.grid(alpha=0.2)

kale.dist1d(np.log10(a10yr_sam), density=True, hist=False, confidence=True, carpet=False, 
            lw=4, color='k', label='SAM')
kale.dist1d(np.log10(a10yr_sam_no_gpf), density=True, hist=False, confidence=True, carpet=False, 
            lw=4, color='orchid', label='SAM (no GPF)')
for i in np.arange(len(tng_dpops)):    
    kale.dist1d(np.log10(tng_dpops[i].a10yr), density=True, hist=False, confidence=False,  carpet=False, 
                label=tng_dpops[i].lbl, lw=lws[i], color=colors[i],alpha=0.5)
    kale.dist1d(np.log10(tng_fsa_dpops[i].a10yr), density=True, hist=False, confidence=False,  carpet=False, 
                label=tng_fsa_dpops[i].lbl, lw=lws[i], color=colors[i], ls='--')
    
#kale.dist1d(np.log10(a10yr_fsa_new_ill), density=True, confidence=False, label='a10yr_fsa_new_ill',color='m')
#kale.dist1d(np.log10(a10yr_sam), density=True, confidence=True, label='a10yr_sam')
plt.title("Comparison of GWB amplitudes at 1/10yr")
plt.legend()
fig.savefig(f'compare_dpops_tau{tau/(1e9*YR):.1f}_A0.1yr.png')
plt.show()

In [None]:
#print(gwb_old_ill.back.shape, gwb_old_ill.fore.shape, gwb_old_ill.loudest.shape)
##plot.plot_gwb(freqs, gwb_new_ill_nst.back)
#print(freqs)
#print(gwb_old_ill.back)
for d in all_sim_dpops:
    plot.plot_gwb(d.freqs, d.gwb.back) #, ylim=(5.0e-17,2.0e-14))
    plt.title(d.lbl)
#plot.plot_gwb(freqs, gwb_tng50_3.back, ylim=(5.0e-17,2.0e-14))
##plot.plot_gwb(freqs, gwb_sam, ylim=(5.0e-17,2.0e-14))

plt.show()

In [None]:
sepa = np.logspace(-4.5, 5, 100) * PC
fig, ax = plot.figax(figsize=(8,5))
for d in tng_dpops:
    plot_evo(d.evo, freqs=None, sepa=sepa, ax=ax)
    plt.title(d.lbl)
    #plot_evo(evo_tng50_3, freqs=None, sepa=sepa, ax=ax)
plt.legend()
plt.show()

In [None]:
sepa = np.logspace(-4.5, 5, 100) * PC
fig, ax = plot.figax(figsize=(8,5))
for d in tng_fsa_dpops:
    plot_evo(d.evo, freqs=None, sepa=sepa, ax=ax)
    plt.title(d.lbl)
#plt.legend()
plt.show()

In [None]:
# Calculate the total lifetime of each binary
ncols = 2
nrows = 3
fig, axes = plot.figax(scale='lin', xlabel='Time: actual/specified', ylabel='density',figsize=(8,5),
                       ncols=ncols, nrows=nrows, wspace=0.3,hspace=0.3)
#times = [evo_old_ill.tlook, evo_new_ill.tlook, evo_tng100_1.tlook, evo_tng300_1.tlook, hard.tlook]
times = [dpop_tng100_1.evo.tlook, dpop_tng300_1.evo.tlook] #, hard.tlook]
print(len(times))
i = 0
for idx,ax in np.ndenumerate(axes):
    i = idx[0]*ncols + idx[1]
    if i >= len(times): break
    print(times[i].shape)
    print(f"idx: {idx}, i: {i}")
    dt = times[i][:, 0] - times[i][:, -1]
    # Create figure
    # use kalepy to plot distribution
    kale.dist1d(dt/tau, density=True, ax=ax)
    #ax.legend()
    
plt.show()

In [None]:
for d in tng_dpops:
    plot_bin_pop(d.pop)
plt.show()

In [None]:
#plot_mbh_scaling_relations(fsa_pop_old_ill)
#plot_mbh_scaling_relations(pop_new_ill)
#plot_mbh_scaling_relations(pop_tng100_1)
#plot_mbh_scaling_relations(pop_tng300_1)
#plot_mbh_scaling_relations(pop_old_ill,compare_pops=[pop_new_ill,pop_tng100_1,pop_tng300_1], ncols=2, nrows=2)
plot_mbh_scaling_relations(dpop_new_ill.pop, compare_pops=[d.pop for d in tng_fsa_dpops], 
                           ncols=4, nrows=2,
                           xlim=(6.5,13.2), ylim=(4.5,11.1))
plt.savefig('mmbulge_relations.png')
plt.show()

In [None]:
dpop_new_ill.pop.mbulge.shape

In [None]:
dpop_tng100_1.pop.mbulge.shape

In [None]:
## ---- Set the fixed binary lifetime
##tau = 2.0 * GYR 
#tau = 1.0 * GYR 
####################################

#fixed_sepa = 1.0e4 * PC

## construct sampling frequencies
##freqs = holo.utils.nyquist_freqs(dur=20.0*YR, cad=0.25*YR, lgspace=True)
##freqs_edges = holo.utils.nyquist_freqs_edges(dur=20.0*YR, cad=0.25*YR, lgspace=True)
#freqs, freqs_edges = utils.pta_freqs()
#### freqs, freq_edges = holo.librarian.get_freqs(None) ## doesn't work for Discrete_GW
##print(freqs)
##print(np.log10(freqs))
#print(freqs.size)

#colors = ['r', 'orange', 'y', 'blue', 'c', 'm', 'orchid', 'k', 'k', 'k', 'k'] #, 'c']
#lws = [3.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.0, 1.0, 1.0, 1.0]

In [None]:
##dpop_old_ill = Discrete(freqs, freqs_edges, fname=None, lbl="oldIll", tau=tau, fixed_sepa=None)
##dpop_fsa_old_ill = Discrete(freqs, freqs_edges, fname=None, lbl="fsaOldIll", tau=tau, 
##                            fixed_sepa=fixed_sepa, mod_mmbulge=True)

#dpop_new_ill = Discrete(freqs, freqs_edges, lbl="newIll", tau=tau, fixed_sepa=None, 
#                        fname='galaxy-mergers_Illustris-1_gas-100_dm-100_star-100_bh-001.hdf5')
#dpop_fsa_new_ill = Discrete(freqs, freqs_edges, lbl="fsaNewIll", tau=tau, fixed_sepa=fixed_sepa, mod_mmbulge=True,
#                            fname='galaxy-mergers_Illustris-1_gas-100_dm-100_star-100_bh-001.hdf5')


In [None]:
#### TNG100 ####

#dpop_tng100_1 = Discrete(freqs, freqs_edges, lbl="TNG100-1", tau=tau, fixed_sepa=None, 
#                         fname='galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-001.hdf5')
#dpop_fsa_tng100_1 = Discrete(freqs, freqs_edges, lbl="fsaTNG100-1", tau=tau, fixed_sepa=fixed_sepa, mod_mmbulge=True,
#                             fname='galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-001.hdf5')

#dpop_tng100_2 = Discrete(freqs, freqs_edges, lbl="TNG100-2", tau=tau, fixed_sepa=None, 
#                         fname='galaxy-mergers_TNG100-2_gas-012_dm-012_star-012_bh-001.hdf5')
#dpop_fsa_tng100_2 = Discrete(freqs, freqs_edges, lbl="fsaTNG100-2", tau=tau, fixed_sepa=fixed_sepa, mod_mmbulge=True,
#                             fname='galaxy-mergers_TNG100-2_gas-012_dm-012_star-012_bh-001.hdf5')


In [None]:
#### TNG300 ####

#dpop_tng300_1 = Discrete(freqs, freqs_edges, lbl="TNG300-1", tau=tau, fixed_sepa=None, 
#                          fname='galaxy-mergers_TNG300-1_gas-012_dm-012_star-012_bh-001.hdf5')
#dpop_fsa_tng300_1 = Discrete(freqs, freqs_edges, lbl="fsaTNG300-1", tau=tau, fixed_sepa=fixed_sepa, mod_mmbulge=True,
#                             fname='galaxy-mergers_TNG300-1_gas-012_dm-012_star-012_bh-001.hdf5')


In [None]:
#### TNG50 ####

#dpop_tng50_1 = Discrete(freqs, freqs_edges, lbl="TNG50-1", tau=tau, fixed_sepa=None, 
#                        fname='galaxy-mergers_TNG50-1_gas-800_dm-800_star-800_bh-001.hdf5')
#dpop_fsa_tng50_1 = Discrete(freqs, freqs_edges, lbl="fsaTNG50-1", tau=tau, fixed_sepa=fixed_sepa, mod_mmbulge=True,
#                            fname='galaxy-mergers_TNG50-1_gas-800_dm-800_star-800_bh-001.hdf5')

#dpop_tng50_2 = Discrete(freqs, freqs_edges, lbl="TNG50-2", tau=tau, fixed_sepa=None, 
#                        fname='galaxy-mergers_TNG50-2_gas-100_dm-100_star-100_bh-001.hdf5')
#dpop_fsa_tng50_2 = Discrete(freqs, freqs_edges, lbl="fsaTNG50-2", tau=tau, fixed_sepa=fixed_sepa, mod_mmbulge=True,
#                            fname='galaxy-mergers_TNG50-2_gas-100_dm-100_star-100_bh-001.hdf5')

#dpop_tng50_3 = Discrete(freqs, freqs_edges, lbl="TNG50-3", tau=tau, fixed_sepa=None, 
#                        fname='galaxy-mergers_TNG50-3_gas-012_dm-012_star-012_bh-001.hdf5')
#dpop_fsa_tng50_3 = Discrete(freqs, freqs_edges, lbl="fsaTNG50-3", tau=tau, fixed_sepa=fixed_sepa, mod_mmbulge=True,
#                            fname='galaxy-mergers_TNG50-3_gas-012_dm-012_star-012_bh-001.hdf5')


In [None]:
##all_sim_dpops = [dpop_old_ill, dpop_new_ill, dpop_tng50_1, dpop_tng50_2, 
##                 dpop_tng50_3, dpop_tng100_1, dpop_tng100_2, dpop_tng300_1]
##all_fsa_dpops = [dpop_fsa_old_ill, dpop_fsa_new_ill, dpop_fsa_tng50_1, dpop_fsa_tng50_2, 
##                 dpop_fsa_tng50_3, dpop_fsa_tng100_1, dpop_fsa_tng100_2, dpop_fsa_tng300_1]
#all_sim_dpops = [dpop_new_ill, dpop_tng50_1, dpop_tng50_2, 
#                 dpop_tng50_3, dpop_tng100_1, dpop_tng100_2, dpop_tng300_1]
#all_fsa_dpops = [dpop_fsa_new_ill, dpop_fsa_tng50_1, dpop_fsa_tng50_2, 
#                 dpop_fsa_tng50_3, dpop_fsa_tng100_1, dpop_fsa_tng100_2 , dpop_fsa_tng300_1]

##tng_dpops = [dpop_tng100_1, dpop_tng100_2] #, dpop_tng300_1]
##tng_fsa_dpops = [dpop_fsa_tng100_1, dpop_fsa_tng100_2] #, dpop_tng300_1]

#tng_dpops = [dpop_tng50_1, dpop_tng50_2, dpop_tng50_3, dpop_tng100_1, dpop_tng100_2 , dpop_tng300_1]
#tng_fsa_dpops = [dpop_fsa_tng50_1, dpop_fsa_tng50_2, dpop_fsa_tng50_3, 
#                 dpop_fsa_tng100_1, dpop_fsa_tng100_2, #] 
#                 dpop_fsa_tng300_1] #, rescale_dpop_fsa_tng300_1]


In [None]:
### ---- rescaled TNG300-1 file (masses increased by factor of 1.4)
#rescale_fsa_pop_tng300_1 = holo.population.Pop_Illustris(fname='galaxy-mergers_TNG300-1_gas-012_dm-012_star-012_bh-001.hdf5',
#                                                         fixed_sepa=fixed_sepa)
#mmbulge = holo.relations.MMBulge_KH2013()
#print(f'min/max mbulge in fsa_pop_tng300_1: {rescale_fsa_pop_tng300_1.mbulge.min()/MSOL:.6g}, {rescale_fsa_pop_tng300_1.mbulge.max()/MSOL:.6g}')
#mod_KH2013_rTNG300 = holo.population.PM_Mass_Reset(mmbulge, scatter=True, rescale_mbulge=True)
#rescale_fsa_pop_tng300_1.modify(mod_KH2013_rTNG300)
#print("after rescaling,")
#print(f'min/max mbulge in fsa_pop_tng300_1: {rescale_fsa_pop_tng300_1.mbulge.min()/MSOL:.6g}, {rescale_fsa_pop_tng300_1.mbulge.max()/MSOL:.6g}')


In [None]:
### create a fixed-total-time hardening mechanism
#rescale_fsa_fixed_tng300_1 = holo.hardening.Fixed_Time_2PL.from_pop(rescale_fsa_pop_tng300_1, tau)
### Create an evolution instance using population and hardening mechanism
#rescale_fsa_evo_tng300_1 = holo.evolution.Evolution(rescale_fsa_pop_tng300_1, rescale_fsa_fixed_tng300_1)
### evolve binary population
#rescale_fsa_evo_tng300_1.evolve()
#print("vol:",rescale_fsa_evo_tng300_1._sample_volume)