# Utilities

In [11]:
import numpy as np
import matplotlib.pyplot as plt
from jupyterthemes import jtplot
import matplotlib as mpl
from functions import *
import pandas as pd
pd.options.display.float_format = '{:,.2f}'.format
import cartopy.crs as ccrs
import matplotlib.gridspec as gridspec
import cartopy
import xarray as xr
from shapely.geometry.polygon import LinearRing
from matplotlib import colors
import pickle
jtplot.style(context='paper', fscale=1.4, spines=True, grid=False, ticks=True,gridlines='--')

fontsize=16
mpl.rcParams['xtick.direction'] = 'in'
mpl.rcParams['ytick.direction'] = 'in'
mpl.rcParams['xtick.top'] = True
mpl.rcParams['ytick.right'] = True

mpl.rcParams['font.size'] = 16
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['figure.titlesize'] = 'medium'
mpl.rcParams['axes.labelsize']= 'x-large'
mpl.rcParams['figure.facecolor']='white'

mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Arial']
hfont = {'fontname':'Arial'}

mpl.rcParams['text.latex.preamble']= r'\usepackage{amsmath}'
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=['#1b9e77','#d95f02','#7570b3','#e7298a','#66a61e','#e6ab02','#a6761d','#666666']) 

In [12]:
months = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec']

In [13]:
model_names = ['HadGEM3-GC3-1MM','CanESM5','CESM2','IPSL-CM6A-LR']
lat_keys = [False,True,True,True]

In [14]:
tos_keys = dict(zip(model_names,['tos','tos','TS','tos']))
prcp_keys = dict(zip(model_names,['precipitation_flux','pr','PRECC','pr']))

In [15]:
aidcs = [[0,80*12],[0,80*12],[0+11,80*12+11],[0,80*12]]
cidcs = [[20*12,100*12],[20*12,100*12],[100*12+11,180*12+11],[60*12,140*12]]
aslices1 = dict(zip(model_names,aidcs))
cslices1 = dict(zip(model_names,cidcs))
aslices = aslices1
cslices = cslices1
aslicesm = aslices1
cslicesm = cslices1

In [16]:
aidcs = [[0,80],[0,80],[1,81],[0,80]]
cidcs = [[20,100],[20,100],[101,181],[60,140]]
aslicesy = dict(zip(model_names,aidcs))
cslicesy = dict(zip(model_names,cidcs))

In [17]:
aidcs = [[0,-1],[0,-1],[0,-1],[0,-1]]
cidcs = [[0,-1],[0,-1],[0,-1],[0,-1]]
aslices_full = dict(zip(model_names,aidcs))
cslices_full = dict(zip(model_names,cidcs))

In [18]:
extent_boxes = {'am1': [-20, -90, -30, 16],
                'am3': [-20, -90, -30, 16],
 'wam4': [40, -25, -15, 30],
 'ism1': [65, 90, 0, 30],
 'easm2': [105, 145, 5, 45]}

In [19]:
boxes_latlon = [[285-360,317.5-360,-15,-5,],
    [295-360,315-360,-15,-5],
    [-75,-60,0,10],
    [-75,-50,-10,0],
    [-50,-40,-15,-5],
    [-20,20,9,20],
    [-20,20,0,10],
    [-20,20,5,15],
   [-75,-55,-5,5],
   [-20,25,5,20],
   [70,85,5,25],
    [110,140,10,40]]
boxes_dict = {
    'am1': [285,317.5,-15,-5],
    'am2' :[295,315,-15,-5],
     'pam1':   [360-75,360-60,0,10],
    'pam2': [360-75,360-50,-10,0],
    'pam3': [360-50,360-40,-15,-5],
    'wam1': [360-20,360+20,9,20],
    'wam2': [360-20,360+20,0,10],
    'wam3': [360-20,360+20,5,15],
    'am3':[360-75,360-55,-5,5],
    'wam4':[360-20,25,5,20],
    'ism1':[70,85,5,25],
    'easm2':[110,140,10,40]
}

In [20]:
extent_latlon = {
            'am':[-20, -85,-30, 30],
            'af':[35, -30,-30, 30],
            'in':[60, 100,0, 30],
            'as':[100, 150,15,50]}

In [21]:
labels = ['(a)','(b)','(c)','(d)','(e)','(f)','(g)','(h)','(i)','(j)','(k)','(l)','(m)','(n)','(o)','(p)']

