In [None]:
import numpy as np
import xarray as xr
import scipy.stats as st
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import requests
import os,errno
import sys
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER,LATITUDE_FORMATTER
import matplotlib.ticker as mticker
import cartopy.feature as cfeature
import datetime as dt
import pandas as pd
import time
import urllib.request
import metpy.calc as mpcalc
import salem
import scipy.optimize as opt
import warnings
import geopy.distance
import matplotlib as mpl
from scipy.spatial import ConvexHull

from scipy.ndimage.measurements import label
from scipy.ndimage import binary_dilation

warnings.filterwarnings('ignore')

dir_data='/Users/ahenny/'
dir3='/Volumes/My Passport/'
dir2='/Volumes/Extreme Pro/'

In [None]:
df=xr.open_dataset(dir_data+'merra2_gridareas.nc')
areas1=df['cell_area'].transpose('lat','lon')
areas1=areas1/float(1e6)

lons_east=[x for x in areas1.lon.values if x>=90.]
lons_west=[x for x in areas1.lon.values if x<-90.]

areas_east=areas1.sel(lon=lons_east)
areas_west=areas1.sel(lon=lons_west)

areas_east['lon']=[x-360. for x in lons_east]
areas_west['lon']=[x+360. for x in lons_west]

areas=xr.concat([areas_east,areas1,areas_west],dim='lon')

In [None]:
print(areas.max().values)
print(areas.lon.values[0:10])
print(areas.lat.values[0:10])
print(areas)

In [None]:
yrs=np.arange(1980,2023,1)
for l in range(len(yrs)):
    year=yrs[l]
    print(year)
    ds=xr.open_dataset(dir2+'merra2.ivt.3hr.model.'+str(year)+'.nc')
    ivt=ds['ivt']
    addition=0
    if year%4==0:
        addition=8
    if ivt.time.size!=365*8+addition:
        print('FAIL')
    
    for j in range(int(ivt.time.size/8)):
        #print(j)
        ivt_sel=ivt[8*j:8*(j+1),:,:]
        ivt_mean=ivt_sel.mean(dim='time',skipna=True)
        if j==0:
            ivt_concat=ivt_mean
        else:
            ivt_concat=xr.concat([ivt_concat,ivt_mean],dim='time')
        
    dk=xr.Dataset()
    dk['ivt']=(('time','lat','lon'),ivt_concat.values)
    dk.coords['time']=ivt_concat.time
    dk.coords['lat']=ivt_concat.lat
    dk.coords['lon']=ivt_concat.lon
    dk['lat'].attrs["units"]='degrees_north'
    dk['lon'].attrs["units"]='degrees_east'
    dk['ivt'].attrs["units"]='kg*m-1*s-1'
    try:
        os.remove(dir2+'merra2.ivt.dly.model.'+str(year)+'.nc')
    except OSError:
        pass
    dk.to_netcdf(dir2+'merra2.ivt.dly.model.'+str(year)+'.nc',mode='w',format='NETCDF4')
    
    ds.close()
    ivt.close()

In [None]:
if 1==0:
    ds=xr.open_dataset(dir2+'merra2.ivt.dly.model.'+str(1980)+'.nc')
    ivt=ds['ivt']

    ivt_threshold_final=xr.zeros_like(ivt[0,:,:])

    dk=xr.Dataset()
    dk['threshold']=(('lat','lon'),ivt_threshold_final.values)
    dk.coords['lat']=ivt_threshold_final.lat
    dk.coords['lon']=ivt_threshold_final.lon
    dk['lat'].attrs["units"]='degrees_north'
    dk['lon'].attrs["units"]='degrees_east'
    dk['threshold'].attrs["units"]='kg*m-1*s-1'

    try:
        os.remove(dir2+'merra2_ivt_threshold_test.nc')
    except OSError:
        pass
    dk.to_netcdf(dir2+'merra2_ivt_threshold_test.nc',mode='w',format='NETCDF4')
    sys.exit()


yrs=np.arange(1980,2023,1)

lon_bins=np.arange(-180,210,30)
lat_bins=np.arange(-90,120,30)
 
dates_all=[]
for i in range(len(yrs)):
    year=yrs[i]
    year_length=365
    if year%4==0:
        year_length=366
            
    date_start=dt.datetime(year,1,1,0)
    date_range=[date_start+dt.timedelta(days=x) for x in range(year_length)]
    dates_all=dates_all+date_range

