# Regression

In [None]:
from dask_jobqueue import PBSCluster
from dask.distributed import Client

In [None]:
# One node on Gadi has 48 cores - try and use up a full core before going to multiple nodes (jobs)

walltime = '00:10:00'
cores = 2
memory = '8GB'

cluster = PBSCluster(walltime=str(walltime), cores=cores, memory=str(memory),
                     job_extra=['-l ncpus='+str(cores),
                                '-l mem='+str(memory),
                                '-P xv83',
                                '-l storage=gdata/xv83+gdata/rt52+scratch/xv83'],
                     header_skip=["select"])

In [None]:
cluster.scale(jobs=1)
client = Client(cluster)

In [None]:
client

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import regionmask
import copy
from scipy import stats
from collections import OrderedDict
import xskillscore as xs

import statsmodels.api as sm
import statsmodels.formula.api as smf

import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import cartopy.crs as ccrs
import cartopy
cartopy.config['pre_existing_data_dir'] = '/g/data/xv83/dr6273/work/data/cartopy-data'
cartopy.config['data_dir'] = '/g/data/xv83/dr6273/work/data/cartopy-data'

import functions as fn

In [None]:
plt_params = fn.get_plot_params()

In [None]:
# default colours
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

# Load coffee data

In [None]:
# Order abbrevs and names by species and production
country_order = fn.get_country_order()

In [None]:
growing_calendar = pd.read_csv('/g/data/xv83/dr6273/work/projects/coffee/data/coffee_country_growing_calendar_extended.csv',
                               index_col=0)
growing_calendar.head()

In [None]:
arabica_abbrevs = np.unique(growing_calendar.loc[(growing_calendar.species == 'Arabica'), 'abbrevs'])
robusta_abbrevs = np.unique(growing_calendar.loc[(growing_calendar.species == 'Robusta'), 'abbrevs'])

# Gridded climate data relevant for each phase of coffee (growing and flowering)

In [None]:
vpd_flowering = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/era5_vpd_detrended_Flowering_upper_tail_1_std.zarr',
                             consolidated=True)
vpd_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/era5_vpd_detrended_Growing_upper_tail_1_std.zarr',
                              consolidated=True)

In [None]:
mn2t_flowering = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_tmin_detrended_Flowering_lower_tail_1_std.zarr',
                             consolidated=True)
mn2t_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_tmin_detrended_Growing_upper_tail_1_std.zarr',
                              consolidated=True)

In [None]:
mx2t_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_tmax_detrended_Growing_upper_tail_1_std.zarr',
                                  consolidated=True)

In [None]:
t2m_lt_growing_optimal = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_temperature_detrended_Growing_lower_tail_1_std.zarr',
                                             consolidated=True)

In [None]:
t2m_ut_growing_optimal = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_temperature_detrended_Growing_upper_tail_1_std.zarr',
                                             consolidated=True)

In [None]:
tp_lt_growing_optimal = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/gpcc_precip_detrended_Annual_lower_tail_1_std.zarr',
                                             consolidated=True)

In [None]:
tp_ut_growing_optimal = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/gpcc_precip_detrended_Annual_upper_tail_1_std.zarr',
                                             consolidated=True)

### Proportion of each country, and global coffee area, in drought each year

In [None]:
vpd_grid_template = 'era5'
temperature_grid_template = 'berkeley'
precip_grid_template = 'gpcc'

### VPD events

In [None]:
vpd_flowering_events = fn.calculate_event_statistics(vpd_flowering, vpd_grid_template).compute()

In [None]:
vpd_growing_events = fn.calculate_event_statistics(vpd_growing, vpd_grid_template).compute()

### Tmin averages events

In [None]:
mn2t_flowering_events = fn.calculate_event_statistics(mn2t_flowering, temperature_grid_template).compute()

In [None]:
mn2t_growing_events = fn.calculate_event_statistics(mn2t_growing, temperature_grid_template).compute()

### Tmax averages events

In [None]:
mx2t_growing_events = fn.calculate_event_statistics(mx2t_growing, temperature_grid_template).compute()

### T ranges events

In [None]:
t2m_lt_growing_optimal_events = fn.calculate_event_statistics(t2m_lt_growing_optimal, temperature_grid_template).compute()

In [None]:
t2m_ut_growing_optimal_events = fn.calculate_event_statistics(t2m_ut_growing_optimal, temperature_grid_template).compute()

### Precip ranges events

In [None]:
tp_lt_growing_optimal_events = fn.calculate_event_statistics(tp_lt_growing_optimal, precip_grid_template).compute()

In [None]:
tp_ut_growing_optimal_events = fn.calculate_event_statistics(tp_ut_growing_optimal, precip_grid_template).compute()

# Load mode data

- Stick to growing season, as only one climate risk in flowering season.
    - Means the comparison to 12-month rainfall is not quite right.

In [None]:
sst_dataset = 'hadisst'

