# Creating the figures for The Cooler Pasts paper
## Written by Eric Rohr

In [None]:
### import modules
import illustris_python as il
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.patheffects as pe
import matplotlib.transforms as transforms
from matplotlib.gridspec import GridSpec
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import matplotlib.patches as patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.ndimage import gaussian_filter
from scipy import ndimage
from scipy.interpolate import interp1d
from scipy import interpolate
from tenet.util import sphMap
import scipy.stats
from scipy.stats import norm
from sklearn.neighbors import KernelDensity
from scipy.stats import ks_2samp, anderson_ksamp
from scipy.optimize import curve_fit
import os
import time
import h5py
import rohr_utils as ru 
import random
import six
%matplotlib widget

plt.style.use('fullpage.mplstyle')

zs, times = ru.return_zs_costimes()
times /= 1.0e9 # [Gyr]
scales = 1. / (1.+ zs)

os.chdir('/u/reric/Scripts/')
! pwd

In [None]:
# define some plotting parameters
figsizewidth  = 6.902 # the textwidth in inches of MNRAS
figsizeratio = 9. / 16.
figsizeheight = figsizewidth * figsizeratio

figsizewidth_column = (244. / 508.) * figsizewidth
figsizeheight_column = figsizewidth_column * figsizeratio

outdirec_figures = '/u/reric/Figures/ColdPast/TNG-Cluster/'
outdirec_overleaf = '/u/reric/Papers/Rohretal_TNG_CoolerPast/figures/'
outdirecs = [outdirec_figures, outdirec_overleaf]
savefig = False

In [None]:
# define some plotting functions that are useful everywhere
def add_redshift_sincez2(ax, label=True, axislabel_kwargs=dict()):
    """
    For a given x axis, add redshift since z=2 to the top x-axis. 
    Optionally label the axis + tick marks.
    Returns ax
    """
    ticks_SnapNum = [33, 40, 50, 59, 67, 78, 84, 91, 99]
    ticks_costime = times[ticks_SnapNum]
    ticks_labels = ['2', '1.5', '1', '0.7', '0.5', '0.3', '0.2', '0.1', '0']

    xlolim = ru.floor_to_value(times[np.argmin(abs(zs - 2.5))], 0.1)
    xhilim = ru.ceil_to_value(times[np.argmin(abs(zs - 0.0))], 0.1)
    xhilim = 14.1

    ax.set_xlim(xlolim, xhilim)

    redshift_ax = ax.twiny()
    redshift_ax.set_xlim(ax.get_xlim())
    redshift_ax.set_xticks(ticks_costime)
    redshift_ax.tick_params(axis='both', which='minor', top=False)
    
    yscale = ax.get_yscale()
    if yscale == 'log':
        locmin = mpl.ticker.LogLocator(subs=(0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9))
        ax.yaxis.set_minor_locator(locmin)
        #ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())

    if label:
        redshift_ax.set_xlabel(r'Redshift', **axislabel_kwargs)
        redshift_ax.set_xticklabels(ticks_labels)
    else:
        redshift_ax.set_xticklabels([])
        
    return ax


def add_redshift_sincez5(ax, label=True, axislabel_kwargs=dict()):
    """
    For a given x axis, add redshift since z=5 to the top x-axis. 
    Optionally label the axis + tick marks.
    Returns ax
    """
    ticks_SnapNum = [17, 33, 50, 67, 84, 99]
    ticks_costime = times[ticks_SnapNum]
    ticks_labels = ['5', '2', '1', '0.5', '0.2', '0']

    xlolim = ru.floor_to_value(times[np.argmin(abs(zs - 5.5))], 0.1)
    xlolim = 0.9
    xhilim = ru.ceil_to_value(times[np.argmin(abs(zs - 0.0))], 0.1)
    xhilim = 14.1

    ax.set_xlim(xlolim, xhilim)

    redshift_ax = ax.twiny()
    redshift_ax.set_xlim(ax.get_xlim())
    redshift_ax.set_xticks(ticks_costime)
    redshift_ax.tick_params(axis='both', which='minor', top=False)
    
    yscale = ax.get_yscale()
    if yscale == 'log':
        locmin = mpl.ticker.LogLocator(subs=(0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9))
        ax.yaxis.set_minor_locator(locmin)
        #ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())

    if label:
        redshift_ax.set_xlabel(r'Redshift', **axislabel_kwargs)
        redshift_ax.set_xticklabels(ticks_labels)
    else:
        redshift_ax.set_xticklabels([])
        
    return ax