for j in range(len(lat_bins)-1):
    for k in range(len(lon_bins)-1):
        print((j,k))
        ds1=xr.open_dataset(dir2+'merra2_ivt_threshold_test.nc')
        threshold=ds1['threshold'].values
        
        lat_range=[x for x in ds1.lat.values if lat_bins[j]<=x<lat_bins[j+1]]
        lon_range=[x for x in ds1.lon.values if lon_bins[k]<=x<lon_bins[k+1]]
        
        zipped_lat=list(zip(np.arange(ds1.lat.size),ds1.lat.values.tolist()))
        zipped_lon=list(zip(np.arange(ds1.lon.size),ds1.lon.values.tolist()))

        select_lat_indices=[x[0] for x in zipped_lat if x[1] in lat_range]
        select_lon_indices=[x[0] for x in zipped_lon if x[1] in lon_range]
            
        min_lat_index=min(select_lat_indices)
        min_lon_index=min(select_lon_indices)
        max_lat_index=max(select_lat_indices)+1
        max_lon_index=max(select_lon_indices)+1
        
        threshold_test=np.sum(threshold[min_lat_index:max_lat_index,min_lon_index:max_lon_index])
        if threshold_test==0:
        
            for i in range(len(yrs)):
                year=yrs[i]

                ds=xr.open_dataset(dir2+'merra2.ivt.dly.model.'+str(year)+'.nc')
                ivt=ds['ivt']
                
                ivt=ivt.sel(lat=lat_range,lon=lon_range)

                if i==0:
                    ivt_concat=ivt
                else:
                    ivt_concat=xr.concat([ivt_concat,ivt],dim='time')
            
            threshold_array=np.nanquantile(ivt_concat.values,0.9,axis=0)
            threshold_array=np.minimum(threshold_array,300.)
            threshold_array=np.maximum(threshold_array,50.)
            
            threshold[min_lat_index:max_lat_index,min_lon_index:max_lon_index]=threshold_array#.values
            
            
            if 1==1:
                dk=xr.Dataset()
                dk['threshold']=(('lat','lon'),threshold)
                dk.coords['lat']=ds1.lat
                dk.coords['lon']=ds1.lon
                dk['lat'].attrs["units"]='degrees_north'
                dk['lon'].attrs["units"]='degrees_east'
                dk['threshold'].attrs["units"]='kg*m-1*s-1'

                try:
                    os.remove(dir2+'merra2_ivt_threshold_test.nc')
                except OSError:
                    pass
                dk.to_netcdf(dir2+'merra2_ivt_threshold_test.nc',mode='w',format='NETCDF4')
                dk.close()
                ds.close()
                ds1.close()

In [None]:
#Plot IVT magnitude threshold

if 1==0:
    ds=xr.open_dataset(dir2+'merra2_ivt_threshold.nc')
    threshold=ds['threshold']
    threshold=threshold.where(threshold>=50.).fillna(50)
    print(threshold.min().values)
    print(threshold.max().values)
    sys.exit()
    dk=xr.Dataset()
    dk['threshold']=(('lat','lon'),threshold.values)
    dk.coords['lat']=ds1.lat
    dk.coords['lon']=ds1.lon
    dk['lat'].attrs["units"]='degrees_north'
    dk['lon'].attrs["units"]='degrees_east'
    dk['threshold'].attrs["units"]='kg*m-1*s-1'

    threshold=dk['threshold']
    print(threshold.min().values)
    try:
        os.remove(dir2+'merra2_ivt_threshold.nc')
    except OSError:
        pass
    dk.to_netcdf(dir2+'merra2_ivt_threshold.nc',mode='w',format='NETCDF4')
    sys.exit()

