In [1]:
import h5py as h5
import arepo
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
from scipy.interpolate import interp1d
from scipy.stats import binned_statistic_2d
from numba import njit
from astropy.io import fits
import os
from joblib import Parallel, delayed
import warnings

basepath = '/n/holylfs05/LABS/hernquist_lab/Users/abeane/GSEgas/'

import sys
sys.path.append(basepath+'note/')
import galaxy

from scipy.ndimage import gaussian_filter

import illustris_python as il
TNGbase = '/n/holylfs05/LABS/hernquist_lab/IllustrisTNG/Runs/L35n2160TNG/output/'

In [2]:
# mpl.rc('text', usetex=True)
# mpl.rc('text.latex', preamble=r"""
# \usepackage{amsmath}
# """)
# # mpl.rcParams.update({'font.size': 22})
# # mpl.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}']
# # color palette
# tb_c = ['#4e79a7', '#f28e2b', '#e15759', '#76b7b2', '#59a14f',
#         '#edc948', '#b07aa1', '#ff9da7', '#9c755f', '#bab0ac']

# columnwidth = 242.26653 / 72.27 # converts pts to inches
# textwidth = 513.11743 / 72.27

# mpl.rcParams.update({'font.size': 8})

In [19]:
import numpy as np
from numba import njit

@njit
def compute_com(pos0, pos, mass, ri, rf, rfac):
    """Numba-accelerated computation of the center of mass."""
    rcut = ri
    while rcut > rf:
        rsq = np.sum((pos - pos0)**2, axis=1)
        mask = rsq < rcut**2
        if np.sum(mask) == 0:  # Prevent empty slice errors
            break
        
        weights = mass[mask]
        pos_masked = pos[mask]
        pos0 = np.sum(pos_masked * weights[:, None], axis=0) / np.sum(weights)

        rcut *= rfac
    return pos0


def _get_com_of_snap(idx, output_dir, ri, rf, rfac):
    def _get_subhalo_pos(idx, output_dir):
        # first get the subhalo pos as a starting point
        subfname = output_dir + '/fof_subhalo_tab_'+str(idx).zfill(3)+'.hdf5'
        if not os.path.exists(subfname):
            return np.array([np.nan, np.nan, np.nan]), np.array([np.nan, np.nan, np.nan])
        
        sub = h5.File(subfname, mode='r')
        subpos0 = sub['Subhalo/SubhaloPos'][0]
        if len(sub['Subhalo/SubhaloPos']) > 1:
            subpos1 = sub['Subhalo/SubhaloPos'][1]
        else:
            subpos1 = np.array([np.nan, np.nan, np.nan])
        sub.close()
        
        return subpos0, subpos1
    
    def _load_snap(idx, output_dir):
        # now load in all particles equally
        snapname = output_dir + '/snapshot_'+str(idx).zfill(3)+'.hdf5'
        if not os.path.exists(snapname):
            return np.nan, np.nan, np.nan
        
        snap = h5.File(snapname, mode='r')
        NumPart_Total = snap['Header'].attrs['NumPart_Total']
        MassTable = snap['Header'].attrs['MassTable']
        Time = snap['Header'].attrs['Time']
        
        pos = []
        vel = []
        mass = []
        for i,npart in enumerate(NumPart_Total):
            if npart == 0:
                continue
                
            ptype = 'PartType'+str(i)  
            if 'Coordinates' not in snap[ptype].keys():
                continue
        
            pos.append(snap[ptype]['Coordinates'][:])
            vel.append(snap[ptype]['Velocities'][:])
        
            if MassTable[i] > 0:
                mass.append(np.full(npart, MassTable[i]))
            else:
                mass.append(snap[ptype]['Masses'][:])
    
        pos = np.concatenate(pos)
        vel = np.concatenate(vel)
        mass = np.concatenate(mass)
        snap.close()
    
        return Time, pos, vel, mass
    
    def _get_com(pos0, pos, mass, ri, rf, rfac):
        if np.isnan(pos0).any():
            return pos0

        # Filter particles initially
        pos_init = np.copy(pos0)
        rsq = np.sum((pos - pos_init)**2, axis=1)
        mask = rsq < (4 * ri)**2
        pos_ = pos[mask]
        mass_ = mass[mask]
        
        # ensure float64
        pos0 = pos0.astype(np.float64)
        pos_ = pos_.astype(np.float64)
        mass_ = mass_.astype(np.float64)

        # Call Numba-accelerated function
        pos0 = compute_com(pos0, pos_, mass_, ri, rf, rfac)

        # Ensure the center of mass didn't travel too far
        rtravel = np.linalg.norm(pos0 - pos_init)
        assert rtravel < 2 * ri

        return pos0
    
    def _get_comv(com0, pos, mass, vel, rcut=4):
        if np.isnan(com0).any():
            return com0
        
        rsq = np.sum((pos - com0)**2, axis=1)
        rcutsq = rcut*rcut
        in_rcut = rsq < rcutsq
        
        # wts = np.copy(mass)
        # wts[np.logical_not(in_rcut)] = 0.
        
        return np.average(vel[in_rcut], weights=mass[in_rcut], axis=0)
    
    subpos0, subpos1 = _get_subhalo_pos(idx, output_dir)
    time, pos, vel, mass  = _load_snap(idx, output_dir)
    
    com0 = _get_com(subpos0, pos, mass, ri, rf, rfac)
    com1 = _get_com(subpos1, pos, mass, ri, rf, rfac)
    
    comv0 = _get_comv(com0, pos, mass, vel)
    comv1 = _get_comv(com1, pos, mass, vel)
    
    return time, com0, com1, comv0, comv1
    
