In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys

import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sams
import holodeck.gravwaves
from holodeck import cosmo, utils, plot, discrete, sams, host_relations
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode

log = holo.log
log.setLevel(logging.INFO)

In [None]:
def calc_sim_merging_gsmf(dpop, mass_type='all', req_z=None, req_dz=None, mask=None, verbose=False):
    
    mstar = dpop.pop.mbulge/MSOL
    if verbose: print(f"in calc_sim_merging_gsmf: {mstar.shape=}")

    if req_z is None:
        if verbose: print("calculating gsmf for all redshifts.")
        if mask is not None:
            mstar = mstar[mask]
    else:
        if req_dz is None:
            req_dz = 0.1
        z = dpop.pop.redz
        zmask = (np.abs(z-req_z)<req_dz)
        if mask is not None:
            mstar = mstar[mask&zmask]
        else:
            mstar = mstar[zmask]
    print(f"after masking, in calc_sim_merging_gsmf: {mstar.shape=}")
        
        
    if mass_type == 'tot':
        mstar = mstar[:,0] + mstar[:,1]
    elif mass_type == 'pri':
        mstar = mstar.max(axis = 1)
    elif mass_type not in ('tot','pri','all'):
        err = "`mass_type` must be 'tot', 'pri', or 'all'"
        raise ValueError(err)

    if verbose: print(f"after setting mass_type, in calc_sim_merging_gsmf: {mstar.shape=}")

    mstar = np.log10(mstar)
    box_vol_mpc = dpop.pop._sample_volume / (1.0e6*PC)**3

    mhist, mbin_edges = np.histogram(mstar, range=(7.25,15.25), bins=16)
    mbinsize = mbin_edges[1]-mbin_edges[0]
    mbins = mbin_edges[:-1]+mbinsize
        
    #return mbins, mhist/mbinsize/np.log(10)/box_vol_mpc, mhist
    return mbins, mhist/mbinsize/box_vol_mpc, mhist

def compare_sim_merging_gsmfs(dpops, req_z=None, req_dz=None):
    
    assert isinstance(dpops, list), '`dpops` must be a list of binary populations'

    fig, ax = plt.subplots(figsize=[10, 5])
    ax.set(ylabel='Probability Density')
    ax.grid(alpha=0.01)
    ax.set(yscale='log')
    
    if req_z is None:
        lbl_extra = ' (mergers, all z)'
    else:
        lbl_extra = f' (mergers, {np.maximum(0,req_z-req_dz):.2g}<z<{req_z+req_dz:.2g})'
    
    for i,dp in enumerate(dpops):
        xlab = r'$\log_{10}(M_{*})$' 
        ax.set(xlabel=xlab)
        ax.set(ylabel=r'GSMF [$(dex^{-1} Mpc^{-3}$]')

        mbins, gsmf, mhist = calc_sim_merging_gsmf(dp, req_z=req_z, req_dz=req_dz)
        plt.plot(mbins, gsmf, ':', alpha=0.5, color=dp.color, lw=dp.lw)
        plt.plot(mbins[mhist>10], gsmf[mhist>10], color=dp.color, lw=dp.lw,label=dp.lbl+lbl_extra)

    return fig, ax


In [None]:
import compare_discrete

In [None]:
tmp = compare_discrete.create_dpops(allow_mbh0=True, mod_mmbulge=False, skip_evo=True)
all_dpops, tng_dpops, all_fsa_dpops, tng_fsa_dpops = tmp

In [None]:
print("creating SAM using Galaxy Pair Fraction (GPF) + Galaxy Merger Timescale (GMT)...")

sam = sams.Semi_Analytic_Model(gpf = sams.GPF_Power_Law())

print("SKIPPING HARDENING")

#print("    ...calculating hardening")
#hard = holo.hardening.Fixed_Time_2PL_SAM(sam, all_fsa_dpops[0].tau, sepa_init=1.0e4*PC)
#print("    ...creating gwb")
#gwb_sam = sam.gwb_new(all_fsa_dpops[0].freqs_edges, hard, realize=500)

In [None]:
print("creating 'no-GPF' SAM using Galaxy Merger Rate (GMR) "
      "    (uses galaxy merger rates directly from RG15 instead of GPF+GMT)...")