def add_redshift_sincez7(ax, label=True, axislabel_kwargs=dict()):
    """
    For a given x axis, add redshift since z=7 to the top x-axis. 
    Optionally label the axis + tick marks.
    Returns ax
    """
    ticks_SnapNum = [11, 21, 33, 50, 67, 84, 99]
    ticks_costime = times[ticks_SnapNum]
    ticks_labels = ['7', '4', '2', '1', '0.5', '0.2', '0']

    xlolim = ru.floor_to_value(times[np.argmin(abs(zs - 7.5))], 0.1)
    xlolim = 0.9
    xhilim = ru.ceil_to_value(times[np.argmin(abs(zs - 0.0))], 0.1)
    xhilim = 14.1

    ax.set_xlim(xlolim, xhilim)

    redshift_ax = ax.twiny()
    redshift_ax.set_xlim(ax.get_xlim())
    redshift_ax.set_xticks(ticks_costime)
    redshift_ax.tick_params(axis='both', which='minor', top=False)
    
    yscale = ax.get_yscale()
    if yscale == 'log':
        locmin = mpl.ticker.LogLocator(subs=(0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9))
        ax.yaxis.set_minor_locator(locmin)
        #ax.yaxis.set_minor_formatter(mpl.ticker.NullFormatter())

    if label:
        redshift_ax.set_xlabel(r'Redshift', **axislabel_kwargs)
        redshift_ax.set_xticklabels(ticks_labels)
    else:
        redshift_ax.set_xticklabels([])
        
    return ax


In [None]:
# smoothing functions for evolution quantities
def noSmoothEvolution(group, xdset_key, ydset_key):
    """ return xdset, ydset without any smoothing. Returns xdset, ydset """
    return group[xdset_key], group[ydset_key]


def smoothSubhaloIndicesEvolution(group, xdset_key, ydset_key):
    """ interpolate between where the subhalo is not defined. Returns xdset, ydset """

    subhalo_indices = group['SubfindID'] >= 0

    _xdset = group[xdset_key][subhalo_indices]
    _ydset = group[ydset_key][subhalo_indices]

    ydset_func = interp1d(_xdset, _ydset, bounds_error=False, fill_value=0)

    xdset = group[xdset_key]
    ydset = ydset_func(xdset)

    return xdset, ydset
    

def smoothRunningMedianEvolution(group, xdset_key, ydset_key, nRM=5):
    """ 
    Compute the running median over nRM snapshots of ydset. 
    Assumes that y > 0 for all times. 
    Returns xdset, ydset.
    """
    subhalo_indices = group['SubfindID'] >= 0
    
    _xdset = group[xdset_key][subhalo_indices]
    _ydset = group[ydset_key][subhalo_indices]

    mask = (_ydset > 0)

    y_rm = ru.RunningMedian(_ydset[mask], 5)
    ydset_func = interp1d(_xdset[mask], y_rm, bounds_error=False, fill_value=0)

    xdset = group[xdset_key]
    ydset = ydset_func(xdset)

    return xdset, ydset



## TNG-Cluster

In [None]:
# load the gas radial profile dictionary
grp_keys = ['SnapNum', 'SubfindID', 'CosmicTime', 'HostGroup_M_Crit200',
            'HostGroup_R_Crit200', 'SubhaloMass', 'HostSubhaloGrNr',
            'Subhalo_Mstar_Rgal', 'Subhalo_Rgal',
            'SubhaloColdGasMass', 'SubhaloHotGasMass', 'SubhaloGasMass',
            'radii', 'SubhaloColdGasMassShells',
            'SubhaloCGMColdGasMass', 'SubhaloCGMColdGasFraction']

