In [None]:
from __future__ import print_function, division

%matplotlib inline
from matplotlib import pyplot as plt
import seaborn as sns; sns.set(context="poster")
import ipywidgets
import yt
import glob
import os
import warnings
import h5py

import numpy as np
import pandas as pd

from astropy import constants as const
from astropy import units as u

M_solar = const.M_sun.cgs.value
m_proton = const.m_p.cgs.value
pc = u.pc.to(u.cm)
yr = u.yr.to(u.s)
Myr = 1e6*yr
gamma = 5/3

@yt.derived_field(name="pressure", units="g  / s**2 / cm")
def _pressure(field, data):
    return (gamma-1) * data["thermal_energy"] * data["density"]

In [None]:
# # # location of the *SNe.dat input file
# SN_directory = "../ICs/cluster/"
# SN_directory = "../ICs/cluster_cooling/"
# SN_directory = "../ICs/cluster_cooling_lowres/"
# SN_directory = "../ICs/cluster_cooling_150/"
# SN_directory = "../ICs/cluster_cooling_200/"
# SN_directory = "../ICs/cluster_cooling_250/"
# SN_directory = "../ICs/cluster_cooling_300/"
# SN_directory = "../ICs/single/"
# SN_directory = "../ICs/single_cooling/"
# SN_directory = "../ICs/double/"
# SN_directory = "../ICs/double_cooling/"


# # # location of the *snapshot.hdf5 output files
# snapshot_dir = "../output/cluster/"
# snapshot_dir = "../output/cluster_cooling/"
# snapshot_dir = "../output/cluster_cooling_lowres/"
# snapshot_dir = "../output/cluster_cooling_150/"
# snapshot_dir = "../output/cluster_cooling_200/"
# snapshot_dir = "../output/cluster_cooling_250/"
# snapshot_dir = "../output/cluster_cooling_300/"
# snapshot_dir = "../output/single/"
# snapshot_dir = "../output/single_cooling/"
# snapshot_dir = "../output/double/"
# snapshot_dir = "../output/double_cooling/"


# Warning
This file is very much a work-in-progress.

To do:
 - implement a shock-finder
 - only get energy within the remnant

# Overview

In [None]:
possible_SN_files = glob.glob(os.path.join(SN_directory, "*SNe.dat"))

if len(possible_SN_files) == 0: 
    raise FileNotFoundError("No SN data files found in {}".format(SN_directory))
elif len(possible_SN_files) > 1:
    raise RuntimeError("Too many SN data files found in {}".format(SN_directory))

    
SN_file = possible_SN_files[0]
SN_data = np.loadtxt(SN_file, ndmin=2)
sorted_indices = np.argsort(SN_data[:,0])

SN_times         = SN_data[:,0][sorted_indices]
SN_ejecta_masses = SN_data[:,2][sorted_indices]

SN_times -= SN_times[0]
# SN_times[0] = 3e10

SN_times /= u.Myr.to(u.s)
SN_ejecta_masses /= M_solar

In [None]:
print(SN_times)

In [None]:
print(SN_ejecta_masses)

In [None]:
# unit_base = {
#     "length" : (1.0, "pc"),
#     "time"   : (1.0, "Myr"),
#     "mass"   : (1.0, "Msun")
# }

unit_base = {
    "UnitLength_in_cm" : (pc),
    "UnitVelocity_in_cm_per_s"   : (pc / Myr),
    "UnitMass_in_g"   : (M_solar)
}

In [None]:
snapshot_filename_format = "snapshot_???.hdf5"

snapshot_filenames = sorted(glob.glob(os.path.join(snapshot_dir, snapshot_filename_format)))

n_files_ready = len(snapshot_filenames)
if n_files_ready == 0:
    raise FileNotFoundError("No snapshots found in {}".format(snapshot_dir))

ts = yt.load(os.path.join(snapshot_dir, snapshot_filename_format),
             unit_base=unit_base)

times_snapshots = np.array([ts[i].current_time.convert_to_cgs() 
                  for i in range(len(ts))]) / u.Myr.to(u.s)

print("Loaded {} snapshots".format(len(ts)))

ds = ts[0]
rho_0 = ds.all_data()["all","density"].mean()


In [None]:
times_snapshots

In [None]:
print( "Length unit: ",   ds.length_unit)
print( "Time unit: ",     ds.time_unit)
print( "Mass unit: ",     ds.mass_unit)
print( "Velocity unit: ", ds.velocity_unit)

