In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib as mpl
from matplotlib.ticker import AutoLocator, AutoMinorLocator, LogLocator
from matplotlib.colors import Normalize
import glob
from scipy.interpolate import griddata
from pathlib import Path
import h5py
import mpl_toolkits.axes_grid1
import sys
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerTuple
from matplotlib.ticker import MultipleLocator

# Where am I running?
try:
    # Normal script
    here = Path(__file__).resolve().parent
except NameError:
    # Notebook / REPL
    here = Path.cwd()

phys_const_path = (here / '..' / 'phys_const').resolve()
sys.path.append(str(phys_const_path))

nsm_plots_path = (here / '..' / 'nsm_plots').resolve()
sys.path.append(str(nsm_plots_path))

nsm_plots_postproc = (here / '..' / 'nsm_instabilities').resolve()
sys.path.append(str(nsm_plots_postproc))

import phys_const as pc
import plot_functions as pf
import functions_angular_crossings as fac


### Reading hdf5 data

In [None]:
bh_r=5.43, # km
bh_x=48.0, # km
bh_y=48.0, # km
bh_z=16.0, # km
# bh_x=0.0, # km
# bh_y=0.0, # km
# bh_z=0.0, # km

In [None]:
har_n_e_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/n_eu_8.h5', 'r')
har_xlow_n_e      = har_n_e_eq_h5['xlow'][:]
har_ylow_n_e      = har_n_e_eq_h5['ylow'][:]
har_n_e_log10_low = har_n_e_eq_h5['zlow'][:]
har_xup_n_e       = har_n_e_eq_h5['xup'][:]
har_yup_n_e       = har_n_e_eq_h5['yup'][:]
har_n_e_log10_up  = har_n_e_eq_h5['zup'][:]
har_n_e_eq_h5.close()

har_nbar_e_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/n_eu_24.h5', 'r')
har_xlow_nbar_e      = har_nbar_e_eq_h5['xlow'][:]
har_ylow_nbar_e      = har_nbar_e_eq_h5['ylow'][:]
har_nbar_e_log10_low = har_nbar_e_eq_h5['zlow'][:]
har_xup_nbar_e       = har_nbar_e_eq_h5['xup'][:]
har_yup_nbar_e       = har_nbar_e_eq_h5['yup'][:]
har_nbar_e_log10_up  = har_nbar_e_eq_h5['zup'][:]
har_nbar_e_eq_h5.close()

har_n_x_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/n_eu_-1.h5', 'r')
har_xlow_n_x      = har_n_x_eq_h5['xlow'][:]
har_ylow_n_x      = har_n_x_eq_h5['ylow'][:]
har_n_x_log10_low = har_n_x_eq_h5['zlow'][:]
har_xup_n_x       = har_n_x_eq_h5['xup'][:]
har_yup_n_x       = har_n_x_eq_h5['yup'][:]
har_n_x_log10_up  = har_n_x_eq_h5['zup'][:]
har_n_x_eq_h5.close()
                                          
T_map = 'viridis'
ye_map = 'viridis'
rho_map = 'viridis'

# color bar limits
n_max = np.nanmax([
    np.nanmax(har_n_e_log10_low),
    np.nanmax(har_n_e_log10_up),
    np.nanmax(har_nbar_e_log10_low),
    np.nanmax(har_nbar_e_log10_up),
    np.nanmax(har_n_x_log10_low),
    np.nanmax(har_n_x_log10_up)
])
n_min = np.nanmin([
    np.nanmin(har_n_e_log10_low),
    np.nanmin(har_n_e_log10_up),
    np.nanmin(har_nbar_e_log10_low),
    np.nanmin(har_nbar_e_log10_up),
    np.nanmin(har_n_x_log10_low),
    np.nanmin(har_n_x_log10_up)
])

n_min = 23

# n_max = np.nan
# n_min = np.nan

# N_dot_max = 56
# N_dot_min = 47

# # Set NaN for points inside the black hole (r < bh_r) in both slices
# r_bh = bh_r  # km

