# mcdc_analysis_d22a

## Purpose
Using Monte Carlo Drift Correction (MCDC), analyse data produced by [data_d22a.ipynb](https://github.com/grandey/d22a-mcdc/blob/main/data_d22a.ipynb), including production of figures and tables.

## Input data requirements
NetCDF files in [data/](https://github.com/grandey/d22a-mcdc/tree/main/data/) (produced by [data_d22a.ipynb](https://github.com/grandey/d22a-mcdc/blob/main/data_d22a.ipynb)), each containing a global mean time series for a given variable, AOGCM variant, and CMIP6 experiment.

## Output files written
Figures (TODO) and tables (TODO).

## History
BSG, 2022.

In [1]:
! date

Wed Aug 17 14:43:01 +08 2022


In [2]:
from functools import cache
import itertools
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pathlib
from scipy import stats
import statsmodels.api as sm
import xarray as xr

In [3]:
# Matplotlib settings
%matplotlib inline
plt.rcParams['savefig.dpi'] = 300

In [4]:
# Package versions
for p in [xr, np, pd, sm, xr]:
    print(f'{p.__name__}: {p.__version__}')

xarray: 2022.6.0
numpy: 1.23.1
pandas: 1.4.3
statsmodels.api: 0.13.2
xarray: 2022.6.0


In [5]:
# Random number generator
rng = np.random.default_rng(12345)
rng

Generator(PCG64) at 0x170393220

## Identify AOGCM variants (source-member pairs)
Note: the AOGCM variants identified should match those identified by data_d22a.ipynb.

In [6]:
# Location of data produced by data_d22a.ipynb
in_base = pathlib.Path.cwd() / 'data' / 'regrid_missto0_yearmean_fldmean_mergetime'

# Core variables required
core_var_list = ['rsdt', 'rsut', 'rlut', # R = rsdt-rsut-rlut
                 'hfds',  # H (without flux correction)
                 'zostoga']  # Z

# Experiments required (with corresponding names, used for figs later)
exp_dict = {'piControl': 'Control', 'historical': 'Historical',
            'ssp126': 'SSP1-2.6', 'ssp245': 'SSP2-4.5',
            'ssp370': 'SSP3-7.0', 'ssp585': 'SSP5-8.5'}

# Identify source-member pairs to use
source_member_list = sorted([d.name for d in in_base.glob(f'rsdt/[!.]*_*')])  # this list will be reduced
for source_member in source_member_list.copy():  # loop over copy of source-member pairs to check data availability
    for var in core_var_list:  # loop over required variables
        for exp in exp_dict.keys():  # loop over experiments
            #in_fns = sorted(in_base.glob(f'{var}/{source_member}/{var}_{source_member}_{exp}.mergetime.nc'))
            in_fn = in_base.joinpath(f'{var}/{source_member}/{var}_{source_member}_{exp}.mergetime.nc')
            if not in_fn.is_file():  # if input file for this experiment does not exist...
                try:
                    source_member_list.remove(source_member)  # ... do not use this source-member pair
                except ValueError:  # when source-member pair has previously been removed
                    pass

print(f'{len(source_member_list)} source-member pairs identified.')
source_member_list

20 source-member pairs identified.


['ACCESS-CM2_r1i1p1f1',
 'ACCESS-ESM1-5_r1i1p1f1',
 'CMCC-CM2-SR5_r1i1p1f1',
 'CMCC-ESM2_r1i1p1f1',
 'CNRM-CM6-1_r1i1p1f2',
 'CNRM-ESM2-1_r1i1p1f2',
 'CanESM5_r1i1p1f1',
 'EC-Earth3-Veg-LR_r1i1p1f1',
 'EC-Earth3-Veg_r1i1p1f1',
 'EC-Earth3_r1i1p1f1',
 'GISS-E2-1-G_r1i1p5f1',
 'GISS-E2-1-H_r1i1p1f2',
 'IPSL-CM6A-LR_r1i1p1f1',
 'MIROC6_r1i1p1f1',
 'MPI-ESM1-2-HR_r1i1p1f1',
 'MPI-ESM1-2-LR_r1i1p1f1',
 'MRI-ESM2-0_r1i1p1f1',
 'NorESM2-LM_r1i1p1f1',
 'NorESM2-MM_r1i1p1f1',
 'UKESM1-0-LL_r1i1p1f2']

## Read input data

In [7]:
%%time
# Dictionary to hold input DataArrays
in_da_dict = {}  # keys will be tuples of (source_member, exp, var)

# List of input data variables 
in_var_list = ['rsdt', 'rsut', 'rlut',  # R = rsdt-rsut-rlut
               'hfds',  # H (without flux correction)
               'hfcorr',  # flux correction, available for very few source-member pairs
               'zostoga']  # Z

# Loop over source-member pairs, experiments, and variables
for source_member in source_member_list:
    for exp in exp_dict.keys():
        for var in in_var_list:
            # Read input data (if they exist)
            in_fn = in_base.joinpath(f'{var}/{source_member}/{var}_{source_member}_{exp}.mergetime.nc')
            try:
                in_ds = xr.open_dataset(in_fn)  # Dataset
                in_da = in_ds[var]  # DataArray
                # Remove degenerate lon and lat dimensions
                in_da = in_da.squeeze()
                # Convert time units to year
                in_da['time'] = (in_da['time'] // 1e4).astype(int)
                in_da['time'].attrs['units'] = 'a'
                # Convert zostoga units to mm
                if var == 'zostoga':
                    in_da.data = in_da.data * 1e3
                    in_da.attrs['units'] = 'mm'
                # Check: do data have non-zero values?
                if (in_da**2).sum() == 0:
                    print(f'Skipping {source_member} {exp} {var} (no non-zero values)')
                else:
                    # Save to dictionary
                    in_da_dict[(source_member, exp, var)] = in_da
            except FileNotFoundError:
                pass

print(f'in_da_dict contains {len(in_da_dict)} DataArrays')

in_da_dict contains 606 DataArrays
CPU times: user 1.5 s, sys: 46.2 ms, total: 1.54 s
Wall time: 1.65 s


## Basic processing of input data

In [8]:
# Are there any gaps (missing years) in the data coverage?
for source_member in source_member_list:
    for exp in exp_dict.keys():
        for var in in_var_list:
            try:
                in_da = in_da_dict[(source_member, exp, var)]
                # Is the interval betweeen successive time coords always 1 year?
                intervals = in_da.time.data[1:] - in_da.time.data[:-1]
                if not np.all(intervals == 1):
                    # If a gap is found, limit to period before gap
                    gap_i = int(np.where(intervals != 1)[0])  # identify first gap
                    gap_yr = in_da.time.data[gap_i]  # final year before gap
                    len_old = len(in_da)  # old length
                    in_da = in_da.where(in_da.time <= gap_yr, drop=True)  # limit data
                    in_da_dict[(source_member, exp, var)] = in_da  # update dict
                    len_new = len(in_da)  # new length
                    print(f'{source_member} {exp} {var} had missing years; using period before gap; '
                          f'length {len_old} -> {len_new}.')
            except KeyError:
                pass

IPSL-CM6A-LR_r1i1p1f1 piControl hfds had missing years; using period before gap; length 1800 -> 1000.


In [9]:
# For a given source-member and experiment, is the time coverage consistent between core variables?
for source_member in source_member_list:
    for exp in exp_dict.keys():
        try:
            # Time coords for primary variables of interest
            time1 = in_da_dict[(source_member, exp, 'rsdt')].time.data
            time2 = in_da_dict[(source_member, exp, 'rsut')].time.data
            time3 = in_da_dict[(source_member, exp, 'rlut')].time.data
            time4 = in_da_dict[(source_member, exp, 'hfds')].time.data
            time5 = in_da_dict[(source_member, exp, 'zostoga')].time.data
            # Are the time coords the same?
            if not (np.array_equal(time1, time2) and np.array_equal(time1, time3) and
                    np.array_equal(time1, time4) and np.array_equal(time1, time5)):
                print(f'{source_member} {exp} has inconsistent time coord')
                # Are the start years the same?
                start_list = [t[0] for t in [time1, time2, time3, time4, time5]]
                if len(set(start_list)) > 1:
                    print(f'  Start years differ: {start_list}')
                start_max = max(start_list)  # earliest year available for all
                # Are the end years the same?
                end_list = [t[-1] for t in [time1, time2, time3, time4, time5]]
                if len(set(end_list)) > 1:
                    print(f'  End years differ: {end_list}')
                end_min = min(end_list)  # latest year availabe for all
                # Limit to shared period
                print(f'  Limiting to shared period of {start_max}-{end_min} ({end_min-start_max+1} years)')
                for var in in_var_list:  # limit time coord for all variables (including hfcorr)
                    try:
                        in_da = in_da_dict[(source_member, exp, var)]
                        in_da = in_da.where((in_da.time >= start_max) & (in_da.time <= end_min), drop=True)
                        in_da_dict[(source_member, exp, var)] = in_da
                    except KeyError:
                        pass
        except KeyError:
            pass

CanESM5_r1i1p1f1 ssp585 has inconsistent time coord
  End years differ: [2300, 2300, 2300, 2300, 2180]
  Limiting to shared period of 1850-2180 (331 years)
EC-Earth3-Veg_r1i1p1f1 piControl has inconsistent time coord
  End years differ: [3849, 3849, 3849, 2349, 2349]
  Limiting to shared period of 1850-2349 (500 years)
IPSL-CM6A-LR_r1i1p1f1 piControl has inconsistent time coord
  End years differ: [3849, 3849, 3849, 2849, 3849]
  Limiting to shared period of 1850-2849 (1000 years)
MIROC6_r1i1p1f1 piControl has inconsistent time coord
  End years differ: [3999, 3999, 3999, 3699, 3999]
  Limiting to shared period of 3200-3699 (500 years)
UKESM1-0-LL_r1i1p1f2 piControl has inconsistent time coord
  End years differ: [3839, 3839, 3839, 3839, 3059]
  Limiting to shared period of 1960-3059 (1100 years)


In [10]:
# Shift PI control start year to 1 (arbitrary)
for source_member in source_member_list:
    for var in in_var_list:
        try:
            in_da = in_da_dict[(source_member, 'piControl', var)]
            in_da['time'] = in_da['time'] - in_da['time'][0] + 1  # shift
        except IndexError:
            print('IndexError encountered:', source_member, var, in_da)
        except KeyError:
            pass

In [11]:
# How much PI control data are available for each source-member pair?
for source_member in source_member_list:
    in_da = in_da_dict[(source_member, 'piControl', 'zostoga')]
    print(f'{source_member} piControl has {len(in_da)} years')

ACCESS-CM2_r1i1p1f1 piControl has 500 years
ACCESS-ESM1-5_r1i1p1f1 piControl has 1000 years
CMCC-CM2-SR5_r1i1p1f1 piControl has 500 years
CMCC-ESM2_r1i1p1f1 piControl has 500 years
CNRM-CM6-1_r1i1p1f2 piControl has 500 years
CNRM-ESM2-1_r1i1p1f2 piControl has 500 years
CanESM5_r1i1p1f1 piControl has 1000 years
EC-Earth3-Veg-LR_r1i1p1f1 piControl has 501 years
EC-Earth3-Veg_r1i1p1f1 piControl has 500 years
EC-Earth3_r1i1p1f1 piControl has 501 years
GISS-E2-1-G_r1i1p5f1 piControl has 201 years
GISS-E2-1-H_r1i1p1f2 piControl has 451 years
IPSL-CM6A-LR_r1i1p1f1 piControl has 1000 years
MIROC6_r1i1p1f1 piControl has 500 years
MPI-ESM1-2-HR_r1i1p1f1 piControl has 500 years
MPI-ESM1-2-LR_r1i1p1f1 piControl has 1000 years
MRI-ESM2-0_r1i1p1f1 piControl has 701 years
NorESM2-LM_r1i1p1f1 piControl has 501 years
NorESM2-MM_r1i1p1f1 piControl has 500 years
UKESM1-0-LL_r1i1p1f2 piControl has 1100 years


In [12]:
# Limit SSPs to 2100 for consistency
for source_member in source_member_list:
    for exp in exp_dict.keys():
        if 'ssp' in exp:
            for var in in_var_list:
                try:
                    time_data = in_da_dict[(source_member, exp, var)].time.data
                    if time_data[-1] > 2100: 
                        in_da = in_da_dict[(source_member, exp, var)]
                        in_da = in_da.where(in_da.time <= 2100, drop=True)
                        in_da_dict[(source_member, exp, var)] = in_da
                except KeyError:
                    pass

In [13]:
# Correct discontinuities in MRI-ESM2-0_r1i1p1f1 zostoga SSP time series
for source_member in ['MRI-ESM2-0_r1i1p1f1',]:
    # Using historical time series, extrapolate to 2015 based on 2014-2013 diff
    in_da = in_da_dict[(source_member, 'historical', 'zostoga')]
    diff = in_da[-1].data - in_da[-2].data
    extrap_2015 = in_da[-1].data + diff
    # Shift SSP data from 2015 onwards to match extrapolation bridging discontinuity
    for exp in exp_dict.keys():
        if 'ssp' in exp:
            in_da = in_da_dict[(source_member, exp, 'zostoga')]
            old_2015 = in_da.sel(time=2015).data  # current value for 2015
            correction = extrap_2015 - old_2015  # correction to apply
            new_da = xr.concat([in_da.sel(time=slice(1850,2014)),
                                (in_da.sel(time=slice(2015,2100))+correction)], dim='time')  # apply correction
            new_2015 = new_da.sel(time=2015).data  # new value for 2015
            in_da_dict[(source_member, exp, 'zostoga')] = new_da
            print(f'{source_member} {exp} zostoga year-2015 shifted from {old_2015:.4f} to {new_2015:.4f}')

MRI-ESM2-0_r1i1p1f1 ssp126 zostoga year-2015 shifted from 2.0144 to 69.3741
MRI-ESM2-0_r1i1p1f1 ssp245 zostoga year-2015 shifted from 2.1230 to 69.3741
MRI-ESM2-0_r1i1p1f1 ssp370 zostoga year-2015 shifted from 2.3462 to 69.3741
MRI-ESM2-0_r1i1p1f1 ssp585 zostoga year-2015 shifted from 1.2473 to 69.3741


## Calculate uncorrected $R$, ${\int}R$, $H$, ${\int}H$, and $\Delta Z$ from input data

In [14]:
# Dictionary to hold DataArrays
da_dict = {}  # keys will be tuples of (source_member, exp, var)

# Loop over source-member pairs and experiments
for source_member in source_member_list:
    for exp in exp_dict.keys():
        # R = rsdt-rsut-rlut
        r_da = (in_da_dict[(source_member, exp, 'rsdt')]
                - in_da_dict[(source_member, exp, 'rsut')]
                - in_da_dict[(source_member, exp, 'rlut')])
        da_dict[(source_member, exp, 'R')] = r_da  # R
        da_dict[(source_member, exp, 'R')].attrs['units'] = 'W m$^{-2}$'
        # \int R
        da_dict[(source_member, exp, '\int R')] = r_da.cumsum()
        da_dict[(source_member, exp, '\int R')].attrs['units'] = 'W m$^{-2}$ a'  # a is year
        # H = hfds+hfcorr (ie apply flux correction if it is non-zero)
        try:
            # If hfcorr exists, add it to hfds
            in_da = in_da_dict[(source_member, exp, 'hfcorr')].copy()
            print(f'{source_member} {exp} has hfcorr (mean={in_da.mean().data:0.3f}, std={in_da.std().data:0.3f})')
            da_dict[(source_member, exp, 'H')] = in_da_dict[(source_member, exp, 'hfds')] + in_da
        except KeyError:
            # If hfcorr does not exist, assume it is zero and just use hfds
            da_dict[(source_member, exp, 'H')] = in_da_dict[(source_member, exp, 'hfds')].copy()
        da_dict[(source_member, exp, 'H')].attrs['units'] = 'W m$^{-2}$'
        # \int H
        da_dict[(source_member, exp, '\int H')] = da_dict[(source_member, exp, 'H')].cumsum()
        da_dict[(source_member, exp, '\int H')].attrs['units'] = 'W m$^{-2}$ a'
        # \Delta Z = zostoga (using first year as reference for zero)
        z_da = in_da_dict[(source_member, exp, 'zostoga')].copy()
        z_da -= z_da[0]
        da_dict[(source_member, exp, '\Delta Z')] = z_da

print(f'da_dict contains {len(da_dict)} DataArrays')

MRI-ESM2-0_r1i1p1f1 piControl has hfcorr (mean=0.288, std=0.010)
MRI-ESM2-0_r1i1p1f1 historical has hfcorr (mean=0.283, std=0.013)
MRI-ESM2-0_r1i1p1f1 ssp126 has hfcorr (mean=0.263, std=0.030)
MRI-ESM2-0_r1i1p1f1 ssp245 has hfcorr (mean=0.259, std=0.038)
MRI-ESM2-0_r1i1p1f1 ssp370 has hfcorr (mean=0.255, std=0.045)
MRI-ESM2-0_r1i1p1f1 ssp585 has hfcorr (mean=0.251, std=0.052)
da_dict contains 600 DataArrays


## Constants and conversion factors

In [15]:
# Total area of earth
in_fn = pathlib.Path.cwd() / 'data' / 'area_earth.csv'  # produced by data_d22a.ipynb
area_df = pd.read_csv(in_fn)  # read to DataFrame
print(area_df)  # print
area_earth = area_df['area_earth'].mean()  # convert to single number
print(f'area_earth = {area_earth:.3e} m2')

                source_member    area_earth
0         ACCESS-CM2_r1i1p1f1  5.100645e+14
1      ACCESS-ESM1-5_r1i1p1f1  5.100645e+14
2       CMCC-CM2-SR5_r1i1p1f1  5.100645e+14
3          CMCC-ESM2_r1i1p1f1  5.100645e+14
4         CNRM-CM6-1_r1i1p1f2  5.100645e+14
5        CNRM-ESM2-1_r1i1p1f2  5.100645e+14
6            CanESM5_r1i1p1f1  5.100645e+14
7   EC-Earth3-Veg-LR_r1i1p1f1  5.100645e+14
8      EC-Earth3-Veg_r1i1p1f1  5.100645e+14
9          EC-Earth3_r1i1p1f1  5.100645e+14
10       GISS-E2-1-G_r1i1p5f1  5.100645e+14
11       GISS-E2-1-H_r1i1p1f2  5.100645e+14
12      IPSL-CM6A-LR_r1i1p1f1  5.100645e+14
13            MIROC6_r1i1p1f1  5.100645e+14
14     MPI-ESM1-2-HR_r1i1p1f1  5.100645e+14
15     MPI-ESM1-2-LR_r1i1p1f1  5.100645e+14
16        MRI-ESM2-0_r1i1p1f1  5.100645e+14
17        NorESM2-LM_r1i1p1f1  5.100645e+14
18        NorESM2-MM_r1i1p1f1  5.100645e+14
19       UKESM1-0-LL_r1i1p1f2  5.100645e+14
area_earth = 5.101e+14 m2


In [16]:
# Conversion factor for W m-2 a -> YJ
convert_Wm2a_YJ = area_earth * 365 * 24 * 60 * 60 / 1e24
print(f'1 W m-2 a = {convert_Wm2a_YJ:.4f} YJ')
print(f'1 YJ = {1/convert_Wm2a_YJ:.2f} W m-2 a')

1 W m-2 a = 0.0161 YJ
1 YJ = 62.17 W m-2 a


## Monte Carlo Drift Correction functions

**`calc_trend()`** and **`calc_mean()`** calculate the trend/mean of a time series (an xr.DataArray).
By default, the trend/mean is perturbed by a random error, drawn from a Gaussian distribution corresponding to the standard errror.

**`sample_segments_calc_trends_means()`** randomly draws segments of a specified length from a list of DataArrays.
The DataArrays are sampled consistently using the same 150 year segments.
For each segment, the function uses `calc_trend()` and `calc_mean()` to calculate trends and means (with random errors included by default).
`sample_segments_calc_trends_means()` returns trends (list of arrays), means (list of arrays), and the start years of the segments sampled (array).

In [17]:
def calc_trend(data_da, inc_rand_error=True):
    """Calculate the trend of a time series (xr.DataArray), include random error (optional), and return trend."""
    # Linear regression (using statsmodels)
    x_in = sm.add_constant(data_da.time, prepend=True)
    sm_reg = sm.OLS(data_da.data, x_in).fit()
    b = sm_reg.params[1]  # slope
    # Include random error using the standard error (assuming Gaussian)?
    if inc_rand_error:
        bse = sm_reg.bse[1]  # standard error on slope
        trend = rng.normal(loc=b, scale=bse)  # sample from Gaussian
    else:
        trend = b
    return trend

# Example
data_da = da_dict[(source_member_list[-1], 'piControl', '\Delta Z')].copy()
%time calc_trend(data_da)

CPU times: user 1.05 ms, sys: 534 µs, total: 1.59 ms
Wall time: 1.07 ms


-0.053359242656797146

In [18]:
def calc_mean(data_da, inc_rand_error=True):
    """Calculate the mean of a time series (xr.DataArray), include random error (optional), and return result."""
    # Mean
    m = np.mean(data_da.data)
    # Include random error using the standard error (assuming Gaussian)?
    if inc_rand_error:
        sem = stats.sem(data_da.data)  # standard eror of mean
        res = rng.normal(loc=m, scale=sem)  # sample from Gaussian
    else:
        res = m
    return res

# Example
data_da = da_dict[(source_member_list[-1], 'piControl', '\Delta Z')].copy()
%time calc_mean(data_da)

CPU times: user 373 µs, sys: 266 µs, total: 639 µs
Wall time: 379 µs


-40.920236559730064

In [19]:
def sample_segments_calc_trends_means(da_list,
                                      inc_rand_error=True,
                                      sample_length=150, sample_n=500,
                                      verbose=False):
    """Draw n segments of specified length from DataArray list; return trends, means, start years (arrays)."""
    # Identify possible start years from which to sample
    poss_start_yrs = da_list[0].time.data[0:-sample_length+1]
    if verbose:
        print(f'Possible start years: {poss_start_yrs[0]} - {poss_start_yrs[-1]}')
    # Randomly choose start years (with replacement)
    rand_start_yrs = rng.choice(poss_start_yrs, size=sample_n, replace=True)
    if verbose:
        print(f'Randomly chosen start years: {rand_start_yrs}')
    # Lists to hold trends and means for DataArrays in da_list    
    trends_list = []
    means_list = []
    # Loop over DataArrays in list
    for data_da in da_list:
        # Calculate trends and means of the chosen time series segments
        trends = np.zeros(sample_n)  # initialize arrays of trends/means with zero
        means = np.zeros(sample_n)
        for i in range(sample_n):
            start_yr = rand_start_yrs[i]
            sample_da = data_da.sel(time=slice(start_yr,start_yr+sample_length-1))
            trends[i] = calc_trend(sample_da, inc_rand_error=inc_rand_error)
            means[i] = calc_mean(sample_da, inc_rand_error=inc_rand_error)
        if verbose:
            print(f'Trends: {trends}')
            print(f'Means: {means}')
        trends_list.append(trends)
        means_list.append(means)
    return trends_list, means_list, rand_start_yrs

# Example
da_list = [da_dict[(source_member_list[-1], 'piControl', 'R')].copy(),
           da_dict[(source_member_list[-1], 'piControl', 'H')].copy(),
           da_dict[(source_member_list[-1], 'piControl', '\Delta Z')].copy()]
%time sample_segments_calc_trends_means(da_list, sample_n=3, verbose=True)

Possible start years: 1 - 951
Randomly chosen start years: [195 759 612]
Trends: [-0.00010098 -0.0003937   0.00044171]
Means: [-0.01634032  0.04432896  0.00087912]
Trends: [1.02814916e-03 1.94414061e-05 2.44629685e-05]
Means: [-0.27763762 -0.24957567 -0.25250733]
Trends: [-0.06995799 -0.01925332 -0.00144673]
Means: [-29.38422147 -54.77843728 -49.6483107 ]
CPU times: user 7.89 ms, sys: 1.5 ms, total: 9.39 ms
Wall time: 8.11 ms


([array([-0.00010098, -0.0003937 ,  0.00044171]),
  array([1.02814916e-03, 1.94414061e-05, 2.44629685e-05]),
  array([-0.06995799, -0.01925332, -0.00144673])],
 [array([-0.01634032,  0.04432896,  0.00087912]),
  array([-0.27763762, -0.24957567, -0.25250733]),
  array([-29.38422147, -54.77843728, -49.6483107 ])],
 array([195, 759, 612]))

In [20]:
! date

Wed Aug 17 14:43:04 +08 2022