In [22]:
# the way pyplot plots contourf hatches is a by taking each gridcell as a dot and 
# connecting them to form areas, but what we want is each grid square to be filled 
# if it matches/is significant/etc., so we need to shift the grid and fill in the appropriate squares
# NB: this function is not completely generalized, and some details have to be changed for different grids
def get_shifted_stip(data,lats, lons,lat_bnds,lon_bnds):
    shifted_sig = np.full((len(lats),len(lons)+1),np.nan)
    for i, lat in enumerate(lats):
        for j, lon in enumerate(lons):
            try:
                x = data.sel(latitude=lat,longitude=lon).values
            except:
                x = data.sel(lat=lat,lon=lon).values
            if not np.isnan(x):
                if (i!=len(lats)-1) & (j!=len(lons)):
                    shifted_sig[i,j]=x
                    shifted_sig[i,j+1]=x
                    shifted_sig[i+1,j]=x
                    shifted_sig[i+1,j+1]=x
                else:
                    if (i==len(lats)-1) & (j!=len(lons)):
                        shifted_sig[0,j]=x
                        shifted_sig[0,j+1]=x
                    elif (j==len(lons))&(i!=len(lats)-1):
                        shifted_sig[i,0]=x
                        shifted_sig[i+1,0]=x
                    else:
                        print(i,j)
                        shifted_sig[0,0]=x
    
    sig = xr.Dataset(
            data_vars = dict(data=(['lat','lon'],shifted_sig)),
            coords = dict(
                    lat = xr.DataArray(lat_bnds, dims="lat", coords=dict(lat=("lat", lat_bnds))),
                    lon = xr.DataArray(np.append(Clon_bnds, 360-1.40625), dims="lon", coords=dict(lon=("lon", np.append(Clon_bnds, 360-1.40625))))))
    return sig

# agreement with HadGEM

## annual mean

In [28]:
diff_prcp_Cgrid = {}

for model in model_names:
    adata = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/{}/vars/aprcpm_Cgrid.nc'.format(model))[prcp_keys[model]][aslicesm[model][0]:aslicesm[model][1]]
    cdata = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/{}/vars/cprcpm_Cgrid.nc'.format(model))[prcp_keys[model]][cslicesm[model][0]:cslicesm[model][1]]

    diff_prcp_Cgrid[model] = (cdata.mean(dim='time'))-(adata.mean(dim='time'))
ens_mean_diff = (diff_prcp_Cgrid[model_names[0]]+diff_prcp_Cgrid[model_names[1]]+diff_prcp_Cgrid[model_names[2]]+diff_prcp_Cgrid[model_names[3]])/4

hsign = np.sign(diff_prcp_Cgrid['HadGEM3-GC3-1MM'])

  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)


In [25]:
adata = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/{}/vars/aprcpm_Cgrid.nc'.format(model))
Clon_bnds = adata.lon_bnds.values[:,0]
Clat_bnds = adata.lat_bnds.values[:,0]

  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)


In [18]:
match = diff_prcp_Cgrid['HadGEM3-GC3-1MM'].where(
                      (np.sign(diff_prcp_Cgrid['CanESM5'])==hsign)&
                    (np.sign(diff_prcp_Cgrid['CESM2'])==hsign)&
                    (np.sign(diff_prcp_Cgrid['IPSL-CM6A-LR'])==hsign)
)

In [20]:
sig = get_shifted_stip(match,match.lat,match.lon,Clat_bnds,Clon_bnds)

In [21]:
sig.to_netcdf('/p/tmp/mayayami/NAHosMIP/shifted_sigs/all_sign_match.nc')

## seasonal

In [23]:
seas_prcp_Cgrid = {}

for model in model_names:
    adata = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/{}/vars/aprcpm_Cgrid.nc'.format(model))[prcp_keys[model]][aslicesm[model][0]:aslicesm[model][1]]
    cdata = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/{}/vars/cprcpm_Cgrid.nc'.format(model))[prcp_keys[model]][cslicesm[model][0]:cslicesm[model][1]]

    seas_prcp_Cgrid[model] = (season_mean(cdata))-(season_mean(adata))
ens_mean_seas_diff = (seas_prcp_Cgrid[model_names[0]]+seas_prcp_Cgrid[model_names[1]]+seas_prcp_Cgrid[model_names[2]]+seas_prcp_Cgrid[model_names[3]])/4



  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)


