# Preconditioned drought and fire weather

In [None]:
from dask_jobqueue import SLURMCluster
from dask.distributed import Client

On Pearcey, nodes have 20 cores, each with 6GB. More than 20 cores, need more than one worker. `processes` ensures the number of workers.

In [None]:
cluster = SLURMCluster(processes=1,
                       walltime='01:00:00',
                       cores=3,
                       memory='18GB',
                       job_extra=['--qos="express"'])

In [None]:
cluster.scale(jobs=1)

In [None]:
client = Client(cluster)
client

In [None]:
# client.close()
# cluster.close()

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import copy
import geopandas
from shapely.geometry import mapping
import string

import matplotlib as mpl
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import matplotlib.patheffects as PathEffects
from matplotlib.ticker import FormatStrFormatter
from matplotlib.colors import rgb2hex
import matplotlib.ticker as mticker

import cartopy
import cartopy.crs as ccrs
cartopy.config['pre_existing_data_dir'] = '/datasets/work/oa-dcfp/work/squ027/data/cartopy-data'
cartopy.config['data_dir'] = '/datasets/work/oa-dcfp/work/squ027/data/cartopy-data'

In [None]:
import functions as fn

# Plotting parameters

In [None]:
fontsize = 7
coastlines_lw = 0.5
linewidth = 1.1
patheffect_lw_add = linewidth * 1.8
    
plt_params = {'lines.linewidth': linewidth,
              
              'hatch.linewidth': 0.5,
    
              'font.size': fontsize,
              
              'legend.fontsize' : fontsize-1,
              'legend.columnspacing': 0.7,
              'legend.labelspacing' : 0.03,
              'legend.handlelength' : 1.,
             
              'axes.linewidth': 0.5}

# default colours
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

letters = list(string.ascii_lowercase)

# Custom cmap
custom_cmap = fn.get_magma_waterlily_cmap()

# Load masks

### JRA land mask

In [None]:
jra_mask = xr.open_zarr('/scratch1/ric368/projects/fire/data/masks/jra_land_mask.zarr',
                        consolidated=True).land
jra_mask = jra_mask.compute()

### JRA area of grid cells

In [None]:
jra_grid_area = xr.open_zarr('/scratch1/ric368/projects/fire/data/masks/jra_grid_area.zarr').cell_area
jra_grid_area = jra_grid_area.compute()

### Burned area mask

In [None]:
fire_mask = xr.open_zarr('/scratch1/ric368/projects/fire/data/masks/cci_burned_area.zarr',
                         consolidated=True).burned_cells

### Burned area mask on JRA grid

In [None]:
jra_ba_mask = xr.open_zarr('/scratch1/ric368/projects/fire/data/masks/jra_burned_area_mask.zarr',
                            consolidated=True).burned_cells

# Load data and derive other quantities

### FFDI

In [None]:
# Monthly FFDI
ffdi_monthly = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/jra_monthly_ffdi_global_195801_202012.zarr',
                            consolidated=True).FFDI

In [None]:
# Regional monthly FFDI
ar6_ffdi_m = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/jra_ar6_extreme_ffdi_dpm_global_195801_202012.zarr',
                          consolidated=True).ffdi_dpm

In [None]:
# Extreme FFDI days per year
grid_ex_dpy = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/jra_extreme_ffdi_dpy_global_1959_2020.zarr',
                             consolidated=True).ffdi_dpy
# grid_ex_dpy = grid_ex_dpy.compute()

In [None]:
# Normalised FFDI days per year 2001-2019
grid_dpy_norm = (grid_ex_dpy.sel(time=slice('2001', '2019')) - grid_ex_dpy.sel(time=slice('2001', '2019')).mean('time')) / grid_ex_dpy.sel(time=slice('2001', '2019')).std('time')

In [None]:
# Regional extreme FFDI days per year
region_ex_dpy = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/jra_ar6_extreme_ffdi_dpy_global_1959_2020.zarr',
                             consolidated=True).ffdi_dpy
region_ex_dpy = region_ex_dpy.compute()

### Precipitation

In [None]:
# Monthly precip for AR6 regions
ar6_pr = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/gpcc_ar6_precip_monthly_global_1950_2020.zarr',
                      consolidated=True).precip

In [None]:
# Precipitation for fire year
grid_ypr = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/gpcc_fire_year_precip_jra_grid_global_1959_2020.zarr',
                         consolidated=True).precip
# grid_ypr = grid_ypr.compute()

In [None]:
# Normalised fire year precip 2001-2019
grid_ypr_norm = (grid_ypr.sel(time=slice('2001', '2019')) - grid_ypr.sel(time=slice('2001', '2019')).mean('time')) / grid_ypr.sel(time=slice('2001', '2019')).std('time')

In [None]:
# AR6 precipitation over fire year
region_ypr = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/gpcc_ar6_precip_fire_year_global_1957_2020.zarr',
                             consolidated=True).precip_fire_year
region_ypr = region_ypr.compute()

### Burned area

In [None]:
# Monthly burned area
ba = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/cci_burned_area_monthly_global_200101_202004.zarr',
                    consolidated=True).burned_area

In [None]:
# Monthly burned area for AR6 regions
ar6_ba = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/cci_ar6_burned_area_monthly_global_200101_202012.zarr',
                     consolidated=True).burned_area

In [None]:
# Burned area for fire year on JRA grid
ba_jra_fw_season = xr.open_zarr('/scratch1/ric368/projects/fire/data/derived/cci_fire_year_burned_area_jra_grid_global_2001_2020.zarr',
                                consolidated=True).burned_area

In [None]:
# Normalised fire year burned area 2001-2019
ba_jra_fw_norm = (ba_jra_fw_season.sel(time=slice('2001', '2019')) - ba_jra_fw_season.sel(time=slice('2001', '2019')).mean('time')) / ba_jra_fw_season.sel(time=slice('2001', '2019')).std('time')
# ba_jra_fw_norm = ba_jra_fw_norm.compute()

## Statistical significance of burned area anomalies
- We use a permutation test as these are more suitable for hypothesis testing with small samples
- https://stats.stackexchange.com/questions/20217/bootstrap-vs-permutation-hypothesis-testing

In [None]:
def da_mean(da):
    return da.mean('time')

In [None]:
%%time
perm = fn.n_random_resamples(ba_jra_fw_norm, samples={'time': (7, 1)}, n_repeats=5000, function=da_mean, with_dask=True, replace=False)
perm = perm.rename({'k': 'bs_sample'}).compute() # Rename for apply_percentile later

