In [1]:
import numpy as np
import xarray as xr
import sys
import glob
import netCDF4 as nc
import os
import sys
import h5py
import scipy.io
from scipy.interpolate import griddata
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
from matplotlib.colors import LogNorm
import cartopy.feature as cfeature
from matplotlib.ticker import MaxNLocator
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as colors
from datetime import date, timedelta, datetime
import pandas
import copy


from os.path import join,expanduser,exists,split
user_home_dir = expanduser('~')
sys.path.append(join(user_home_dir,'ECCOv4-py'))
import ecco_v4_py as ecco

# Suppress warning messages for a cleaner presentation
import warnings
warnings.filterwarnings('ignore')

### What we have so far (February 2024):
### OISST MONTHLY only (keep in mind that the name of the plots has daily in it - needs to be fixed in the future):
    - 1993-2016 and 2004-2016
##### How to run it:
    - lev_or_int = 'oisst'
    - dataset_tag = 'oisst_v2'
    - daily_monthly_plot_tag = 'daily'
    - ecco_monthly_tag_field = 'heat'
    - ecco_monthly_tag = False 
    - ecco_daily_tag = True 
    - test_peaks_tag = False
    - delta_time_tag = '1tstep'
    
### ECCO MONTHLY - HEAT:
    - 1993-2016 and 2004-2016
##### How to run it:
    - lev_or_int = 'zlev01'
    - dataset_tag = 'ECCOv4r4_heat'
    - daily_monthly_plot_tag = 'monthly'
    - ecco_monthly_tag_field = 'heat'
    - ecco_monthly_tag = True 
    - ecco_daily_tag = False  
    - test_peaks_tag = False
    - delta_time_tag = '1tstep'
    
### ECCO monthly OHC: 
    - 1993-2016 and 2004-2016:
##### How to run it:
    - lev_or_int = zlev01
    - dataset_tag = 'ECCOv4r4_heat'
    - ecco_monthly_tag_field = 'ohc_to50m'
    - daily_monthly_plot_tag = monthly
    - ecco_monthly_tag = True
    - ecco_daily_tag = False
    - delta_time_tag = '1tstep'   

### ECCO daily: 
    - 1992-2018 and 2004-2016:
    
##### How to run it:
    - lev_or_int = zlev00 or zlev05 or zlev09
    - dataset_tag = 'ECCOv4r4_heat'
    - ecco_monthly_tag_field = 'heat'
    - daily_monthly_plot_tag = daily
    - ecco_monthly_tag = False
    - ecco_daily_tag = True
    - delta_time_tag = '5tstep' 
    
### Argo: 
    - 2004-2016:
    
##### How to run it:
    - lev_or_int = argo
    - dataset_tag = 'argo_ohc15_50'
    - ecco_monthly_tag_field = 'heat'
    - daily_monthly_plot_tag = daily
    - ecco_monthly_tag = False
    - ecco_daily_tag = True
    - delta_time_tag = '5tstep'  

In [2]:
# FUTURE: we can look at # of events by year for each month (plot one map for each month) - area covered by MHWs
# 1 figure for Janaury:
    # subplots for each year:
        # number of events that start in that month 

# 1 figure:
    #1 subplot for each year:
        # count number of events that start in that year 
        
# using area_mhws_1d plot one panel for each month of the year
# in each panel plot the timseries for that month for each year (one line for each year in each plot)

# maybe add contributions from different budget terms based on seasons 
# (same for changes in time keeping in mind different modes of variability (ENSO)) 


In [8]:
# year_start = 2004  # Define the start year
# year_end = 2016  # Define the end year (for ECCO daily,NOT included, will stop the year before this one)

# Years indicated in the name of input file
year_start_in_filename = 1992  # Define the start year
year_end_in_filename = 2018 # Define the end year (for ECCO daily, NOT included, will stop the year before this one)
# Years included in the input file
year_start_inside_file = 1992  # Define the start year
year_end_inside_file = 2017 # Define the end year (for ECCO daily, NOT included, will stop the year before this one)


# Years indicated in the name of input file
# year_start_in_filename = 1993  # Define the start year
# year_end_in_filename = 2016 # Define the end year (for ECCO daily, NOT included, will stop the year before this one)
# # Years included in the input file
# year_start_inside_file = 1993  # Define the start year
# year_end_inside_file = 2016 # Define the end year (for ECCO daily, NOT included, will stop the year before this one)




# Years to plots
year_start_4plot = 1992 #year_start+1 # +0, +1  # Define the start year
year_end_4plot = 2018 #year_end-2 # -0, -2 # Define the end year (for ECCO daily, NOT included, will stop the year before this one)
# year_end_4plot = 1994 #year_end-2 # -0, -2 # Define the end year (for ECCO daily, NOT included, will stop the year before this one)


# ECCO MONTHLY and OISST and ECCO DAILY SUBSAMPLED (to match other data set temporal extension)
# year_start = 1993  # Define the start year
# year_end = 2016  # Define the end year

# Select the level
# lev_or_int = 'zlev00' 
# lev_or_int = 'zlev01' # JUST FOR ECCO MONTHLY 
# lev_or_int = 'zlev05' 
# lev_or_int = 'zlev09' 
lev_or_int = 'ohc_k0_k5'
# lev_or_int = 'oisst'
# lev_or_int = 'argo'

dataset_tag = 'ECCOv4r4_heat' # argo_ohc15_50 oisst_v2 ECCOv4r4_heat
daily_monthly_plot_tag = 'daily' # 'daily', 'monthly'
ecco_monthly_tag_field = 'ohc_to50m' # JUST FOR ECCO MONTHLY OHC (ohc_to50m) vs LEVELS (or daily ohc) (heat)
ecco_monthly_tag = False # True False
ecco_daily_tag = True # False
test_peaks_tag = False
delta_time_tag = '5tsteps' # '5tsteps' 1 for OISST and ECCO monthly (zlev01), 5 for daily
test_data_alignment = False

grid_and_save_fields_2plot = True

# MHW duration thresholds
MHW_duration_min_threshold = 5 # days (new minimum duration for MHW events) # 5, 30
MHW_duration_max_threshold = np.inf # days (new maximum duration for MHW events) # 30, np.inf
MHW_duration_threshold_tag = str(MHW_duration_min_threshold) + '_' + str(MHW_duration_max_threshold) + '_days'

# load_dir = '/Users/jacoposala/Downloads/figures_Nov3_k9/daily/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/ECCOv4r4_daily_temp_zlev0_1992_2018/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/ECCOv4r4_daily_temp_zlev0_2004_2016/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/ECCOv4r4_daily_temp_zlev5_1992_2018/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/ECCOv4r4_daily_temp_zlev5_2004_2016/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/ECCOv4r4_daily_temp_zlev9_1992_2018/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/ECCOv4r4_daily_temp_zlev9_2004_2016/output/figures/'
load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/ECCOv4r4_heat_daily_temp_ohc_k0_k5_1992_2018/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/ECCOv4r4_daily_temp_ohc_k0_k5_2004_2016/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/OISST_1993_2016/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/OISST_2004_2016/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/ECCOv4r4_monthly_temp_zlev0_2004_2016/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/ECCOv4r4_monthly_temp_zlev0_1993_2016/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/Argo_2004_2016/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/peaks_test_ECCOv4r4_daily_temp_ohc_k0_k5_1992_2018/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/ECCOv4r4_monthly_temp_ohc_to50m_1993_2016_k0_k5/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/ECCOv4r4_monthly_temp_ohc_to50m_1993_2016_k0_k5/output/figures/'
# load_dir = '/Users/jacoposala/Downloads/Blanca_Outputs/ECCOv4r4_monthly_temp_ohc_to50m_2004_2016_k0_k5/output/figures/'

ECCO_dir = '/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/data/'
save_dir = join(ECCO_dir,'outputs')
path = '/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/data/outputs/nc_files_zlev_or_zint/'
save_path = '/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/summary_maps/' + lev_or_int + '/'
gridded_path = '/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/gridded_vars/'


In [9]:
# How many time stamps to ignore?
if year_start_inside_file != year_start_4plot or year_end_inside_file != year_end_4plot:

    if dataset_tag=='ECCOv4r4_heat' and daily_monthly_plot_tag=='daily':
        sdate = date(year_start_inside_file,1,2)   # start date
        edate = date(year_end_inside_file,12,31)   # end date
        time = pandas.date_range(sdate,edate-timedelta(days=1),freq='d')
        time_mask = np.logical_and(time.year>=year_start_4plot, time.year<=year_end_4plot)
        ind_time_start_slice = str(np.where(time_mask)[0][0])
        ind_time_end_slice = str(np.where(time_mask)[0][-1])
else:
    ind_time_start_slice = '0'
    ind_time_end_slice = '-1'
    
    

