In [None]:
#!/usr/bin/env python
# coding: utf-8

In [None]:
import yt
import ytree
import numpy as np
from matplotlib import pylab
from yt.analysis_modules.halo_analysis.api import HaloCatalog

In [None]:
from yt.funcs import mylog
mylog.setLevel(50)

# BH particle filters

In [None]:
@yt.particle_filter('p3_bh', ['creation_time', 'particle_mass', 'particle_type'])
def p3_bh(pfilter, data):
    return (data['particle_type'] == 1) & (data['creation_time'] > 0) &         (data['particle_mass'].in_units('Msun') > 1e-3)

# Loading data

In [None]:
ts_s = yt.load('~jw254/data/SG64-2020/GravPotential/DD????/output_????')

In [None]:
ts_halos_s = yt.load('~jw254/data/SG64-2020/rockstar_halos-jhw/halos_DD????.0.bin')
# ts_halos_s includes output_0000 which is NOT present in ts_s

In [None]:
PATH='/storage/home/hcoda1/0/jw254/data/SG64-2020/rockstar_halos-jhw/trees/'

In [None]:
ts_trees = ytree.load('/storage/home/hcoda1/0/jw254/data/SG64-2020/rockstar_halos-jhw/trees/tree_0_0_0.dat')

In [None]:
fn = ts_trees.save_arbor()

In [None]:
ts_trees = ytree.load(fn)

# Projections of merger tree 0 with `ytree` data

make new matrix of scale_factors with shape (92, 3) where 92 = len(ts_trees[0]['prog', 'position'])

In [None]:
scale_factors_x3 = [[sf, sf, sf] for sf in ts_trees[0]['prog', 'scale_factor']]

multiply with scale_factors to match the data in yt snapshots

In [None]:
hl_pos = ts_trees[0]['prog', 'position'] * scale_factors_x3

reverse the order cuz the tree traces backwards from latest to oldest (biggest to smallest)

In [None]:
hl_pos = hl_pos[::-1].to('kpc')

In [None]:
hl_pos

multiply with scale_factors to match the data in yt snapshots

In [None]:
hl_vr = ts_trees[0]['prog', 'virial_radius'] * ts_trees[0]['prog', 'scale_factor']

reverse the order cuz the tree traces backwards from latest to oldest

In [None]:
hl_vr = hl_vr[::-1].to('kpc')

In [None]:
hl_vr

In [None]:
for ids in range(len(hl_vr)): # just look at the last 20 snapshots to save time
    print(ids)
    ds_s = ts_s[-len(hl_vr)+ids] # ts_s start from output_0002, length is 124
    print(ds_s)
    print(hl_vr[-len(hl_vr)+ids])

LOOP THRU EACH SNAPSHOT IN TIMESERIES

In [None]:
for ids in range(len(hl_vr)): # just look at the snapshots where the merger tree appears
    # LOAD EACH SNAPSHOT DATASET IN TIMESERIES
    ds_s = ts_s[-len(hl_vr)+ids] # ts_s start from output_0002, length is 124
    ds_s.add_particle_filter('p3_bh')
    # LOAD EACH HALO DATASET IN HALOS TIMESERIES
    halos_s = ts_halos_s[1:][-len(hl_vr)+ids] # ts_halos_s has output_0000 which was NOT present in ts_s, start from index 1 to not include output_0000

    ###### CODE FOR EACH SNAPSHOT

    # STEP 1: GET FIELD INFO FROM EACH SNAPSHOT
    ######
    # starting STEP 1: get field info from ds
    # CONVERT everything to the same unit
    bh_id = ds_s.r['p3_bh', 'particle_index']
    bh_pos = ds_s.r['p3_bh', 'particle_position'].to('pc') # (17, 3)
   
    ihl = -len(hl_vr)+ids # indexing the halos from -len(hl_vr) to -1

    # create a sphere of the host halo
    sp = ds_s.sphere(hl_pos[ihl], hl_vr[ihl])

    ### plot the sphere and positions of BH and halo center, red cross is halo center, pink cross is BH
    prj = yt.ProjectionPlot(ds_s, 'x', 'density', weight_field='density', data_source=sp, center=hl_pos[ihl], width=(6,'kpc'))

    ### annotate the BHs
    # load bh_pos again every time because the last position vector in bh_pos gets altered after running the cell
    bh_pos = ds_s.r['p3_bh', 'particle_position'].to('pc') # (17, 3)
    prj.annotate_particles((5, 'kpc'), p_size=10, ptype='p3_bh', col='red')

    ### annotate halos
    prj.annotate_marker([0,0], coord_system='plot', plot_args={'color':'red', 's':500}, marker = 'x')
    prj.annotate_sphere([0,0], radius=hl_vr[ihl], coord_system='plot', circle_args={'color':'green', 'linewidth':4, 'linestyle':'dashed'})

    ### timestamps & set limits for colorbar
    prj.annotate_timestamp(redshift=True) # add timestamp and redshift
    prj.set_zlim('density', zmin=1e-26, zmax=1e-23) ### set limits for colorbar
    
    print(ihl) # just for track-keeping
    
    prj.save("images_ts-ver?-mergertree0/")