In [None]:
%load_ext autoreload
%autoreload 2
%aimport

In [None]:
%matplotlib inline

In [None]:
import warnings
from matplotlib import rc
rc('text', usetex=True)
warnings.filterwarnings('ignore')

In [None]:
import astropy
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt 
import re 
from astropy.table import Table
import astropy.table
import json
from scipy import stats
from copy import deepcopy

from relaxed import halo_parameters, halo_catalogs 
import warnings

# Utility

In [None]:
def setup(name='m11'):
    output = f'../temp/output_{name}/'
    cat_file = Path(output, 'final_table.csv')
    z_map_file = Path(output, 'z_map.json')
    
    with open(z_map_file,'r') as fp: 
        scale_map = json.load(fp) # map from i -> scale
        
    #only keep stable scales. 
    indices = np.array(list(scale_map.keys()))
    scales = np.array(list(scale_map.values()))
    keep = scales>.15 
    indices = indices[keep]
    scales = scales[keep] # we are removing from the end bc that's how scales are ordered.
    
    # load catalog.
    hcat = halo_catalogs.HaloCatalog('Bolshoi', cat_file, label=name)
    hcat.load_cat_csv()

    # remove weird ID
    hcat.cat = hcat.cat[hcat.cat['mvir_a18']>0]
    
    return hcat, indices, scales
    
    

def get_m_a_corrs(cat, param, indices): 
    corrs = [] 
    for k in indices:
        k = int(k)
        colname = f'mvir_a{k}'
        keep = (~np.isnan(cat[colname])) & (cat[colname] > 0)
        
        # get mass fraction at this scale
        mvir = cat['mvir'][keep]
        ms = cat[colname][keep]
        ms = ms / mvir
        pvalue = cat[param][keep]

        # get correlation.
        assert np.all(ms>0) and np.all(~np.isnan(ms))
        assert np.all(mvir>0)
        corr = stats.spearmanr(ms, pvalue)[0]
        corrs.append(corr)
    
    return np.array(corrs)


def add_box_indices(cat, boxes=8, box_size=250):
    # create a new row add it to the catalogu for which box it is in. 
    assert int(boxes**(1./3)) == boxes**(1./3)
    box_per_dim = int(boxes**(1./3))
    divides = np.linspace(0, box_size, box_per_dim+1)[1:-1] # only use the middle. 
    cat.add_column(np.zeros(len(cat)), name='ibox')
    for k, dim in enumerate(['x','y','z']):
        for d in divides:
            cat['ibox'] += 2**k * (d < cat[dim])

def vol_jacknife_values(f, cat, param, *args):
    # assumes cat has had its box indices added with the function above.
    n_boxes = int(np.max(cat['ibox']) + 1)
    values = [] 
    for b in range(n_boxes):
        _cat = cat[cat['ibox'] != b]
        value = f(_cat, param, *args)
        values.append(value)
    return np.array(values )

# Plot Wang 2020 Correlation plot (Bolshoi)

## prepare dynamical time

In [None]:
hcat, indices, scales = setup('m11')

In [None]:
plt.hist(np.log10(hcat.cat['mvir']))

In [None]:
np.mean(hcat.cat['tdyn']/10**9)

In [None]:
plt.figure(figsize=(6,6))
plt.hist(hcat.cat['tdyn']/ 10**9)
plt.tick_params(labelsize=14)

## m(a) with jackknife erros

In [None]:
# get cosmology for Bolshoi 
from astropy.cosmology import LambdaCDM

sim = halo_catalogs.sims['Bolshoi']
cosmo = LambdaCDM(H0=sim.h*100, Ob0=sim.omega_b, Ode0=sim.omega_lambda, Om0=sim.omega_m)
print(cosmo.age(0).value)

# create function that converts scale to fractional tdyn
def get_fractional_tdyn(scale, tdyn):
    # tdyn in Gyrs
    z = (1/scale) - 1
    return (cosmo.age(0).value - cosmo.age(z).value) / tdyn

