In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl
from matplotlib.ticker import AutoLocator, AutoMinorLocator, LogLocator
from matplotlib.colors import Normalize
import glob
from scipy.interpolate import griddata
from pathlib import Path
import h5py
import mpl_toolkits.axes_grid1
import sys
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerTuple
from matplotlib.ticker import MultipleLocator

# Where am I running?
try:
    # Normal script
    here = Path(__file__).resolve().parent
except NameError:
    # Notebook / REPL
    here = Path.cwd()

phys_const_path = (here / '..' / 'phys_const').resolve()
sys.path.append(str(phys_const_path))

nsm_plots_path = (here / '..' / 'nsm_plots').resolve()
sys.path.append(str(nsm_plots_path))

nsm_plots_postproc = (here / '..' / 'nsm_instabilities').resolve()
sys.path.append(str(nsm_plots_postproc))

import phys_const as pc
import plot_functions as pf
import functions_angular_crossings as fac


### Reading hdf5 data

In [None]:
bh_r=5.43 # km
# bh_x=48.0, # km
# bh_y=48.0, # km
# bh_z=16.0, # km
bh_x=0.0 # km
bh_y=0.0 # km
bh_z=0.0 # km

In [None]:
rho_ye_T_h5py = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/rho_Ye_T.hdf5', 'r')

x_slice_indx = 48
y_slice_indx = 48
z_slice_indx = 16

ncellsx = rho_ye_T_h5py['/ncellsx'][()]
ncellsy = rho_ye_T_h5py['/ncellsy'][()]
ncellsz = rho_ye_T_h5py['/ncellsz'][()]

xmax_cm = rho_ye_T_h5py['/xmax_cm'][()]
xmin_cm = rho_ye_T_h5py['/xmin_cm'][()]
ymax_cm = rho_ye_T_h5py['/ymax_cm'][()]
ymin_cm = rho_ye_T_h5py['/ymin_cm'][()]
zmax_cm = rho_ye_T_h5py['/zmax_cm'][()]
zmin_cm = rho_ye_T_h5py['/zmin_cm'][()]

# Create 1D arrays for each axis based on ncells and min/max values
x = np.linspace(xmin_cm / 1e5, xmax_cm / 1e5, ncellsx + 1)  # convert cm to km
y = np.linspace(ymin_cm / 1e5, ymax_cm / 1e5, ncellsy + 1)
z = np.linspace(zmin_cm / 1e5, zmax_cm / 1e5, ncellsz + 1)

x = x[:-1] + np.diff(x) / 2  # shift to cell centers
y = y[:-1] + np.diff(y) / 2  # shift to cell centers
z = z[:-1] + np.diff(z) / 2  # shift to cell centers

# Create 3D meshgrids for the full domain
x3d, y3d, z3d = np.meshgrid(x, y, z, indexing='ij')

# Define "up" (z-slice) and "down" (y-slice) 2D meshgrids
xup, yup = np.meshgrid(x, y, indexing='ij')    # z = const (z_slice_indx)
xlow, ylow = np.meshgrid(x, z, indexing='ij')  # y = const (y_slice_indx)

print(f"xup.shape = {xup.shape}, yup.shape = {yup.shape}")
print(f"xlow.shape = {xlow.shape}, ylow.shape = {ylow.shape}")

T_map = 'viridis'
ye_map = 'plasma'
rho_map = 'cividis'

# profile slices
temp_slice_up  = np.array(rho_ye_T_h5py['T_Mev'])[:,:,z_slice_indx]
print(f"temp_slice_up shape: {temp_slice_up.shape}")
ye_slice_up    = np.array(rho_ye_T_h5py['Ye'])[:,:,z_slice_indx]
rho_slice_up   = np.array(rho_ye_T_h5py['rho_g|ccm'])[:,:,z_slice_indx]
temp_slice_low = np.array(rho_ye_T_h5py['T_Mev'])[:,y_slice_indx,:]
ye_slice_low   = np.array(rho_ye_T_h5py['Ye'])[:,y_slice_indx,:]
rho_slice_low  = np.array(rho_ye_T_h5py['rho_g|ccm'])[:,y_slice_indx,:]

# color bar limits
temp_max = np.nanmax(np.array([np.nanmax(temp_slice_low), np.nanmax(temp_slice_up)]))
temp_min = np.nanmin([np.nanmin(temp_slice_low), np.nanmin(temp_slice_up)])
ye_max = np.nanmax([np.nanmax(ye_slice_low), np.nanmax(ye_slice_up)])
ye_min = np.nanmin([np.nanmin(ye_slice_low), np.nanmin(ye_slice_up)])
rho_max = np.nanmax([np.nanmax(np.log10(rho_slice_low)), np.nanmax(np.log10(rho_slice_up))])
rho_min = np.nanmin([np.nanmin(np.log10(rho_slice_low)), np.nanmin(np.log10(rho_slice_up))])

# create 2x3 figure
fig, axes = plt.subplots(2, 3, figsize=(18, 12), sharex='col', sharey='row')
plt.subplots_adjust(wspace=0, hspace=-0.24)
fig.align_labels()