def find_com(sim, ri=10., rf=5., rfac=0.9, usetqdm=False):
    key = 'lvl4-Rs'+sim[0]+'-Vv'+sim[1]+'-e'+sim[2]
    output_dir = basepath + 'runs/MW7_GSE4-eRVgrid-lvl4/' + key + '/output'
    
    itr = np.arange(321)
    if usetqdm:
        itr = tqdm(itr, leave=True, position=0)
    out = Parallel(n_jobs=64)(delayed(_get_com_of_snap)(idx, output_dir, ri, rf, rfac) for idx in itr)
    
    time    = np.array([o[0] for o in out])
    subpos0 = np.array([o[1] for o in out])
    subpos1 = np.array([o[2] for o in out])
    subvel0 = np.array([o[3] for o in out])
    subvel1 = np.array([o[4] for o in out])
    
    
    com = {}
    com['time']  = time
    com['com0']  = subpos0
    com['com1']  = subpos1
    com['comv0'] = subvel0
    com['comv1'] = subvel1
    
    return com

In [20]:
Rs_list = ['116', '129', '142']
Vv_list = ['116', '129', '142']
ecc_list = ['04', '05', '06']

sim_list = [(Rs, Vv, ecc) for Rs in Rs_list for Vv in Vv_list for ecc in ecc_list]

In [21]:
com = find_com(sim_list[0], usetqdm=True)

100%|██████████| 321/321 [00:21<00:00, 15.17it/s]


In [22]:
for sim in tqdm(sim_list, leave=True, position=0):
    com = find_com(sim)
    np.save('centering/sim_'+sim[0]+'_'+sim[1]+'_'+sim[2]+'.npy', com)

 20%|█▉        | 64/321 [17:34<1:10:33, 16.47s/it]
100%|██████████| 27/27 [17:40<00:00, 39.28s/it]


In [7]:
sim = ('142', '116', '04')
sd = 0
com = np.load('centering/sim_'+sim[0]+'_'+sim[1]+'_'+sim[2]+'_seed'+str(sd)+'.npy', allow_pickle=True).item()

In [23]:
com.keys()

dict_keys(['time', 'com0', 'com1', 'comv0', 'comv1'])