# Plot Global Quantities

In [None]:
ds = ts[0]
dd = ds.all_data()

In [None]:
energies = np.loadtxt(os.path.join(snapshot_dir, "energy.txt"), ndmin=2)

times_statistics    = energies[:,0] * ds.time_unit
thermal_energies    = energies[:,1] * ds.mass_unit * (ds.velocity_unit)**2
potential_energies  = energies[:,2] * ds.mass_unit * (ds.velocity_unit)**2
kinetic_energies    = energies[:,3] * ds.mass_unit * (ds.velocity_unit)**2

times_statistics    = times_statistics.convert_to_cgs().value / Myr
thermal_energies    = thermal_energies.convert_to_cgs().value
potential_energies  = potential_energies.convert_to_cgs().value
kinetic_energies    = kinetic_energies.convert_to_cgs().value


total_mass    = energies[:,-6]


total_energies = thermal_energies + kinetic_energies

In [None]:
kinetic_energies

In [None]:
sns.rugplot(SN_times, color="k", linewidth=3)
plt.plot(times_statistics, kinetic_energies)
plt.xlabel(r"$t$ $[\mathrm{Myr}]$")
plt.ylabel(r"$E_\mathrm{kin}$ $[\mathrm{ergs}]$")

In [None]:
thermal_energies

In [None]:
sns.rugplot(SN_times, color="k", linewidth=3)
plt.plot(times_statistics, thermal_energies)
plt.xlabel(r"$t$ $[\mathrm{Myr}]$")
plt.ylabel(r"$E_\mathrm{int}$ $[\mathrm{ergs}]$")

In [None]:
sns.rugplot(SN_times, color="k", linewidth=3)
plt.plot(times_statistics, total_energies)
plt.xlabel(r"$t$ $[\mathrm{Myr}]$")
plt.ylabel(r"$E_\mathrm{total}$ $[\mathrm{ergs}]$")

To do: remove the contribution from cooling outside the remnant

In [None]:
sns.rugplot(SN_times, color="k", linewidth=3)
plt.plot(times_statistics, total_energies - total_energies[0])
plt.xlabel(r"$t$ $[\mathrm{Myr}]$")
plt.ylabel(r"$\Delta E_\mathrm{total}$ $[\mathrm{ergs}]$")
print((total_energies - total_energies[0]) / 1e51)

In [None]:
if "double" in snapshot_dir:

    final_checkpoint_before_second_SN = np.searchsorted(times_statistics, SN_times[1])-1
    
    sns.rugplot(SN_times[:1], color="k", linewidth=3)
    plt.plot(times_statistics[0:final_checkpoint_before_second_SN],
             total_energies[0:final_checkpoint_before_second_SN] - total_energies[0])
    plt.xlabel(r"$t$ $[\mathrm{Myr}]$")
    plt.ylabel(r"$\Delta E_\mathrm{total}$ $[\mathrm{ergs}]$")

    print( (total_energies[0:final_checkpoint_before_second_SN] - total_energies[0]) / 1e51 )

## Mass Plots

WARNING: in "energy.txt" GIZMO only uses %g formatting; the change in total mass due to ejecta might be truncated 

In [None]:
def total_mass_of_snapshot(snapshot_filename):
    f = h5py.File(snapshot_filename, mode="r")

    total_mass = np.sum(f["PartType0"]["Masses"], dtype=float)
    
    f.close()
    
    return total_mass

In [None]:
_masses = [total_mass_of_snapshot(snapshot_filename) for snapshot_filename in snapshot_filenames]
    

In [None]:
# ## Waaaay slower than just reading the hdf5 file

# _masses = np.empty(len(ts))

# for i, ds in enumerate(ts):
#     dd = ds.all_data()

#     _masses[i]  = dd["all", "Masses"].sum() / M_solar

In [None]:
sns.rugplot(SN_times, color="k", linewidth=3)
plt.plot(times_snapshots, _masses - _masses[0], 
         label="snapshots", linestyle="solid", drawstyle="steps-post")
plt.plot(SN_times, SN_ejecta_masses.cumsum(), 
         label="intended", linestyle="dashed", drawstyle="steps-post")
# plt.plot(times_statistics, total_mass - total_mass[0], 
#          label="energy.txt", drawstyle="steps-post", linestyle="dotted")
plt.xlabel(r"$t$ $[\mathrm{Myr}]$")
plt.ylabel(r"$\Delta M$ $[M_\odot]$")
plt.legend(loc="best")

