# Role of climate modes

In [None]:
from dask_jobqueue import PBSCluster
from dask.distributed import Client

In [None]:
# One node on Gadi has 48 cores - try and use up a full core before going to multiple nodes (jobs)

walltime = '00:30:00'
cores = 2
memory = '8GB'

cluster = PBSCluster(walltime=str(walltime), cores=cores, memory=str(memory),
                     job_extra=['-l ncpus='+str(cores),
                                '-l mem='+str(memory),
                                '-P xv83',
                                '-l storage=gdata/xv83+gdata/rt52+scratch/xv83'],
                     header_skip=["select"])

In [None]:
cluster.scale(jobs=1)
client = Client(cluster)

In [None]:
client

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import regionmask
import copy
import xskillscore as xs
from xbootstrap import block_bootstrap

import matplotlib
import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy
cartopy.config['pre_existing_data_dir'] = '/g/data/xv83/dr6273/work/data/cartopy-data'
cartopy.config['data_dir'] = '/g/data/xv83/dr6273/work/data/cartopy-data'

import functions as fn

In [None]:
plt_params = fn.get_plot_params()

In [None]:
# default colours
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

# Load coffee data

In [None]:
# Order abbrevs and names by species and production
country_order = fn.get_country_order()

In [None]:
growing_calendar = pd.read_csv('/g/data/xv83/dr6273/work/projects/coffee/data/coffee_country_growing_calendar_extended.csv',
                               index_col=0)
growing_calendar.head()

In [None]:
arabica_abbrevs = np.unique(growing_calendar.loc[(growing_calendar.species == 'Arabica'), 'abbrevs'])
robusta_abbrevs = np.unique(growing_calendar.loc[(growing_calendar.species == 'Robusta'), 'abbrevs'])

# Gridded climate data relevant for each phase of coffee (growing and flowering)

In [None]:
vpd_flowering = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/era5_vpd_detrended_Flowering_upper_tail_1_std.zarr',
                             consolidated=True)
vpd_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/era5_vpd_detrended_Growing_upper_tail_1_std.zarr',
                              consolidated=True)

In [None]:
mn2t_flowering = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_tmin_detrended_Flowering_lower_tail_1_std.zarr',
                             consolidated=True)
mn2t_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_tmin_detrended_Growing_upper_tail_1_std.zarr',
                              consolidated=True)

In [None]:
mx2t_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_tmax_detrended_Growing_upper_tail_1_std.zarr',
                                  consolidated=True)

In [None]:
t2m_lt_growing_optimal = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_temperature_detrended_Growing_lower_tail_1_std.zarr',
                                             consolidated=True)

In [None]:
t2m_ut_growing_optimal = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/berkeley_temperature_detrended_Growing_upper_tail_1_std.zarr',
                                             consolidated=True)

In [None]:
tp_lt_growing_optimal = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/gpcc_precip_detrended_Annual_lower_tail_1_std.zarr',
                                             consolidated=True)

In [None]:
tp_ut_growing_optimal = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/gpcc_precip_detrended_Annual_upper_tail_1_std.zarr',
                                             consolidated=True)

### Proportion of each country, and global coffee area, in drought each year

In [None]:
vpd_grid_template = 'era5'
temperature_grid_template = 'berkeley'
precip_grid_template = 'gpcc'

### VPD events

In [None]:
vpd_flowering_events = fn.calculate_event_statistics(vpd_flowering, vpd_grid_template).compute()

In [None]:
vpd_growing_events = fn.calculate_event_statistics(vpd_growing, vpd_grid_template).compute()

### Tmin averages events

In [None]:
mn2t_flowering_events = fn.calculate_event_statistics(mn2t_flowering, temperature_grid_template).compute()

In [None]:
mn2t_growing_events = fn.calculate_event_statistics(mn2t_growing, temperature_grid_template).compute()

### Tmax averages events

In [None]:
mx2t_growing_events = fn.calculate_event_statistics(mx2t_growing, temperature_grid_template).compute()

### T ranges events

In [None]:
t2m_lt_growing_optimal_events = fn.calculate_event_statistics(t2m_lt_growing_optimal, temperature_grid_template).compute()

In [None]:
t2m_ut_growing_optimal_events = fn.calculate_event_statistics(t2m_ut_growing_optimal, temperature_grid_template).compute()

### Precip ranges events

In [None]:
tp_lt_growing_optimal_events = fn.calculate_event_statistics(tp_lt_growing_optimal, precip_grid_template).compute()