# # For z = const (up) slice
# r_up = np.sqrt((xup - bh_x)**2 + (yup - bh_y)**2 + (z[z_slice_indx] - bh_z)**2)
# mask_up = r_up < r_bh
# temp_slice_up[mask_up] = np.nan
# ye_slice_up[mask_up] = np.nan
# rho_slice_up[mask_up] = np.nan

# # For y = const (low) slice
# r_low = np.sqrt((xlow - bh_x)**2 + (y[y_slice_indx] - bh_y)**2 + (ylow - bh_z)**2)
# mask_low = r_low < r_bh
# temp_slice_low[mask_low] = np.nan
# ye_slice_low[mask_low] = np.nan
# rho_slice_low[mask_low] = np.nan

xmid = 48 # km
ymid = 48 # km
zmid = 16 # km

# create 2x3 figure
fig, axes = plt.subplots(2, 3, figsize=(18, 12), sharex='col', sharey='row')
plt.subplots_adjust(wspace=0, hspace=-0.255)
fig.align_labels()

# # colorbars
# norm = mpl.colors.Normalize(vmin=temp_min, vmax=temp_max)
# sm = mpl.cm.ScalarMappable(cmap=T_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,0])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$T\, [\mathrm{MeV}]$", labelpad=10)
# cax.xaxis.set_major_locator(MultipleLocator(1))

# norm = mpl.colors.Normalize(vmin=ye_min, vmax=ye_max)
# sm = mpl.cm.ScalarMappable(cmap=ye_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,1])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$Y_e$", labelpad=10)

# norm = mpl.colors.Normalize(vmin=rho_min, vmax=rho_max)
# sm = mpl.cm.ScalarMappable(cmap=rho_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,2])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$\log\ \rho \ [\mathrm{g/ccm}]$", labelpad=10)
# cax.xaxis.set_major_locator(MultipleLocator(1))


# Shared colorbar
# Shared colorbar (manually set position)

norm = Normalize(n_min, n_max)
sm = plt.cm.ScalarMappable(cmap=T_map, norm=norm)
sm.set_array([])

# Manually set colorbar position using fig.add_axes([left, bottom, width, height])
# Example: right side, full height, narrow bar
cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.64])  # adjust as needed
cbar = fig.colorbar(sm, cax=cbar_ax)
cbar.set_label(r'$\log\ |n_{e\mu}| \, [\mathrm{cm}^{-3}]$', fontsize=22)
cbar.ax.yaxis.set_minor_locator(AutoMinorLocator())


# set axis labels
for ax in axes[-1,:]:
    ax.set_xlabel(r'$x\ [\mathrm{km}]$')
axes[0,0].set_ylabel(r'$z\ [\mathrm{km}]$')
axes[1,0].set_ylabel(r'$y\ [\mathrm{km}]$')
# remove xtick labels on all but the bottom row
for ax in axes[:-1,:].flatten():
    for label in ax.get_xticklabels():
        label.set_visible(False)
for ax in axes[:,1:].flatten():
    for label in ax.get_yticklabels():
        label.set_visible(False)
for ax in axes.flat:
    ax.set_aspect('equal')
    ax.xaxis.set_major_locator(MultipleLocator(20))
    ax.yaxis.set_major_locator(MultipleLocator(20))
    ax.minorticks_on()
    # apply_custom_settings(ax, False)
# mask out the black hole region
# Set white as the color for NaN values in the colormap
# cmap_with_nan = mpl.cm.get_cmap(cmap).copy()
# cmap_with_nan.set_bad(color='white')

levels = [24,25,26]

# Temperature, z=const slice (bottom left)
axes[1,0].pcolormesh(har_xup_n_e - xmid, har_yup_n_e - ymid, har_n_e_log10_up, shading='auto', cmap=T_map, vmin=n_min, vmax=n_max)
cs = axes[1,0].contour(har_xup_n_e - xmid, har_yup_n_e - ymid, har_n_e_log10_up, levels=levels, colors='k', linewidths=1)
axes[1,0].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,0].add_patch(circle_bh)

