In [None]:
import os, h5py, numpy as np
import illustris_python as il
import yt
import yt.units as ytu
from emis import emissivity
import matplotlib.pyplot as plt
import astropy.units as u
import astropy.constants as c
from plot_utils import quiver_from_components
from yt_field_utils import *
from scipy.spatial.transform import Rotation as R

In [None]:
with h5py.File('mw_tng_sample_catalog.hdf5') as f:
    mw_id = np.array(f['SubfindID'])

In [None]:
subid = mw_id[5]
outpath = f'/home/cj535/project/Mothra_CGM_CNN/MW_sample/TNG50_snap099_subid{subid:06d}_gas_sphere_500kpc.hdf5'
idx_path = outpath.replace(".hdf5", ".ytindex")
ds = yt.load(outpath, index_filename=idx_path,smoothing_factor=2)


add_emission_fields(ds,line="Halpha")
#add_emission_weight_fields(ds,minL=1e-43,line='Halpha')
#add_coldgas_weight_fields(ds,maxT=10**4.5)

vel_bounds = [(-300, -100),(-100, 100),(100, 300)]
for vb in vel_bounds:
    lo,hi = vb
    add_los_mask_fields(ds,lo, hi)
    add_los_mask_emission_fields(ds,lo,hi,line='Halpha')

In [None]:
angle_number=0
theta_list = [['x',0,'y',0],
              ['x',0,'y',30],
              ['x',0,'y',60],
              ['x',0,'y',90],
              ['x',30,'y',0],
              ['x',60,'y',0],
              ['x',90,'y',0],
              ['x',90,'z',30],
              ['x',90,'z',60],
              ['x',45,'z',45]]
theta = theta_list[angle_number]
r1 = R.from_euler(theta[0], theta[1],degrees=True)
r2 = R.from_euler(theta[2], theta[3],degrees=True)
z_hat = np.array([0,0,1])
y_hat = np.array([0,1,0])
w_hat = r2.apply(r1.apply(z_hat))
v_hat = r2.apply(r1.apply(y_hat))
w_hat /= np.sqrt(np.dot(w_hat,w_hat))
v_hat /= np.sqrt(np.dot(v_hat,v_hat))
u_hat = np.cross(v_hat,w_hat)
add_velocity_projection_fields(ds,u_hat,v_hat,w_hat)

In [None]:
center_comov = 0.5 * ds.domain_width          # code center (comoving)
center_prop  = (center_comov.to("kpc") * ds.scale_factor)

In [None]:
def _vel_sph_rad(field, data):
    
    # define the side length of the simulation domain in kpc
    box_side_length = (data.ds.domain_width[0].in_units('kpc'))/2.
    x     = data['gas', 'x'].in_units("kpc") - box_side_length
    y     = data['gas', 'y'].in_units("kpc") - box_side_length
    z     = data['gas', 'z'].in_units("kpc") - box_side_length
    
    r     = np.sqrt(x**2 + y**2 + z**2)
    
    theta = np.arctan2(y, x)
    phi   = np.arccos(z/r)

    vel_sph_rad = (data['gas', 'velocity_x'].in_units("km/s") * np.cos(theta) * np.sin(phi) +
                   data['gas', 'velocity_y'].in_units("km/s") * np.sin(theta) * np.sin(phi) +
                   data['gas', 'velocity_z'].in_units("km/s") * np.cos(phi))

    return (vel_sph_rad)


# spherical radius
def _spherical_r(field, data):
    
    # define the side length of the simulation domain in kpc
    box_side_length = (data.ds.domain_width[0].in_units('kpc'))/2.
    spherical_r   = np.sqrt((data['gas', 'x'].in_units("kpc") - box_side_length)**2 + 
                            (data['gas', 'y'].in_units("kpc") - box_side_length)**2 + 
                            (data['gas', 'z'].in_units("kpc") - box_side_length)**2)
    return (spherical_r)


ds.add_field(("gas", "vel_sph_rad"), function=_vel_sph_rad, units="km/s", sampling_type="particle")
ds.add_field(("gas", "spherical_r"), function=_spherical_r, units="kpc", sampling_type="particle")

In [None]:
delta_r = 5

shell_midpoints = np.arange(20,205,5)
shell_names = []