In [None]:
tp_ut_growing_optimal_events = fn.calculate_event_statistics(tp_ut_growing_optimal, precip_grid_template).compute()

# Load mode data

- Stick to growing season, as only one climate risk in flowering season.
    - Means the comparison to 12-month rainfall is not quite right.

In [None]:
sst_dataset = 'hadisst'

### Nino3.4

In [None]:
nino34_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_nino34_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### DMI

In [None]:
dmi_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_dmi_detrended_Growing_both_tails_1_std.zarr',
                            consolidated=True).compute()

### Atlantic Nino

In [None]:
atl_nino_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_atl_nino_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### TNA

In [None]:
tna_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_tna_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### TSA

In [None]:
tsa_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/'+sst_dataset+'_tsa_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### MJO

In [None]:
mjo_days_per_month_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/era5_mjo_days_per_month_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

mjo_mean_amplitude_growing = xr.open_zarr('/g/data/xv83/dr6273/work/projects/coffee/data/era5_mjo_mean_amplitude_detrended_Growing_both_tails_1_std.zarr',
                                consolidated=True).compute()

### Correlation between modes

In [None]:
modes_concat = xr.concat([
    nino34_growing.nino34_detrended.expand_dims({'mode': ['nino34']}),
    dmi_growing.dmi_detrended.expand_dims({'mode': ['dmi']}),
    atl_nino_growing.atl_nino_detrended.expand_dims({'mode': ['atl_nino']}),
    tna_growing.tna_detrended.expand_dims({'mode': ['tna']}),
    tsa_growing.tsa_detrended.expand_dims({'mode': ['tsa']}),
    
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=1).expand_dims({'mode': ['mjo_dpm_p1']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=2).expand_dims({'mode': ['mjo_dpm_p2']}).drop('phase_ID'),
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=3).expand_dims({'mode': ['mjo_dpm_p3']}).drop('phase_ID'),    
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=4).expand_dims({'mode': ['mjo_dpm_p4']}).drop('phase_ID'),    
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=5).expand_dims({'mode': ['mjo_dpm_p5']}).drop('phase_ID'),    
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=6).expand_dims({'mode': ['mjo_dpm_p6']}).drop('phase_ID'),    
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=7).expand_dims({'mode': ['mjo_dpm_p7']}).drop('phase_ID'),    
    mjo_days_per_month_growing.mjo_days_per_month_detrended.sel(phase_ID=8).expand_dims({'mode': ['mjo_dpm_p8']}).drop('phase_ID'),
                        ], 'mode')

In [None]:
modes_concat = modes_concat.sel(time=slice('1980', '2020'))

### Subset modes on warm/dry and cold/wet events

In [None]:
# Select relevant countries for each species and concat
arabica_season_ids = [s for s in mn2t_growing_events.season_id.values if s.split('_')[0] in arabica_abbrevs]
robusta_season_ids = [s for s in mn2t_growing_events.season_id.values if s.split('_')[0] in robusta_abbrevs]

In [None]:
arabica_risks = {
                 'VPD > x': vpd_growing_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'Tmax > x': mx2t_growing_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'T < x': t2m_lt_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'T > x': t2m_ut_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'P < x': tp_lt_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids),
                 'P > x': tp_ut_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=arabica_season_ids)
                }

In [None]:
robusta_risks = {
                 'Tmin fl < x': mn2t_flowering_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'Tmin gr > x': mn2t_growing_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'T < x': t2m_lt_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'T > x': t2m_ut_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'P < x': tp_lt_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids),
                 'P > x': tp_ut_growing_optimal_events.sel(time=slice('1980', '2020')).sel(season_id=robusta_season_ids)
                }

In [None]:
signed_arabica_risks = copy.deepcopy(arabica_risks)
signed_arabica_risks['T < x'] *= -1
signed_arabica_risks['P > x'] *= -1

In [None]:
signed_robusta_risks = copy.deepcopy(robusta_risks)
signed_robusta_risks['Tmin fl < x'] *= -1
signed_robusta_risks['T < x'] *= -1
signed_robusta_risks['P > x'] *= -1

In [None]:
n_events = fn.combine_n_events([arabica_risks, robusta_risks])
n_events = n_events.sel(season_id=list(country_order.keys()))

signed_n_events = fn.combine_n_events([signed_arabica_risks, signed_robusta_risks]) # can be used to tell whether the majority of events in a year are warm/dry or cold/wet
signed_n_events = signed_n_events.sel(season_id=list(country_order.keys()))
signed_n_events = xr.where(signed_n_events < 0, n_events * -1, n_events)