In [None]:
names = ['m11', 'm12']
params = ['cvir', 'x0', 'v0', 't/|u|', 'q', 'phi_l']
latex_params = ['c_{\\rm vir}', 'x_{\\rm off}', 'v_{\\rm off}', 't/|u|', 'q', '\\Phi_{L}']
colors = ['r','b', 'g', 'm', 'k', 'y']
markers = np.array(['.', 'x'])

In [None]:
# get the jacknife errors for each index in indices for each param for each cat. 
errs = {name:{} for name in names}
for i,name in enumerate(names):
    hcat, indices, scales = setup(name)
    add_box_indices(hcat.cat)
    for j, param in enumerate(params):
        values = vol_jacknife_values(get_m_a_corrs, hcat.cat, param, indices)
        errs[name][param] = np.sqrt(values.var(axis=0)*7)

## Full plot m_a correlations

In [None]:
fig, axes = plt.subplots(1,len(names),figsize=(len(names)*7,7))
axes = axes.flatten() if len(names) > 1 else [axes]
for i, name in enumerate(names):
    hcat, indices, scales = setup(name)
    ax = axes[i]
    max_scales = [0.]*len(params)
    tdyn = np.mean(hcat.cat['tdyn']) / 10**9 #Gyr which astropy also returns by default

    for j, param in enumerate(params):
        latex_param = latex_params[j]
        color = colors[j]
        corrs = get_m_a_corrs(hcat.cat, param, indices)
        err = errs[name][param]
        pos = corrs > 0 
        neg = ~pos
        corrs = abs(corrs)
        
        # plot positive corr and negative corr with different markers. 
        if sum(pos) > 0:
            label = f'${latex_param}$' if sum(pos) > sum(neg) else None
            ax.plot(scales[pos], corrs[pos], color=color, marker=markers[0], label=label, markersize=7)
            
        if sum(neg) > 0:
            label = f'${latex_param}$' if sum(pos) < sum(neg) else None
            ax.plot(scales[neg], corrs[neg], color=color, marker=markers[1], label=label, markersize=7)
            
        max_scales[j] = scales[np.nanargmax(abs(corrs))]
        
        nan_indx = np.isnan(err)
        err[nan_indx] = 0. 
        y1 = corrs - err
        y2 = corrs + err
        ax.fill_between(scales, y1, y2, alpha=0.2, linewidth=0.001, color=color)
        
                
    # draw a vertical line at max scales
    for j, s in enumerate(max_scales):
        color = colors[j]
        ax.axvline(s, linestyle='--', color=color)
        

    ax.set_ylim(0, 1.0)
    ax.set_xlim(0, 1.0)
    ax.set_title(name, size=22)
    ax.set_ylabel(f"$\\rho(\\cdot, m(a))$", size=22)
    ax.set_xlabel(f"$a$", size=22)
    ax.tick_params(axis='both', which='major', labelsize=16)
    
    # add additional x-axis with tydn fractional scale
    ax2 = ax.twiny()
    ax2.set_xlim(ax.get_xlim())
    ax2.set_xticks(ax.get_xticks())
    
    fractional_tdyn = get_fractional_tdyn(ax.get_xticks(), tdyn)
    fractional_tdyn = np.array([f'{x:.2g}' for x in fractional_tdyn])
    ax2.set_xticklabels(fractional_tdyn, size=16)
    ax2.set_xlabel("$\\tau_{\\rm dyn}$", size=22)
    
    ax.legend(loc='best', prop={'size': 14})
    
    ax.set_xlim(0.15, 1)
    ax2.set_xlim(0.15 , 1)
plt.tight_layout()
plt.show()

# Exponential fit to log m(a)

## histogram of goodness of fit (alpha + beta)

In [None]:
from relaxed.progenitors.catalog import get_ma, get_alpha, lma_fit

In [None]:
hcat, indices, scales = setup('m11')
zs = (1/ scales) -1 
ma = get_ma(hcat.cat, indices)
lma = np.log(ma)[:, :160] # towards very big scales many masses become really close to 0. 

# remove infs 
keep = np.zeros(len(hcat.cat))
for indx in range(len(lma)):
    if np.isinf(lma[indx]).sum() == 0 and np.isnan(lma[indx]).sum() == 0: 
        keep[indx] = 1