def load_grpdict(infname, sim='L680n8192TNG', keys=None):
    result = {}
    with h5py.File('../Output/%s_subfindGRP/'%sim + infname, 'r') as f:
        for group_key in f.keys():
            result[group_key] = {}
            if not keys:
                keys = f[group_key].keys()
            for dset_key in keys:
                if 'xray' in dset_key:
                    continue
                if dset_key not in f[group_key]:
                    continue
                result[group_key][dset_key] = f[group_key][dset_key][:]
        f.close()
    
    return result

# NB: the output file and directory are at L680n8192TNG, while all figures are saved as 'TNG-Cluster'
sim = 'L680n8192TNG'

infname = 'central_subfind_%s_branches.hdf5'%sim
TNGCluster_grp_dict = load_grpdict(infname, sim)
TNGCluster_grp_dict_keys = list(TNGCluster_grp_dict.keys())
sim = 'TNG-Cluster'



In [None]:
# reformat the grp_dict into the tau_dict

CGMColdGasMass_key = 'SubhaloCGMColdGasMass'
fCGMColdGas_key = 'SubhaloCGMColdGasFraction'

bh_mass_key = 'MainBHMass'
bh_particleID_key = 'MainBHParticleID'
BH_CumEgyInjection_RM_key = 'MainBH_CumEgyInjection_RM'
BH_RM_FirstSnap_key = 'MainBH_RM_FirstSnap'

Nsats_total_key = 'Nsatellites_total'
Nsats_dr200_key = 'Nsatellites_dsathost<R200c'
Nsats_mstar1e7_dr200c_key = 'Nsatellties_Mstar>1.0e7_dsathost<R200c'
Nsats_mstar1e7_fgas_dr200c_key = 'NSatellites_Mstar>1.0e7_fgas>0.01_dsathost<R200c'
Nsats_mstar1e9_dr200c_key = 'Nsatellties_Mstar>1.0e9_dsathost<R200c'
Nsats_mstar1e9_fgas_dr200c_key = 'NSatellites_Mstar>1.0e9_fgas>0.01_dsathost<R200c'
Nsats_mstar1e10_dr200c_key = 'NSatellites_Mstar>1.0e10_fgas>0.01_dsathost<R200c'
Nsats_mstar1e10_fgas_dr200c_key = 'NSatellites_Mstar>1.0e10_fgas>0.01_dsathost<R200c'
Nsats_mstar_1e7_SF_dr200c_key = 'NSatellites_Mstar>1.0e7_SF_dsathost<R200c'

Nsats_keys = [Nsats_total_key, Nsats_dr200_key, 
              Nsats_mstar1e7_dr200c_key, Nsats_mstar1e7_fgas_dr200c_key,
              Nsats_mstar1e9_dr200c_key, Nsats_mstar1e9_fgas_dr200c_key, 
              Nsats_mstar1e10_dr200c_key, Nsats_mstar1e10_fgas_dr200c_key,
              Nsats_mstar_1e7_SF_dr200c_key]

quench_snap_flag = -99
bh_rm_firstsnap_flag = -100