ds=xr.open_dataset(dir2+'merra2_ivt_threshold.nc')
threshold_array=ds['threshold']
if 1==0:
    if 1==1:
        if 1==1:
            fig=plt.figure(figsize=(22,9))
            ax=plt.subplot(1,1,1,projection=ccrs.PlateCarree())
            cax1=ax.contourf(threshold_array.lon.values,threshold_array.lat.values,threshold_array,levels=np.arange(50,325,25),extend='max',cmap=plt.cm.BrBG,transform=ccrs.PlateCarree())
            cbar=plt.colorbar(cax1,pad=0,fraction=0.046,ticks=np.arange(100,350,50))
            cbar.ax.tick_params(labelsize=35) 
            cbar.set_label('kg m$^{-1}$ s$^{-1}$',fontsize=37,labelpad=5)
            g1=ax.gridlines(crs=ccrs.PlateCarree(),draw_labels=True,linewidth=1.5,color='gray',alpha=0.0,linestyle='--')

            ax.coastlines(resolution='10m')
            #ax.add_feature(cfeature.STATES.with_scale('10m'),alpha=0.3)
            #ax.add_feature(cfeature.LAKES.with_scale('50m'))
            countries = cfeature.NaturalEarthFeature(category='cultural',name='admin_0_boundary_lines_land',scale='50m',facecolor='none')
            ax.add_feature(countries)
            g1.xlabel_style={'size':33,'color':'k'}
            g1.ylabel_style={'size':33,'color':'k'}
            g1.xformatter=LONGITUDE_FORMATTER
            g1.yformatter=LATITUDE_FORMATTER
            #g1.xlocator = mticker.FixedLocator([119,120,121,122,123])
            #g1.ylocator = mticker.FixedLocator([21,22,23,24,25,26])

            g1.top_labels=False
            g1.right_labels=False
            ax.set_xlim(-180,180)
            ax.set_ylim(-90,90)

            ax.set_title(r'$\bf{MERRA-2}$',fontsize=52,pad=5)
            plt.show()

In [None]:
fig.savefig(dir_data+'ivt_thresholds_paper_1.png')

In [None]:
#ds.close()
ds=xr.open_dataset(dir2+'merra2_ivt_threshold.nc')
threshold_array=ds['threshold']

lons_east=[x for x in threshold_array.lon.values if x>=90.]
lons_west=[x for x in threshold_array.lon.values if x<-90.]
    
threshold_array_east=threshold_array.sel(lon=lons_east)
threshold_array_west=threshold_array.sel(lon=lons_west)
    
threshold_array_east['lon']=[x-360. for x in lons_east]
threshold_array_west['lon']=[x+360. for x in lons_west]
    
ivt_threshold=xr.concat([threshold_array_east,threshold_array,threshold_array_west],dim='lon')
print(ivt_threshold)

In [None]:
yrs=np.arange(1980,2024,1)
count_stop=0
count_cont=0 
count_fig=0


ds=xr.open_dataset(dir2+'merra2.ar.labels.model.variable.1980.nc')
l1=ds['ar_labeled']