In [10]:
# ECCO
if dataset_tag == 'ECCOv4r4_heat':
    # Load ECCO grid data
    XC_lon = xr.open_dataset(path + 'ECCOv4r4_XC_lon_1993_2017.nc')
    YC_lat = xr.open_dataset(path + 'ECCOv4r4_YC_lat_1993_2017.nc')
    Z_depth = xr.open_dataset(path + 'ECCOv4r4_Z_depth_1993_2017.nc')
    # ECCO daily files
    if ecco_daily_tag:
        # OHC
        if lev_or_int == 'ohc_k0_k5':
            file_path = load_dir + f'ECCOv4r4_heat_daily_{year_start_in_filename}_{year_end_in_filename}_prcnt90_noTrend_minLen_5tsteps_withAVE.mat'
        # Heat budget terms
        else:
            file_path = load_dir + f'ECCOv4r4_heat_{lev_or_int}_daily_{year_start_in_filename}_{year_end_in_filename}_prcnt90_noTrend_minLen_5tsteps_withAVE.mat'
    # ECCO monthly files
    if ecco_monthly_tag:
        if ecco_monthly_tag_field == 'ohc_to50m':
            file_path = load_dir + f'ECCOv4r4_{ecco_monthly_tag_field}_{lev_or_int}_{year_start_in_filename}_{year_end_in_filename}_prcnt90_noTrend_minLen_1tsteps_withAVE.mat'
        else:
            file_path = load_dir + f'ECCOv4r4_heat_{lev_or_int}_{year_start_in_filename}_{year_end_in_filename}_prcnt90_noTrend_minLen_1tsteps_withAVE.mat'

# Argo
elif dataset_tag == 'argo_ohc15_50':
    file_path = load_dir + f'argo_ohc15_50_{year_start_in_filename}_{year_end_in_filename}_prcnt90_noTrend_minLen_1tsteps_withAVE.mat'

    # Lat/lon
    XC_lon = np.arange(20.5, 380.5)  # Similar to MATLAB's [20.5:379.5]'
    YC_lat = np.arange(-89.5, 90.5)  # Similar to MATLAB's [-89.5:89.5]'

# OISST
elif dataset_tag == 'oisst_v2':
    file_path = load_dir + f'oisst_v2_{year_start_in_filename}_{year_end_in_filename}_prcnt90_noTrend_minLen_1tsteps_withAVE.mat'
    # Load OISST file that includes lat/lon values
    oisst_dataset = xr.open_dataset('/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/OISSTv2/DATA/sst.mon.mean.nc')
    XC_lon = oisst_dataset.lon.values
    YC_lat = oisst_dataset.lat.values

       
# Load the data selected above
mat_data = h5py.File(file_path, 'r')

# List the keys in the file
# print("Keys in the file:", list(mat_data.keys()))

# Access the data under the key '#refs#'
refs_data = mat_data['#refs#']

# Access the data under the key 'find_MHWs_info'
find_mhws_info_data = mat_data['find_MHWs_info']

# List the keys in the file find_MHWs_info'
# print(find_mhws_info_data.keys())

# Calculate the duration of the decline pahse (as a difference between the total duration and the duration of the onset phase)
decline_duration_in_tsteps = find_mhws_info_data['events_duration_in_tsteps'].value - find_mhws_info_data['onset_duration_in_tsteps'].value


In [11]:
find_mhws_info_data.keys()

<KeysViewHDF5 ['G_advection_declineAve', 'G_advection_eventAve', 'G_advection_onsetAve', 'G_diffusion_declineAve', 'G_diffusion_eventAve', 'G_diffusion_onsetAve', 'G_forcing_declineAve', 'G_forcing_eventAve', 'G_forcing_onsetAve', 'G_total_declineAve', 'G_total_eventAve', 'G_total_onsetAve', 'adv_vConv_declineAve', 'adv_vConv_eventAve', 'adv_vConv_onsetAve', 'data_mhw_tstep_msk', 'data_percentile3d', 'data_used4MHWs', 'data_used4MHWs_declineAve', 'data_used4MHWs_eventAve', 'data_used4MHWs_onsetAve', 'delta_tstep', 'dif_vConv_declineAve', 'dif_vConv_eventAve', 'dif_vConv_onsetAve', 'end_tstep', 'end_tstep_stored_at_peak', 'events_duration_in_tsteps', 'events_number', 'flag_remove_trend', 'onset_duration_in_tsteps', 'peak_tstep', 'peak_tstep_msk', 'peak_value', 'percentile', 'start_tstep', 'start_tstep_msk', 'years']>

In [12]:
# load_path = f'/Volumes/MyPassportForMac/MAC_15/NASA_project/2023/NEW_heatBudgetECCO_daily/data/outputs/metadata/'
load_path = '/Users/jacoposala/Downloads/'
# Load metadata
ECCO_metadata = ['XC_lon', 'YC_lat', 'Z_depth', 'vol', 'area']

# Create a dictionary to store the variables
ecco_data = {}
for ivar in ECCO_metadata:
    file_path = f"{load_path}/ECCOv4r4_{ivar}_1993_2017.nc"
    # Open the dataset and store it in the dictionary
    ecco_data[ivar] = xr.open_dataset(file_path)
    
# Access the variables using the dictionary
XC_lon = ecco_data['XC_lon']
YC_lat = ecco_data['YC_lat']
Z_depth = ecco_data['Z_depth']
vol = ecco_data['vol']
area = ecco_data['area']



In [13]:
# Define data from Matlab output
data_used4MHWs = find_mhws_info_data['data_used4MHWs'].value


In [14]:
# Slice time for data from Matlab output
data_used4MHWs_sel = data_used4MHWs[eval(ind_time_start_slice):eval(ind_time_end_slice)+1, :, :]

time_sel = time[eval(ind_time_start_slice):eval(ind_time_end_slice)+1]


### Onset calculations

In [15]:
# reshape 2d 9495x(8100x13) peak_tstep_2d start_tstep_msk_2d

# onset_mask = deep copy start_tstep_msk_2d

# for loop over time
# for each timestep i_t, loop through those point where start_tstep_msk_2d[i_t,:] == 1 
# for each of these points ix, assign onset_mask[i_t:peak_tstep_2d[i_t,ix]+1,ix] = 1 


# new item in condition_list: onset
# use condition_noMHWs (nan where this is satisfied)

# for decline, use condiiton_noMHWS and onset to find decline


In [16]:
save_onset_tag = False

In [17]:
if save_onset_tag:
    # Define data needed from Matlab output
    peak_tstep = find_mhws_info_data['peak_tstep'].value
    start_tstep_msk = find_mhws_info_data['start_tstep_msk'].value


In [18]:
if save_onset_tag:
    # Slice time 
    peak_tstep_sel = peak_tstep[eval(ind_time_start_slice):eval(ind_time_end_slice)+1, :, :]
    start_tstep_msk_sel = start_tstep_msk[eval(ind_time_start_slice):eval(ind_time_end_slice)+1, :, :]

    # Reshape the arrays to 2d
    peak_tstep_2d = peak_tstep_sel.reshape((np.shape(time_sel)[0], 13*8100))
    start_tstep_msk_2d = start_tstep_msk_sel.reshape((np.shape(time_sel)[0], 13*8100))


In [19]:
if save_onset_tag:
    # Create a mask for onset
    # Replace NaN values with -1
    peak_tstep_2d[np.isnan(peak_tstep_2d)] = -1
    onset_mask = copy.deepcopy(start_tstep_msk_2d)

    # Now, proceed with the slicing operation
    for i_t in range(peak_tstep_2d.shape[0]):
        current_datetime = datetime.now()
        # Print the current datetime
        print("Current datetime:", current_datetime)
        print('i_t', i_t)
        print('-----')
        for ix in range(peak_tstep_2d.shape[1]):
            if start_tstep_msk_2d[i_t, ix] == 1:
                onset_mask[i_t:int(peak_tstep_2d[i_t, ix]) + 1, ix] = 1


In [20]:
if save_onset_tag:
    onset_mask_3d = onset_mask.reshape((np.shape(time_sel)[0], 13, 8100))
    condition_onset = onset_mask_3d.astype(bool)
    np.save(f'/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/line_plots/area_covered_MHWs/condition_onset_{dataset_tag}_{daily_monthly_plot_tag}_{lev_or_int}_{year_start_4plot}_{year_end_4plot}.npy', onset_mask_3d)
else:
    condition_onset = np.load(f'/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/line_plots/area_covered_MHWs/condition_onset_{dataset_tag}_{daily_monthly_plot_tag}_{lev_or_int}_{year_start_4plot}_{year_end_4plot}.npy')


In [21]:
# condition_onset = condition_onset != 0

In [22]:
def reshape_metadata_2_variable(xr_meta, np_array_data):
#     print(type(xr_meta))
    meta_reshaped = xr_meta.reshape(np_array_data.shape[1], np_array_data.shape[2])
    meta_reshaped = np.repeat(np.expand_dims(meta_reshaped, axis=2), np_array_data.shape[0], axis=2)
    meta_reshaped = meta_reshaped.transpose(2, 0, 1)
    return(meta_reshaped)


In [23]:
XC_lon_input = XC_lon.XC_lon.values
YC_lat_input = YC_lat.YC_lat.values
area_input = area.area.values
XC_lon_reshaped = reshape_metadata_2_variable(xr_meta=XC_lon_input, np_array_data=data_used4MHWs_sel)
YC_lat_reshaped = reshape_metadata_2_variable(xr_meta=YC_lat_input, np_array_data=data_used4MHWs_sel)
area_reshaped = reshape_metadata_2_variable(xr_meta=area_input, np_array_data=data_used4MHWs_sel)


In [24]:
# plt.scatter(XC_lon_reshaped[0,:,:].flatten(), YC_lat_reshaped[0,:,:].flatten(), 5, data_used4MHWs_sel[0,:,:].flatten())


In [25]:
# area_tot = np.sum(area.area.values.flatten())


