In [None]:
%reload_ext autoreload
%autoreload 2
from importlib import reload

import os
import sys

import logging
import warnings
import numpy as np
import astropy as ap
import scipy as sp
import scipy.stats
import matplotlib as mpl
import matplotlib.pyplot as plt

import h5py
import tqdm.notebook as tqdm

import kalepy as kale
import kalepy.utils
import kalepy.plot

import holodeck as holo
import holodeck.sams
import holodeck.gravwaves
from holodeck import cosmo, utils, plot, discrete, sams, host_relations
from holodeck.constants import MSOL, PC, YR, MPC, GYR, SPLC

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 15})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
plt.rcParams.update({'grid.alpha': 0.5})
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode

log = holo.log
log.setLevel(logging.INFO)

In [None]:
def calc_sim_merging_gsmf(dpop, mass_type='all', mask = None, verbose = False):
    
    data = dpop.pop.mbulge/MSOL
    if verbose: print(f"in calc_gsmf: {data.shape=}")

    if mask is not None:
        data = data[mask]
        if verbose: print(f"after mask, in calc_sim_merging_gsmf: {data.shape=}")

    if mass_type == 'tot':
        data = data[:,0] + data[:,1]
    elif mass_type == 'pri':
        data = data.max(axis = 1)
    elif mass_type not in ('tot','pri','all'):
        err = "`mass_type` must be 'tot', 'pri', or 'all'"
        raise ValueError(err)

    if verbose: print(f"after setting mass_type, in calc_sim_merging_gsmf: {data.shape=}")

    data = np.log10(data)
    box_vol_mpc = dpop.pop._sample_volume / (1.0e6*PC)**3

    mhist, mbin_edges = np.histogram(data, range=(7.25,15.25), bins=16)
    mbinsize = mbin_edges[1]-mbin_edges[0]
    mbins = mbin_edges[:-1]+mbinsize
        
    return mbins, mhist/mbinsize/np.log(10)/box_vol_mpc

def compare_sim_merging_gsmfs(dpops):
    
    assert isinstance(dpops, list), '`dpops` must be a list of binary populations'

    fig, ax = plt.subplots(figsize=[10, 5])
    ax.set(ylabel='Probability Density')
    ax.grid(alpha=0.01)
    ax.set(yscale='log')
    
    for i,dp in enumerate(dpops):
        xlab = r'$\log_{10}(M_{*})$' 
        ax.set(xlabel=xlab)
        ax.set(ylabel=r'GSMF [$(dex^{-1} Mpc^{-3}$]')

        mbins, gsmf = calc_sim_merging_gsmf(dp)
        plt.plot(mbins, gsmf, color=dp.color, lw=dp.lw,label=dp.lbl+' (merging gals, all z)')

    return fig, ax