In [26]:
for i, season in enumerate(['DJF', 'JJA', 'MAM', 'SON']):
    hsign = np.sign(seas_prcp_Cgrid['HadGEM3-GC3-1MM'][i])

    match = seas_prcp_Cgrid['HadGEM3-GC3-1MM'][i].where(
                      (np.sign(seas_prcp_Cgrid['CanESM5'][i])==hsign)&
                    (np.sign(seas_prcp_Cgrid['CESM2'][i])==hsign)&
                    (np.sign(seas_prcp_Cgrid['IPSL-CM6A-LR'][i])==hsign)
    )
    sig = get_shifted_stip(match,match.lat,match.lon,Clat_bnds,Clon_bnds)
    sig.to_netcdf('/p/tmp/mayayami/NAHosMIP/shifted_sigs/all_sign_match_{}.nc'.format(season))

## only two models

In [29]:
match2 = diff_prcp_Cgrid['HadGEM3-GC3-1MM'].where(
((np.sign(diff_prcp_Cgrid['CanESM5'])==hsign)&(np.sign(diff_prcp_Cgrid['CESM2'])==hsign)) |
((np.sign(diff_prcp_Cgrid['CanESM5'])==hsign)&(np.sign(diff_prcp_Cgrid['IPSL-CM6A-LR'])==hsign)) |
((np.sign(diff_prcp_Cgrid['IPSL-CM6A-LR'])==hsign)&(np.sign(diff_prcp_Cgrid['CESM2'])==hsign)))


In [30]:
sig2 = get_shifted_stip(match2,match2.lat,match2.lon,Clat_bnds,Clon_bnds)
sig2.to_netcdf('/p/tmp/mayayami/NAHosMIP/shifted_sigs/all_sign_match_two.nc')

# vs 4xCO2

In [133]:
idc = [[60,140],[60,140],[60,140],[60,140]]
co2_slicesy = dict(zip(model_names,idc))

In [134]:
co2_diff_prcp_Cgrid = {}

for model in model_names:
    adata = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/{}/vars/aprcpm_Cgrid.nc'.format(model))[prcp_keys[model]][aslicesm[model][0]:aslicesm[model][1]]
    co2_data = (xr.open_dataset('/p/tmp/mayayami/NAHosMIP/{}/4CO2_prcp_Cgrid_ymean.nc'.format(model)).pr*86400)[co2_slicesy[model][0]:co2_slicesy[model][1]]

    co2_diff_prcp_Cgrid[model] = (co2_data.mean(dim='time'))-(adata.mean(dim='time'))

co2_ens_mean_diff = (co2_diff_prcp_Cgrid[model_names[0]]+co2_diff_prcp_Cgrid[model_names[1]]+co2_diff_prcp_Cgrid[model_names[2]]+co2_diff_prcp_Cgrid[model_names[3]])/4

  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)


In [135]:
co2_hsign = np.sign(co2_diff_prcp_Cgrid['HadGEM3-GC3-1MM'])
co2_match = co2_hsign.where(
                      (np.sign(co2_diff_prcp_Cgrid['CanESM5'])==co2_hsign)&
                    (np.sign(co2_diff_prcp_Cgrid['CESM2'])==co2_hsign)&
                    (np.sign(co2_diff_prcp_Cgrid['IPSL-CM6A-LR'])==co2_hsign)
)

In [136]:
co2_sig = get_shifted_stip(co2_match,co2_match.lat,co2_match.lon,Clat_bnds,Clon_bnds)
co2_sig.to_netcdf('/p/tmp/mayayami/NAHosMIP/all_sign_match_co2.nc')

# dry season agreement

In [22]:
extent_latlon = {'am': [-25, -90, -30, 21],
 'wam': [40, -25, -17.5, 32.5],
 'ism': [60, 95, 0, 30],
 'easm': [100, 150, 5, 45]}

In [23]:
region_dict = {
        'am':[360-85,360-30,-20,11],
        'wam': [360-20,35,-10,25],
        'ism': [70, 85, 5, 25],
        'easm':[110,140,10,40]}

In [24]:
dmap5_Cgrid = {}
dmap4_Cgrid = {}
wmap7_Cgrid = {}
wmap6_Cgrid = {}

for model in model_names:
    dmapp5 = {}
    dmapp4 = {}
    wmapp7 = {}
    wmapp6 = {}
    for region, extent in region_dict.items():
        dmapp5[region] = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/dry_maps/{}_dmap_{}_q5_Cgrid.nc'.format(model,region))
        dmapp4[region] = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/dry_maps/{}_dmap_{}_q4_Cgrid.nc'.format(model,region))
        wmapp7[region] = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/dry_maps/{}_wmap_{}_q7_Cgrid.nc'.format(model,region))
        wmapp6[region] = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/dry_maps/{}_wmap_{}_q6_Cgrid.nc'.format(model,region))

        
    dmap5_Cgrid[model]=dmapp5
    dmap4_Cgrid[model]=dmapp4
    wmap7_Cgrid[model]=wmapp7
    wmap6_Cgrid[model]=wmapp6