# Temperature, y=const slice (top left)
axes[0,0].pcolormesh(har_xlow_n_e - xmid, har_ylow_n_e - zmid, har_n_e_log10_low, shading='auto', cmap=T_map, vmin=n_min, vmax=n_max)
cs = axes[0,0].contour(har_xlow_n_e - xmid, har_ylow_n_e - zmid, har_n_e_log10_low, levels=levels, colors='k', linewidths=1)
axes[0,0].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_00 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,0].add_patch(circle_bh_00)

# Ye, z=const slice (bottom middle)
axes[1,1].pcolormesh(har_xup_nbar_e - xmid, har_yup_nbar_e - ymid, har_nbar_e_log10_up, shading='auto', cmap=ye_map, vmin=n_min, vmax=n_max)
cs = axes[1,1].contour(har_xup_nbar_e - xmid, har_yup_nbar_e - ymid, har_nbar_e_log10_up, levels=levels, colors='k', linewidths=1)
axes[1,1].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_11 = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,1].add_patch(circle_bh_11)

# Ye, y=const slice (top middle)
axes[0,1].pcolormesh(har_xlow_nbar_e - xmid, har_ylow_nbar_e - zmid, har_nbar_e_log10_low, shading='auto', cmap=ye_map, vmin=n_min, vmax=n_max)
cs = axes[0,1].contour(har_xlow_nbar_e - xmid, har_ylow_nbar_e - zmid, har_nbar_e_log10_low, levels=levels, colors='k', linewidths=1)
axes[0,1].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_01 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,1].add_patch(circle_bh_01)
legend_circle = Line2D([0], [0], color='k', linestyle='dashed', linewidth=1, label='Black hole')
# axes[0,1].legend(handles=[legend_circle], loc='upper right', frameon=False, fontsize=14)

# log10(rho), z=const slice (bottom right)
axes[1,2].pcolormesh(har_xup_n_x - xmid, har_yup_n_x - ymid, har_n_x_log10_up, shading='auto', cmap=rho_map, vmin=n_min, vmax=n_max)
cs = axes[1,2].contour(har_xup_n_x - xmid, har_yup_n_x - ymid, har_n_x_log10_up, levels=levels, colors='k', linewidths=1)
axes[1,2].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_12 = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,2].add_patch(circle_bh_12)

# log10(rho), y=const slice (top right)
axes[0,2].pcolormesh(har_xlow_n_x - xmid, har_ylow_n_x - zmid, har_n_x_log10_low, shading='auto', cmap=rho_map, vmin=n_min, vmax=n_max)
cs = axes[0,2].contour(har_xlow_n_x - xmid, har_ylow_n_x - zmid, har_n_x_log10_low, levels=levels, colors='k', linewidths=1)
axes[0,2].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_02 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,2].add_patch(circle_bh_02)


legend = axes[1,0].legend(handles=[legend_circle], loc='upper left', frameon=True, fontsize=18)
legend.get_frame().set_alpha(0.9)
for txt in legend.get_texts():
    txt.set_color('Black')


axes[1,0].text(0.1, 0.15, r'$t=0.03$ ms', transform=axes[1,0].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.9, edgecolor='none'))
# axes[0,0].text(0.1, 0.15, r'$\nu_e$', transform=axes[0,0].transAxes, fontsize=28,
    # verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

axes[1,1].text(0.1, 0.15, r'$t=0.08$ ms', transform=axes[1,1].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.9, edgecolor='none'))
# axes[0,1].text(0.1, 0.15, r'$\bar{\nu}_e$', transform=axes[0,1].transAxes, fontsize=28,
#     verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

axes[1,2].text(0.1, 0.15, r'$t=0.32$ ms', transform=axes[1,2].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.9, edgecolor='none'))
# axes[0,2].text(0.1, 0.15, r'$\nu_x$', transform=axes[0,2].transAxes, fontsize=28,
#     verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

