In [None]:
# Standard Python modules
import os, sys
import yaml
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
# matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import AxesGrid
from matplotlib import rcParams
from matplotlib.colors import ListedColormap
import matplotlib.ticker as mticker
# plot styles/formatting
import seaborn as sns
import cmocean.cm as cmo
import cmocean
# cartopy
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes

# Extra 
from scipy.ndimage import gaussian_filter    # smoothing contour lines
from scipy.stats import linregress
import metpy.calc as mpcalc
from metpy.units import units
from IPython.display import Image, display

# import personal modules

# Path to modules
sys.path.append('../modules')

# Import my modules
from plotter import draw_basemap
from timeseries import persistence, select_months
from teleconnections import build_teleconnection_df
from statistical_tests import lin_regress

In [None]:
# set up multiple workers for use later when parallel
from dask.distributed import Client
client = Client(processes=True, workers=8)
client

In [None]:
# Set up paths

path_to_data = '/home/nash/DATA/data/'                            # project data -- read only
path_to_out  = '/home/nash/DATA/repositories/AR_types/out/'       # output files (numerical results, intermediate datafiles) -- read & write
path_to_figs = '/home/nash/DATA/repositories/AR_types/figs/'      # figures

# USE pandas.options TO DISPLAY FLOATS TO TWO DECIMAL PLACES
pd.options.display.float_format = "{:,.2f}".format

In [None]:
start_date = '1979-03-01'
end_date = '2018-05-31'

tele = build_teleconnection_df('daily', 'ANOM', start_date, end_date)
tele

In [None]:
df_index_djf = select_months(tele, 12, 2)
df_index_mam = select_months(tele, 3, 5)
print('# DJF days: ', len(df_index_djf))
print('# MAM days: ', len(df_index_mam))

In [None]:
def ar_daily_df(ssn, nk):
    fname_id = 'HUV500t0120050cor'
    filepath = path_to_out + fname_id + 'hma_AR-types-' + ssn + '.csv'
    df = pd.read_csv(filepath)

    # set up datetime index
    df = df.rename(columns={'Unnamed: 0': 'date'})
    df = df.set_index(pd.to_datetime(df.date))
    
    ## Break up columns into different AR Types
    keys = []
    for k in range(nk):
        keys.append("AR_CAT{:1d}".format(k+1,))

    values = np.zeros((len(df.index)))
    dicts = dict(zip(keys, values))

    df_cat = pd.DataFrame(dicts, index=df.index)

    for k in range(nk):
        idx = (df['AR_CAT'] == k+1)
        col = "AR_CAT{:1d}".format(k+1,)
        df_cat.loc[idx, col] = 1
        
    # get total of all AR types
    df_cat['AR_ALL'] = df_cat['AR_CAT1'] + df_cat['AR_CAT2'] + df_cat['AR_CAT3']
    df_cat['AR_CAT'] = df['AR_CAT']
    
    return df_cat

In [None]:
df_djf = ar_daily_df(ssn='djf', nk=3)
df_mam = ar_daily_df(ssn='mam', nk=3)

In [None]:
# combine ar df with tele df
print(len(df_djf), len(df_index_djf), len(df_mam), len(df_index_mam))
# join indices with AR count
new_djf = df_djf.join(df_index_djf)
new_mam = df_mam.join(df_index_mam)

### Linear Regression

In [None]:
# Select lat/lon grid
lonmin = -180
lonmax = 180
latmin = 0
latmax = 90

def preprocess(ds):
    '''keep only selected lats and lons'''
    return ds.sel(latitude=slice(latmax, latmin), longitude=slice(lonmin, lonmax))

# # open HUV data
filepath_pattern = path_to_data + 'ERA5/huv/anomalies/daily_filtered_anomalies_*.nc'  
    
f2 = xr.open_mfdataset(filepath_pattern, preprocess=preprocess, combine='by_coords')

ds = f2

In [None]:
# update season for djf or mam plot
ssn = 'mam'

if ssn == 'djf':
    start_date = '1979-12-01'
    end_date = '2018-02-28'
    mon_s = 12
    mon_e = 2
    new_ds = new_djf
elif ssn == 'mam':
    start_date = '1979-03-01'
    end_date = '2018-05-31'
    mon_s = 3
    mon_e = 5
    new_ds = new_mam

In [None]:
# Trim date range

idx = slice(start_date, end_date)
ds = ds.sel(time=idx)