## Momentum Plots

In [None]:
def total_radial_momentum_from_snapshot_file(snapshot_filename):
    f = h5py.File(snapshot_filename, mode="r")
    masses_shape = f["PartType0"]["Masses"].shape
    new_shape = masses_shape + (1,)

    r_hat = f["PartType0"]["Coordinates"] - (f["Header"].attrs["BoxSize"]/2)
    r_hat = r_hat / (np.sum(r_hat**2, axis=1, dtype=float).reshape(new_shape))**.5
    
    mom = np.sum(r_hat * f["PartType0"]["Velocities"] \
    * np.reshape(f["PartType0"]["Masses"], new_shape), dtype=float)
    
    f.close()
    return mom * M_solar * pc / (1e6*yr)

In [None]:
radial_momentum = [total_radial_momentum_from_snapshot_file(snapshot_filename)
                   for snapshot_filename in snapshot_filenames]

radial_momentum = np.array(radial_momentum)

In [None]:
# ## Waaaay slower than just reading the hdf5 file

# radial_momentum = np.empty(len(ts))

# for i, ds in enumerate(ts):
#     dd = ds.all_data()

#     v_r = ((dd["all", "Coordinates"]-ds.domain_center) \
#             / dd["all", "particle_radius"].reshape(dd["all", "Masses"].size, 1) \
#             * dd["all","Velocities"])\
#             .sum(axis=1)

#     radial_momentum[i]  = (v_r * dd["all", "Masses"]).sum()
    

In [None]:
radial_momentum

In [None]:
times_snapshots

In [None]:
sns.rugplot(SN_times, color="k", linewidth=3)
plt.plot(times_snapshots, radial_momentum / (100 * M_solar * 1e5))

plt.xlabel(r"$t$ $[\mathrm{Myr}]$")
plt.ylabel(r"$p$ $[100$ $M_\odot$ $\mathrm{km}$ $\mathrm{s}^{-1}]$")
plt.ylim(ymin=0)

In [None]:
sns.rugplot(SN_times, color="k", linewidth=3)
plt.plot(times_snapshots, radial_momentum / (100 * M_solar * 1e5 * SN_times.size))
plt.xlabel(r"$t$ $[\mathrm{Myr}]$")
plt.ylabel(r"$p$ $[100$ $M_\odot$ $N_\mathrm{SNe}$ $\mathrm{km}$ $\mathrm{s}^{-1}]$")
plt.ylim(ymin=0)

# Plot Snapshot Views

In [None]:
def show_projected_density(i):
    ds = ts[i]
    
    p = yt.ProjectionPlot(ds, "x", ("gas","density"))
    p.set_cmap(field="density", cmap="viridis")
    p.annotate_timestamp(corner="upper_left", draw_inset_box=True)
    
    t = ds.current_time.convert_to_cgs().value / u.Myr.to(u.s)
    N_SNe_so_far = np.sum(t > SN_times)
    p.annotate_text((.8,.94), 
                    "N_SNe: {}".format(N_SNe_so_far),
                    coord_system="axis",
                    inset_box_args={"facecolor":"darkslategray",
                                       "alpha":0.9},
                   )
    p.show()
    
ipywidgets.interact(show_projected_density,
                i=ipywidgets.IntSlider(min=0,
                                       max=len(ts)-1,
                                       value=len(ts)-1))

In [None]:
field_type = {
    "density": "gas",
    "temperature": "gas",
    "pressure": "gas",
    "velocity_magnitude": "gas",
    "radius": "index",
    "metallicity": "gas"
}

def show_sliced_field(i, field):
    ds = ts[i]
    s = yt.SlicePlot(ds, "z", (field_type[field], field))
    s.set_cmap(field=field, cmap="viridis")
    s.annotate_timestamp(corner="upper_left", draw_inset_box=True)
    t = ds.current_time.convert_to_cgs().value / u.Myr.to(u.s)
    N_SNe_so_far = np.sum(t > SN_times)
    s.annotate_text((.8,.94), 
                    "N_SNe: {}".format(N_SNe_so_far),
                    coord_system="axis",
                    inset_box_args={"facecolor":"darkslategray",
                                       "alpha":0.9},
                   )
    s.show()
    
ipywidgets.interact(show_sliced_field,
                i=ipywidgets.IntSlider(min=0,
                                       max=len(ts)-1,
                                       value=len(ts)-1),
                field = ipywidgets.Dropdown(options=list(field_type.keys()),
                                            value="density"))