sam_no_gpf = sams.Semi_Analytic_Model()

print("SKIPPING HARDENING")

#print("    ...calculating hardening for no-GPF SAM")
#hard_no_gpf = holo.hardening.Fixed_Time_2PL_SAM(sam_no_gpf, all_fsa_dpops[0].tau, sepa_init=1.0e4*PC)
#print("    ...creating gwb for no-GPF SAM")
#gwb_sam_no_gpf = sam_no_gpf.gwb_new(all_fsa_dpops[0].freqs_edges, hard_no_gpf, realize=50)

In [None]:
compare_sim_merging_gsmfs(all_dpops)
plt.legend()
plt.show()

In [None]:
nrows = 2
ncols = 2
mass_type = 'tot'
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex='all', sharey='all', figsize=[10,9])
#print(axes.shape)
#print(axes.dtype)
#print(axes[0].shape)

if mass_type == 'pri':
    xlab = r'$\log_{10}(M_{*,pri})$'
elif mass_type == 'tot':
    xlab = r'$\log_{10}(M_{*,tot})$'
elif mass_type == 'all':
    xlab = r'$\log_{10}(M_{*}) (all progenitors)$'

axes[0,0].set(yscale='log', 
              ylabel=r'GSMF [$(\log_{10} M)^{-1} Mpc^{-3}$]',
              xlim=(7.25,15.5), ylim=(1.0e-6,1.0))
axes[1,0].set(yscale='log',xlabel=xlab,
              ylabel=r'BHMF [$(\log_{10} M)^{-1} Mpc^{-3}$]',
              xlim=(7.25,15.5), ylim=(1.0e-6,1.0))
axes[1,1].set(yscale='log',xlabel=xlab,
              #ylabel=r'BHMF [$(\log_{10} M)^{-1} Mpc^{-3}$]',
              xlim=(7.25,15.5), ylim=(1.0e-6,1.0))
    
#dpop_lists = [t50_hires_fid_and_fsa_dpops, t100_hires_fid_and_fsa_dpops, ill_hires_fid_and_fsa_dpops]
#dpop_lists = [ 
#    [ d for d in all_hires_fid_and_fsa_dpops if d.lbl in ['TNG50-1', 'fsa-TNG50-1', 'fsa-TNG50-1-bh0'] ],
#    [ d for d in all_hires_fid_and_fsa_dpops if d.lbl in ['TNG50-1-N100', 'fsa-TNG50-1-N100', 'fsa-TNG50-1-N100-bh0'] ],
#    [ d for d in all_hires_fid_and_fsa_dpops if d.lbl in ['TNG100-1', 'fsa-TNG100-1', 'fsa-TNG100-1-bh0'] ],
#    ill_hires_fid_and_fsa_dpops
#    ]
dpop_lists = [d for d in [all_fsa_dpops]]
print(len(dpop_lists))
for i in range(nrows):
    for j in range(ncols):
        
        if i+2*j >= np.minimum((nrows*ncols),len(dpop_lists)): 
            break

        for dp in dpop_lists[i+2*j]:
            qbulge = dp.pop.mbulge[:,0] / dp.pop.mbulge[:,1]
            qbulge[qbulge>1] = 1.0/qbulge[qbulge>1]
            #print(f"{qbulge.min()=} {qbulge.max()=}")
            
            x,y,mh = calc_sim_merging_gsmf(dp, mass_type=mass_type)
            axes[i,j].plot(x,y, lw=dp.lw, label=dp.lbl)
            
            xqcut,yqcut,mhqcut = calc_sim_merging_gsmf(dp, mask = (qbulge>0.1), 
                                                       mass_type=mass_type)
            axes[i,j].plot(xqcut,yqcut, lw=dp.lw, label=dp.lbl+' q>0.1', alpha=0.4)
            
            axes[i,j].legend()

    else:
        continue  # only executed if the inner loop did NOT break
    break  # only executed if the inner loop DID break

fig.suptitle(f"GSMF for merging galaxies (mass type = {mass_type})")
fig.subplots_adjust(top=0.95, wspace=0.1, hspace=0.1)
plt.show()