plt.savefig('plots/n_eu_off_diagonal_densities.png', bbox_inches='tight', dpi=500)

In [None]:
har_f_e_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/f_eu_8.h5', 'r')
har_xlow_f_e      = har_f_e_eq_h5['xlow'][:]
har_ylow_f_e      = har_f_e_eq_h5['ylow'][:]
har_fhat_f_e_log10_low = har_f_e_eq_h5['zlow'][:]
har_vector_x_low_f_e       = har_f_e_eq_h5['vector_x_low'][:]
har_vector_y_low_f_e       = har_f_e_eq_h5['vector_y_low'][:]
har_vector_x_low_f_e_im       = har_f_e_eq_h5['vector_x_low_im'][:]
har_vector_y_low_f_e_im       = har_f_e_eq_h5['vector_y_low_im'][:]
har_xup_f_e       = har_f_e_eq_h5['xup'][:]
har_yup_f_e       = har_f_e_eq_h5['yup'][:]
har_fhat_f_e_log10_up  = har_f_e_eq_h5['zup'][:]
har_vector_x_up_f_e       = har_f_e_eq_h5['vector_x_up'][:]
har_vector_y_up_f_e       = har_f_e_eq_h5['vector_y_up'][:]
har_vector_x_up_f_e_im       = har_f_e_eq_h5['vector_x_up_im'][:]
har_vector_y_up_f_e_im       = har_f_e_eq_h5['vector_y_up_im'][:]
har_f_e_eq_h5.close()

har_fbar_e_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/f_eu_24.h5', 'r')
har_xlow_fbar_e      = har_fbar_e_eq_h5['xlow'][:]
har_ylow_fbar_e      = har_fbar_e_eq_h5['ylow'][:]
har_fhat_fbar_e_log10_low = har_fbar_e_eq_h5['zlow'][:]
har_vector_x_low_fbar_e       = har_fbar_e_eq_h5['vector_x_low'][:]
har_vector_y_low_fbar_e       = har_fbar_e_eq_h5['vector_y_low'][:]
har_vector_x_low_fbar_e_im       = har_fbar_e_eq_h5['vector_x_low_im'][:]
har_vector_y_low_fbar_e_im       = har_fbar_e_eq_h5['vector_y_low_im'][:]
har_xup_fbar_e       = har_fbar_e_eq_h5['xup'][:]
har_yup_fbar_e       = har_fbar_e_eq_h5['yup'][:]
har_fhat_fbar_e_log10_up  = har_fbar_e_eq_h5['zup'][:]
har_vector_x_up_fbar_e       = har_fbar_e_eq_h5['vector_x_up'][:]
har_vector_y_up_fbar_e       = har_fbar_e_eq_h5['vector_y_up'][:]
har_vector_x_up_fbar_e_im       = har_fbar_e_eq_h5['vector_x_up_im'][:]
har_vector_y_up_fbar_e_im       = har_fbar_e_eq_h5['vector_y_up_im'][:]

har_fbar_e_eq_h5.close()

har_f_x_eq_h5 = h5py.File('/home/erick/software/devscrpts/gw170817_paper_plots/f_eu_-1.h5', 'r')
har_xlow_f_x      = har_f_x_eq_h5['xlow'][:]
har_ylow_f_x      = har_f_x_eq_h5['ylow'][:]
har_fhat_f_x_log10_low = har_f_x_eq_h5['zlow'][:]
har_vector_x_low_f_x       = har_f_x_eq_h5['vector_x_low'][:]
har_vector_y_low_f_x       = har_f_x_eq_h5['vector_y_low'][:]
har_vector_x_low_f_x_im       = har_f_x_eq_h5['vector_x_low_im'][:]
har_vector_y_low_f_x_im       = har_f_x_eq_h5['vector_y_low_im'][:]
har_xup_f_x       = har_f_x_eq_h5['xup'][:]
har_yup_f_x       = har_f_x_eq_h5['yup'][:]
har_fhat_f_x_log10_up  = har_f_x_eq_h5['zup'][:]
har_vector_x_up_f_x       = har_f_x_eq_h5['vector_x_up'][:]
har_vector_y_up_f_x       = har_f_x_eq_h5['vector_y_up'][:]
har_vector_x_up_f_x_im       = har_f_x_eq_h5['vector_x_up_im'][:]
har_vector_y_up_f_x_im       = har_f_x_eq_h5['vector_y_up_im'][:]
har_f_x_eq_h5.close()