def create_taudict(grp_dict, snaps, branches_flag=False):
    """ 
    Given the grp_dict and snaps of interest, rearrange the grp_dict
    into a 2D array of the datasets at the snaps of interest. 
    snaps should be a list of snapNums, where snapNum -99 
    is the flag to use the quenching_snap.
    Returns the tau_dict.
    """
    # input validation
    if not isinstance(snaps, (list, np.ndarray)):
        snaps = [snaps]

    tau_keys = grp_keys.copy()
    if branches_flag:
        tau_keys.extend(Nsats_keys)
                        
    tauresult = {}
    # begin loop over subhalos
    for group_index, group_key in enumerate(grp_dict):
        group = grp_dict[group_key]
        # if just starting, then initialize the dictionary 
        if group_index == 0:
            tauresult['SubfindID'] = np.zeros(len(grp_dict), dtype=int)
            tauresult['HostSubhaloGrNr'] = np.zeros(len(grp_dict), dtype=int)
            for tau_key in tau_keys:
                if 'radii' in tau_key or 'Shells' in tau_key:
                    continue
                for snap in snaps:
                    if snap == quench_snap_flag:
                        tauresult_key = tau_key + '_snapNumQuench'
                        tauresult[tauresult_key] = np.zeros(len(grp_dict),
                                                            dtype=group[tau_key].dtype) - 1
                    elif snap == bh_rm_firstsnap_flag:
                        tauresult_key = tau_key + '_snapNumBHRMFirstSnap'
                        tauresult[tauresult_key] = np.zeros(len(grp_dict),
                                                            dtype=group[tau_key].dtype) - 1
                    else:
                        tauresult_key = tau_key + '_snapNum%03d'%snap
                        tauresult[tauresult_key] = np.zeros(len(grp_dict),
                                                            dtype=group[tau_key].dtype) - 1
                    
        tauresult['SubfindID'][group_index] = group['SubfindID'][0]
        #tauresult['HostSubhaloGrNr'][group_index] = group['HostSubhaloGrNr'][0]
        # finish initializing the the result
        # assign the values at z=0, which are always the 0th element in the array
        for tau_key in tau_keys:
            if 'radii' in tau_key or 'Shells' in tau_key:
                continue
            for snap in snaps:
                if snap == quench_snap_flag:
                    tauresult_key = tau_key + '_snapNumQuench'
                    tau_index = group['quenching_snap'] == group['SnapNum']
                    tauresult[tauresult_key][group_index] = group[tau_key][tau_index]
                elif snap == bh_rm_firstsnap_flag:
                    tauresult_key = tau_key + '_snapNumBHRMFirstSnap'
                    tau_index = group[BH_RM_FirstSnap_key] == group['SnapNum']
                    tauresult[tauresult_key][group_index] = group[tau_key][tau_index]
                else:
                    tauresult_key = tau_key + '_snapNum%03d'%snap
                    tau_index = snap == group['SnapNum']
                    tauresult[tauresult_key][group_index] = group[tau_key][tau_index]
                        
            # finish loop over snaps for the grp_key
        # finish grp_keys for the group
    # finish loop over the groups
    return tauresult
    

In [None]:
TNGCluster_tau_dict = create_taudict(TNGCluster_grp_dict, [99, 67, 50, 33, bh_rm_firstsnap_flag], branches_flag=True)


### Figure 1: evolution of the cool ICM gas at the population level at fixed snapshots

In [None]:
cmap = mpl.colors.LinearSegmentedColormap.from_list('Redshift_custom', ['k', 'tab:purple', 'tab:orange'])
bounds = np.linspace(-0.5, 2.5, 4)
bounds = np.array([-0.25, 0.25, 0.75, 3.25])
cmap_norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
redshifts = [0.0, 0.5, 2.0]
colors = ['k', 'tab:purple', 'tab:orange']

def plot_stacked_temp_dict_evolution(ax, grp_dict, grp_dict_keys,
                                     redshifts=redshifts, colors=colors, cmap=cmap, norm=cmap_norm,
                                     return_color_dset='Redshift', color_dset_log=False):

    ax, lc = add_stacked_temp_dict(ax, grp_dict, grp_dict_keys, redshifts=redshifts,
                                   colors=colors, cmap=cmap, norm=norm, return_color_dset=return_color_dset, color_dset_log=color_dset_log)
    ax.set_yscale('log')
    cbar = fig.colorbar(lc, ax=ax, ticks=redshifts)
    cbar.ax.set_yticklabels(['0', '0.5', '2'])
    cbar.ax.minorticks_off()
    cbar.solids.set(alpha=1.0)
    cbar.set_label(r'Redshift', fontsize='small')

    ax.set_ylabel(r'PDF', fontsize='small')
    ax.set_xlabel(r'ICM Gas Temperature [log k]', fontsize='small')
    ax.set_title(r'TNG-Cluster $M_{\rm 200c}^{z=0}\sim10^{15}\, {\rm M_\odot}$ MPBs (%d)'%(len(grp_dict_keys)), fontsize='medium')

    ax.set_ylim(3.0e-6, 5)
    ax.set_xlim(2.75, 9.0)
    
    return ax


