## This script uses Mani's TIMPS data (https://docs.google.com/document/d/1aAjcnRubP0HV8hbdqYfNSSnzMediNNgUlP27cOhq2d8/edit) to plot PDFs of ERA5 convergence for each CPEX-CV convective case that is a (are) TIMPS-tracked MCS(s).

In [None]:
import os
import sys
import math
import h5py
import xarray as xr
import numpy as np
import pandas as pd
import scipy.stats

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  #to get python's normal library of colormaps
import matplotlib.colors as mplc

import cartopy.crs as ccrs
import cartopy.feature as cfeature
#from cartopy.util import add_cyclic_point
#from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

from datetime import datetime
from datetime import timedelta

import metpy.calc as mpcalc
from metpy.units import units

from PIL import Image
import icartt            #needed to read .ict files

import time

tstart = time.time()


In [None]:
timps_test_path = os.path.join(os.getcwd(), 'TIMPS_data', 'TIMPS_0230041_202209060430_16_-24.nc')

testds = xr.open_dataset(timps_test_path)
testds


In [None]:
#ERA5 data
era5_folder = os.path.join(os.getcwd(), 'ERA5_Reanalysis_Data')
era5_path = os.path.join(era5_folder, 'CPEXCV_ERA5_Reanalysis_Hourly_Pressure.nc')
ds_era5_test = xr.open_dataset(era5_path)

ds_era5_test

In [None]:
# aew_path_old = os.path.join(aew_folder, 'AEW_data_postprocessed__0924_UPDATE_EXTENDED_2022_B1-6hr_OLD.nc')
# xr.open_dataset(aew_path_old)
# #^^^ the old file above is very similar to the new file below, but without AEW_strength, TC_gen_time, and TC_name variables


In [None]:
#AEW tracker data, 6-hourly (Quinton’s AEW tracker: https://osf.io/jnv5u, https://github.com/qlawton/QTrack?tab=readme-ov-file)
aew_folder = os.path.join(os.getcwd(), 'AEW_Tracker_Data')
#aew_path_old = os.path.join(aew_folder, 'AEW_data_postprocessed__0924_UPDATE_EXTENDED_2022_B1-6hr.nc')
aew_path = os.path.join(aew_folder, 'AEW_tracks_post_processed_year_2022.nc')
ds_aew_test = xr.open_dataset(aew_path)

ds_aew_test


In [None]:
###the only variables you need to change are the 2 in this cell
pressures_to_plot_conv = [975, 700]    #desired pressure levels for convergence

#dict of desired cases to calculate ERA5 convergence PDFs for and each case's CPEX-CV collection time range
    #(rounded to the nearest hours (to match with hourly ERA5 data), with start/end times at XX:30 UTC rounded
    #to the most inclusive hour) (see "CPEX-CV Well Documented Convection.docx" for collection hours)    
    
#CPEX-CV convective cases (average time for dropsondes for a given convective case, per CPEX-CV Well Documented Convection.docx)
# case_dict_conv = {8: ['20220906', [11,12]],
#                   9: ['20220906', [13,14,15,16]],
#                   10: ['20220906', [16,17,18]],
#                   11: ['20220907', [13,14]],
#                   12: ['20220907', [14,15,16,17,18]],
#                   13: ['20220907', [15,16]],
#                   14: ['20220914', [10,11,12,14,15]],
#                   15: ['20220914', [12,13,14,15,16,17]],
#                   16: ['20220916', [14,15,16]],
#                   17: ['20220916', [17,18,19]],
#                   18: ['20220922', [6,7,8,9]],
#                   19: ['20220923', [9,10,11,12,13,14,15]],
#                   20: ['20220926', [7,8,9,10,11]],
#                   21: ['20220929', [11,12,13,14]],
#                   22: ['20220930', [14,15]]}

#CPEX-CV convective cases associated with an AEW
    #(average time for dropsondes for a given convective case, per CPEX-CV Well Documented Convection.docx)
# case_dict_conv_aew = {8: ['20220906', [11,12]],
#                   9: ['20220906', [13,14,15,16]],
#                   10: ['20220906', [16,17,18]],
#                   11: ['20220907', [13,14]],
#                   12: ['20220907', [14,15,16,17,18]],
#                   13: ['20220907', [15,16]],
#                   16: ['20220916', [14,15,16]],
#                   17: ['20220916', [17,18,19]],
#                   18: ['20220922', [6,7,8,9]],
#                   19: ['20220923', [9,10,11,12,13,14,15]],
#                   20: ['20220926', [7,8,9,10,11]],
#                   21: ['20220929', [11,12,13,14]],
#                   22: ['20220930', [14,15]]}

#CPEX-CV convective cases to compare against one another for AGU paper
case_dict_conv_aew = {10: ['20220906', [16,17,18]],
                      12: ['20220907', [14,15,16,17,18]]}

#test
#case_dict_conv = {19: ['20220923', [9,10,11,12,13,14,15]]}


### THE FOLLOWING CELL PLOTS A 3-PANEL PLOT OF CONVECTION-RELATIVE CONVERGENCE PDFs (1 PANEL PER LIFECYCLE STAGE)

In [None]:
# #THIS CELL PLOTS A 3-PANEL PLOT OF CONVECTION-RELATIVE CONVERGENCE PDFs (1 PANEL PER LIFECYCLE STAGE)

# #For each CPEX-CV case, calculate and plot domain-PDFs of low- (975 hPa) and mid- (700 hPa) level convergence 
# #(display median, mean, standard deviation, skewness, and kurtosis as well), with the domain being an 
# #2-by-2 degree box around TIMPS MCS centroids (e.g., Galarneau et al., 2023)
#     #partition the data by lifecycle of the convection (use TIMPS "gmd" variable)
#     #for calculating domain-mean convergence, cosine-weight each grid box value (see AOS 573 material)
#     #KEEP TRACK OF WHICH GRID CELLS YOU HAVE ALREADY ADDED CONVERGENCE FOR, SO THAT YOU DON'T
#         #DOUBLE COUNT GRID CELLS WHEN YOU HAVE A CASE WITH MULTIPLE TIMPS IDS WHOSE BOXES MAY OVERLAP


# #set some baseline plot displays

# #matplotlib.rcParams['axes.facecolor'] = [0.9,0.9,0.9]
# matplotlib.rcParams['axes.labelsize'] = 16
# matplotlib.rcParams['axes.titlesize'] = 16
# matplotlib.rcParams['axes.labelweight'] = 'bold'
# matplotlib.rcParams['axes.titleweight'] = 'bold'
# matplotlib.rcParams['xtick.labelsize'] = 16
# matplotlib.rcParams['ytick.labelsize'] = 16
# matplotlib.rcParams['legend.fontsize'] = 16
# #matplotlib.rcParams['legend.facecolor'] = 'w'
# #matplotlib.rcParams['axes.facecolor'] = 'w'
# matplotlib.rcParams['font.family'] = 'arial'
# matplotlib.rcParams['hatch.linewidth'] = 0.3
        
# #Dropsonde data
# drop_metric_filepath = os.path.join(os.getcwd(), 'Dropsonde_Metric_Calculations_CPEXCV.csv')
# df_drop = pd.read_csv(drop_metric_filepath)

# #ERA5 data
# era5_folder = os.path.join(os.getcwd(), 'ERA5_Reanalysis_Data')
# era5_path = os.path.join(era5_folder, 'CPEXCV_ERA5_Reanalysis_Hourly_Pressure.nc')
# ds_era5 = xr.open_dataset(era5_path)

# #TIMPS data
# timps_folder = os.path.join(os.getcwd(), 'TIMPS_data')

# conv_bins = np.arange(-15,15.1,0.5)

# for case_num in case_dict_conv.keys():

#     print (f'Case {case_num} convergence PDF plots in progress...')
    
#     df_drop_case = df_drop[df_drop['Case'] == case_num].copy()
    
#     case_date = case_dict_conv[case_num][0]
#     case_hours = case_dict_conv[case_num][1]
#     case_timps_ids = df_drop_case['TIMPS ID'].unique()
    
#     for pres_lev in pressures_to_plot_conv:          #plot convergence PDFs at low- and mid-levels
        
#         group_fig = plt.figure(figsize = (36, 12))   #initialize the convergence PDF figure for the given case
        
#         for lifecycle in [1,2,3]:                    #partition convergence PDFs by TIMPS convective lifecycle stage

#             conv_df = pd.DataFrame()
#             conv_lats_df = pd.DataFrame()
#             #conv_lons_df = pd.DataFrame()

#             for hr in case_hours:
#                 hr2 = str(hr).zfill(2)
                
#                 coords_check_list = []

#                 for ii, unique_timps_id in enumerate(case_timps_ids):

#                     timps_filepath = None
#                     if pd.isnull(unique_timps_id):
#                         continue
#                     else:
#                         unique_timps_id = str(int(unique_timps_id))
#                         for filename in os.listdir(timps_folder):
#                             if unique_timps_id in filename:
#                                 timps_filepath = os.path.join(timps_folder, filename)
#                                 break

#                     if timps_filepath == None:
#                         sys.exit(f'Could not find TIMPS file for TIMP ID {unique_timps_id}')
#                     else:
#                         timps_ds0 = xr.open_dataset(timps_filepath)
#                         timps_ds0 = timps_ds0.sel(time = case_date)
#                         timps_ds = timps_ds0.sel(time = timps_ds0.time.dt.hour.isin(hr))   #gives 2 times: minute = 0 and minute = 30
#                         timps_ds = timps_ds.sel(time = timps_ds.time.dt.minute.isin(0))    #grab the time on the hour to match with ERA5
                        
#                         if len(timps_ds.gmd) == 0:
#                             print (f'{hr2} UTC is out of range of the TIMPS ID range ({timps_ds0.time[0].values.astype(str)[:-10]} - {timps_ds0.time[-1].values.astype(str)[:-10]})')
#                             timps_ds0.close()
#                             continue
                        
#                         #TIMPS has an "unidentifiable" lifecycle phase for the collection range for Case 10, so 
#                             #we're using our manual characterization of "mature" instead for Case 10
#                         if case_num == 10:
#                             timps_gmd = 2
#                         else:
#                             timps_gmd = timps_ds.gmd.item()
                            
#                         #only add ERA5 convergence data for the given lifecycle stage
#                         if timps_gmd != lifecycle:
#                             timps_ds0.close()
#                             continue
                            
#                         timps_weighted_lat = timps_ds.centlatwgt.item()
#                         timps_weighted_lon = timps_ds.centlonwgt.item()

#                         #create 2-by-2 degree box around weighted centroid for the given TIMPS ID at the given hour
#                         timps_lat_range = slice(timps_weighted_lat + 1, timps_weighted_lat - 1)
#                         timps_lon_range = slice(timps_weighted_lon - 1, timps_weighted_lon + 1)

#                         #grab all the ERA5 low-/mid-level convergence values and corresponding lats/lons within the given convective box at the given hour
#                         v700 = ds_era5.v.sel(time = case_date).sel(level = 700)
#                         v700 = v700.sel(time = v700.time.dt.hour.isin(hr))
#                         v700 = mpcalc.smooth_gaussian(v700, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         v700 = v700.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                        
#                         #manually calculating convergence from ERA5 u and v winds (recommended by Brandon Wolding via George Kiladis)
#                         u = ds_era5.u.sel(time = case_date).sel(level = pres_lev)
#                         u = u.sel(time = u.time.dt.hour.isin(hr))
#                         u = mpcalc.smooth_gaussian(u, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         #u = u.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                
#                         v = ds_era5.v.sel(time = case_date).sel(level = pres_lev)
#                         v = v.sel(time = v.time.dt.hour.isin(hr))
#                         v = mpcalc.smooth_gaussian(v, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         #v = v.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                        
#                         delta_lons = 0.25   #ERA5 lat/lon resolution is 0.25 degrees
#                         delta_lons_meters = (111.3195 * 1000 * delta_lons) * np.cos(u.latitude.values * np.pi/180)  #distance between longitude lines at equator is 111.3195 km and cosine weighting this distance by latitude
#                         dudx = (u[:,:,1:].values - u[:,:,:-1].values).squeeze() / np.expand_dims(np.abs(delta_lons_meters), axis=1)  #squeeze() removes dimensions of size 1 from an array, and expand_dims() inserts a new axis that will appear at the axis position
#                         dudx = np.column_stack((dudx, dudx[:,-1]))  #duplicate the last column of dudx to match original grid shape (and shape of dvdy)
            
#                         delta_lats = 0.25
#                         delta_lats_meters = 110.5744 * 1000 * delta_lats  #distance between latitude lines everywhere
#                         dvdy = (v[:,:-1,:].values - v[:,1:,:].values).squeeze() / delta_lats_meters  #squeeze() removes dimensions of size 1 from an array
#                         dvdy = np.vstack((dvdy, dvdy[-1,:]))  #duplicate the last row of dvdy to match original grid shape (and shape of dudx)
            
#                         conv_old = (dudx + dvdy) * -1 * 10**5  #manually calculated convergence from ERA5 u and v winds (times 10**5 1/s)
#                         ds_conv = xr.Dataset(data_vars = dict(convergence = (["latitude", "longitude"], conv_old)),
#                                              coords = dict(latitude = ("latitude", u.latitude.values), 
#                                                            longitude = ("longitude", u.longitude.values)),
#                                              attrs = dict(description = "Manually calculated ERA5 convergence data"))
                        