#### Calculate p-values

- Two-tailed test, as testing whether burned area observation is above or below the bootstrapped sample
- If the test statistic from the sample (i.e. the obs) is negative, the p-value for a two-tailed test is 2 * the p-value for the lower-tail p-value. Conversely, for a positive test statistics, the p-value for the two-tailed test is 2 * the p_value for the upper-tail.

In [None]:
%%time
ba_fw = ba_jra_fw_norm.where(grid_dpy_norm > 0.5).mean('time').drop('quantile').compute()
fw_pvals_perm = fn.apply_percentile(ba_fw, perm, absolute=False)

In [None]:
%%time
ba_dry = ba_jra_fw_norm.where(grid_ypr_norm < -0.5).mean('time').compute()
dry_pvals_perm = fn.apply_percentile(ba_dry, perm, absolute=False)

In [None]:
%%time
ba_wet = ba_jra_fw_norm.where(grid_ypr_norm > 0.5).mean('time').compute()
wet_pvals_perm = fn.apply_percentile(ba_wet, perm, absolute=False)

#### Find $\min \{ F(x), 1-F(x) \}$, and use this to find p value
- e.g. https://stats.stackexchange.com/a/277391, https://stats.stackexchange.com/a/140517

In [None]:
fw_2tailp = 2 * xr.ufuncs.minimum(fw_pvals_perm, 1-fw_pvals_perm)
dry_2tailp = 2 * xr.ufuncs.minimum(dry_pvals_perm, 1-dry_pvals_perm)
wet_2tailp = 2 * xr.ufuncs.minimum(wet_pvals_perm, 1-wet_pvals_perm)

#### Apply FDR

In [None]:
alpha = 0.2 # Following alpha_FDR = 2 * alpha_0 convention in Wilks 2016: https://doi.org/10.1175/BAMS-D-15-00267.1

fw_fdr = fn.fdr(fw_2tailp, alpha=alpha)
dry_fdr = fn.fdr(dry_2tailp, alpha=alpha)
wet_fdr = fn.fdr(wet_2tailp, alpha=alpha)

### IPCC AR6 regions

In [None]:
ar6_regions = geopandas.read_file("/scratch1/ric368/data/ar6_regions/IPCC-WGI-reference-regions-v4_shapefile/IPCC-WGI-reference-regions-v4.shp")
regions_subset = ['WNA', 'MED', 'ESB', 'SES', 'WSAF', 'SAU'] # For plotting

# Figure 1: Fire, FFDI climatology and month of maximum activity