In [None]:
class Discrete:
    
    def __init__(self, freqs, freqs_edges, attrs=(None,None,'k',1.0), lbl=None, fixed_sepa=None, 
                 tau=1.0*YR, nreals=500, mod_mmbulge=False, rescale_mbulge=False, allow_mbh0=False, 
                 skip_evo=False, use_mstar_tot_as_mbulge=False):

        self.attrs = attrs
        self.freqs = freqs
        self.freqs_edges = freqs_edges
        self.lbl = lbl
        self.fname = self.attrs[0]
        self.basepath = self.attrs[1]
        self.color = self.attrs[2]
        self.lw = self.attrs[3]
        self.fixed_sepa = fixed_sepa
        self.tau = tau
        self.nreals = nreals
        self.mod_mmbulge = mod_mmbulge
        self.allow_mbh0 = allow_mbh0
        self.use_mstar_tot_as_mbulge = use_mstar_tot_as_mbulge
        
        print(f"\nCreating Discrete_Pop class instance '{self.lbl}' with tau={self.tau}, fixed_sepa={self.fixed_sepa}")
        print(f" fname={self.fname}")
        self.pop = discrete.population.Pop_Illustris(fname=self.fname, fixed_sepa=self.fixed_sepa, allow_mbh0=self.allow_mbh0, 
                                                     use_mstar_tot_as_mbulge=self.use_mstar_tot_as_mbulge)
        print(f"{self.pop.sepa.min()=}, {self.pop.sepa.max()=}, {self.pop.sepa.shape=}")
        print(f"{self.pop.mstar_tot.min()=}, {self.pop.mstar_tot.max()=}, {self.pop.mstar_tot.shape=}")

        # apply modifiers if requested
        if self.mod_mmbulge == True:
            print(f"before mass mod: {self.pop.mass.min()=}, {self.pop.mass.max()=}, {self.pop.mass.shape=}")
            print(f"before mass mod: {self.pop.mbulge.min()=}, {self.pop.mbulge.max()=}, {self.pop.mbulge.shape=}")
            print(f"after mass mod: {self.pop.mbulge.min()=}, {self.pop.mbulge.max()=}, {self.pop.mbulge.shape=}")
            self.mmbulge = holo.relations.MMBulge_KH2013()
            self.mod_KH2013 = discrete.population.PM_Mass_Reset(self.mmbulge, scatter=True, 
                                                                rescale_mbulge=rescale_mbulge)
            self.pop.modify(self.mod_KH2013)
            print(f"after mass mod: {self.pop.mass.min()=}, {self.pop.mass.max()=}, {self.pop.mass.shape=}")
            print(f"after mass mod: {self.pop.mbulge.min()=}, {self.pop.mbulge.max()=}, {self.pop.mbulge.shape=}")
            #print(f"{self.pop.sepa.min()=}, {self.pop.sepa.max()=}, {self.pop.sepa.shape=}")

        if skip_evo == False:
            # create a fixed-total-time hardening mechanism
            print(f"modeling fixed-total-time hardening...")
            self.fixed = holo.hardening.Fixed_Time_2PL.from_pop(self.pop, self.tau)
            print(f"{self.pop.sepa.min()=}, {self.pop.sepa.max()=}, {self.pop.sepa.shape=}")

            # Create an evolution instance using population and hardening mechanism
            print(f"creating evolution instance and evolving it...")
            self.evo = discrete.evolution.Evolution(self.pop, self.fixed)
            print(f"{self.evo.sepa.min()=}, {self.evo.sepa.max()=}, {self.evo.sepa.shape=}")
            print(f"{self.pop.sepa.min()=}, {self.pop.sepa.max()=}, {self.pop.sepa.shape=}")
            # evolve binary population
            self.evo.evolve()
            print(f"{self.evo._sample_volume=}")

            ## create GWB
            self.gwb = holo.gravwaves.GW_Discrete(self.evo, self.freqs, nreals=self.nreals)
            self.gwb.emit()

    def get_amplitudes_at_freqs(self, select_freqs=None):
        if (select_freqs is not None):
            print("sorry this function sucks, you cannot select freqs yet. choosing 1/yr, 1/3yr, 1/10yr.")
        
        # ---- find frequency bins closest to 1/yr, 1/(3yr), 1/(10yr)
        self.idx_ayr = np.where(np.abs(self.freqs-1/YR)==np.abs(self.freqs-1/YR).min())[0]
        self.idx_a3yr = np.where(np.abs(self.freqs-1/(3*YR))==np.abs(self.freqs-1/(3*YR)).min())[0]
        self.idx_a10yr = np.where(np.abs(self.freqs-1/(10*YR))==np.abs(self.freqs-1/(10*YR)).min())[0]
        print(self.idx_ayr,self.idx_a3yr,self.idx_a10yr)

        self.ayr = self.gwb.back[self.idx_ayr,:].flatten()
        self.a3yr = self.gwb.back[self.idx_a3yr,:].flatten()
        self.a10yr = self.gwb.back[self.idx_a10yr,:].flatten()