In [26]:
load_and_save_vars_input_tag = False

In [27]:
varnames_load = ['G_advection', 'G_diffusion', 'G_forcing']
varnames = ['G_advection', 'G_diffusion', 'G_forcing']

if load_and_save_vars_input_tag:
    print('load_and_save_vars_input_tag is: ' + str(load_and_save_vars_input_tag))
    # Create a dictionary to store the datasets for each variable - FOR SINGLE LEVELS
    datasets = {}
    years = np.arange(year_start_4plot, year_end_4plot+1)

    for ivar_load, ivar in zip(varnames_load, varnames):
        var_data = []
        path_for_load = '/Volumes/MyPassportForMac/MAC_15/NASA_project/2023/NEW_heatBudgetECCO_daily/data/outputs/nc_files_zlev_or_zint'
    #     for year in years:
    #         print(year)
        file_path = f'{path_for_load}/ECCOv4r4_{ivar_load}_{lev_or_int}_{year_start_in_filename}_{year_end_in_filename}.nc'

        dataset = xr.open_dataset(file_path)

        print('done: load')
    #     if year >= 2004 and ivar_load == 'G_advection': 
    #         ivar = ivar_load 
    #         dataset_box = dataset[ivar+'_cut'].where(lat_lon_bounds, np.nan)
    #     elif year>= 2004 and ivar_load == 'G_diffusion':
    #         ivar = ivar_load
    #         dataset_box = dataset[ivar+'_cut'].where(lat_lon_bounds, np.nan)
    #     else:
        dataset_box = dataset[ivar + '_' + lev_or_int]#.where(lat_lon_bounds, np.nan)
        print('done: box')
        var_data.append(dataset_box)

        # Concatenate the datasets along the time dimension
        var_dataset = xr.concat(var_data, dim='time')

        # Store the variable dataset in the dictionary
        datasets[ivar_load] = var_dataset

        # Save the dataset as a NetCDF file
        output_path = f'/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/line_plots/area_covered_MHWs/lineplots_input_{ivar_load}_{lev_or_int}_dataset.nc'  # Replace with your desired output path
        var_dataset.to_netcdf(output_path)
        print(f'Saved dataset for {ivar_load}')

        dataset.close()
        
else:
    datasets = {}
    print('load_and_save_vars_input_tag is: ' + str(load_and_save_vars_input_tag))
    
    # Define the paths of the saved files
    output_paths = '/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/line_plots/area_covered_MHWs/'
    # Create an empty dictionary to store the loaded datasets
    loaded_datasets = {}

    # Load the datasets
    for ivar in varnames_load:
        var_data = []

        # Open the dataset
        dataset_load = xr.open_dataset(output_paths + f'lineplots_input_{ivar}_{lev_or_int}_dataset.nc')
        print('loading ' + ivar)

        dataset = dataset_load[ivar + '_' + lev_or_int]

        var_data.append(dataset)

        var_dataset = xr.concat(var_data, dim='time')

        datasets[ivar] = var_dataset

        # Store the dataset in the dictionary
    #     loaded_datasets[ivar] = dataset[ivar + '_' + lev_or_int]




load_and_save_vars_input_tag is: False
loading G_advection
loading G_diffusion
loading G_forcing


In [28]:
# data_used4MHWs < data_percentile3d -> nan


In [29]:
# plt.pcolor(find_mhws_info_data['data_used4MHWs'].value[1,:,:])

In [30]:
# key = 'data_used4MHWs'
# data2plot = data_used4MHWs[0,:,:]
# plot_map_from_scattered_TEST(XC_lon, YC_lat, data2plot, key, year_start, year_end)


In [31]:
# key = 'data_used4MHWs'
# data2plot = find_mhws_info_data['data_used4MHWs'][1,:,:]
# plot_map_from_scattered_TEST(XC_lon, YC_lat, data2plot, key, year_start, year_end, stat_tag)


In [32]:
# Open full heat budget terms data, use the same mask, caluclate the area for each time step, 
# where each term is the dominant



In [33]:
# Condition for where data_used is larger than data_percentile 
condition_MHWs = data_used4MHWs_sel > \
find_mhws_info_data['data_percentile3d'].value[eval(ind_time_start_slice):eval(ind_time_end_slice)+1, :, :]


In [34]:
condition_decline = np.logical_and(condition_MHWs, np.logical_not(condition_onset))


In [35]:
# conditions_events_onset_decline = [condition_decline, condition_onset, condition_MHWs]
# conditions_events_onset_decline_names = ['condition_decline', 'condition_onset', 'condition_MHWs']


In [36]:
regions_condition_list = []
regions_condition_list_GLOBAL = []

# GLOBAL
regions_condition_list_GLOBAL.append({"region_name":'Global'})

# GLOBAL NO EQ PACIFIC
regions_condition_list_GLOBAL.append({"lon_min_exclude2":[130, -180],
                   "lon_max_exclude2":[180, -140],
                   "lat_min_exclude2":[-20, -20],
                   "lat_max_exclude2":[20, 20], 
                   "region_name":'Global_no_Eq_Pacific'})

# PACIFIC OCEAN (no Equatorial)
regions_condition_list.append({"lon_min_exclude2":[130, -180, 0, -140],
                   "lon_max_exclude2":[180, -140, 130, 0],
                   "lat_min_exclude2":[-20, -20, -90, -90],
                   "lat_max_exclude2":[20, 20, 90, 90], 
                   "region_name":'Pacific_no_Eq'})

# INDIAN OCEAN
regions_condition_list.append({"lon_min":20,
                   "lon_max":146,
                   "lat_min":-60,
                   "lat_max":30, 
                   "region_name":'Indian'})

# ATLANTIC OCEAN
regions_condition_list.append({"lon_min":-80,
                   "lon_max":-10,
                   "lat_min":-90,
                   "lat_max":90, 
                   "region_name":'Atlantic'})

# SWP 
# regions_condition_list.append({"lon_min":-170.5,
#                    "lon_max":-140.5,
#                    "lat_min":-45.5,
#                    "lat_max":-25.5, 
#                    "region_name":'SWP'})

# # NEP 
# regions_condition_list.append({"lon_min":-150.5,
#                    "lon_max":-134.5,
#                    "lat_min":39.5,
#                    "lat_max":50.5, 
#                    "region_name":'NEP'})
                              
# #TASMAN
# regions_condition_list.append({"lon_min":147,
#                    "lon_max":155,
#                    "lat_min":-45,
#                    "lat_max":-37, 
#                    "region_name":'TASMAN'})



# regions_condition_list_names = ['Global', 'noEq_Pac', 'Pacific', 'SWP', 'NEP', 'TASMAN']
# budget_condition_list_names = ['MHWs', 'MHWs, Forcing > Advection', 
#                                'MHWs, Forcing > Diffusion', 
#                                'MHWs, Advection > Diffusion', 
#                                'MHWs, Forcing > Advection and Diffusion'] 
# # budget_condition_list_names = ['MHWs', 'MHWs, Forcing > Advection'] 
# budget_condition_list_names = ['MHWs', 'Forcing > 0'] 


In [37]:
def reshape_array_to_condition_list(array_input, ind_time_start_slice, ind_time_end_slice):
    arr = array_input.values
    # Reshape the other elements
    reshaped_array = np.reshape(arr, (arr.shape[0], arr.shape[1], -1))

    # Slice along the time dimension only for the second element
    sliced_array = reshaped_array[eval(ind_time_start_slice):eval(ind_time_end_slice) + 1, :, :]

    return(sliced_array)

In [38]:
condition_list = []
condition_list.append({"phase_mhw":condition_MHWs, "budget_condition":[], "tag":'MHW_all'})

################ FORCING ###########################################################################
bfr = reshape_array_to_condition_list(array_input = datasets['G_forcing'] > 0, \
                                      ind_time_start_slice = ind_time_start_slice, \
                                      ind_time_end_slice = ind_time_end_slice)
condition_list.append({"phase_mhw":condition_onset, "budget_condition":bfr, "tag":'MHW_onset_forcing_pos'})

bfr = reshape_array_to_condition_list(array_input = datasets['G_forcing'] < 0, \
                                      ind_time_start_slice = ind_time_start_slice, \
                                      ind_time_end_slice = ind_time_end_slice)
condition_list.append({"phase_mhw":condition_decline, "budget_condition":bfr, "tag":'MHW_decline_forcing_neg'})

# ################# ADVECTION ###########################################################################
# bfr = reshape_array_to_condition_list(array_input = datasets['G_advection'] > 0, \
#                                       ind_time_start_slice = ind_time_start_slice, \
#                                       ind_time_end_slice = ind_time_end_slice)
# condition_list.append({"phase_mhw":condition_onset, "budget_condition":bfr, "tag":'MHW_onset_advection_pos'})

# bfr = reshape_array_to_condition_list(array_input = datasets['G_advection'] < 0, \
#                                       ind_time_start_slice = ind_time_start_slice, \
#                                       ind_time_end_slice = ind_time_end_slice)
# condition_list.append({"phase_mhw":condition_decline, "budget_condition":bfr, "tag":'MHW_decline_advection_neg'})

# ################# DIFFUSION ###########################################################################
# bfr = reshape_array_to_condition_list(array_input = datasets['G_advection'] > 0, \
#                                       ind_time_start_slice = ind_time_start_slice, \
#                                       ind_time_end_slice = ind_time_end_slice)
# condition_list.append({"phase_mhw":condition_onset, "budget_condition":bfr, "tag":'MHW_onset_advection_pos'})

