Notebook with functions for reading in data and getting ensemble means for various experiments (0.1, 0.2, and 0.5 TW; surf, therm, mid, and bot; const, 2xCO2, and 4xCO2). This is designed to be used with the notebook read_and_calculate.ipynb.

In [1]:
import numpy as np
import xarray as xr

# modules for using datetime variables
import datetime
from datetime import time
from datetime import timedelta

import warnings
warnings.filterwarnings('ignore')

import cftime
from pandas.errors import OutOfBoundsDatetime  # Import the specific error

import copy

from xclim import ensembles

In [2]:
myVars = globals()

# Functions

## Functions to return ensemble and ensemble mean data

In [3]:
# added flag to omit the first ensemble member from the calculation (to use if the file/variable doesn't exist)

def create_const_doub_ens_mean(exp_name_list,start_years,end_years,chunk_length,
                               variable_list,pp_type='av-annual',diag_file='ocean_monthly_z',omit_mem1=False,debug=False):

    num_ens_mem = len(exp_name_list)
    ens_mem_list = [None] * num_ens_mem
    
    for idx, exp_name in enumerate(exp_name_list):
        if (omit_mem1 and idx == 0): # if not including first ens member to calculate mean
            if debug:
                print(f'Omitting member #1: {exp_name}')
            continue
        else:
            if idx == 0:
                time_decoding=True
            else:
                time_decoding=False

        ens_mem_list[idx] = get_pp_av_data(exp_name,start_years[idx],end_years[idx],chunk_length,pp_type=pp_type,\
                                                  diag_file=diag_file,time_decoding=time_decoding,var=variable_list,debug=False)
      
    if omit_mem1:
        # select non-None list elements
        ens_mem_list = ens_mem_list[1:]
        # adjust timestamps to range of year 1 to 200
        for idx, time_val in enumerate(ens_mem_list[0].time.values):
            new_year = time_val.year - 200
            new_date = time_val.replace(year=new_year)
            ens_mem_list[0].time.values[idx] = new_date
            ens_mem_list[1].time.values[idx] = new_date
    else:
        for idx, time_val in enumerate(ens_mem_list[0].time.values):
            ens_mem_list[1].time.values[idx] = time_val
            ens_mem_list[2].time.values[idx] = time_val
        
    # ensemble mean
    ens_mean = ensembles.create_ensemble(ens_mem_list).mean("realization")

    ext_var_list = copy.deepcopy(variable_list)
    if diag_file != 'ocean_scalar_monthly':
        ext_var_list.extend(["areacello","dxt","dyt","wet"]) #,"volcello"
    ens_mean = ens_mean[ext_var_list]
        
    return ens_mem_list, ens_mean

In [4]:
def create_quad_ens_mean(exp_name_list,doub_ens_mem_list,doub_ens_mean,start_years,end_years,chunk_length,
                              variable_list,pp_type='av-annual',diag_file='ocean_monthly_z',omit_mems=[],debug=False):
    num_ens_mem = len(exp_name_list)
    quad_ens_mem_list = [None] * num_ens_mem
    
    doub_cutoff_yr = 51
    doub_cutoff_dt = cftime.DatetimeNoLeap(doub_cutoff_yr, 1, 1, 0, 0, 0, 0, has_year_zero=True)
    post_51_start_years = [51,251,451]
    
    if start_years[0] < doub_cutoff_yr and end_years[0] < doub_cutoff_yr:
        quad_ens_mem_list = doub_ens_mem_list
        quad_ens_mean = doub_ens_mean
        
    elif start_years[0] < doub_cutoff_yr and end_years[0] > doub_cutoff_yr:
        for idx, ens_memb in enumerate(doub_ens_mem_list):
            if (1 in omit_mems and idx == 0): # if not including first ens member to calculate mean
                if debug:
                    print(f'Omitting member #1: {exp_name_list[idx]}')
                continue
            elif (2 in omit_mems and idx == 1):
                if debug:
                    print(f'Omitting member #2: {exp_name_list[idx]}')
                continue
            elif (3 in omit_mems and idx == 2):
                if debug:
                    print(f'Omitting member #3: {exp_name_list[idx]}')
                continue
            else:
                if idx == 0:
                    time_decoding=True
                else:
                    time_decoding=False

            # all the 2xCO2 ensemble members have already had their times adjusted to within 1-200 range
            pre_51_quad = ens_memb.sel(time=slice(None,doub_cutoff_dt))
            if idx == 1:
                for pre_51_idx, pre_51_time in enumerate(pre_51_quad.time.values):
                    new_year = pre_51_time.year + 200
                    new_date = pre_51_time.replace(year=new_year)
                    pre_51_quad.time.values[pre_51_idx] = new_date
            elif idx == 2:
                for pre_51_idx, pre_51_time in enumerate(pre_51_quad.time.values):
                    new_year = pre_51_time.year + 400
                    new_date = pre_51_time.replace(year=new_year)
                    pre_51_quad.time.values[pre_51_idx] = new_date
            
            if len(pre_51_quad) == 0:
                raise ValueError("len(pre_51_quad) = 0. It seems that 2xCO2 members have not had their times \
                adjusted to be within 1-200 years.")
        
            post_51_quad = get_pp_av_data(exp_name_list[idx],post_51_start_years[idx],end_years[idx],chunk_length,pp_type=pp_type,\
                                          diag_file=diag_file,time_decoding=time_decoding,var=variable_list,debug=False)
            post_51_quad = post_51_quad[variable_list]
            quad_ens_mem_list[idx] = xr.concat([pre_51_quad,post_51_quad],"time")
    
    else: # a.k.a. if (start_years[0] > doub_cutoff_yr)
        for idx, ens_memb in enumerate(doub_ens_mem_list):
            if (1 in omit_mems and idx == 0): # if not including first ens member to calculate mean
                if debug:
                    print(f'Omitting member #1: {exp_name_list[idx]}')
                continue
            elif (2 in omit_mems and idx == 1):
                if debug:
                    print(f'Omitting member #2: {exp_name_list[idx]}')
                continue
            elif (3 in omit_mems and idx == 2):
                if debug:
                    print(f'Omitting member #3: {exp_name_list[idx]}')
                continue
            if idx == 0:
                time_decoding=True
            else:
                time_decoding=False
                
            quad_ens_mem_list[idx] = get_pp_av_data(exp_name_list[idx],start_years[idx],end_years[idx],chunk_length,pp_type=pp_type,\
                                                    diag_file=diag_file,time_decoding=time_decoding,var=variable_list,debug=False)

    if omit_mems == [1]:
        # select non-None list elements
        quad_ens_mem_list = quad_ens_mem_list[1:]
        # adjust timestamps to range of year 1 to 200
        for idx, time_val in enumerate(quad_ens_mem_list[0].time.values):
            new_year = time_val.year - 200
            new_date = time_val.replace(year=new_year)
            quad_ens_mem_list[0].time.values[idx] = new_date
            quad_ens_mem_list[1].time.values[idx] = new_date
            if debug:
                if idx == 0:
                    print(f'time_val: {time_val}')
                    print(f'New year: {new_date}')
                    print(f'New date: {new_date}')
                    print(f'Actual xarray value: {quad_ens_mem_list[0].time.values[idx]}')
    elif omit_mems == [1,2]:
        # select non-None list elements
        quad_ens_mem_list = [quad_ens_mem_list[-1]]
        # adjust timestamps to range of year 1 to 200
        for idx, time_val in enumerate(quad_ens_mem_list[0].time.values):
            new_year = time_val.year - 400
            new_date = time_val.replace(year=new_year)
            quad_ens_mem_list[0].time.values[idx] = new_date
            if debug:
                if idx == 0:
                    print(f'time_val: {time_val}')
                    print(f'New year: {new_date}')
                    print(f'New date: {new_date}')
                    print(f'Actual xarray value: {quad_ens_mem_list[0].time.values[idx]}')
    elif omit_mems == [2]:
        # select non-None list elements
        quad_ens_mem_list = [quad_ens_mem_list[0],quad_ens_mem_list[2]]
        # adjust timestamps to range of year 1 to 200
        for idx, time_val in enumerate(quad_ens_mem_list[0].time.values):
            quad_ens_mem_list[1].time.values[idx] = time_val
    elif omit_mems == [3]:
        # select non-None list elements
        quad_ens_mem_list = quad_ens_mem_list[0,1]
        # adjust timestamps to range of year 1 to 200
        for idx, time_val in enumerate(quad_ens_mem_list[0].time.values):
            quad_ens_mem_list[1].time.values[idx] = time_val
    else:
        # adjust timestamps to range of year 1 to 200
        for idx, time_val in enumerate(quad_ens_mem_list[0].time.values):
            quad_ens_mem_list[1].time.values[idx] = time_val
            quad_ens_mem_list[2].time.values[idx] = time_val
                
    # ensemble mean
    if omit_mems == [1,2]:
        quad_ens_mean = copy.deepcopy(quad_ens_mem_list[0])
    else:
        quad_ens_mean = ensembles.create_ensemble(quad_ens_mem_list).mean("realization")

    ext_var_list = copy.deepcopy(variable_list)
    if diag_file != 'ocean_scalar_monthly':
        ext_var_list.extend(["areacello","dxt","dyt","wet"]) #,"volcello"
    quad_ens_mean = quad_ens_mean[ext_var_list]
    
    return quad_ens_mem_list, quad_ens_mean