# Select months
if mon_s > mon_e:
    idx = (ds.time.dt.month >= mon_s) | (ds.time.dt.month <= mon_e)
else:
    idx = (ds.time.dt.month >= mon_s) & (ds.time.dt.month <= mon_e)
    
ds = ds.sel(time=idx)


# Combine AR Cat data w/ reanalysis data
# Add ar time series to the ERA dataset
cols = ['AR_CAT', 'AO', 'PDO', 'ENSO', 'SH']
for i, col in enumerate(cols):
    ds[col] = ('time', new_ds[col])

ds = ds.set_coords('AR_CAT')
ds

In [None]:
def new_linregress(x, y):
    # Wrapper around scipy linregress to use in apply_ufunc
    slope, intercept, r_value, p_value, std_err = linregress(x, y)
    return np.array([slope, intercept, r_value, p_value, std_err])

def lin_regress(ds, x, y):
    '''Wrapped scipy.stats.linregress to calculate slope, y-int, r-value, p-value, and standard error in xr.dataset form'''
    return xr.apply_ufunc(new_linregress, ds[x], ds[y],
                           input_core_dims=[['time'], ['time']],
                           output_core_dims=[["parameter"]],
                           vectorize=True,
                           dask="parallelized",
                           output_dtypes=['float64'],
                           output_sizes={"parameter": 5},
                      )


In [None]:
%%time
# run linear regression for each AR Type and each teleconnection
cols = ['AO', 'PDO', 'ENSO', 'SH']
rval = []
pval = []
slope = []

for i, artype in enumerate(np.arange(1,4)):
    idx = (ds['AR_CAT'] == artype)
    data = ds.sel(time=idx).chunk({'time': -1})
    for j, tele in enumerate(cols):
        results = lin_regress(data, 'z', tele)
#         print(artype, tele)
#         rval.append(results.isel(parameter=2).values)
#         pval.append(results.isel(parameter=3).values)
        slope.append(results.isel(parameter=0).values)

In [None]:
# Set seaborn plot style
sns.set()
sns.set_style("ticks", {'patch.force_edgecolor':False})

# Set up projection
mapcrs = ccrs.NorthPolarStereo()
datacrs = ccrs.PlateCarree()

# Set tick/grid locations
dx = np.arange(lonmin,lonmax+20,20)
dy = np.arange(latmin,latmax+20,20)

# lat/lon arrays
lats = data.latitude.values
lons = data.longitude.values 

In [None]:
filepath = path_to_figs + 'composites/teleconnections/regress_teleconnections_slope_' + ssn + '.png'    
nrows = 3
ncols = 4

cols = ['ao', 'pdo', 'enso', 'sh']
plt_lbls = ['AO', 'PDO', 'ENSO', 'SH']+['']*8
row_lbls = ['AR Type 1', '', '', '',
           'AR Type 2', '', '', '',
           'AR Type 3', '', '', '']

cmap = cmo.balance

# Create figure
fig = plt.figure(figsize=(10,15))

# Set up Axes Grid
axes_class = (GeoAxes,dict(map_projection=mapcrs))
axgr = AxesGrid(fig, 111, axes_class=axes_class,
                nrows_ncols=(nrows, ncols), axes_pad = 0.2,
                cbar_location='bottom', cbar_mode='single',
                cbar_pad=0.10, cbar_size='2%',label_mode='',
                direction='row')

for k, ax in enumerate(axgr): 
    data = slope[k]
    ax = draw_basemap(ax, extent=[lonmin,lonmax,latmin,latmax], grid=True)
    # Contour Filled
    cflevs = np.arange(-0.1, 0.1, 0.01)
    cf = ax.contourf(lons, lats, data, transform=datacrs,
                     levels=cflevs, cmap=cmap, alpha=0.9, extend='both')
    ax.set_title(plt_lbls[k], fontsize=13)
    # Row labels
    ax.text(-0.07, 0.55, row_lbls[k], va='bottom', ha='center',
        rotation='vertical', rotation_mode='anchor', fontsize=13,
        transform=ax.transAxes)
                  
# # Colorbar (single)
cb = fig.colorbar(cf, axgr.cbar_axes[0], orientation='horizontal', drawedges=True, extend='both', spacing='uniform')
cb.set_label('m')
    
# Save figure
plt.savefig(filepath, dpi=150, bbox_inches='tight')

# Show
plt.show()

In [None]:
print(data.min(), data.max())