#                         conv = ds_conv.convergence.sel(longitude = timps_lon_range, latitude = timps_lat_range)

#                         # #grab all the ERA5 low-level convergence values and corresponding lats/lons within the given TIMPS ID box at the given hour
#                         # conv = ds_era5.d.sel(time = case_date).sel(level = pres_lev) * -1    #convergence of the wind (1/s)
#                         # conv = conv.sel(time = conv.time.dt.hour.isin(hr)) * 10**5  #convergence of the wind (times 10**5 1/s)
#                         # conv = conv.sel(longitude = timps_lon_range, latitude = timps_lat_range)

#                         lon, lat = np.meshgrid(conv.longitude, conv.latitude)
#                         lons = lon.reshape(-1)
#                         lats = lat.reshape(-1)
#                         conv_values = conv.values.reshape(-1)

#                         #add data from each hour as COLUMNS to corresponding df
#                         #only add values to conv_df that haven't been already been added from a prior TIMPS ID for the given hour (so no duplicates!)
#                         if ii != 0:  #if not the first TIMPS ID for the given hour
#                             for xx, coord in enumerate(zip(lats,lons)):
#                                 if coord in coords_check_list:
#                                     print (f'Duplicate coordinate for Case {case_num}, replacing with NaN')
#                                     lats[xx] = np.nan
#                                     lons[xx] = np.nan
#                                     conv_values[xx] = np.nan
                        
#                         coords_check_list += list(zip(lats, lons))  #append the lats/lons to the list as tuples (coordinate pairs)
                        
#                         conv_df = pd.concat((conv_df, pd.Series(conv_values)), axis = 1, ignore_index = True)
#                         conv_lats_df = pd.concat((conv_lats_df, pd.Series(lats)), axis = 1, ignore_index = True)
#                         #conv_lons_df = pd.concat((conv_lons_df, pd.Series(lons)), axis = 1, ignore_index = True)

#                         timps_ds0.close()

#             ax = group_fig.add_subplot(1, 3, lifecycle)        

#             if lifecycle == 1:
#                 clc = 'Growth'
#                 color = 'limegreen'
#             elif lifecycle == 2:
#                 clc = 'Mature'
#                 color = 'darkred'
#             elif lifecycle == 3:
#                 clc = 'Decay'
#                 color = 'blue'
            
#             ax.set_title('Case %i (%s Stage) ERA5 %i hPa Convergence PDF (Convection-Relative)' % (case_num, clc, pres_lev))
                
#             if len(conv_df) == 0:  #no data for the given lifecycle for the given case for the given hours
#                 ax.text(0.5, 0.5, f'{clc} stage was not sampled\nfor Case {case_num} during CPEX-CV', 
#                         horizontalalignment = 'center', verticalalignment = 'center', 
#                         fontsize = 30, bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 10})
#                 ax.set_xticklabels([])
#                 ax.set_yticklabels([])
#                 continue
            
#             conv_df = conv_df.values  #convert Pandas DataFrame to NumPy array
#             conv_lats_df = conv_lats_df.values  #convert Pandas DataFrame to NumPy array
            
#             #mask NaN values in the Dataframes so that the numpy stat calculations work below
#             conv_df_masked = np.ma.masked_where(np.isnan(conv_df), conv_df)
#             conv_lats_df_masked = np.ma.masked_where(np.isnan(conv_lats_df), conv_lats_df)
                
#             ax.hist(conv_df_masked.reshape(-1), bins = conv_bins, density = True, weights = None,
#                     histtype = 'step', align = 'mid', orientation = 'vertical', color = color, linewidth = 2)
#                 #density = True returns a probability density: each bin will display the bin's raw count 
#                     #divided by the total number of counts and the bin width
#                     #(density = counts / (sum(counts) * np.diff(bins))), so that the area under the 
#                     #histogram integrates to 1 (np.sum(density * np.diff(bins)) == 1)
    
#             cos_weights = np.sqrt(np.cos(conv_lats_df_masked * np.pi/180))   #cosine weights to apply to conv_df
        
#             conv_count = np.count_nonzero(~np.isnan(conv_df))
#             conv_median = np.round(np.nanmedian(conv_df, axis = None), 2)
#             conv_mean = np.round(np.nanmean(conv_df, axis = None), 2)                                        #non-weighted mean (1st moment)
#             conv_wgt_mean = np.round(np.average(conv_df_masked, axis = None, weights = cos_weights), 2)      #cosine-weighted mean (1st moment)
#             conv_std = np.round(np.std(conv_df_masked, axis = None), 2)                                      #standard deviation (2nd moment)
#             conv_skew = np.round(scipy.stats.skew(conv_df_masked, axis = None, nan_policy = 'omit'), 4)      #skewness (3rd moment)
#             conv_kurt = np.round(scipy.stats.kurtosis(conv_df_masked, axis = None, nan_policy = 'omit'), 4)  #kurtosis (4th moment)
        
#             ax.axvline(x = 0, color = 'k', linestyle = '--', alpha = 0.5)
#             ax.text(0.98, 0.875, 
#                     f'Count: {conv_count}\nMedian: {conv_median}\nMean: {conv_mean}\nWeighted Mean: {conv_wgt_mean}\nStandard Deviation: {conv_std}\nSkewness: {conv_skew}\nKurtosis: {conv_kurt}\n', 
#                     transform = ax.transAxes, horizontalalignment = 'right', verticalalignment = 'center', 
#                     fontsize = 16, fontweight = 'bold', color = color)
            
#             ax.set_xlabel('Convergence [10$^{-5}$ s$^{-1}$]')
#             ax.set_ylabel('Prob(Convergence)')
#             ax.set_xlim([-15,15])
#             ax.set_xticks(np.arange(-15, 15.1, 2))
#             ax.set_yticks(np.arange(0, 0.401, 0.05))
#             ax.grid(axis = 'y')
                    
#         #plt.tight_layout()
#         #plt.subplots_adjust(wspace = 0.1)

#         #save the figure
#         plot_save_name = f'Case{case_num}_{pres_lev}hPa_convergence_PDFs_convection_relative_separate.png'
#         plt.savefig(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name), bbox_inches = 'tight')
#         #plt.show()  #plt.show() must come after plt.savefig() in order for the image to save properly
#         #plt.clf()   #supposedly speeds things up? According to: https://www.youtube.com/watch?v=jGVIZbi9uMY
#         plt.close()
#         plt.clf()    #if placing this after plt.close(), may release memory related to the figure (https://stackoverflow.com/questions/741877/how-do-i-tell-matplotlib-that-i-am-done-with-a-plot)

#         ##decrease file size of the image by 66% without noticeable image effects (if using Matplotlib)
#         ##(good to use if you're producing a lot of images, see https://www.youtube.com/watch?v=fzhAseXp5B4)
#         im = Image.open(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name))

#         try:
#             im2 = im.convert('P', palette = Image.Palette.ADAPTIVE)
#         except:
#             #use this for older version of PIL/Pillow if the above line doesn't work, 
#             #though this line will have isolated, extremely minor image effects due to 
#             #only using 256 colors instead of the 3-element RGB scale
#             im2 = im.convert('P')

#         im2.save(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name))
#         im.close()
#         im2.close()
    
#     print (f'Case {case_num} convergence PDF plots complete!\n')

# ds_era5.close()

# tend = time.time()
# print (f'This script took {np.round((tend - tstart) / 60, 1)} minutes to complete.')


### THE FOLLOWING CELL PLOTS A 3-PANEL PLOT OF CONVECTION-RELATIVE CONVERGENCE PDFs COMPARING 2 CASES (1 PANEL PER LIFECYCLE STAGE)

In [None]:
# #THIS CELL PLOTS A 3-PANEL PLOT OF CONVECTION-RELATIVE CONVERGENCE PDFs COMPARING 2 CASES (1 PANEL PER LIFECYCLE STAGE)

# #For each CPEX-CV case, calculate and plot domain-PDFs of low- (975 hPa) and mid- (700 hPa) level convergence 
# #(display median, mean, standard deviation, skewness, and kurtosis as well), with the domain being an 
# #2-by-2 degree box around TIMPS MCS centroids (e.g., Galarneau et al., 2023)
#     #partition the data by lifecycle of the convection (use TIMPS "gmd" variable)
#     #for calculating domain-mean convergence, cosine-weight each grid box value (see AOS 573 material)
#     #KEEP TRACK OF WHICH GRID CELLS YOU HAVE ALREADY ADDED CONVERGENCE FOR, SO THAT YOU DON'T
#         #DOUBLE COUNT GRID CELLS WHEN YOU HAVE A CASE WITH MULTIPLE TIMPS IDS WHOSE BOXES MAY OVERLAP

# #CPEX-CV convective cases (average time for dropsondes for a given convective case, per CPEX-CV Well Documented Convection.docx)
# case_dict_comp = {10: ['20220906', [16,17,18]],
#                   12: ['20220907', [14,15,16,17,18]]}

# # case_dict_comp = {8: ['20220906', [11,12]],
# #                   9: ['20220906', [13,14,15,16]],
# #                   10: ['20220906', [16,17,18]],
# #                   11: ['20220907', [13,14]],
# #                   12: ['20220907', [14,15,16,17,18]],
# #                   13: ['20220907', [15,16]],
# #                   14: ['20220914', [10,11,12,14,15]],
# #                   15: ['20220914', [12,13,14,15,16,17]],
# #                   16: ['20220916', [14,15,16]],
# #                   17: ['20220916', [17,18,19]],
# #                   18: ['20220922', [6,7,8,9]],
# #                   19: ['20220923', [9,10,11,12,13,14,15]],
# #                   20: ['20220926', [7,8,9,10,11]],
# #                   21: ['20220929', [11,12,13,14]],
# #                   22: ['20220930', [14,15]]}        
        
# #set some baseline plot displays

# #matplotlib.rcParams['axes.facecolor'] = [0.9,0.9,0.9]
# matplotlib.rcParams['axes.labelsize'] = 20
# matplotlib.rcParams['axes.titlesize'] = 20
# matplotlib.rcParams['axes.labelweight'] = 'bold'
# matplotlib.rcParams['axes.titleweight'] = 'bold'
# matplotlib.rcParams['xtick.labelsize'] = 20
# matplotlib.rcParams['ytick.labelsize'] = 20
# matplotlib.rcParams['legend.fontsize'] = 18
# #matplotlib.rcParams['legend.facecolor'] = 'w'
# #matplotlib.rcParams['axes.facecolor'] = 'w'
# matplotlib.rcParams['font.family'] = 'arial'
# matplotlib.rcParams['hatch.linewidth'] = 0.3
        
# #Dropsonde data
# drop_metric_filepath = os.path.join(os.getcwd(), 'Dropsonde_Metric_Calculations_CPEXCV.csv')
# df_drop = pd.read_csv(drop_metric_filepath)

# #ERA5 data
# era5_folder = os.path.join(os.getcwd(), 'ERA5_Reanalysis_Data')
# era5_path = os.path.join(era5_folder, 'CPEXCV_ERA5_Reanalysis_Hourly_Pressure.nc')
# ds_era5 = xr.open_dataset(era5_path)

# #TIMPS data
# timps_folder = os.path.join(os.getcwd(), 'TIMPS_data')

# conv_bins = np.arange(-15,15.1,0.5)

# for pres_lev in pressures_to_plot_conv:          #plot convergence PDFs at low- and mid-levels

#     group_fig = plt.figure(figsize = (36, 12))   #initialize the convergence PDF figure for the given case comparison

#     first_case_no_lifecycle_data = [False, False, False]
    
#     for case_num in case_dict_comp.keys():

#         #print (f'Case {case_num} convergence PDF plots in progress...')

#         df_drop_case = df_drop[df_drop['Case'] == case_num].copy()

#         case_date = case_dict_comp[case_num][0]
#         case_hours = case_dict_comp[case_num][1]
#         case_timps_ids = df_drop_case['TIMPS ID'].unique()

#         for lifecycle in [1,2,3]:                    #partition convergence PDFs by TIMPS convective lifecycle stage

#             conv_df = pd.DataFrame()
#             conv_lats_df = pd.DataFrame()
#             #conv_lons_df = pd.DataFrame()

#             for hr in case_hours:
#                 hr2 = str(hr).zfill(2)

#                 coords_check_list = []

#                 for ii, unique_timps_id in enumerate(case_timps_ids):

#                     timps_filepath = None
#                     if pd.isnull(unique_timps_id):
#                         continue
#                     else:
#                         unique_timps_id = str(int(unique_timps_id))
#                         for filename in os.listdir(timps_folder):
#                             if unique_timps_id in filename:
#                                 timps_filepath = os.path.join(timps_folder, filename)
#                                 break

#                     if timps_filepath == None:
#                         sys.exit(f'Could not find TIMPS file for TIMP ID {unique_timps_id}')
#                     else:
#                         timps_ds0 = xr.open_dataset(timps_filepath)
#                         timps_ds0 = timps_ds0.sel(time = case_date)
#                         timps_ds = timps_ds0.sel(time = timps_ds0.time.dt.hour.isin(hr))   #gives 2 times: minute = 0 and minute = 30
#                         timps_ds = timps_ds.sel(time = timps_ds.time.dt.minute.isin(0))    #grab the time on the hour to match with ERA5