In [None]:
def create_dpops(tau=1.0, fsa=1.0e4, mod_mmbulge=True, nreals=500, inclIll=True, inclOldIll=False, 
                 inclT50=True, inclT300=True, inclRescale=False, allow_mbh0=False, skip_evo=False,
                 use_mstar_tot_as_mbulge=False):
    
    # ---- Set the fixed binary lifetime
    print(f"Setting inspiral timescale tau = {tau} Gyr.")
    tau = tau * GYR
    
    # ---- Define the GWB frequencies
    freqs, freqs_edges = utils.pta_freqs()

    # ---- Initialize return variables
    all_dpops = []
    tng_dpops = []

    # ---- (Optionally) set the fixed initial binary separation & initialize fsa return vars
    if fsa is not None:
        print(f"Setting fixed init binary sep = {fsa} pc.")
        fsa = fsa * PC
        all_fsa_dpops = []
        tng_fsa_dpops = []
        
    # ---- Define dpop attributes: (filename, plot color, plot linewidth)
    tpath = '/orange/lblecha/IllustrisTNG/Runs/'
    ipath = '/orange/lblecha/Illustris/'
    dpop_attrs = {
        # dont use this file; it has at least one merger remnant with mbulge=0. prob need to rerun with Ngas=10
        ### ('galaxy-mergers_Illustris-1_gas-000_dm-010_star-010_bh-000.hdf5', 'darkgreen', 1.5), 
        #'TNG50-1-N100' : ('galaxy-mergers_TNG50-1_gas-100_dm-100_star-100_bh-001.hdf5', 
        #                  tpath+'TNG50-1/output/', 'darkred', 4),
        #'TNG50-1-N100-bh0' : ('galaxy-mergers_TNG50-1_gas-100_dm-100_star-100_bh-000.hdf5', 
        #                      tpath+'TNG50-1/output/', 'darkred', 3),
        #'TNG50-1' : ('galaxy-mergers_TNG50-1_gas-800_dm-800_star-800_bh-001.hdf5', 
        #             tpath+'TNG50-1/output/', 'r', 3.5),
        #'TNG50-1-bh0' : ('galaxy-mergers_TNG50-1_gas-800_dm-800_star-800_bh-000.hdf5', 
        #                 tpath+'TNG50-1/output/', 'r', 2.5),
        #'TNG50-2' : ('galaxy-mergers_TNG50-2_gas-100_dm-100_star-100_bh-001.hdf5', 
        #             tpath+'TNG50-2/output/', 'orange', 2.5),
        #'TNG50-3' : ('galaxy-mergers_TNG50-3_gas-012_dm-012_star-012_bh-001.hdf5', 
        #             tpath+'TNG50-3/output/', 'y', 1.5),
        ##'oldIll' : (None, 'brown', 2.5),
        #---'Ill-nomprog' : ('galaxy_merger_files_with_no_mprog/galaxy-mergers_Illustris-1_gas-100_dm-100_star-100_bh-001.hdf5', 
        #---                 ipath+'Illustris-1/output/', 'g', 2.5),
        #'Ill-N010-bh0' : ('galaxy-mergers_Illustris-1_gas-000_dm-000_star-010_bh-000.hdf5', 
        #                  ipath+'Illustris-1/output/', 'darkgreen', 1.5),
        #'Ill-bh0' : ('galaxy-mergers_Illustris-1_gas-100_dm-100_star-100_bh-000.hdf5', 
        #             ipath+'Illustris-1/output/', 'g', 1.5),
        'Ill' : ('galaxy-mergers_Illustris-1_gas-100_dm-100_star-100_bh-001.hdf5', 
                 ipath+'Illustris-1/output/', 'g', 2.5),
        #'TNG100-1-N010-bh0' : ('galaxy-mergers_TNG100-1_gas-000_dm-000_star-010_bh-000.hdf5', 
        #                       tpath+'TNG100-1/output/', 'darkblue', 2.5),
        #'TNG100-1-bh0' : ('galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-000.hdf5', 
        #                  tpath+'TNG100-1/output/', 'b', 1.5),
        'TNG100-1' : ('galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-001.hdf5', 
                      tpath+'TNG100-1/output/', 'b', 2.5),
        #---'TNG100-1-nomprog' : ('galaxy_merger_files_with_no_mprog/galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-001.hdf5', 
        #---                      tpath+'TNG100-1/output/', 'b', 2.5),
        #---'TNG100-1-bh0-nomprog' : ('galaxy_merger_files_with_no_mprog/galaxy-mergers_TNG100-1_gas-100_dm-100_star-100_bh-000.hdf5', 
        #---                          tpath+'TNG100-1/output/', 'b', 1.5),
        #---'TNG100-1-N012-bh0' : ('galaxy_merger_files_with_no_mprog/galaxy-mergers_TNG100-1_gas-012_dm-012_star-012_bh-000.hdf5', 
        #---                       tpath+'TNG100-1/output/', 'darkblue', 2.5),
        #'TNG100-2' : ('galaxy-mergers_TNG100-2_gas-012_dm-012_star-012_bh-001.hdf5', 
        #              tpath+'TNG100-1/output/', 'c', 1.5),
        #'TNG300-1' : ('galaxy-mergers_TNG300-1_gas-012_dm-012_star-012_bh-001.hdf5', tpath+'TNG300-1/output/', 'm', 1.5),
        #'TNG300-1-bh0' : ('galaxy-mergers_TNG300-1_gas-012_dm-012_star-012_bh-000.hdf5', tpath+'TNG300-1/output/', 'm', 1.0),
        #'TNG300-1-N100' : ('galaxy-mergers_TNG300-1_gas-100_dm-100_star-100_bh-001.hdf5', tpath+'TNG300-1/output/', 'pink', 1.5),
        #'TNG300-1-N100-bh0' : ('galaxy-mergers_TNG300-1_gas-100_dm-100_star-100_bh-000.hdf5', tpath+'TNG300-1/output/', 'pink', 1)
    }
    
    # ---- Loop thru dict and create dpops
    for l in dpop_attrs.keys():
        if ('Ill' in l) and (not inclIll): 
            continue
        if (l == 'oldIll') and (not inclOldIll):
            continue
        if ('TNG50' in l) and (not inclT50): 
            continue
        if ('TNG300' in l) and (not inclT300): 
                continue

        if '-bh0' not in l:
            dp = Discrete(freqs, freqs_edges, lbl=l, tau=tau, fixed_sepa=None, nreals=nreals,
                          allow_mbh0=allow_mbh0, skip_evo=skip_evo, attrs=dpop_attrs[l],
                          use_mstar_tot_as_mbulge=use_mstar_tot_as_mbulge)

            all_dpops = all_dpops + [dp]
            if 'Ill' not in l: 
                tng_dpops = tng_dpops + [dp]
        else:
            print(f"Skipping run {l} with bh0")

        if fsa is not None:

            lbl='fsa-mm-'+l if mod_mmbulge else 'fsa-'+l
            dp_fsa = Discrete(freqs, freqs_edges, lbl='fsa-mm-'+l, tau=tau, fixed_sepa=fsa, nreals=nreals,
                              allow_mbh0=allow_mbh0, skip_evo=skip_evo, attrs=dpop_attrs[l], 
                              mod_mmbulge=mod_mmbulge, use_mstar_tot_as_mbulge=use_mstar_tot_as_mbulge)

            all_fsa_dpops = all_fsa_dpops + [dp_fsa]
            if 'Ill' not in l: 
                tng_fsa_dpops = tng_fsa_dpops + [dp_fsa]
            
            if ('TNG300' in l) and (inclT300) and (inclRescale):
                rescale_dp_fsa = Discrete(freqs, freqs_edges, lbl='rescale-fsa-mm-'+l,tau=tau, fixed_sepa=fsa, 
                                          nreals=nreals, allow_mbh0=allow_mbh0, skip_evo=skip_evo, attrs=dpop_attrs[l],
                                          mod_mmbulge=True, use_mstar_tot_as_mbulge=use_mstar_tot_as_mbulge, rescale_mbulge=True)
                tng_fsa_dpops = tng_fsa_dpops + [rescale_dp_fsa]

        print(f"{l} dpop_attrs: {dpop_attrs[l][0]} {dpop_attrs[l][1]} {dpop_attrs[l][2]} {dpop_attrs[l][3]}")

    
    if fsa is not None:

        return all_dpops, tng_dpops, all_fsa_dpops, tng_fsa_dpops

    else:
        
        return all_dpops, tng_dpops

