# Set Up

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy
import yt
import pandas as pd
import trident as tr
from trident.absorption_spectrum.absorption_line import tau_profile
from linetools.lists.linelist import LineList

In [None]:
from linetools.analysis import absline
import astropy.units as u

# Parameters

In [None]:
# Location of the target snapshot
simulation_fp = '/Users/zhafen/data/fire/fire2/m12i_res7100_md/output/snapdir_600/snapshot_600.0.hdf5'

# Location of the halo file if provided. If use_halo_file == False then the center will be found automatically.
use_halo_file = False
halo_catalog_fp = None

# Location and width of zoom in kpc
zoom_center = [ 1., 1. ]
zoom_width = 2.
n_default = 800

# Obs choices
EW_min = 0.01
b_default = 30. # In km/s

# Load Data

In [None]:
np.set_printoptions(precision=20)

In [None]:
ds = yt.load( simulation_fp )

In [None]:
data = ds.all_data()

In [None]:
tr.add_ion_fields( ds, ions=['O VI','C IV', 'Si II', 'Mg II','Na I'],)

## The Center of Galaxy Using Halo Data

In [None]:
kpc = ds.quan( 1, 'kpc' )

In [None]:
if use_halo_file:
    import h5py
    f = h5py.File( halo_catalog_fp, 'r' ) #f is a dictionary, so you can do f.keys()
    index = f['mass'][...].argmax() 

    #... means indexing all of it, because HDF5 doesn't show you everything unless you ask for it
    #.argmax() gets the index of maximum.
    #maximum mass is center because host galaxy is the one in the center so it should have the most mass.

    center_ckpc = f['position'][...][index]
    #the actual code to bring it to modern day, getting physical distance, is below. 
    #not used because error but f[redshift] is 0 so it's basically center_ckpc

    center = center_ckpc / ( 1. + f['snapshot:redshift'][...] )

    center_kpc=center*kpc
else:
    print( 'Finding center using maximum density' )
    _, center = ds.find_max( ('gas', 'density') )
    center_kpc = center.to( 'kpc' )

## Making the Sun the Origin Using Vector Math

In [None]:
gas_coordinates = (data[('PartType0', 'Coordinates')]).in_units("kpc")

### First need to filter out really far star particles that are not in the galaxy

In [None]:
#apply distance formula from each star coordinate to the center. if <150 kpc, it is within a good range.
distance_to_center=np.sqrt( (gas_coordinates[:,0]-center_kpc[0] )**2 + (gas_coordinates[:,1]-center_kpc[1])**2 +(gas_coordinates[:,2]-center_kpc[2])**2 )

within_range = distance_to_center < (150.*kpc)
galaxy_gas = gas_coordinates[within_range]


### Then do vector math, applying refined list of star coordinates

In [None]:
# Get dataset units
kpc = ds.quan( 1, 'kpc' )

# Calculate the angular momentum (extracting a sphere around the galaxy first for it).
# `center` is the center of the galaxy in kpc.
sp = ds.sphere( center_kpc , (10, "kpc")) #10 is radius
jtot = sp.quantities.angular_momentum_vector( particle_type='PartType0' ).to( 'kpc * km / s' ).value

# Unit vectors for a frame w/ zhat aligned with the total angular momentum
zhat = jtot / np.linalg.norm( jtot ) #basically getting the unit vector by dividing jtot vector by its magnitude

#cross product of zhat and some vector to get a vector that's perpendicular to zhat (along galaxy disk):
xhat = np.cross( [ 1, 0, 0 ], zhat )
xhat /= np.linalg.norm( xhat )

# Sun position in simulation coordinates -> 8 kpc from center at an arbitrary angle
sun_position = center_kpc + (8. * xhat * kpc)

# Unit vectors for on-sky coordinate system.
# xskyhat points from the sun to the center of the galaxy.
# zskyhat points parallel to the axis of total angular momentum.
# yskyhat (should) point to the left on a sky map.
xskyhat = -xhat
zskyhat = zhat
yskyhat = np.cross( zskyhat, xskyhat )

# Given a series of positions (the undefined vector)
# shift them over, i.e. make a new frame centered on the sun.
positions_sun = galaxy_gas - sun_position

# And now rotate them to the sky frame
positions_sky = np.array([
    np.dot( positions_sun, xskyhat ), # Check the axis argument, this may not be right
    np.dot( positions_sun, yskyhat ),
    np.dot( positions_sun, zskyhat ),
    ]).transpose()

# Trident Column Density Maps

In [None]:
print(jtot) #in kpc
print(xhat*kpc)
print(sp.center)

### A PLOT FOR GAS DENSITY (NO IONS SPECIFIED)

In [None]:
"""
the normal vector is the distance from the center to the sun along the plane of the galaxy (x axis)
the north_vector is jtot, aka zhat. (unit vector of angular momentum vector)

"""

prj_off_axis = yt.OffAxisProjectionPlot(
    ds,
    normal=xhat*kpc,
    fields=("gas", "density"),
    width=(100, "kpc"),
    center=sp.center, #works bc even tho same as center_kpc, it's in code length units
    north_vector=zhat
)
prj_off_axis.set_xlabel("y (kpc)")
prj_off_axis.set_ylabel("z (kpc)")
prj_off_axis

### Function for column densities of ions

In [None]:
def yt_column_density_plot(field, width=(100, "kpc"), center=sp.center, **kwargs):
    """ 
    to see trident ion column density, write this for the field argument:
    "<element>_<ionized how many times>_number_density"
    for example, O VI, oxygen ionized 5 times is:
    "O_p5_number_density"
    
    """

    prj_off_axis_column_density = yt.OffAxisProjectionPlot(
        ds,
        normal=8*xhat*kpc,
        fields=field,
        width=width,
        center=center,
        north_vector=zhat,
        **kwargs
    )
    prj_off_axis_column_density.set_xlabel("y (kpc)")
    prj_off_axis_column_density.set_ylabel("z (kpc)")
    return prj_off_axis_column_density