#                         if len(timps_ds.gmd) == 0:
#                             print (f'{hr2} UTC is out of range of the TIMPS ID range ({timps_ds0.time[0].values.astype(str)[:-10]} - {timps_ds0.time[-1].values.astype(str)[:-10]})')
#                             timps_ds0.close()
#                             continue

#                         #TIMPS has an "unidentifiable" lifecycle phase for the collection range for Case 10, so 
#                             #we're using our manual characterization of "mature" instead for Case 10
#                         if case_num == 10:
#                             timps_gmd = 2
#                         else:
#                             timps_gmd = timps_ds.gmd.item()

#                         #only add ERA5 convergence data for the given lifecycle stage
#                         if timps_gmd != lifecycle:
#                             timps_ds0.close()
#                             continue

#                         timps_weighted_lat = timps_ds.centlatwgt.item()
#                         timps_weighted_lon = timps_ds.centlonwgt.item()

#                         #create 2-by-2 degree box around weighted centroid for the given TIMPS ID at the given hour
#                         timps_lat_range = slice(timps_weighted_lat + 1, timps_weighted_lat - 1)
#                         timps_lon_range = slice(timps_weighted_lon - 1, timps_weighted_lon + 1)

#                         #grab all the ERA5 low-/mid-level convergence values and corresponding lats/lons within the given convective box at the given hour
#                         v700 = ds_era5.v.sel(time = case_date).sel(level = 700)
#                         v700 = v700.sel(time = v700.time.dt.hour.isin(hr))
#                         v700 = mpcalc.smooth_gaussian(v700, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         v700 = v700.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                        
#                         #manually calculating convergence from ERA5 u and v winds (recommended by Brandon Wolding via George Kiladis)
#                         u = ds_era5.u.sel(time = case_date).sel(level = pres_lev)
#                         u = u.sel(time = u.time.dt.hour.isin(hr))
#                         u = mpcalc.smooth_gaussian(u, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         #u = u.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                
#                         v = ds_era5.v.sel(time = case_date).sel(level = pres_lev)
#                         v = v.sel(time = v.time.dt.hour.isin(hr))
#                         v = mpcalc.smooth_gaussian(v, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         #v = v.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                        
#                         delta_lons = 0.25   #ERA5 lat/lon resolution is 0.25 degrees
#                         delta_lons_meters = (111.3195 * 1000 * delta_lons) * np.cos(u.latitude.values * np.pi/180)  #distance between longitude lines at equator is 111.3195 km and cosine weighting this distance by latitude
#                         dudx = (u[:,:,1:].values - u[:,:,:-1].values).squeeze() / np.expand_dims(np.abs(delta_lons_meters), axis=1)  #squeeze() removes dimensions of size 1 from an array, and expand_dims() inserts a new axis that will appear at the axis position
#                         dudx = np.column_stack((dudx, dudx[:,-1]))  #duplicate the last column of dudx to match original grid shape (and shape of dvdy)
            
#                         delta_lats = 0.25
#                         delta_lats_meters = 110.5744 * 1000 * delta_lats  #distance between latitude lines everywhere
#                         dvdy = (v[:,:-1,:].values - v[:,1:,:].values).squeeze() / delta_lats_meters  #squeeze() removes dimensions of size 1 from an array
#                         dvdy = np.vstack((dvdy, dvdy[-1,:]))  #duplicate the last row of dvdy to match original grid shape (and shape of dudx)
            
#                         conv_old = (dudx + dvdy) * -1 * 10**5  #manually calculated convergence from ERA5 u and v winds (times 10**5 1/s)
#                         ds_conv = xr.Dataset(data_vars = dict(convergence = (["latitude", "longitude"], conv_old)),
#                                              coords = dict(latitude = ("latitude", u.latitude.values), 
#                                                            longitude = ("longitude", u.longitude.values)),
#                                              attrs = dict(description = "Manually calculated ERA5 convergence data"))
                        
#                         conv = ds_conv.convergence.sel(longitude = timps_lon_range, latitude = timps_lat_range)

#                         # #grab all the ERA5 low-level convergence values and corresponding lats/lons within the given TIMPS ID box at the given hour
#                         # conv = ds_era5.d.sel(time = case_date).sel(level = pres_lev) * -1    #convergence of the wind (1/s)
#                         # conv = conv.sel(time = conv.time.dt.hour.isin(hr)) * 10**5  #convergence of the wind (times 10**5 1/s)
#                         # conv = conv.sel(longitude = timps_lon_range, latitude = timps_lat_range)

#                         lon, lat = np.meshgrid(conv.longitude, conv.latitude)
#                         lons = lon.reshape(-1)
#                         lats = lat.reshape(-1)
#                         conv_values = conv.values.reshape(-1)

#                         #add data from each hour as COLUMNS to corresponding df
#                         #only add values to conv_df that haven't been already been added from a prior TIMPS ID for the given hour (so no duplicates!)
#                         if ii != 0:  #if not the first TIMPS ID for the given hour
#                             for xx, coord in enumerate(zip(lats,lons)):
#                                 if coord in coords_check_list:
#                                     print (f'Duplicate coordinate for Case {case_num}, replacing with NaN')
#                                     lats[xx] = np.nan
#                                     lons[xx] = np.nan
#                                     conv_values[xx] = np.nan

#                         coords_check_list += list(zip(lats, lons))  #append the lats/lons to the list as tuples (coordinate pairs)

#                         conv_df = pd.concat((conv_df, pd.Series(conv_values)), axis = 1, ignore_index = True)
#                         conv_lats_df = pd.concat((conv_lats_df, pd.Series(lats)), axis = 1, ignore_index = True)
#                         #conv_lons_df = pd.concat((conv_lons_df, pd.Series(lons)), axis = 1, ignore_index = True)

#                         timps_ds0.close()
            
#             #create or grab the axis of interest
#             if case_num == list(case_dict_comp.keys())[0]:
#                 ax = group_fig.add_subplot(1, 3, lifecycle)
#                 text_denom = 0
                
#                 if lifecycle == 1:
#                     clc = 'Growth'
#                     color = 'limegreen'
#                 elif lifecycle == 2:
#                     clc = 'Mature'
#                     color = 'darkred'
#                 elif lifecycle == 3:
#                     clc = 'Decay'
#                     color = 'cornflowerblue'
#             else:  #case_num == list(case_dict_comp.keys())[-1]
#                 ax = group_fig.get_axes()[lifecycle - 1]
#                 text_denom = 0.2

#                 if lifecycle == 1:
#                     clc = 'Growth'
#                     color = 'darkgreen'
#                 elif lifecycle == 2:
#                     clc = 'Mature'
#                     color = 'navy'
#                 elif lifecycle == 3:
#                     clc = 'Decay'
#                     color = 'darkblue'
            
#                 ax.set_title('ERA5 %i hPa Convergence PDFs\n(Convection-Relative, %s Stage, 2x2 Degree Box)' % (pres_lev, clc))
            
#             if (len(conv_df) == 0) and (case_num == list(case_dict_comp.keys())[0]):     #no data for the given lifecycle for the given case for the given hours
#                 first_case_no_lifecycle_data[lifecycle - 1] = True
#                 continue
#             elif (len(conv_df) == 0) and (case_num == list(case_dict_comp.keys())[-1]):  #no data for the given lifecycle for the given case for the given hours
#                 if first_case_no_lifecycle_data[lifecycle - 1]:  #no data for the given lifecycle for either case
#                     ax.text(0.5, 0.5, f'{clc} stage was not sampled for\nCase {list(case_dict_comp.keys())[0]} nor Case {case_num} during CPEX-CV', 
#                             horizontalalignment = 'center', verticalalignment = 'center', 
#                             fontsize = 30, bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 10})
#                     ax.set_xticklabels([])
#                     ax.set_yticklabels([])
#                     continue
#                 else:
#                     continue
#             else:
#                 pass

#             conv_df = conv_df.values  #convert Pandas DataFrame to NumPy array
#             conv_lats_df = conv_lats_df.values  #convert Pandas DataFrame to NumPy array

#             #mask NaN values in the Dataframes so that the numpy stat calculations work below
#             conv_df_masked = np.ma.masked_where(np.isnan(conv_df), conv_df)
#             conv_lats_df_masked = np.ma.masked_where(np.isnan(conv_lats_df), conv_lats_df)

#             #line plot histogram (clearer to interpret than "step" histogram below)
#             hist, bins = np.histogram(conv_df_masked.reshape(-1), bins = conv_bins, density = True, weights = None)
#             bin_centers = (bins[:-1] + bins[1:]) / 2  # Midpoints of the bins
#             ax.plot(bin_centers, hist, linewidth = 2, linestyle = '-', color = color, label = f'Case {case_num}')

#             # #normal "step" histogram
#             # ax.hist(conv_df_masked.reshape(-1), bins = conv_bins, density = True, weights = None,
#             #         histtype = 'step', align = 'mid', orientation = 'vertical', color = color, 
#             #         linewidth = 2, label = f'Case {case_num}')
#             #     #density = True returns a probability density: each bin will display the bin's raw count 
#             #         #divided by the total number of counts and the bin width
#             #         #(density = counts / (sum(counts) * np.diff(bins))), so that the area under the 
#             #         #histogram integrates to 1 (np.sum(density * np.diff(bins)) == 1)

#             cos_weights = np.sqrt(np.cos(conv_lats_df_masked * np.pi/180))   #cosine weights to apply to conv_df

#             conv_count = np.count_nonzero(~np.isnan(conv_df))
#             conv_median = np.round(np.nanmedian(conv_df, axis = None), 2)
#             conv_mean = np.round(np.nanmean(conv_df, axis = None), 2)                                        #non-weighted mean (1st moment)
#             conv_wgt_mean = np.round(np.average(conv_df_masked, axis = None, weights = cos_weights), 2)      #cosine-weighted mean (1st moment)
#             conv_std = np.round(np.std(conv_df_masked, axis = None), 2)                                      #standard deviation (2nd moment)
#             conv_skew = np.round(scipy.stats.skew(conv_df_masked, axis = None, nan_policy = 'omit'), 4)      #skewness (3rd moment)
#             conv_kurt = np.round(scipy.stats.kurtosis(conv_df_masked, axis = None, nan_policy = 'omit'), 4)  #kurtosis (4th moment)

#             ax.axvline(x = 0, color = 'k', linestyle = '--', alpha = 0.5)
            
#             ax.text(0.98, 0.89 - text_denom, 
#                     f'Count: {conv_count}\nMedian: {conv_median}\nMean: {conv_mean}\nWeighted Mean: {conv_wgt_mean}\nStandard Deviation: {conv_std}\nSkewness: {conv_skew}\nKurtosis: {conv_kurt}\n', 
#                     transform = ax.transAxes, horizontalalignment = 'right', verticalalignment = 'center', 
#                     fontsize = 16, fontweight = 'bold', color = color)
                    
#         #print (f'Case {case_num} convergence PDF plots complete!\n')
                    
#     for ax in group_fig.get_axes():
#         ax.set_xlabel('Convergence [10$^{-5}$ s$^{-1}$]')
#         ax.set_ylabel('Prob(Convergence)')
#         ax.set_xlim([-13,13])
#         ax.set_xticks(np.arange(-13, 13.1, 2))
#         ax.set_ylim(bottom = 0)
#         #ax.set_yticks(np.arange(0, 0.401, 0.05))
#         ax.grid(axis = 'y')
#         ax.legend(loc = 'upper left')

#     #plt.tight_layout()
#     #plt.subplots_adjust(wspace = 0.1)
    
#     #save the figure
#     case_nums = '-'.join(map(str, case_dict_comp.keys()))
#     plot_save_name = f'Case{case_nums}_{pres_lev}hPa_convergence_PDFs_convection_relative_separate.png'
#     plt.savefig(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name), bbox_inches = 'tight')
#     #plt.show()  #plt.show() must come after plt.savefig() in order for the image to save properly
#     #plt.clf()   #supposedly speeds things up? According to: https://www.youtube.com/watch?v=jGVIZbi9uMY
#     plt.close()
#     plt.clf()    #if placing this after plt.close(), may release memory related to the figure (https://stackoverflow.com/questions/741877/how-do-i-tell-matplotlib-that-i-am-done-with-a-plot)

#     ##decrease file size of the image by 66% without noticeable image effects (if using Matplotlib)
#     ##(good to use if you're producing a lot of images, see https://www.youtube.com/watch?v=fzhAseXp5B4)
#     im = Image.open(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name))

#     try:
#         im2 = im.convert('P', palette = Image.Palette.ADAPTIVE)
#     except:
#         #use this for older version of PIL/Pillow if the above line doesn't work, 
#         #though this line will have isolated, extremely minor image effects due to 
#         #only using 256 colors instead of the 3-element RGB scale
#         im2 = im.convert('P')

#     im2.save(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name))
#     im.close()
#     im2.close()

# ds_era5.close()

# #tend = time.time()
# #print (f'This script took {np.round((tend - tstart) / 60, 1)} minutes to complete.')


### THE FOLLOWING CELL PLOTS A 1-PANEL PLOT OF CONVECTION-RELATIVE CONVERGENCE PDFs (LIFECYCLE STAGE PDFs OVERLAID ON ONE ANOTHER)

In [None]:
# #THIS CELL PLOTS A 1-PANEL PLOT OF CONVECTION-RELATIVE CONVERGENCE PDFs (LIFECYCLE STAGE PDFs OVERLAID ON ONE ANOTHER)