# Profiles

In [None]:
def create_density_profile(ds, n_bins=20):
    dd = ds.all_data()
    r_max = ds.domain_width[0]/2

    dr = r_max / n_bins

    rs = np.linspace(0, r_max.value, num=n_bins+1)[1:]

    dmass = np.zeros(n_bins)
    ones = np.zeros(n_bins, dtype=int)

    for i in range(n_bins):
        r_i = dr*(i)
        r_o = dr*(i+1)

        mask =    (dd["all", "particle_position_spherical_radius"] >= r_i) \
                & (dd["all", "particle_position_spherical_radius"] <  r_o)

#         ones[i] = mask.sum()
        dmass[i] = dd["all", "Masses"][mask].sum().convert_to_cgs().value
            
    Vs = 4/3*np.pi*rs**3
    Vs = np.insert(Vs, 0, 0)
    dVs = Vs[1:] - Vs[:-1]

    densities = dmass / (dVs * pc**3)
    
    return rs, densities

In [None]:
field_y_labels = {
    "density" : r"$\rho$ $[\mathrm{m_p}$ $\mathrm{cm}^{-3}]$",
    "temperature" : r"$T$ $[\mathrm{K}]$",
    "pressure" : r"$P$ $[\mathrm{ergs}$ $\mathrm{cm}^{-3}]$",
    "velocity_magnitude" : r"$\|\mathbf{v}\|$ $[\mathrm{km}$ $\mathrm{s}^{-1}]$",
    "radial_velocity" : r"$v_r$ $[\mathrm{km}$ $\mathrm{s}^{-1}]$",
    "Metallicity" : r"$Z / Z_\odot$",
}

field_weight = {
    "temperature" : "cell_mass",
    "pressure" : "cell_volume",
    "velocity_magnitude" : "cell_mass",
    "radial_velocity" : "cell_mass",
    "Metallicity" : "cell_mass",
}

field_units = {
    "density" : m_proton,
    "temperature" : 1,
    "pressure" : 1,
    "velocity_magnitude" : 1e5, # km / s
    "radial_velocity"    : 1e5, # km / s
    "Metallicity" : 0.02,
}



def show_profile(i, field):
    ds = ts[i]
    sp = ds.sphere(ds.domain_center, ds.domain_width[0]/2)
    
    
    if field is "density":
        rs, densities = create_density_profile(ds,n_bins=64)
        plt.plot(rs, densities / field_units[field])

        plt.ylim(ymin=1e-4)
        
        plt.axhline(rho_0, linestyle="dashed", color="k")
        
    else:    
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            pp = yt.create_profile(sp, 
                                   "radius", [field, "ones"], 
                                   weight_field=field_weight[field],
                                   units = {"radius":"pc"},
                                   logs = {"radius":False},
                                   n_bins=64,
            )
        mask = pp["ones"] > 0.1 # filter out bins with no particles
        plt.plot(pp.x.value[mask], 
                 pp[field][mask] / field_units[field])

    plt.yscale("log")
    plt.ylabel(field_y_labels[field])

    plt.xlabel(r"$R$ $[\mathrm{pc}]$")
    
    if times_snapshots[i] < 0:
        raise RuntimeError("Invalid time: {}".format(time_snapshots[i]))
    elif times_snapshots[i] < 1e-3:
        time = times_snapshots[i]
        time_units = "Myr"
        title = r"$t$ $= {:.1e}$ $\mathrm{{{}}}$".format(time, time_units)
    elif times_snapshots[i] < 1:
        time = times_snapshots[i] * 1e3
        time_units = "kyr"
        title = r"$t$ $= {:.0f}$ $\mathrm{{{}}}$".format(time, time_units)
    elif times_snapshots[i] < 10:
        time = times_snapshots[i]
        time_units = "Myr"
        title = r"$t$ $= {:.1f}$ $\mathrm{{{}}}$".format(time, time_units)
    else:
        time = times_snapshots[i]
        time_units = "Myr"
        title = r"$t$ $= {:.0f}$ $\mathrm{{{}}}$".format(time, time_units)
        
    plt.title(title)
    

ipywidgets.interact(show_profile,
                i=ipywidgets.IntSlider(min=0,
                                       max=len(ts)-1,
                                       value=len(ts)-1),
                field = ipywidgets.Dropdown(options=list(field_y_labels),
                                            value="density"))