In [1]:
import h5py as h5
import arepo
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm
from scipy.interpolate import interp1d
from scipy.stats import binned_statistic_2d
from numba import njit
import importlib
import galaxy
from scipy.spatial import KDTree
from scipy.stats import gaussian_kde

basepath = '/n/holylfs05/LABS/hernquist_lab/Users/abeane/GSEgas/'

In [10]:
importlib.reload(galaxy)

<module 'galaxy' from '/n/holylfs05/LABS/hernquist_lab/Users/abeane/GSEgas/note/galaxy.py'>

In [3]:
# idx = 320

# gals = {}
# Rs_list = ['116', '129', '142']
# Vv_list = ['116', '129', '142']
# ecc_list = ['04', '05', '06']
# for Rs in tqdm(Rs_list):
#     gals[Rs] = {}
#     for Vv in Vv_list:
#         gals[Rs][Vv] = {}
#         for ecc in ecc_list:
#             key = 'lvl4-Rs'+Rs+'-Vv'+Vv+'-e'+ecc
#             gals[Rs][Vv][ecc] = galaxy.Galaxy(basepath + 'runs/MW7_GSE4-eRVgrid-lvl4/' + key + '/output',
#                                               idx, orient=True)

In [None]:
idx_list = np.arange(0, 320+1, 5)
gals_idx = {}
Rs = '142'
Vv = '142'
ecc = '04'

key = 'lvl4-Rs'+Rs+'-Vv'+Vv+'-e'+ecc
for idx in tqdm(idx_list):
    gals_idx[idx] = galaxy.Galaxy(basepath+'runs/MW7_GSE4-eRVgrid-lvl4/' + key + '/output',
                                  idx, orient=True)

 77%|███████▋  | 50/65 [02:28<00:50,  3.36s/it]

In [None]:
def get_logFeH_logMgFe(gal, ptype=4):
    GFM_SOLAR_ABUNDANCE_HYDROGEN = 0.7388
    GFM_SOLAR_ABUNDANCE_MAGNESIUM = 0.0007
    GFM_SOLAR_ABUNDANCE_IRON   =   0.0013
    
    part = getattr(gal.sn, 'part'+str(ptype))
    
    FeH = part.GFM_Metals[:,8] / GFM_SOLAR_ABUNDANCE_IRON
    logFeH = np.log10(FeH)

    MgH = part.GFM_Metals[:,6] / GFM_SOLAR_ABUNDANCE_MAGNESIUM
    MgFe = MgH/FeH
    logMgH = np.log10(MgH)
    logMgFe = np.log10(MgFe)
    
    return logFeH, logMgH, logMgFe

def in_SN(gal, ptype, Rmin=4, Rmax=16, zmin=0, zmax=3, dLz=0.1):
    part = getattr(gal.sn, 'part'+str(ptype))
    
    pos = part.rotpos
    vel = part.rotvel
    Lz = np.cross(pos, vel)[:,2]
    
    Lzsun = 8*220
    key = np.logical_and((Lz/Lzsun)-1 < dLz, np.abs(pos[:,2]) < zmax)
    
    return key

In [None]:
def get_kd_density(gal, logFeH_min=-1.1, logFeH_max=0.6, logMgFe_min=0.1, logMgFe_max=0.65, 
                   nres=1024, K=256):
    key = in_SN(gal, 4)
    logFeH, logMgH, logMgFe = get_logFeH_logMgFe(gal)
    data = np.vstack((logFeH[key], logMgFe[key])).T
    N = len(data)
    tree = KDTree(data)
        
    dlogFeH = (logFeH_max - logFeH_min)/nres
    dlogMgFe = (logMgFe_max - logMgFe_min)/nres
        
    logFeH_lin = np.linspace(logFeH_min + dlogFeH, logFeH_max-dlogFeH, nres)
    logMgFe_lin = np.linspace(logMgFe_min + dlogMgFe, logMgFe_max-dlogMgFe, nres)
    
    logFeH_grid, logMgFe_grid = np.meshgrid(logFeH_lin, logMgFe_lin)
    grid = np.vstack([logFeH_grid.ravel(), logMgFe_grid.ravel()]).T

    distances, indices = tree.query(grid, k=K)
    kth_distances = distances[:, -1]
    
    # Calculate the density: Inverse of the area of the circle surrounding the Kth neighbor
    density = K / (N * np.pi * kth_distances**2)
    density = np.reshape(density, logFeH_grid.shape).T
    
    extent = [logFeH_min, logFeH_max, logMgFe_min, logMgFe_max]
    
    return density, logFeH_grid, logMgFe_grid

In [None]:
%%time
gal_end = gals_idx[320]
density, logFeH_grid, logMgFe_grid = get_kd_density(gal_end)

In [None]:
cont.levels

In [None]:
fig, axs = plt.subplots(22, 3, figsize=(11, 60), sharex=True, sharey=True)

logFeH_min=-1.1
logFeH_max=0.6
logMgFe_min=0.1
logMgFe_max=0.65

subsamp = 5000

for i,idx in tqdm(enumerate(idx_list)):
    gal = gals_idx[idx]
    ax = axs.ravel()[i]
    
    # lay down contour of stars
    levels = [0, 3, 6, 9, 15, 21]
    cont = ax.contour(logFeH_grid, logMgFe_grid, density.T, levels, colors='k')
    
    is_in_SN = in_SN(gal, 0)
    logFeH, logMgH, logMgFe = get_logFeH_logMgFe(gal, ptype=0)
    
    # print(np.sum(is_in_SN))
    key = np.random.choice(np.where(is_in_SN)[0], subsamp, replace=False)
    ax.scatter(logFeH[key], logMgFe[key], s=1, alpha=0.5)
    
    is_sf_in_SN = np.logical_and(is_in_SN, gal.sn.part0.sfr > 0)
    sfkey = np.where(is_sf_in_SN)[0]
    sfkey = sfkey[np.isin(sfkey, key)]
    ax.scatter(logFeH[sfkey], logMgFe[sfkey], s=0.1, alpha=0.3)
    
    ax.set_title(str(round(gal.sn.Time.value, 2)))
    
    # break
    
    
axs[0][0].set(xlim=(logFeH_min, logFeH_max), ylim=(logMgFe_min, logMgFe_max))
fig.tight_layout()