# Preamble 

In [None]:
import astropy


In [None]:
%load_ext autoreload
%autoreload 2
%aimport
# %reload_ext autoreload

import sys 
from os.path import abspath
paths = [abspath('../..'), "/home/imendoza/alcca/nbody-relaxed/packages/minnow"]

for path in paths: 
    if path not in sys.path: 
        sys.path.insert(0, path)

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt 
import re 
from astropy.table import Table
import astropy.table

from relaxed import utils 
from relaxed.frames import params, catalogs, filters
from relaxed.progenitors import progenitors

In [None]:
from matplotlib.backends.backend_pdf import PdfPages
%matplotlib inline

# Load catalogs 

In [None]:
#load catalog we know and love (kindof)
# produced by --summary option from a /bin file 
prog_table_file = '/home/imendoza/alcca/nbody-relaxed/data/trees_bolshoi/progenitors.csv'
cat_filepath = Path('/home/imendoza/alcca/nbody-relaxed/data/Bolshoi/'
                    'minh/hlist_1.00035.minh')
name = 'BolshoiP'
hcat = catalogs.HaloCatalog(cat_filepath, name, verbose=True, add_subhalo=True, add_progenitor=prog_table_file)
hcat.load_base_cat()
cat = hcat._cat

In [None]:
#declare additional catalogs, all are of course automatically loaded when created in this way
colors = ['r','b', 'g']

print('total: ', len(hcat))

log_func = lambda x: np.log10(x)
hcat_m11 = catalogs.HaloCatalog.create_filtered_from_base(hcat, 
                                                          filters.get_bound_filter('mvir',high=11.22, modifier=log_func), 
                                                          label="M11")
hcat_m12 = catalogs.HaloCatalog.create_filtered_from_base(hcat, 
                                                          filters.get_bound_filter('mvir',12,12.2, modifier=log_func),
                                                          label="M12")
hcat_m13 = catalogs.HaloCatalog.create_filtered_from_base(hcat, 
                                                          filters.get_bound_filter('mvir',13, 14, modifier=log_func),
                                                          label="M13")

print(len(hcat_m11), len(hcat_m12), len(hcat_m13))
# 500 haloes are so > 13.75 

power_cat = catalogs.HaloCatalog.create_relaxed_from_base(hcat, 'power2011')
neto_cat = catalogs.HaloCatalog.create_relaxed_from_base(hcat, 'neto2007')
print(len(power_cat), len(neto_cat))

In [None]:
ids1 = set(hcat_m11._cat['id'])
ids2 = set(hcat_m12._cat['id'])
ids3 = set(hcat_m13._cat['id'])

# Save progenitors in sample 

In [None]:
from astropy.io import ascii
prog_file = '/home/imendoza/alcca/nbody-relaxed/data/trees_bolshoi/progenitors.txt'
prog_generator = progenitors.get_prog_lines_generator(prog_file)
# there are like 382474 main lines (in the filtered version)

In [None]:
new_progenitor_path = Path('/home/imendoza/alcca/nbody-relaxed/data/trees_bolshoi/progenitor_subset2/')

In [None]:
for i, prog in enumerate(prog_generator): 
    if i % 10000 == 0: 
        print(i)
    if prog.root_id in ids1: 
        path = new_progenitor_path.joinpath(f"{int(prog.root_id)}.csv")
        ascii.write(prog.cat, path, format='fast_csv', fast_writer=True)

In [None]:
# how long just load each of them 
for i, f in enumerate(new_progenitor_path.iterdir()):
#     print(f)
    t = Table.read(f, format='ascii.fast_csv')
    if i %1000 == 0: 
        print(i)

In [None]:
import json 
import csv

# TODO: add root_id to everything and root_halo_mvir

z_map = {}
count = 0
z_dir = new_progenitor_path.joinpath("z_files")
z_map_file = z_dir.joinpath("z_map.json")
z_dir.mkdir(exist_ok=True)

In [None]:
for i, f in enumerate(new_progenitor_path.iterdir()): 
    if f.suffix == '.csv':
        if i%1000 ==0: 
            print(i)

        prog_cat = Table.read(f, format='ascii.fast_csv')
        root_id = prog_cat[0]['halo_id']

        assert max(prog_cat['scale']) == prog_cat['scale'][0], "make sure order is correct"

        for row in prog_cat: 
            a, mvir = row['scale'], row['mvir']

            if a not in z_map: 
                z_map[a] = count 
                count+=1 
            z_file = z_dir.joinpath(f"{z_map[a]}.csv")


            fieldnames = ['root_id', 'mvir']

            if not z_file.exists(): 
                with open(z_file, 'w') as zf: 
                    writer = csv.DictWriter(zf, fieldnames)
                    writer.writeheader()
            with open(z_file, 'a') as zf: 
                writer = csv.DictWriter(zf, fieldnames)
                dct = {'root_id': root_id, 'mvir': mvir}
                writer.writerow(dct)
                zf.flush()

            

In [None]:
# # save json file for z_map
# import json 
# with open(z_map_file, 'w') as fp: 
#     json.dump(z_map, fp)

In [None]:
from scipy.stats import spearmanr
with open(z_map_file, 'r') as zf: 
    z_map = json.load(zf)
    
inv_z_map = {v:k for k,v in z_map.items()}

cat = hcat_m11.get_cat()
cat.sort('id')

scales = [] 
corrs = [] 
n_halos = [] 

z_files = np.sort([int(p.stem) for p in z_dir.iterdir() if p.suffix=='.csv'])
z_files = [z_dir.joinpath(str(name)+'.csv') for name in z_files]

for f in z_files:
    scale_key = int(f.stem)
    scale = float(inv_z_map[scale_key])
    scales.append(scale)

    t = Table.read(f, format='ascii.fast_csv')
    t.sort('root_id')
    t.rename_column('root_id', 'id')

    _cat = catalogs.intersection(cat, t)
    t = catalogs.intersection(t, _cat)
    
    n_halos.append(len(t))

    corr = spearmanr(t['mvir']/_cat['mvir'], _cat['cvir'])
    corrs.append(corr.correlation)

    

In [None]:
plt.plot(scales, corrs)