# #For each CPEX-CV case, calculate and plot domain-PDFs of low- (975 hPa) and mid- (700 hPa) level convergence 
# #(display median, mean, standard deviation, skewness, and kurtosis as well), with the domain being an 
# #2-by-2 degree box around TIMPS MCS centroids (e.g., Galarneau et al., 2023)
#     #partition the data by lifecycle of the convection (use TIMPS "gmd" variable)
#     #for calculating domain-mean convergence, cosine-weight each grid box value (see AOS 573 material)
#     #KEEP TRACK OF WHICH GRID CELLS YOU HAVE ALREADY ADDED CONVERGENCE FOR, SO THAT YOU DON'T
#         #DOUBLE COUNT GRID CELLS WHEN YOU HAVE A CASE WITH MULTIPLE TIMPS IDS WHOSE BOXES MAY OVERLAP

# #set some baseline plot displays

# #matplotlib.rcParams['axes.facecolor'] = [0.9,0.9,0.9]
# matplotlib.rcParams['axes.labelsize'] = 18
# matplotlib.rcParams['axes.titlesize'] = 18
# matplotlib.rcParams['axes.labelweight'] = 'bold'
# matplotlib.rcParams['axes.titleweight'] = 'bold'
# matplotlib.rcParams['xtick.labelsize'] = 18
# matplotlib.rcParams['ytick.labelsize'] = 18
# matplotlib.rcParams['legend.fontsize'] = 17
# #matplotlib.rcParams['legend.facecolor'] = 'w'
# #matplotlib.rcParams['axes.facecolor'] = 'w'
# matplotlib.rcParams['font.family'] = 'arial'
# matplotlib.rcParams['hatch.linewidth'] = 0.3
        
# #Dropsonde data
# drop_metric_filepath = os.path.join(os.getcwd(), 'Dropsonde_Metric_Calculations_CPEXCV.csv')
# df_drop = pd.read_csv(drop_metric_filepath)

# #ERA5 data
# era5_folder = os.path.join(os.getcwd(), 'ERA5_Reanalysis_Data')
# era5_path = os.path.join(era5_folder, 'CPEXCV_ERA5_Reanalysis_Hourly_Pressure.nc')
# ds_era5 = xr.open_dataset(era5_path)

# #TIMPS data
# timps_folder = os.path.join(os.getcwd(), 'TIMPS_data')

# conv_bins = np.arange(-15,15.1,0.5)

# for case_num in case_dict_conv.keys():

#     print (f'Case {case_num} convergence PDF plots in progress...')
    
#     df_drop_case = df_drop[df_drop['Case'] == case_num].copy()
    
#     case_date = case_dict_conv[case_num][0]
#     case_hours = case_dict_conv[case_num][1]
#     case_timps_ids = df_drop_case['TIMPS ID'].unique()
    
#     for pres_lev in pressures_to_plot_conv:          #plot convergence PDFs at low- and mid-levels
        
#         no_lifecycle_counter = 0
        
#         group_fig = plt.figure(figsize = (12, 12))   #initialize the convergence PDF figure for the given case
#         ax = group_fig.add_subplot(1, 1, 1)
        
#         for lifecycle in [1,2,3]:                    #partition convergence PDFs by TIMPS convective lifecycle stage

#             conv_df = pd.DataFrame()
#             conv_lats_df = pd.DataFrame()
#             #conv_lons_df = pd.DataFrame()

#             for hr in case_hours:
#                 hr2 = str(hr).zfill(2)
                
#                 coords_check_list = []

#                 for ii, unique_timps_id in enumerate(case_timps_ids):

#                     timps_filepath = None
#                     if pd.isnull(unique_timps_id):
#                         continue
#                     else:
#                         unique_timps_id = str(int(unique_timps_id))
#                         for filename in os.listdir(timps_folder):
#                             if unique_timps_id in filename:
#                                 timps_filepath = os.path.join(timps_folder, filename)
#                                 break

#                     if timps_filepath == None:
#                         sys.exit(f'Could not find TIMPS file for TIMP ID {unique_timps_id}')
#                     else:
#                         timps_ds0 = xr.open_dataset(timps_filepath)
#                         timps_ds0 = timps_ds0.sel(time = case_date)
#                         timps_ds = timps_ds0.sel(time = timps_ds0.time.dt.hour.isin(hr))   #gives 2 times: minute = 0 and minute = 30
#                         timps_ds = timps_ds.sel(time = timps_ds.time.dt.minute.isin(0))    #grab the time on the hour to match with ERA5
                        
#                         if len(timps_ds.gmd) == 0:
#                             print (f'{hr2} UTC is out of range of the TIMPS ID range ({timps_ds0.time[0].values.astype(str)[:-10]} - {timps_ds0.time[-1].values.astype(str)[:-10]})')
#                             timps_ds0.close()
#                             continue
                            
#                         #TIMPS has an "unidentifiable" lifecycle phase for the collection range for Case 10, so 
#                             #we're using our manual characterization of "mature" instead for Case 10
#                         if case_num == 10:
#                             timps_gmd = 2
#                         else:
#                             timps_gmd = timps_ds.gmd.item()
                            
#                         #only add ERA5 convergence data for the given lifecycle stage
#                         if timps_gmd != lifecycle:
#                             timps_ds0.close()
#                             continue
                            
#                         timps_weighted_lat = timps_ds.centlatwgt.item()
#                         timps_weighted_lon = timps_ds.centlonwgt.item()

#                         #create 2-by-2 degree box around weighted centroid for the given TIMPS ID at the given hour
#                         timps_lat_range = slice(timps_weighted_lat + 1, timps_weighted_lat - 1)
#                         timps_lon_range = slice(timps_weighted_lon - 1, timps_weighted_lon + 1)

#                         #grab all the ERA5 low-/mid-level convergence values and corresponding lats/lons within the given convective box at the given hour
#                         v700 = ds_era5.v.sel(time = case_date).sel(level = 700)
#                         v700 = v700.sel(time = v700.time.dt.hour.isin(hr))
#                         v700 = mpcalc.smooth_gaussian(v700, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         v700 = v700.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                        
#                         #manually calculating convergence from ERA5 u and v winds (recommended by Brandon Wolding via George Kiladis)
#                         u = ds_era5.u.sel(time = case_date).sel(level = pres_lev)
#                         u = u.sel(time = u.time.dt.hour.isin(hr))
#                         u = mpcalc.smooth_gaussian(u, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         #u = u.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                
#                         v = ds_era5.v.sel(time = case_date).sel(level = pres_lev)
#                         v = v.sel(time = v.time.dt.hour.isin(hr))
#                         v = mpcalc.smooth_gaussian(v, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                         #v = v.sel(longitude = timps_lon_range, latitude = timps_lat_range)
                        
#                         delta_lons = 0.25   #ERA5 lat/lon resolution is 0.25 degrees
#                         delta_lons_meters = (111.3195 * 1000 * delta_lons) * np.cos(u.latitude.values * np.pi/180)  #distance between longitude lines at equator is 111.3195 km and cosine weighting this distance by latitude
#                         dudx = (u[:,:,1:].values - u[:,:,:-1].values).squeeze() / np.expand_dims(np.abs(delta_lons_meters), axis=1)  #squeeze() removes dimensions of size 1 from an array, and expand_dims() inserts a new axis that will appear at the axis position
#                         dudx = np.column_stack((dudx, dudx[:,-1]))  #duplicate the last column of dudx to match original grid shape (and shape of dvdy)
            
#                         delta_lats = 0.25
#                         delta_lats_meters = 110.5744 * 1000 * delta_lats  #distance between latitude lines everywhere
#                         dvdy = (v[:,:-1,:].values - v[:,1:,:].values).squeeze() / delta_lats_meters  #squeeze() removes dimensions of size 1 from an array
#                         dvdy = np.vstack((dvdy, dvdy[-1,:]))  #duplicate the last row of dvdy to match original grid shape (and shape of dudx)
            
#                         conv_old = (dudx + dvdy) * -1 * 10**5  #manually calculated convergence from ERA5 u and v winds (times 10**5 1/s)
#                         ds_conv = xr.Dataset(data_vars = dict(convergence = (["latitude", "longitude"], conv_old)),
#                                              coords = dict(latitude = ("latitude", u.latitude.values), 
#                                                            longitude = ("longitude", u.longitude.values)),
#                                              attrs = dict(description = "Manually calculated ERA5 convergence data"))
                        
#                         conv = ds_conv.convergence.sel(longitude = timps_lon_range, latitude = timps_lat_range)

#                         # #grab all the ERA5 low-level convergence values and corresponding lats/lons within the given TIMPS ID box at the given hour
#                         # conv = ds_era5.d.sel(time = case_date).sel(level = pres_lev) * -1    #convergence of the wind (1/s)
#                         # conv = conv.sel(time = conv.time.dt.hour.isin(hr)) * 10**5  #convergence of the wind (times 10**5 1/s)
#                         # conv = conv.sel(longitude = timps_lon_range, latitude = timps_lat_range)

#                         lon, lat = np.meshgrid(conv.longitude, conv.latitude)
#                         lons = lon.reshape(-1)
#                         lats = lat.reshape(-1)
#                         conv_values = conv.values.reshape(-1)

#                         #add data from each hour as COLUMNS to corresponding df
#                         #only add values to conv_df that haven't been already been added from a prior TIMPS ID for the given hour (so no duplicates!)
#                         if ii != 0:  #if not the first TIMPS ID for the given hour
#                             for xx, coord in enumerate(zip(lats,lons)):
#                                 if coord in coords_check_list:
#                                     print (f'Duplicate coordinate for Case {case_num}, replacing with NaN')
#                                     lats[xx] = np.nan
#                                     lons[xx] = np.nan
#                                     conv_values[xx] = np.nan
                        
#                         coords_check_list += list(zip(lats, lons))  #append the lats/lons to the list as tuples (coordinate pairs)
                        
#                         conv_df = pd.concat((conv_df, pd.Series(conv_values)), axis = 1, ignore_index = True)
#                         conv_lats_df = pd.concat((conv_lats_df, pd.Series(lats)), axis = 1, ignore_index = True)
#                         #conv_lons_df = pd.concat((conv_lons_df, pd.Series(lons)), axis = 1, ignore_index = True)

#                         timps_ds0.close()        

#             if lifecycle == 1:
#                 clc = 'Growth'
#                 color = 'limegreen'
#                 text_denom = 1
#             elif lifecycle == 2:
#                 clc = 'Mature'
#                 color = 'darkred'
#                 text_denom = 1.325
#             elif lifecycle == 3:
#                 clc = 'Decay'
#                 color = 'blue'
#                 text_denom = 1.96
                
#             if len(conv_df) == 0:  #no lifecycle data for the given case for the given hours
#                 no_lifecycle_counter += 1
#                 continue
            
#             conv_df = conv_df.values  #convert Pandas DataFrame to NumPy array
#             conv_lats_df = conv_lats_df.values  #convert Pandas DataFrame to NumPy array
            
#             #mask NaN values in the Dataframes so that the numpy stat calculations work below
#             conv_df_masked = np.ma.masked_where(np.isnan(conv_df), conv_df)
#             conv_lats_df_masked = np.ma.masked_where(np.isnan(conv_lats_df), conv_lats_df)
                
#             ax.hist(conv_df_masked.reshape(-1), bins = conv_bins, density = True, weights = None,
#                     histtype = 'step', align = 'mid', orientation = 'vertical', color = color,
#                     linewidth = 2, label = clc)
#                 #density = True returns a probability density: each bin will display the bin's raw count 
#                     #divided by the total number of counts and the bin width
#                     #(density = counts / (sum(counts) * np.diff(bins))), so that the area under the 
#                     #histogram integrates to 1 (np.sum(density * np.diff(bins)) == 1)
    
#             cos_weights = np.sqrt(np.cos(conv_lats_df_masked * np.pi/180))   #cosine weights to apply to conv_df
        
#             conv_count = np.count_nonzero(~np.isnan(conv_df))
#             conv_median = np.round(np.nanmedian(conv_df, axis = None), 2)
#             conv_mean = np.round(np.nanmean(conv_df, axis = None), 2)                                        #non-weighted mean (1st moment)
#             conv_wgt_mean = np.round(np.average(conv_df_masked, axis = None, weights = cos_weights), 2)      #cosine-weighted mean (1st moment)
#             conv_std = np.round(np.std(conv_df_masked, axis = None), 2)                                      #standard deviation (2nd moment)
#             conv_skew = np.round(scipy.stats.skew(conv_df_masked, axis = None, nan_policy = 'omit'), 4)      #skewness (3rd moment)
#             conv_kurt = np.round(scipy.stats.kurtosis(conv_df_masked, axis = None, nan_policy = 'omit'), 4)  #kurtosis (4th moment)
        
#             ax.text(0.98, 0.875 / text_denom, 
#                     f'Count: {conv_count}\nMedian: {conv_median}\nMean: {conv_mean}\nWeighted Mean: {conv_wgt_mean}\nStandard Deviation: {conv_std}\nSkewness: {conv_skew}\nKurtosis: {conv_kurt}\n', 
#                     transform = ax.transAxes, horizontalalignment = 'right', verticalalignment = 'center', 
#                     fontsize = 16, fontweight = 'bold', color = color)
        