In [None]:
all_dpops, tng_dpops, all_fsa_dpops, tng_fsa_dpops = create_dpops(allow_mbh0=True, mod_mmbulge=False, skip_evo=True)

In [None]:
print("creating SAM using Galaxy Pair Fraction (GPF) + Galaxy Merger Timescale (GMT)...")

sam = sams.Semi_Analytic_Model(gpf = sams.GPF_Power_Law())

print("SKIPPING HARDENING")

#print("    ...calculating hardening")
#hard = holo.hardening.Fixed_Time_2PL_SAM(sam, all_fsa_dpops[0].tau, sepa_init=1.0e4*PC)
#print("    ...creating gwb")
#gwb_sam = sam.gwb_new(all_fsa_dpops[0].freqs_edges, hard, realize=500)

In [None]:
print("creating 'no-GPF' SAM using Galaxy Merger Rate (GMR) "
      "    (uses galaxy merger rates directly from RG15 instead of GPF+GMT)...")

sam_no_gpf = sams.Semi_Analytic_Model()

print("SKIPPING HARDENING")

#print("    ...calculating hardening for no-GPF SAM")
#hard_no_gpf = holo.hardening.Fixed_Time_2PL_SAM(sam_no_gpf, all_fsa_dpops[0].tau, sepa_init=1.0e4*PC)
#print("    ...creating gwb for no-GPF SAM")
#gwb_sam_no_gpf = sam_no_gpf.gwb_new(all_fsa_dpops[0].freqs_edges, hard_no_gpf, realize=50)

In [None]:
compare_sim_merging_gsmfs(tng_dpops)
plt.legend()
plt.show()

In [None]:
compare_sim_merging_gsmfs(tng_fsa_dpops)
plt.legend()
plt.show()

In [None]:
compare_sim_merging_gsmfs(all_dpops)
plt.legend()
plt.show()