def add_stacked_temp_dict(ax, grp_dict, grp_dict_keys, redshifts=redshifts,
                          bincents_key='CGMTemperaturesHistogramBincents', dset_key='CGMTemperaturesHistogram',
                          colors=colors, cmap=cmap, norm=cmap_norm,
                          return_color_dset='Redshift', color_dset_log=False):

    result = {}
    for redshift_i, redshift in enumerate(redshifts):
        result[redshift] = {}
        stacked_dict, bincents, hists, color = return_stacked_temp_dict(grp_dict, grp_dict_keys, return_all_profiles=True,
                                                                        dset_key=dset_key, bincents_key=bincents_key,
                                                                        redshift=redshift, return_color_dset=return_color_dset,
                                                                        color_dset_log=color_dset_log)
        result[redshift]['stacked_dict'] = stacked_dict
        result[redshift]['bincents'] = bincents
        result[redshift]['hists'] = hists
        result[redshift]['color'] = color
        result[redshift]['stacked_dict_kwargs'] = dict(path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()],
                                                       marker='None', ls='-', lw=2, c=colors[redshift_i], label=r'$z=%d$'%(int(redshift)))
        result[redshift]['norm_kwargs'] = dict(lw=0.1, alpha=0.5, ls='-', cmap=cmap, norm=norm)

    for redshift_i, redshift in enumerate(redshifts[::-1]):
        _result = result[redshift]
        time_index = np.argmin(np.abs(grp_dict[grp_dict_keys[0]]['Redshift'] - redshift))

        y = _result['stacked_dict']['50']
        x = _result['stacked_dict']['bincents']
        if 'Mass' in dset_key:
            mask = y > 0
        else: 
            mask = x > 0
        ax.plot(x[mask], y[mask], **_result['stacked_dict_kwargs'])

        _hists = _result['hists']
        _bincents = _result['bincents']
        ys = []
        xs = []
        cs = []
        for y_i, y in enumerate(_hists):
            if grp_dict[grp_dict_keys[y_i]]['SubfindID'][time_index] < 0:
                continue
            if 'Mass' in dset_key:
                mask = y > 0
            else: 
                mask = x > 0
            x = _bincents[y_i][mask]
            xs.append(x)
            ys.append(y[mask])
            cs.append(_result['color'][y_i])

        if not norm:
            vmin = np.percentile(cs, 5)
            vmax = np.percentile(cs, 95)
            if color_dset_log:
                _result['norm_kwargs']['norm'] = mpl.colors.Normalize(vmin, vmax)
            else:
                _result['norm_kwargs']['norm'] = mpl.colors.LogNorm(vmin, vmax)
        
        lc = ru.multiline(xs, ys, cs, ax=ax, **_result['norm_kwargs'])

    return ax, lc


def return_stacked_temp_dict(grp_dict, grp_dict_keys, dset_key='CGMTemperaturesHistogram', bincents_key='CGMTemperaturesHistogramBincents',
                             redshift=0., return_all_profiles=False, return_color_dset=False, color_dset_log=False):
    """
    Given the grp_dict_keys and the implicit grp_dict, stack the dset_key for all of the
    keys at the given time, which must match be either z0 or a tau definition, such as tau_infall_mass
    """
        
    result_dict = {}
 
    # initalize the outputs
    _bincents = np.zeros((len(grp_dict_keys), len(grp_dict[grp_dict_keys[0]][bincents_key][0])), dtype=float) - 1   
    _hists = _bincents.copy() 

    if return_color_dset:
        if return_color_dset in grp_dict[grp_dict_keys[0]].keys():
            color_dset = np.zeros(len(grp_dict_keys), dtype=grp_dict[grp_dict_keys[0]][return_color_dset].dtype) - 1
        else:
            print('return_color_dset %s not recognized. Please choose from the following'%return_color_dset)
            print(grp_dict[grp_dict_keys[0]].keys())
            raise KeyError()

    for index, grp_dict_key in enumerate(grp_dict_keys):
        group = grp_dict[grp_dict_key]
        time_index = np.argmin(abs(group['Redshift'] - redshift))
        if group['SubfindID'][time_index] < 0:
            continue
        if return_color_dset:
            if color_dset_log:
                color_dset[index] = np.log10(group[return_color_dset][time_index])
            else:
                color_dset[index] = group[return_color_dset][time_index]
        _bincents[index,:] = group[bincents_key][time_index]
        _hist = group[dset_key][time_index]
        if np.sum(_hist) <= 0:
            continue
        
        if 'Temp' in dset_key:
            _hists[index,:] = clean_temp_hist(_bincents[index,:], _hist, interp_zero=True, force_zero=True)
        elif 'Mass' in dset_key:
            result = compute_massext_profile(group, dset_key, time_index, norm='r200c')
            _bincents[index,:] = result[0]
            _hists[index,:] = result[1]

    # finish loop of indices, save final results
    bincents = np.ma.masked_values(_bincents, -1)
    hists = np.ma.masked_values(_hists, -1)

    result_dict['50'] = np.median(hists, axis=0)
    result_dict['16'] = np.percentile(hists, 16, axis=0)
    result_dict['84'] = np.percentile(hists, 84, axis=0) 
    result_dict['Ngal'] = len(hists)
    result_dict['bincents'] = np.median(bincents, axis=0)
    
    if return_all_profiles:
        if return_color_dset:
            return result_dict, bincents, hists, color_dset
        else:
            return result_dict, bincents, hists
    else:
        if return_color_dset:
            return result_dict, color_dset
        else:
            return result_dict