# Events and modes time series

In [None]:
n_events_s = xr.where(signed_n_events < 0, n_events * -1, n_events)

negatives = xr.where(n_events_s < 0, np.abs(n_events_s), 0)
positives = xr.where(n_events_s > 0, n_events_s, 0)

year_counts_neg = negatives.sum('season_id')
year_counts_pos = positives.sum('season_id')

In [None]:
# Mean of all events, and warm/dry, cold/wet
#  For warm/dry, add mean to cold/wet annual totals, so that exceedances are obvious
mean_all = (year_counts_neg + year_counts_pos).mean('time')
mean_wd = year_counts_pos.mean('time')
mean_cw = year_counts_neg.mean('time')

sd_wd = year_counts_pos.std('time')
sd_cw = year_counts_neg.std('time')

# For plotting
cw_exceedances = xr.where(year_counts_neg > mean_cw + 0, 3, np.nan)
wd_exceedances = xr.where(year_counts_pos > mean_wd + 0, year_counts_pos + year_counts_neg - 3, np.nan)

In [None]:
mode_labels = ['ENSO', 'IOD', r'Atl. Ni$\mathrm{\tilde{n}}$o', 'TNA', 'TSA',
                   r'$\mathrm{MJO}_{1}$',
                   r'$\mathrm{MJO}_{2}$',
                   r'$\mathrm{MJO}_{3}$',
                   r'$\mathrm{MJO}_{4}$',
                   r'$\mathrm{MJO}_{5}$',
                   r'$\mathrm{MJO}_{6}$',
                   r'$\mathrm{MJO}_{7}$',
                   r'$\mathrm{MJO}_{8}$']

In [None]:
n_abs_max = 4
time = t2m_lt_growing_optimal.sel(time=slice('1980', '2020')).time.dt.year.values

with plt.rc_context(plt_params):
    cmap = plt.cm.BrBG_r
    norm = matplotlib.colors.BoundaryNorm(np.arange(-n_abs_max, n_abs_max+2), cmap.N)
    c_ = [0.25, 0.75]
    
    fig = plt.figure(figsize=(6.9, 4), dpi=150)
    gs = fig.add_gridspec(ncols=1, nrows=2, height_ratios=[0.5, 1.2])
    
    # =============================== Events bar plot
    ax = fig.add_subplot(gs[0])
    
    ax.bar(time, year_counts_neg, color=cmap(c_[0]), width=.8, zorder=0)
    ax.bar(time, year_counts_pos, bottom=year_counts_neg, color=cmap(c_[1]), width=.8, zorder=0)
        
    ax.scatter(time, cw_exceedances, c='k', s=10, zorder=1, label=r"Cold or wet ($\mu = $"+str(np.round(mean_cw.values, 1))+')')
    ax.scatter(time, wd_exceedances, c='k', s=10, marker='v', zorder=1, label=r"Warm or dry ($\mu = $"+str(np.round(mean_wd.values, 1))+')')
    
    # Custom legend with bar color and markers
    ax.scatter(1980.4858657, 36, marker='s', color=cmap(c_[0]))
    ax.scatter(1980.4858657, 30.915, marker='s', color=cmap(c_[1]))
    ax.legend(loc=(0.01, 0.65), frameon=False)
    
    ax.set_ylim(0, 41)
    ax.set_yticks(range(0, 41, 10))
    ax.set_ylabel('Number of events')

    ax.set_xticks(time)
    xtick_labels = []
    for i in range(8):
        xtick_labels.append(time[::5][i])
        xtick_labels.append('')
        xtick_labels.append('')
        xtick_labels.append('')
        xtick_labels.append('')
    xtick_labels.append(2020)
    ax.set_xticklabels(xtick_labels)
    ax.set_xlim(1979.5, 2020.5)
    
    ax.text(-0.1, 0.95, 'a', weight='bold', transform=ax.transAxes)

    # =============================== Modes time series heatmap
    modes = (modes_concat / modes_concat.std('time')).mean('season_id')
    
    ax = fig.add_subplot(gs[1])
    
    p = ax.pcolormesh(modes, cmap='RdBu_r', vmin=-2.5, vmax=2.5)
    
    ax.set_ylim(0, 13)
    ax.set_yticks(np.arange(0.5, 13, 1))
    ax.set_yticklabels(mode_labels)
    ax.invert_yaxis()
    
    ax.set_xticks(np.arange(0.5, len(time)))
    ax.set_xticklabels(xtick_labels)
    ax.tick_params(axis="x", bottom=True, top=True, labelbottom=True, labeltop=False)
    
    cb_ax1 = fig.add_axes([0.91, 0.125, 0.017, 0.48])
    cb1 = fig.colorbar(p, cax=cb_ax1, orientation='vertical', ticks=np.arange(-2.5, 2.6, 0.5))
    cb1.ax.set_ylabel('Standardised anomaly [-]', rotation=270, va='bottom')
    
    ax.text(-.1, 0.95, 'b', weight='bold', transform=ax.transAxes)
    
    plt.savefig('./figures/events_modes_timeseries_detrended.pdf', format='pdf', dpi=400, bbox_inches='tight')