### Load data for all subhalos & compare with merging subhalos

In [None]:
def calc_gsmf_from_snap(basePath, req_z, req_binsize=0.05, verbose=False):

    f = h5py.File(f"{basePath}/gsmf_all_snaps_Nmin1.hdf5","r")

    box_vol_mpc = f.attrs['box_volume_mpc']
    snapnums = f.attrs['SnapshotNums']
    scalefacs = f.attrs['SnapshotScaleFacs']
    zsnaps = 1.0 / scalefacs - 1.0
    
    diff = np.abs(zsnaps-req_z)
    snapNum = snapnums[diff==diff.min()][0]
    zsnap = zsnaps[diff==diff.min()][0]
    if verbose or (diff.min()>0.01):
        print(f"{req_z=}, {snapNum=}, {zsnap=}, {diff.min()=}")

    dlgm_orig = f.attrs['LogMassBinWidth']
    mbin_edges_orig = np.array(f['StellarMassBinEdges'])
    nbins_orig = mbin_edges_orig.size - 1
    mhist_all_snaps = np.array(f['StellarMassHistograms'])
    
    mhist_snap_orig = mhist_all_snaps[:,(snapnums==snapNum)].flatten()
    if verbose: print(f"{mhist_snap_orig.shape=}, {mbin_edges_orig.shape=}")
    if mhist_snap_orig.size != nbins_orig:
        print('whoops')
        return


    if req_binsize < dlgm_orig:
        raise ValueError(f"{req_binsize=} requested, but min allowed is {dlgm_orig=}")
    if int(req_binsize/dlgm_orig) > nbins_orig/2:
        raise ValueError(f"{req_binsize=} requested, but max allowed is {dlgm_orig*nbins_orig/2=}")

    ncomb = int(req_binsize/dlgm_orig)
    dlgm = dlgm_orig * ncomb
    mbin_edges = mbin_edges_orig[::ncomb]
    nbins = mbin_edges.size
    if ncomb > 1:
        mbin_edges = np.append(mbin_edges, mbin_edges[-1]+dlgm)
        mhist_snap = np.zeros((nbins))
        if verbose: print(f"{mbin_edges.size=}")
        for i in range(mbin_edges.size-1):
            mhist_snap[i] = mhist_snap_orig[i*ncomb:i*ncomb+ncomb].sum()
        if verbose:
            print(f"{mbin_edges_orig=}")
            print(f"{mbin_edges=}")
    else:
        if verbose:
            print(f"WARNING: {req_binsize=}, {ncomb=}; retaining original binsize {dlgm_orig=}")
        assert mbin_edges.all() == mbin_edges_orig.all() and dlgm == dlgm_orig, "Error in setting ncomb=1!"
        mhist_snap = mhist_snap_orig
        
    if verbose:
        print(f"{mhist_all_snaps.shape=}, {mhist_all_snaps.min()=}, {mhist_all_snaps.max()=}")
        print(f"{mhist_snap.shape=}, {mhist_snap.min()=}, {mhist_snap.max()=}")
        print(f"{snapnums=}")
        print(f"{dlgm=}, {mbin_edges.shape=}")
        print(f"{mbin_edges=}")

    ##gsmf = mhist_snap / dlgm / np.log(10) / box_vol_mpc  # dex^-1 Mpc^-3
    gsmf = mhist_snap / dlgm / box_vol_mpc  # dex^-1 Mpc^-3
    return mbin_edges[:-1]+0.5*dlgm, gsmf, mhist_snap #mbin_edges, dlgm


In [None]:
tng100_basePath = '/orange/lblecha/IllustrisTNG/Runs/TNG100-1/output/'
tng50_basePath = '/orange/lblecha/IllustrisTNG/Runs/TNG50-1/output/'
ill_basePath = '/orange/lblecha/Illustris/Illustris-1/output/'
z_arr = np.array([0, 0.2, 0.4, 0.5, 1, 2, 3, 4, 5])
binsize = 0.2
mhist_tng100 = {}
gsmf_tng100 = {}
bins_tng100 = {}
mhist_tng50 = {}
gsmf_tng50 = {}
bins_tng50 = {}
mhist_ill = {}
gsmf_ill = {}
bins_ill = {}
gsmf_sam = {}