#         ax.set_title('Case %i (%s-%s-%s) ERA5 %i hPa Convergence PDFs (Convection-Relative)' % (case_num, case_date[:4], case_date[4:6], case_date[6:], pres_lev))
        
#         if no_lifecycle_counter == 3:  #no lifecycle data for the given case for the given hours
#             ax.text(0.5, 0.5, f'Unidentifiable lifecycle stage throughout\nthe CPEX-CV sampling time for Case {case_num}', 
#                     horizontalalignment = 'center', verticalalignment = 'center', 
#                     fontsize = 30, bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 10})
#             ax.set_xticklabels([])
#             ax.set_yticklabels([])
#         else:
#             ax.axvline(x = 0, color = 'k', linestyle = '--', alpha = 0.5)
#             ax.set_xlabel('Convergence [10$^{-5}$ s$^{-1}$]')
#             ax.set_ylabel('Prob(Convergence)')
#             ax.set_xlim([-15,15])
#             ax.set_xticks(np.arange(-15, 15.1, 2))
#             ax.set_yticks(np.arange(0, 0.401, 0.05))
#             ax.grid(axis = 'y')
#             ax.legend(title = 'Lifecycle Stage', title_fontproperties = {'weight': 'bold', 'size': 18}, loc = 'upper left')
                    
#         #plt.tight_layout()
#         #plt.subplots_adjust(wspace = 0.1)

#         #save the figure
#         plot_save_name = f'Case{case_num}_{pres_lev}hPa_convergence_PDFs_convection_relative.png'
#         plt.savefig(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name), bbox_inches = 'tight')
#         #plt.show()  #plt.show() must come after plt.savefig() in order for the image to save properly
#         #plt.clf()   #supposedly speeds things up? According to: https://www.youtube.com/watch?v=jGVIZbi9uMY
#         plt.close()
#         plt.clf()    #if placing this after plt.close(), may release memory related to the figure (https://stackoverflow.com/questions/741877/how-do-i-tell-matplotlib-that-i-am-done-with-a-plot)

#         ##decrease file size of the image by 66% without noticeable image effects (if using Matplotlib)
#         ##(good to use if you're producing a lot of images, see https://www.youtube.com/watch?v=fzhAseXp5B4)
#         im = Image.open(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name))

#         try:
#             im2 = im.convert('P', palette = Image.Palette.ADAPTIVE)
#         except:
#             #use this for older version of PIL/Pillow if the above line doesn't work, 
#             #though this line will have isolated, extremely minor image effects due to 
#             #only using 256 colors instead of the 3-element RGB scale
#             im2 = im.convert('P')

#         im2.save(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/Convection_Relative/Smoothed_ERA5_winds', plot_save_name))
#         im.close()
#         im2.close()
    
#     print (f'Case {case_num} convergence PDF plots complete!\n')

# ds_era5.close()

# tend = time.time()
# print (f'This script took {np.round((tend - tstart) / 60, 1)} minutes to complete.')


### THE FOLLOWING CELL PLOTS A 1-PANEL PLOT OF AEW-RELATIVE CONVERGENCE PDFs

In [None]:
# #THIS CELL PLOTS A 1-PANEL PLOT OF AEW-RELATIVE CONVERGENCE PDFs

# #For each CPEX-CV case, calculate and plot domain-PDFs of low- (975 hPa) and mid- (700 hPa) level convergence 
# #(display median, mean, standard deviation, skewness, and kurtosis as well), with the domain being a 
# #10-by-10 degree box around the AEW center for the given case (get from Quinton’s AEW tracker: https://osf.io/jnv5u, https://zenodo.org/records/13350860)
#     #partition the data into the 2 sectors of the AEW for the given case (ahead/behind the AEW)
#     #for calculating domain-mean convergence, cosine-weight each grid box value (see AOS 573 material)
    
#     #WON'T NEED TO KEEP TRACK OF WHICH GRID CELLS YOU HAVE ALREADY ADDED CONVERGENCE FOR, SINCE YOU'RE NOT
#         #WORKING WITH MULTIPLE AEWs AT A GIVEN HOUR (LIKE YOU WERE WITH MULTIPLE TIMPS IDs PER HOUR)
        
#     #Also don't need to split up convergence PDFs by convective lifecycle (just ahead/behind the AEW), because
#         #it would be difficult to relate convergence of an AEW region to convective lifecycle, since one
#         #AEW region could (and likely often does) have convective systems that are in different lifecycle stages

# #set some baseline plot displays

# #matplotlib.rcParams['axes.facecolor'] = [0.9,0.9,0.9]
# matplotlib.rcParams['axes.labelsize'] = 18
# matplotlib.rcParams['axes.titlesize'] = 18
# matplotlib.rcParams['axes.labelweight'] = 'bold'
# matplotlib.rcParams['axes.titleweight'] = 'bold'
# matplotlib.rcParams['xtick.labelsize'] = 18
# matplotlib.rcParams['ytick.labelsize'] = 18
# matplotlib.rcParams['legend.fontsize'] = 16
# #matplotlib.rcParams['legend.facecolor'] = 'w'
# #matplotlib.rcParams['axes.facecolor'] = 'w'
# matplotlib.rcParams['font.family'] = 'arial'
# matplotlib.rcParams['hatch.linewidth'] = 0.3
        
# #Dropsonde data
# drop_metric_filepath = os.path.join(os.getcwd(), 'Dropsonde_Metric_Calculations_CPEXCV.csv')
# df_drop = pd.read_csv(drop_metric_filepath)

# #ERA5 data
# era5_folder = os.path.join(os.getcwd(), 'ERA5_Reanalysis_Data')
# era5_path = os.path.join(era5_folder, 'CPEXCV_ERA5_Reanalysis_Hourly_Pressure.nc')
# ds_era5 = xr.open_dataset(era5_path)

# #AEW tracker data, 6-hourly (Quinton’s AEW tracker: https://osf.io/jnv5u, https://zenodo.org/records/13350860)
# aew_folder = os.path.join(os.getcwd(), 'AEW_Tracker_Data')
# aew_path = os.path.join(aew_folder, 'AEW_tracks_post_processed_year_2022.nc')
# ds_aew = xr.open_dataset(aew_path)

# #TIMPS data
# timps_folder = os.path.join(os.getcwd(), 'TIMPS_data')

# conv_bins = np.arange(-15,15.1,0.5)

# for case_num in case_dict_conv_aew.keys():

#     print (f'Case {case_num} convergence PDF plots in progress...')
    
#     aew_system_index_use = None
    
#     df_drop_case = df_drop[df_drop['Case'] == case_num].copy()
    
#     case_date = case_dict_conv_aew[case_num][0]
#     case_hours = case_dict_conv_aew[case_num][1]
#     case_timps_ids = df_drop_case['TIMPS ID'].unique()
    
#     if len(case_timps_ids) == 1 and pd.isnull(case_timps_ids[0]):  #no TIMPS IDs for the given case
#         continue
        
#     aews_at_given_date = ds_aew.sel(time = case_date)
#     day_after = datetime.strptime(case_date, '%Y%m%d') + timedelta(days = 1)  #day after case_date
#     aews_at_given_date_plus1 = ds_aew.sel(time = datetime.strftime(day_after, '%Y%m%d'))
    
#     #match the convective case to the nearest AEW (longitudinally) from Quinton's tracker
#         #and confirm that each TIMPS ID for the given case matches to the same AEW (sanity check)
#     case_hour_nearest_multiple6 = case_hours[np.nanargmin(np.abs(np.array(case_hours) % 6))]  #find the hour closest to a multiple of 6 (AEW tracker is only in 6-hourly intervals)
#     nearest_6_to_case_hour = 6 * round(case_hour_nearest_multiple6 / 6)  #find the multiple of 6 closest to case_hour_nearest_multiple6
    
#     for ii, unique_timps_id in enumerate(case_timps_ids):
        
#         unique_timps_id = str(int(unique_timps_id))
        
#         timps_filepath = None
#         for filename in os.listdir(timps_folder):
#             if unique_timps_id in filename:
#                 timps_filepath = os.path.join(timps_folder, filename)
#                 break

#         if timps_filepath == None:
#             sys.exit(f'Could not find TIMPS file for TIMP ID {unique_timps_id}')
#         else:
#             timps_ds0 = xr.open_dataset(timps_filepath)
#             timps_ds0 = timps_ds0.sel(time = case_date)
#             timps_ds = timps_ds0.sel(time = timps_ds0.time.dt.hour.isin(case_hour_nearest_multiple6))   #gives 2 times: minute = 0 and minute = 30
#             timps_ds = timps_ds.sel(time = timps_ds.time.dt.minute.isin(0))    #grab the time on the hour to match with ERA5 and AEW tracker

#             if len(timps_ds.gmd) == 0:
#                 print (f'{str(case_hour_nearest_multiple6).zfill(2)} UTC is out of range of the TIMPS ID range ({timps_ds0.time[0].values.astype(str)[:-10]} - {timps_ds0.time[-1].values.astype(str)[:-10]})')
#                 timps_ds0.close()
#                 continue

#             #timps_weighted_lat = timps_ds.centlatwgt.item()
#             timps_weighted_lon = timps_ds.centlonwgt.item()
            
#             aews_at_given_hour = aews_at_given_date.sel(time = aews_at_given_date.time.dt.hour.isin(nearest_6_to_case_hour))
            
#             lon_difs = aews_at_given_hour['AEW_lon_smooth'][:,0] - timps_weighted_lon
#             aew_system_index = np.nanargmin(np.abs(lon_difs).values)   #np.argmin() would grab a NaN value!
            
#             if aew_system_index_use != None:
#                 if case_num != 22:
#                     assert aew_system_index == aew_system_index_use, 'TIMPS IDs for the given case are not matching to the same AEW'
#                 else:  #Case 22 has 2 TIMPS IDs which incorrectly match to different AEWs;
#                        #the 2nd TIMPS ID (304352) matches to the correct AEW, so we're forcing the code to choose that AEW here (not ideal, but just one case we need to do this for)
#                     aew_system_index_use = aew_system_index * 1   #the number of the AEW in the AEW tracker file (using * 1 so that aew_system_index_use variable doesn't point (i.e., isn't tied to) to the same reference as aew_system_index)
#                     print (f"Actual AEW system: {aew_system_index_use + 1}, AEW central longitude and strength at {case_date} {nearest_6_to_case_hour} UTC: {np.round(aews_at_given_hour['AEW_lon_smooth'][aew_system_index_use,0].item(), 2)}, {aews_at_given_hour['AEW_strength'][aew_system_index_use,0].item()} s-1")
#             else:
#                 aew_system_index_use = aew_system_index * 1   #the number of the AEW in the AEW tracker file (using * 1 so that aew_system_index_use variable doesn't point (i.e., isn't tied to) to the same reference as aew_system_index)
#                 print (f"AEW system: {aew_system_index_use + 1}, AEW central longitude and strength at {case_date} {nearest_6_to_case_hour} UTC: {np.round(aews_at_given_hour['AEW_lon_smooth'][aew_system_index_use,0].item(), 2)}, {aews_at_given_hour['AEW_strength'][aew_system_index_use,0].item()} s-1")
                
#             timps_ds0.close()
            
#     matched_aew_ds = aews_at_given_date.isel(system = aew_system_index_use)
#     matched_aew_ds_dayafter = aews_at_given_date_plus1.isel(system = aew_system_index_use)

    
#     #plot time evolution of AEW strength for the AEW associated with the given case (put dotted line at the AEW time closest to the sampling of the case)
#     group_fig0 = plt.figure(figsize = (24, 12))   #initialize the time series figure for the given case
#     ax0 = group_fig0.add_subplot(1, 1, 1)
#     #aew_strength = ds_aew['AEW_strength'][aew_system_index_use, :] * 10**5
#     #aew_strength_nonans = aew_strength[~np.isnan(aew_strength)]
#     ax0.plot(ds_aew.time, ds_aew['AEW_strength'][aew_system_index_use, :] * 10**5, linewidth = 2, linestyle = '-', color = 'k')
#     ax0.axvline(x = aews_at_given_hour.time.values[0], color = 'k', linestyle = '--', alpha = 0.5)
    
#     ax0.set_title(f'Case {case_num} AEW Strength Evolution')
#     ax0.set_xlabel('Time [UTC]')
#     ax0.set_ylabel('Non-divergent Curvature Vorticity [10$^{-5}$ s$^{-1}$]')
#     #ax0.set_xlim([aew_strength_nonans.time.values[0], aew_strength_nonans.time.values[-1]])
#     #ax0.set_xticks(ds_aew.time[::60])
#     ax0.set_ylim([0,2])
#     ax0.set_yticks(np.arange(0, 2.01, 0.2))

#     #save the figure
#     plot_save_name = f'Case{case_num}_AEW_strength_time_series.png'
#     plt.savefig(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/AEW_strength_evolution', plot_save_name), bbox_inches = 'tight')
#     #plt.show()  #plt.show() must come after plt.savefig() in order for the image to save properly
#     #plt.clf()   #supposedly speeds things up? According to: https://www.youtube.com/watch?v=jGVIZbi9uMY
#     plt.close()
#     plt.clf()    #if placing this after plt.close(), may release memory related to the figure (https://stackoverflow.com/questions/741877/how-do-i-tell-matplotlib-that-i-am-done-with-a-plot)