# bfr = reshape_array_to_condition_list(array_input = datasets['G_advection'] < 0, \
#                                       ind_time_start_slice = ind_time_start_slice, \
#                                       ind_time_end_slice = ind_time_end_slice)
# condition_list.append({"phase_mhw":condition_decline, "budget_condition":bfr, "tag":'MHW_decline_advection_neg'})

# budget_condition_list.append(datasets['G_forcing'] < 0) 
# budget_condition_list.append(np.abs(datasets['G_forcing']) > np.abs(datasets['G_advection'])) 
# budget_condition_list.append(np.abs(datasets['G_forcing']) > np.abs(datasets['G_diffusion'])) 
# budget_condition_list.append(np.abs(datasets['G_advection']) > np.abs(datasets['G_diffusion'])) 
# budget_condition_list.append(
#     (datasets['G_forcing']>0) &
#     (datasets['G_forcing']) > (datasets['G_advection']) &
#     ((datasets['G_forcing']) > (datasets['G_diffusion']))
# )



In [39]:
[x["tag"] for x in condition_list]

['MHW_all', 'MHW_onset_forcing_pos', 'MHW_decline_forcing_neg']

In [40]:
[x["region_name"] for x in regions_condition_list]

['Pacific_no_Eq', 'Indian', 'Atlantic']

In [41]:
[x["region_name"] for x in regions_condition_list_GLOBAL]

['Global', 'Global_no_Eq_Pacific']

In [39]:
# plot to check if masks onset and decline were calculated correctly (showing onset, decline, peak, percentile and data)

# datapercentile_threshold = find_mhws_info_data['data_percentile3d'].value[eval(ind_time_start_slice):eval(ind_time_end_slice)+1, :, :]

# plt.figure(figsize = (12,5))
# plt.plot(time_sel, data_used4MHWs_sel[:,1,1], label = 'data')
# plt.plot(time_sel, datapercentile_threshold[:,1,1], label = 'threshold')
# plt.plot(time_sel[condition_onset[:,1,1]], data_used4MHWs_sel[:,1,1][condition_onset[:,1,1]], '*', label = 'onset')
# plt.plot(time_sel[condition_decline[:,1,1]], data_used4MHWs_sel[:,1,1][condition_decline[:,1,1]], 'o', label = 'decline')
# start_date = datetime(1992, 12, 1)
# end_date = datetime(1993, 1, 15 )
# plt.xlim(start_date, end_date)
# plt.xlabel('Time', fontsize = 12)
# plt.ylabel('Data', fontsize = 12)
# plt.ylim([0,35])
# plt.legend(fontsize = 12)
# plt.tight_layout()


In [40]:
ciao

NameError: name 'ciao' is not defined

### Global

In [None]:
bfr_area_results_GLOBAL = {}

for i, icondition in enumerate(condition_list):
    print('Condition: ' + icondition["tag"])
    condition = icondition["phase_mhw"]
    ibudget = icondition["budget_condition"]
    iname = icondition["tag"]

    print('starting copy')
    
    bfr_area_reshaped = np.copy(area_reshaped)
    bfr_area_reshaped[np.isnan(data_used4MHWs_sel)] = 0 # CHECK THIS - NOT WORKING NOW
    
    bfr_area_reshaped_copy = np.copy(bfr_area_reshaped)
    
    print('ending copy')
    
    if condition != []:
        bfr_area_reshaped[~condition] = 0
    if ibudget != []:
        bfr_area_reshaped[~ibudget] = 0
    
    
    for ireg in regions_condition_list_GLOBAL:
        bfr_area_reshaped_reg = np.copy(bfr_area_reshaped)
        bfr_area_reshaped_copy_reg = np.copy(bfr_area_reshaped_copy)

#         if ireg['region_name'] is not 'Global':
        # for regions to keep
        if 'lon_min' in ireg.keys():
            bfr_area_reshaped_reg[XC_lon_reshaped < ireg["lon_min"]] = 0
            bfr_area_reshaped_reg[XC_lon_reshaped > ireg["lon_max"]] = 0
            bfr_area_reshaped_reg[YC_lat_reshaped < ireg["lat_min"]] = 0
            bfr_area_reshaped_reg[YC_lat_reshaped > ireg["lat_max"]] = 0

            bfr_area_reshaped_copy_reg[XC_lon_reshaped < ireg["lon_min"]] = 0
            bfr_area_reshaped_copy_reg[XC_lon_reshaped > ireg["lon_max"]] = 0
            bfr_area_reshaped_copy_reg[YC_lat_reshaped < ireg["lat_min"]] = 0
            bfr_area_reshaped_copy_reg[YC_lat_reshaped > ireg["lat_max"]] = 0             

        # for regions to exclude    
        elif 'lon_min_exclude2' in ireg.keys():
            for ix in np.arange(0, len(ireg['lon_min_exclude2']), 1):
                bfr_area_reshaped_reg[np.logical_and(
                    np.logical_and(XC_lon_reshaped >= ireg["lon_min_exclude2"][ix],
                                   XC_lon_reshaped <= ireg["lon_max_exclude2"][ix]),
                    np.logical_and(YC_lat_reshaped >= ireg["lat_min_exclude2"][ix],
                                   YC_lat_reshaped <= ireg["lat_max_exclude2"][ix]))] = 0 


                bfr_area_reshaped_copy_reg[np.logical_and(
                    np.logical_and(XC_lon_reshaped >= ireg["lon_min_exclude2"][ix],
                                   XC_lon_reshaped <= ireg["lon_max_exclude2"][ix]),
                    np.logical_and(YC_lat_reshaped >= ireg["lat_min_exclude2"][ix],
                                   YC_lat_reshaped <= ireg["lat_max_exclude2"][ix]))] = 0
        print('done ireg: ' + ireg['region_name'])
        
        bfr_area_tot = np.sum(bfr_area_reshaped_copy_reg.flatten())
        bfr_area_reshaped_reg_1d_2plot = np.sum(bfr_area_reshaped_reg.reshape(bfr_area_reshaped_reg.shape[0], -1), axis=1)
    
        # create list with bfr_area_reshaped_reg_1d_2plot and bfr_area_tot for each region
        # Store the results for the current region with a unique name
        region_name = ireg['region_name']
        bfr_area_results_GLOBAL[icondition['tag'] + '_' + region_name + "_bfr_area_reshaped_reg_1d_2plot"] = bfr_area_reshaped_reg_1d_2plot
        bfr_area_results_GLOBAL[icondition['tag'] + '_' + region_name + "_bfr_area_tot"] = bfr_area_tot
#         bfr_area_results_GLOBAL['condition_list'] = condition_list[i]
        path2saveinfo4plots = '/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/summary_maps/area_covered/'
        np.save(path2saveinfo4plots + f'info4plots_{icondition["tag"]}_{region_name}.npy', bfr_area_results_GLOBAL)
        
        del bfr_area_tot, bfr_area_reshaped_reg, bfr_area_reshaped_copy_reg
    del bfr_area_reshaped, bfr_area_reshaped_copy, condition, ibudget, iname
    print('done icondition iteration')
    
    
#         ciao
        # in another cell: plot
#         ciao
        # Loop through subplots and time subsets
#         for i, subplot in enumerate(axes):
#             # Define start and end indices for the current time subset
#             start_index = i * time_interval
#             end_index = start_index + time_interval - 1

# #             scaling_factor = np.nanmax(area_mhws_1d/area_tot * 100)
# #             scaling_factor_str = str('{0:.2f}'.format(scaling_factor))

#             subplot.plot(time[start_index:end_index+1],
#                          (bfr_area_reshaped_reg_1d_2plot[start_index:end_index+1]/bfr_area_tot * 100), label=ireg['region_name'] + ' ('  + ')')

#             print('done plot')

#             legend_labels.append(ireg['region_name'] + ' ('  + ')')

#             # Customize the subplot (e.g., title, labels)
#             subplot.set_title(f'Time period: {time[start_index]} - {time[end_index]}', fontsize=12)
#             subplot.set_xlabel('Time', fontsize=12)
#             subplot.set_ylabel('%', fontsize=12)
#             print('done title etc')

#         # Common title and legend for the entire figure
#         fig.suptitle(f'Percentage of area covered by MHWs - {iname}', fontsize=16)
#         fig.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=3, labels=legend_labels)
        
        

Condition: MHW_all
starting copy
ending copy
done ireg: Global
done ireg: Global_no_Eq_Pacific
done icondition iteration
Condition: MHW_onset_forcing_pos
starting copy
ending copy


In [None]:
bfr_area_results_GLOBAL.keys()

In [None]:
np.save(bfr_area_results_GLOBAL)

In [None]:
plt.plot(bfr_area_results_GLOBAL['MHW_onset_forcing_pos_Global_bfr_area_reshaped_reg_1d_2plot'])

In [None]:
plt.plot(bfr_area_results_GLOBAL['MHW_onset_forcing_pos_Global_bfr_area_reshaped_reg_1d_2plot'])

In [None]:
# Create an empty dictionary to store legend labels
legend_labels = {}

