In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl
from matplotlib.ticker import AutoLocator, AutoMinorLocator, LogLocator
from matplotlib.colors import Normalize
import glob
from scipy.interpolate import griddata
from pathlib import Path
import h5py
import mpl_toolkits.axes_grid1
import sys
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerTuple
from matplotlib.ticker import MultipleLocator

# Where am I running?
try:
    # Normal script
    here = Path(__file__).resolve().parent
except NameError:
    # Notebook / REPL
    here = Path.cwd()

phys_const_path = (here / '..' / 'phys_const').resolve()
sys.path.append(str(phys_const_path))

nsm_plots_path = (here / '..' / 'nsm_plots').resolve()
sys.path.append(str(nsm_plots_path))

nsm_plots_postproc = (here / '..' / 'nsm_instabilities').resolve()
sys.path.append(str(nsm_plots_postproc))

import phys_const as pc
import plot_functions as pf
import functions_angular_crossings as fac


### Reading hdf5 data

In [None]:
bh_r=5.43, # km
bh_x=48.0, # km
bh_y=48.0, # km
bh_z=16.0, # km
# bh_x=0.0, # km
# bh_y=0.0, # km
# bh_z=0.0, # km

In [None]:
N_dot_nu_e_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/N_dot_nu_e_eq.h5', 'r')
xlow_nu_e = N_dot_nu_e_eq_h5['xlow'][:]
ylow_nu_e = N_dot_nu_e_eq_h5['ylow'][:]
N_dot_nu_e_eq_log10_low = N_dot_nu_e_eq_h5['N_dot_nu_e_eq_log10_low'][:]
xup_nu_e = N_dot_nu_e_eq_h5['xup'][:]
yup_nu_e = N_dot_nu_e_eq_h5['yup'][:]
N_dot_nu_e_eq_log10_up = N_dot_nu_e_eq_h5['N_dot_nu_e_eq_log10_up'][:]
N_dot_nu_e_eq_h5.close()

N_dot_nubar_e_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/N_dot_nubar_e_eq.h5', 'r')
xlow_nubar_e = N_dot_nubar_e_eq_h5['xlow'][:]
ylow_nubar_e = N_dot_nubar_e_eq_h5['ylow'][:]
N_dot_nubar_e_eq_log10_low = N_dot_nubar_e_eq_h5['N_dot_nubar_e_eq_log10_low'][:]
xup_nubar_e = N_dot_nubar_e_eq_h5['xup'][:]
yup_nubar_e = N_dot_nubar_e_eq_h5['yup'][:]
N_dot_nubar_e_eq_log10_up = N_dot_nubar_e_eq_h5['N_dot_nubar_e_eq_log10_up'][:]

N_dot_nu_x_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/N_dot_nu_x_eq.h5', 'r')
xlow_nu_x = N_dot_nu_x_eq_h5['xlow'][:]
ylow_nu_x = N_dot_nu_x_eq_h5['ylow'][:]
N_dot_nu_x_eq_log10_low = N_dot_nu_x_eq_h5['N_dot_nu_x_eq_log10_low'][:]
xup_nu_x = N_dot_nu_x_eq_h5['xup'][:]
yup_nu_x = N_dot_nu_x_eq_h5['yup'][:]
N_dot_nu_x_eq_log10_up = N_dot_nu_x_eq_h5['N_dot_nu_x_eq_log10_up'][:]
N_dot_nu_x_eq_h5.close()

T_map = 'hot'
ye_map = 'hot'
rho_map = 'hot'

# color bar limits
N_dot_max = np.nanmax([
    np.nanmax(N_dot_nu_e_eq_log10_low),
    np.nanmax(N_dot_nu_e_eq_log10_up),
    np.nanmax(N_dot_nubar_e_eq_log10_low),
    np.nanmax(N_dot_nubar_e_eq_log10_up),
    np.nanmax(N_dot_nu_x_eq_log10_low),
    np.nanmax(N_dot_nu_x_eq_log10_up)
])
N_dot_min = np.nanmin([
    np.nanmin(N_dot_nu_e_eq_log10_low),
    np.nanmin(N_dot_nu_e_eq_log10_up),
    np.nanmin(N_dot_nubar_e_eq_log10_low),
    np.nanmin(N_dot_nubar_e_eq_log10_up),
    np.nanmin(N_dot_nu_x_eq_log10_low),
    np.nanmin(N_dot_nu_x_eq_log10_up)
])

# N_dot_max = 56
N_dot_min = 47

# # Set NaN for points inside the black hole (r < bh_r) in both slices
# r_bh = bh_r  # km

# # For z = const (up) slice
# r_up = np.sqrt((xup - bh_x)**2 + (yup - bh_y)**2 + (z[z_slice_indx] - bh_z)**2)
# mask_up = r_up < r_bh
# temp_slice_up[mask_up] = np.nan
# ye_slice_up[mask_up] = np.nan
# rho_slice_up[mask_up] = np.nan

# # For y = const (low) slice
# r_low = np.sqrt((xlow - bh_x)**2 + (y[y_slice_indx] - bh_y)**2 + (ylow - bh_z)**2)
# mask_low = r_low < r_bh
# temp_slice_low[mask_low] = np.nan
# ye_slice_low[mask_low] = np.nan
# rho_slice_low[mask_low] = np.nan

xmid = 48 # km
ymid = 48 # km
zmid = 16 # km

# create 2x3 figure
fig, axes = plt.subplots(2, 3, figsize=(18, 12), sharex='col', sharey='row')
plt.subplots_adjust(wspace=0, hspace=-0.255)
fig.align_labels()