#     ##decrease file size of the image by 66% without noticeable image effects (if using Matplotlib)
#     ##(good to use if you're producing a lot of images, see https://www.youtube.com/watch?v=fzhAseXp5B4)
#     im = Image.open(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/AEW_strength_evolution', plot_save_name))

#     try:
#         im2 = im.convert('P', palette = Image.Palette.ADAPTIVE)
#     except:
#         #use this for older version of PIL/Pillow if the above line doesn't work, 
#         #though this line will have isolated, extremely minor image effects due to 
#         #only using 256 colors instead of the 3-element RGB scale
#         im2 = im.convert('P')

#     im2.save(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/AEW_strength_evolution', plot_save_name))
#     im.close()
#     im2.close()
    

#     #plot convergence PDFs at low- and mid-levels
#     for pres_lev in pressures_to_plot_conv:
        
#         group_fig = plt.figure(figsize = (12, 12))   #initialize the convergence PDF figure for the given case
#         ax = group_fig.add_subplot(1, 1, 1)
        
#         for sector in ['full', 'south_ahead', 'ahead', 'north_ahead', 'south_behind', 'behind', 'north_behind']:

#             conv_df = pd.DataFrame()
#             conv_lats_df = pd.DataFrame()
#             #conv_lons_df = pd.DataFrame()

#             for hr in case_hours:

#                 #grab/calculate the AEW centroid coordinates for the given hr
#                 if hr % 6 == 0:  #if hr is a multiple of 6, then don't need to interpolate the AEW centroid at all
#                     matched_aew_smoothed_lat = matched_aew_ds.sel(time = matched_aew_ds.time.dt.hour.isin(hr)).AEW_lat_smooth.item()
#                     matched_aew_smoothed_lon = matched_aew_ds.sel(time = matched_aew_ds.time.dt.hour.isin(hr)).AEW_lon_smooth.item()

#                 else:  #interpolate the centroid of the matched AEW to the given hr

#                     #find the multiple of 6 directly below/equal to hr; this will be your starting hour to interpolate the AEW centroid to hr
#                     nearest_6_below_hr = 6 * (hr // 6)
#                     nearest_6_above_hr = nearest_6_below_hr + 6  #this will be your ending hour to interpolate the AEW centroid to hr

#                     matched_aew_start_hr_ds = matched_aew_ds.sel(time = matched_aew_ds.time.dt.hour.isin(nearest_6_below_hr))

#                     if nearest_6_above_hr == 24:  #grab 00 UTC from the next day
#                         matched_aew_end_hr_ds = matched_aew_ds_dayafter.sel(time = matched_aew_ds_dayafter.time.dt.hour.isin(0))
#                     else:
#                         matched_aew_end_hr_ds = matched_aew_ds.sel(time = matched_aew_ds.time.dt.hour.isin(nearest_6_above_hr))

#                     matched_aew_start_lat = matched_aew_start_hr_ds.AEW_lat_smooth.item()
#                     matched_aew_start_lon = matched_aew_start_hr_ds.AEW_lon_smooth.item()
#                     matched_aew_end_lat = matched_aew_end_hr_ds.AEW_lat_smooth.item()
#                     matched_aew_end_lon = matched_aew_end_hr_ds.AEW_lon_smooth.item()

#                     #AEW centroid moves XX degrees per hour
#                     lat_per_hour = (matched_aew_end_lat - matched_aew_start_lat) / 6
#                     lon_per_hour = (matched_aew_end_lon - matched_aew_start_lon) / 6

#                     #interpolated AEW centroid at the given hr
#                     matched_aew_smoothed_lat = matched_aew_start_lat + (lat_per_hour * (hr % 6))
#                     matched_aew_smoothed_lon = matched_aew_start_lon + (lon_per_hour * (hr % 6))

#                 #create 10-by-10 degree box around the (interpolated) AEW centroid at the given hr
#                 aew_lat_range = slice(matched_aew_smoothed_lat + 5, matched_aew_smoothed_lat - 5)
#                 aew_lon_range = slice(matched_aew_smoothed_lon - 5, matched_aew_smoothed_lon + 5)

#                 #grab all the ERA5 low-/mid-level convergence values and corresponding lats/lons within the given AEW box at the given hour
#                     #and separate the data into the 2 sectors (ahead/behind) of the AEW
#                 v700 = ds_era5.v.sel(time = case_date).sel(level = 700)
#                 v700 = v700.sel(time = v700.time.dt.hour.isin(hr))
#                 v700 = mpcalc.smooth_gaussian(v700, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                 v700 = v700.sel(longitude = aew_lon_range, latitude = aew_lat_range)
                
#                 #manually calculating convergence from ERA5 u and v winds (recommended by Brandon Wolding via George Kiladis)
#                 u = ds_era5.u.sel(time = case_date).sel(level = pres_lev)
#                 u = u.sel(time = u.time.dt.hour.isin(hr))
#                 u = mpcalc.smooth_gaussian(u, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                 #u = u.sel(longitude = aew_lon_range, latitude = aew_lat_range)
                
#                 v = ds_era5.v.sel(time = case_date).sel(level = pres_lev)
#                 v = v.sel(time = v.time.dt.hour.isin(hr))
#                 v = mpcalc.smooth_gaussian(v, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
#                 #v = v.sel(longitude = aew_lon_range, latitude = aew_lat_range)
                
#                 delta_lons = 0.25   #ERA5 lat/lon resolution is 0.25 degrees
#                 delta_lons_meters = (111.3195 * 1000 * delta_lons) * np.cos(u.latitude.values * np.pi/180)  #distance between longitude lines at equator is 111.3195 km and cosine weighting this distance by latitude
#                 dudx = (u[:,:,1:].values - u[:,:,:-1].values).squeeze() / np.expand_dims(np.abs(delta_lons_meters), axis=1)  #squeeze() removes dimensions of size 1 from an array, and expand_dims() inserts a new axis that will appear at the axis position
#                 dudx = np.column_stack((dudx, dudx[:,-1]))  #duplicate the last column of dudx to match original grid shape (and shape of dvdy)

#                 delta_lats = 0.25
#                 delta_lats_meters = 110.5744 * 1000 * delta_lats  #distance between latitude lines everywhere
#                 dvdy = (v[:,:-1,:].values - v[:,1:,:].values).squeeze() / delta_lats_meters  #squeeze() removes dimensions of size 1 from an array
#                 dvdy = np.vstack((dvdy, dvdy[-1,:]))  #duplicate the last row of dvdy to match original grid shape (and shape of dudx)

#                 conv_old = (dudx + dvdy) * -1 * 10**5  #manually calculated convergence from ERA5 u and v winds (times 10**5 1/s)
#                 ds_conv = xr.Dataset(data_vars = dict(convergence = (["latitude", "longitude"], conv_old)),
#                                      coords = dict(latitude = ("latitude", u.latitude.values), 
#                                                    longitude = ("longitude", u.longitude.values)),
#                                      attrs = dict(description = "Manually calculated ERA5 convergence data"))
                
#                 conv = ds_conv.convergence.sel(longitude = aew_lon_range, latitude = aew_lat_range)
                
#                 # #using convergence variable from ERA5
#                 # conv = ds_era5.d.sel(time = case_date).sel(level = pres_lev) * -1    #convergence of the wind (1/s)
#                 # conv = conv.sel(time = conv.time.dt.hour.isin(hr)) * 10**5  #convergence of the wind (times 10**5 1/s)
#                 # conv = conv.sel(longitude = aew_lon_range, latitude = aew_lat_range)
                
#                 #filter the convergence data by the 700-hPa v-component of the wind  
#                     #v <= 0: the grid point is ahead of the AEW center 
#                     #v > 0: the grid point is behind the AEW center
#                         #This dynamically defines ahead/behind AEW centers, which is especially practical for asymmetric AEWs!
#                 if sector == 'full':
#                     conv = conv * 1
#                     clc = 'Full'
#                     color = 'k'
#                     text_denom = 0
#                 elif sector == 'south_ahead':
#                     conv = conv.where(conv.latitude < matched_aew_smoothed_lat).where(v700 <= 0)  #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
#                     clc = 'Ahead (South)'
#                     color = 'skyblue'
#                     text_denom = 0.14
#                 elif sector == 'ahead':
#                     conv = conv.where(v700 <= 0)  #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
#                     clc = 'Ahead'
#                     color = 'dodgerblue'
#                     text_denom = 0.28
#                 elif sector == 'north_ahead':
#                     conv = conv.where(conv.latitude >= matched_aew_smoothed_lat).where(v700 <= 0)  #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
#                     clc = 'Ahead (North)'
#                     color = 'navy'
#                     text_denom = 0.42
#                 elif sector == 'south_behind':
#                     conv = conv.where(conv.latitude < matched_aew_smoothed_lat).where(v700 > 0)   #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
#                     clc = 'Behind (South)'
#                     color = 'lightsalmon'
#                     text_denom = 0.56
#                 elif sector == 'behind':
#                     conv = conv.where(v700 > 0)   #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
#                     clc = 'Behind'
#                     color = 'red'
#                     text_denom = 0.7
#                 elif sector == 'north_behind':
#                     conv = conv.where(conv.latitude >= matched_aew_smoothed_lat).where(v700 > 0)   #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
#                     clc = 'Behind (North)'
#                     color = 'darkred'
#                     text_denom = 0.84

#                 lon, lat = np.meshgrid(conv.longitude, conv.latitude)
#                 lats = lat.reshape(-1)
#                 conv_values = conv.values.reshape(-1)

#                 #add data from each hour as COLUMNS to corresponding df
#                 conv_df = pd.concat((conv_df, pd.Series(conv_values)), axis = 1, ignore_index = True)
#                 conv_lats_df = pd.concat((conv_lats_df, pd.Series(lats)), axis = 1, ignore_index = True)
#                 #conv_lons_df = pd.concat((conv_lons_df, pd.Series(lons)), axis = 1, ignore_index = True)         
                    
#             # if sector == 'ahead':
#             #     clc = 'Ahead'
#             #     color = 'darkred'
#             #     text_denom = 1
#             # elif sector == 'behind':
#             #     clc = 'Behind'
#             #     color = 'blue'
#             #     text_denom = 1.325

#             conv_df = conv_df.values  #convert Pandas DataFrame to NumPy array
#             conv_lats_df = conv_lats_df.values  #convert Pandas DataFrame to NumPy array

#             #mask NaN values in the Dataframes so that the numpy stat calculations work below
#             conv_df_masked = np.ma.masked_where(np.isnan(conv_df), conv_df)
#             conv_lats_df_masked = np.ma.masked_where(np.isnan(conv_lats_df), conv_lats_df)

#             #line plot histogram (clearer to interpret than "step" histogram below)
#             hist, bins = np.histogram(conv_df_masked.reshape(-1), bins = conv_bins, density = True, weights = None)
#             bin_centers = (bins[:-1] + bins[1:]) / 2  # Midpoints of the bins
#             ax.plot(bin_centers, hist, linewidth = 2, linestyle = '-', color = color, label = clc)

#             # #normal "step" histogram
#             # ax.hist(conv_df_masked.reshape(-1), bins = conv_bins, density = True, weights = None,
#             #         histtype = 'step', align = 'mid', orientation = 'vertical', color = color,
#             #         linewidth = 2, label = clc)
#             #     #density = True returns a probability density: each bin will display the bin's raw count 
#             #         #divided by the total number of counts times the bin width
#             #         #(density = counts / (sum(counts) * np.diff(bins))), so that the area under the 
#             #         #histogram integrates to 1 (np.sum(density * np.diff(bins)) == 1)

#             cos_weights = np.sqrt(np.cos(conv_lats_df_masked * np.pi/180))   #cosine weights to apply to conv_df_masked

#             conv_count = np.count_nonzero(~np.isnan(conv_df))
#             conv_median = np.round(np.nanmedian(conv_df, axis = None), 2)
#             conv_mean = np.round(np.nanmean(conv_df, axis = None), 2)                                        #non-weighted mean (1st moment)
#             conv_wgt_mean = np.round(np.average(conv_df_masked, axis = None, weights = cos_weights), 2)      #cosine-weighted mean (1st moment)
#             conv_std = np.round(np.std(conv_df_masked, axis = None), 2)                                      #standard deviation (2nd moment)
#             conv_skew = np.round(scipy.stats.skew(conv_df_masked, axis = None, nan_policy = 'omit'), 4)      #skewness (3rd moment)
#             conv_kurt = np.round(scipy.stats.kurtosis(conv_df_masked, axis = None, nan_policy = 'omit'), 4)  #kurtosis (4th moment)

#             # ax.text(0.98, 0.875 / text_denom, 
#             #         f'Count: {conv_count}\nMedian: {conv_median}\nMean: {conv_mean}\nWeighted Mean: {conv_wgt_mean}\nStandard Deviation: {conv_std}\nSkewness: {conv_skew}\nKurtosis: {conv_kurt}\n', 
#             #         transform = ax.transAxes, horizontalalignment = 'right', verticalalignment = 'center', 
#             #         fontsize = 16, fontweight = 'bold', color = color)
#             ax.text(0.98, 0.92 - text_denom, 
#                     f'Count: {conv_count}\nMedian: {conv_median}\nMean: {conv_mean}\nWeighted Mean: {conv_wgt_mean}\nStandard Deviation: {conv_std}\nSkewness: {conv_skew}\nKurtosis: {conv_kurt}\n', 
#                     transform = ax.transAxes, horizontalalignment = 'right', verticalalignment = 'center', 
#                     fontsize = 12, fontweight = 'bold', color = color)
        