# Loop through each condition
for i, icondition in enumerate(condition_list):
    print(icondition["tag"])
    condition_name = icondition["tag"]

    # Define the number of subplots (rows)
    num_subplots = 5

    # Calculate the time step interval for each subplot (assuming equal intervals)
    time_interval = int(len(time) / num_subplots)

    # Create a figure with 5 subplots
    fig, axes = plt.subplots(num_subplots, 1, figsize=(20, 12))
    fig.subplots_adjust(hspace=0.6)

    # Loop through subplots and time subsets
    for j, subplot in enumerate(axes):
        # Define start and end indices for the current time subset
        start_index = j * time_interval
        end_index = start_index + time_interval - 1

        # Plot data for each region
        for region_name, bfr_area_reshaped_reg_1d_2plot in bfr_area_results.items():
            # Check if the region name contains the condition name and the specific plot data
            if condition_name in region_name and "bfr_area_reshaped_reg_1d_2plot" in region_name:
                # Extract the total area for normalization
                total_area_key = region_name.replace("_reshaped_reg_1d_2plot", "_tot")
                total_area = bfr_area_results.get(total_area_key, 1)  # Default to 1 if not found to avoid division by zero
                # Plot the data
                subplot.plot(time[start_index:end_index+1],
                             (bfr_area_results[region_name][start_index:end_index+1] / total_area * 100),
                             label=region_name.replace(f"_{condition_name}_bfr_area_reshaped_reg_1d_2plot", ""))

                # Add legend label to the list if not already added
                if region_name.replace(f"_{condition_name}_bfr_area_reshaped_reg_1d_2plot", "") not in legend_labels:
                    legend_labels[region_name.replace(f"_{condition_name}_bfr_area_reshaped_reg_1d_2plot", "")] = condition_name

        # Customize the subplot (e.g., title, labels)
        subplot.set_title(f'Time period: {time[start_index]} - {time[end_index]}', fontsize=12)
        subplot.set_xlabel('Time', fontsize=12)
        subplot.set_ylabel('%', fontsize=12)
        fig.subplots_adjust(top=0.93)

    # Common title and legend for the entire figure
    fig.suptitle(f'Percentage of area covered by MHWs - {condition_name}', fontsize=16)

    # Create common legend with the same colors as the plot
    handles, labels = [], []
    for label, condition_name in legend_labels.items():
        # Find the first line with the label and extract its color
        for line in axes[0].lines:
            if line.get_label() == label:
                color = line.get_color()
                break
        handles.append(plt.Line2D([0], [0], linestyle='-', color=color, label=label))
        labels.append(label)
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.05), ncol=3)


    # Show or save the plot
    plt.show()  # or plt.savefig('plot_name.png') to save the plot


In [None]:
plt.scatter(XC_lon_reshaped[0,:,:].flatten(), YC_lat_reshaped[0,:,:].flatten(), bfr_area_reshaped_copy_reg[0,:,:].flatten())



In [None]:
plt.scatter(XC_lon_reshaped[0,:,:].flatten(), YC_lat_reshaped[0,:,:].flatten(), bfr_area_reshaped_reg[0,:,:].flatten())



In [None]:
plt.plot(bfr_area_reshaped_reg_1d_2plot/bfr_area_tot)

In [None]:


# one plot for each budget condition and for each onset/decline/MHWs condition - GLOBAL
for i, icondition in enumerate(condition_list):
    print(icondition["tag"])
    condition = icondition["phase_mhw"]
    ibudget = icondition["budget_condition"]

    # Define the number of subplots (rows)
    num_subplots = 5

    # Calculate the time step interval for each subplot (assuming equal intervals)
    time_interval = int(len(time) / num_subplots)

    # Create a figure with 5 subplots
    fig, axes = plt.subplots(num_subplots, 1, figsize=(20, 12))  # Adjust figsize as needed
    fig.subplots_adjust(hspace=0.7)  # Adjust as needed

    legend_labels = []

    print('starting copy')

    bfr_data_used4MHWs_sel = copy.deepcopy(data_used4MHWs_sel)
    bfr_data_used4MHWs_sel[~condition] = np.nan

    if type(ibudget) is not list:
        bfr_data_used4MHWs_sel[~ibudget] = np.nan
    print('done bfr_data_used4MHWs_sel')

    for ireg in regions_condition_list_GLOBAL:
        bfr_data_used4MHWs_sel_reg = copy.deepcopy(bfr_data_used4MHWs_sel)

        if ireg['region_name'] is not 'Global':
            # for regions to keep
            if 'lon_min' in ireg.keys():
                bfr_data_used4MHWs_sel_reg[XC_lon_reshaped < ireg["lon_min"]] = np.nan
                bfr_data_used4MHWs_sel_reg[XC_lon_reshaped > ireg["lon_max"]] = np.nan
                bfr_data_used4MHWs_sel_reg[YC_lat_reshaped < ireg["lat_min"]] = np.nan
                bfr_data_used4MHWs_sel_reg[YC_lat_reshaped > ireg["lat_max"]] = np.nan
            # for regions to exclude    
            elif 'lon_min_exclude2' in ireg.keys():
                for ix in np.arange(0, len(ireg['lon_min_exclude2']), 1):
                    bfr_data_used4MHWs_sel_reg[np.logical_and(
                        np.logical_and(XC_lon_reshaped >= ireg["lon_min_exclude2"][ix],
                                       XC_lon_reshaped <= ireg["lon_max_exclude2"][ix]),
                        np.logical_and(YC_lat_reshaped >= ireg["lat_min_exclude2"][ix],
                                       YC_lat_reshaped <= ireg["lat_max_exclude2"][ix]))] = np.nan 
        print('done ireg')

        # Loop through subplots and time subsets
        for i, subplot in enumerate(axes):
            # Define start and end indices for the current time subset
            start_index = i * time_interval
            end_index = start_index + time_interval - 1

            # Subset data for the current time period
            bfr_data_subset = bfr_data_used4MHWs_sel_reg[start_index:end_index+1, :]

            # Your plotting logic using bfr_data_subset
            mhw_mask = np.ones(bfr_data_subset.shape)
            mhw_mask[np.isnan(bfr_data_subset)] = 0
            print('done mhw_mask')

            area_mhws = area_reshaped[start_index:end_index+1, :] * mhw_mask
            area_mhws_2d = area_mhws.reshape(bfr_data_subset.shape[0], -1)
            area_mhws_1d = np.sum(area_mhws_2d, axis=1)
            print('done area')

            scaling_factor = np.nanmax(area_mhws_1d/area_tot * 100)
            scaling_factor_str = str('{0:.2f}'.format(scaling_factor))

            subplot.plot(time[start_index:end_index+1],
                         (area_mhws_1d/area_tot * 100) / scaling_factor, label=ireg['region_name'] + ' (' + scaling_factor_str + ')')

            print('done plot')

            legend_labels.append(ireg['region_name'] + ' (' + scaling_factor_str + ')')


            # Customize the subplot (e.g., title, labels)
            subplot.set_title(f'Time period: {time[start_index]} - {time[end_index]}', fontsize=12)
            subplot.set_xlabel('Time', fontsize=12)
            subplot.set_ylabel('%', fontsize=12)
            print('done title etc')

        # Common title and legend for the entire figure
        fig.suptitle(f'Percentage of area covered by MHWs - {ibudget_names} - {condition_name}', fontsize=16)
        fig.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=3, labels=legend_labels)
#         del bfr_data_used4MHWs_sel
#         del area_mhws
#         del area_mhws_2d
#         del area_mhws_1d
#         del mhw_mask
#         del bfr_data_subset

    # Save the figure
#         plt.savefig(f'/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/line_plots/area_covered_MHWs/area_covered_{ibudget_names}_{condition_name}_subplots.png', dpi=1000)
    plt.show()  # Display the figure
    plt.close(fig)  # Close the figure to avoid memory issues
    print('end figure')
ciao
# Your original plotting code (excluding the loops) using the full bfr_data_used4MHWs_sel can go here (optional)


In [None]:
# import matplotlib.pyplot as plt
# import numpy as np

# # one plot for each budget condition and for each onset/decline/MHWs condition - GLOBAL
# for ibudget, ibudget_names in zip(budget_condition_list_reshaped, budget_condition_list_names):
#     print('restart')
#     print(ibudget_names)
#     for condition, condition_name in zip(conditions_events_onset_decline, conditions_events_onset_decline_names):
#         print(condition_name)

#         # Define the number of subplots (rows)
#         num_subplots = 5

#         # Calculate the time step interval for each subplot (assuming equal intervals)
#         time_interval = int(len(time) / num_subplots)

#         # Create a figure with 5 subplots
#         fig, axes = plt.subplots(num_subplots, 1, figsize=(20, 12))  # Adjust figsize as needed
#         fig.subplots_adjust(hspace=0.7)  # Adjust as needed
        
#         legend_labels = []

#         print('starting copy')
        
#         bfr_data_used4MHWs_sel = data_used4MHWs_sel.copy()
#         bfr_data_used4MHWs_sel[~condition] = np.nan

#         if type(ibudget) is not list:
#             bfr_data_used4MHWs_sel[~ibudget] = np.nan
#         print('done bfr_data_used4MHWs_sel')

#         for ireg in regions_condition_list_GLOBAL:
        