keep = keep.astype(bool)
hcat.cat = hcat.cat[keep]
lma = lma[keep]
n = lma.shape[1]
zs = zs[:n]
print(len(hcat.cat))
print(lma.shape)
print(zs.shape)

In [None]:
# calculate all alpha, betas and add them to table. 
alphas = [] 
for idx in range(len(hcat.cat)):
    alpha = get_alpha(zs, lma[idx])
    alphas.append(alpha)
c1 = astropy.table.Column(alphas, name='alpha')
hcat.cat.add_column(c1)

In [None]:
fits = lma_fit(zs.reshape(1, -1), hcat.cat['alpha'])
fits.shape

In [None]:
# check correctness by looking 4x5 plots by eye.
# choose random indices 
rows_indices = np.random.choice(np.arange(len(hcat.cat)), size=20, replace=True)
fig, axes = plt.subplots(4, 5, figsize=(15,12))
axes=axes.flatten()
for i, indx in enumerate(rows_indices):
    ax = axes[i]
    ax.plot(zs, lma[indx])
    ax.plot(zs, fits[indx])
    ax.set_xlabel('$z$', size=20)
    ax.set_ylabel('$\\log m(z)$', size=20)

    # alter x_ticks
    start, end = ax.get_xlim()
    step_size=0.5
    ax.xaxis.set_ticks(np.arange(0, 4, step_size))
plt.tight_layout()
fig.savefig('../figures/lma_fit_alpha.pdf')

In [None]:
# now make histogrm of chi2
chi2 = np.sqrt(np.mean((lma - fits)**2, 1))
print(chi2.shape)
plt.hist(chi2, bins=50, range=(0, 0.5), histtype='step')
hcat.cat.add_column(astropy.table.Column(chi2, name='chi2'))

In [None]:
# add a1/2 
from relaxed.halo_parameters import A2
a2 = A2.from_cat(hcat.cat, scales, indices)
c = astropy.table.Column(a2, name='a2')
hcat.cat.add_column(c)

In [None]:
plt.hist(hcat.cat['a2'])

In [None]:
from relaxed import plots, plot_funcs
from relaxed.halo_parameters import get_hparam
from collections import OrderedDict

# round up params.
params = [
    "mvir",
    "eta",
    "x0",
    "v0",
    "q",
    "cvir",
    "phi_l",
]
hparams = [get_hparam(param, log=True) for param in params]
hparams.append(get_hparam("alpha", log=False))
hparams.append(get_hparam("f_sub", log=False))
hparams.append(get_hparam("a2", log=False))
# hparams.append(get_hparam("chi2", log=False))
hparams.append(get_hparam("gamma_tdyn", log=False))

params.append('alpha')
params.append('f_sub')
params.append('a2')
# params.append('chi2')
params.append('gamma_tdyn')

plot_func = plot_funcs.MatrixValues(xlabel_size=24, ylabel_size=24)
plot = plots.MatrixPlot(
    plot_func, hparams, nrows=1, ncols=1, figsize=(10, 10), figpath='../figures/'
)

# load catalogs
plot.load(hcat)

names = ['Bolshoi']
plot_params = OrderedDict({param: {*names} for param in params})
plot.generate(plot_params)

fname = '../figures/matrix3.pdf'
plot.save(fname=fname)

# get a_m

## Preliminaries

In [None]:
from relaxed.progenitors.catalog import get_ma, get_alpha, lma_fit

In [None]:
hcat, indices, scales = setup('m11')

In [None]:
# histogram
ma = get_ma(hcat.cat, indices)
plt.hist(ma[:,1:].flatten(), bins=100, range=(0, 1), histtype='step');


In [None]:
# there are a bunch with mass bigger than 1 at some point in their history.
count = 0 
for i in range(len(ma)):
    if sum(ma[i]>1.0001) > 0: 
        count+=1
print(count)

In [None]:
np.arange(1, 100, 1)

In [None]:
np.percentile(lma, 0.01)

In [None]:
len(ma.flatten())

In [None]:
len(lma)

In [None]:
lma = np.log(ma.flatten()[ma.flatten() > 0])
plt.hist(lma, bins=30, histtype='step')
plt.yscale('log')
plt.title('\\rm all redshifts', size=20)
plt.xlabel('$\\log m(z)$', size=18)