sam_lgmstar_arr = np.arange(8.0,13.1,0.05)
sam_mstar_arr = 10**sam_lgmstar_arr * MSOL

for z in z_arr:
    #tmp = calc_gsmf_from_snap(tng100_basePath, z, req_binsize=binsize)
    #bins_tng100[z] = tmp[0]
    #gsmf_tng100[z] = tmp[1]
    #mhist_tng100[z] = tmp[2]
    bins_tng100[z],gsmf_tng100[z],mhist_tng100[z] = calc_gsmf_from_snap(tng100_basePath, 
                                                                        z, req_binsize=binsize)

    bins_tng50[z],gsmf_tng50[z],mhist_tng50[z] = calc_gsmf_from_snap(tng50_basePath, 
                                                                     z, req_binsize=binsize)
    
    bins_ill[z],gsmf_ill[z],mhist_ill[z] = calc_gsmf_from_snap(tng50_basePath, 
                                                                     z, req_binsize=binsize)
    
    tmp = calc_gsmf_from_snap(tng50_basePath, z, req_binsize=binsize)
    bins_tng50[z] = tmp[0]
    gsmf_tng50[z] = tmp[1]
    mhist_tng50[z] = tmp[2]

    tmp = calc_gsmf_from_snap(ill_basePath, z, req_binsize=binsize)
    bins_ill[z] = tmp[0]
    gsmf_ill[z] = tmp[1]
    mhist_ill[z] = tmp[2]
    
    gsmf_sam[z] = sam._gsmf(sam_mstar_arr, z) ##/ np.log(10)  # units of dex^-1 Mpc^-3
print(f"{gsmf_sam[1]=}")
print(f"{sam_lgmstar_arr=}")

In [None]:
plt.yscale('log')
plt.plot(bins_tng100[0.5], gsmf_tng100[0.5],'c--',label='TNG100-1 (all gals, z=0.5)')
plt.plot(bins_tng100[5], gsmf_tng100[5],'c--',label='TNG100-1 (all gals, z=5)')

In [None]:
mstar_pri,mstar_rat,mstar_tot,redz = sam.mass_stellar() 
print(mstar_pri.shape)
print(mstar_pri.min()/MSOL, mstar_pri.max()/MSOL)

print(f"phi0={sam._gsmf._phi0}, phiz={sam._gsmf._phiz}")
print(f"log10(mchar0/msun)={np.log10(sam._gsmf._mchar0/MSOL):g}, log10(mcharz/msun)={np.log10(sam._gsmf._mcharz/MSOL):g}")
print(f"alpha0={sam._gsmf._alpha0}, alphaz={sam._gsmf._alphaz}")
#plt.xscale('log')
plt.yscale('log')
#plt.xlim(5.0e7,2.0e12)
plt.ylim(1.0e-6,0.01)
plt.plot(sam_lgmstar_arr, gsmf_sam[0], lw=3)
plt.plot(sam_lgmstar_arr, gsmf_sam[0.5], lw=3)
plt.plot(sam_lgmstar_arr, gsmf_sam[5], lw=3)
#plt.plot(mstar_pri.flatten()/MSOL, sam._gsmf(mstar_pri.flatten(), 0.0)/ np.log(10) ,'.',lw=0,ms=0.1,alpha=0.2)
#plt.plot(mstar_pri.flatten()/MSOL, sam._gsmf(mstar_pri.flatten(), 5.0)/ np.log(10) ,'.',lw=0,ms=0.1,alpha=0.2)
#phi_check_z0 = np.power(10.0, sam._gsmf._phi0 + sam._gsmf._phiz * 0.0)
#m0_check = sam._gsmf._mchar0 / MSOL
#plt.plot([1.0e8,1.0e12], [phi_check_z0, phi_check_z0])
#plt.plot([m0_check,m0_check],[1.0e-6,1.0e-2],)
print(f"{gsmf_sam[1]=}")
print(f"{sam_lgmstar_arr=}")