#             if ireg['region_name'] is not 'Global':
#                 # for regions to keep
#                 if 'lon_min' in ireg.keys():
#                     bfr_data_used4MHWs_sel[XC_lon_reshaped < ireg["lon_min"]] = np.nan
#                     bfr_data_used4MHWs_sel[XC_lon_reshaped > ireg["lon_max"]] = np.nan
#                     bfr_data_used4MHWs_sel[YC_lat_reshaped < ireg["lat_min"]] = np.nan
#                     bfr_data_used4MHWs_sel[YC_lat_reshaped > ireg["lat_max"]] = np.nan
#                 # for regions to exclude    
#                 elif 'lon_min_exclude2' in ireg.keys():
#                     for ix in np.arange(0, len(ireg['lon_min_exclude2']), 1):
#                         bfr_data_used4MHWs_sel[np.logical_and(
#                             np.logical_and(XC_lon_reshaped >= ireg["lon_min_exclude2"][ix],
#                                            XC_lon_reshaped <= ireg["lon_max_exclude2"][ix]),
#                             np.logical_and(YC_lat_reshaped >= ireg["lat_min_exclude2"][ix],
#                                            YC_lat_reshaped <= ireg["lat_max_exclude2"][ix]))] = np.nan 
#             print('done ireg')
                        
#             # Loop through subplots and time subsets
#             for i, subplot in enumerate(axes):
#                 # Define start and end indices for the current time subset
#                 start_index = i * time_interval
#                 end_index = start_index + time_interval - 1

#                 # Subset data for the current time period
#                 bfr_data_subset = bfr_data_used4MHWs_sel[start_index:end_index+1, :]

#                 # Your plotting logic using bfr_data_subset
#                 mhw_mask = np.ones(bfr_data_subset.shape)
#                 mhw_mask[np.isnan(bfr_data_subset)] = 0
#                 print('done mhw_mask')

#                 area_mhws = area_reshaped[start_index:end_index+1, :] * mhw_mask
#                 area_mhws_2d = area_mhws.reshape(bfr_data_subset.shape[0], -1)
#                 area_mhws_1d = np.sum(area_mhws_2d, axis=1)
#                 print('done area')

#                 scaling_factor = np.nanmax(area_mhws_1d/area_tot * 100)
#                 scaling_factor_str = str('{0:.2f}'.format(scaling_factor))

#                 subplot.plot(time[start_index:end_index+1],
#                              (area_mhws_1d/area_tot * 100) / scaling_factor, label=ireg['region_name'] + ' (' + scaling_factor_str + ')')

#                 print('done plot')
                
#                 legend_labels.append(ireg['region_name'] + ' (' + scaling_factor_str + ')')

                    
#                 # Customize the subplot (e.g., title, labels)
#                 subplot.set_title(f'Time period: {time[start_index]} - {time[end_index]}', fontsize=12)
#                 subplot.set_xlabel('Time', fontsize=12)
#                 subplot.set_ylabel('%', fontsize=12)
#                 print('done title etc')

#             # Common title and legend for the entire figure
#             fig.suptitle(f'Percentage of area covered by MHWs - {ibudget_names} - {condition_name}', fontsize=16)
#             fig.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=3, labels=legend_labels)
#             del bfr_data_used4MHWs_sel
#             del area_mhws
#             del area_mhws_2d
#             del area_mhws_1d
#             del mhw_mask
#             del bfr_data_subset

#         # Save the figure
# #         plt.savefig(f'/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/line_plots/area_covered_MHWs/area_covered_{ibudget_names}_{condition_name}_subplots.png', dpi=1000)
#         plt.show()  # Display the figure
#         plt.close(fig)  # Close the figure to avoid memory issues
#         print('end figure')
#     ciao
#     # Your original plotting code (excluding the loops) using the full bfr_data_used4MHWs_sel can go here (optional)


In [None]:
ciao

In [None]:
# import matplotlib.pyplot as plt
# import numpy as np
# import copy


# # One plot for each combination of ibudget and condition
# for ibudget, ibudget_names in zip(budget_condition_list_reshaped, budget_condition_list_names):
#     print('restart')
#     print(ibudget_names)
#     for condition, condition_name in zip(conditions_events_onset_decline, conditions_events_onset_decline_names):
#         print(condition_name)

#         # Define the number of subplots (rows)
#         num_subplots = 5

#         # Calculate the time step interval for each subplot (assuming equal intervals)
#         time_interval = int(len(time) / num_subplots)

#         # Create a figure with 5 subplots
#         fig, axes = plt.subplots(num_subplots, 1, figsize=(20, 12))  # Adjust figsize as needed
#         fig.subplots_adjust(hspace=0.7)  # Adjust as needed

#         legend_labels = []

#         bfr_data_used4MHWs_sel_copy = copy.deepcopy(data_used4MHWs_sel)
#         bfr_data_used4MHWs_sel_copy[~condition] = np.nan

#         if not isinstance(ibudget, list):
#             bfr_data_used4MHWs_sel_copy[~ibudget] = np.nan

#         for ireg in regions_condition_list_GLOBAL:
#             if ireg['region_name'] != 'Global':
#                 if 'lon_min' in ireg.keys():
#                     bfr_data_used4MHWs_sel_copy[XC_lon_reshaped < ireg["lon_min"]] = np.nan
#                     bfr_data_used4MHWs_sel_copy[XC_lon_reshaped > ireg["lon_max"]] = np.nan
#                     bfr_data_used4MHWs_sel_copy[YC_lat_reshaped < ireg["lat_min"]] = np.nan
#                     bfr_data_used4MHWs_sel_copy[YC_lat_reshaped > ireg["lat_max"]] = np.nan
#                 elif 'lon_min_exclude2' in ireg.keys():
#                     for ix in range(len(ireg['lon_min_exclude2'])):
#                         bfr_data_used4MHWs_sel_copy[
#                             np.logical_and(
#                                 np.logical_and(XC_lon_reshaped >= ireg["lon_min_exclude2"][ix],
#                                                XC_lon_reshaped <= ireg["lon_max_exclude2"][ix]),
#                                 np.logical_and(YC_lat_reshaped >= ireg["lat_min_exclude2"][ix],
#                                                YC_lat_reshaped <= ireg["lat_max_exclude2"][ix]))
#                         ] = np.nan

#             for i, subplot in enumerate(axes):
#                 start_index = i * time_interval
#                 end_index = start_index + time_interval - 1

#                 bfr_data_subset = bfr_data_used4MHWs_sel_copy[start_index:end_index + 1, :]
#                 mhw_mask = np.ones(bfr_data_subset.shape)
#                 mhw_mask[np.isnan(bfr_data_subset)] = 0

#                 area_mhws = area_reshaped[start_index:end_index + 1, :] * mhw_mask
#                 area_mhws_2d = area_mhws.reshape(bfr_data_subset.shape[0], -1)
#                 area_mhws_1d = np.sum(area_mhws_2d, axis=1)

#                 scaling_factor = np.nanmax(area_mhws_1d / area_tot * 100)
#                 scaling_factor_str = str('{0:.2f}'.format(scaling_factor))

#                 subplot.plot(time[start_index:end_index + 1],
#                              (area_mhws_1d / area_tot * 100) / scaling_factor,
#                              label=ireg['region_name'] + ' (' + scaling_factor_str + ')')

#                 legend_labels.append(ireg['region_name'] + ' (' + scaling_factor_str + ')')

#                 subplot.set_title(f'Time period: {time[start_index]} - {time[end_index]}', fontsize=12)
#                 subplot.set_xlabel('Time', fontsize=12)
#                 subplot.set_ylabel('%', fontsize=12)

#         fig.suptitle(f'Percentage of area covered by MHWs - {ibudget_names} - {condition_name}', fontsize=16)
#         fig.legend(loc='lower center', bbox_to_anchor=(0.5, 0), ncol=3, labels=legend_labels)

#         plt.show()
#         plt.close(fig)
#         print('end figure')

        

### Plots for Global 

In [None]:
# one plot for each budget condition and for each onset/decline/MHWs condition - GLOBAL
for ibudget, ibudget_names in zip(budget_condition_list_reshaped, budget_condition_list_names):
    for condition, condition_name in zip(conditions_events_onset_decline, conditions_events_onset_decline_names):
        plt.figure(figsize=(20,8))  # Create a new figure for each combination of budget condition and condition onset
        for ireg in regions_condition_list_GLOBAL:
            bfr_data_used4MHWs_sel = copy.deepcopy(data_used4MHWs_sel)
            bfr_data_used4MHWs_sel[~condition] = np.nan
            
            if type(ibudget) is not list:
                bfr_data_used4MHWs_sel[~ibudget] = np.nan

            if ireg['region_name'] is not 'Global':
                # for regions to keep
                if 'lon_min' in ireg.keys():
                    bfr_data_used4MHWs_sel[XC_lon_reshaped < ireg["lon_min"]] = np.nan
                    bfr_data_used4MHWs_sel[XC_lon_reshaped > ireg["lon_max"]] = np.nan
                    bfr_data_used4MHWs_sel[YC_lat_reshaped < ireg["lat_min"]] = np.nan
                    bfr_data_used4MHWs_sel[YC_lat_reshaped > ireg["lat_max"]] = np.nan
                # for regions to exclude    
                elif 'lon_min_exclude2' in ireg.keys():
                    for ix in np.arange(0, len(ireg['lon_min_exclude2']), 1):
                        bfr_data_used4MHWs_sel[np.logical_and(
                            np.logical_and(XC_lon_reshaped >= ireg["lon_min_exclude2"][ix],
                                           XC_lon_reshaped <= ireg["lon_max_exclude2"][ix]),
                            np.logical_and(YC_lat_reshaped >= ireg["lat_min_exclude2"][ix],
                                           YC_lat_reshaped <= ireg["lat_max_exclude2"][ix]))] = np.nan    

            mhw_mask = np.ones(bfr_data_used4MHWs_sel.shape)
            mhw_mask[np.isnan(bfr_data_used4MHWs_sel)] = 0

            area_mhws = area_reshaped * mhw_mask
            area_mhws_2d = area_mhws.reshape(bfr_data_used4MHWs_sel.shape[0], -1)
            area_mhws_1d = np.sum(area_mhws_2d, axis=1)

            scaling_factor = np.nanmax(area_mhws_1d/area_tot * 100)
            scaling_factor_str = str('{0:.2f}'.format(scaling_factor))
            plt.plot(time[eval(ind_time_start_slice):eval(ind_time_end_slice)+1],
                     (area_mhws_1d/area_tot * 100) / scaling_factor, label=ireg['region_name'] + ' (' + scaling_factor_str + ')')

        plt.legend()
        plt.title('Percentage of area covered by MHWs where condition "' + ibudget_names + ' - ' + condition_name + '" is met', fontsize=16)
        plt.xlabel('Time', fontsize=16)
        plt.ylabel('%', fontsize=16)