## Function to calculate ensemble-mean differences and horizontal mean differences

In [14]:
def calc_ens_diffs(diff_ens_name,ref_ens_list,perturb_ens_list,variable_list,diag_file='ocean_monthly_z',stdev_ds=None,verbose=False):

    num_ens_mem = len(ref_ens_list)
    
    if num_ens_mem != 1:
        
        diffs_mem_list = [None] * num_ens_mem
        
        # raise an error if a variable isn't found in one of the arrays
        for i in range(num_ens_mem):
            # check that each variable is in the i-th member of both the reference and perturbed simulations
            for elem in variable_list:
                if elem not in ref_ens_list[i].variables:
                    raise IOError(f"{elem} not found in ref_ens_list[{i}].")
                if elem not in perturb_ens_list[i].variables:
                    raise IOError(f"{elem} not found in perturb_ens_list[{i}].")
        
            # take difference of variables
            diffs_mem_list[i] = perturb_ens_list[i][variable_list] - ref_ens_list[i][variable_list]
        
        diff_ens = ensembles.create_ensemble(diffs_mem_list)
        
        # create hatching variable
        if diag_file != 'ocean_scalar_monthly':
            
            if stdev_ds != None:
            # diff_ens['wet'] = ref_ens_list[0]['wet']
        
                for var in variable_list:
                    stdev_cond = (diff_ens[var] > stdev_ds[f"{var}_stdev"]).all(dim="realization")
                    nan_cond = np.isnan(diff_ens[var]).all(dim="realization")

                    # Identify regions where members all have significant change or it's NaN (i.e., land or bathymetry)
                    # hatching_maskbin = stdev_cond | nan_cond

                    agree_cond = ((diff_ens[var] > 0).all(dim="realization")) | ((diff_ens[var] < 0).all(dim="realization"))
                    disagree_cond = (agree_cond == False) & (nan_cond == False)

                    # regions where change is significant among all ensemble members but not all members agree on sign
                    # it seems that this isn't the case anywhere (check by taking max of <var>_hatch)
                    hatching_maskbin = disagree_cond & stdev_cond
            
                    # Set hatching mask: True where change not significant
                    # hatching_mask = xr.where(hatching_maskbin, False, True)
                    # switched for the disagree_cond
                    hatching_mask = xr.where(hatching_maskbin, True, False)
                    
                    # Store in dataset
                    mask_name = f"{var}_hatch"
                    diff_ens[mask_name] = hatching_mask
    
                    # initial masking conditions
                    # cond1 = (diff_ens[var] > 0).all(dim="realization")
                    # cond2 = (diff_ens[var] < 0).all(dim="realization")
                    # nan_cond = np.isnan(diff_ens[var])
            
                    # # Identify regions where members agree or where it's NaN (i.e., land)
                    # hatching_maskbin = cond1 | cond2 | nan_cond
            
                    # # Set hatching mask: True where they disagree, False where they agree or it's NaN (over land or bathymetry)
                    # hatching_mask = xr.where(hatching_maskbin, False, True)
                    
                    # # Store in dataset
                    # mask_name = f"{var}_disagree"
                    # diff_ens[mask_name] = hatching_mask
        
            diff_ens = diff_ens.mean("realization")
            diff_ens['wet'] = ref_ens_list[0]['wet']
            diff_ens['dxt'] = ref_ens_list[0]['dxt']
            diff_ens['dyt'] = ref_ens_list[0]['dyt']
            diff_ens['areacello'] = ref_ens_list[0]['areacello']
            # diff_ens['volcello'] = ref_ens_list[0]['volcello']
            
        else:
            diff_ens = diff_ens.mean("realization")
        
        myVars.__setitem__(diff_ens_name, diff_ens)
        if verbose:
            print(f'{diff_ens_name}')
        
        # alternative to get more stats is:
        # diff_ens = ensembles.ensemble_mean_std_max_min(diff_ens)
        # for var in variable_list:
        #     mean_var_name = var + '_mean'
        #     diff_ens = diff_ens.rename({mean_var_name: var})
        
        if diag_file != 'ocean_scalar_monthly':
            horiz_avg_diff_name = f"{diff_ens_name}_mean"
            horiz_avg_diff = xr.Dataset()
            for var in variable_list:
                horiz_avg_diff[var] = horizontal_mean(diff_ens[var],diff_ens)
            myVars.__setitem__(horiz_avg_diff_name, horiz_avg_diff)
            if verbose:
                print(f'{horiz_avg_diff_name} done')
                
    # if there is only one ens member
    else:
        # take difference of variables
        diffs_ens = perturb_ens_list[0][variable_list] - ref_ens_list[0][variable_list]
        myVars.__setitem__(diff_ens_name, diff_ens)
        if verbose:
            print(f'{diff_ens_name}')

## Main functions