# Events and modes scatter plots

### Significance testing

Obtain 10,000 block-bootstrap samples, and calculate mean of each sample. If observed mean is outside [5,95] percentile, deem significant

In [None]:
def mean_over_time(X):
    """
    Mean over time of X
    """
    mean = np.mean(X)
    return xr.DataArray(mean)

In [None]:
m = (modes_concat / modes_concat.std('time')).mean('season_id')

In [None]:
block_lengths = fn.estimate_L(m).astype('int')
# block_lengths = [int(i) for i in block_lengths]
unique_block_lengths = np.sort(np.unique(block_lengths))
unique_block_lengths

In [None]:
resamples_list = []
for L in unique_block_lengths:
    bootstraps = block_bootstrap(m, blocks={'time': L}, n_iteration=10000, exclude_dims=None, circular=False)

    resamples = xr.apply_ufunc(mean_over_time, bootstraps, input_core_dims=[['time']], output_core_dims=[[]],
                                dask='forbidden', vectorize=True)
    resamples_list.append(resamples.assign_coords({'L': L}))
resamples = xr.concat(resamples_list, dim='L')

In [None]:
cw_mode_avg = (modes_concat / modes_concat.std('time')).mean('season_id').where(cw_exceedances.notnull()).mean('time')
wd_mode_avg = (modes_concat / modes_concat.std('time')).mean('season_id').where(wd_exceedances.notnull()).mean('time')
avgs = xr.concat([cw_mode_avg.expand_dims({'category': ['CW']}),
                      wd_mode_avg.expand_dims({'category': ['WD']})],
                     dim='category')

In [None]:
quantiles = np.full_like(avgs.values, np.nan)
for i, cat in enumerate(avgs.category):
    for j, mode in enumerate(avgs.mode):
        L = block_lengths.sel(mode=mode)
        pc = fn.get_quantile(avgs.sel(category=cat, mode=mode),
                             resamples.sel(L=L, mode=mode))
        quantiles[i,j] = pc
quantiles = xr.DataArray(quantiles,
                           coords=avgs.coords,
                           dims=avgs.dims)

In [None]:
y1 = np.array([i for i in range(len(quantiles.sel(category='CW'))) if (quantiles.sel(category='CW').isel(mode=i) < 0.05) \
               | (quantiles.sel(category='CW').isel(mode=i) > 0.95)]) + 0.5
x1 = np.repeat(0.5, len(y1))

In [None]:
y2 = np.array([i for i in range(len(quantiles.sel(category='WD'))) if (quantiles.sel(category='WD').isel(mode=i) < 0.05) \
               | (quantiles.sel(category='WD').isel(mode=i) > 0.95)]) + 0.5
x2 = np.repeat(1.5, len(y2))