#         plt.savefig(f'/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/line_plots/area_covered_MHWs/area_covered_global_{ibudget_names}_{condition_name}.png', dpi=1000)
        plt.show()


### Plots for NO Global 

In [None]:
# one plot for each budget condition and for each onset/decline/MHWs condition - NO GLOBAL
for ibudget, ibudget_names in zip(budget_condition_list_reshaped, budget_condition_list_names):
    for condition, condition_name in zip(conditions_events_onset_decline, conditions_events_onset_decline_names):
        plt.figure(figsize=(20,8))  # Create a new figure for each combination of budget condition and condition onset
        for ireg in regions_condition_list:
            bfr_data_used4MHWs_sel = copy.deepcopy(data_used4MHWs_sel)
            bfr_data_used4MHWs_sel[~condition] = np.nan

            if type(ibudget) is not list:
                bfr_data_used4MHWs_sel[~ibudget] = np.nan

            if ireg['region_name'] is not 'Global':
                # for regions to keep
                if 'lon_min' in ireg.keys():
                    bfr_data_used4MHWs_sel[XC_lon_reshaped < ireg["lon_min"]] = np.nan
                    bfr_data_used4MHWs_sel[XC_lon_reshaped > ireg["lon_max"]] = np.nan
                    bfr_data_used4MHWs_sel[YC_lat_reshaped < ireg["lat_min"]] = np.nan
                    bfr_data_used4MHWs_sel[YC_lat_reshaped > ireg["lat_max"]] = np.nan
                # for regions to exclude    
                elif 'lon_min_exclude2' in ireg.keys():
                    for ix in np.arange(0, len(ireg['lon_min_exclude2']), 1):
                        bfr_data_used4MHWs_sel[np.logical_and(
                            np.logical_and(XC_lon_reshaped >= ireg["lon_min_exclude2"][ix],
                                           XC_lon_reshaped <= ireg["lon_max_exclude2"][ix]),
                            np.logical_and(YC_lat_reshaped >= ireg["lat_min_exclude2"][ix],
                                           YC_lat_reshaped <= ireg["lat_max_exclude2"][ix]))] = np.nan    

            mhw_mask = np.ones(bfr_data_used4MHWs_sel.shape)
            mhw_mask[np.isnan(bfr_data_used4MHWs_sel)] = 0

            area_mhws = area_reshaped * mhw_mask
            area_mhws_2d = area_mhws.reshape(bfr_data_used4MHWs_sel.shape[0], -1)
            area_mhws_1d = np.sum(area_mhws_2d, axis=1)

            scaling_factor = np.nanmax(area_mhws_1d/area_tot * 100)
            scaling_factor_str = str('{0:.2f}'.format(scaling_factor))
            plt.plot(time[eval(ind_time_start_slice):eval(ind_time_end_slice)+1],
                     (area_mhws_1d/area_tot * 100) / scaling_factor, label=ireg['region_name'] + ' (' + scaling_factor_str + ')', linewidth = 2)

        plt.legend()
        plt.title('Percentage of area covered by MHWs where condition "' + ibudget_names + ' - ' + condition_name + '" is met', fontsize=16)
        plt.xlabel('Time', fontsize=16)
        plt.ylabel('%', fontsize=16)
        plt.savefig(f'/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/code/line_plots/area_covered_MHWs/area_covered_{ibudget_names}_{condition_name}.png', dpi=1000)
        plt.show()


In [None]:
ciao

# DO NOT TOUCH

In [None]:
# one plot for each budget condition and for each onset/decline/MHWs condition
for ibudget, ibudget_names in zip(budget_condition_list_reshaped, budget_condition_list_names):
    for condition, condition_name in zip(conditions_events_onset_decline, conditions_events_onset_decline_names):
        plt.figure(figsize=(15,6))  # Create a new figure for each combination of budget condition and condition onset
        for ireg in regions_condition_list:
            bfr_data_used4MHWs_sel = copy.deepcopy(data_used4MHWs_sel)
            bfr_data_used4MHWs_sel[~condition] = np.nan

            if type(ibudget) is not list:
                bfr_data_used4MHWs_sel[~ibudget] = np.nan

            if ireg['region_name'] is not 'Global':
                # for regions to keep
                if 'lon_min' in ireg.keys():
                    bfr_data_used4MHWs_sel[XC_lon_reshaped < ireg["lon_min"]] = np.nan
                    bfr_data_used4MHWs_sel[XC_lon_reshaped > ireg["lon_max"]] = np.nan
                    bfr_data_used4MHWs_sel[YC_lat_reshaped < ireg["lat_min"]] = np.nan
                    bfr_data_used4MHWs_sel[YC_lat_reshaped > ireg["lat_max"]] = np.nan
                # for regions to exclude    
                elif 'lon_min_exclude2' in ireg.keys():
                    for ix in np.arange(0, len(ireg['lon_min_exclude2']), 1):
                        bfr_data_used4MHWs_sel[np.logical_and(
                            np.logical_and(XC_lon_reshaped >= ireg["lon_min_exclude2"][ix],
                                           XC_lon_reshaped <= ireg["lon_max_exclude2"][ix]),
                            np.logical_and(YC_lat_reshaped >= ireg["lat_min_exclude2"][ix],
                                           YC_lat_reshaped <= ireg["lat_max_exclude2"][ix]))] = np.nan    

            mhw_mask = np.ones(bfr_data_used4MHWs_sel.shape)
            mhw_mask[np.isnan(bfr_data_used4MHWs_sel)] = 0

            area_mhws = area_reshaped * mhw_mask
            area_mhws_2d = area_mhws.reshape(bfr_data_used4MHWs_sel.shape[0], -1)
            area_mhws_1d = np.sum(area_mhws_2d, axis=1)

            scaling_factor = np.nanmax(area_mhws_1d/area_tot * 100)
            scaling_factor_str = str('{0:.2f}'.format(scaling_factor))
            plt.plot(time[eval(ind_time_start_slice):eval(ind_time_end_slice)+1],
                     (area_mhws_1d/area_tot * 100) / scaling_factor, label=ireg['region_name'] + ' (' + scaling_factor_str + ')')

        plt.legend()
        plt.title('Percentage of area covered by MHWs where condition "' + ibudget_names + ' - ' + condition_name + '" is met', fontsize=16)
        plt.xlabel('Time', fontsize=16)
        plt.ylabel('%', fontsize=16)
        plt.show()


In [None]:
# Plot to check regions excluded/kept

# import cartopy.crs as ccrs
# import matplotlib.pyplot as plt
# import numpy as np
# import shapely.geometry as sgeom

# # Define bounding box coordinates and labels
# bounding_boxes = [
#     ((20, 146, -60, 30), 'Indian (2keep)', 'blue', 'None'),  
#     ((-80, 20, -90, 90), 'Atlantic (2keep)', 'green', 'None'),      
#     ((-170.5, -140.5, -45.5, -25.5), 'SWP (2keep)', 'orange', 'None'),      
#     ((-150.5, -134.5, 39.5, 50.5), 'NEP (2keep)', 'purple', 'None'), 
#     ((147, 155, -45, -37), 'TASMAN (2keep)', 'gold', 'None')     
# ]

# # Create a PlateCarree projection
# projection = ccrs.PlateCarree()

# # Create a figure and axis with Cartopy projection
# fig, ax = plt.subplots(subplot_kw={'projection': projection}, figsize=(20,6))

# # Set extent to global
# ax.set_global()

# # Add coastlines
# ax.coastlines()

# # Loop through bounding boxes and add them to the plot
# for bbox, label, color, color_shading in bounding_boxes:
#     lon_min, lon_max, lat_min, lat_max = bbox
#     bounding_box = sgeom.box(lon_min, lat_min, lon_max, lat_max)
#     ax.add_geometries([bounding_box], crs=projection, edgecolor=color, facecolor=color_shading, linewidth=2)
    