colormap = 'viridis'
T_map = colormap
ye_map = colormap
rho_map = colormap

# color bar limits (use f instead of n)
n_max = np.nanmax([
    np.nanmax(har_fhat_f_e_log10_low),
    np.nanmax(har_fhat_f_e_log10_up),
    np.nanmax(har_fhat_fbar_e_log10_low),
    np.nanmax(har_fhat_fbar_e_log10_up),
    np.nanmax(har_fhat_f_x_log10_low),
    np.nanmax(har_fhat_f_x_log10_up)
])
n_min = np.nanmin([
    np.nanmin(har_fhat_f_e_log10_low),
    np.nanmin(har_fhat_f_e_log10_up),
    np.nanmin(har_fhat_fbar_e_log10_low),
    np.nanmin(har_fhat_fbar_e_log10_up),
    np.nanmin(har_fhat_f_x_log10_low),
    np.nanmin(har_fhat_f_x_log10_up)
])

n_min = 22

# n_max = np.nan
# n_min = np.nan

# N_dot_max = 56
# N_dot_min = 47

# # Set NaN for points inside the black hole (r < bh_r) in both slices
# r_bh = bh_r  # km

# # For z = const (up) slice
# r_up = np.sqrt((xup - bh_x)**2 + (yup - bh_y)**2 + (z[z_slice_indx] - bh_z)**2)
# mask_up = r_up < r_bh
# temp_slice_up[mask_up] = np.nan
# ye_slice_up[mask_up] = np.nan
# rho_slice_up[mask_up] = np.nan

# # For y = const (low) slice
# r_low = np.sqrt((xlow - bh_x)**2 + (y[y_slice_indx] - bh_y)**2 + (ylow - bh_z)**2)
# mask_low = r_low < r_bh
# temp_slice_low[mask_low] = np.nan
# ye_slice_low[mask_low] = np.nan
# rho_slice_low[mask_low] = np.nan

xmid = 48 # km
ymid = 48 # km
zmid = 16 # km

# create 2x3 figure
fig, axes = plt.subplots(2, 3, figsize=(18, 12), sharex='col', sharey='row')
plt.subplots_adjust(wspace=0, hspace=-0.255)
fig.align_labels()

# # colorbars
# norm = mpl.colors.Normalize(vmin=temp_min, vmax=temp_max)
# sm = mpl.cm.ScalarMappable(cmap=T_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,0])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$T\, [\mathrm{MeV}]$", labelpad=10)
# cax.xaxis.set_major_locator(MultipleLocator(1))

# norm = mpl.colors.Normalize(vmin=ye_min, vmax=ye_max)
# sm = mpl.cm.ScalarMappable(cmap=ye_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,1])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$Y_e$", labelpad=10)

# norm = mpl.colors.Normalize(vmin=rho_min, vmax=rho_max)
# sm = mpl.cm.ScalarMappable(cmap=rho_map, norm=norm)
# sm.set_array([])  # Required for ScalarMappable even if unused
# divider = mpl_toolkits.axes_grid1.make_axes_locatable(axes[0,2])
# cax = divider.append_axes("top", size="4%", pad=0.05)
# cbar = plt.colorbar(sm, cax=cax, orientation='horizontal')
# cbar.ax.tick_params(which="both",direction="in")
# cbar.ax.xaxis.set_ticks_position('top')
# cbar.ax.xaxis.set_label_position('top')
# cbar.set_label(r"$\log\ \rho \ [\mathrm{g/ccm}]$", labelpad=10)
# cax.xaxis.set_major_locator(MultipleLocator(1))