In [None]:
with plt.rc_context(plt_params):
    
    cmap = plt.cm.BrBG_r
    c_ = [0.25, 0.75]
    
    fig = plt.figure(figsize=(6.9, 3), dpi=150)
    gs = fig.add_gridspec(ncols=30, nrows=2)
    
    # =============================== Mode averages heatmap   
    ax = fig.add_subplot(gs[:,0:2])
    
    p = ax.pcolormesh(avgs.transpose('mode', 'category'), cmap='RdBu_r', vmin=-0.7, vmax=0.7)
    
    ax.scatter(x1, y1, color='k', edgecolor='white', s=12, lw=.7)
    ax.scatter(x2, y2, color='k', edgecolor='white', s=12, lw=.7)
    
    ax.set_xticks([0.5, 1.5])
    ax.set_xticklabels(['Cold/wet', 'Warm/dry'], rotation=60)
    
    ax.set_yticks(np.arange(0.5, 13, 1))
    ax.set_yticklabels(mode_labels)
    ax.invert_yaxis()
    
    cb_ax1 = fig.add_axes([0.18, 0.13, 0.015, 0.75])
    cb1 = fig.colorbar(p, cax=cb_ax1, orientation='vertical', ticks=np.arange(-0.7, 0.71, 0.35))
    cb1.ax.set_ylabel('Standardised anomaly [-]', rotation=270, va='bottom')
    
    ax.text(-1.6, 0.95, 'a', weight='bold', transform=ax.transAxes)
    
    # =============================== Mode-events scatter
    
    def scatter(x, y, color, label=None):
        ax.scatter(x, y, color=color, edgecolor='k', lw=0.07, s=20, alpha=0.8, label=label)
    
    # Nino34
    x = modes_concat.sel(mode='nino34').mean('season_id')
    ax = fig.add_subplot(gs[0,9:17])
    ax.axvline(0, c='k', lw=0.8, zorder=0)
    scatter(x, year_counts_neg, cmap(c_[0]), label='Cold or wet')
    scatter(x, year_counts_pos, cmap(c_[1]), label='Warm or dry')
    
    ax.set_xlim(-1.4, 1.4)
    ax.set_xticks(np.arange(-1.4, 1.5, 0.7))
    
    ax.tick_params(axis="y", left=False, right=True, labelleft=False, labelright=False)
    
    ax.text(0.04, 0.87, 'b', weight='bold', transform=ax.transAxes)
    ax.text(0.13, 0.87, mode_labels[0], transform=ax.transAxes)
    
    # MJO 1
    x = modes_concat.sel(mode='mjo_dpm_p1').mean('season_id')
    ax = fig.add_subplot(gs[0,19:27])
    ax.axvline(0, c='k', lw=0.8, zorder=0)
    scatter(x, year_counts_neg, cmap(c_[0]))
    scatter(x, year_counts_pos, cmap(c_[1]))
    
    ax.set_xlim(-5, 5)
    ax.set_xticks(np.arange(-5, 5.1, 2.5))
    
    ax.tick_params(axis="y", left=False, right=True, labelleft=False, labelright=True)
    ax.set_ylabel('Hazards per year\n', rotation=270)
    ax.yaxis.labelpad = 10
    ax.yaxis.set_label_position("right")
    
    ax.text(0.04, 0.87, 'c', weight='bold', transform=ax.transAxes)
    ax.text(0.13, 0.87, mode_labels[5], transform=ax.transAxes)
    
    # TNA
    x = modes_concat.sel(mode='tna').mean('season_id')
    ax = fig.add_subplot(gs[1,9:17])
    ax.axvline(0, c='k', lw=0.8, zorder=0)
    scatter(x, year_counts_neg, cmap(c_[0]))
    scatter(x, year_counts_pos, cmap(c_[1]))
    
    ax.set_xlim(-0.7, 0.7)
    ax.set_xticks(np.arange(-0.7, 0.8, 0.35))
    ax.set_xlabel(r'SST anomaly [$^\circ$C]')
    
    ax.tick_params(axis="y", left=False, right=True, labelleft=False, labelright=False)
    
    ax.text(0.04, 0.87, 'd', weight='bold', transform=ax.transAxes)
    ax.text(0.13, 0.87, mode_labels[3], transform=ax.transAxes)
    
    # MJO 4
    x = modes_concat.sel(mode='mjo_dpm_p4').mean('season_id')
    ax = fig.add_subplot(gs[1,19:27])
    ax.axvline(0, c='k', lw=0.8, zorder=0)
    scatter(x, year_counts_neg, cmap(c_[0]), label='Cold or wet')
    scatter(x, year_counts_pos, cmap(c_[1]), label='Warm or dry')
    
    ax.set_xlim(-3.2, 3.2)
    ax.set_xticks(np.arange(-3.2, 3.3, 1.6))
    ax.set_xlabel(r'$\mathrm{MJO}_{i}$ [days per month]')
    
    ax.tick_params(axis="y", left=False, right=True, labelleft=False, labelright=True)
    ax.set_ylabel('Hazards per year', rotation=270)
    ax.yaxis.labelpad = 10
    ax.yaxis.set_label_position("right")
    
    ax.text(0.04, 0.87, 'e', weight='bold', transform=ax.transAxes)
    ax.text(0.13, 0.87, mode_labels[8], transform=ax.transAxes)
    
    ax.legend(loc=(-0.65, -0.55), frameon=False, ncol=2)
    
    plt.savefig('./figures/events_modes_average.pdf', format='pdf', dpi=400, bbox_inches='tight')

# Scatter plot indices and events