# # colorbars
# norm = mpl.colors.Normalize(vmin=temp_min, vmax=temp_max)
# sm = mpl.cm.ScalarMappable(cmap=T_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,0])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$T\, [\mathrm{MeV}]$", labelpad=10)
# cax.xaxis.set_major_locator(MultipleLocator(1))

# norm = mpl.colors.Normalize(vmin=ye_min, vmax=ye_max)
# sm = mpl.cm.ScalarMappable(cmap=ye_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,1])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$Y_e$", labelpad=10)

# norm = mpl.colors.Normalize(vmin=rho_min, vmax=rho_max)
# sm = mpl.cm.ScalarMappable(cmap=rho_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,2])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$\log\ \rho \ [\mathrm{g/ccm}]$", labelpad=10)
# cax.xaxis.set_major_locator(MultipleLocator(1))


# Shared colorbar
# Shared colorbar (manually set position)

norm = Normalize(N_dot_min, N_dot_max)
sm = plt.cm.ScalarMappable(cmap=T_map, norm=norm)
sm.set_array([])

# Manually set colorbar position using fig.add_axes([left, bottom, width, height])
# Example: right side, full height, narrow bar
cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.64])  # adjust as needed
cbar = fig.colorbar(sm, cax=cbar_ax)
cbar.set_label(r'$\log\ \dot{N} \, [\mathrm{s}^{-1}]$', fontsize=22)
cbar.ax.yaxis.set_minor_locator(AutoMinorLocator())


# set axis labels
for ax in axes[-1,:]:
    ax.set_xlabel(r'$x\ [\mathrm{km}]$')
axes[0,0].set_ylabel(r'$z\ [\mathrm{km}]$')
axes[1,0].set_ylabel(r'$y\ [\mathrm{km}]$')
# remove xtick labels on all but the bottom row
for ax in axes[:-1,:].flatten():
    for label in ax.get_xticklabels():
        label.set_visible(False)
for ax in axes[:,1:].flatten():
    for label in ax.get_yticklabels():
        label.set_visible(False)
for ax in axes.flat:
    ax.set_aspect('equal')
    ax.xaxis.set_major_locator(MultipleLocator(20))
    ax.yaxis.set_major_locator(MultipleLocator(20))
    ax.minorticks_on()
    # apply_custom_settings(ax, False)
# mask out the black hole region
# Set white as the color for NaN values in the colormap
# cmap_with_nan = mpl.cm.get_cmap(cmap).copy()
# cmap_with_nan.set_bad(color='white')

# Temperature, z=const slice (bottom left)
axes[1,0].pcolormesh(xup_nu_e-xmid , yup_nu_e-ymid , N_dot_nu_e_eq_log10_up , shading='auto', cmap=T_map, vmin=N_dot_min, vmax=N_dot_max)
circle_bh = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,0].add_patch(circle_bh)
legend_circle = Line2D([0], [0], color='k', linestyle='dashed', linewidth=1, label='Black hole')
axes[1,0].legend(handles=[legend_circle], loc='upper right', frameon=False, fontsize=14)

# Temperature, y=const slice (top left)
axes[0,0].pcolormesh(xlow_nu_e-xmid, ylow_nu_e-zmid, N_dot_nu_e_eq_log10_low, shading='auto', cmap=T_map, vmin=N_dot_min, vmax=N_dot_max)
circle_bh_00 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,0].add_patch(circle_bh_00)

# Ye, z=const slice (bottom middle)
axes[1,1].pcolormesh(xup_nubar_e-xmid , yup_nubar_e-ymid , N_dot_nubar_e_eq_log10_up , shading='auto', cmap=ye_map, vmin=N_dot_min, vmax=N_dot_max)
circle_bh_11 = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,1].add_patch(circle_bh_11)

# Ye, y=const slice (top middle)
axes[0,1].pcolormesh(xlow_nubar_e-xmid, ylow_nubar_e-zmid, N_dot_nubar_e_eq_log10_low, shading='auto', cmap=ye_map, vmin=N_dot_min, vmax=N_dot_max)
circle_bh_01 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,1].add_patch(circle_bh_01)

# log10(rho), z=const slice (bottom right)
axes[1,2].pcolormesh(xup_nu_x-xmid , yup_nu_x-ymid , N_dot_nu_x_eq_log10_up , shading='auto', cmap=rho_map, vmin=N_dot_min  , vmax=N_dot_max  )
circle_bh_12 = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,2].add_patch(circle_bh_12)

# log10(rho), y=const slice (top right)
axes[0,2].pcolormesh(xlow_nu_x-xmid, ylow_nu_x-zmid, N_dot_nu_x_eq_log10_low, shading='auto', cmap=rho_map, vmin=N_dot_min, vmax=N_dot_max)
circle_bh_02 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,2].add_patch(circle_bh_02)

axes[1,0].text(0.1, 0.15, r'$\nu_e$', transform=axes[1,0].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))
# axes[0,0].text(0.1, 0.15, r'$\nu_e$', transform=axes[0,0].transAxes, fontsize=28,
    # verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

axes[1,1].text(0.1, 0.15, r'$\bar{\nu}_e$', transform=axes[1,1].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))
# axes[0,1].text(0.1, 0.15, r'$\bar{\nu}_e$', transform=axes[0,1].transAxes, fontsize=28,
#     verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

axes[1,2].text(0.1, 0.15, r'$\nu_x$', transform=axes[1,2].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))
# axes[0,2].text(0.1, 0.15, r'$\nu_x$', transform=axes[0,2].transAxes, fontsize=28,
#     verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

plt.savefig('plots/N_dot.pdf', bbox_inches='tight')