In [None]:
compare_sim_merging_gsmfs(all_fsa_dpops)
plt.legend()
plt.show()

In [None]:
nrows = 2
ncols = 2
mass_type = 'tot'
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex='all', sharey='all', figsize=[10,9])
#print(axes.shape)
#print(axes.dtype)
#print(axes[0].shape)

if mass_type == 'pri':
    xlab = r'$\log_{10}(M_{*,pri})$'
elif mass_type == 'tot':
    xlab = r'$\log_{10}(M_{*,tot})$'
elif mass_type == 'all':
    xlab = r'$\log_{10}(M_{*}) (all progenitors)$'

axes[0,0].set(yscale='log', 
              ylabel=r'GSMF [$(\log_{10} M)^{-1} Mpc^{-3}$]',
              xlim=(7.25,15.5), ylim=(1.0e-6,1.0))
axes[1,0].set(yscale='log',xlabel=xlab,
              ylabel=r'BHMF [$(\log_{10} M)^{-1} Mpc^{-3}$]',
              xlim=(7.25,15.5), ylim=(1.0e-6,1.0))
axes[1,1].set(yscale='log',xlabel=xlab,
              #ylabel=r'BHMF [$(\log_{10} M)^{-1} Mpc^{-3}$]',
              xlim=(7.25,15.5), ylim=(1.0e-6,1.0))
    
#dpop_lists = [t50_hires_fid_and_fsa_dpops, t100_hires_fid_and_fsa_dpops, ill_hires_fid_and_fsa_dpops]
#dpop_lists = [ 
#    [ d for d in all_hires_fid_and_fsa_dpops if d.lbl in ['TNG50-1', 'fsa-TNG50-1', 'fsa-TNG50-1-bh0'] ],
#    [ d for d in all_hires_fid_and_fsa_dpops if d.lbl in ['TNG50-1-N100', 'fsa-TNG50-1-N100', 'fsa-TNG50-1-N100-bh0'] ],
#    [ d for d in all_hires_fid_and_fsa_dpops if d.lbl in ['TNG100-1', 'fsa-TNG100-1', 'fsa-TNG100-1-bh0'] ],
#    ill_hires_fid_and_fsa_dpops
#    ]
dpop_lists = [d for d in [all_fsa_dpops]]
print(len(dpop_lists))
for i in range(nrows):
    for j in range(ncols):
        
        if i+2*j >= (nrows*ncols): 
            break

        for dp in dpop_lists[i+2*j]:
            qbulge = dp.pop.mbulge[:,0] / dp.pop.mbulge[:,1]
            qbulge[qbulge>1] = 1.0/qbulge[qbulge>1]
            #print(f"{qbulge.min()=} {qbulge.max()=}")
            
            x,y = calc_sim_merging_gsmf(dp, mass_type=mass_type)
            axes[i,j].plot(x,y, lw=dp.lw, label=dp.lbl)
            
            xqcut,yqcut = calc_sim_merging_gsmf(dp, mask = (qbulge>0.1), mass_type=mass_type)
            axes[i,j].plot(xqcut,yqcut, lw=dp.lw, label=dp.lbl+' q>0.1', alpha=0.4)
            
            axes[i,j].legend()

    else:
        continue  # only executed if the inner loop did NOT break
    break  # only executed if the inner loop DID break

fig.suptitle(f"GSMF for merging galaxies (mass type = {mass_type})")
fig.subplots_adjust(top=0.95, wspace=0.1, hspace=0.1)
plt.show()

### Load data for all subhalos & compare with merging subhalos