In [None]:
def scatter_events_modes():
    """
    Scatter plot of events versus modes
    """
    with plt.rc_context(plt_params):
        fig, ax = plt.subplots(4, 4, figsize=(6.9, 6.9), dpi=150)
        
        for i, mode in enumerate(modes_concat.mode.values):
            x = modes_concat.sel(mode=mode).mean('season_id')
        
            ax.flatten()[i].axvline(0, color='k')
            ax.flatten()[i].scatter(x, year_counts_neg, color=cmap(c_[0]), s=10)
            ax.flatten()[i].scatter(x, year_counts_pos, color=cmap(c_[1]), s=10)
            ax.flatten()[i].text(0.05, 0.9, mode_labels[i], transform=ax.flatten()[i].transAxes)
            
        plt.tight_layout()

In [None]:
scatter_events_modes()

# Correlation of modes and surface variables

In [None]:
mode_labels = ['ENSO', 'IOD', r'Atl. Ni$\mathrm{\tilde{n}}$o', 'TNA', 'TSA',
                   r'$\mathrm{MJO}_{1}$',
                   r'$\mathrm{MJO}_{2}$',
                   r'$\mathrm{MJO}_{3}$',
                   r'$\mathrm{MJO}_{4}$',
                   r'$\mathrm{MJO}_{5}$',
                   r'$\mathrm{MJO}_{6}$',
                   r'$\mathrm{MJO}_{7}$',
                   r'$\mathrm{MJO}_{8}$']

In [None]:
# Remove duplicate Colombia and Uganda
country_subset = copy.deepcopy(country_order)
country_subset.pop('CO_2')
country_subset.pop('UG_13')
country_subset = list(country_subset.keys())

arabica_subset = [i for i in country_subset if i in arabica_season_ids]
robusta_subset = [i for i in country_subset if i in robusta_season_ids]

In [None]:
def sfc_mode_cor(sfc_ds, sfc_var, dataset, mode_ds, mode_var):
    """
    Correlation of surface variables with climate modes
    """
    mask = fn.get_combined_mask(dataset)
    
    da_list = []
    for s_id in sfc_ds.season_id.values:
        abbrev = s_id.split('_')[0]

        sfc_da = sfc_ds[sfc_var].sel(season_id=s_id)
        sfc_da = sfc_da.where(mask.sel(abbrevs=abbrev) == True, drop=False)

        mode_da = mode_ds[mode_var].sel(season_id=s_id)

#         cor = xr.corr(sfc_da, mode_da, dim='time')
        cor = xs.spearman_r(sfc_da, mode_da, dim='time')
        da_list.append(cor)

    cor_da = xr.concat(da_list, dim='season_id')
    cor_da = cor_da.sum('season_id', skipna=True)
    cor_da = cor_da.where(mask.sum('abbrevs'))
    
    return cor_da

In [None]:
def plot_cor(da_list, text_list, sup_titles, save_fig, filename):    
    """
    Plot correlation maps
    """
    n_brazil = fn.get_n_Brazil_boundary()
    
    with plt.rc_context(plt_params):
        if len(da_list) < 7:
            figsize = (6.9, 3.1)
            nrows = 3
            cbar_height = 0.025
        elif (len(da_list) >= 7) & (len(da_list) < 9):
            figsize=(6.9,4.1)
            nrows = 4
            cbar_height = 0.02
        else:
            figsize=(6.9,5)
            nrows = 5
            cbar_height = 0.015
            
        fig = plt.figure(figsize=figsize, dpi=200)
        gs = fig.add_gridspec(nrows=nrows, ncols=2)

        for i, da in enumerate(da_list):
            ax = fig.add_subplot(gs[i], projection=ccrs.PlateCarree())
            ax.set_extent((-117, 142, 36, -35), crs=ccrs.PlateCarree())
            ax.coastlines(lw=plt_params['lines.linewidth']/3)
            
            n_brazil.boundary.plot(ax=ax, color='r', lw=plt_params['lines.linewidth']/4)
            ax.plot((26, 26), (8, -90), color='r', ls='-', lw=plt_params['lines.linewidth']/2)
            ax.plot((48, 48), (90, 0), color='r', ls='-', lw=plt_params['lines.linewidth']/2)
            ax.plot((26.5, 48), (8, 0), color='r', ls='-', lw=plt_params['lines.linewidth']/2)
            
            ax.add_feature(cartopy.feature.BORDERS, lw=plt_params['lines.linewidth']/4)
            ax.add_feature(cartopy.feature.LAND, facecolor='lightgrey')
            ax.text(0.27, 0.87, text_list[i], ha='center',
                    transform=ax.transAxes,
                    fontsize=plt_params['font.size'])

            p = da.plot(ax=ax, cmap='PiYG', vmin=-1, vmax=1, add_colorbar=False,
                       rasterized=True)
            
            if i == 0:
                ax.text(0.35, 0.05, 'Arabica', transform=ax.transAxes)
                ax.text(0.65, 0.05, 'Robusta', transform=ax.transAxes)
            
            if i in [0,1]:
                if len(sup_titles) == 2:
                    ax.text(0.5, 1.05, sup_titles[i], ha='center', transform=ax.transAxes)
            
        cb_ax1 = fig.add_axes([0.17, 0.08, 0.7, cbar_height])
        cb1 = fig.colorbar(p, cax=cb_ax1, orientation='horizontal', ticks=np.arange(-1, 1.1, 0.2))
        cb1.ax.set_xlabel('Pearson correlation [-]')

        plt.subplots_adjust(wspace=0.02, hspace=0.0)
        
        if save_fig:
            plt.savefig('./figures/'+filename, format='pdf', dpi=400, bbox_inches='tight') 