In [25]:
dry_len_diff = {}
dry_prcp_diff = {}
wet_len_diff = {}
wet_prcp_diff = {}

region = 'am'

for im, model in enumerate(model_names):
    dry_len = {}
    dry_prcp = {}
    wet_len = {}
    wet_prcp = {}
    for region, extent in region_dict.items():
        seamask = xr.open_dataset('/p/tmp/mayayami/NAHosMIP/{}/Cgrid_seamask.nc'.format(model)).topo
        region_box = region_dict[region]

        ####
        data = (dmap4_Cgrid[model][region].cdlen1.mean(dim='year') - dmap4_Cgrid[model][region].adlen.mean(dim='year')).where(seamask==1)
        dry_len[region]=data
        ####

        #####
        data = (((dmap4_Cgrid[model][region].cdprcp1/dmap4_Cgrid[model][region].cdlen1).mean(dim='year'))
                -((dmap4_Cgrid[model][region].adprcp/dmap4_Cgrid[model][region].adlen).mean(dim='year'))).where(seamask==1)
        dry_prcp[region]=data
        #####

        ####
        data = (wmap6_Cgrid[model][region].cwlen1.mean(dim='year') - wmap6_Cgrid[model][region].awlen.mean(dim='year')).where(seamask==1)
        wet_len[region]=data
        ####

        ####
        data = (((wmap6_Cgrid[model][region].cwprcp1/wmap6_Cgrid[model][region].cwlen1).mean(dim='year'))
                -((wmap6_Cgrid[model][region].awprcp/wmap6_Cgrid[model][region].awlen).mean(dim='year'))).where(seamask==1)
        wet_prcp[region]=data
        ####
    dry_len_diff[model] = dry_len
    dry_prcp_diff[model] = dry_prcp
    wet_len_diff[model] = wet_len
    wet_prcp_diff[model] = wet_prcp

In [57]:
def get_shifted_stip2(data,lats, lons,lat_bnds,lon_bnds):
    shifted_sig = np.full((len(lat_bnds),len(lon_bnds)),np.nan)
    for i, lat in enumerate(lats):
        for j, lon in enumerate(lons):
            try:
                x = data.sel(latitude=lat,longitude=lon).values
            except:
                x = data.sel(lat=lat,lon=lon).values
            shifted_sig[i,j]=x
            shifted_sig[i,j+1]=x
            shifted_sig[i+1,j]=x
            shifted_sig[i+1,j+1]=x

    
    sig = xr.Dataset(
            data_vars = dict(data=(['lat','lon'],shifted_sig)),
            coords = dict(
                    lat = xr.DataArray(lat_bnds, dims="lat", coords=dict(lat=("lat", lat_bnds))),
                    lon = xr.DataArray(lon_bnds, dims="lon", coords=dict(lon=("lon", lon_bnds)))))
    return sig

In [101]:

for iv, val in enumerate([dry_len_diff,dry_prcp_diff,wet_len_diff,wet_prcp_diff]):
    for i, region in enumerate(region_dict.keys()):
        if region=='wam':
            hsign = np.sign(val['HadGEM3-GC3-1MM'][region])
            match = hsign.where(
                                  (np.sign(val['CanESM5'][region])==hsign)&
                                (np.sign(val['CESM2'][region])==hsign)&
                                (np.sign(val['IPSL-CM6A-LR'][region])==hsign)
            )
            box_lon_bnds = Clon_bnds[int(np.where(Clon_bnds>=match.lon.values[0])[0][0]-1):int(np.where(Clon_bnds<=match.lon.values[-1])[0][-1]+2)]
            box_lon_bnds = np.concatenate((box_lon_bnds[:14],box_lon_bnds[-7:],[360-1.40625]))
            box_lat_bnds = Clat_bnds[int(np.where(Clat_bnds>=match.lat.values[0])[0][0]-1):int(np.where(Clat_bnds<=match.lat.values[-1])[0][-1]+2)]

            sig = get_shifted_stip2(match,match.lat,match.lon,box_lat_bnds,box_lon_bnds)
            sig.to_netcdf('/p/tmp/mayayami/NAHosMIP/seas_signs/all_sign_match_{}_{}.nc'.format(iv,region))