In [None]:
def calc_gsmf_from_snap(basePath, snapNum, req_binsize=0.05, verbose=False):

    f = h5py.File(f"{basePath}/gsmf_all_snaps_Nmin1.hdf5","r")

    snapnums = f.attrs['SnapshotNums']
    dlgm_orig = f.attrs['LogMassBinWidth']
    mbin_edges_orig = np.array(f['StellarMassBinEdges'])
    nbins_orig = mbin_edges_orig.size - 1
    mhist_all_snaps = np.array(f['StellarMassHistograms'])
    mhist_snap_orig = mhist_all_snaps[:,(snapnums==snapNum)].flatten()
    print(f"{mhist_snap_orig.shape=}, {mbin_edges_orig.shape=}")
    if mhist_snap_orig.size != nbins_orig:
        print('whoops')
        return

    box_vol_mpc = f.attrs['box_volume_mpc']
    scalefac = f.attrs['SnapshotScaleFacs'][snapnums==snapNum]
    zsnap = 1.0 / scalefac - 1.0
    print(f"{scalefac=}, {zsnap=}")
    

    if req_binsize < dlgm_orig:
        raise ValueError(f"{req_binsize=} requested, but min allowed is {dlgm_orig=}")
    if int(req_binsize/dlgm_orig) > nbins_orig/2:
        raise ValueError(f"{req_binsize=} requested, but max allowed is {dlgm_orig*nbins_orig/2=}")

    ncomb = int(req_binsize/dlgm_orig)
    dlgm = dlgm_orig * ncomb
    mbin_edges = mbin_edges_orig[::ncomb]
    nbins = mbin_edges.size
    if ncomb > 1:
        mbin_edges = np.append(mbin_edges, mbin_edges[-1]+dlgm)
        mhist_snap = np.zeros((nbins))
        print(f"{mbin_edges.size=}")
        for i in range(mbin_edges.size-1):
            mhist_snap[i] = mhist_snap_orig[i*ncomb:i*ncomb+ncomb].sum()
        print(f"{mbin_edges_orig=}")
        print(f"{mbin_edges=}")
    else:
        print(f"WARNING: {req_binsize=}, {ncomb=}; retaining original binsize {dlgm_orig=}")
        assert mbin_edges.all() == mbin_edges_orig.all() and dlgm == dlgm_orig, "Error in setting ncomb=1!"
        mhist_snap = mhist_snap_orig
        
    
    #ncombine_mbins = 40
    #ncombine_mbins = 10
    #ncombine_mbins = 5
    #ncombine_mbins = 2
    #ncombine_mbins = 20
    #dlgm_new = dlgm_all_snaps * ncombine_mbins
    #mbin_edges = mbins_all_snaps[::ncombine_mbins]
    #mbin_edges = np.append(mbin_edges, mbin_edges[-1]+dlgm_new)
    #mbin_edges = mbin_edges[mbin_edges >= 8.0]
    if verbose:
        print(f"{mhist_all_snaps.shape=}, {mhist_all_snaps.min()=}, {mhist_all_snaps.max()=}")
        print(f"{mhist_snap.shape=}, {mhist_snap.min()=}, {mhist_snap.max()=}")
        print(f"{snapnums=}")
        print(f"{dlgm=}, {mbin_edges.shape=}")
        print(f"{mbin_edges=}")


    gsmf = mhist_snap / dlgm / np.log(10) / box_vol_mpc  # dex^-1 Mpc^-3
    return gsmf, mbin_edges[:-1]+0.5*dlgm #mbin_edges, dlgm

#import illustris_python as il
#def calc_gsmf_from_snap(basePath,snapNum):
#    subs = il.groupcat.loadSubhalos(basePath,snapNum,fields=['SubhaloLenType','SubhaloMassType',
#                                                             'SubhaloMassInHalfRadType','SubhaloMassInRadType'])
#    mstar = subs['SubhaloMassType'][:,4]*1.0e10/0.704
#    Ngas = subs['SubhaloLenType'][:,0]
#    Ndm = subs['SubhaloLenType'][:,1]
#    Nstar = subs['SubhaloLenType'][:,4]
#    Nbh = subs['SubhaloLenType'][:,5]
    
#    mstar = mstar[(Ngas>100)&(Ndm>100)&(Nstar>100)&(Nbh>0)]
#    lgmstar = np.log10(mstar)
#    #print(lgmstar.min(),lgmstar.max())
#    mstar_hist, bin_edges = np.histogram(lgmstar, bins=20)
#    binsize = bin_edges[1:] - bin_edges[:-1]
#    box_vol_mpc = 1357213.6324803103 # TNG100  ## BAD!! THIS IS BAD. 
#    gsmf = mstar_hist / binsize / np.log(10) / box_vol_mpc # dex^-1 Mpc^-3
#    #print(gsmf)
#    #print(bin_edges)
#    return gsmf, bin_edges

In [None]:
x = np.random.random(100)
assert x.all()==x[::1].all(), 'whoops'

In [None]:
tng100_1_basePath = '/orange/lblecha/IllustrisTNG/Runs/TNG100-1/output/'
gsmf_tng100_1_s99, bedg_tng100_1_s99 = calc_gsmf_from_snap(tng100_1_basePath,99,req_binsize=0.5) # z=0
gsmf_tng100_1_s17, bedg_tng100_1_s17 = calc_gsmf_from_snap(tng100_1_basePath,17,req_binsize=0.5) # z=5
gsmf_tng100_1_s33, bedg_tng100_1_s33 = calc_gsmf_from_snap(tng100_1_basePath,33,req_binsize=0.5) # z=2
gsmf_tng100_1_s67, bedg_tng100_1_s67 = calc_gsmf_from_snap(tng100_1_basePath,67,req_binsize=0.5) # z=0.5
alt_gsmf_tng100_1_s67, alt_bedg_tng100_1_s67 = calc_gsmf_from_snap(tng100_1_basePath,67) # z=0.5
gsmf_tng100_1_s72, bedg_tng100_1_s72 = calc_gsmf_from_snap(tng100_1_basePath,72,req_binsize=0.5) # z=0.4
print(f"{gsmf_tng100_1_s67.shape=}, {bedg_tng100_1_s67.shape=}")