# colorbars
norm = mpl.colors.Normalize(vmin=temp_min, vmax=temp_max)
sm = mpl.cm.ScalarMappable(cmap=T_map, norm=norm)
sm.set_array([])  # Required for ScalarMappable even if unused
divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,0])
cax = divider.append_axes("top", size="4%", pad=0.05)
cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
cbar.ax.tick_params(which="both",direction="in")
cbar.ax.xaxis.set_ticks_position('top')
cbar.ax.xaxis.set_label_position('top')
cbar.set_label(r"$T\, [\mathrm{MeV}]$", labelpad=10)
cax.xaxis.set_major_locator(MultipleLocator(1))

norm = mpl.colors.Normalize(vmin=ye_min, vmax=ye_max)
sm = mpl.cm.ScalarMappable(cmap=ye_map, norm=norm)
sm.set_array([])  # Required for ScalarMappable even if unused
divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,1])
cax = divider.append_axes("top", size="4%", pad=0.05)
cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
cbar.ax.tick_params(which="both",direction="in")
cbar.ax.xaxis.set_ticks_position('top')
cbar.ax.xaxis.set_label_position('top')
cbar.set_label(r"$Y_e$", labelpad=10)

norm = mpl.colors.Normalize(vmin=rho_min, vmax=rho_max)
sm = mpl.cm.ScalarMappable(cmap=rho_map, norm=norm)
sm.set_array([])  # Required for ScalarMappable even if unused
divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,2])
cax = divider.append_axes("top", size="4%", pad=0.05)
cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
cbar.ax.tick_params(which="both",direction="in")
cbar.ax.xaxis.set_ticks_position('top')
cbar.ax.xaxis.set_label_position('top')
cbar.set_label(r"$\log_{10}\ \rho \ [\mathrm{g\ cm}^{-3}]$", labelpad=10)
cax.xaxis.set_major_locator(MultipleLocator(1))
# set axis labels
for ax in axes[-1,:]:
    ax.set_xlabel(r'$x\ [\mathrm{km}]$')
axes[0,0].set_ylabel(r'$z\ [\mathrm{km}]$')
axes[1,0].set_ylabel(r'$y\ [\mathrm{km}]$')
# remove xtick labels on all but the bottom row
for ax in axes[:-1,:].flatten():
    for label in ax.get_xticklabels():
        label.set_visible(False)
for ax in axes[:,1:].flatten():
    for label in ax.get_yticklabels():
        label.set_visible(False)
for ax in axes.flat:
    ax.set_aspect('equal')
    ax.xaxis.set_major_locator(MultipleLocator(20))
    ax.yaxis.set_major_locator(MultipleLocator(20))
    ax.minorticks_on()
    # apply_custom_settings(ax, False)
# mask out the black hole region
# Set white as the color for NaN values in the colormap
# cmap_with_nan = mpl.cm.get_cmap(cmap).copy()
# cmap_with_nan.set_bad(color='white')

# Temperature, z=const slice (bottom left)
axes[1,0].pcolormesh(xup , yup , temp_slice_up , shading='auto', cmap=T_map, vmin=temp_min, vmax=temp_max)
circle_bh = plt.Circle((bh_x, bh_y), bh_r, color='white', fill=False, linestyle='--', linewidth=1)
axes[1,0].add_patch(circle_bh)
legend_circle = Line2D([0], [0], color='white', linestyle='dashed', linewidth=1, label='Black hole')
axes[1,0].legend(handles=[legend_circle], loc='upper right', frameon=False, fontsize=18)

# Temperature, y=const slice (top left)
axes[0,0].pcolormesh(xlow, ylow, temp_slice_low, shading='auto', cmap=T_map, vmin=temp_min, vmax=temp_max)
circle_bh_00 = plt.Circle((bh_x , bh_z), bh_r, color='white', fill=False, linestyle='--', linewidth=1)
axes[0,0].add_patch(circle_bh_00)

# Ye, z=const slice (bottom middle)
axes[1,1].pcolormesh(xup, yup , ye_slice_up   , shading='auto', cmap=ye_map, vmin=ye_min  , vmax=ye_max  )
circle_bh_11 = plt.Circle((bh_x, bh_y), bh_r, color='white', fill=False, linestyle='--', linewidth=1)
axes[1,1].add_patch(circle_bh_11)

# Ye, y=const slice (top middle)
axes[0,1].pcolormesh(xlow, ylow, ye_slice_low  , shading='auto', cmap=ye_map, vmin=ye_min  , vmax=ye_max  )
circle_bh_01 = plt.Circle((bh_x, bh_z), bh_r, color='white', fill=False, linestyle='--', linewidth=1)
axes[0,1].add_patch(circle_bh_01)

# log10(rho), z=const slice (bottom right)
axes[1,2].pcolormesh(xup , yup , np.log10(rho_slice_up)   , shading='auto', cmap=rho_map, vmin=rho_min  , vmax=rho_max  )
circle_bh_12 = plt.Circle((bh_x, bh_y), bh_r, color='white', fill=False, linestyle='--', linewidth=1)
axes[1,2].add_patch(circle_bh_12)

# log10(rho), y=const slice (top right)
axes[0,2].pcolormesh(xlow, ylow, np.log10(rho_slice_low)  , shading='auto', cmap=rho_map, vmin=rho_min  , vmax=rho_max  )
circle_bh_02 = plt.Circle((bh_x, bh_z), bh_r, color='white', fill=False, linestyle='--', linewidth=1)
axes[0,2].add_patch(circle_bh_02)

plt.savefig('plots/rho_T_Ye.png', bbox_inches='tight', dpi=500)