for l in range(len(yrs)):
    year=yrs[l]
    
    ds=xr.open_dataset(dir2+'merra2.ivt.3hr.model.'+str(year)+'.nc')
    ivt_mag_current1=ds['ivt'][0,:,:]
    
    lons_east=[x for x in ds.lon.values if x>=90.]
    lons_west=[x for x in ds.lon.values if x<-90.]
    
    ivt_mag_current_east=ivt_mag_current1.sel(lon=lons_east)
    ivt_mag_current_west=ivt_mag_current1.sel(lon=lons_west)
    
    ivt_mag_current_east['lon']=[x-360. for x in lons_east]
    ivt_mag_current_west['lon']=[x+360. for x in lons_west]
    
    ivt_mag_current=xr.concat([ivt_mag_current_east,ivt_mag_current1,ivt_mag_current_west],dim='lon')
    
    lons_extended=ivt_mag_current.lon.values.tolist()

    lats=ds.lat.values.tolist()
    lons=lons_extended
    dates_unique=ds.time.values[0::2]
    
    lat_array=xr.zeros_like(ivt_mag_current[:,:])
    lon_array=xr.zeros_like(ivt_mag_current[:,:])
    for i in range(ivt_mag_current.lon.size):
        lat_array[:,i]=ivt_mag_current.lat.values
    for i in range(ivt_mag_current.lat.size):
        lon_array[i,:]=ivt_mag_current.lon.values
        
    lon_array_list=lon_array.values.ravel()
    lat_array_list=lat_array.values.ravel()
    zipped_latlon=list(zip(lat_array_list,lon_array_list))
    
    #indices=np.random.choice(np.arange(1464),150)
    #indices=[129]
    indices=[1108]
    
    #for d in range(len(dates_unique)):#2161
    for d in indices:
        
        date=pd.to_datetime(dates_unique[d])
        print(date)
        year=date.year
        month=date.month
        day=date.day
        hour=date.hour
        
        ivt_mag_current1=ds['ivt'].sel(time=date)
        
        lons_east=[x for x in ds.lon.values if x>=90.]
        lons_west=[x for x in ds.lon.values if x<-90.]

        ivt_mag_current_east=ivt_mag_current1.sel(lon=lons_east)
        ivt_mag_current_west=ivt_mag_current1.sel(lon=lons_west)

        ivt_mag_current_east['lon']=[x-360. for x in lons_east]
        ivt_mag_current_west['lon']=[x+360. for x in lons_west]

        ivt_mag_current=xr.concat([ivt_mag_current_east,ivt_mag_current1,ivt_mag_current_west],dim='lon')
        
        ivt_extreme=ivt_mag_current.where(ivt_mag_current>=ivt_threshold)
        ivt_ones=ivt_extreme/ivt_extreme
        ivt_ones=ivt_ones.fillna(0)
        
        structure = np.ones((3, 3))          
        labeled,ncomponents=label(ivt_ones,structure)
        labeled_xr=xr.zeros_like(ivt_extreme)
        labeled_xr[:,:]=labeled
        labeled_xr_flat=labeled_xr.values.flatten()
        labeled_xr_original=labeled_xr.copy()

        ar_timestep=xr.zeros_like(labeled_xr)
        
        for i in range(ncomponents+1):
            distance_max=0
            select_component=labeled_xr.where(labeled_xr==i)
            ivt_component=ivt_ones.where(labeled_xr==i)
            ivt_sum=ivt_component.sum(skipna=True).values
            
            if 1==0:
                test=select_component/select_component
                test=test.fillna(0)

                fig=plt.figure(figsize=(22,9))
                ax=plt.subplot(1,1,1,projection=ccrs.PlateCarree())
                cax1=ax.contourf(test.lon,test.lat,test.where(test>0),levels=[0,1],colors='b',transform=ccrs.PlateCarree(),alpha=0.75)
                cbar=plt.colorbar(cax1,pad=0,fraction=0.046)
                cbar.ax.tick_params(labelsize=20) 
                cbar.set_label('kg m$^{-1}$ s$^{-1}$',fontsize=21,labelpad=5)

                g1=ax.gridlines(crs=ccrs.PlateCarree(),draw_labels=True,linewidth=1.5,color='gray',alpha=0.0,linestyle='--')

                ax.coastlines(resolution='10m')
                #ax.add_feature(cfeature.STATES.with_scale('10m'),alpha=0.3)
                #ax.add_feature(cfeature.LAKES.with_scale('50m'))
                countries = cfeature.NaturalEarthFeature(category='cultural',name='admin_0_boundary_lines_land',scale='50m',facecolor='none')
                ax.add_feature(countries)
                g1.xlabel_style={'size':20,'color':'k'}
                g1.ylabel_style={'size':20,'color':'k'}
                g1.xformatter=LONGITUDE_FORMATTER
                g1.yformatter=LATITUDE_FORMATTER
                #g1.xlocator = mticker.FixedLocator([119,120,121,122,123])
                #g1.ylocator = mticker.FixedLocator([21,22,23,24,25,26])

                g1.top_labels=False
                g1.right_labels=False
                ax.set_xlim(-180,180)
                ax.set_ylim(-75,75)

                ax.set_title('AR detection: '+date.strftime('%Y.%m.%d.%H'),fontsize=27,pad=5)
                plt.show()


            
            if ivt_sum>=60.:
                lats_where=lat_array.where(labeled_xr==i)
                lons_where=lon_array.where(labeled_xr==i)
                min_lat=lats_where.min(skipna=True).values
                max_lat=lats_where.max(skipna=True).values
                min_lon=lons_where.min(skipna=True).values
                max_lon=lons_where.max(skipna=True).values
                distance_hyp=geopy.distance.geodesic((min_lat,min_lon),(max_lat,max_lon)).km
                
                if distance_hyp>=2000:#filter out features that could not possibly be long enough
                    choose_lats=lat_array.where(labeled_xr==i)
                    choose_lons=lon_array.where(labeled_xr==i)
                    lats_flat=choose_lats.values.ravel()
                    lons_flat=choose_lons.values.ravel()
                  
                    zipped_ar=list(zip(lats_flat,lons_flat))
                    zipped_ar=[x for x in zipped_ar if x[0]>=-90]

                    hull=ConvexHull(zipped_ar)

                    hull_points=[]
                    for j in range(len(hull.vertices)):
                        hull_points.append(zipped_ar[hull.vertices[j]])

                    distance_list=[]
                    for j in range(len(hull_points)):
                        for k in range(len(hull_points)):
                            point_1=hull_points[j]
                            point_2=hull_points[k]
                            distance=geopy.distance.distance(point_1,point_2).km
                            distance_list.append(distance)

                    length=max(distance_list)
                    
                    area_sum=areas.where(labeled_xr==i).sum(dim=('lat','lon'),skipna=True).values.tolist()
                    width=area_sum/length
                    
                    #Now apply tropical condition
                    
                    component_tropical=select_component.where(lat_array<=20)
                    component_tropical=component_tropical.where(lat_array>=-20)
                    component_tropical_filled=component_tropical.fillna(0)
                    component_nontropical=select_component.where(component_tropical_filled==0)
                    area_tropical=areas.where(component_tropical==i).sum(dim=('lat','lon'),skipna=True).values
                    area_nontropical=areas.where(component_nontropical==i).sum(dim=('lat','lon'),skipna=True).values
                    
                    CONTINUE='yes'
                    
                    if 1==1:
                        if area_tropical/area_nontropical<1./2.5 and length/width>=2.5 and length>=2000.:
                            
                            final_region=labeled_xr.where(labeled_xr==i).fillna(0)
                            final_region=final_region/final_region
                            final_region=final_region.fillna(0)
                            ar_timestep=ar_timestep+final_region
                                
                            count_stop=count_stop+1
                            CONTINUE='no'
                        
                        elif length>=2000.:
                            CONTINUE='yes'
                        else:
                            CONTINUE='no'
                            
                        if area_nontropical==0:
                            CONTINUE='no'
                            
                        count_extra=0
                        higher_threshold=ivt_threshold+50.
                        while higher_threshold.max().values<=550. and CONTINUE=='yes':
                            CONTINUE='no'
                            count_extra=count_extra+1
                            ivt_higher_threshold=ivt_mag_current.where(ivt_mag_current>=higher_threshold)
                            region_specific=ivt_higher_threshold.where(labeled_xr==i)
                            binary=region_specific/region_specific
                            binary=binary.fillna(0)
                            
                            labeled_new=xr.zeros_like(binary)
                            labeled_new_values,n=label(binary,structure)
                            labeled_new[:,:]=labeled_new_values
                            labeled_new_flat=labeled_new.values.flatten()
                            
                            for k1 in range(n+1):
                                distance_max=0
                                select_component=labeled_new.where(labeled_new==k1)
                                ivt_sum=binary.where(labeled_new==k1).sum(skipna=True).values
                                if ivt_sum>=60.:
                                    
                                    lats_where=lat_array.where(labeled_new==k1)
                                    lons_where=lon_array.where(labeled_new==k1)
                                    min_lat=lats_where.min(skipna=True).values
                                    max_lat=lats_where.max(skipna=True).values
                                    min_lon=lons_where.min(skipna=True).values
                                    max_lon=lons_where.max(skipna=True).values
                                    distance_hyp=geopy.distance.geodesic((min_lat,min_lon),(max_lat,max_lon)).km
                                    
                                    if distance_hyp>=2000:#filter out features that could not possibly be long enough
                                    
                                        area_sum=areas.where(labeled_new==k1).sum(dim=('lat','lon'),skipna=True).values.tolist()

                                        choose_lats=lat_array.where(labeled_new==k1)
                                        choose_lons=lon_array.where(labeled_new==k1)
                                        lats_flat=choose_lats.values.ravel()
                                        lons_flat=choose_lons.values.ravel()
                                        
                                        zipped_ar=list(zip(lats_flat,lons_flat))
                                        zipped_ar=[x for x in zipped_ar if x[0]>=-90]

                                        hull=ConvexHull(zipped_ar)

                                        hull_points=[]
                                        for j2 in range(len(hull.vertices)):
                                            hull_points.append(zipped_ar[hull.vertices[j2]])

                                        distance_list=[]
                                        for j2 in range(len(hull_points)):
                                            for k2 in range(len(hull_points)):
                                                point_1=hull_points[j2]
                                                point_2=hull_points[k2]
                                                distance=geopy.distance.distance(point_1,point_2).km
                                                distance_list.append(distance)

                                        length=max(distance_list)
                                        width=area_sum/length
                                        
                                        component_tropical=select_component.where(lat_array<=20)
                                        component_tropical=component_tropical.where(lat_array>=-20)
                                        component_tropical_filled=component_tropical.fillna(0)
                                        component_nontropical=select_component.where(component_tropical_filled==0)
                                        area_tropical=areas.where(component_tropical==k1).sum(dim=('lat','lon'),skipna=True).values
                                        area_nontropical=areas.where(component_nontropical==k1).sum(dim=('lat','lon'),skipna=True).values
                                        
                                        if area_tropical/area_nontropical<1./2.5 and length/width>=2.5 and length>=2000.:
                                            label_ar_term=labeled_new.where(labeled_new==k1)
                                            label_ar_term=label_ar_term/label_ar_term
                                            label_ar_term=label_ar_term.fillna(0)
                                    
                                            count_cont=count_cont+1
                                            print('CONT')
                                            
                                            structure_dilation=np.zeros((3,3))
                                            structure_dilation[1,2]=1
                                            structure_dilation[0,1]=1
                                            structure_dilation[1,0]=1
                                            structure_dilation[2,1]=1
                                            structure_dilation[1,1]=1
                                            region_dilate=binary_dilation(label_ar_term,structure_dilation,iterations=3+2*(count_extra-1))
                                            
                                            final_region=ivt_ones.where(region_dilate==1)
                                            final_region=final_region.where(labeled_xr_original==i).fillna(0)
                                            
                                            area_sum_mod=areas.where(final_region==1).sum(dim=('lat','lon'),skipna=True).values

                                            region_north=labeled_xr.where(labeled_xr==i).where(lat_array>=45.).fillna(0)
                                            region_south=labeled_xr.where(labeled_xr==i).where(lat_array<=-45.).fillna(0)
                                            region_polar=region_north+region_south
                                            
                                            final_region=final_region+region_polar
                                            final_region=final_region/final_region
                                            final_region=final_region.fillna(0)
                                            
                                            ar_timestep=ar_timestep+final_region
                                                
                                            ######remove this portion from the original by modifying labeled_xr
                                            region_subset=select_component/select_component
                                            region_subset=region_subset.fillna(0)
                                            region_subset=region_subset.where(region_subset==0)
                                            labeled_xr=labeled_xr+region_subset
                                            ######
                                            
                                            
                                            
                                            lons_center=[x for x in ivt_mag_current.lon.values if -180<=x<180]
                                            
                                            test=ivt_mag_current#.sel(lon=lons_center)
                                            test=test.where(test>0)
                                            test=test.values

                                            l1_sel=l1.sel(time=date)
                                            
                                            fig=plt.figure(figsize=(22,9))
                                            ax=plt.subplot(1,1,1,projection=ccrs.PlateCarree(central_longitude=180))
                                            cax1=ax.contourf(ivt_mag_current.lon.values,ivt_mag_current.lat.values,test,levels=np.arange(250,1350,100),cmap=plt.cm.Greys,extend='max',transform=ccrs.PlateCarree(),alpha=0.5,label='Identified extreme IVT')
                                            #cbar=plt.colorbar(cax1,pad=0,fraction=0.046)
                                            #cbar.ax.tick_params(labelsize=20) 
                                            #cbar.set_label('kg m$^{-1}$ s$^{-1}$',fontsize=21,labelpad=5)

                                            
                                            test1=final_region
                                            field_final=xr.zeros_like(final_region)
                                            labeled_xr=xr.zeros_like(final_region)
                                            labeled,n=label(test1,structure)
                                            labeled_xr[:,:]=labeled
                                            for x1 in range(1,n+1):
                                                select=labeled_xr.where(labeled_xr==x1)
                                                select=select/select
                                                select=select.fillna(0)
                                                if select.sum().values>=100:
                                                    field_final=field_final+select
                                            
                                            cax2=ax.pcolormesh(l1_sel.lon,l1_sel.lat,l1_sel.where(l1_sel>0),vmin=0,vmax=l1_sel.max().values,cmap=plt.cm.jet,transform=ccrs.PlateCarree(),alpha=0.5,label='Identified AR')
                                            mpl.rcParams['hatch.linewidth']=0.01

                                            ivt_300=ivt_mag_current#.sel(lon=lons_center)
                                            ivt_300=ivt_300.where(ivt_300>=300.)
                                            ivt_300_ones=ivt_300/ivt_300
                                            ivt_300_ones=ivt_300_ones.fillna(0)

                                            
                                            c=ax.contour(ivt_300.lon,ivt_300.lat,ivt_300_ones,[0,1],colors='k',linewidths=4,transform=ccrs.PlateCarree())         
                                            g1=ax.gridlines(crs=ccrs.PlateCarree(),draw_labels=True,linewidth=2.5,color='gray',alpha=0.0,linestyle='--')

                                            
                                            test=label_ar_term/label_ar_term
                                            test=test.fillna(0)
                                            c=ax.contour(test.lon,test.lat,test,[0,1],colors='magenta',linewidths=7,transform=ccrs.PlateCarree())         

                                            
                                            ax.coastlines(resolution='10m')
                                            #ax.add_feature(cfeature.STATES.with_scale('10m'),alpha=0.3)
                                            #ax.add_feature(cfeature.LAKES.with_scale('50m'))
                                            countries = cfeature.NaturalEarthFeature(category='cultural',name='admin_0_boundary_lines_land',scale='50m',facecolor='none')
                                            ax.add_feature(countries)
                                            g1.xlabel_style={'size':25,'color':'k'}
                                            g1.ylabel_style={'size':25,'color':'k'}
                                            g1.xformatter=LONGITUDE_FORMATTER
                                            g1.yformatter=LATITUDE_FORMATTER
                                            #g1.xlocator = mticker.FixedLocator([119,120,121,122,123])
                                            #g1.ylocator = mticker.FixedLocator([21,22,23,24,25,26])

                                            g1.top_labels=False
                                            g1.right_labels=False
                                            
                                            print((min_lon,max_lon))
                                            #if 180 in lons_where.values.flatten().tolist():
                                            #if min_lon<-180 and max_lon>-180 or min_lon<180 and max_lon>180:
                                            #    ax.set_xlim(-180,180)
                                            #    ax.set_ylim(-75,75)
                                            #else:
                                            #    ax.set_xlim(max(min_lon-20,-180),min(max_lon+20,180))
                                            #    ax.set_ylim(max(min_lat-20,-90),min(max_lat+20,90))

                                            ax.set_xlim(-180,180)
                                            ax.set_ylim(-80,80)
                                            
                                            #ax.set_extent([100,-150,0,45])
                                            #ax.set_extent([-165,-65,-20,-65])
                                            #ax.set_extent([-105,-0.1,10,52.5])
                                            #ax.set_extent([125,-90,-60,-4])
                                            ax.set_extent([115,-70,0,70])
                                            
                                            
                                            ax.set_title('MERRA-2 semi-fixed AR detection: '+date.strftime('%Y.%m.%d.%H'),fontsize=36,pad=5)
                                            plt.show()
                                            #sys.exit()
                                            count_fig=count_fig+1
                                            #fig.savefig(dir_data+'method_testing_v9_'+str(d)+'.png')
                                            fig.savefig(dir_data+'method_testing_single_5.png')
                                            sys.exit()
                                            if count_fig==50:
                                                sys.exit()
                                            
                                            
                                            
                                            
                                        elif length>=2000.:
                                            CONTINUE='yes'
                                            

                            higher_threshold=higher_threshold+50.
        
        lons_center=[x for x in ar_timestep.lon.values if -180<=x<180]#90-degree barrier on each side
        ar_timestep=ar_timestep.sel(lon=lons_center)
        ar_timestep=ar_timestep/ar_timestep#can have overlap from binary dilation
        ar_timestep=ar_timestep.fillna(0)

        if d==0:
            ar_points=ar_timestep
        else:
            ar_points=xr.concat([ar_points,ar_timestep],dim='time')

        ivt_timestep=ivt_mag_current.sel(lon=lons_center)
        #Now make figure for testing and tweaking of definition

        if 1==0:
            test=ivt_timestep.where(ivt_timestep>0)
            test=test.values

            fig=plt.figure(figsize=(22,9))
            ax=plt.subplot(1,1,1,projection=ccrs.PlateCarree())
            cax1=ax.contourf(ivt_timestep.lon.values,ivt_timestep.lat.values,test,levels=np.arange(250,1050,100),cmap=plt.cm.Greys,transform=ccrs.PlateCarree(),alpha=0.3,label='Identified extreme IVT',zorder=20)
            cbar=plt.colorbar(cax1,pad=0,fraction=0.046)
            cbar.ax.tick_params(labelsize=20) 
            cbar.set_label('kg m$^{-1}$ s$^{-1}$',fontsize=21,labelpad=5)

            cax2=ax.contourf(ar_points.lon,ar_points.lat,ar_timestep.where(ar_timestep>0),[0,1],colors='b',transform=ccrs.PlateCarree(),hatches=[None,'.'],alpha=0.5,label='Identified AR',zorder=20)
            mpl.rcParams['hatch.linewidth']=0.01

            ivt_300=ivt_timestep.where(ivt_timestep>=300.)
            ivt_300_ones=ivt_300/ivt_300
            ivt_300_ones=ivt_300_ones.fillna(0)

            c=ax.contour(ivt_300.lon,ivt_300.lat,ivt_300_ones,[0,1],linewidths=1.5,transform=ccrs.PlateCarree())         
            g1=ax.gridlines(crs=ccrs.PlateCarree(),draw_labels=True,linewidth=1.5,color='gray',alpha=0.0,linestyle='--')

            ax.coastlines(resolution='10m')
            #ax.add_feature(cfeature.STATES.with_scale('10m'),alpha=0.3)
            #ax.add_feature(cfeature.LAKES.with_scale('50m'))
            countries = cfeature.NaturalEarthFeature(category='cultural',name='admin_0_boundary_lines_land',scale='50m',facecolor='none')
            ax.add_feature(countries)
            g1.xlabel_style={'size':20,'color':'k'}
            g1.ylabel_style={'size':20,'color':'k'}
            g1.xformatter=LONGITUDE_FORMATTER
            g1.yformatter=LATITUDE_FORMATTER
            #g1.xlocator = mticker.FixedLocator([119,120,121,122,123])
            #g1.ylocator = mticker.FixedLocator([21,22,23,24,25,26])

            g1.top_labels=False
            g1.right_labels=False
            ax.set_xlim(-180,180)
            ax.set_ylim(-75,75)

            ax.set_title('AR detection: '+date.strftime('%Y.%m.%d.%H'),fontsize=27,pad=5)
            plt.show()
            #fig.savefig(dir_data+'ar_detection_example_'+str(d)+'.png')
            if d==80:
                sys.exit()
            #fig.savefig(dir1+'ar_test_new_'+str(year)+'_'+str(d)+'.png')
     
    print(count_normal)
    print(count_stop)
    sys.exit()
    if 1==0:
        dk=xr.Dataset()

        dk['ar']=(('time','lat','lon'),ar_points.values)

        dk.coords['time']=ar_points.time
        dk.coords['lat']=ar_points.lat
        dk.coords['lon']=ar_points.lon
        
        dk['ar'].attrs["description"]='1 = AR, 0 = no AR'
        dk['lat'].attrs["units"]='degrees_north'
        dk['lon'].attrs["units"]='degrees_east'
        
        try:
            os.remove(dir2+'merra_2_ar_detection_polar_'+str(year)+'.nc')
        except OSError:
            pass
        dk.to_netcdf(dir2+'merra_2_ar_detection_polar_'+str(year)+'.nc',mode='w',format='NETCDF4')

In [None]:
print(count_cont)
print(count_stop)

In [None]:
ds=xr.open_dataset(dir2+'merra2.ar.labels.model.variable.1980.nc')
l1=ds['ar_labeled']


index=759
print(l1.time[index])
fig=plt.figure(figsize=(20,10))
ax=plt.subplot(1,1,1,projection=ccrs.PlateCarree())
ax.pcolormesh(l1.lon,l1.lat,l1[index,:,:],vmin=0,vmax=l1[index,:,:].max().values,cmap=plt.cm.prism)
ax.coastlines(resolution='10m')
plt.show()

In [None]:
fig=plt.figure(figsize=(20,10))
ax=plt.subplot(1,1,1)
ax.pcolormesh(ivt_mag_current.lon.values,ivt_mag_current.lat.values,ivt_mag_current,vmin=0,vmax=600,cmap=plt.cm.Greys)
ax.set_xlim(-180,180)
ax.set_ylim(-80,80)
plt.show()