In [None]:
plt.yscale('log')
bw=bin_edges_tng100_1_snap67[1]-bin_edges_tng100_1_snap67[0]
altbw=alt_bin_edges_tng100_1_snap67[1]-alt_bin_edges_tng100_1_snap67[0]
plt.plot(bin_edges_tng100_1_snap67[:-1]+0.5*bw,gsmf_tng100_1_snap67,'c--',label='TNG100-1 (all gals, z=0.5)')
plt.plot(alt_bin_edges_tng100_1_snap67[:-1]+0.5*altbw,alt_gsmf_tng100_1_snap67,'c',label='TNG100-1 (all gals, z=0.5)')
print(f"{bin_edges_tng100_1_snap67=}")
print(f"{alt_bin_edges_tng100_1_snap67=}")
print(gsmf_tng100_1_snap67[0],bin_edges_tng100_1_snap67[0],bin_edges_tng100_1_snap67[1])
print(alt_gsmf_tng100_1_snap67[0:20].sum(),alt_bin_edges_tng100_1_snap67[0],alt_bin_edges_tng100_1_snap67[20])

In [None]:
tng50_1_basePath = '/orange/lblecha/IllustrisTNG/Runs/TNG50-1/output/'
gsmf_tng50_1_s99, bedg_tng50_1_s99 = calc_gsmf_from_snap(tng50_1_basePath,99,req_binsize=0.5) # z=0
gsmf_tng50_1_s17, bedg_tng50_1_s17 = calc_gsmf_from_snap(tng50_1_basePath,17,req_binsize=0.5) # z=5
gsmf_tng50_1_s33, bedg_tng50_1_s33 = calc_gsmf_from_snap(tng50_1_basePath,33,req_binsize=0.5) # z=2
gsmf_tng50_1_s67, bedg_tng50_1_s67 = calc_gsmf_from_snap(tng50_1_basePath,67,req_binsize=0.5) # z=0.5
alt_gsmf_tng50_1_s67, alt_bedg_tng50_1_s67 = calc_gsmf_from_snap(tng50_1_basePath,67) # z=0.5
gsmf_tng50_1_s72, bedg_tng50_1_s72 = calc_gsmf_from_snap(tng50_1_basePath,72,req_binsize=0.5) # z=0.4

In [None]:
mstar_arr = 10**np.arange(8.0,13.1,0.1) * MSOL


In [None]:
#print(sam.mtot.shape, sam.mrat.shape, hard._sepa_init, hard._norm.shape)
#mass_stellar() returns (mstar_pri, mstar_rat, mstar_tot, redz)
mstar_pri,mstar_rat,mstar_tot,redz = sam.mass_stellar() 
print(mstar_pri.shape)
print(mstar_pri.min()/MSOL, mstar_pri.max()/MSOL)
#sam_dadt_vals = hard.dadt(sam.mtot, sam.mrat, np.repeat(hard._sepa_init,sam.mtot.size))
#print("mtot [msun]:",sam.mtot/MSOL)
#mstar_arr = 10**np.arange(8.0,12.1,0.1) * MSOL
print(mstar_arr.min()/MSOL, mstar_arr.max()/MSOL)