#     # Add label
#     ax.text((lon_min + lon_max) / 2, (lat_min + lat_max) / 2, label, transform=projection,
#             horizontalalignment='center', verticalalignment='center', fontsize=10, color=color)

# # Add gridlines
# ax.gridlines(draw_labels=True, linestyle='--')

# # Add title
# plt.title('Global Map with Multiple Bounding Boxes')

# # Show plot
# plt.show()



# import cartopy.crs as ccrs
# import matplotlib.pyplot as plt
# import numpy as np
# import shapely.geometry as sgeom

# # Define bounding box coordinates and labels
# bounding_boxes = [    
#     ((-180, -100, 20, 90), 'Pacific_noEQ', 'red', 'None'),      
#     ((-180, -70, -20, -90), 'Pacific_noEQ', 'red', 'None'),      
#     ((130, 180, -20, -90), 'Pacific_noEQ', 'red', 'None'),      
#     ((130, 180, 20, 90), 'Pacific_noEQ', 'red', 'None')
# ]

# # Create a PlateCarree projection
# projection = ccrs.PlateCarree()

# # Create a figure and axis with Cartopy projection
# fig, ax = plt.subplots(subplot_kw={'projection': projection}, figsize=(20,6))

# # Set extent to global
# ax.set_global()

# # Add coastlines
# ax.coastlines()

# # Loop through bounding boxes and add them to the plot
# for bbox, label, color, color_shading in bounding_boxes:
#     lon_min, lon_max, lat_min, lat_max = bbox
#     bounding_box = sgeom.box(lon_min, lat_min, lon_max, lat_max)
#     ax.add_geometries([bounding_box], crs=projection, edgecolor=color, facecolor=color_shading, linewidth=2)
    
#     # Add label
#     ax.text((lon_min + lon_max) / 2, (lat_min + lat_max) / 2, label, transform=projection,
#             horizontalalignment='center', verticalalignment='center', fontsize=10, color=color)

# # Add gridlines
# ax.gridlines(draw_labels=True, linestyle='--')

# # Add title
# plt.title('Global Map with Multiple Bounding Boxes')

# # Show plot
# plt.show()


In [None]:
for ireg in area_mhws_store.keys():
    plt.figure(figsize=(15,6))
    bfr_area_msk = area_mhws_store[ireg].flatten()>0
    bfr_XC = XC_lon_reshaped.flatten()[bfr_area_msk]
    bfr_YC = YC_lat_reshaped.flatten()[bfr_area_msk]
    plt.hist2d(bfr_XC, bfr_YC, bins=360)
    plt.colorbar()
#     plt.plot(XC_lon_reshaped.flatten(), YC_lat_reshaped.flatten(), area_mhws_store[ireg].flatten(), linestyle = '', marker = '.')
    plt.title(ireg)
    

In [None]:

# 'peak_tstep'
# 'start_tstep_msk'

In [None]:
import numpy as np

# Assuming you have these arrays defined:
# peak_tstep_2d: numpy array of shape (9495, 105300)
# start_tstep_msk_2d: numpy array of shape (9495, 105300)
# onset_mask: numpy array of shape (9495, 105300)

# Create a boolean mask where start_tstep_msk_2d equals 1
mask = start_tstep_msk_2d == 1

# Reshape peak_tstep_2d to match the shape of onset_mask
reshaped_peak = peak_tstep_2d[:, :, np.newaxis]

# Create a range of indices from i_t to peak_tstep_2d[i_t, ix] inclusive
indices = np.arange(onset_mask.shape[0])[:, np.newaxis, np.newaxis] <= reshaped_peak

# Combine the mask and indices to set values in onset_mask
onset_mask[mask & indices] = 1


In [None]:
indices.shape

In [None]:
# plt.pcolor(mhw_mask[0,:,:])
# plt.colorbar()

In [None]:
# lon1d = XC_lon_reshaped.flatten()
# lat1d = YC_lat_reshaped.flatten()

In [None]:
# plt.scatter(lon1d[mhw_mask.flatten()==1], lat1d[mhw_mask.flatten()==1], '.')

In [None]:
# Load area file
# area = xr.open_dataset('/Users/jacoposala/Desktop/CU/3.RESEARCH/NASA_project/NEW_heatBudgetECCO/data/metadata_ecco/ECCOv4r4_area_1993_2017.nc')
# area_transpose = area.area.values #.transpose(0, 2, 1)
# # Reshape to match data_used4MHWs
# area_reshaped = area_transpose.reshape((data_used4MHWs.shape[1], data_used4MHWs.shape[2]))
# area_reshaped = np.repeat(np.expand_dims(area_reshaped, axis=2), data_used4MHWs.shape[0], axis=2)
# area_reshaped = area_reshaped.transpose(2, 0, 1)


In [None]:
# add OISST in the future

In [None]:
# Line plot area covered by MHWs

In [None]:
# sdate = date(year_start,1,2)   # start date
# edate = date(2017,12,31)   # end date
# time = pandas.date_range(sdate,edate-timedelta(days=1),freq='d')

In [None]:

# Percentage of area covered by MHWs in time
plt.figure(figsize=(15,6))
plt.plot(time_sel, area_mhws_1d/area_tot*100)
plt.title('Percentage of area covered by MHWs in time', fontsize = 18)
plt.tick_params(axis='both', which='major', labelsize=14)



In [None]:
plt.figure(figsize=(15,6))
plt.plot(area_mhws_1d)

In [None]:
# add plot percentage of area covered by MHWs where each term of the budget is dominant
# load of heat budget terms
# for each term, we look when one is bigger than the other 2
# we create a new mask in addition to mhw_mask


In [None]:
plt.pcolor(area_mhws[0,:,:])
plt.colorbar()

In [None]:
plt.pcolor(mhw_mask[0,:,:])
plt.colorbar()

In [None]:
plt.pcolor(find_mhws_info_data['G_advection_eventAve'][0,:,:])
plt.colorbar()



In [None]:
# Add

In [None]:
# Main function to make maps
def plot_map_from_scattered_TEST(XC_lon, YC_lat, d2plot, keyplot, year_start, year_end):

    #__________
    # Set data needed for the map
    #__________    
    if dataset_tag == 'ECCOv4r4_heat':
        # Reshape XC_lon and YC_lat to match data2plot shape
        d2plot_lon = np.reshape(XC_lon.XC_lon.values, [XC_lon.XC_lon.shape[0], XC_lon.XC_lon.shape[1]*XC_lon.XC_lon.shape[2]])
        d2plot_lat = np.reshape(YC_lat.YC_lat.values, [YC_lat.YC_lat.shape[0], YC_lat.YC_lat.shape[1]*YC_lat.YC_lat.shape[2]])

        # Shift longitudes
        #d2plot_lon_shifted = shift_longitude(d2plot_lon)

        # Reshape XC_lon and YC_lat
        points = np.concatenate((np.reshape(d2plot_lon, [d2plot_lon.shape[0]*d2plot_lon.shape[1], 1]), np.reshape(d2plot_lat, [d2plot_lat.shape[0]*d2plot_lat.shape[1], 1])), axis=1)
        # Reshape data2plot
        values = np.reshape(d2plot, [d2plot.shape[0]*d2plot.shape[1], 1])

        # Define the new grid for the map (360x180)
        grid_x, grid_y = np.meshgrid(np.linspace(-179.5, 179.5, 360),
                                     np.linspace(-89.5, 89.5, 180), indexing='ij')
    
        # Interpolate using nearest neighbor
        grid_z0 = np.transpose(griddata(points, values, (grid_x, grid_y), method='nearest'), [1,0,2])[:,:,0]

    elif dataset_tag == 'oisst_v2':

        #Shift longitudes
        grid_x, grid_y = np.meshgrid(XC_lon + 20., YC_lat, indexing='ij')
        grid_z0 = d2plot
              
    elif dataset_tag == 'argo_ohc15_50':
        grid_x, grid_y = np.meshgrid(XC_lon,YC_lat, indexing='ij')
        grid_z0 = d2plot

    #__________
    # Set color bar details based on specific type of plot and/or variable being plotted
    #__________    
    # Determine vmax as the maximum absolute value
    lev_or_int == 'ohc_k0_k5'
    vmax = 90 #60
    vmin = 0
    cmap = plt.get_cmap('Reds') 
    levels = np.linspace(vmin, vmax, 30)
    levels_cbar = levels
            
            
    # Set other tags needed for the title of the map
    if 'onset' in keyplot:
        titletag = 'onset'
        value_tag = 'positive'
    elif 'decline' in keyplot:
        titletag = 'decline'
        value_tag = 'negative'
    else:
        titletag = 'events'
        
    common_name = remove_text_after_second_underscore(keyplot)
    
    #__________
    # Make plot
    #__________  
    fig = plt.figure(figsize=(20, 8))
    ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson(central_longitude=-180))
    norm = mcolors.BoundaryNorm(levels, cmap.N)
    im = ax.pcolormesh(grid_x.T, grid_y.T, grid_z0, cmap=cmap, norm=norm, transform=ccrs.PlateCarree())       
    ax.add_feature(cfeature.LAND, edgecolor='black', facecolor='lightgrey', zorder=1)
    fig.colorbar(im, ax=ax, ticks=levels_cbar, boundaries=levels)

    plt.show()
    return

In [None]:
def remove_text_after_second_underscore(name):
    parts = name.split('_')
    if len(parts) >= 3:
        return '_'.join(parts[:2])
    return name