def clean_temp_hist(bincents, hist, interp_zero=False, force_zero=False, rewrite_sfgas=False, normalize=True):
    """
    clean the temperature histograms of artifacts, namely:
    (1) interpolate zero-values between 10^4 K and maximum temperature
    (2) overwrite temperatures between 10^3 and 10^4 K to be 0
    (3) attribute SF gas to 10^4 K rather than 10^3 K
    (4) Normalize the histogram to create a PDF
    """
    _result = hist.copy()
    binwidth = bincents[1] - bincents[0]
    tolerance = 1.0e-1 # % of binwidth
    if interp_zero:
        bincents_mask = bincents > (4. - binwidth * tolerance)
        zero_mask = hist <= 0
        if _result[bincents_mask & ~zero_mask].size >= 2:
            tempfunc = interp1d(bincents[bincents_mask & ~zero_mask], hist[bincents_mask & ~zero_mask], bounds_error=False, fill_value=0)
            _result[bincents_mask & zero_mask] = tempfunc(bincents[bincents_mask & zero_mask])

    if force_zero:
        bincents_mask = ((bincents > (3. + binwidth * tolerance)) &
                         (bincents < (4. - binwidth * tolerance)))
        _result[bincents_mask] = 0

    if rewrite_sfgas:
        sf_mask = np.argmin(np.abs(bincents - 3.0))
        rewrite_mask = np.argmin(np.abs(bincents - 4.0))
        _result[rewrite_mask] = _result[sf_mask]
        _result[sf_mask] = 0.

    if normalize:
        result = _result / np.sum(_result * binwidth)
    else:
        result = _result

    return result


In [None]:
grp_dict = TNGCluster_grp_dict
tau_dict = TNGCluster_tau_dict
M200c_log = np.log10(tau_dict['HostGroup_M_Crit200_snapNum099'])
mask = ((M200c_log > 14.95) * (M200c_log < 15.05))
SubfindIDs = tau_dict['SubfindID'][mask]
grp_dict_keys = []
for subfindID in SubfindIDs:
    grp_dict_keys.append('099_%08d'%subfindID)
    

fig, ax = plt.subplots()
ax = plot_stacked_temp_dict_evolution(ax, grp_dict, grp_dict_keys)

### Figure 3: MPBs of the all clusters

In [None]:
grp_dict = TNGCluster_grp_dict
tau_dict = TNGCluster_tau_dict
maskdset_key = 'HostGroup_M_Crit200_snapNum099'
maskdset_bincents = [14.4, 14.7, 15.0, 15.3]
maskdset_binwidths = [0.1, 0.1, 0.1, 0.15]
xdset_key = 'CosmicTime'
ydset_key = 'HostGroup_M_Crit200'
cdset_key = 'HostGroup_M_Crit200'
smooth_func = smoothSubhaloIndicesEvolution

