# Set Up

In [1]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy
import yt
import tqdm
import pandas as pd
import trident as tr
from trident.absorption_spectrum.absorption_line import tau_profile
from linetools.lists.linelist import LineList

In [2]:
from linetools.analysis import absline
import astropy.units as u

# Parameters

In [None]:
# Location of the target snapshot
simulation_fp = '/Users/zhafen/data/fire/fire2/metal_diffusion/m12i_res57000/output/snapshot_600.hdf5'

# Location of the halo file if provided. If use_halo_file == False then the center will be found automatically.
use_halo_file = True
halo_catalog_fp = '/Users/zhafen/data/fire/fire2/metal_diffusion/m12i_res57000/halo/rockstar_dm/catalog_hdf5/halo_600.hdf5'

# Projection choices
proj_halfwidth_kpc = 10. # Halfwidth in units of kpc
proj_xrange_halfwidth = np.array([ 0.5, 0.5 ]) # Slice location in units of halfwidth
default_resolution = 100 # Points per side

# Location and width of inset in kpc
zoom_yrange_kpc = np.array([ 0., 1. ])
zoom_zrange_kpc = np.array([ 0., 1. ])

# Obs choices
EW_min = 0.01
b_default = 30. # In km/s

# Load Data

In [None]:
np.set_printoptions(precision=20)

In [None]:
ds = yt.load( simulation_fp )

In [None]:
data = ds.all_data()

In [None]:
tr.add_ion_fields( ds, ions=['O VI','C IV', 'Si II', 'Mg II','Na I'],)

## The Center of Galaxy Using Halo Data

In [None]:
kpc = ds.quan( 1, 'kpc' )

In [None]:
if use_halo_file:
    import h5py
    f = h5py.File( halo_catalog_fp, 'r' ) #f is a dictionary, so you can do f.keys()
    index = f['mass'][...].argmax() 

    #... means indexing all of it, because HDF5 doesn't show you everything unless you ask for it
    #.argmax() gets the index of maximum.
    #maximum mass is center because host galaxy is the one in the center so it should have the most mass.

    center_ckpc = f['position'][...][index]
    #the actual code to bring it to modern day, getting physical distance, is below. 
    #not used because error but f[redshift] is 0 so it's basically center_ckpc

    center = center_ckpc / ( 1. + f['snapshot:redshift'][...] )

    center_kpc=center*kpc
else:
    print( 'Finding center using maximum density' )
    _, center = ds.find_max( ('gas', 'density') )
    center_kpc = center.to( 'kpc' )

## Making the Sun the Origin Using Vector Math

In [None]:
gas_coordinates = (data[('PartType0', 'Coordinates')]).in_units("kpc")

### First need to filter out really far star particles that are not in the galaxy

In [None]:
#apply distance formula from each star coordinate to the center. if <150 kpc, it is within a good range.
distance_to_center=np.sqrt( (gas_coordinates[:,0]-center_kpc[0] )**2 + (gas_coordinates[:,1]-center_kpc[1])**2 +(gas_coordinates[:,2]-center_kpc[2])**2 )

within_range = distance_to_center < (150.*kpc)
galaxy_gas = gas_coordinates[within_range]


### Then do vector math, applying refined list of star coordinates

In [None]:
# Get dataset units
kpc = ds.quan( 1, 'kpc' )

# Calculate the angular momentum (extracting a sphere around the galaxy first for it).
# `center` is the center of the galaxy in kpc.
sp = ds.sphere( center_kpc , (10, "kpc")) #10 is radius
jtot = sp.quantities.angular_momentum_vector( particle_type='PartType0' ).to( 'kpc * km / s' ).value

# Unit vectors for a frame w/ zhat aligned with the total angular momentum
zhat = jtot / np.linalg.norm( jtot ) #basically getting the unit vector by dividing jtot vector by its magnitude

#cross product of zhat and some vector to get a vector that's perpendicular to zhat (along galaxy disk):
xhat = np.cross( [ 1, 0, 0 ], zhat )
xhat /= np.linalg.norm( xhat )