In [None]:
def plot_climatology_and_schematic(first_last_year):
    
    # Hack to get vertices of EAU region
    df = ar6_regions.set_index(ar6_regions['Acronym']).loc[['EAU']].geometry
    g = [i for i in df.geometry]
    all_coords = mapping(g[0])["coordinates"] # for first feature/row
    v1 = (145.5, -33) # lower left
    v2 = (145.5, -20) # upper left
    v3 = (155, -20) # upper right
    v4 = (155, -38) # lower right

    # Slice of desired period
    period_slice = slice(first_last_year[0], first_last_year[1])
    
    # Annual average FFDI and burned area
    avg_annual_ba = ba.sel(time=period_slice) \
                      .resample(time='1YS').sum() \
                        .mean('time') \
                        .compute()
    avg_annual_ffdi = ffdi_monthly.sel(time=period_slice).mean('time')
    
    # Average month of maximum FFDI/burned area
    fire_max_month = fn.get_max_month(ba.sel(time=period_slice)).compute()
    ffdi_max_month = fn.get_max_month(ffdi_monthly.sel(time=period_slice)).compute()

    with mpl.rc_context(plt_params):

        figsize = (6.9,4.2)

        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(nrows=9,
                              ncols=20,
                              figure=fig,height_ratios=[1,1,1,1,1,1,.7,.7,.7])

        # ==== Total burned area ---------------------
        ax = fig.add_subplot(gs[:3,:10], projection=ccrs.PlateCarree())
        ax.coastlines(lw=coastlines_lw)
        ax.set_extent([-179.99, 180, 90, -65])
        Z = avg_annual_ba.where(avg_annual_ba > 0)
        p1 = Z.sel(lat=slice(90,-65)).plot(ax=ax,
                                           norm=mpl_colors.LogNorm(vmin=1e-2, vmax=1e3),
                                           cmap='magma',
                                           add_colorbar=False,
                                           rasterized=True,
                                           zorder=0)

        # Hack plot EAU boundary
        for line in [[v1, v2], [v2, v3], [v3, v4], [v4, v1]]:
            ax.plot((line[0][0], line[1][0]),
                    (line[0][1], line[1][1]),
                    color='w',
                    lw=plt_params['lines.linewidth']-.25,
                    zorder=1)
            ax.plot((line[0][0], line[1][0]),
                    (line[0][1], line[1][1]),
                    color='k',
                    lw=plt_params['lines.linewidth']-.75,
                    zorder=2)

        ax.text(0.815, 0.035, 'EAU', transform=ax.transAxes, fontsize=plt_params['font.size']-1)
        ax.plot([145, 155],
                [-55, -40],
                color='k',
                lw=plt_params['lines.linewidth']-0.5)

        ax.text(0.01, 0.98, r'$\bf{a}$',
                ha='left', va='top', transform=ax.transAxes)

        #  ==== Average annual FFDI --------------------
        ax = fig.add_subplot(gs[:3,10:], projection=ccrs.PlateCarree())
        ax.coastlines(lw=coastlines_lw)
        ax.set_extent([-179.99, 180, 90, -65])
        f = avg_annual_ffdi.where(jra_mask == 1)
        p2 = f.plot(ax=ax,
                    norm=mpl_colors.LogNorm(vmin=1e-1, vmax=1e2),
                    cmap='magma',
                    add_colorbar=False,
                    rasterized=True)
        ax.text(0.01, 0.98, r'$\bf{b}$',
                ha='left', va='top', transform=ax.transAxes)

        # BA/FFDI max month ---------------------

        # select desired starting colours
        colours_sel = np.asarray(plt.rcParams['axes.prop_cycle'].by_key()['color'])[[7,4,0,1]]
        # get three grads of lightness
        l_colours = [fn.adjust_lightness(c, a) for c in colours_sel for a in [0.9, 1.2, 1.5]]
        # set up cmap and norm
        cmap = mpl.colors.ListedColormap(l_colours)
        norm = mpl.colors.BoundaryNorm(np.arange(12.01), cmap.N) 

        # ==== BA max month
        ax = fig.add_subplot(gs[3:6,:10], projection=ccrs.PlateCarree())
        ax.coastlines(lw=coastlines_lw)
        ax.set_extent([-179.99, 180, 90, -65])
        ba_max = (fire_max_month.where(fire_max_month < 12, 0) * fire_mask.where(fire_mask == 1)) \
                  .sel(lat=slice(90, -60))
        ba_max.plot(ax=ax,
                    levels=range(12),
                    cmap=cmap,
                    norm=norm,
                    add_colorbar=False,
                    rasterized=True)
        ax.text(0.01, 0.98, r'$\bf{c}$',
                ha='left', va='top', transform=ax.transAxes)

        # ==== FFDI max month
        ax = fig.add_subplot(gs[3:6,10:], projection=ccrs.PlateCarree())
        ax.coastlines(lw=coastlines_lw)
        ax.set_extent([-179.99, 180, 90, -65])
        ffdi_max = ffdi_max_month.where(ffdi_max_month < 12, 0) * jra_mask
        p3 = ffdi_max.plot(ax=ax,
                           levels=range(12),
                           cmap=cmap,
                           norm=norm,
                           add_colorbar=False,
                           rasterized=True)
        ax.text(0.01, 0.98, r'$\bf{d}$',
                ha='left', va='top', transform=ax.transAxes)

        # ==== Schematic of FW/precip years
        ax = fig.add_subplot(gs[7:,:18])
        ax2 = ax.twinx()
        ax3 = ax.twinx()

        ar6_pr.sel(region='EAU', time=slice('2018', '2020')).plot(ax=ax,
                                                                  color=colors[0],
                                                                  alpha=0.3,
                                                                  lw=plt_params['lines.linewidth'],
                                                                 zorder=7)
        ar6_pr.sel(region='EAU', time=slice('2018-11', '2019-10')).plot(ax=ax,
                                                                        color=colors[0],
                                                                        lw=plt_params['lines.linewidth'],
                                                                        label='Precipitation',
                                                                        path_effects=[PathEffects.Stroke(linewidth=patheffect_lw_add,
                                                                                                         foreground='white'),
                                                                                      PathEffects.Normal()],
                                                                       zorder=8)


        ar6_ffdi_m.sel(region='EAU', time=slice('2018', '2020')).plot(ax=ax2,
                                                                      color=colors[1],
                                                                      alpha=0.3,
                                                                      lw=plt_params['lines.linewidth'],
                                                                     zorder=7)
        ar6_ffdi_m.sel(region='EAU', time=slice('2019-04', '2020-03')).plot(ax=ax2,
                                                                            color=colors[1],
                                                                            lw=plt_params['lines.linewidth'],
                                                                            path_effects=[PathEffects.Stroke(linewidth=patheffect_lw_add,
                                                                                                             foreground='white'),
                                                                                          PathEffects.Normal()],
                                                                           zorder=8)


        ax3.spines['right'].set_position(("axes", 1.13))
        ar6_ba.sel(region='EAU', time=slice('2018', '2020-04')).plot(ax=ax3,
                                                                      color=cm.get_cmap('magma')(0.3),
                                                                      alpha=0.3,
                                                                      lw=plt_params['lines.linewidth'],
                                                                    zorder=7)
        ar6_ba.sel(region='EAU', time=slice('2019-04', '2020-03')).plot(ax=ax3,
                                                                        color=cm.get_cmap('magma')(0.3),
                                                                        lw=plt_params['lines.linewidth'],
                                                                        path_effects=[PathEffects.Stroke(linewidth=patheffect_lw_add,
                                                                                                         foreground='white'),
                                                                                      PathEffects.Normal()],
                                                                       zorder=8)

        ax.set_xlim(pd.to_datetime('2017-12'), pd.to_datetime('2021-01'))
        ax2.set_xlim(pd.to_datetime('2017-12'), pd.to_datetime('2021-01'))
        ax3.set_xlim(pd.to_datetime('2017-12'), pd.to_datetime('2021-01'))

        ax3.axvspan(pd.to_datetime('2018-11'), pd.to_datetime('2019-10'),
                    0.02, 1.08,
                    clip_on=False,
                    fc='None',
                    ec=colors[0],
                    ls='--',
                   lw=plt_params['lines.linewidth']-.5,
                  zorder=10)
        ax3.axvspan(pd.to_datetime('2019-04'), pd.to_datetime('2020-03'),
                    0.08, 1.02,
                    clip_on=False,
                    fc='None',
                    ec=colors[1],
                    ls='--',
                    lw=plt_params['lines.linewidth']-.5,
                    zorder=10)
        ax3.axvspan(pd.to_datetime('2019-03-27'), pd.to_datetime('2020-03-06'),
                    0.05, 1.05,
                    clip_on=False,
                    fc='None',
                    ec=cm.get_cmap('magma')(0.3),
                    ls='--',
                    lw=plt_params['lines.linewidth']-.5,
                    zorder=10)

        ax2.text(pd.to_datetime('2018-09-22'),
                38,
                'Precipitation for 2020 event',
                ha='right',
                fontsize=plt_params['font.size']-1,
                color=colors[0])
        ax2.text(pd.to_datetime('2020-04-15'),
                 38,
                 'FWD for 2020 event',
                 ha='left',
                 fontsize=plt_params['font.size']-1,
                 color=colors[1])
        ax2.text(pd.to_datetime('2019-07'),
                 43,
                 'October: average peak month of FFDI in EAU',
                 ha='center',
                 fontsize=plt_params['font.size']-1,
                 color='k')
        ax2.text(pd.to_datetime('2020-04-15'),
                 43,
                 'Burned area for 2020 event',
                 ha='left',
                 fontsize=plt_params['font.size']-1,
                 color=cm.get_cmap('magma')(0.3))

        ax2.plot([pd.to_datetime('2018-09-28'), pd.to_datetime('2018-10-28')],
                [38, 33],
                color=colors[0],
                lw=plt_params['lines.linewidth']-.5,
                clip_on=False)
        ax2.plot([pd.to_datetime('2020-03-06'), pd.to_datetime('2020-04-11')],
                [32, 38],
                color=colors[1],
                 lw=plt_params['lines.linewidth']-.5,
                clip_on=False)
        ax2.plot([pd.to_datetime('2020-03-06'), pd.to_datetime('2020-04-11')],
                [33.5, 39.5],
                color=cm.get_cmap('magma')(0.3),
                 lw=plt_params['lines.linewidth']-.5,
                clip_on=False)
        ax2.plot([pd.to_datetime('2019-10'), pd.to_datetime('2019-10')],
                [35, 41],
                color='k',
                 lw=plt_params['lines.linewidth']-.5,
                clip_on=False)

        ax.set_ylim(-22,180)
        ax2.set_ylim(-5,30)
        ax3.set_ylim(-2500,20000)

        ax.set_yticks(np.arange(0, 181, 60))
        ax2.set_yticks(np.arange(0, 31, 10))
        ax3.set_yticks(np.arange(0, 20100, 10000))
        ax3.ticklabel_format(axis="y", style="sci", scilimits=(0, 2))
        ax3.get_yaxis().get_offset_text().set_position((1.14, 0))

        ax.spines['top'].set_visible(False)
        ax2.spines['top'].set_visible(False)

        ax.set_title('')
        ax2.set_title('')
        ax3.set_title('')

        ax.set_ylabel('Precipitation\n[mm]')
        ax2.set_ylabel('FWD [days]')
        ax3.set_ylabel('Burned area\n'+r'[km$^{2}$]')

        ax.set_xlabel('')
        ax2.set_xlabel('')
        ax3.set_xlabel('')

        ax.set_xticks(pd.date_range('2018-01', '2020-12', freq='MS'))
        ax.set_xticklabels(np.concatenate([['Jan\n2018'],
                                            np.repeat('', 5),
                                            ['Jul'],
                                            np.repeat('', 5),
                                            ['Jan\n2019'],
                                            np.repeat('', 5),
                                            ['Jul'],
                                            np.repeat('', 5),
                                            ['Jan\n2020'],
                                            np.repeat('', 5),
                                            ['Jul'],
                                            np.repeat('', 5)]),
                          rotation=0,
                          ha='center')

        ax.text(0.01, 0.98, r'$\bf{e}$',
                ha='left', va='top', transform=ax.transAxes)

        cb_ax = fig.add_axes([0.108, 0.61, 0.013, 0.26])
        cb = fig.colorbar(p1, cax=cb_ax, orientation='vertical', ticks=[1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4, 1e5])
        cb.ax.yaxis.set_ticks_position('left')
        cb.ax.set_ylabel(r'Burned area [km$^{2}$]')
        cb.ax.yaxis.set_label_position('left')

        cb_ax = fig.add_axes([0.905, 0.61, 0.013, 0.26])
        cb = fig.colorbar(p2, cax=cb_ax, orientation='vertical', ticks=[1e-1, 1e0, 1e1, 1e2])
        cb.ax.set_ylabel('FFDI')

        cb_ax = fig.add_axes([0.905, 0.33, 0.013, 0.26])
        cb = fig.colorbar(p3, cax=cb_ax, orientation='vertical', ticks=np.arange(0.5, 12.5, 1))
        cb.ax.set_yticklabels(['Dec', '', 'Feb', '', 'Apr', '', 'Jun', '', 'Aug', '', 'Oct', ''])
        cb.ax.set_ylabel('Month')

    plt.subplots_adjust(hspace=0.001, wspace=0.1)

    plt.savefig('./figures/climatology_and_schematic_'+first_last_year[0]+'-'+first_last_year[1]+'.pdf', bbox_inches='tight', format='pdf', dpi=400)