### Nino3.4

In [None]:
nino34_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_nino34_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### DMI

In [None]:
dmi_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_dmi_detrended_Growing_both_tails_1_std.zarr',
                            consolidated=True).compute()

### Atlantic Nino

In [None]:
atl_nino_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_atl_nino_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### TNA

In [None]:
tna_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_tna_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### TSA

In [None]:
tsa_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_tsa_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### MJO

In [None]:
mjo_days_per_month_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/era5_mjo_days_per_month_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### Concatenate

In [None]:
modes_concat = xr.concat([
    nino34_growing.nino34_detrended.expand_dims({'mode': ['nino34']}),
    dmi_growing.dmi_detrended.expand_dims({'mode': ['dmi']}),
    atl_nino_growing.atl_nino_detrended.expand_dims({'mode': ['atl_nino']}),
    tna_growing.tna_detrended.expand_dims({'mode': ['tna']}),
    tsa_growing.tsa_detrended.expand_dims({'mode': ['tsa']}),

    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=1).expand_dims({'mode': ['mjo_dpm_p1']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=2).expand_dims({'mode': ['mjo_dpm_p2']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=3).expand_dims({'mode': ['mjo_dpm_p3']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=4).expand_dims({'mode': ['mjo_dpm_p4']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=5).expand_dims({'mode': ['mjo_dpm_p5']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=6).expand_dims({'mode': ['mjo_dpm_p6']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=7).expand_dims({'mode': ['mjo_dpm_p7']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=8).expand_dims({'mode': ['mjo_dpm_p8']}).drop('phase_ID'),
                        ], 'mode')

modes_concat = modes_concat.sel(time=slice('1980', '2020'))

### Mode names for plots

In [None]:
mode_names = [
    'ENSO',
    'IOD',
    r'$\mathrm{Atl. Ni\tilde{n}o}$',
    'TNA',
    'TSA',
    r'$\mathrm{MJO}_{1}$',
    r'$\mathrm{MJO}_{2}$',
    r'$\mathrm{MJO}_{3}$',
    r'$\mathrm{MJO}_{4}$',
    r'$\mathrm{MJO}_{5}$',
    r'$\mathrm{MJO}_{6}$',
    r'$\mathrm{MJO}_{7}$',
    r'$\mathrm{MJO}_{8}$'
]

# Correlation between modes

In [None]:
def mode_cor():
    """
    Compute correlation between each pair of climate modes
    """

    def cor(s_id):
        cor = np.full((len(modes_concat['mode'].values), len(modes_concat['mode'].values)), np.nan)

        for i, mode1 in enumerate(modes_concat['mode'].values):
            da1 = modes_concat.sel(season_id=s_id, mode=mode1)

            for j, mode2 in enumerate(modes_concat['mode'].values):
                da2 = modes_concat.sel(season_id=s_id, mode=mode2)

#                 cor[i,j] = xr.corr(da1, da2, dim='time').values
                cor[i, j] = xs.spearman_r(da1, da2, dim='time').values

        return xr.DataArray(cor, dims=['mode1', 'mode2'],
                            coords = {'mode1': modes_concat['mode'].values,
                                      'mode2': modes_concat['mode'].values})
    
    da_list = []
    for s_id in modes_concat.season_id.values:
        da = cor(s_id).expand_dims({'season_id': [s_id]})
        da_list.append(da)
    cor_da = xr.concat(da_list, dim='season_id')
    
    return cor_da

In [None]:
cor_da = mode_cor()

### Set correlation threshold for figure and regression

In [None]:
correlation_threshold = 0.7

In [None]:
def plot_cor(da_list, season_name, names, cor_thresh, save_fig, filename):
    """
    Plot correlation matrix
    """
    
    n_modes = len(da_list[0].mode1)
    
    da_copy = da_list[0].copy(deep=True)
    arr = da_copy.values
    arr[np.triu_indices(arr.shape[0], 0)] = np.nan
    
    with plt.rc_context(plt_params):
    
        fig, ax = plt.subplots(1, 1, figsize=(6.9, 6.9))

        p = ax.pcolormesh(arr, cmap='PiYG', vmin=-1, vmax=1)
        print(np.nanmin(arr), np.nanmax(arr))
        
        # For individual countries, we can plot markers where cor is > thresh
        if len(da_list) == 1:
            below_thresh = np.where(abs(arr) > cor_thresh)
            ax.scatter(below_thresh[1]+.5, below_thresh[0]+.5, fc='k', ec='k', s=30)

        ax.set_xlim(0, n_modes)
        ax.set_xticks(np.arange(n_modes) + 0.5)
        ax.set_xticklabels(names, rotation=30)

        ax.set_ylim(0, n_modes)
        ax.set_yticks(np.arange(n_modes) + 0.5)
        ax.set_yticklabels(names)
        
        ax.set_title(season_name)
        
        cb_ax1 = fig.add_axes([0.92, 0.126, 0.02, 0.753])
        cb1 = fig.colorbar(p, cax=cb_ax1, orientation='vertical', ticks=np.arange(-1, 1.01, 0.2))
        cb1.ax.set_ylabel('Correlation [-]', rotation=270, va='bottom')
        
        if len(da_list) == 2:
            da_copy = da_list[1].copy(deep=True)
            arr = da_copy.values
            arr[np.tril_indices(arr.shape[0], 0)] = np.nan
            print(np.nanmin(arr), np.nanmax(arr))
            
            p2 = ax.pcolormesh(arr, cmap='inferno', vmin=0, vmax=.35)
            
            ax.plot((0,1), (0,1), c='k', transform=ax.transAxes)
            
            cb_ax2 = fig.add_axes([1.05, 0.126, 0.02, 0.753])
            cb2 = fig.colorbar(p2, cax=cb_ax2, orientation='vertical', ticks=np.arange(0, 1.01, 0.05))
            cb2.ax.set_ylabel('Standard deviation of correlation [-]', rotation=270, va='bottom')
            cb1.ax.set_ylabel('Mean of correlation [-]', rotation=270, va='bottom')
            
        if save_fig:
            plt.savefig('./figures/'+filename, format='pdf', dpi=400, bbox_inches='tight')

In [None]:
plot_cor([cor_da.sel(season_id='BRS_0')], 'Brazil S growing season', mode_names,
         correlation_threshold, save_fig=False, filename='mode_cor_BRS.pdf')

In [None]:
plot_cor([cor_da.mean('season_id'),
          cor_da.std('season_id')],
         'Mean and standard deviation over all growing seasons',
         mode_names, cor_thresh=np.nan,
        save_fig=True, filename='mode_cor_all.pdf')

In [None]:
print(cor_da.sel(mode1='nino34', mode2=['mjo_dpm_p1', 'mjo_dpm_p4', 'mjo_dpm_p8']).min().values,
      cor_da.sel(mode1='nino34', mode2=['mjo_dpm_p1', 'mjo_dpm_p4', 'mjo_dpm_p8']).max().values)

In [None]:
print(abs(cor_da.sel(mode1='nino34', mode2=['mjo_dpm_p1', 'mjo_dpm_p4', 'mjo_dpm_p8'])).min().values,
      cor_da.sel(mode1='nino34', mode2=['mjo_dpm_p1', 'mjo_dpm_p4', 'mjo_dpm_p8']).max().values)

In [None]:
print(cor_da.sel(mode1='nino34', mode2=['mjo_dpm_p3', 'mjo_dpm_p5', 'mjo_dpm_p7']).min().values,
      cor_da.sel(mode1='nino34', mode2=['mjo_dpm_p3', 'mjo_dpm_p5', 'mjo_dpm_p7']).max().values)

In [None]:
print(abs(cor_da.sel(mode1='nino34', mode2=['mjo_dpm_p3', 'mjo_dpm_p5', 'mjo_dpm_p7'])).min().values,
      cor_da.sel(mode1='nino34', mode2=['mjo_dpm_p3', 'mjo_dpm_p5', 'mjo_dpm_p7']).max().values)

# Generalised linear models

#### Regression to fit to n events

In [None]:
def AIC(k, llh):
    """
    Calculate the Aikake Information Criterion
    """
    return 2*k - 2*(llh)

In [None]:
def AICc(AIC, n, k):
        """
        Calculates corrected Akaike Information Criterion.
        AIC is the AIC score from a fitted model
        n is the sample size
        k is the number of parameters.
            """
        return AIC + (2*k**2 + 2*k) / (n - k - 1)

In [None]:
def plot_ts_hist(da, save_fig=False, filename='events_time_series.pdf'):
    """
    Plot time series and histograms
    """
    with plt.rc_context(plt_params):
        fig, ax = plt.subplots(4, 4, figsize=(6.9,8), dpi=100)

        for i, s_id in enumerate(da.season_id.values):
            da.sel(season_id=s_id).plot(ax=ax.flatten()[i])

            ax.flatten()[i].set_ylim(-5.05, 5.05)
            ax.flatten()[i].set_yticks(np.arange(-5, 5.01, 1))

            ax.flatten()[i].set_title('')
            ax.flatten()[i].text(0.05, 0.93, list(country_order.values())[i], transform=ax.flatten()[i].transAxes)

            if i < 11:
                ax.flatten()[i].set_xlabel('')
                ax.flatten()[i].set_xticks(da.time.values[::10])
                ax.flatten()[i].set_xticklabels([])
            else:
                ax.flatten()[i].set_xlabel('Year')
                ax.flatten()[i].set_xticks(da.time.values[::10])
                ax.flatten()[i].set_xticklabels(da.time.dt.year.values[::10], rotation=30, ha='center')

            if i in [0, 4, 8, 12]:
                ax.flatten()[i].set_ylabel(r'$n$ signed events')
            else:
                ax.flatten()[i].set_ylabel('')
                ax.flatten()[i].set_yticklabels([])

            # Histogram inset
            axins = inset_axes(ax.flatten()[i], width="50%", height="19%", bbox_to_anchor=(.0, .08, .8, .9),
                               bbox_transform=ax.flatten()[i].transAxes, loc=3, borderpad=1)
            axins.hist(da.sel(season_id=s_id).values, color='dimgray', bins=10)

            axins.set_xlim(-4, 4)
            axins.set_xticks(np.arange(-4, 4.01, 2))
            axins.set_xticklabels(range(-4, 5, 2))

            axins.set_ylim(0,25)
            axins.set_yticks([0,25])
            axins.set_yticklabels([0, 25])
            axins.yaxis.tick_right()

            axins.tick_params(direction='out', length=2)

        ax[-1,-1].axis('off')

        plt.subplots_adjust(wspace=0.08, hspace=0.08)
        
        if save_fig:
            plt.savefig('./figures/'+filename, format='pdf', dpi=400, bbox_inches='tight')

In [None]:
def subset_predictors(predictors, correlation_da, correlation_thresh, mode1_name='mode1', mode2_name='mode2'):
    """
    Obtain a subset of predictors that excludes combinations where any pair of predictors
    has a correlation greater than some threshold
    """
    
    def powerset(iterable):
        "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
        from itertools import chain, combinations
        s = list(iterable)
        return list(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))
    
    discard = [] # will be a list of lists containing predictor sets to discard for each mode
    
    full_set = powerset(predictors)
    candidates = full_set[len(predictors)+1:] # Only consider sets with at least two predictors
    
    for mode in correlation_da[mode1_name].values:
        mode_discard = [] # list containing predictor sets to discard
        
        # obtain set of modes correlated with mode
        correlated_modes = correlation_da.where(abs(correlation_da) > correlation_thresh).sel({mode1_name: mode})
        correlated_modes = correlated_modes[correlated_modes.notnull()]
        correlated_modes = correlated_modes[mode2_name].values
        mode_i = np.where(correlated_modes == mode)[0][0]
        correlated_modes = np.delete(correlated_modes, mode_i) # remove mode (correlation of a mode with itself is 1)
        set_cor_modes = set(correlated_modes)
        
        if len(set_cor_modes) > 0: # Only process modes that are correlated with other modes
            subset = [x for x in candidates if mode in x] # all predictor sets including this mode
            
            for s in subset:
                set_s = set(s)
                intersection = set_s.intersection(set_cor_modes) # returns predictor sets containing correlated modes
                if len(intersection) > 0:
                    mode_discard.append(s) # append any correlated predictor sets
                    
        discard.append(mode_discard)
        
    # Use these discard lists to choose predictor sets with only uncorrelated predictors
    for mode_discard in discard:
        candidates = [i for i in candidates if i not in mode_discard]
        
    return candidates

In [None]:
def glm_selection_uncorrelated(data, response, family, correlation_da, correlation_thresh, mode1_name='mode1', mode2_name='mode2'):
    """
    Fits a generalised linear model to the data. All possible combinations of predictors are tested,
    as long as none of the predictors are correlated with each other.

    data : pandas DataFrame containing response and predictors
    response : string indicating name of response
    family : statsmodels family and links function i.e. statsmodels.families.<family>(link=sm.families.links.<link>)
    correlation_da : xarray DataArray of predictors' correlation
    correlation_thresh : value of correlation by which to exclude predictors
    """
    
    n = data.shape[0]
    
    predictors = set(data.columns)
    predictors.remove(response)
    
    all_combs = subset_predictors(predictors, correlation_da, correlation_thresh, mode1_name=mode1_name, mode2_name=mode2_name)
    n_combs = len(all_combs)
    print(n_combs, 'combinations\n')
    
    formulas = ['XXX' for i in range(n_combs)]
    scores = np.full(n_combs, np.nan)
    
    for i, com in enumerate(all_combs):
        k = len(com)
        
        formula = "{} ~ {}".format(response, ' + '.join(com))
        formulas[i] = formula
        
        aic = smf.glm(formula, data, family=family).fit().aic
        aicc = AICc(aic, n, k)
        scores[i] = aicc
    
    all_results = OrderedDict({k: v for k, v in zip(scores, formulas)})
    
    best_score_ind = np.where(scores == np.nanmin(scores))[0][0]
    best_predictors = all_combs[best_score_ind]
    
    formula = "{} ~ {} + 1".format(response, ' + '.join(best_predictors))
    model = smf.glm(formula, data, family=family).fit()
    
    return model, all_results

In [None]:
def fit_models(y_da, X_da, family_dict, correlation_da, correlation_thresh,
               standardise=True, mode1_name='mode1', mode2_name='mode2'):
    """
    Wrapper for fitting models
    """
    best_models = {}
    all_models = {}
    for s_id in y_da.season_id.values:
        print(s_id)
        
        family_name = family_dict[s_id]
        if family_name == 'gamma':
            shift = 2.001
            family = sm.families.Gamma(link=sm.families.links.log())
        elif family_name == 'normal':
            shift = 0
            family = sm.families.Gaussian(link=sm.families.links.identity())
        else:
            raise ValueError("Incorrect family. Should be 'normal' or 'gamma'.")
        
        y = y_da.sel(season_id=s_id) + shift
        X = X_da.sel(season_id=s_id).transpose('time', 'mode')
        
        if standardise:
            y = y / y.std('time')
            X = X / X.std('time')

        y_df = pd.DataFrame(y.values, index=y.time.values, columns=['y'])
        X_df = pd.DataFrame(X.values, index=y.time.values, columns=X.mode.values)
        df = pd.concat([y_df, X_df], axis=1)

#         models[s_id] = forward_selected(model, df, 'y', family)
        best_models[s_id], all_models[s_id] = glm_selection_uncorrelated(df, 'y', family,
                                                                         correlation_da.sel(season_id=s_id),
                                                                         correlation_thresh)
        
    return best_models, all_models

In [None]:
def plot_residuals(model_dict, resid_name='Deviance residuals', save_fig=False, filename='residuals.pdf'):
    """
    Plot model residuals
    """
    with plt.rc_context(plt_params):
        fig, ax = plt.subplots(4, 4, figsize=(6.9,8), dpi=100)

        for i, (k, model) in enumerate(zip(model_dict.keys(), model_dict.values())):
            ax.flatten()[i].axhline(0, color='k', lw=1)
            
            if resid_name == 'Deviance residuals':
                y = model.resid_deviance
            elif resid_name == 'Pearson residuals':
                y = model.resid_pearson
            elif resid_name == 'Response residuals':
                y = model.resid_response
            else:
                raise ValueError('Incorrect residuals type')
            
            ax.flatten()[i].scatter(model.fittedvalues, y, s=2)
            ax.flatten()[i].set_title(k)
            
            ax.flatten()[i].set_xticks(np.arange(-3, 3.01, 1))

            ax.flatten()[i].set_ylim(-2.05, 2.05)
            ax.flatten()[i].set_yticks(np.arange(-2, 2.01, 0.5))

            ax.flatten()[i].set_title('')
            ax.flatten()[i].text(0.05, 0.93, list(country_order.values())[i], transform=ax.flatten()[i].transAxes)

            if i < 11:
                ax.flatten()[i].set_xlabel('')
                ax.flatten()[i].set_xticklabels([])
            else:
                ax.flatten()[i].set_xlabel('Fitted values')

            if i in [0, 4, 8, 12]:
                ax.flatten()[i].set_ylabel(resid_name)
            else:
                ax.flatten()[i].set_ylabel('')
                ax.flatten()[i].set_yticklabels([])


        ax[-1,-1].axis('off')

        plt.subplots_adjust(wspace=0.08, hspace=0.08)
        
        if save_fig:
            plt.savefig('./figures/'+filename, format='pdf', dpi=400, bbox_inches='tight')

In [None]:
def plot_residuals_hist(model_dict, resid_name='Deviance residuals', save_fig=False, filename='residuals_hist.pdf'):
    """
    Plot histogram of the residuals
    """
    with plt.rc_context(plt_params):
        fig, ax = plt.subplots(4, 4, figsize=(6.9,8), dpi=100)

        for i, (k, model) in enumerate(zip(model_dict.keys(), model_dict.values())):
            ax.flatten()[i].axhline(0, color='k', lw=1)
            
            if resid_name == 'Deviance residuals':
                y = model.resid_deviance
            elif resid_name == 'Pearson residuals':
                y = model.resid_pearson
            elif resid_name == 'Response residuals':
                y = model.resid_response
            else:
                raise ValueError('Incorrect residuals type')
            
            ax.flatten()[i].hist(model.fittedvalues, 15)
            ax.flatten()[i].axvline(0, c='k')
            ax.flatten()[i].set_title(k)
            
            ax.flatten()[i].set_xticks(np.arange(-3, 3.01, 1))

            ax.flatten()[i].set_ylim(0, 15)
            ax.flatten()[i].set_yticks(range(0, 16, 5))

            ax.flatten()[i].set_title('')
            ax.flatten()[i].text(0.05, 0.93, list(country_order.values())[i], transform=ax.flatten()[i].transAxes)

            if i < 11:
                ax.flatten()[i].set_xlabel('')
                ax.flatten()[i].set_xticklabels([])
            else:
                ax.flatten()[i].set_xlabel(resid_name)

            if i in [0, 4, 8, 12]:
                ax.flatten()[i].set_ylabel('Counts')
            else:
                ax.flatten()[i].set_ylabel('')
                ax.flatten()[i].set_yticklabels([])


        ax[-1,-1].axis('off')

        plt.subplots_adjust(wspace=0.08, hspace=0.08)
        
        if save_fig:
            plt.savefig('./figures/'+filename, format='pdf', dpi=400, bbox_inches='tight')

In [None]:
def plot_qq(model_dict, resid_name='Deviance residuals', save_fig=False, filename='qq_plot.pdf'):
    """
    Q-q plot
    """
    with plt.rc_context(plt_params):
        fig, ax = plt.subplots(4, 4, figsize=(6.9,8), dpi=100)

        for i, (k, model) in enumerate(zip(model_dict.keys(), model_dict.values())):
            
            if resid_name == 'Deviance residuals':
                y = model.resid_deviance
            elif resid_name == 'Pearson residuals':
                y = model.resid_pearson
            elif resid_name == 'Response residuals':
                y = model.resid_response
            else:
                raise ValueError('Incorrect residuals type')
                
            pp = sm.ProbPlot(y, dist=stats.norm, fit=True)
            
            ax.flatten()[i].plot((-2, 2), (-2, 2), color='r', zorder=0)
            ax.flatten()[i].scatter(pp.theoretical_quantiles, pp.sample_quantiles, s=4, zorder=1)
            
            ax.flatten()[i].set_title(k)
            
            ax.flatten()[i].set_xticks(np.arange(-2, 2.01, 1))

            ax.flatten()[i].set_ylim(-3, 3)
            ax.flatten()[i].set_yticks(np.arange(-3, 3.01, 1))

            ax.flatten()[i].set_title('')
            ax.flatten()[i].text(0.05, 0.93, list(country_order.values())[i], transform=ax.flatten()[i].transAxes)

            if i < 11:
                ax.flatten()[i].set_xlabel('')
                ax.flatten()[i].set_xticklabels([])
            else:
                ax.flatten()[i].set_xlabel('Theoretical quantiles')

            if i in [0, 4, 8, 12]:
                ax.flatten()[i].set_ylabel('Empirical quantiles')
            else:
                ax.flatten()[i].set_ylabel('')
                ax.flatten()[i].set_yticklabels([])


        ax[-1,-1].axis('off')

        plt.subplots_adjust(wspace=0.08, hspace=0.08)
        
        if save_fig:
            plt.savefig('./figures/'+filename, format='pdf', dpi=400, bbox_inches='tight')

In [None]:
def plot_model_coeffs(model_dict, mode_da, names, save_fig=False, filename='model_coeffs.pdf'):
    """
    Plot model coefficient values
    """
    def get_params(model_dict, mode_da):
        arr = np.full((len(mode_da.mode), len(mode_da.season_id)), np.nan) # rows are modes, columns are countries

        for j, s_id in enumerate(country_order.keys()):
            params = model_dict[s_id].params
            for i, m in enumerate(mode_da.mode.values):
                try:
                    coeff = params[m]
                    arr[i,j] = coeff
                except:
                    pass
                
        return arr
    
    def roundup(x, nearest):
        import math
        return int(math.ceil(x / nearest)) * nearest
                
    plot_data = get_params(model_dict, mode_da)
    p_min = np.nanmin(plot_data)
    p_max = np.nanmax(plot_data)
    p_absmax = np.nanmax(np.abs([p_min, p_max]))
    
    if p_absmax > 2.5:
        p_lim = roundup(p_absmax, 1)
        cbar_space = 1
    else:
        p_lim = roundup(p_absmax, 0.5)
        cbar_space = 0.5
    print(p_min, p_max, p_lim)
                
    with plt.rc_context(plt_params):
        fig, ax = plt.subplots(1, 1, figsize=(3, 2), dpi=160)

        p = ax.pcolormesh(plot_data.transpose(), cmap='RdBu_r', vmin=-p_lim, vmax=p_lim)

        ax.invert_yaxis()
        ax.set_yticks(np.arange(0.5, len(modes_concat.season_id)))
        ax.set_yticklabels(list(country_order.values()))
    #     ax.set_ylabel('Coffee seasons\n\n')

        ax.set_xticks(np.arange(0.5, len(mode_da.mode)+0.1, 1))
        ax.set_xticklabels(names, rotation=90)

#         ax.set_title('Coefficients of selected model')
        
        ax.text(-0.32, 0.7, 'Arabica', transform=ax.transAxes, rotation=90, ha='center', va='center')
        ax.text(-0.32, 0.2, 'Robusta', transform=ax.transAxes, rotation=90, ha='center', va='center')

        ax.annotate('', xy=(-0.35, 0.4), xycoords='axes fraction', xytext=(1.01, 0.4), 
            arrowprops=dict(arrowstyle="-", ls=':', lw=plt_params['lines.linewidth']-0.5))
        ax.annotate('', xy=(-0.29, 0.), xycoords='axes fraction', xytext=(-0.29, 1), 
            arrowprops=dict(arrowstyle="-", ls=':', lw=plt_params['lines.linewidth']-0.5))

        cb_ax1 = fig.add_axes([0.93, 0.15, 0.02, 0.71])
        cb1 = fig.colorbar(p, cax=cb_ax1, orientation='vertical', ticks=np.arange(-p_lim, p_lim+.01, cbar_space))
        cb1.ax.set_ylabel('Coefficient [-]', rotation=270, va='bottom')
        
        if save_fig:
            plt.savefig('./figures/'+filename, format='pdf', dpi=400, bbox_inches='tight')

In [None]:
# Select relevant countries for each species and concat
arabica_season_ids = [s for s in mn2t_growing_events.season_id.values if s.split('_')[0] in arabica_abbrevs]
robusta_season_ids = [s for s in mn2t_growing_events.season_id.values if s.split('_')[0] in robusta_abbrevs]

In [None]:
arabica_risks = {
                 'VPD > x': vpd_growing_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'Tmax > x': mx2t_growing_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'T < x': t2m_lt_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'T > x': t2m_ut_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'P < x': tp_lt_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'P > x': tp_ut_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids)
                }

In [None]:
robusta_risks = {
#                  'Tmin fl < x': mn2t_flowering_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'Tmin gr > x': mn2t_growing_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'T < x': t2m_lt_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'T > x': t2m_ut_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'P < x': tp_lt_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'P > x': tp_ut_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids)
                }

In [None]:
signed_arabica_risks = copy.deepcopy(arabica_risks)
signed_arabica_risks['T < x'] *= -1
signed_arabica_risks['P > x'] *= -1

In [None]:
signed_robusta_risks = copy.deepcopy(robusta_risks)
# signed_robusta_risks['Tmin fl < x'] *= -1
signed_robusta_risks['T < x'] *= -1
signed_robusta_risks['P > x'] *= -1

In [None]:
n_events = fn.combine_n_events([arabica_risks, robusta_risks])
n_events = n_events.sel(season_id=list(country_order.keys()))

In [None]:
signed_n_events = fn.combine_n_events([signed_arabica_risks, signed_robusta_risks]) # can be used to tell whether the majority of events in a year are warm/dry or cold/wet
signed_n_events = signed_n_events.sel(season_id=list(country_order.keys()))
# signed_n_events = xr.where(signed_n_events < 0, n_events * -1, n_events)

### Fit to signed number of events

In [None]:
plot_ts_hist(signed_n_events, save_fig=True, filename='signed_events_time_series.pdf')

### Which family do we choose when fitting our GLM?
- Arabica data ranges from -2 to 4
- Robusta data ranges from -2 to 3
- A Gaussian distribution looks appropriate for some regions
- A Gamma could also fit (the data would first be shifted to ensure it is nonnegative)

In [None]:
dists = {k: 'normal' for k in modes_concat.season_id.values}
dists

### All possible predictor sets with uncorrelated predictors

In [None]:
best_models, all_models = fit_models(signed_n_events, modes_concat,
                                        dists, cor_da, correlation_threshold)

In [None]:
for k, model in zip(best_models.keys(), best_models.values()):
    print(k, model.model.formula, '\nAIC: {0:.2f}\n'.format(model.aic))

In [None]:
plot_residuals(best_models, resid_name='Deviance residuals', save_fig=True, filename='gaussian_family_residuals.pdf')

In [None]:
plot_residuals_hist(best_models, save_fig=True, filename='gaussian_family_residuals_hist.pdf')

In [None]:
plot_qq(best_models, save_fig=True, filename='gaussian_family_qq.pdf')

In [None]:
plot_model_coeffs(best_models, modes_concat, mode_names,
                  save_fig=True, filename='gaussian_family_coefficients.pdf')

# Ordinal regression - gives surprisingly similar results
Treat the data as ordered categorical variables i.e. 2 wet events, 1 wet events, no events, 1 dry event, 2 dry events...

Note that you need the development version of `statsmodels` (v0.14.0) to run ordered models

In [None]:
import pickle
from statsmodels.miscmodels.ordinal_model import OrderedModel

In [None]:
def ordered_model_selection_uncorrelated(data, response, distr, correlation_da, correlation_thresh, mode1_name='mode1', mode2_name='mode2'):
    """
    Fits a generalised linear model to the data. All possible combinations of predictors are tested,
    as long as none of the predictors are correlated with each other.

    data : pandas DataFrame containing response and predictors
    response : string indicating name of response
    family : statsmodels family and link function i.e. statsmodels.families.<family>(link=sm.families.links.<link>)
    correlation_da : xarray DataArray of predictors' correlation
    correlation_thresh : value of correlation by which to exclude predictors
    """
    
    n = data.shape[0]
    
    predictors = set(data.columns)
    predictors.remove(response)
    
    all_combs = subset_predictors(predictors, correlation_da, correlation_thresh, mode1_name=mode1_name, mode2_name=mode2_name)
    n_combs = len(all_combs)
    print(n_combs, 'combinations\n')
    
    formulas = ['XXX' for i in range(n_combs)]
    scores = np.full(n_combs, np.nan)
    
    for i, com in enumerate(all_combs):
        k = len(com)
        
        formula = "{} ~ {}".format(response, ' + '.join(com))
        formulas[i] = formula
        
        aic = OrderedModel.from_formula(formula, data, distr=distr).fit(method='lbfgs', disp=False).aic
        aicc = AICc(aic, n, k)
        scores[i] = aicc
    
    all_results = OrderedDict({k: v for k, v in zip(scores, formulas)})
    
    best_score_ind = np.where(scores == np.nanmin(scores))[0][0]
    best_predictors = all_combs[best_score_ind]
    
    formula = "{} ~ {} ".format(response, ' + '.join(best_predictors))
    model = OrderedModel.from_formula(formula, data, distr=distr).fit(method='lbfgs', disp=False)
    
    return model, all_results

In [None]:
def fit_ordered_models(y_da, X_da, distr, correlation_da, correlation_thresh,
                       standardise=True, mode1_name='mode1', mode2_name='mode2'):
    """
    Wrapper to fit ordered models
    """
    best_models = {}
    all_models = {}
    for s_id in y_da.season_id.values:
        print(s_id)
        
        y = y_da.sel(season_id=s_id)
        X = X_da.sel(season_id=s_id).transpose('time', 'mode')
        
        if standardise:
            y = y / y.std('time')
            X = X / X.std('time')

        y_df = pd.DataFrame(y.values, index=y.time.values, columns=['y'])
        X_df = pd.DataFrame(X.values, index=y.time.values, columns=X.mode.values)
        df = pd.concat([y_df, X_df], axis=1)

        best_models[s_id], all_models[s_id] = ordered_model_selection_uncorrelated(df, 'y', distr,
                                                                         correlation_da.sel(season_id=s_id),
                                                                         correlation_thresh)
        
    return best_models, all_models

### Computing this takes a little while to run (~1 hour)

In [None]:
compute = False

if compute:
    best_ordered_models, all_ordered_models = fit_ordered_models(signed_n_events, modes_concat,
                                                                'logit', cor_da, correlation_threshold)
    best_model_formulas = {k: [best_ordered_models[k].model.formula,
                               best_ordered_models[k].aic]
                           for k in list(best_ordered_models.keys())}

    with open('/g/data/xv83/dr6273/work/projects/coffee/data/ordered_regression_results_best.pickle', 'wb') as handle:
        pickle.dump(best_model_formulas, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open('/g/data/xv83/dr6273/work/projects/coffee/data/ordered_regression_results_best.pickle', 'rb') as handle:
        best_model_formulas = pickle.load(handle)

In [None]:
best_model_formulas

In [None]:
standardise = True

best_ordered_models = {}
for s_id in best_model_formulas.keys():
    y = signed_n_events.sel(season_id=s_id)
    X = modes_concat.sel(season_id=s_id).transpose('time', 'mode')
    
    if standardise:
        y = y / y.std('time')
        X = X / X.std('time')

    y_df = pd.DataFrame(y.values, index=y.time.values, columns=['y'])
    X_df = pd.DataFrame(X.values, index=y.time.values, columns=X.mode.values)
    df = pd.concat([y_df, X_df], axis=1)
    
    formula = best_model_formulas[s_id][0]
    model = OrderedModel.from_formula(formula, df, distr='logit').fit(method='lbfgs', disp=False)
    
    best_ordered_models[s_id] = model

### AIC is lower for ordered regression compared to assuming a normal distribution in only three regions - these correspond to regions with the poorest fit on the q-q plots.

In [None]:
for k, norm, ordered in zip(best_models.keys(), best_models.values(), best_ordered_models.values()):
    norm_aic = norm.aic
    ordered_aic = ordered.aic
    if ordered_aic < norm_aic:
        print(f'{k}: Norm AIC: {norm_aic}; Ordered AIC: {ordered_aic}')

In [None]:
plot_model_coeffs(best_ordered_models, modes_concat, mode_names,
                  save_fig=True, filename='ordinal_model_coefficients.pdf')

# Write data to file

In [None]:
for s_id in signed_n_events.season_id.values:
    y = signed_n_events.sel(season_id=s_id)
    X = modes_concat.sel(season_id=s_id).transpose('time', 'mode')

    y_df = pd.DataFrame(y.values, index=y.time.values, columns=['y'])
    X_df = pd.DataFrame(X.values, index=y.time.values, columns=X.mode.values)
    df = pd.concat([y_df, X_df], axis=1)
    
    df.to_csv(f'/g/data/xv83/dr6273/work/projects/coffee/data/country_climate_mode_data_{s_id}.csv')

# Close cluster

In [None]:
client.close()
cluster.close()