cmap = 'viridis_r'
cmap_norm = mpl.colors.Normalize(vmin=14.3, vmax=15.4)
all_profiles_kwargs = dict(cmap=cmap, norm=cmap_norm, linewidths=0.1)
stacked_profiles_kwargs = dict(cmap=cmap, norm=cmap_norm, linewidths=3, 
                               path_effects=[pe.Stroke(linewidth=4, foreground='white'), pe.Normal()])



def plot_stacked_dict_evolution(ax, plot_result,
                                all_profiles_kwargs=all_profiles_kwargs, stacked_profiles_kwargs=stacked_profiles_kwargs):
    """
    plot all and stacked profiles from plot_result to ax
    """

    lc = ru.multiline(plot_result['all_profiles']['xs'], plot_result['all_profiles']['ys'], plot_result['all_profiles']['cs'],
                    ax=ax, **all_profiles_kwargs)

    xs = []
    ys = []
    cs = []
    for bincent in plot_result['stacked_profiles']:
        result_dict = plot_result['stacked_profiles'][bincent]
        xs.append(result_dict['bincents'])
        ys.append(result_dict['50'])
        cs.append(float(bincent))

    lc = ru.multiline(xs, ys, cs, ax=ax, **stacked_profiles_kwargs)

    return ax, lc



def return_stacked_dict_evolution(grp_dict=grp_dict, tau_dict=tau_dict,
                                  maskdset_key=maskdset_key, maskdset_bincents=maskdset_bincents, maskdset_binwidths=maskdset_binwidths,
                                  xdset_key=xdset_key, ydset_key=ydset_key, cdset_key=cdset_key, 
                                  smooth_func=smooth_func):
    """ 
    Compute the evolution of ydset_key as a function of xdset_key, and compute the median evolution trend based on
    the maskdset. Returns a dictionary plot_result which contains the necessary profiles to plot via ru.multiline()
    both all profiles and the median trends. 
    """

    plot_result = dict(all_profiles=dict(), stacked_profiles=dict())

    result_dict, result, color_dset = compute_stacked_dict_evolution(grp_dict, list(grp_dict.keys()),
                                                                    smooth_func=smooth_func, ydset_key=ydset_key, xdset_key=xdset_key,
                                                                    return_all_profiles=True, return_color_dset=cdset_key, color_dset_log=True)
    ys = []
    for row in np.arange(result.shape[0]):
        ys.append(result[row])
    xs = [result_dict['bincents'].tolist()] * len(ys)
    cs = color_dset.tolist()

    plot_result['all_profiles'] = dict(xs=xs, ys=ys, cs=cs, Ngal=len(xs))

    SubfindIDz0 = tau_dict['SubfindID']

    maskdset = tau_dict[maskdset_key]

    maskdset_bincents = [14.4, 14.7, 15.0, 15.3]
    maskdset_binwidths = [0.1, 0.1, 0.1, 0.15]
    xs = []
    ys = []
    cs = []
    for bincent_i, bincent in enumerate(maskdset_bincents):
        maskdset_binwidth = maskdset_binwidths[bincent_i]
        maskdset_lolim = 10.**(bincent - maskdset_binwidth)
        maskdset_hilim = 10.**(bincent + maskdset_binwidth)

        indices = np.where((maskdset > maskdset_lolim) & (maskdset < maskdset_hilim))[0]

        grp_dict_keys = []
        for index in indices:
            grp_dict_keys.append('099_%08d'%SubfindIDz0[index])

        result_dict = compute_stacked_dict_evolution(grp_dict, grp_dict_keys, xdset_key=xdset_key,
                                                    smooth_func=smooth_func, ydset_key=ydset_key)
        
        plot_result['stacked_profiles'][bincent] = result_dict
    
    return plot_result