#         ax.set_title('Case %i (%s-%s-%s) ERA5 %i hPa Convergence PDFs (AEW-Relative)' % (case_num, case_date[:4], case_date[4:6], case_date[6:], pres_lev))
        
#         if case_num == 20:
#             ax.text(0.5, 0.5, f'Case {case_num} not being correctly\nmatched to the appropriate AEW', 
#             transform = ax.transAxes, horizontalalignment = 'center', verticalalignment = 'center', 
#             fontsize = 30, bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 10})
        
#         ax.axvline(x = 0, color = 'k', linestyle = '--', alpha = 0.5)
#         ax.set_xlabel('Convergence [10$^{-5}$ s$^{-1}$]')
#         ax.set_ylabel('Prob(Convergence)')
#         ax.set_xlim([-11,11])
#         ax.set_xticks(np.arange(-11, 11.1, 2))
#         ax.set_ylim(bottom = 0)
#         #ax.set_yticks(np.arange(0, 0.601, 0.05))
#         ax.grid(axis = 'y')
#         ax.legend(title = 'Sector of AEW', title_fontproperties = {'weight': 'bold', 'size': 18}, loc = 'upper left')
                    
#         #plt.tight_layout()
#         #plt.subplots_adjust(wspace = 0.1)

#         #save the figure
#         plot_save_name = f'0_Case{case_num}_{pres_lev}hPa_convergence_PDFs_AEW_relative_all_sectors.png'
#         plt.savefig(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/Smoothed_ERA5_winds', plot_save_name), bbox_inches = 'tight')
#         #plt.show()  #plt.show() must come after plt.savefig() in order for the image to save properly
#         #plt.clf()   #supposedly speeds things up? According to: https://www.youtube.com/watch?v=jGVIZbi9uMY
#         plt.close()
#         plt.clf()    #if placing this after plt.close(), may release memory related to the figure (https://stackoverflow.com/questions/741877/how-do-i-tell-matplotlib-that-i-am-done-with-a-plot)

#         ##decrease file size of the image by 66% without noticeable image effects (if using Matplotlib)
#         ##(good to use if you're producing a lot of images, see https://www.youtube.com/watch?v=fzhAseXp5B4)
#         im = Image.open(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/Smoothed_ERA5_winds', plot_save_name))

#         try:
#             im2 = im.convert('P', palette = Image.Palette.ADAPTIVE)
#         except:
#             #use this for older version of PIL/Pillow if the above line doesn't work, 
#             #though this line will have isolated, extremely minor image effects due to 
#             #only using 256 colors instead of the 3-element RGB scale
#             im2 = im.convert('P')

#         im2.save(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/Smoothed_ERA5_winds', plot_save_name))
#         im.close()
#         im2.close()
    
#     print (f'Case {case_num} convergence PDF plots complete!\n')

# ds_era5.close()
# ds_aew.close()

# tend = time.time()
# print (f'This script took {np.round((tend - tstart) / 60, 1)} minutes to complete.')


### THE FOLLOWING CELL CREATES PLOTS OF AEW-RELATIVE CONVERGENCE PDFs COMPARING 2 CASES (1 PANEL PER AEW SECTOR)

In [None]:
#THIS CELL PLOTS A 1-PANEL PLOT OF AEW-RELATIVE CONVERGENCE PDFs

#For each CPEX-CV case, calculate and plot domain-PDFs of low- (975 hPa) and mid- (700 hPa) level convergence 
#(display median, mean, standard deviation, skewness, and kurtosis as well), with the domain being a 
#10-by-10 degree box around the AEW center for the given case (get from Quinton’s AEW tracker: https://osf.io/jnv5u, https://zenodo.org/records/13350860)
    #partition the data into the 2 sectors of the AEW for the given case (ahead/behind the AEW)
    #for calculating domain-mean convergence, cosine-weight each grid box value (see AOS 573 material)
    
    #WON'T NEED TO KEEP TRACK OF WHICH GRID CELLS YOU HAVE ALREADY ADDED CONVERGENCE FOR, SINCE YOU'RE NOT
        #WORKING WITH MULTIPLE AEWs AT A GIVEN HOUR (LIKE YOU WERE WITH MULTIPLE TIMPS IDs PER HOUR)
        
    #Also don't need to split up convergence PDFs by convective lifecycle (just ahead/behind the AEW), because
        #it would be difficult to relate convergence of an AEW region to convective lifecycle, since one
        #AEW region could (and likely often does) have convective systems that are in different lifecycle stages

#set some baseline plot displays

#matplotlib.rcParams['axes.facecolor'] = [0.9,0.9,0.9]
matplotlib.rcParams['axes.labelsize'] = 18
matplotlib.rcParams['axes.titlesize'] = 18
matplotlib.rcParams['axes.labelweight'] = 'bold'
matplotlib.rcParams['axes.titleweight'] = 'bold'
matplotlib.rcParams['xtick.labelsize'] = 18
matplotlib.rcParams['ytick.labelsize'] = 18
matplotlib.rcParams['legend.fontsize'] = 16
#matplotlib.rcParams['legend.facecolor'] = 'w'
#matplotlib.rcParams['axes.facecolor'] = 'w'
matplotlib.rcParams['font.family'] = 'arial'
matplotlib.rcParams['hatch.linewidth'] = 0.3
        
#Dropsonde data
drop_metric_filepath = os.path.join(os.getcwd(), 'Dropsonde_Metric_Calculations_CPEXCV.csv')
df_drop = pd.read_csv(drop_metric_filepath)

#ERA5 data
era5_folder = os.path.join(os.getcwd(), 'ERA5_Reanalysis_Data')
era5_path = os.path.join(era5_folder, 'CPEXCV_ERA5_Reanalysis_Hourly_Pressure.nc')
ds_era5 = xr.open_dataset(era5_path)

#AEW tracker data, 6-hourly (Quinton’s AEW tracker: https://osf.io/jnv5u, https://zenodo.org/records/13350860)
aew_folder = os.path.join(os.getcwd(), 'AEW_Tracker_Data')
aew_path = os.path.join(aew_folder, 'AEW_tracks_post_processed_year_2022.nc')
ds_aew = xr.open_dataset(aew_path)

#TIMPS data
timps_folder = os.path.join(os.getcwd(), 'TIMPS_data')

conv_bins = np.arange(-15,15.1,0.5)