def add_shell_filter(rmid):
    rmin,rmax = rmid-delta_r/2, rmid+delta_r/2
    
    def shell_filter(pfilter, data):
        # choose your particle family; for TNG gas is usually "PartType0"
        p = pfilter.filtered_type                # e.g., "PartType0"
        # Get particle positions (kpc)
        r = data[p, "spherical_r"].to("kpc")   # shape (N,3)
        return (r >= rmin) & (r < rmax)
    shell_name = f"shell_{rmid}"
    yt.add_particle_filter(
        shell_name, 
        function=shell_filter,
        filtered_type="gas",               # change to your particle type
        requires=["spherical_r"]
    )
    ds.add_particle_filter(shell_name)

for i in range(len(shell_midpoints)):
    rmid = int(shell_midpoints[i])
    add_shell_filter(rmid)
    
    shell_names.append(f"shell_{rmid}")

    


In [None]:
width_proper = (500.0, "kpc")                 # proper
center_comov = 0.5 * ds.domain_width          # code center (comoving)
center_prop  = (center_comov.to("kpc") * ds.scale_factor)
sp_small = ds.sphere(center_prop, (50.0, "kpc"))
bv = sp_small.quantities.bulk_velocity()
sp = ds.sphere(center_prop, (500.0, "kpc"))
sp.set_field_parameter("bulk_velocity", bv)

In [None]:
ds.derived_field_list

In [None]:
def joint_mask(sp,filter1,filter2):
    pid_all = sp['gas','index']
    pid_1 = sp[filter1,'index']
    pid_2 = sp[filter2,'index']
    pid_both = np.intersect1d(pid_1, pid_2, assume_unique=False)
    mask_both = np.in1d(pid_all, pid_both, assume_unique=False)
    return mask_both

In [None]:
mass_flow_array = []
for sn in shell_names:
    mask = joint_mask(sp,sn,'cold_gas')
    vr = sp['gas','vel_sph_rad'][mask].in_units('km/s')
    m = sp['gas','mass'][mask].in_units('Msun')
    mass_flow = np.sum(vr * m) / (delta_r*ytu.pc)
    mass_flow_array.append(mass_flow.in_units('Msun/yr'))


In [None]:
mass_flow_array

In [None]:
'''def calculate_mass_flux(volume, radius=50.0, halfwidth=2.5, cuts=[]):
    # Simplified mass flux calculation through a spherical shell

    shell = volume.cut_region([f"obj['gas', 'spherical_r'].in_units('kpc').value >= ({radius} - {halfwidth})",
                        f"obj['gas', 'spherical_r'].in_units('kpc').value <= ({radius} + {halfwidth})"])

    effective_radius = np.sum(shell['gas','cell_volume'].in_units('kpc**3')) / (4*np.pi* (radius*kpc)**2)

    # Apply metallicity / temp cuts
    if len(cuts) > 0:
        shell = shell.cut_region(cuts)

    mass_flux = ((1. / effective_radius) * np.sum(shell["gas", "mass"] * shell["gas", "vel_sph_rad"])).in_units("Msun/yr")

    return float(mass_flux.value)

'''cuts = ["obj['gas', 'metallicity'].in_units('Zsun').value  > 0.25",
        "obj['gas', 'temperature'].in_units('K').value < 300",]'''
Lcut = 5e-43
cuts = [f"obj['gas','Halpha_brightness'].to_value('erg/s/cm**3/arcsec**2') > ({Lcut})"]

sp_small = ds.sphere(center_prop, (50.0, "kpc"))
bv = sp_small.quantities.bulk_velocity()
sp = ds.sphere(center_prop, (500.0, "kpc"))
sp.set_field_parameter("bulk_velocity", bv)

shell = sp.cut_region(cuts)'''

In [None]:
add_emission_filter(ds,minL=1e-43,line='Halpha')
add_coldgas_filter(ds,maxT=10**4.5)

In [None]:
particle_filter = 'cold_gas'
resolution=(512,512)


for field in [(particle_filter,"velocity_u"),(particle_filter,"velocity_v"),(particle_filter,"velocity_w")]:
    p = yt.OffAxisProjectionPlot(
        ds, w_hat, field,
        center=center_prop, width=width_proper,
        north_vector=v_hat,
        weight_field=(particle_filter,"density"),
        buff_size=resolution,
        data_source=sp
    )
    img = p.frb[field]