In [None]:
%%time
plot_climatology_and_schematic(['2001', '2019'])

# Figure 2 and S1: Burned area composited on anomalous FFDI/drought/joint years

In [None]:
def plot_burned_area_anoms(plot_significance):

    figsize = (6.6, 2.9) 
    cmap = custom_cmap
    boundary_lw = .5
    lim = 2.4
    n_std = 0.5

    # ==== Dummy cmap and another custom cmap for wet vs dry panel
    # select desired starting colours
    colours = np.asarray(['#FFFFFF', rgb2hex(cmap(0.3)), rgb2hex(cmap(0.75)), '#FFFFFF'])
    # set up cmap and norm
    ind_cmap = mpl.colors.ListedColormap(colours)
    ind_norm = mpl.colors.BoundaryNorm(np.arange(4.01), ind_cmap.N) 
    dummy_cmap = mpl.colors.ListedColormap(colours[[1,2]])

    with plt.rc_context(plt_params):

        dummy_ = xr.full_like(grid_ypr.isel(time=1), 0)

        fig, ax = plt.subplots(2, 2, figsize=figsize, subplot_kw={'projection': ccrs.PlateCarree()})

        d1 = ba_jra_fw_norm.where(grid_ypr_norm < -n_std).mean('time') # Dry
        d2 = ba_jra_fw_norm.where(grid_dpy_norm > n_std).mean('time') # Heightened fire weather
        d3 = ba_jra_fw_norm.where(grid_ypr_norm > n_std).mean('time') # Wet
        
        if plot_significance:
            d1 = d1.where(dry_fdr == 1)
            d2 = d2.where(fw_fdr == 1)
            d3 = d3.where(wet_fdr == 1)

        # === BA on precip < -n_std
        p1 = d1.plot(ax=ax[0,0], cmap=cmap, zorder=1, vmin=-lim, vmax=lim, add_colorbar=False, rasterized=True)

        # === BA on FFDI > n_std
        p2 = d2.plot(ax=ax[0,1], cmap=cmap, zorder=1, vmin=-lim, vmax=lim, add_colorbar=False, rasterized=True)

        # === BA on precip > n_std
        p3 = d3.plot(ax=ax[1,0], cmap=cmap, zorder=1, vmin=-lim, vmax=lim, add_colorbar=False, rasterized=True)

        # === Wet vs Dry
        if plot_significance is False:
            dry_fw_minus_fw = ba_jra_fw_norm.where((grid_dpy_norm > n_std) & (grid_ypr_norm < -n_std)).mean('time') - ba_jra_fw_norm.where(grid_dpy_norm > n_std).mean('time')
            wet_fw_minus_fw = ba_jra_fw_norm.where((grid_dpy_norm > n_std) & (grid_ypr_norm > n_std)).mean('time') - ba_jra_fw_norm.where(grid_dpy_norm > n_std).mean('time')
            ind = (dry_fw_minus_fw > 0).astype('int') * 2 + (wet_fw_minus_fw > 0).astype('int')
            dummy_ = ind.where((ind == 1) | (ind == 2))
            p4 = dummy_.plot(cmap=dummy_cmap, add_colorbar=False, zorder=1)
            ind.where(jra_ba_mask == 1).plot(ax=ax[1,1], cmap=ind_cmap, zorder=1, add_colorbar=False, rasterized=True)

        long_labels = [r'Dry: $P_a < \mu_{P_a} - \sigma_{P_a} / 2$',
                       r'$\mathrm{FWD}_f > \mu_{\mathrm{FWD}_f} + \sigma_{\mathrm{FWD}_f} / 2$',
                       r'Wet: $P_a > \mu_{P_a} + \sigma_{P_a} / 2$']
        props = dict(facecolor='w', edgecolor='w', pad=.1)
        for a, label, long_label in zip(ax.flatten()[:3], [r'$\bf{a}$', r'$\bf{b}$', r'$\bf{c}$'], long_labels):
            a.coastlines(zorder=3, lw=coastlines_lw)
            a.add_feature(cartopy.feature.LAND, color='gainsboro', zorder=0)
            a.set_extent([-179.99, 180, 90, -65])
            a.set_title('')
            a.text(0.01, 0.98, label, ha='left', va='top', transform=a.transAxes)
            a.text(0.6, 0.02, long_label, ha='center', va='bottom', bbox=props, transform=a.transAxes, zorder=10)
            
        if plot_significance is False:
            ax[1,1].coastlines(zorder=3, lw=coastlines_lw)
            ax[1,1].add_feature(cartopy.feature.LAND, color='gainsboro', zorder=0)
            ax[1,1].set_extent([-179.99, 180, 90, -65])
            ax[1,1].set_title('')
            ax[1,1].text(0.01, 0.98, r'$\bf{d}$', ha='left', va='top', transform=a.transAxes)
            ax[1,1].text(0.6, 0.02, r'Increased $\mathrm{BA}_f ''$', ha='center', va='bottom', bbox=props, transform=ax[1,1].transAxes, zorder=10)
    #         ar6_regions.set_index(ar6_regions['Acronym']).loc[regions_subset].boundary.plot(ax=a, color='black', lw=plt_params['lines.linewidth']-.75)

    #     # Region labels
    #     label_locs = [[0.028, 0.65],
    #                   [0.37, 0.62],
    #                   [0.91, 0.57],
    #                   [0.21, 0.19],
    #                   [0.41, 0.25],
    #                   [0.85, 0.03]]
    #     for r, loc in zip(regions_subset, label_locs):
    #         ax[0,0].text(loc[0], loc[1], r, transform=ax[0,0].transAxes)

        cb_ax = fig.add_axes([0.11, 0.53, 0.011, 0.33])
        cb = fig.colorbar(p1, cax=cb_ax, orientation='vertical', ticks=np.arange(-2.4, 2.5, 0.8), extend='neither')
        cb.ax.set_ylabel("$\mathrm{BA}_f ' / \sigma_{\mathrm{BA}_f}$ [-]")
        cb.ax.yaxis.set_ticks_position('left')
        cb.ax.yaxis.set_label_position('left')

        cb_ax = fig.add_axes([0.905, 0.53, 0.011, 0.33])
        cb = fig.colorbar(p2, cax=cb_ax, orientation='vertical', ticks=np.arange(-2.4, 2.5, 0.8))
        cb.ax.set_ylabel("$\mathrm{BA}_f ' / \sigma_{\mathrm{BA}_f}$ [-]")

        cb_ax = fig.add_axes([0.11, 0.15, 0.011, 0.33])
        cb = fig.colorbar(p3, cax=cb_ax, orientation='vertical', ticks=np.arange(-2.4, 2.5, 0.8))
        cb.ax.set_ylabel("$\mathrm{BA}_f ' / \sigma_{\mathrm{BA}_f}$ [-]")
        cb.ax.yaxis.set_ticks_position('left')
        cb.ax.yaxis.set_label_position('left')

        if plot_significance:
            ax[1,1].axis('off')
        else:
            cb_ax = fig.add_axes([0.905, 0.15, 0.011, 0.33])
            cb = fig.colorbar(p4, cax=cb_ax, orientation='vertical', ticks=[1.25, 1.75])
            cb.set_ticklabels(['Wet', 'Dry'])

        plt.subplots_adjust(hspace=0.02, wspace=0.00)

        if plot_significance:
            plt.savefig('./figures/burned_area_anomalies_on_ffdi_drought_years_with_significance.pdf', bbox_inches='tight', format='pdf', dpi=400)
        else:
            plt.savefig('./figures/burned_area_anomalies_on_ffdi_drought_years.pdf', bbox_inches='tight', format='pdf', dpi=400)