## Correlation with rainfall

In [None]:
precip_nino34_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'precip_detrended',
                                 'gpcc',
                                 nino34_growing.sel(time=slice('1980', '2020')),
                                 'nino34_detrended').compute()
precip_dmi_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'precip_detrended',
                                 'gpcc',
                                 dmi_growing.sel(time=slice('1980', '2020')),
                                 'dmi_detrended').compute()
precip_atl_nino_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'precip_detrended',
                                 'gpcc',
                                 atl_nino_growing.sel(time=slice('1980', '2020')),
                                 'atl_nino_detrended').compute()
precip_tna_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'precip_detrended',
                                 'gpcc',
                                 tna_growing.sel(time=slice('1980', '2020')),
                                 'tna_detrended').compute()
precip_tsa_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'precip_detrended',
                                 'gpcc',
                                 tsa_growing.sel(time=slice('1980', '2020')),
                                 'tsa_detrended').compute()

In [None]:
precip_mjo_p1_dpm_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'precip_detrended',
                                     'gpcc',
                                     mjo_days_per_month_growing.sel(phase_ID=1,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
precip_mjo_p2_dpm_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'precip_detrended',
                                     'gpcc',
                                     mjo_days_per_month_growing.sel(phase_ID=2,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
precip_mjo_p3_dpm_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'precip_detrended',
                                     'gpcc',
                                     mjo_days_per_month_growing.sel(phase_ID=3,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
precip_mjo_p4_dpm_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'precip_detrended',
                                     'gpcc',
                                     mjo_days_per_month_growing.sel(phase_ID=4,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
precip_mjo_p5_dpm_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'precip_detrended',
                                     'gpcc',
                                     mjo_days_per_month_growing.sel(phase_ID=5,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
precip_mjo_p6_dpm_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'precip_detrended',
                                     'gpcc',
                                     mjo_days_per_month_growing.sel(phase_ID=6,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
precip_mjo_p7_dpm_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'precip_detrended',
                                     'gpcc',
                                     mjo_days_per_month_growing.sel(phase_ID=7,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
precip_mjo_p8_dpm_cor = sfc_mode_cor(tp_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'precip_detrended',
                                     'gpcc',
                                     mjo_days_per_month_growing.sel(phase_ID=8,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()

In [None]:
plot_cor([precip_nino34_cor,
          precip_dmi_cor,
         precip_atl_nino_cor,
         precip_tna_cor,
         precip_tsa_cor],
         mode_labels[:5],
         sup_titles=[],
        save_fig=True, filename='precip_ocean_mode_cor.pdf')

In [None]:
plot_cor([precip_mjo_p1_dpm_cor,
          precip_mjo_p2_dpm_cor,
         precip_mjo_p3_dpm_cor,
         precip_mjo_p4_dpm_cor,
         precip_mjo_p5_dpm_cor,
         precip_mjo_p6_dpm_cor,
         precip_mjo_p7_dpm_cor,
         precip_mjo_p8_dpm_cor],
         mode_labels[5:],
         sup_titles=[],
        save_fig=True, filename='precip_mjo_cor.pdf')

## Correlation with temperature

In [None]:
t2m_nino34_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'temperature_detrended',
                                 'berkeley',
                                 nino34_growing.sel(time=slice('1980', '2020')),
                                 'nino34_detrended').compute()
t2m_dmi_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'temperature_detrended',
                                 'berkeley',
                                 dmi_growing.sel(time=slice('1980', '2020')),
                                 'dmi_detrended').compute()
t2m_atl_nino_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'temperature_detrended',
                                 'berkeley',
                                 atl_nino_growing.sel(time=slice('1980', '2020')),
                                 'atl_nino_detrended').compute()
t2m_tna_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'temperature_detrended',
                                 'berkeley',
                                 tna_growing.sel(time=slice('1980', '2020')),
                                 'tna_detrended').compute()
t2m_tsa_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                              time=slice('1980', '2020')),
                                 'temperature_detrended',
                                 'berkeley',
                                 tsa_growing.sel(time=slice('1980', '2020')),
                                 'tsa_detrended').compute()

In [None]:
t2m_mjo_p1_dpm_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'temperature_detrended',
                                     'berkeley',
                                     mjo_days_per_month_growing.sel(phase_ID=1,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
t2m_mjo_p2_dpm_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'temperature_detrended',
                                     'berkeley',
                                     mjo_days_per_month_growing.sel(phase_ID=2,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
t2m_mjo_p3_dpm_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'temperature_detrended',
                                     'berkeley',
                                     mjo_days_per_month_growing.sel(phase_ID=3,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
t2m_mjo_p4_dpm_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'temperature_detrended',
                                     'berkeley',
                                     mjo_days_per_month_growing.sel(phase_ID=4,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
t2m_mjo_p5_dpm_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'temperature_detrended',
                                     'berkeley',
                                     mjo_days_per_month_growing.sel(phase_ID=5,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
t2m_mjo_p6_dpm_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'temperature_detrended',
                                     'berkeley',
                                     mjo_days_per_month_growing.sel(phase_ID=6,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
t2m_mjo_p7_dpm_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'temperature_detrended',
                                     'berkeley',
                                     mjo_days_per_month_growing.sel(phase_ID=7,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()
t2m_mjo_p8_dpm_cor = sfc_mode_cor(t2m_lt_growing_optimal.sel(season_id=country_subset,
                                                                  time=slice('1980', '2020')),
                                     'temperature_detrended',
                                     'berkeley',
                                     mjo_days_per_month_growing.sel(phase_ID=8,
                                                                    time=slice('1980', '2020')).drop('phase_ID'),
                                     'mjo_days_per_month_detrended').compute()

In [None]:
plot_cor([t2m_nino34_cor,
          t2m_dmi_cor,
         t2m_atl_nino_cor,
         t2m_tna_cor,
         t2m_tsa_cor],
         mode_labels[:5],
         sup_titles=[],
        save_fig=True, filename='tmp_ocean_mode_cor.pdf')

In [None]:
plot_cor([t2m_mjo_p1_dpm_cor,
          t2m_mjo_p2_dpm_cor,
         t2m_mjo_p3_dpm_cor,
         t2m_mjo_p4_dpm_cor,
         t2m_mjo_p5_dpm_cor,
         t2m_mjo_p6_dpm_cor,
         t2m_mjo_p7_dpm_cor,
         t2m_mjo_p8_dpm_cor],
         mode_labels[5:],
         sup_titles=[],
        save_fig=True, filename='tmp_mjo_cor.pdf')

### Combine some of these for the paper

In [None]:
plot_cor([t2m_nino34_cor,
          precip_nino34_cor,
         t2m_dmi_cor,
         precip_dmi_cor,
         t2m_tna_cor,
         precip_tna_cor,
         t2m_mjo_p1_dpm_cor,
         precip_mjo_p1_dpm_cor,
         t2m_mjo_p4_dpm_cor,
         precip_mjo_p4_dpm_cor],
        [r'$\mathrm{\bf{a}}$ ENSO',
         r'$\mathrm{\bf{b}}$ ENSO',
         r'$\mathrm{\bf{c}}$ IOD',
         r'$\mathrm{\bf{d}}$ IOD',
         r'$\mathrm{\bf{e}}$ TNA',
         r'$\mathrm{\bf{f}}$ TNA',
         r'$\mathrm{\bf{g}}\ \mathrm{MJO}_{1}$',
         r'$\mathrm{\bf{h}}\ \mathrm{MJO}_{1}$',
         r'$\mathrm{\bf{i}}\ \mathrm{MJO}_{4}$',
         r'$\mathrm{\bf{j}}\ \mathrm{MJO}_{4}$'],
         sup_titles=['Temperature', 'Precipitation'],
        save_fig=True, filename='mix_cor.pdf')

# Close cluster

In [None]:
client.close()
cluster.close()