In [None]:
%load_ext autoreload
%autoreload 2
%aimport

In [None]:

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt 
import re 
from astropy.table import Table
import astropy.table
import json
from scipy import stats
from copy import deepcopy

from relaxed import halo_parameters, halo_catalogs 
import warnings

## Utils

In [None]:
def get_a2(cat):

    # return the a_1/2 scale.
    idx = np.argmin(
        np.where(
            self.cat["mvir"] > self.cat["mvir"][0] * 0.5, self.cat["mvir"], np.inf
        )
    )

    return self.cat["scale"][idx]

# Read catalog

In [None]:
from matplotlib import rc
warnings.filterwarnings('ignore')
rc('text', usetex=True)

In [None]:
def setup(name='m11'):
    output = f'../temp/output_{name}/'
    cat_file = Path(output, 'final_table.csv')
    z_map_file = Path(output, 'z_map.json')
    
    with open(z_map_file,'r') as fp: 
        scale_map = json.load(fp) # map from i -> scale
        
    #only keep stable scales. 
    indices = np.array(list(scale_map.keys()))
    scales = np.array(list(scale_map.values()))
    keep = scales>.15 
    indices = indices[keep]
    scales = scales[keep] # we are removing from the end bc that's how scales are ordered.
    
    # load catalog.
    hcat = halo_catalogs.HaloCatalog('Bolshoi', cat_file, label=name)
    hcat.load_cat_csv()

    # remove weird ID
    hcat.cat = hcat.cat[hcat.cat['mvir_a18']>0]
    
    return hcat, indices, scales
    
    

def get_m_a_corrs(cat, param, indices): 
    corrs = [] 
    for k in indices:
        k = int(k)
        colname = f'mvir_a{k}'
        keep = (~np.isnan(cat[colname])) & (cat[colname] > 0)
        
        # get mass fraction at this scale
        mvir = cat['mvir'][keep]
        ms = cat[colname][keep]
        ms = ms / mvir
        pvalue = cat[param][keep]

        # get correlation.
        assert np.all(ms>0) and np.all(~np.isnan(ms))
        assert np.all(mvir>0)
        assert np.all(pvalue>0)
        corr = stats.spearmanr(ms, pvalue)[0]
        corrs.append(corr)
    
    return np.array(corrs)


def add_box_indices(cat, boxes=8, box_size=250):
    # create a new row add it to the catalogu for which box it is in. 
    assert int(boxes**(1./3)) == boxes**(1./3)
    box_per_dim = int(boxes**(1./3))
    divides = np.linspace(0, box_size, box_per_dim+1)[1:-1] # only use the middle. 
    cat.add_column(np.zeros(len(cat)), name='ibox')
    for k, dim in enumerate(['x','y','z']):
        for d in divides:
            cat['ibox'] += 2**k * (d < cat[dim])

def vol_jacknife_values(f, cat, param, *args):
    # assumes cat has had its box indices added with the function above.
    n_boxes = int(np.max(cat['ibox']) + 1)
    values = [] 
    for b in range(n_boxes):
        _cat = cat[cat['ibox'] != b]
        value = f(_cat, param, *args)
        values.append(value)
    return np.array(values )

In [None]:
hcat, indices, scales = setup('m11')

In [None]:
names = ['m11']
params = ['cvir', 'x0', 'v0', 't/|u|', 'q', 'phi_l']
latex_params = ['c_{\\rm vir}', 'x_{\\rm off}', 'v_{\\rm off}', 't/|u|', 'q', '\\Phi_{L}']
colors = ['r','b', 'g', 'm', 'k', 'y']
markers = np.array(['.', 'x'])

In [None]:
# get the jacknife errors for each index in indices for each param for each cat. 
errs = {name:{} for name in names}
for i,name in enumerate(names):
    hcat, indices, scales = setup(name)
    add_box_indices(hcat.cat)
    for j, param in enumerate(params):
        values = vol_jacknife_values(get_m_a_corrs, hcat.cat, param, indices)
        errs[name][param] = np.sqrt(values.var(axis=0)*7)