In [None]:
plot_burned_area_anoms(False)

In [None]:
plot_burned_area_anoms(True)

# Figure 3: preconditioned compound events on grid and time series

In [None]:
mask = jra_ba_mask
bad_col = 'gainsboro'
plt.rcParams["xtick.major.size"] = 2 # temporarily change tick length

figsize = (7,6)

with plt.rc_context(plt_params):

    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(nrows=6, ncols=6, figure=fig,height_ratios=[2.5, 1, 1, 0.06, 0.3, 0.7])
    
    # ============================= Maps
    # === Prep data
    pr = grid_ypr.sel(time=slice('1958', '2020')).sel(lat=slice(90, -65))

    ffdi_thresh = grid_ex_dpy.mean(dim='time') + grid_ex_dpy.std(dim='time')
    pr_thresh = pr.mean(dim='time') - pr.std(dim='time')

    # Compound (preconditioned) events
    cc_da = xr.where((pr < pr_thresh) &
                     (grid_ex_dpy > ffdi_thresh),
                     1,
                     0)

    # Number of compound events
    cc_sum = cc_da.sel(time=slice('1970', '2020')).sum('time').rename('n events')
    cc_max = cc_sum.max().values

    # Difference of most recent 25 years and preceding 25 years
    p_early = slice('1971', '1995')
    p_late = slice('1996', '2020')
    cc_diff = (cc_da.sel(time=p_late).sum('time') - cc_da.sel(time=p_early).sum('time')).rename('Difference in n events')
    cc_mag_max = np.max([np.abs(cc_diff.min().values), np.abs(cc_diff.max().values)])
    
    # ========== Plot
    # ===== n compound events
    current_cmap = copy.copy(mpl.cm.get_cmap('magma_r'))
    current_cmap.set_under(bad_col)
    
    ax = fig.add_subplot(gs[0, :3], projection=ccrs.PlateCarree())
    ax.set_extent([-179.99, 180, 90, -60])
    ax.coastlines(lw=coastlines_lw, zorder=2)
    ax.add_feature(cartopy.feature.OCEAN, color='w', zorder=1)
    
    cc_p = cc_sum.where(mask == 1)
    cc_p = xr.where(cc_p.isnull(), -999, cc_p) # Set NaNs to -999 to help with greying out
    cc_p = cc_p.plot(ax=ax, levels=range(9), extend='neither', add_colorbar=False, cmap=current_cmap, zorder=0, rasterized=True)
    
    ar6_regions.set_index(ar6_regions['Acronym']).loc[regions_subset].boundary.plot(ax=ax, color='black', lw=plt_params['lines.linewidth']-.75)
    
    cb_ax = fig.add_axes([0.165, 0.61, 0.3, 0.01])
    cb = fig.colorbar(cc_p, cax=cb_ax, orientation='horizontal', ticks=np.arange(0.5, 9.6))
    cb.ax.set_xticklabels(range(8))
    cb.ax.set_xlabel('Number of preconditioned events [years]')
    
    ax.set_title('')
    
    # Region labels
    label_locs = [[0.028, 0.65],
                  [0.37, 0.62],
                  [0.91, 0.65],
                  [0.22, 0.15],
                  [0.47, 0.05],
                  [0.78, 0.04]]
    for r, loc in zip(regions_subset, label_locs):
        ax.text(loc[0], loc[1], r,
                transform=ax.transAxes)
        
    ax.text(0.01, 0.98, r'$\bf{a}$',
            ha='left', va='top', transform=ax.transAxes)

    gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                      linewidth=.5, color='k', alpha=0.5, linestyle='--')
    gl.xlocator = mticker.FixedLocator([])
    gl.ylocator = mticker.FixedLocator([23.44, -23.44])

    # Draw arrows on first map indicating zones
    ax.arrow(-.02, 0.78,
             0, 0.2,
             length_includes_head=True,
             width=0.001,
             head_width=0.01,
             color=colors[9],
             transform=ax.transAxes,
             clip_on=False)
    ax.arrow(-.02, 0.78,
             0, -0.22,
             length_includes_head=True,
             width=0.001,
             head_width=0.01,
             color=colors[9],
             transform=ax.transAxes,
             clip_on=False)
    ax.text(-0.05, 0.78,
            'N. Hem',
            rotation=90,
            transform=ax.transAxes,
            horizontalalignment='center',
            verticalalignment='center',
            color=colors[9],
            fontsize=plt_params['font.size']-1,
            path_effects=[PathEffects.withStroke(linewidth=2, foreground="w")],
            zorder=5,
            clip_on=False)
    
    ax.arrow(-.02, 0.4,
             0, 0.13,
             length_includes_head=True,
             width=0.001,
             head_width=0.01,
             color=colors[6],
             transform=ax.transAxes,
             clip_on=False)
    ax.arrow(-.02, 0.4,
             0, -0.15,
             length_includes_head=True,
             width=0.001,
             head_width=0.01,
             color=colors[6],
             transform=ax.transAxes,
             clip_on=False)
    ax.text(-0.05, 0.4,
            'Tropics',
            rotation=90,
            transform=ax.transAxes,
            horizontalalignment='center',
            verticalalignment='center',
            color=colors[6],
            fontsize=plt_params['font.size']-1,
            path_effects=[PathEffects.withStroke(linewidth=2, foreground="w")],
            zorder=5,
            clip_on=False)
    
    ax.arrow(-.02, 0.11,
             0, 0.12,
             length_includes_head=True,
             width=0.001,
             head_width=0.01,
             color=colors[2],
             transform=ax.transAxes,
             clip_on=False)
    ax.arrow(-.02, 0.11,
             0, -0.12,
             length_includes_head=True,
             width=0.001,
             head_width=0.01,
             color=colors[2],
             transform=ax.transAxes,
             clip_on=False)
    ax.text(-0.05, 0.085,
            'S. Hem.',
            rotation=90,
            transform=ax.transAxes,
            horizontalalignment='center',
            verticalalignment='center',
            color=colors[2],
            fontsize=plt_params['font.size']-1,
            path_effects=[PathEffects.withStroke(linewidth=2, foreground="w")],
            zorder=5,
            clip_on=False)

    
    
    # ===== n compound events change
    current_cmap = copy.copy(custom_cmap)
    current_cmap.set_under(bad_col)
        
    ax = fig.add_subplot(gs[0, 3:], projection=ccrs.PlateCarree())
    ax.set_extent([-179.99, 180, 90, -60])
    ax.coastlines(lw=coastlines_lw, zorder=2)
    ax.add_feature(cartopy.feature.OCEAN, color='w', zorder=1)
    
    diff_p = cc_diff.where(mask == 1)
    diff_p = xr.where(diff_p.isnull(), -999, diff_p)
    diff_p = diff_p.plot(ax=ax, levels=range(-7, 7+1, 2), extend='neither', add_colorbar=False, cmap=current_cmap, zorder=0, rasterized=True)

    ar6_regions.set_index(ar6_regions['Acronym']).loc[regions_subset].boundary.plot(ax=ax, color='black', lw=plt_params['lines.linewidth']-.75)
    
    cb_ax = fig.add_axes([0.565, 0.61, 0.3, 0.01])
    cb = fig.colorbar(diff_p, cax=cb_ax, orientation='horizontal')
    cb.ax.set_xlabel('Difference in number of preconditioned events [years]')
    
    gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                      linewidth=.5, color='k', alpha=0.5, linestyle='--')
    gl.xlocator = mticker.FixedLocator([])
    gl.ylocator = mticker.FixedLocator([23.44, -23.44])
    gl.xlabel_style = {'size': plt_params['font.size']}
    gl.ylabel_style = {'size': plt_params['font.size']}
    ax.text(-179, 23.5, r'$23.44^{\circ}$N', fontsize=plt_params['font.size']-1, va='bottom')
    ax.text(-179, -26, r'$23.44^{\circ}$S', fontsize=plt_params['font.size']-1, va='top')
    ax.set_title('')
    
    ax.text(0.01, 0.98, r'$\bf{b}$',
            ha='left', va='top', transform=ax.transAxes)
    
    
    # ============================= Region time series
    # === Data prep
    pr = region_ypr.sel(time=slice('1958', '2020'))
    
    ffdi_thresh = region_ex_dpy.mean(dim='time') + region_ex_dpy.std(dim='time')
    pr_thresh = pr.mean(dim='time') - pr.std(dim='time')

    # Compound (preconditioned) events
    cc_da = xr.where((pr < pr_thresh) &
                     (region_ex_dpy > ffdi_thresh),
                     1,
                     0)
        
    # Selected regions plots
    panel_labels = [r'$\bf{c}$', r'$\bf{d}$', r'$\bf{e}$', r'$\bf{f}$', r'$\bf{g}$', r'$\bf{h}$']
    for i, r in enumerate(regions_subset):

        # === Data setup - normalise and set equivalent thresholds
        d_pr = pr.sel(region=r)
        d_pr = -fn.normalise(d_pr) + 1 + 0.1 # invert and + 1 to make dry=high
        d_pr_t = d_pr.mean() + d_pr.std()

        d_ffdi = region_ex_dpy.sel(region=r)
        d_ffdi = fn.normalise(d_ffdi) + 0.1
        d_ffdi_t = d_ffdi.mean() + d_ffdi.std()

        # Shift threshholds to be equivalent
        t_diff = d_pr_t - d_ffdi_t
        d_ffdi = d_ffdi + t_diff

        # === Plot
        row_loc = (i >= 3) + 1
        col_loc = (i % 3) * 2
        ax = fig.add_subplot(gs[row_loc, col_loc:col_loc+2])
        
        # Precip
        d_pr.plot(ax=ax, color=colors[0], zorder=3)

        # Threshold
        ax.axhline(d_pr_t, ls='-', color='black', zorder=2, lw=.6)

        # FFDI
        d_ffdi.plot(ax=ax, color=colors[1], zorder=4)
        
        # Compound events
        cc = cc_da.sel(region=r)
        ylim = ax.get_ylim()
        cc = xr.where(cc == 1, ylim[1]+ylim[1]*0.05, 0)
        ax.bar(pd.to_datetime(cc.time.values), cc, width=200, color='gray', zorder=1, alpha=0.5)

        # Extras
        ax.set_ylim(0, ylim[1]+ylim[1]*0.05)
        ax.set_yticks([d_pr_t])
        ax.tick_params(length=2)

        fy = pd.to_datetime(d_ffdi.time.values[0]) + pd.DateOffset(years=-2)
        ly = pd.to_datetime(d_ffdi.time.values[-1]) + pd.DateOffset(years=1)
        ax.set_xlim(fy, ly)
        ax.set_xticks(pd.date_range('1960', '2020', freq='10YS'))
        ax.set_xticklabels(pd.date_range('1960', '2020', freq='10YS').year)
        ax.text(0.01, 0.9, panel_labels[i],
                va='center', transform=ax.transAxes,
                path_effects=[PathEffects.Stroke(linewidth=patheffect_lw_add, foreground='white'), PathEffects.Normal()])

        ax.set_title('')

        # axis labels
        if i % 3 == 0:
            ax.set_yticklabels([r'$\mu \pm \sigma$'])
        else:
            ax.set_yticklabels([''])
        if i <= (3 * 3) - 3:
            ax.set_xlabel('')
            ax.set_xticklabels('')

        ax.text(0.011, 0.05, r, transform=ax.transAxes,
                path_effects=[PathEffects.withStroke(linewidth=patheffect_lw_add, foreground="w")],
                zorder=5,
               fontsize=plt_params['font.size']-1)
        
        # Draw arrows on first time series plot
        if i == 0:
            ax.arrow(-.03, 0.6,
                     0, 0.4,
                     length_includes_head=True,
                     width=0.001,
                     head_width=0.015,
                     color='k',
                     transform=ax.transAxes,
                     clip_on=False)
            ax.text(-0.07, 1.01,
                    r'Lower $P_a^r$',
                    transform=ax.transAxes,
                    horizontalalignment='right',
                    verticalalignment='top',
                    color=colors[0],
                    fontsize=plt_params['font.size'],
                    path_effects=[PathEffects.withStroke(linewidth=2, foreground="w")],
                    zorder=5,
                    clip_on=False)
            ax.text(-0.07, 0.82,
                    r'Higher $\mathrm{FWD}_f^r$',
                    transform=ax.transAxes,
                    horizontalalignment='right',
                    verticalalignment='top',
                    color=colors[1],
                    fontsize=plt_params['font.size'],
                    path_effects=[PathEffects.withStroke(linewidth=2, foreground="w")],
                    zorder=5,
                    clip_on=False)
        
    # ============================= All grid cells summary plots
    ax1 = fig.add_subplot(gs[4:, 0:2])
    ax2 = fig.add_subplot(gs[4:, 2:4])
    # Split bottom right into two
    ax3_top = fig.add_subplot(gs[4, 4:6])
    ax3 = fig.add_subplot(gs[5, 4:6])
    
    r_colors = [colors[6], colors[2], colors[9], 'black']
    r_lw = [plt_params['lines.linewidth'], plt_params['lines.linewidth'], plt_params['lines.linewidth'], plt_params['lines.linewidth']+3]
    r_pe_foreground = ['black', 'black', 'black', 'white']
    r_labels = ['NH', 'SH', 'Tr', 'All']
    
    pr = grid_ypr.sel(time=slice('1959', '2020'))

    ffdi_thresh = grid_ex_dpy.mean(dim='time') + grid_ex_dpy.std(dim='time')
    pr_thresh = pr.mean(dim='time') - pr.std(dim='time')

    # Compound (preconditioned) events
    cc_da = xr.where((pr < pr_thresh) &
                     (grid_ex_dpy > ffdi_thresh),
                     1,
                     0)

    nh_lim = 23.44
    sh_lim = -23.44

    ### NH, SH, Tr and global summaries ------------
    r_colors = [colors[9], colors[2], colors[6], 'black']
    r_lw = [plt_params['lines.linewidth'], plt_params['lines.linewidth'], plt_params['lines.linewidth'], plt_params['lines.linewidth']+.5]
    r_pe_foreground = ['white', 'white', 'white', 'white']
    r_labels = ['N. Hem.', 'S. Hem.', 'Tropics', 'All']
    
    mask = jra_ba_mask #dry_fw_mask

    for ir, r_slice in enumerate([slice(90, nh_lim), slice(sh_lim, -60), slice(nh_lim, sh_lim), slice(90, -60)]):
        
        # Total area of hemisphere
        r_tot = jra_grid_area.where(mask == 1).sel(lat=r_slice).sum()


        # Proportion of regions in drought
        d = (xr.where(pr.where(mask == 1).sel(lat=r_slice) < pr_thresh, True, False) * jra_grid_area.sel(lat=r_slice)).sum(['lat', 'lon']) / r_tot
        d.plot(ax=ax1,
               lw=r_lw[ir],
               color=r_colors[ir],
               label=r_labels[ir])

        # Proportion of regions with extreme FFDI
        if (r_labels[ir] == 'Tropics') | (r_labels[ir] == 'All'):
            d = grid_ex_dpy.sel(time=slice('1970', '2021'))
        else:
            d = grid_ex_dpy
            
        d = (xr.where(d.where(mask == 1).sel(lat=r_slice) > ffdi_thresh, True, False) * jra_grid_area.sel(lat=r_slice)).sum(['lat', 'lon']) / r_tot
        if r_labels[ir] == 'Tropics':
            d = d.sel(time=slice('1970', '2021'))
        d.plot(ax=ax2,
               lw=r_lw[ir],
               color=r_colors[ir],
               label=r_labels[ir])

        # Proportion of regions with compound extreme years
        if (r_labels[ir] == 'Tropics') | (r_labels[ir] == 'All'):
            d = cc_da.sel(time=slice('1970', '2021'))
        else:
            d = cc_da
        d = (xr.where(d.where(mask == 1).sel(lat=r_slice) == 1, True, False) * jra_grid_area.sel(lat=r_slice)).sum(['lat', 'lon']) / r_tot
        if r_labels[ir] == 'Tropics':
            d = d.sel(time=slice('1970', '2021'))
        d.plot(ax=ax3,
               lw=r_lw[ir],
               color=r_colors[ir],
               label=r_labels[ir])
        # For SH compound, we want to draw on the upper subplot as the data go up much higher
        #  We use a split y-axis
        d.plot(ax=ax3_top,
               lw=r_lw[ir],
               color=r_colors[ir])
        if r_labels[ir] == 'S. Hem.':
            axis_max = d.max().values

    # Axes extras
    xticks = pd.date_range('1960', '2020', freq='10YS')
    xticklabels = ['']
    for y in xticks[1::2]:
        xticklabels.append(str(y.year))
        xticklabels.append('')
    #xticklabels = xticklabels[:-1]
    for a in [ax1, ax2, ax3, ax3_top]:
        a.set_xlim(fy, ly)
        a.set_xlabel('')
        a.set_xticks(xticks)
        a.tick_params(length=2)
        a.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    for a in [ax1, ax2, ax3]:
        a.set_xticklabels(xticklabels, rotation=0, ha='center')
        
    for a, t in zip([ax1, ax2], [r'$\bf{i}$ Drought', r'$\bf{j}$ Fire weather']):
        a.text(0.02, 0.96, t,
               horizontalalignment='left',
               verticalalignment='top',
               transform=a.transAxes,
               fontsize=plt_params['font.size'])
        a.set_title('')
    ax3_top.text(0.02, 0.85, r'$\bf{k}$ Preconditioned',
               horizontalalignment='left',
               verticalalignment='top',
               transform=ax3_top.transAxes,
               fontsize=plt_params['font.size'])
    ax3_top.set_title('')
        
    ax1.set_ylim(-0.02,0.72)
    ax1.set_yticks([0, 0.35, 0.7])
    ax1.set_ylabel('Proportion [-]')
    ax1.legend(ncol=2, loc='upper right', bbox_to_anchor=(1., 1.05), frameon=False, fontsize=plt_params['font.size']-1)

    ax2.set_ylim(-0.02,0.72)
    ax2.set_yticks([0, 0.35, 0.7])
    ax2.set_yticklabels('')
    
    ax3.set_ylim(-0.02,0.18)
    ax3.set_yticks([0, 0.08, 0.16])
    ax3.set_yticks
    ax3.yaxis.set_ticks_position('right')
    ax3.yaxis.set_label_position("right")
    ax3.spines['top'].set_visible(False)
    ax3.tick_params(labeltop=False)  # don't put tick labels at the top
    ax3.xaxis.tick_bottom()
    ax3.set_title('')
    
    ax3_top.set_ylim(np.round(axis_max, 2)-.05, axis_max+.01)
    ax3_top.set_yticks([float(axis_max)])
    ax3_top.set_yticklabels([np.round(axis_max, 2)])
    ax3_top.yaxis.set_ticks_position('right')
    ax3_top.yaxis.set_label_position("right")
    ax3_top.spines['bottom'].set_visible(False)
    ax3_top.set_xticks([])
    ax3_top.set_xticklabels('')
    ax3_top.set_xlabel('')

    # Axis line breaks
    d_ = .015  # how big to make the diagonal lines in axes coordinates
    # arguments to pass to plot, just so we don't keep repeating them
    kwargs = dict(transform=ax3_top.transAxes, color='k', clip_on=False)
    ax3_top.plot((-d_, d_), (-d_-.05, d_*4-.05), lw=plt_params['axes.linewidth'], **kwargs) # top-left diagonal
    ax3_top.plot((1 - d_, 1 + d_), (-d_-0.05, d_*4-0.05), lw=plt_params['axes.linewidth'], **kwargs) # top-right diagonal

    kwargs.update(transform=ax3.transAxes)  # switch to the bottom axes
    ax3.plot((-d_, +d_), (1 - d_, 1 + d_), lw=plt_params['axes.linewidth'], **kwargs)  # bottom-left diagonal
    ax3.plot((1 - d_, 1 + d_), (1 - d_, 1 + d_), lw=plt_params['axes.linewidth'], **kwargs)  # bottom-right diagonal

    # ===========================
    
    plt.rcParams["xtick.major.size"] = 3.5 #  change tick length back to default
    
    plt.subplots_adjust(hspace=0.06, wspace=0.08)
    
    plt.savefig('./figures/preconditioned_compound_events.pdf', bbox_inches='tight', format='pdf', dpi=400)

# Close cluster

In [None]:
client.close()
cluster.close()