# Sun position in simulation coordinates -> 8 kpc from center at an arbitrary angle
sun_position = center_kpc + (8. * xhat * kpc)

# Unit vectors for on-sky coordinate system.
# xskyhat points from the sun to the center of the galaxy.
# zskyhat points parallel to the axis of total angular momentum.
# yskyhat (should) point to the left on a sky map.
xskyhat = -xhat
zskyhat = zhat
yskyhat = np.cross( zskyhat, xskyhat )

# Given a series of positions (the undefined vector)
# shift them over, i.e. make a new frame centered on the sun.
positions_sun = galaxy_gas - sun_position

# And now rotate them to the sky frame
positions_sky = np.array([
    np.dot( positions_sun, xskyhat ), # Check the axis argument, this may not be right
    np.dot( positions_sun, yskyhat ),
    np.dot( positions_sun, zskyhat ),
    ]).transpose()

# Generate Images

## General Function

In [None]:
yrange_full_proj_kpc = np.array([ -1, 1 ]) * proj_halfwidth_kpc
zrange_full_proj_kpc = np.array([ -1, 1 ]) * proj_halfwidth_kpc

In [None]:
def get_image( field, yrange_kpc=yrange_full_proj_kpc, zrange_kpc=zrange_full_proj_kpc ):
    
    # Convert ranges to unyt
    yrange = yrange_kpc * kpc
    zrange = zrange_kpc * kpc
    proj_halfwidth = proj_halfwidth_kpc * kpc
    proj_xrange = proj_xrange_halfwidth * proj_halfwidth
    
    # Get width/depth
    proj_ywidth = yrange[1] - yrange[0]
    proj_zwidth = zrange[1] - zrange[0]
    proj_depth = proj_xrange[1] - proj_xrange[0]

    # Recenter
    new_center = (
        sp.center
        + xskyhat * 0.5 * ( proj_xrange[0] + proj_xrange[1] )
        + yskyhat * 0.5 * ( yrange[0] + yrange[1] )
        + zskyhat * 0.5 * ( zrange[0] + zrange[1] )
    )
    
    image = yt.off_axis_projection(
        ds,
        normal_vector = xhat,
        item = field,
        width = np.array([ proj_ywidth, proj_zwidth, proj_depth ]),
        center = new_center,
        north_vector = zhat,
        resolution = (default_resolution, default_resolution),
    )
    
    return image

## Loop to get images

In [None]:
fields = [ ( 'gas', 'density' ), 'Na_p0_number_density' ]
field_units = [ 'Msun/kpc**2', 'cm**-2' ]

In [None]:
# Full projection images
imgs = []
for i, field in enumerate( tqdm.tqdm( fields ) ):
    img = get_image( field, ).to( field_units[i] )
    imgs.append( img )

In [None]:
# Zoom images
zoom_imgs = []
for i, field in enumerate( tqdm.tqdm( fields ) ):
    img = get_image( field, zoom_yrange_kpc, zoom_zrange_kpc ).to( field_units[i] )
    zoom_imgs.append( img )

# Convert to Equivalent Width

## Curve of Growth

### Get range

In [None]:
linelist = LineList( 'ISM' )
NaI_lines = linelist.all_transitions( (11,1) )
transition = NaI_lines['name'][0]

In [None]:
colden_min = absline.N_from_Wr_transition( EW_min * u.angstrom, transition ).value

### Calculate

In [None]:
coldens_cog = np.logspace( np.log10( colden_min ), 20, 1001 ) * u.cm**-2
bs = np.arange( 0., 50, 5., )

In [None]:
EWs_cog = []
for b in bs:
    EWs_b = absline.Wr_from_N_b_transition( coldens_cog, np.full( coldens_cog.shape, b ) * u.km / u.s, transition ).value
    EWs_cog.append( EWs_b )

In [None]:
EWs_default = absline.Wr_from_N_b_transition( coldens_cog, np.full( coldens_cog.shape, b_default ) * u.km / u.s, transition ).value