# Shared colorbar
# Shared colorbar (manually set position)

norm = Normalize(n_min, n_max)
sm = plt.cm.ScalarMappable(cmap=T_map, norm=norm)
sm.set_array([])

# Manually set colorbar position using fig.add_axes([left, bottom, width, height])
# Example: right side, full height, narrow bar
cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.64])  # adjust as needed
cbar = fig.colorbar(sm, cax=cbar_ax)
cbar.set_label(r'$\log\ |\vec{f}_{e\mu}| \, [\mathrm{cm}^{-3}]$', fontsize=22)
cbar.ax.yaxis.set_minor_locator(AutoMinorLocator())


# set axis labels
for ax in axes[-1,:]:
    ax.set_xlabel(r'$x\ [\mathrm{km}]$')
axes[0,0].set_ylabel(r'$z\ [\mathrm{km}]$')
axes[1,0].set_ylabel(r'$y\ [\mathrm{km}]$')
# remove xtick labels on all but the bottom row
for ax in axes[:-1,:].flatten():
    for label in ax.get_xticklabels():
        label.set_visible(False)
for ax in axes[:,1:].flatten():
    for label in ax.get_yticklabels():
        label.set_visible(False)
for ax in axes.flat:
    ax.set_aspect('equal')
    ax.xaxis.set_major_locator(MultipleLocator(20))
    ax.yaxis.set_major_locator(MultipleLocator(20))
    ax.minorticks_on()
    # apply_custom_settings(ax, False)
# mask out the black hole region
# Set white as the color for NaN values in the colormap
# cmap_with_nan = mpl.cm.get_cmap(cmap).copy()
# cmap_with_nan.set_bad(color='white')

levels = [-1.4, -1.0, -0.6]
ve = 4

print(f'har_xup_f_e.shape = {har_xup_f_e.shape}')
print(f'har_fhat_f_e_log10_up.shape = {har_fhat_f_e_log10_up.shape}')

# Temperature, z=const slice (bottom left)
axes[1,0].pcolormesh(har_xup_f_e - xmid, har_yup_f_e - ymid, har_fhat_f_e_log10_up, shading='auto', cmap=T_map, vmin=n_min, vmax=n_max)
# cs = axes[1,0].contour(har_xup_f_e - xmid, har_yup_f_e - ymid, har_fhat_f_e_log10_up, levels=levels, colors='k', linewidths=1, linestyles='solid')
axes[1,0].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,0].add_patch(circle_bh)
axes[1,0].quiver(har_xup_f_e[::ve,::ve] - xmid, har_yup_f_e[::ve,::ve] - ymid, ve*har_vector_x_up_f_e[::ve,::ve], ve*har_vector_y_up_f_e[::ve,::ve], color='red', scale=1, scale_units='xy', angles='xy', alpha=1.0)
axes[1,0].quiver(har_xup_f_e[::ve,::ve] - xmid, har_yup_f_e[::ve,::ve] - ymid, ve*har_vector_x_up_f_e_im[::ve,::ve], ve*har_vector_y_up_f_e_im[::ve,::ve], color='white', scale=1, scale_units='xy', angles='xy', alpha=1.0)

# Temperature, y=const slice (top left)
axes[0,0].pcolormesh(har_xlow_f_e - xmid, har_ylow_f_e - zmid, har_fhat_f_e_log10_low, shading='auto', cmap=T_map, vmin=n_min, vmax=n_max)
# cs = axes[0,0].contour(har_xlow_f_e - xmid, har_ylow_f_e - zmid, har_fhat_f_e_log10_low, levels=levels, colors='k', linewidths=1, linestyles='solid')
axes[0,0].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_00 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,0].add_patch(circle_bh_00)
axes[0,0].quiver(har_xlow_f_e[::ve,::ve] - xmid, har_ylow_f_e[::ve,::ve] - zmid, ve*har_vector_x_low_f_e[::ve,::ve], ve*har_vector_y_low_f_e[::ve,::ve], color='red', scale=1, scale_units='xy', angles='xy', alpha=1.0)
axes[0,0].quiver(har_xlow_f_e[::ve,::ve] - xmid, har_ylow_f_e[::ve,::ve] - zmid, ve*har_vector_x_low_f_e_im[::ve,::ve], ve*har_vector_y_low_f_e_im[::ve,::ve], color='white', scale=1, scale_units='xy', angles='xy', alpha=1.0)