In [None]:
# obtain a(m), use percentile bins
_scales = np.array([*scales, np.nan]) # append dummy indices
p1 = 0
log_bins = np.linspace(-5, 0, 100)
# ma_flat = ma.flatten()
a_m = [] 
for e1,e2 in zip(log_bins, log_bins[1:]): #edges
    m1, m2 = np.exp(e1), np.exp(e2)
    bools = (m1 <= ma) & (ma <= m2)

    # ideally each halo only has one m satisfying the condition above but our bins could be too wide so pick smallest a
    indices_m = np.argmax(bools, axis=1)
    indices_m = np.where(bools.sum(1), indices_m, -1)
    a_m.append(_scales[indices_m])

a_m = np.array(a_m).T

In [None]:
plt.hist(a_m.flatten(), bins=50)

In [None]:
plt.plot(((~np.isnan(a_m)).sum(0)))
plt.title("\\# of haloes in each mass bin", size=18)
plt.xlabel("mass bin")

## Phil's procedure

1. Inversion is only a well-defined process for monotonic functions, and m(a) for an individual halo isn't necessarily monotonic. To solve this, the standard redefinition of a(m0) is that it's the first a where m(a) > m0. (This is, for example, how Rockstar defines halfmass scales.)

2. Next, first pick your favorite set of mass bins that you'll evaluate it at.  I think logarithmic bins spanning 0.01*m(a=1) to 1*m(a=1) is pretty reasonable, but you should probably choose this based on the mass ranges which are the most informative once you.

3. Now, for each halo with masses m(a_i), measure M(a_i) = max_j{ m(a_j) | j <= i}.

4. Remove (a_i, M(a_i)) pairs where M(a_i) = M(a_{i-1}), since this will mess up the inversion.

5. Use scipy.interpolate.interp1d to create a function, f(m), which evaluates a(m). For stability, you'll want to run the interpolation on log(a_i) and log(M(a_i)), not a_i and M(a_i).

6. Evaluate f(m) at the mass bins you decided that you liked in step 2. Now you can run your pipeline on this, just like you did for m(a).



In [None]:
from relaxed.progenitors.catalog import get_ma, get_alpha, lma_fit

In [None]:
def get_am(name='m11'):
    hcat, indices, scales = setup(name)
    
    # 2. 
    mass_bins = np.linspace(np.log(0.01), np.log(1.0), 100)
    
    # 3.
    ma = get_ma(hcat.cat, indices)
    Ma = np.zeros_like(ma)
    for i in range(len(ma)):
        _min = ma[i][0]
        for j in range(len(ma[i])):
            if ma[i][j] < _min:
                _min = ma[i][j]
            Ma[i][j] = _min
            
    # 4. + 5. 
    # We will get the interpolation for each halo separately
    import scipy
    fs = [] 
    for i in range(len(Ma)): 
        pairs = [(scales[0], Ma[i][0])]
        count = 0
        for j in range(1, len(Ma[i])):
            # keep only pairs that do NOT satisfy (a_{j-1}, Ma_{j-1}) = (a_j, Ma_j)
            if pairs[count][1] != Ma[i][j]:
                pairs.append((scales[j], Ma[i][j]))
                count+=1
        _scales = np.array([pair[0] for pair in pairs])
        _Mas = np.array([pair[1] for pair in pairs])
        fs.append(scipy.interpolate.interp1d(np.log(_Mas),np.log(_scales), bounds_error=False, fill_value=np.nan))
        
    #6. 
    am = np.array([np.exp(f(mass_bins)) for f in fs])
    
    return am, np.exp(mass_bins)

In [None]:
# am, mass_bins = get_am(name='m11')

# # compare interpolation at random values to make sure we did the correct thing
# idx = np.random.randint(len(am))
# fig, axes = plt.subplots(1, 2, figsize=(8, 4))
# (ax1, ax2) = axes.flatten()
# ax1.plot(scales, Ma[idx])
# ax2.plot(am[idx], np.exp(mass_bins))

## Now calculate correlation plot for m11