In [None]:
import parse_conselice2016_table1
c16_zlo, c16_zhi, c16_lgm, c16_gsmfs, c16_lim, refs = parse_conselice2016_table1.calc_c16_gsmfs()
c16_zcen = (c16_zhi - c16_zlo) / 2 + c16_zlo
for i in range(len(c16_zcen)):
    print(f"zcen={c16_zcen[i]} ({c16_zlo[i]}-{c16_zhi[i]}) {refs[i]}")
print(c16_zlo.shape, c16_zhi.shape, c16_zcen.shape, c16_lgm.shape, c16_gsmfs.shape, c16_lim.shape, refs.shape)

In [None]:
def compare_sim_vs_sam_gsmfs(z=0, dz=0.1, c16_zrange=None, apply_lim=False, verbose=False): 

    compare_sim_merging_gsmfs(all_dpops[:3], req_z=z, req_dz=dz)
    plt.xlim(7.8,13.2)
    plt.ylim(1.0e-7,0.1)
    
    if isinstance(c16_zrange, (tuple, list, np.ndarray)) and len(c16_zrange)==2:
        c16_zlo, c16_zhi, c16_lgm, c16_gsmfs, c16_lim, refs = parse_conselice2016_table1.calc_c16_gsmfs()
        c16_zcen = (c16_zhi - c16_zlo) / 2 + c16_zlo
        ix = np.where((c16_zcen>=c16_zrange[0])&(c16_zcen<c16_zrange[1]))[0]
        if apply_lim:
            plt.plot(c16_lgm, c16_gsmfs[:, ix],ls='--',alpha=0.5,color='k')
            for ii in ix:
                plt.plot(c16_lgm[c16_lgm>=c16_lim[ii]], c16_gsmfs[c16_lgm>=c16_lim[ii], ii],
                         label=f"{refs[ii]} (z={c16_zlo[ii]}-{c16_zhi[ii]})")
        else:
            for ii in ix:
                plt.plot(c16_lgm, c16_gsmfs[:, ii],ls='--',alpha=0.5,
                         label=f"{refs[ii]} (z={c16_zlo[ii]}-{c16_zhi[ii]})")

    elif c16_zrange is not None:
        print(f"Error: c16_zrange must be a len-2 tuple, list, or array. {c16_zrange=}")

    plt.plot(sam_lgmstar_arr, gsmf_sam[z],'k',lw=4,label=f'SAM (z={z:.2g})')
    plt.plot(bins_tng50[z],gsmf_tng50[z],'r:',alpha=0.5,lw=3)
    plt.plot(bins_tng50[z][mhist_tng50[z]>10],gsmf_tng50[z][mhist_tng50[z]>10],'r',
             label=f'TNG50-1 (all gals, z={z:.2g})',lw=3)
    plt.plot(bins_tng100[z],gsmf_tng100[z],':',color='orchid',alpha=0.5,lw=3)
    plt.plot(bins_tng100[z][mhist_tng100[z]>10], gsmf_tng100[z][mhist_tng100[z]>10],
             color='orchid',label=f'TNG100-1 (all gals, z={z:.2g})',lw=3)
    plt.plot(bins_ill[z],gsmf_ill[z],'c:',alpha=0.5,lw=3)
    plt.plot(bins_ill[z][mhist_ill[z]>10],gsmf_ill[z][mhist_ill[z]>10],'c',
             label=f'Ill-1 (all gals, z={z:.2g})',lw=3)
    
        
    plt.legend()
    plt.show()

In [None]:
compare_sim_vs_sam_gsmfs(z=0, dz=0.1,c16_zrange=(0,0.4))
#plt.plot(c16_lgm, c16_gsmfs[:,c16_zcen<0.5])

In [None]:
compare_sim_vs_sam_gsmfs(z=0.5, dz=0.1, c16_zrange=(0.3,0.39))

In [None]:
compare_sim_vs_sam_gsmfs(z=1, dz=0.25, c16_zrange=(0.8,1.2), apply_lim=False)

In [None]:
compare_sim_vs_sam_gsmfs(z=2, dz=0.5, c16_zrange=(1.7,2.3))

In [None]:
compare_sim_vs_sam_gsmfs(z=3, dz=0.5, c16_zrange=(2.75,3.25))

In [None]:
compare_sim_vs_sam_gsmfs(z=4, dz=0.5, c16_zrange=(3.75,4.25))