gsmf_func_z0 = sam._gsmf(mstar_arr, 0.0) / np.log(10)  # units of dex^-1 Mpc^-3
gsmf_func_z0pt4 = sam._gsmf(mstar_arr, 0.4) / np.log(10)  # units of dex^-1 Mpc^-3
gsmf_func_z0pt5 = sam._gsmf(mstar_arr, 0.5) / np.log(10)  # units of dex^-1 Mpc^-3
gsmf_func_z3 = sam._gsmf(mstar_arr, 3.0) / np.log(10)  # units of dex^-1 Mpc^-3
gsmf_func_z5 = sam._gsmf(mstar_arr, 5.0) / np.log(10)  # units of dex^-1 Mpc^-3
#print(sam._gsmf(mstar_arr, 0.0))
print(f"phi0={sam._gsmf._phi0}, phiz={sam._gsmf._phiz}")
print(f"log10(mchar0/msun)={np.log10(sam._gsmf._mchar0/MSOL):g}, log10(mcharz/msun)={np.log10(sam._gsmf._mcharz/MSOL):g}")
print(f"alpha0={sam._gsmf._alpha0}, alphaz={sam._gsmf._alphaz}")
plt.xscale('log')
plt.yscale('log')
plt.xlim(5.0e7,2.0e12)
plt.ylim(1.0e-6,0.01)
plt.plot(mstar_arr/MSOL, gsmf_func_z0, lw=3)
#plt.plot(mstar_arr/MSOL, gsmf_func_z0pt4, lw=3)
plt.plot(mstar_arr/MSOL, gsmf_func_z0pt5, lw=3)
plt.plot(mstar_arr/MSOL, gsmf_func_z5, lw=3)
plt.plot(mstar_pri.flatten()/MSOL, sam._gsmf(mstar_pri.flatten(), 0.0)/ np.log(10) ,'.',lw=0,ms=0.1,alpha=0.2)
plt.plot(mstar_pri.flatten()/MSOL, sam._gsmf(mstar_pri.flatten(), 5.0)/ np.log(10) ,'.',lw=0,ms=0.1,alpha=0.2)
phi_check_z0 = np.power(10.0, sam._gsmf._phi0 + sam._gsmf._phiz * 0.0)
m0_check = sam._gsmf._mchar0 / MSOL
plt.plot([1.0e8,1.0e12], [phi_check_z0, phi_check_z0])
plt.plot([m0_check,m0_check],[1.0e-6,1.0e-2],)

In [None]:
compare_sim_merging_gsmfs(all_dpops[:2])
#box_vol_mpc = dpop_tng50_1.evo._sample_volume / (1.0e6*PC)**3
#print("box_vol_mpc = ", box_vol_mpc)
plt.xlim(7.8,13.2)
plt.ylim(1.0e-7,0.1)
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0pt5,'k',lw=2,label='SAM (z=0.5)')
plt.plot(bin_edges_tng50_1_snap67[:-1],gsmf_tng50_1_snap67,'r--',label='TNG50-1 (all gals, z=0.5)')
#plt.plot(alt_bin_edges_tng50_1_snap67[:-1],alt_gsmf_tng50_1_snap67,'r',label='TNG50-1 (all gals, z=0.5)')
#plt.plot(bin_edges_tng50_1_snap33[:-1],gsmf_tng50_1_snap33,'c:',label='TNG50-1 (all gals, z=3)')
#plt.plot(bin_edges_tng50_1_snap17[:-1],gsmf_tng50_1_snap17,'c:',label='TNG50-1 (all gals, z=5)')
plt.plot(bedg_tng100_1_s67,gsmf_tng100_1_s67,'c--',label='TNG100-1 (all gals, z=0.5)')
#plt.plot(alt_bedg_tng100_1_s67,alt_gsmf_tng100_1_s67,'c',label='TNG100-1 (all gals, z=0.5)')
#plt.plot(bin_edges_tng100_1_snap33[:-1],gsmf_tng100_1_snap33,'c:',label='TNG100-1 (all gals, z=3)')
#plt.plot(bin_edges_tng100_1_snap17[:-1],gsmf_tng100_1_snap17,'c:',label='TNG100-1 (all gals, z=5)')
plt.legend()
plt.show()

In [None]:
compare_sim_merging_gsmfs(all_dpops[:2])
plt.xlim(7.8,13.2)
plt.ylim(1.0e-7,0.1)
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z0,'k',lw=2,label='SAM (z=0)')
plt.plot(bin_edges_tng50_1_snap99[:-1],gsmf_tng50_1_snap99,'r--',label='TNG50-1 (all gals, z=0)')
plt.plot(bedg_tng100_1_s99,gsmf_tng100_1_s99,'c--',label='TNG100-1 (all gals, z=0)')
plt.legend()
plt.show()

In [None]:
compare_sim_merging_gsmfs(all_dpops[:2])
plt.xlim(7.8,13.2)
plt.ylim(1.0e-7,0.1)
plt.plot(np.log10(mstar_arr/MSOL), gsmf_func_z3,'k',lw=2,label='SAM (z=3)')
plt.plot(bin_edges_tng50_1_snap33[:-1],gsmf_tng50_1_snap33,'r--',label='TNG50-1 (all gals, z=3)')
plt.plot(bedg_tng100_1_s33,gsmf_tng100_1_s33,'c--',label='TNG100-1 (all gals, z=3)')
plt.legend()
plt.show()