In [None]:
color_norm = matplotlib.colors.Normalize( vmin=bs[0], vmax=bs[-1] )
colormap = matplotlib.colormaps.get( 'viridis' )

In [None]:
# Plot curve of growth
fig = plt.figure()
ax = plt.gca()

for i, EWs_b in enumerate( EWs_cog ):
    ax.plot(
        coldens_cog,
        EWs_b,
        color = colormap( color_norm( bs[i] ) ),
    )
    
ax.plot(
    coldens_cog,
    EWs_default,
    color = 'k',
)

ax.set_xlabel( r'$N_{Na\,I}$ [cm$^{-2}$]' )
ax.set_ylabel( r'EW [$\AA$]' )

ax.set_xscale( 'log' )
ax.set_yscale( 'log' )

### Create interpolation function

In [None]:
EW_interp_fn = scipy.interpolate.interp1d( coldens_cog, EWs_default )

## Conversion Itself

In [None]:
# Full projection
coldens = imgs[1]
coldens[coldens<colden_min] = colden_min
EWs = EW_interp_fn( coldens )
EWs_masked = np.ma.masked_where( EWs <= EW_min, EWs )

In [None]:
# Zoom
coldens_zoom = zoom_imgs[1]
coldens_zoom[coldens_zoom<colden_min] = colden_min
EWs_zoom = EW_interp_fn( coldens )
EWs_zoom_masked = np.ma.masked_where( EWs_zoom <= EW_min, EWs_zoom )

# Plot

## Setup

In [None]:
ys = np.linspace( yrange_full_proj_kpc[0], yrange_full_proj_kpc[1], default_resolution )
zs = np.linspace( zrange_full_proj_kpc[0], zrange_full_proj_kpc[1], default_resolution )

In [None]:
ys_zoom = np.linspace( zoom_yrange_kpc[0], zoom_yrange_kpc[1], default_resolution )
zs_zoom = np.linspace( zoom_zrange_kpc[0], zoom_zrange_kpc[1], default_resolution )

## Projected and zoom, individually

In [None]:
# Projected
fig = plt.figure()
ax = plt.gca()

ax.pcolormesh(
    ys,
    zs,
    EWs.transpose(),
    norm = matplotlib.colors.LogNorm(),
)

ax.set_aspect( 'equal' )

In [None]:
# Projected
fig = plt.figure()
ax = plt.gca()

ax.pcolormesh(
    ys_zoom,
    zs_zoom,
    EWs_zoom.transpose(),
    norm = matplotlib.colors.LogNorm(),
)

ax.set_aspect( 'equal' )

# TPCF

In [None]:
assert False, 'Not fully implemented yet..'

## Calculate weighted TPCF

In [None]:
import stained_glass.stats

In [None]:
y_mesh, z_mesh = np.meshgrid( ys, zs )
coords = np.array([ y_mesh.flatten(), z_mesh.flatten(), ]).transpose()
weights = EWs.flatten()

In [None]:
# Toss out non-detections
is_detectable = weights > EW_min
coords = coords[is_detectable]
weights = weights[is_detectable]

In [None]:
edges = np.logspace( -2, np.log10( zoom_width ) )
tpcf, edges, info = stained_glass.stats.weighted_tpcf(
    coords,
    weights = weights,
    edges = edges,
    return_info = True,
)

In [None]:
centers = 0.5 * ( edges[:-1] + edges[1:] )

In [None]:
# Characteristic cloud size
interp_fn = scipy.interpolate.interp1d( tpcf, centers )
l_cloud = interp_fn( 0.5 )

In [None]:
fig = plt.figure()
ax = plt.gca()

ax.plot(
    centers,
    tpcf,
    color = 'k',
    linewidth = 1.5,
)

# Characteristic cloud size
ax.axvline(
    l_cloud,
    color = 'k',
    linewidth = 1,
    linestyle = '--',
)
ax.axhline(
    0,
    color = '0.9',
    zorder = -10,
)

ax.set_xscale( 'log' )
ax.set_ylim( -1, 1 )
ax.set_xlim( centers[0], centers[-1] )