In [None]:
names = ['m11', 'm12']
params = ['cvir', 'x0', 'v0', 't/|u|', 'q', 'phi_l']
latex_params = ['c_{\\rm vir}', 'x_{\\rm off}', 'v_{\\rm off}', 't/|u|', 'q', '\\Phi_{L}']
colors = ['r','b', 'g', 'm', 'k', 'y']
markers = np.array(['.', 'x'])

In [None]:
def get_am_corrs(cat, param, am): 
    corrs = [] 
    n_mass_bins = am.shape[1]
    for k in range(n_mass_bins):
        keep = ~np.isnan(am[:, k])
        
        # get mass fraction at this scale
        am_k = am[:, k][keep]
        pvalue = cat[param][keep]

        # get correlation.
        assert np.all(~np.isnan(pvalue)) and np.all(~np.isnan(am_k))
        corr = stats.spearmanr(pvalue, am_k)[0]
        corrs.append(corr)
        
    return np.array(corrs)


In [None]:
fig, axes = plt.subplots(1,len(names),figsize=(len(names)*7,7))
axes = axes.flatten() if len(names) > 1 else [axes]
for i, name in enumerate(names):
    hcat, indices, scales = setup(name)
    am, mass_bins = get_am(name)
    ax = axes[i]
    max_mass_bins = [0.]*len(params)

    for j, param in enumerate(params):
        latex_param = latex_params[j]
        color = colors[j]
        corrs = get_am_corrs(hcat.cat, param, am)
        pos = corrs >= 0 
        neg = ~pos
        corrs = abs(corrs)
        
        # plot positive corr and negative corr with different markers. 
        if sum(pos) > 0:
            label = f'${latex_param}$' if sum(pos) > sum(neg) else None
            ax.plot(mass_bins[pos], corrs[pos], color=color, marker=markers[0], label=label, markersize=7)
            
        if sum(neg) > 0:
            label = f'${latex_param}$' if sum(pos) < sum(neg) else None
            ax.plot(mass_bins[neg], corrs[neg], color=color, marker=markers[1], label=label, markersize=7)
            
        max_mass_bins[j] = mass_bins[np.nanargmax(abs(corrs))]
    
        
                
    # draw a vertical line at max scales
    for j, mbin in enumerate(max_mass_bins):
        color = colors[j]
        ax.axvline(mbin, linestyle='--', color=color)
        

    ax.set_ylim(0, 1.0)
    ax.set_xlim(0.01, 1.0)
#     ax.set_xscale('log')
    
    ax.set_title(name, size=22)
    ax.set_ylabel(f"$\\rho(\\cdot, a(m))$", size=22)
    ax.set_xlabel(f"$m$", size=22)
    ax.tick_params(axis='both', which='major', labelsize=16, size=10)    
    ax.tick_params(axis='x', which='minor', size=8)    
    ax.legend(loc='best', prop={'size': 14})

plt.tight_layout()
plt.show()

## Same minimum mass fraction (m11)? 

In [None]:
from relaxed.progenitors.catalog import get_ma, get_alpha, lma_fit

In [None]:
hcat, indices, scales = setup('m11')
ma = get_ma(hcat.cat, indices)
min_ma = [] 
for i in range(len(ma)):
    _ma = ma[i, :]
    _ma = _ma[_ma > 0]
    min_ma.append(np.min(_ma))
min_ma = np.array(min_ma)
gbg_values = np.log10(min_ma) < -4
print(min_ma.shape)
print(sum(np.log10(min_ma) < -3))
plt.hist(np.log10(np.array(min_ma)), bins=50)

In [None]:
np.where(gbg_values)

In [None]:
cat = hcat.cat[gbg_values]

[cat[f'mvir_a{index}'][100] for index in indices]

In [None]:
scales[170]

## check weird values

In [None]:
from relaxed.progenitors import progenitor_lines

In [None]:
gen = progenitor_lines.get_prog_lines_generator('../temp/bolshoi_progenitors/mline_0_0_0.txt')

In [None]:
from IPython.display import clear_output

In [None]:
for g in gen:
    clear_output(wait=True)
    if g.root_id == 3058686024:
        break