In [6]:
def get_ens_dat(co2_scen, avg_period, mem1_start, mem1_end, var_list,
                pp_type='av-annual',
                diag_file='ocean_monthly_z',
                profiles = ['surf','therm','mid','bot'],
                power_inputs = ['0.1TW', '0.2TW', '0.5TW'],
                power_var_suff = ['0p1TW', '0p2TW', '0p5TW'],
                stdev_ds=None,
                verbose=False,
                debug=False):

    """
    Returns variables containing the ensemble-mean raw data and variables containing the ensemble-mean anomaly data. Anomalies are
    calculated as the difference relative to the control run during the period corresponding to an ensemble member (i.e. the anomalies
    for ensemble member 2 for year 201 to 400 are taking as the difference relative to year 201 to 400 of the control run).

        Args:
            co2_scen (str): one of ['const','doub','quad','const+doub','all']; difference datasets will only be created for the co2 scenario specified, 
                            but ensembles + means may be created for control case of other CO2 scenarios
            avg_period (int): number of years for av/ts period
            mem1_start (int): start year of ens. mem. #1
            mem1_end (int): end year of ens. mem. #1
            var_list (str list): list of variables to read (e.g. var_list = ["temp", "N2", "age", "rhopot2", "salt"])
            profiles (str list): list of profiles to get data for
            power_inputs (str list): list of power inputs to get data for
            power_var_suff (str list): list of variable suffixes for each power input
            stdev_ds (dataset): required to compute hatching variable in calc_ens_diffs();
                                typically the standard deviation of the control run, with variables named <var>_stdev
            verbose: if True, print variable names after declaration
            
        Returns:
            has no return variables, but creates xarray datasets by using myVars = globals()
            
    """
    allowed_scen = ['const','doub','quad','const+doub','all']
    
    if co2_scen not in allowed_scen:
        raise ValueError(f"'co2_scen' must be one of {allowed_scen}.")

    num_ens_mem = 3

    start_yrs = [mem1_start,
                 mem1_start+200,
                 mem1_start+400]
    end_yrs = [mem1_end,
               mem1_end+200,
               mem1_end+400]

    const_exp_root = '_1860IC_200yr_'
    mem1_doub_exp_root = '_2xCO2_1860IC_200yr_' # exp root for 2xCO2 exps is different for mem1 compared to mem2 and mem3
    mem2_3_doub_exp_root = '_2xCO2_200yr_'
    quad_exp_root = '_4xCO2_51-201_'

    ##### CONTROL RUNS #####
    
    ## const CO2 control ##
    const_ctrl_exps = ["tune_ctrl_const_200yr",#"tune_ctrl_1860IC_200yr",
                       "ctrl_1860IC_201-2001", #tune_ctrl_1860IC_201-2001
                       "ctrl_1860IC_201-2001"]

    const_ctrl_mem_list, const_ctrl = create_const_doub_ens_mean(const_ctrl_exps,start_yrs,end_yrs,avg_period,var_list,
                                                                 pp_type=pp_type,diag_file=diag_file,debug=debug)
    const_ctrl_name = f"const_ctrl_{mem1_start}_{mem1_end}"
    const_ctrl_mem_list_name = f"{const_ctrl_name}_mem_list"
    myVars.__setitem__(const_ctrl_name, const_ctrl)
    myVars.__setitem__(const_ctrl_mem_list_name, const_ctrl_mem_list)
    if verbose:
        print(f'{const_ctrl_name}, {const_ctrl_mem_list_name} done')

    if co2_scen != 'const':
        ## 2xCO2 control ##
        doub_ctrl_exps = ["tune_ctrl_2xCO2_1860IC_200yr",
                          "ens2_ctrl_2xCO2_200yr",
                          "ens3_ctrl_2xCO2_200yr"]
        
        doub_ctrl_mem_list, doub_ctrl = create_const_doub_ens_mean(doub_ctrl_exps,start_yrs,end_yrs,avg_period,var_list,
                                                                   pp_type=pp_type,diag_file=diag_file,debug=debug)
        doub_ctrl_name = f"doub_ctrl_{mem1_start}_{mem1_end}"
        doub_ctrl_mem_list_name = f"{doub_ctrl_name}_mem_list"
        myVars.__setitem__(doub_ctrl_name, doub_ctrl)
        myVars.__setitem__(doub_ctrl_mem_list_name, doub_ctrl_mem_list)
        if verbose:
            print(f'{doub_ctrl_name}, {doub_ctrl_mem_list_name} done')

        if (co2_scen == 'doub' or co2_scen == 'const+doub' or co2_scen == 'all'):
            # differences compared to constant CO2 control #
            calc_ens_diffs(f"doub_ctrl_{mem1_start}_{mem1_end}_diff",
                           const_ctrl_mem_list,doub_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)

        if (co2_scen == 'quad' or co2_scen == 'all'):
            ## 4xCO2 control ##
            quad_ctrl_exps = ["tune_ctrl_4xCO2_51-201",
                              "ens2_ctrl_4xCO2_51-201",
                              "ens3_ctrl_4xCO2_51-201"]

            # if 'thetaoga' in var_list:
            #     quad_ctrl_omit_mems = [2]
            # else:
            #     quad_ctrl_omit_mems = []
            quad_ctrl_omit_mems = []
        
            quad_ctrl_mem_list, quad_ctrl = create_quad_ens_mean(quad_ctrl_exps,doub_ctrl_mem_list,doub_ctrl,start_yrs,end_yrs,
                                                                 avg_period,var_list,pp_type=pp_type,diag_file=diag_file,
                                                                 omit_mems=quad_ctrl_omit_mems,debug=debug)
            quad_ctrl_name = f"quad_ctrl_{mem1_start}_{mem1_end}"
            quad_ctrl_mem_list_name = f"{quad_ctrl_name}_mem_list"
            myVars.__setitem__(quad_ctrl_name, quad_ctrl)
            myVars.__setitem__(quad_ctrl_mem_list_name, quad_ctrl_mem_list)
            if verbose:
                print(f'{quad_ctrl_name}, {quad_ctrl_mem_list_name} done')
        
            # differences compared to constant CO2 and 2xCO2 controls #
            if quad_ctrl_omit_mems == [2]:
                calc_ens_diffs(f"quad_ctrl_{mem1_start}_{mem1_end}_diff_const_ctrl",
                               [const_ctrl_mem_list[0],const_ctrl_mem_list[2]],quad_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                calc_ens_diffs(f"quad_ctrl_{mem1_start}_{mem1_end}_diff_2xctrl",
                               [doub_ctrl_mem_list[0],doub_ctrl_mem_list[2]],quad_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
            else:
                calc_ens_diffs(f"quad_ctrl_{mem1_start}_{mem1_end}_diff_const_ctrl",
                               const_ctrl_mem_list,quad_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                calc_ens_diffs(f"quad_ctrl_{mem1_start}_{mem1_end}_diff_2xctrl",
                               doub_ctrl_mem_list,quad_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)

    
    ##### PERTURBATION RUNS #####
    
    for prof in profiles:
        for index, power_str in enumerate(power_inputs):
            if verbose:
                print(f"{prof} {power_str} experiments")
                    
            const_exp_name_list = [None] * num_ens_mem
            doub_exp_name_list = [None] * num_ens_mem
            quad_exp_name_list = [None] * num_ens_mem

            if power_str == '0.5TW':
                const_exp_name_list[0] = "mem1_"+prof+const_exp_root+power_str
            else:
                const_exp_name_list[0] = prof+const_exp_root+power_str

            const_exp_name_list[1] = "ens2_"+prof+const_exp_root+power_str
            const_exp_name_list[2] = "mem3_"+prof+const_exp_root+power_str
            
            doub_exp_name_list[0] = prof+mem1_doub_exp_root+power_str
            doub_exp_name_list[1] = "ens2_"+prof+mem2_3_doub_exp_root+power_str
            doub_exp_name_list[2] = "ens3_"+prof+mem2_3_doub_exp_root+power_str

            quad_exp_name_list[0] = prof+quad_exp_root+power_str
            quad_exp_name_list[1] = "ens2_"+prof+quad_exp_root+power_str
            quad_exp_name_list[2] = "ens3_"+prof+quad_exp_root+power_str

            const_ens_mem_list, const_ens_mean = create_const_doub_ens_mean(const_exp_name_list,start_yrs,end_yrs,
                                                                            avg_period,var_list,
                                                                            pp_type=pp_type,diag_file=diag_file,debug=debug)
            
            if co2_scen != 'const':
                doub_ens_mem_list, doub_ens_mean = create_const_doub_ens_mean(doub_exp_name_list,start_yrs,end_yrs,
                                                                              avg_period,var_list,
                                                                              pp_type=pp_type,diag_file=diag_file,debug=debug)
                
                if (co2_scen == 'quad' or co2_scen == 'all'):
                    quad_omit_mems = []
                    # if ('thetaoga' in var_list and prof=='surf' and power_str == '0.5TW'):
                    #     quad_omit_mems = [2]
                    # elif ('thetaoga' in var_list and prof=='therm' and power_str == '0.5TW'):
                    #     quad_omit_mems = [1,2]
                    # elif ('thetaoga' in var_list and prof=='mid' and power_str == '0.5TW'):
                    #     quad_omit_mems = [1,2]
                    # elif ('thetaoga' in var_list and prof=='bot' and power_str == '0.5TW'):
                    #     quad_omit_mems = [1,2]
                    # else:
                    #     quad_omit_mems = []
                        
                    quad_ens_mem_list, quad_ens_mean = create_quad_ens_mean(quad_exp_name_list,doub_ens_mem_list,
                                                                            doub_ens_mean,start_yrs,end_yrs,
                                                                            avg_period,var_list,
                                                                            pp_type=pp_type,diag_file=diag_file,
                                                                            omit_mems=quad_omit_mems,debug=debug)

            ## COMPUTE DIFFERENCES ##
            diff_root = f"{prof}_{power_var_suff[index]}_{mem1_start}_{mem1_end}_diff"

            # differences wrt 1860 control
            const_diff_name = f"const_{diff_root}"
            doub_const_ctrl_diff_name = f"doub_{diff_root}_const_ctrl"
            quad_const_ctrl_diff_name = f"quad_{diff_root}_const_ctrl"

            # differences wrt 1860 experiment with same diffusivity history
            doub_1860_diff_name = f"doub_{diff_root}_1860"
            quad_1860_diff_name = f"quad_{diff_root}_1860"

            # differences wrt control for particular CO2 scenario
            doub_2xctrl_diff_name = f"doub_{diff_root}_2xctrl"
            quad_4xctrl_diff_name = f"quad_{diff_root}_4xctrl"

            ## CONST EXPERIMENTS
            if (co2_scen == 'const' or co2_scen == 'const+doub' or co2_scen == 'all'):
                calc_ens_diffs(const_diff_name,const_ctrl_mem_list,
                               const_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)

            ## 2xCO2 EXPERIMENTS
            if (co2_scen == 'doub' or co2_scen == 'const+doub' or co2_scen == 'all'):
                # differences wrt 1860 control
                calc_ens_diffs(doub_const_ctrl_diff_name,const_ctrl_mem_list,
                               doub_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                    
                # differences wrt 1860 experiment with same Kd history
                calc_ens_diffs(doub_1860_diff_name,const_ens_mem_list,
                               doub_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                
                # differences wrt control for particular CO2 scenario
                calc_ens_diffs(doub_2xctrl_diff_name,doub_ctrl_mem_list,
                               doub_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                
            ## 4xCO2 EXPERIMENTS
            if (co2_scen == 'quad' or co2_scen == 'all'):
                # differences wrt 1860 control
                if quad_omit_mems == [2]:
                    calc_ens_diffs(quad_const_ctrl_diff_name,[const_ctrl_mem_list[0],const_ctrl_mem_list[2]],
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                elif quad_omit_mems == [1,2]:
                    calc_ens_diffs(quad_const_ctrl_diff_name,[const_ctrl_mem_list[2]],
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                else:
                    calc_ens_diffs(quad_const_ctrl_diff_name,const_ctrl_mem_list,
                                   quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                    
                # differences wrt 1860 experiment with same Kd history
                if quad_omit_mems == [2]:
                    calc_ens_diffs(quad_1860_diff_name,[const_ens_mem_list[0],const_ens_mem_list[2]],
                                   quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                elif quad_omit_mems == [1,2]:
                    calc_ens_diffs(quad_1860_diff_name,[const_ens_mem_list[2]],
                                   quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                else:
                    calc_ens_diffs(quad_1860_diff_name,const_ens_mem_list,
                                   quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                # differences wrt control for particular CO2 scenario
                calc_ens_diffs(quad_4xctrl_diff_name,quad_ctrl_mem_list,
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)

                # additional difference calcs for 4xCO2 cases #
                # difference wrt 2xCO2 ctrl
                quad_2xctrl_diff_name = f"quad_{diff_root}_2xctrl"
                calc_ens_diffs(quad_2xctrl_diff_name,doub_ctrl_mem_list,
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                # difference wrt 2xCO2 experiment with same diffusivity history
                quad_2xCO2_diff_name = f"quad_{diff_root}_2xCO2"
                calc_ens_diffs(quad_2xCO2_diff_name,doub_ens_mem_list,
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)


In [7]:
# function to get ensemble mean for every case, but not calculate any differences

def get_ens_means(co2_scen, avg_period, mem1_start, mem1_end, var_list,
                  pp_type='av-annual',
                  diag_file='ocean_monthly_z',
                  profiles = ['surf','therm','mid','bot'],
                  power_inputs = ['0.1TW', '0.2TW', '0.5TW'],
                  power_var_suff = ['0p1TW', '0p2TW', '0p5TW'],
                  skip_ctrl=False,
                  ctrl_only=False,
                  verbose=False,
                  debug=False):

    """
    Returns variables containing the ensemble-mean data only (no anomalies).

        Args:
            co2_scen (str): one of ['const','doub','quad','const+doub','all']
            avg_period (int): number of years for av/ts period
            mem1_start (int): start year of ens. mem. #1
            mem1_end (int): end year of ens. mem. #1
            var_list (str list): list of variables to read (e.g. var_list = ["Kd_int_tuned", "Kd_int_base", "Kd_interface"])
            profiles (str list): list of profiles to get data for
            power_inputs (str list): list of power inputs to get data for
            power_var_suff (str list): list of variable suffixes for each power input
            skip_ctrl (bool): if True, don't read control data
            ctrl_only (bool): if True, only read control data
            verbose (bool): if True, print variable names after declaration
            
        Returns:
            has no return variables, but creates xarray datasets by using myVars = globals()
            
    """
    allowed_scen = ['const','doub','quad','const+doub','all']

    ctrl_var_list = copy.deepcopy(var_list)
    ctrl_vars_to_drop = ["Kd_int_tuned", "Kd_int_base"]
    
    for elem in ctrl_vars_to_drop:
        if elem in ctrl_var_list:
            ctrl_var_list.remove(elem)

    if verbose:
        print(var_list)
            
    if co2_scen not in allowed_scen:
        raise ValueError(f"'co2_scen' must be one of {allowed_scen}.")

    num_ens_mem = 3

    start_yrs = [mem1_start,
                 mem1_start+200,
                 mem1_start+400]
    end_yrs = [mem1_end,
               mem1_end+200,
               mem1_end+400]

    const_exp_root = '_1860IC_200yr_'
    mem1_doub_exp_root = '_2xCO2_1860IC_200yr_' # exp root for 2xCO2 exps is different for mem1 compared to mem2 and mem3
    mem2_3_doub_exp_root = '_2xCO2_200yr_'
    quad_exp_root = '_4xCO2_51-201_'

    # ##### CONTROL RUNS #####

    if skip_ctrl==False:
        ## const CO2 control ##
        if (co2_scen == 'const' or co2_scen == 'const+doub' or co2_scen == 'all'):
            
            const_ctrl_exps = ["tune_ctrl_const_200yr",#"tune_ctrl_1860IC_200yr",
                               "ctrl_1860IC_201-2001",
                               "ctrl_1860IC_201-2001"]
        
            const_ctrl_mem_list, const_ctrl = create_const_doub_ens_mean(const_ctrl_exps,start_yrs,end_yrs,avg_period,ctrl_var_list,
                                                                         pp_type=pp_type,diag_file=diag_file,debug=debug)
            const_ctrl_name = f"const_ctrl_{mem1_start}_{mem1_end}"
            const_ctrl_mem_list_name = f"{const_ctrl_name}_mem_list"
            myVars.__setitem__(const_ctrl_name, const_ctrl)
            myVars.__setitem__(const_ctrl_mem_list_name, const_ctrl_mem_list)
            if verbose:
                print(f'{const_ctrl_name}, {const_ctrl_mem_list_name} done')
            
            if diag_file != 'ocean_scalar_monthly':
                horiz_avg = xr.Dataset()
                for var in ctrl_var_list:
                    horiz_avg[var] = horizontal_mean(const_ctrl[var],const_ctrl)
            
                horiz_avg_name = f"{const_ctrl_name}_mean"
                myVars.__setitem__(horiz_avg_name, horiz_avg)
                if verbose:
                    print(f'{horiz_avg_name} done')

        
        if co2_scen != 'const':
            ## 2xCO2 control ##
            doub_ctrl_exps = ["tune_ctrl_2xCO2_1860IC_200yr",
                              "ens2_ctrl_2xCO2_200yr",
                              "ens3_ctrl_2xCO2_200yr"]
            
            doub_ctrl_mem_list, doub_ctrl = create_const_doub_ens_mean(doub_ctrl_exps,start_yrs,end_yrs,avg_period,ctrl_var_list,
                                                                       pp_type=pp_type,diag_file=diag_file,debug=debug)

            if (co2_scen == 'const+doub' or co2_scen == 'doub' or co2_scen == 'all'):
                doub_ctrl_name = f"doub_ctrl_{mem1_start}_{mem1_end}"
                doub_ctrl_mem_list_name = f"{doub_ctrl_name}_mem_list"
                myVars.__setitem__(doub_ctrl_name, doub_ctrl)
                myVars.__setitem__(doub_ctrl_mem_list_name, doub_ctrl_mem_list)
                if verbose:
                    print(f'{doub_ctrl_name}, {doub_ctrl_mem_list_name} done')

                if diag_file != 'ocean_scalar_monthly':
                    horiz_avg = xr.Dataset()
                    for var in ctrl_var_list:
                        horiz_avg[var] = horizontal_mean(doub_ctrl[var],doub_ctrl)
                        
                    horiz_avg_name = f"{doub_ctrl_name}_mean"
                    myVars.__setitem__(horiz_avg_name, horiz_avg)
                    if verbose:
                        print(f'{horiz_avg_name} done')

                        
        if (co2_scen == 'quad' or co2_scen == 'all'):
            ## 4xCO2 control ##
            quad_ctrl_exps = ["tune_ctrl_4xCO2_51-201",
                              "ens2_ctrl_4xCO2_51-201",
                              "ens3_ctrl_4xCO2_51-201"]

            # if 'thetaoga' in ctrl_var_list:
            #     quad_ctrl_omit_mems = [2]
            # else:
            #     quad_ctrl_omit_mems = []
            quad_ctrl_omit_mems = []
        
            quad_ctrl_mem_list, quad_ctrl = create_quad_ens_mean(quad_ctrl_exps,doub_ctrl_mem_list,doub_ctrl,start_yrs,end_yrs,
                                                                 avg_period,ctrl_var_list,pp_type=pp_type,diag_file=diag_file,
                                                                 omit_mems=quad_ctrl_omit_mems,debug=debug)
            quad_ctrl_name = f"quad_ctrl_{mem1_start}_{mem1_end}"
            quad_ctrl_mem_list_name = f"{quad_ctrl_name}_mem_list"
            myVars.__setitem__(quad_ctrl_name, quad_ctrl)
            myVars.__setitem__(quad_ctrl_mem_list_name, quad_ctrl_mem_list)
            if verbose:
                print(f'{quad_ctrl_name}, {quad_ctrl_mem_list_name} done')

            if diag_file != 'ocean_scalar_monthly':
                horiz_avg = xr.Dataset()
                for var in ctrl_var_list:
                    horiz_avg[var] = horizontal_mean(quad_ctrl[var],quad_ctrl)
                    
                horiz_avg_name = f"{quad_ctrl_name}_mean"
                myVars.__setitem__(horiz_avg_name, horiz_avg)
                if verbose:
                    print(f'{horiz_avg_name} done')
                
    
    ##### PERTURBATION RUNS #####

    if ctrl_only==False:
        for prof in profiles:
            for index, power_str in enumerate(power_inputs):
                        
                const_exp_name_list = [None] * num_ens_mem
                doub_exp_name_list = [None] * num_ens_mem
                quad_exp_name_list = [None] * num_ens_mem

                if power_str == '0.5TW':
                    const_exp_name_list[0] = "mem1_"+prof+const_exp_root+power_str
                else:
                    const_exp_name_list[0] = prof+const_exp_root+power_str

                const_exp_name_list[1] = "ens2_"+prof+const_exp_root+power_str
                const_exp_name_list[2] = "mem3_"+prof+const_exp_root+power_str
                
                doub_exp_name_list[0] = prof+mem1_doub_exp_root+power_str
                doub_exp_name_list[1] = "ens2_"+prof+mem2_3_doub_exp_root+power_str
                doub_exp_name_list[2] = "ens3_"+prof+mem2_3_doub_exp_root+power_str
    
                quad_exp_name_list[0] = prof+quad_exp_root+power_str
                quad_exp_name_list[1] = "ens2_"+prof+quad_exp_root+power_str
                quad_exp_name_list[2] = "ens3_"+prof+quad_exp_root+power_str
    
                ens_mean_root = f"{prof}_{power_var_suff[index]}_{mem1_start}_{mem1_end}"
    
                if (co2_scen == 'const' or co2_scen == 'const+doub' or co2_scen == 'all'):
                    const_ens_mem_list, const_ens_mean = create_const_doub_ens_mean(const_exp_name_list,start_yrs,end_yrs,
                                                                                    avg_period,var_list,
                                                                                    pp_type=pp_type,diag_file=diag_file,debug=debug)
                    
                    const_ens_mean_name = f"const_{ens_mean_root}"
                    myVars.__setitem__(const_ens_mean_name, const_ens_mean)
                    print(f'{const_ens_mean_name} done')

                    if diag_file != 'ocean_scalar_monthly':
                        horiz_avg = xr.Dataset()
                        for var in var_list:
                            # this doesn't actually make any noticable difference for global means, but perhaps would be
                            # important when taking regional average
                            if var == 'Kd_int_tuned':
                                Kd_added = const_ens_mean['Kd_int_tuned']
                                no_zeros_Kd = Kd_added.where(Kd_added != 0, np.nan)
                                horiz_avg['Kd_int_tuned'] = horizontal_mean(no_zeros_Kd,const_ens_mean)
                            else:
                                horiz_avg[var] = horizontal_mean(const_ens_mean[var],const_ens_mean)
                                
                        horiz_avg_name = f"{const_ens_mean_name}_mean"
                        myVars.__setitem__(horiz_avg_name, horiz_avg)
                        print(f'{horiz_avg_name} done')

                
                if co2_scen != 'const':
                    doub_ens_mem_list, doub_ens_mean = create_const_doub_ens_mean(doub_exp_name_list,start_yrs,end_yrs,
                                                                                  avg_period,var_list,
                                                                                  pp_type=pp_type,diag_file=diag_file,debug=debug)
                    if (co2_scen != 'quad'):
                        doub_ens_mean_name = f"doub_{ens_mean_root}"
                        myVars.__setitem__(doub_ens_mean_name, doub_ens_mean)
                        print(f'{doub_ens_mean_name} done')

                        if diag_file != 'ocean_scalar_monthly':
                            horiz_avg = xr.Dataset()
                            for var in var_list:
                                if var == 'Kd_int_tuned':
                                    Kd_added = doub_ens_mean['Kd_int_tuned']
                                    no_zeros_Kd = Kd_added.where(Kd_added != 0, np.nan)
                                    horiz_avg['Kd_int_tuned'] = horizontal_mean(no_zeros_Kd,doub_ens_mean)
                                else:
                                    horiz_avg[var] = horizontal_mean(doub_ens_mean[var],doub_ens_mean)
                                    
                            horiz_avg_name = f"{doub_ens_mean_name}_mean"
                            myVars.__setitem__(horiz_avg_name, horiz_avg)
                            print(f'{horiz_avg_name} done')
                    
                    if (co2_scen == 'quad' or co2_scen == 'all'):
                        quad_omit_mems = []
                        # if ('thetaoga' in var_list and prof=='surf' and power_str == '0.5TW'):
                        #     quad_omit_mems = [2]
                        # elif ('thetaoga' in var_list and prof=='therm' and power_str == '0.5TW'):
                        #     quad_omit_mems = [1,2]
                        # elif ('thetaoga' in var_list and prof=='mid' and power_str == '0.5TW'):
                        #     quad_omit_mems = [1,2]
                        # elif ('thetaoga' in var_list and prof=='bot' and power_str == '0.5TW'):
                        #     quad_omit_mems = [1,2]
                        # else:
                        #     quad_omit_mems = []
                        quad_ens_mem_list, quad_ens_mean = create_quad_ens_mean(quad_exp_name_list,doub_ens_mem_list,
                                                                                doub_ens_mean,start_yrs,end_yrs,
                                                                                avg_period,var_list,
                                                                                pp_type=pp_type,diag_file=diag_file,
                                                                                omit_mems=quad_omit_mems,debug=debug)
                        quad_ens_mean_name = f"quad_{ens_mean_root}"
                        myVars.__setitem__(quad_ens_mean_name, quad_ens_mean)
                        print(f'{quad_ens_mean_name} done')

                        if diag_file != 'ocean_scalar_monthly':
                            horiz_avg = xr.Dataset()
                            for var in var_list:
                                if var == 'Kd_int_tuned':
                                    Kd_added = quad_ens_mean['Kd_int_tuned']
                                    no_zeros_Kd = Kd_added.where(Kd_added != 0, np.nan)
                                    horiz_avg['Kd_int_tuned'] = horizontal_mean(no_zeros_Kd,quad_ens_mean)
                                else:
                                    horiz_avg[var] = horizontal_mean(quad_ens_mean[var],quad_ens_mean)
                                    
                            horiz_avg_name = f"{quad_ens_mean_name}_mean"
                            myVars.__setitem__(horiz_avg_name, horiz_avg)
                            
                            print(f'{horiz_avg_name} done')
                    

In [8]:
def get_ens_diff_and_means(co2_scen, avg_period, mem1_start, mem1_end, var_list,
                pp_type='av-annual',
                diag_file='ocean_monthly_z',
                profiles = ['surf','therm','mid','bot'],
                power_inputs = ['0.1TW', '0.2TW', '0.5TW'],
                power_var_suff = ['0p1TW', '0p2TW', '0p5TW'],
                stdev_ds=None,
                verbose=False,
                debug=False):

    """
    Creates dataset variables containing the ensemble-mean raw data and anomaly data. Anomalies are calculated as the difference relative to the control run
    during the period corresponding to an ensemble member (i.e. the anomalies for ensemble member 2 for year 201 to 400 are taking as the difference relative 
    to year 201 to 400 of the control run).

        Args:
            co2_scen (str): one of ['const','doub','quad','const+doub','all']
            avg_period (int): number of years for av/ts period
            mem1_start (int): start year of ens. mem. #1
            mem1_end (int): end year of ens. mem. #1
            var_list (str list): list of variables to read (e.g. var_list = ["temp", "N2", "age", "rhopot2", "salt"])
            profiles (str list): list of profiles to get data for
            power_inputs (str list): list of power inputs to get data for
            power_var_suff (str list): list of variable suffixes for each power input
            stdev_ds (dataset): required to compute hatching variable in calc_ens_diffs();
                                typically the standard deviation of the control run, with variables named <var>_stdev
            verbose: if True, print variable names after declaration
            
        Returns:
            has no return variables, but creates xarray datasets by using myVars = globals()
            
    """
    allowed_scen = ['const','doub','quad','const+doub','all']
    
    if co2_scen not in allowed_scen:
        raise ValueError(f"'co2_scen' must be one of {allowed_scen}.")

    ctrl_var_list = copy.deepcopy(var_list)
    ctrl_vars_to_drop = ["Kd_int_tuned", "Kd_int_base"]
    
    for elem in ctrl_vars_to_drop:
        if elem in ctrl_var_list:
            ctrl_var_list.remove(elem)

    if verbose:
        print(var_list)
        

    num_ens_mem = 3

    start_yrs = [mem1_start,
                 mem1_start+200,
                 mem1_start+400]
    end_yrs = [mem1_end,
               mem1_end+200,
               mem1_end+400]

    const_exp_root = '_1860IC_200yr_'
    mem1_doub_exp_root = '_2xCO2_1860IC_200yr_' # exp root for 2xCO2 exps is different for mem1 compared to mem2 and mem3
    mem2_3_doub_exp_root = '_2xCO2_200yr_'
    quad_exp_root = '_4xCO2_51-201_'

    ##### CONTROL RUNS #####
    
    ## const CO2 control ##
    const_ctrl_exps = ["tune_ctrl_const_200yr",#"tune_ctrl_1860IC_200yr",
                       "ctrl_1860IC_201-2001", #tune_ctrl_1860IC_201-2001
                       "ctrl_1860IC_201-2001"]

    const_ctrl_mem_list, const_ctrl = create_const_doub_ens_mean(const_ctrl_exps,start_yrs,end_yrs,avg_period,var_list,
                                                                 pp_type=pp_type,diag_file=diag_file,debug=debug)
    const_ctrl_name = f"const_ctrl_{mem1_start}_{mem1_end}"
    const_ctrl_mem_list_name = f"{const_ctrl_name}_mem_list"
    myVars.__setitem__(const_ctrl_name, const_ctrl)
    myVars.__setitem__(const_ctrl_mem_list_name, const_ctrl_mem_list)
    if verbose:
        print(f'{const_ctrl_name}, {const_ctrl_mem_list_name} done')

    if diag_file != 'ocean_scalar_monthly':
        horiz_avg = xr.Dataset()
        for var in ctrl_var_list:
            horiz_avg[var] = horizontal_mean(const_ctrl[var],const_ctrl)
    
        horiz_avg_name = f"{const_ctrl_name}_mean"
        myVars.__setitem__(horiz_avg_name, horiz_avg)
        if verbose:
            print(f'{horiz_avg_name} done')

    if co2_scen != 'const':
        ## 2xCO2 control ##
        doub_ctrl_exps = ["tune_ctrl_2xCO2_1860IC_200yr",
                          "ens2_ctrl_2xCO2_200yr",
                          "ens3_ctrl_2xCO2_200yr"]
        
        doub_ctrl_mem_list, doub_ctrl = create_const_doub_ens_mean(doub_ctrl_exps,start_yrs,end_yrs,avg_period,var_list,
                                                                   pp_type=pp_type,diag_file=diag_file,debug=debug)
        doub_ctrl_name = f"doub_ctrl_{mem1_start}_{mem1_end}"
        doub_ctrl_mem_list_name = f"{doub_ctrl_name}_mem_list"
        myVars.__setitem__(doub_ctrl_name, doub_ctrl)
        myVars.__setitem__(doub_ctrl_mem_list_name, doub_ctrl_mem_list)
        if verbose:
            print(f'{doub_ctrl_name}, {doub_ctrl_mem_list_name} done')

        if diag_file != 'ocean_scalar_monthly':
            horiz_avg = xr.Dataset()
            for var in ctrl_var_list:
                horiz_avg[var] = horizontal_mean(doub_ctrl[var],doub_ctrl)
                
            horiz_avg_name = f"{doub_ctrl_name}_mean"
            myVars.__setitem__(horiz_avg_name, horiz_avg)
            if verbose:
                print(f'{horiz_avg_name} done')

    if (co2_scen == 'doub' or co2_scen == 'const+doub' or co2_scen == 'all'):
        # differences compared to constant CO2 control #
        calc_ens_diffs(f"doub_ctrl_{mem1_start}_{mem1_end}_diff",
                       const_ctrl_mem_list,doub_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)

    if (co2_scen == 'quad' or co2_scen == 'all'):
        ## 4xCO2 control ##
        quad_ctrl_exps = ["tune_ctrl_4xCO2_51-201",
                          "ens2_ctrl_4xCO2_51-201",
                          "ens3_ctrl_4xCO2_51-201"]
        
        quad_ctrl_omit_mems = []
    
        quad_ctrl_mem_list, quad_ctrl = create_quad_ens_mean(quad_ctrl_exps,doub_ctrl_mem_list,doub_ctrl,start_yrs,end_yrs,
                                                             avg_period,var_list,pp_type=pp_type,diag_file=diag_file,
                                                             omit_mems=quad_ctrl_omit_mems,debug=debug)
        quad_ctrl_name = f"quad_ctrl_{mem1_start}_{mem1_end}"
        quad_ctrl_mem_list_name = f"{quad_ctrl_name}_mem_list"
        myVars.__setitem__(quad_ctrl_name, quad_ctrl)
        myVars.__setitem__(quad_ctrl_mem_list_name, quad_ctrl_mem_list)
        if verbose:
            print(f'{quad_ctrl_name}, {quad_ctrl_mem_list_name} done')

        if diag_file != 'ocean_scalar_monthly':
            horiz_avg = xr.Dataset()
            for var in ctrl_var_list:
                horiz_avg[var] = horizontal_mean(quad_ctrl[var],quad_ctrl)
                
            horiz_avg_name = f"{quad_ctrl_name}_mean"
            myVars.__setitem__(horiz_avg_name, horiz_avg)
            if verbose:
                print(f'{horiz_avg_name} done')
    
        # differences compared to constant CO2 and 2xCO2 controls #
        if quad_ctrl_omit_mems == [2]:
            calc_ens_diffs(f"quad_ctrl_{mem1_start}_{mem1_end}_diff_const_ctrl",
                           [const_ctrl_mem_list[0],const_ctrl_mem_list[2]],quad_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
            calc_ens_diffs(f"quad_ctrl_{mem1_start}_{mem1_end}_diff_2xctrl",
                           [doub_ctrl_mem_list[0],doub_ctrl_mem_list[2]],quad_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
        else:
            calc_ens_diffs(f"quad_ctrl_{mem1_start}_{mem1_end}_diff_const_ctrl",
                           const_ctrl_mem_list,quad_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
            calc_ens_diffs(f"quad_ctrl_{mem1_start}_{mem1_end}_diff_2xctrl",
                           doub_ctrl_mem_list,quad_ctrl_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)

    
    ##### PERTURBATION RUNS #####
    
    for prof in profiles:
        for index, power_str in enumerate(power_inputs):
            if verbose:
                print(f"{prof} {power_str} experiments")
                    
            const_exp_name_list = [None] * num_ens_mem
            doub_exp_name_list = [None] * num_ens_mem
            quad_exp_name_list = [None] * num_ens_mem

            if power_str == '0.5TW':
                const_exp_name_list[0] = "mem1_"+prof+const_exp_root+power_str
            else:
                const_exp_name_list[0] = prof+const_exp_root+power_str

            const_exp_name_list[1] = "ens2_"+prof+const_exp_root+power_str
            const_exp_name_list[2] = "mem3_"+prof+const_exp_root+power_str
            
            doub_exp_name_list[0] = prof+mem1_doub_exp_root+power_str
            doub_exp_name_list[1] = "ens2_"+prof+mem2_3_doub_exp_root+power_str
            doub_exp_name_list[2] = "ens3_"+prof+mem2_3_doub_exp_root+power_str

            quad_exp_name_list[0] = prof+quad_exp_root+power_str
            quad_exp_name_list[1] = "ens2_"+prof+quad_exp_root+power_str
            quad_exp_name_list[2] = "ens3_"+prof+quad_exp_root+power_str

            ens_mean_root = f"{prof}_{power_var_suff[index]}_{mem1_start}_{mem1_end}"

            const_ens_mem_list, const_ens_mean = create_const_doub_ens_mean(const_exp_name_list,start_yrs,end_yrs,
                                                                            avg_period,var_list,
                                                                            pp_type=pp_type,diag_file=diag_file,debug=debug)
            const_ens_mean_name = f"const_{ens_mean_root}"
            myVars.__setitem__(const_ens_mean_name, const_ens_mean)
            print(f'{const_ens_mean_name} done')

            if diag_file != 'ocean_scalar_monthly':
                horiz_avg = xr.Dataset()
                for var in var_list:
                    # this doesn't actually make any noticable difference for global means, but perhaps would be
                    # important when taking regional average
                    if var == 'Kd_int_tuned':
                        Kd_added = const_ens_mean['Kd_int_tuned']
                        no_zeros_Kd = Kd_added.where(Kd_added != 0, np.nan)
                        horiz_avg['Kd_int_tuned'] = horizontal_mean(no_zeros_Kd,const_ens_mean)
                    else:
                        horiz_avg[var] = horizontal_mean(const_ens_mean[var],const_ens_mean)
                horiz_avg_name = f"{const_ens_mean_name}_mean"
                myVars.__setitem__(horiz_avg_name, horiz_avg)
                print(f'{horiz_avg_name} done')
            
            
            if co2_scen != 'const':
                doub_ens_mem_list, doub_ens_mean = create_const_doub_ens_mean(doub_exp_name_list,start_yrs,end_yrs,
                                                                              avg_period,var_list,
                                                                              pp_type=pp_type,diag_file=diag_file,debug=debug)
                doub_ens_mean_name = f"doub_{ens_mean_root}"
                myVars.__setitem__(doub_ens_mean_name, doub_ens_mean)
                print(f'{doub_ens_mean_name} done')

                if diag_file != 'ocean_scalar_monthly':
                    horiz_avg = xr.Dataset()
                    for var in var_list:
                        if var == 'Kd_int_tuned':
                            Kd_added = doub_ens_mean['Kd_int_tuned']
                            no_zeros_Kd = Kd_added.where(Kd_added != 0, np.nan)
                            horiz_avg['Kd_int_tuned'] = horizontal_mean(no_zeros_Kd,doub_ens_mean)
                        else:
                            horiz_avg[var] = horizontal_mean(doub_ens_mean[var],doub_ens_mean)
                    horiz_avg_name = f"{doub_ens_mean_name}_mean"
                    myVars.__setitem__(horiz_avg_name, horiz_avg)
                    print(f'{horiz_avg_name} done')
                
            if (co2_scen == 'quad' or co2_scen == 'all'):
                quad_omit_mems = []
                    
                quad_ens_mem_list, quad_ens_mean = create_quad_ens_mean(quad_exp_name_list,doub_ens_mem_list,
                                                                        doub_ens_mean,start_yrs,end_yrs,
                                                                        avg_period,var_list,
                                                                        pp_type=pp_type,diag_file=diag_file,
                                                                        omit_mems=quad_omit_mems,debug=debug)
                quad_ens_mean_name = f"quad_{ens_mean_root}"
                myVars.__setitem__(quad_ens_mean_name, quad_ens_mean)
                print(f'{quad_ens_mean_name} done')

                if diag_file != 'ocean_scalar_monthly':
                    horiz_avg = xr.Dataset()
                    for var in var_list:
                        if var == 'Kd_int_tuned':
                            Kd_added = quad_ens_mean['Kd_int_tuned']
                            no_zeros_Kd = Kd_added.where(Kd_added != 0, np.nan)
                            horiz_avg['Kd_int_tuned'] = horizontal_mean(no_zeros_Kd,quad_ens_mean)
                        else:
                            horiz_avg[var] = horizontal_mean(quad_ens_mean[var],quad_ens_mean)
                    horiz_avg_name = f"{quad_ens_mean_name}_mean"
                    myVars.__setitem__(horiz_avg_name, horiz_avg)
                    print(f'{horiz_avg_name} done')

            ## COMPUTE DIFFERENCES ##
            diff_root = f"{prof}_{power_var_suff[index]}_{mem1_start}_{mem1_end}_diff"

            # differences wrt 1860 control
            const_diff_name = f"const_{diff_root}"
            doub_const_ctrl_diff_name = f"doub_{diff_root}_const_ctrl"
            quad_const_ctrl_diff_name = f"quad_{diff_root}_const_ctrl"

            # differences wrt 1860 experiment with same diffusivity history
            doub_1860_diff_name = f"doub_{diff_root}_1860"
            quad_1860_diff_name = f"quad_{diff_root}_1860"

            # differences wrt control for particular CO2 scenario
            doub_2xctrl_diff_name = f"doub_{diff_root}_2xctrl"
            quad_4xctrl_diff_name = f"quad_{diff_root}_4xctrl"

            ## CONST EXPERIMENTS
            if (co2_scen == 'const' or co2_scen == 'const+doub' or co2_scen == 'all'):
                calc_ens_diffs(const_diff_name,const_ctrl_mem_list,
                               const_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)

            ## 2xCO2 EXPERIMENTS
            if (co2_scen == 'doub' or co2_scen == 'const+doub' or co2_scen == 'all'):
                # differences wrt 1860 control
                calc_ens_diffs(doub_const_ctrl_diff_name,const_ctrl_mem_list,
                               doub_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                    
                # differences wrt 1860 experiment with same Kd history
                calc_ens_diffs(doub_1860_diff_name,const_ens_mem_list,
                               doub_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                
                # differences wrt control for particular CO2 scenario
                calc_ens_diffs(doub_2xctrl_diff_name,doub_ctrl_mem_list,
                               doub_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                
            ## 4xCO2 EXPERIMENTS
            if (co2_scen == 'quad' or co2_scen == 'all'):
                # differences wrt 1860 control
                if quad_omit_mems == [2]:
                    calc_ens_diffs(quad_const_ctrl_diff_name,[const_ctrl_mem_list[0],const_ctrl_mem_list[2]],
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                elif quad_omit_mems == [1,2]:
                    calc_ens_diffs(quad_const_ctrl_diff_name,[const_ctrl_mem_list[2]],
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                else:
                    calc_ens_diffs(quad_const_ctrl_diff_name,const_ctrl_mem_list,
                                   quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                    
                # differences wrt 1860 experiment with same Kd history
                if quad_omit_mems == [2]:
                    calc_ens_diffs(quad_1860_diff_name,[const_ens_mem_list[0],const_ens_mem_list[2]],
                                   quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                elif quad_omit_mems == [1,2]:
                    calc_ens_diffs(quad_1860_diff_name,[const_ens_mem_list[2]],
                                   quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                else:
                    calc_ens_diffs(quad_1860_diff_name,const_ens_mem_list,
                                   quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                # differences wrt control for particular CO2 scenario
                calc_ens_diffs(quad_4xctrl_diff_name,quad_ctrl_mem_list,
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)

                # additional difference calcs for 4xCO2 cases #
                # difference wrt 2xCO2 ctrl
                quad_2xctrl_diff_name = f"quad_{diff_root}_2xctrl"
                calc_ens_diffs(quad_2xctrl_diff_name,doub_ctrl_mem_list,
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
                # difference wrt 2xCO2 experiment with same diffusivity history
                quad_2xCO2_diff_name = f"quad_{diff_root}_2xCO2"
                calc_ens_diffs(quad_2xCO2_diff_name,doub_ens_mem_list,
                               quad_ens_mem_list,var_list,diag_file=diag_file,stdev_ds=stdev_ds,verbose=verbose)