In [None]:
fig, axes = plt.subplots(1,len(names),figsize=(len(names)*7,7))
axes = axes.flatten() if len(names) > 1 else [axes]
for i, name in enumerate(names):
    hcat, indices, scales = setup(name)
    ax = axes[i]
    max_scales = [0.]*len(params)
    for j, param in enumerate(params):
        latex_param = latex_params[j]
        color = colors[j]
        corrs = get_m_a_corrs(hcat.cat, param, indices)
        err = errs[name][param]
        pos = corrs > 0 
        neg = ~pos
        corrs = abs(corrs)
        
        # plot positive corr and negative corr with different markers. 
        if sum(pos) > 0:
            label = f'${latex_param}$' if sum(pos) > sum(neg) else None
            ax.plot(scales[pos], corrs[pos], color=color, marker=markers[0], label=label, markersize=7)
            
        if sum(neg) > 0:
            label = f'${latex_param}$' if sum(pos) < sum(neg) else None
            ax.plot(scales[neg], corrs[neg], color=color, marker=markers[1], label=label, markersize=7)
            
        max_scales[j] = scales[np.nanargmax(abs(corrs))]
        
        nan_indx = np.isnan(err)
        err[nan_indx] = 0. 
        y1 = corrs - err
        y2 = corrs + err
        ax.fill_between(scales, y1, y2, alpha=0.2, linewidth=0.001, color=color)
        
                
    # draw a vertical line at max scales
    for j, s in enumerate(max_scales):
        color = colors[j]
        ax.axvline(s, linestyle='--', color=color)
        

    ax.set_ylim(0, 1.0)
    ax.set_title(name, size=22)
    ax.set_ylabel(f"$\\rho(\\cdot, m(a))$", size=22)
    ax.set_xlabel(f"$a$", size=22)
    ax.tick_params(axis='both', which='major', labelsize=16)

plt.legend(loc='best', prop={'size': 14})
plt.tight_layout()
        
    

# exponential fit to m(a) 

In [None]:

def ma_fit(z, M0, alpha, beta):
    return M0 * (1 + z)**beta * np.exp(-alpha * z)

def get_alpha(zs, m_a):
    # use the fit of the ofrm M(z) = M(0) * (1 + z)^{\beta} * exp(- \gamma * z) 
    # get best exponential fit to the line of main progenitors.
    from scipy.optimize import curve_fit

    M0 = m_a[0]

    def f(z, alpha, beta):
        return M0 * (1 + z)**beta * np.exp(-alpha * z)

    opt_params, _ = curve_fit(
        f, zs, m_a, p0=(1, 1)
    )

    return opt_params  # = alpha, beta


def get_m_a(cat, indices):
    m_a = np.zeros((len(cat), len(indices)))
    for k in indices:
        k = int(k)
        colname = f'mvir_a{k}'

        # get mass fraction at this scale
        mvir = cat['mvir']
        ms = cat[colname]
        ms = ms / mvir
        m_a[:, k] = ms
        
    return m_a

In [None]:
hcat, indices, scales = setup('m11')
zs = (1/ scales) -1 
m_a = get_m_a(hcat.cat, indices)

In [None]:
idx = 20
alpha,beta = get_alpha(zs, m_a[idx])
plt.plot(zs, m_a[idx])
plt.plot(zs, ma_fit(zs, m_a[idx][0], alpha, beta))

In [None]:
alphas = [] 
for idx in range(len(hcat.cat)):
    alpha,beta = get_alpha(zs, m_a[idx])
    alphas.append(alpha)
c = astropy.table.Column(alphas, name='alpha')
hcat.cat.add_column(c)

In [None]:
params

In [None]:
# calculate correlation between parameters and alpha 
for param in params:
    corr = stats.spearmanr(hcat.cat[param], hcat.cat['alpha'])[0]
    print(f'{param}: {corr}')