def compute_stacked_dict_evolution(grp_dict, grp_dict_keys, xdset_key='CosmicTime', ydset_key=CGMColdGasMass_key,
                                   smooth_func=noSmoothEvolution, return_all_profiles=False, return_color_dset=False, color_dset_log=False):
    """
    Given the grp_dict and grp_dict_keys, stack the ydset_key for all of the
    keys at xdset_key, where both xdset and ydset are scalars. Typically, xdset
    should a time quantity, namely SnapNum, CosmicTime, or Redshift, although
    any monotonically increasing (or decreasing) quantity is valid, such as 
    MainBH_CumEgyInjection_RM or maybe even HostGroup_M_Crit200. 
    Returns the resulting median + 16+84the percentils stacked dictionary,
    plus optionally all profiles and a color dset. 
    """
    group0 = grp_dict[grp_dict_keys[0]]
    xDSET = group0[xdset_key]
        
    result_dict = {}
 
    # initalize the outputs
    result = np.zeros((len(grp_dict_keys), len(xDSET)), dtype=group0[ydset_key].dtype) - 1.

    if return_color_dset:
        if return_color_dset in group0.keys():
            color_dset = np.zeros(len(grp_dict_keys), dtype=group0[return_color_dset].dtype) - 1
        else:
            print('Error return_color_dset %s not available in'%return_color_dset, group0.keys())
            raise ValueError

    for index, grp_dict_key in enumerate(grp_dict_keys):
        group = grp_dict[grp_dict_key]

        xdset, ydset = smooth_func(group, xdset_key, ydset_key)

        result[index,:] = ydset

        if return_color_dset:
            # assumes value of interest is at z=0, index, 0
            if color_dset_log:
                color_dset[index] = np.log10(group[return_color_dset][0])
            else:
                color_dset[index] = group[return_color_dset][0]
        
    # finish loop of indices, save final results
    result = np.ma.masked_values(result, -1)
    
    result_dict['50'] = np.median(result, axis=0)
    result_dict['16'] = np.percentile(result, 16, axis=0)
    result_dict['84'] = np.percentile(result, 84, axis=0)
    result_dict['Ngal'] = len(result)
    result_dict['bincents'] = xdset 
    
    if return_all_profiles:
        if return_color_dset:
            return result_dict, result, color_dset
        else:
            return result_dict, result
    else:
        if return_color_dset:
            return result_dict, color_dset
        else:
            return result_dict



In [None]:

fig, axs = plt.subplots(2, 1, figsize=(figsizewidth_column, figsizeheight_column * 2.75))

# top panel: M200c(t) vs t
ax = axs[0]

plot_result = return_stacked_dict_evolution()
ax, lc = plot_stacked_dict_evolution(ax, plot_result)

ax.set_yscale('log')
ax = add_redshift_sincez7(ax)
ax.set_ylim(10.**(11), 10.**(15.7))

ax.set_xlabel(r'Cosmic Time [Gyr]')
ax.set_ylabel(r'Cluster Mass $[M_{\rm 200c}(t) / {\rm M_\odot}]$')

cax = inset_axes(ax, width='50%', height='10%', loc='lower right')
cbar = plt.colorbar(lc, cax=cax, orientation='horizontal')
cbar.set_label(r'$\log_{10}[M_{\rm 200c}^{z=0} / {\rm M_\odot}]$', fontsize='small', color='black')
cbar.ax.tick_params(labelsize='small')
cbar.set_ticks(maskdset_bincents)
cax.xaxis.set_label_position('top')
cax.xaxis.set_ticks_position('top')
cbar.ax.minorticks_off()

# bottom panel: M_CoolGas^ICM(t) vs t
ax = axs[1]

plot_result = return_stacked_dict_evolution(ydset_key=CGMColdGasMass_key, smooth_func=smoothRunningMedianEvolution)
ax, _ = plot_stacked_dict_evolution(ax, plot_result)

ax.set_yscale('log')
ax = add_redshift_sincez7(ax)
ax.set_ylim(3e7, 7.0e11)

ax.set_xlabel(r'Cosmic Time [Gyr]')
ax.set_ylabel(r'Cool ICM Mass $[M_{\rm CoolGas}^{\rm ICM}(t) / {\rm M_\odot}]$')

axs[0].set_title('TNG-Cluster Main Progenitors \n' + r'$M_{\rm 200c}^{z=0} \sim 10^{14.3-15.4}\, {\rm M_\odot}$ (352)')

fname = '%s_M200ct_ICMCGMt_CosmicTime-Redshift_Evolution.pdf'%(sim)
if savefig:
    for outdirec in outdirecs:
        fig.savefig(outdirec + fname, bbox_inches='tight')