# Ye, z=const slice (bottom middle) -- use fbar_e
axes[1,1].pcolormesh(har_xup_fbar_e - xmid, har_yup_fbar_e - ymid, har_fhat_fbar_e_log10_up, shading='auto', cmap=ye_map, vmin=n_min, vmax=n_max)
cs = axes[1,1].contour(har_xup_fbar_e - xmid, har_yup_fbar_e - ymid, har_fhat_fbar_e_log10_up, levels=levels, colors='k', linewidths=1, linestyles='solid')
axes[1,1].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_11 = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,1].add_patch(circle_bh_11)
axes[1,1].quiver(har_xup_fbar_e[::ve,::ve] - xmid, har_yup_fbar_e[::ve,::ve] - ymid, ve*har_vector_x_up_fbar_e[::ve,::ve], ve*har_vector_y_up_fbar_e[::ve,::ve], color='red', scale=1, scale_units='xy', angles='xy', alpha=1.0)
axes[1,1].quiver(har_xup_fbar_e[::ve,::ve] - xmid, har_yup_fbar_e[::ve,::ve] - ymid, ve*har_vector_x_up_fbar_e_im[::ve,::ve], ve*har_vector_y_up_fbar_e_im[::ve,::ve], color='white', scale=1, scale_units='xy', angles='xy', alpha=1.0)

# Ye, y=const slice (top middle) -- use fbar_e
axes[0,1].pcolormesh(har_xlow_fbar_e - xmid, har_ylow_fbar_e - zmid, har_fhat_fbar_e_log10_low, shading='auto', cmap=ye_map, vmin=n_min, vmax=n_max)
cs = axes[0,1].contour(har_xlow_fbar_e - xmid, har_ylow_fbar_e - zmid, har_fhat_fbar_e_log10_low, levels=levels, colors='k', linewidths=1, linestyles='solid')
axes[0,1].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_01 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,1].add_patch(circle_bh_01)
axes[0,1].quiver(har_xlow_fbar_e[::ve,::ve] - xmid, har_ylow_fbar_e[::ve,::ve] - zmid, ve*har_vector_x_low_fbar_e[::ve,::ve], ve*har_vector_y_low_fbar_e[::ve,::ve], color='red', scale=1, scale_units='xy', angles='xy', alpha=1.0)
axes[0,1].quiver(har_xlow_fbar_e[::ve,::ve] - xmid, har_ylow_fbar_e[::ve,::ve] - zmid, ve*har_vector_x_low_fbar_e_im[::ve,::ve], ve*har_vector_y_low_fbar_e_im[::ve,::ve], color='white', scale=1, scale_units='xy', angles='xy', alpha=1.0)