In [None]:
yt_column_density_plot("Na_p0_number_density", width=(40,"kpc"))

In [None]:
# Na I plot at 60 kpc
yt_column_density_plot("Na_p0_number_density", width=(10,"kpc"))

### Calculating zoom ins by changing the center before processing the image

In [None]:
# new center values:
new_center = sp.center + yskyhat * zoom_center[0] * kpc + zskyhat * zoom_center[1] * kpc

In [None]:
# Na I zoomed in 3 kpc above mid line
Na_off_center = yt_column_density_plot(
    "Na_p0_number_density",
    width=(zoom_width,"kpc"),
    center=sp.center,
    buff_size=(n_default,n_default)
)
Na_off_center

# Equivalent Width Images

## Curve of Growth

### Get range

In [None]:
linelist = LineList( 'ISM' )
NaI_lines = linelist.all_transitions( (11,1) )
transition = NaI_lines['name'][0]

In [None]:
colden_min = absline.N_from_Wr_transition( EW_min * u.angstrom, transition ).value

### Calculate

In [None]:
coldens_cog = np.logspace( np.log10( colden_min.value ), 20, 1001 ) * u.cm**-2
bs = np.arange( 0., 50, 5., )

In [None]:
EWs_cog = []
for b in bs:
    EWs_b = absline.Wr_from_N_b_transition( coldens_cog, np.full( coldens_cog.shape, b ) * u.km / u.s, transition ).value
    EWs_cog.append( EWs_b )

In [None]:
EWs_default = absline.Wr_from_N_b_transition( coldens_cog, np.full( coldens_cog.shape, b_default ) * u.km / u.s, transition ).value

In [None]:
color_norm = matplotlib.colors.Normalize( vmin=bs[0], vmax=bs[-1] )
colormap = matplotlib.colormaps.get( 'viridis' )

In [None]:
fig = plt.figure()
ax = plt.gca()

for i, EWs_b in enumerate( EWs_cog ):
    ax.plot(
        coldens_cog,
        EWs_b,
        color = colormap( color_norm( bs[i] ) ),
    )
    
ax.plot(
    coldens_cog,
    EWs_default,
    color = 'k',
)

ax.set_xlabel( r'$N_{Na\,I}$ [cm$^{-2}$]' )
ax.set_ylabel( r'EW [$\AA$]' )

ax.set_xscale( 'log' )
ax.set_yscale( 'log' )

### Create interpolation function

In [None]:
EW_interp_fn = scipy.interpolate.interp1d( coldens_cog, EWs_default )

## Plot Equivalent Widths

In [None]:
# Convert to EW
coldens = Na_off_center.frb[Na_off_center.fields[0]].value
coldens[coldens<colden_min] = colden_min
EWs = EW_interp_fn( coldens )

In [None]:
side_mins = np.array( zoom_center ) - zoom_width/2.
side_maxs = np.array( zoom_center ) + zoom_width/2.
ys = np.linspace( side_maxs[0], side_mins[0], n_default )
zs = np.linspace( side_mins[1], side_maxs[1], n_default )

In [None]:
EWs_masked = np.ma.masked_where(EWs <= EW_min, EWs)

In [None]:
fig = plt.figure(figsize=(8,8), facecolor='w' )
ax = plt.gca()

pcolor = ax.pcolormesh(
    ys,
    zs,
    EWs_masked,
    norm = matplotlib.colors.LogNorm(EW_min,EWs.max()),
)

ax.set_aspect( 'equal' )

ax.set_xlabel( 'x (kpc; galactocentric)', fontsize = 14, )
ax.set_ylabel( 'y (kpc; galactocentric)', fontsize = 14, )
ax.annotate(
    text = 'Na I\n' + r'EW ($\AA$)',
    xy = ( 1, 1 ),
    xycoords = 'axes fraction',
    xytext = ( 5, -5 ),
    textcoords = 'offset points',
    va = 'top',
    fontsize = 14,
)

fig.colorbar( pcolor, ax=ax, location='right', anchor=(0, 0.3), shrink=0.7)

## Calculate weighted TPCF

In [None]:
import stained_glass.stats

In [None]:
y_mesh, z_mesh = np.meshgrid( ys, zs )
coords = np.array([ y_mesh.flatten(), z_mesh.flatten(), ]).transpose()
weights = EWs.flatten()

In [None]:
# Toss out non-detections
is_detectable = weights > EW_min
coords = coords[is_detectable]

In [None]:
edges = np.logspace( -2, np.log10( zoom_width ) )
tpcf, edges, info = stained_glass.stats.weighted_tpcf(
    coords,
    weights = EWs.flatten(),
    edges = edges,
    return_info = True,
)

In [None]:
centers = 0.5 * ( edges[:-1] + edges[1:] )

In [None]:
# Characteristic cloud size
interp_fn = scipy.interpolate.interp1d( tpcf, centers )
l_cloud = interp_fn( 0.5 )

In [None]:
fig = plt.figure()
ax = plt.gca()

ax.plot(
    centers,
    tpcf,
    color = 'k',
    linewidth = 1.5,
)

# Characteristic cloud size
ax.axvline(
    l_cloud,
    color = 'k',
    linewidth = 1,
    linestyle = '--',
)
ax.axhline(
    0,
    color = '0.9',
    zorder = -10,
)

ax.set_xscale( 'log' )
ax.set_ylim( -1, 1 )
ax.set_xlim( centers[0], centers[-1] )