#plot convergence PDFs at low- and mid-levels for each AEW sector for each case
for pres_lev in pressures_to_plot_conv:
    for sector in ['full', 'south_ahead', 'ahead', 'north_ahead', 'south_behind', 'behind', 'north_behind', 'north']:
    
        group_fig = plt.figure(figsize = (12, 12))   #initialize the convergence PDF figure for the given case
        ax = group_fig.add_subplot(1, 1, 1)
        
        for case_num in case_dict_conv_aew.keys():
        
            #print (f'Case {case_num} convergence PDF plots in progress...')
            
            aew_system_index_use = None
            
            df_drop_case = df_drop[df_drop['Case'] == case_num].copy()
            
            case_date = case_dict_conv_aew[case_num][0]
            case_hours = case_dict_conv_aew[case_num][1]
            case_timps_ids = df_drop_case['TIMPS ID'].unique()
            
            if len(case_timps_ids) == 1 and pd.isnull(case_timps_ids[0]):  #no TIMPS IDs for the given case
                continue
            
            aews_at_given_date = ds_aew.sel(time = case_date)
            day_after = datetime.strptime(case_date, '%Y%m%d') + timedelta(days = 1)  #day after case_date
            aews_at_given_date_plus1 = ds_aew.sel(time = datetime.strftime(day_after, '%Y%m%d'))
            
            #match the convective case to the nearest AEW (longitudinally) from Quinton's tracker
                #and confirm that each TIMPS ID for the given case matches to the same AEW (sanity check)
            case_hour_nearest_multiple6 = case_hours[np.nanargmin(np.abs(np.array(case_hours) % 6))]  #find the hour closest to a multiple of 6 (AEW tracker is only in 6-hourly intervals)
            nearest_6_to_case_hour = 6 * round(case_hour_nearest_multiple6 / 6)  #find the multiple of 6 closest to case_hour_nearest_multiple6
        
            for ii, unique_timps_id in enumerate(case_timps_ids):
                
                unique_timps_id = str(int(unique_timps_id))
                
                timps_filepath = None
                for filename in os.listdir(timps_folder):
                    if unique_timps_id in filename:
                        timps_filepath = os.path.join(timps_folder, filename)
                        break
        
                if timps_filepath == None:
                    sys.exit(f'Could not find TIMPS file for TIMP ID {unique_timps_id}')
                else:
                    timps_ds0 = xr.open_dataset(timps_filepath)
                    timps_ds0 = timps_ds0.sel(time = case_date)
                    timps_ds = timps_ds0.sel(time = timps_ds0.time.dt.hour.isin(case_hour_nearest_multiple6))   #gives 2 times: minute = 0 and minute = 30
                    timps_ds = timps_ds.sel(time = timps_ds.time.dt.minute.isin(0))    #grab the time on the hour to match with ERA5 and AEW tracker
        
                    if len(timps_ds.gmd) == 0:
                        print (f'{str(case_hour_nearest_multiple6).zfill(2)} UTC is out of range of the TIMPS ID range ({timps_ds0.time[0].values.astype(str)[:-10]} - {timps_ds0.time[-1].values.astype(str)[:-10]})')
                        timps_ds0.close()
                        continue
        
                    #timps_weighted_lat = timps_ds.centlatwgt.item()
                    timps_weighted_lon = timps_ds.centlonwgt.item()
                    
                    aews_at_given_hour = aews_at_given_date.sel(time = aews_at_given_date.time.dt.hour.isin(nearest_6_to_case_hour))
                
                    lon_difs = aews_at_given_hour['AEW_lon_smooth'][:,0] - timps_weighted_lon
                    aew_system_index = np.nanargmin(np.abs(lon_difs).values)   #np.argmin() would grab a NaN value!
                    
                    if aew_system_index_use != None:
                        if case_num != 22:
                            assert aew_system_index == aew_system_index_use, 'TIMPS IDs for the given case are not matching to the same AEW'
                        else:  #Case 22 has 2 TIMPS IDs which incorrectly match to different AEWs;
                               #the 2nd TIMPS ID (304352) matches to the correct AEW, so we're forcing the code to choose that AEW here (not ideal, but just one case we need to do this for)
                            aew_system_index_use = aew_system_index * 1   #the number of the AEW in the AEW tracker file (using * 1 so that aew_system_index_use variable doesn't point (i.e., isn't tied to) to the same reference as aew_system_index)
                            print (f"Actual AEW system: {aew_system_index_use + 1}, AEW central longitude and strength at {case_date} {nearest_6_to_case_hour} UTC: {np.round(aews_at_given_hour['AEW_lon_smooth'][aew_system_index_use,0].item(), 2)}, {aews_at_given_hour['AEW_strength'][aew_system_index_use,0].item()} s-1")
                    else:
                        aew_system_index_use = aew_system_index * 1   #the number of the AEW in the AEW tracker file (using * 1 so that aew_system_index_use variable doesn't point (i.e., isn't tied to) to the same reference as aew_system_index)
                        print (f"AEW system: {aew_system_index_use + 1}, AEW central longitude and strength at {case_date} {nearest_6_to_case_hour} UTC: {np.round(aews_at_given_hour['AEW_lon_smooth'][aew_system_index_use,0].item(), 2)}, {aews_at_given_hour['AEW_strength'][aew_system_index_use,0].item()} s-1")
                        
                    timps_ds0.close()
                
            matched_aew_ds = aews_at_given_date.isel(system = aew_system_index_use)
            matched_aew_ds_dayafter = aews_at_given_date_plus1.isel(system = aew_system_index_use)
    
            conv_df = pd.DataFrame()
            conv_lats_df = pd.DataFrame()
            #conv_lons_df = pd.DataFrame()
    
            for hr in case_hours:
    
                #grab/calculate the AEW centroid coordinates for the given hr
                if hr % 6 == 0:  #if hr is a multiple of 6, then don't need to interpolate the AEW centroid at all
                    matched_aew_smoothed_lat = matched_aew_ds.sel(time = matched_aew_ds.time.dt.hour.isin(hr)).AEW_lat_smooth.item()
                    matched_aew_smoothed_lon = matched_aew_ds.sel(time = matched_aew_ds.time.dt.hour.isin(hr)).AEW_lon_smooth.item()
    
                else:  #interpolate the centroid of the matched AEW to the given hr
    
                    #find the multiple of 6 directly below/equal to hr; this will be your starting hour to interpolate the AEW centroid to hr
                    nearest_6_below_hr = 6 * (hr // 6)
                    nearest_6_above_hr = nearest_6_below_hr + 6  #this will be your ending hour to interpolate the AEW centroid to hr
    
                    matched_aew_start_hr_ds = matched_aew_ds.sel(time = matched_aew_ds.time.dt.hour.isin(nearest_6_below_hr))
    
                    if nearest_6_above_hr == 24:  #grab 00 UTC from the next day
                        matched_aew_end_hr_ds = matched_aew_ds_dayafter.sel(time = matched_aew_ds_dayafter.time.dt.hour.isin(0))
                    else:
                        matched_aew_end_hr_ds = matched_aew_ds.sel(time = matched_aew_ds.time.dt.hour.isin(nearest_6_above_hr))
    
                    matched_aew_start_lat = matched_aew_start_hr_ds.AEW_lat_smooth.item()
                    matched_aew_start_lon = matched_aew_start_hr_ds.AEW_lon_smooth.item()
                    matched_aew_end_lat = matched_aew_end_hr_ds.AEW_lat_smooth.item()
                    matched_aew_end_lon = matched_aew_end_hr_ds.AEW_lon_smooth.item()
    
                    #AEW centroid moves XX degrees per hour
                    lat_per_hour = (matched_aew_end_lat - matched_aew_start_lat) / 6
                    lon_per_hour = (matched_aew_end_lon - matched_aew_start_lon) / 6
    
                    #interpolated AEW centroid at the given hr
                    matched_aew_smoothed_lat = matched_aew_start_lat + (lat_per_hour * (hr % 6))
                    matched_aew_smoothed_lon = matched_aew_start_lon + (lon_per_hour * (hr % 6))
    
                #create 10-by-10 degree box around the (interpolated) AEW centroid at the given hr
                aew_lat_range = slice(matched_aew_smoothed_lat + 5, matched_aew_smoothed_lat - 5)
                aew_lon_range = slice(matched_aew_smoothed_lon - 5, matched_aew_smoothed_lon + 5)
    
                #grab all the ERA5 low-/mid-level convergence values and corresponding lats/lons within the given AEW box at the given hour
                    #and separate the data into the 2 sectors (ahead/behind) of the AEW
                v700 = ds_era5.v.sel(time = case_date).sel(level = 700)
                v700 = v700.sel(time = v700.time.dt.hour.isin(hr))
                v700 = mpcalc.smooth_gaussian(v700, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
                v700 = v700.sel(longitude = aew_lon_range, latitude = aew_lat_range)
            
                #manually calculating convergence from ERA5 u and v winds (recommended by Brandon Wolding via George Kiladis)
                u = ds_era5.u.sel(time = case_date).sel(level = pres_lev)
                u = u.sel(time = u.time.dt.hour.isin(hr))
                u = mpcalc.smooth_gaussian(u, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
                #u = u.sel(longitude = aew_lon_range, latitude = aew_lat_range)
                
                v = ds_era5.v.sel(time = case_date).sel(level = pres_lev)
                v = v.sel(time = v.time.dt.hour.isin(hr))
                v = mpcalc.smooth_gaussian(v, 5)   #smooth ERA5 winds using a 5-point filter (Quinton)
                #v = v.sel(longitude = aew_lon_range, latitude = aew_lat_range)
                
                delta_lons = 0.25   #ERA5 lat/lon resolution is 0.25 degrees
                delta_lons_meters = (111.3195 * 1000 * delta_lons) * np.cos(u.latitude.values * np.pi/180)  #distance between longitude lines at equator is 111.3195 km and cosine weighting this distance by latitude
                dudx = (u[:,:,1:].values - u[:,:,:-1].values).squeeze() / np.expand_dims(np.abs(delta_lons_meters), axis=1)  #squeeze() removes dimensions of size 1 from an array, and expand_dims() inserts a new axis that will appear at the axis position
                dudx = np.column_stack((dudx, dudx[:,-1]))  #duplicate the last column of dudx to match original grid shape (and shape of dvdy)
    
                delta_lats = 0.25
                delta_lats_meters = 110.5744 * 1000 * delta_lats  #distance between latitude lines everywhere
                dvdy = (v[:,:-1,:].values - v[:,1:,:].values).squeeze() / delta_lats_meters  #squeeze() removes dimensions of size 1 from an array
                dvdy = np.vstack((dvdy, dvdy[-1,:]))  #duplicate the last row of dvdy to match original grid shape (and shape of dudx)
    
                conv_old = (dudx + dvdy) * -1 * 10**5  #manually calculated convergence from ERA5 u and v winds (times 10**5 1/s)
                ds_conv = xr.Dataset(data_vars = dict(convergence = (["latitude", "longitude"], conv_old)),
                                     coords = dict(latitude = ("latitude", u.latitude.values), 
                                                   longitude = ("longitude", u.longitude.values)),
                                     attrs = dict(description = "Manually calculated ERA5 convergence data"))
                
                conv = ds_conv.convergence.sel(longitude = aew_lon_range, latitude = aew_lat_range)
            
                # #using convergence variable from ERA5
                # conv = ds_era5.d.sel(time = case_date).sel(level = pres_lev) * -1    #convergence of the wind (1/s)
                # conv = conv.sel(time = conv.time.dt.hour.isin(hr)) * 10**5  #convergence of the wind (times 10**5 1/s)
                # conv = conv.sel(longitude = aew_lon_range, latitude = aew_lat_range)
                
                #filter the convergence data by the 700-hPa v-component of the wind  
                    #v <= 0: the grid point is ahead of the AEW center 
                    #v > 0: the grid point is behind the AEW center
                        #This dynamically defines ahead/behind AEW centers, which is especially practical for asymmetric AEWs!
                if sector == 'full':
                    conv = conv * 1
                    clc = 'Full'
                elif sector == 'south_ahead':
                    conv = conv.where(conv.latitude < matched_aew_smoothed_lat).where(v700 <= 0)  #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
                    clc = 'Ahead (South)'
                elif sector == 'ahead':
                    conv = conv.where(v700 <= 0)  #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
                    clc = 'Ahead'
                elif sector == 'north_ahead':
                    conv = conv.where(conv.latitude >= matched_aew_smoothed_lat).where(v700 <= 0)  #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
                    clc = 'Ahead (North)'
                elif sector == 'south_behind':
                    conv = conv.where(conv.latitude < matched_aew_smoothed_lat).where(v700 > 0)   #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
                    clc = 'Behind (South)'
                elif sector == 'behind':
                    conv = conv.where(v700 > 0)   #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
                    clc = 'Behind'
                elif sector == 'north_behind':
                    conv = conv.where(conv.latitude >= matched_aew_smoothed_lat).where(v700 > 0)   #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
                    clc = 'Behind (North)'
                elif sector == 'north':
                    conv = conv.where(conv.latitude >= matched_aew_smoothed_lat)   #returns elements from 'conv' where condition is True, otherwise fill in NaNs by default
                    clc = 'North'
    
                lon, lat = np.meshgrid(conv.longitude, conv.latitude)
                lats = lat.reshape(-1)
                conv_values = conv.values.reshape(-1)
    
                #add data from each hour as COLUMNS to corresponding df
                conv_df = pd.concat((conv_df, pd.Series(conv_values)), axis = 1, ignore_index = True)
                conv_lats_df = pd.concat((conv_lats_df, pd.Series(lats)), axis = 1, ignore_index = True)
                #conv_lons_df = pd.concat((conv_lons_df, pd.Series(lons)), axis = 1, ignore_index = True)         
                    
            if case_num == list(case_dict_conv_aew.keys())[0]:
                color = 'darkred'
                text_denom = 0
            elif case_num == list(case_dict_conv_aew.keys())[1]:
                color = 'navy'
                text_denom = 0.2
    
            conv_df = conv_df.values  #convert Pandas DataFrame to NumPy array
            conv_lats_df = conv_lats_df.values  #convert Pandas DataFrame to NumPy array
    
            #mask NaN values in the Dataframes so that the numpy stat calculations work below
            conv_df_masked = np.ma.masked_where(np.isnan(conv_df), conv_df)
            conv_lats_df_masked = np.ma.masked_where(np.isnan(conv_lats_df), conv_lats_df)
    
            #line plot histogram (clearer to interpret than "step" histogram below)
            hist, bins = np.histogram(conv_df_masked.reshape(-1), bins = conv_bins, density = True, weights = None)
            bin_centers = (bins[:-1] + bins[1:]) / 2  # Midpoints of the bins
            ax.plot(bin_centers, hist, linewidth = 2, linestyle = '-', color = color, label = f'Case {case_num}')
    
            # #normal "step" histogram
            # ax.hist(conv_df_masked.reshape(-1), bins = conv_bins, density = True, weights = None,
            #         histtype = 'step', align = 'mid', orientation = 'vertical', color = color,
            #         linewidth = 2, label = f'Case {case_num}')
            #     #density = True returns a probability density: each bin will display the bin's raw count 
            #         #divided by the total number of counts times the bin width
            #         #(density = counts / (sum(counts) * np.diff(bins))), so that the area under the 
            #         #histogram integrates to 1 (np.sum(density * np.diff(bins)) == 1)
    
            cos_weights = np.sqrt(np.cos(conv_lats_df_masked * np.pi/180))   #cosine weights to apply to conv_df_masked
    
            conv_count = np.count_nonzero(~np.isnan(conv_df))
            conv_median = np.round(np.nanmedian(conv_df, axis = None), 2)
            conv_mean = np.round(np.nanmean(conv_df, axis = None), 2)                                        #non-weighted mean (1st moment)
            conv_wgt_mean = np.round(np.average(conv_df_masked, axis = None, weights = cos_weights), 2)      #cosine-weighted mean (1st moment)
            conv_std = np.round(np.std(conv_df_masked, axis = None), 2)                                      #standard deviation (2nd moment)
            conv_skew = np.round(scipy.stats.skew(conv_df_masked, axis = None, nan_policy = 'omit'), 4)      #skewness (3rd moment)
            conv_kurt = np.round(scipy.stats.kurtosis(conv_df_masked, axis = None, nan_policy = 'omit'), 4)  #kurtosis (4th moment)
    
            # ax.text(0.98, 0.875 / text_denom, 
            #         f'Count: {conv_count}\nMedian: {conv_median}\nMean: {conv_mean}\nWeighted Mean: {conv_wgt_mean}\nStandard Deviation: {conv_std}\nSkewness: {conv_skew}\nKurtosis: {conv_kurt}\n', 
            #         transform = ax.transAxes, horizontalalignment = 'right', verticalalignment = 'center', 
            #         fontsize = 16, fontweight = 'bold', color = color)
            ax.text(0.98, 0.89 - text_denom, 
                    f'Count: {conv_count}\nMedian: {conv_median}\nMean: {conv_mean}\nWeighted Mean: {conv_wgt_mean}\nStandard Deviation: {conv_std}\nSkewness: {conv_skew}\nKurtosis: {conv_kurt}\n', 
                    transform = ax.transAxes, horizontalalignment = 'right', verticalalignment = 'center', 
                    fontsize = 16, fontweight = 'bold', color = color)
        
            if case_num == 20:
                ax.text(0.5, 0.5, f'Case {case_num} not being correctly\nmatched to the appropriate AEW', 
                transform = ax.transAxes, horizontalalignment = 'center', verticalalignment = 'center', 
                fontsize = 30, bbox = {'facecolor': 'white', 'alpha': 0.5, 'pad': 10})

        ax.set_title('ERA5 %i hPa Convergence PDFs (AEW-Relative, %s)' % (pres_lev, clc))
        ax.axvline(x = 0, color = 'k', linestyle = '--', alpha = 0.5)
        ax.set_xlabel('Convergence [10$^{-5}$ s$^{-1}$]')
        ax.set_ylabel('Prob(Convergence)')
        ax.set_xlim([-11,11])
        ax.set_xticks(np.arange(-11, 11.1, 2))
        ax.set_ylim(bottom = 0)
        #ax.set_yticks(np.arange(0, 0.601, 0.05))
        ax.grid(axis = 'y')
        ax.legend(loc = 'upper left')
                        
        #plt.tight_layout()
        #plt.subplots_adjust(wspace = 0.1)

        #save the figure
        plot_save_name = f'Case{list(case_dict_conv_aew.keys())[0]}-{list(case_dict_conv_aew.keys())[1]}_{pres_lev}hPa_convergence_PDFs_AEW_relative_{sector}.png'
        plt.savefig(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/Smoothed_ERA5_winds', plot_save_name), bbox_inches = 'tight')
        #plt.show()  #plt.show() must come after plt.savefig() in order for the image to save properly
        #plt.clf()   #supposedly speeds things up? According to: https://www.youtube.com/watch?v=jGVIZbi9uMY
        plt.close()
        plt.clf()    #if placing this after plt.close(), may release memory related to the figure (https://stackoverflow.com/questions/741877/how-do-i-tell-matplotlib-that-i-am-done-with-a-plot)

        ##decrease file size of the image by 66% without noticeable image effects (if using Matplotlib)
        ##(good to use if you're producing a lot of images, see https://www.youtube.com/watch?v=fzhAseXp5B4)
        im = Image.open(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/Smoothed_ERA5_winds', plot_save_name))

        try:
            im2 = im.convert('P', palette = Image.Palette.ADAPTIVE)
        except:
            #use this for older version of PIL/Pillow if the above line doesn't work, 
            #though this line will have isolated, extremely minor image effects due to 
            #only using 256 colors instead of the 3-element RGB scale
            im2 = im.convert('P')
    
        im2.save(os.path.join('/Users/ben/Desktop/CPEX/CPEX-CV_Convergence_PDFs/Quinton_new_AEW_tracker/Using_Calculated_ERA5_convergence/AEW_Relative/Smoothed_ERA5_winds', plot_save_name))
        im.close()
        im2.close()
        
        #print (f'Case {case_num} convergence PDF plots complete!\n')
    
ds_era5.close()
ds_aew.close()

tend = time.time()
print (f'This script took {np.round((tend - tstart) / 60, 1)} minutes to complete.')