# log10(rho), z=const slice (bottom right) -- use fhat (f_x)
axes[1,2].pcolormesh(har_xup_f_x - xmid, har_yup_f_x - ymid, har_fhat_f_x_log10_up, shading='auto', cmap=rho_map, vmin=n_min, vmax=n_max)
cs = axes[1,2].contour(har_xup_f_x - xmid, har_yup_f_x - ymid, har_fhat_f_x_log10_up, levels=levels, colors='k', linewidths=1, linestyles='solid')
axes[1,2].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_12 = plt.Circle((bh_x[0] - xmid, bh_y[0] - ymid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[1,2].add_patch(circle_bh_12)
axes[1,2].quiver(har_xup_f_x[::ve,::ve] - xmid, har_yup_f_x[::ve,::ve] - ymid, ve*har_vector_x_up_f_x[::ve,::ve], ve*har_vector_y_up_f_x[::ve,::ve], color='red', scale=1, scale_units='xy', angles='xy', alpha=1.0)
axes[1,2].quiver(har_xup_f_x[::ve,::ve] - xmid, har_yup_f_x[::ve,::ve] - ymid, ve*har_vector_x_up_f_x_im[::ve,::ve], ve*har_vector_y_up_f_x_im[::ve,::ve], color='white', scale=1, scale_units='xy', angles='xy', alpha=1.0)

# log10(rho), y=const slice (top right) -- use fhat (f_x)
axes[0,2].pcolormesh(har_xlow_f_x - xmid, har_ylow_f_x - zmid, har_fhat_f_x_log10_low, shading='auto', cmap=rho_map, vmin=n_min, vmax=n_max)
cs = axes[0,2].contour(har_xlow_f_x - xmid, har_ylow_f_x - zmid, har_fhat_f_x_log10_low, levels=levels, colors='k', linewidths=1, linestyles='solid')
axes[0,2].clabel(cs, inline=True, fontsize=15, fmt='%.1f')
circle_bh_02 = plt.Circle((bh_x[0] - xmid, bh_z[0] - zmid), bh_r[0], color='k', fill=False, linestyle='--', linewidth=1)
axes[0,2].add_patch(circle_bh_02)
axes[0,2].quiver(har_xlow_f_x[::ve,::ve] - xmid, har_ylow_f_x[::ve,::ve] - zmid, ve*har_vector_x_low_f_x[::ve,::ve], ve*har_vector_y_low_f_x[::ve,::ve], color='red', scale=1, scale_units='xy', angles='xy', alpha=1.0)
axes[0,2].quiver(har_xlow_f_x[::ve,::ve] - xmid, har_ylow_f_x[::ve,::ve] - zmid, ve*har_vector_x_low_f_x_im[::ve,::ve], ve*har_vector_y_low_f_x_im[::ve,::ve], color='white', scale=1, scale_units='xy', angles='xy', alpha=1.0)

# add legend entries for real (black) and imaginary (white) vectors
black_vec = Line2D([0], [0], marker=r'$\rightarrow$', markersize=18,
                   markerfacecolor='black', markeredgecolor='red', linestyle='None',
                   label=r'Re $\vec{f}_{e\mu}$', alpha=1.0)
white_vec = Line2D([0], [0], marker=r'$\rightarrow$', markersize=18,
                   markerfacecolor='white', markeredgecolor='white', linestyle='None',
                   label=r'Im $\vec{f}_{e\mu}$', alpha=1.0)
legend = axes[1,0].legend(handles=[legend_circle, black_vec, white_vec], loc='upper left', frameon=True, fontsize=18)
legend.get_frame().set_alpha(0.9)
for txt in legend.get_texts():
    txt.set_color('Black')


axes[1,0].text(0.1, 0.15, r'$t=0.03$ ms', transform=axes[1,0].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.9, edgecolor='none'))
# axes[0,0].text(0.1, 0.15, r'$\nu_e$', transform=axes[0,0].transAxes, fontsize=28,
    # verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

axes[1,1].text(0.1, 0.15, r'$t=0.08$ ms', transform=axes[1,1].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.9, edgecolor='none'))
# axes[0,1].text(0.1, 0.15, r'$\bar{\nu}_e$', transform=axes[0,1].transAxes, fontsize=28,
#     verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

axes[1,2].text(0.1, 0.15, r'$t=0.32$ ms', transform=axes[1,2].transAxes, fontsize=28,
    verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.9, edgecolor='none'))
# axes[0,2].text(0.1, 0.15, r'$\nu_x$', transform=axes[0,2].transAxes, fontsize=28,
#     verticalalignment='top', horizontalalignment='left', bbox=dict(facecolor='None', alpha=0.7, edgecolor='none'))

plt.savefig('plots/f_eu_off_diagonal_densities.png', bbox_inches='tight', dpi=500)