In [2]:
import functools
import os
import warnings

import dask
import numpy as np
import pandas as pd
import xarray as xr

### Preliminaries

In [3]:
###############################
# Set paths
# UPDATE THIS FOR REPRODUCTION
###############################
nex_in = "/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/nex-gddp/"  # location of NEX-GDDP metrics
cil_in = "/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/cil-gdpcir/"  # location of CIL-GDPCIR metrics
isi_in = "/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/isimip3b/regridded/conservative/"  # location of *regridded* ISIMIP metrics
cbp_in = "/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/metrics/carbonplan/"  # location of carbonplan metrics

out_path = "/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/results/"  # where to store UC results
poly_path = "/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/forced_response/"  # where to store extracted forced responses
iav_path = "/gpfs/group/kaf26/default/dcl5300/lafferty-sriver_inprep_tbh_DATA/interannual_variability/"  # where to store rolling IAV estimates

In [4]:
###################
# Models
###################
from utils import cil_ssp_dict, deepsdbc_dict, gardsv_ssp_dict, gardsv_var_dict, isimip_ssp_dict, nex_ssp_dict

nex_models = list(nex_ssp_dict.keys())
cil_models = list(cil_ssp_dict.keys())
isi_models = list(isimip_ssp_dict.keys())
cbp_gard_models = list(gardsv_ssp_dict.keys())
cbp_gard_precip_models = [model for model in cbp_gard_models if "pr" in gardsv_var_dict[model]]
cbp_deep_models = list(deepsdbc_dict.keys())

In [5]:
#######################
# Land mask (from NEX)
#######################
land_mask = xr.open_dataset(nex_in + "avg/CanESM5.nc")
land_mask = land_mask.isel(ssp=0, time=0).tas.isnull()
land_mask["lon"] = np.where(land_mask["lon"] > 180, land_mask["lon"] - 360, land_mask["lon"])
land_mask = land_mask.sortby("lon")

In [6]:
############
# Dask
############
from dask_jobqueue import PBSCluster

cluster = PBSCluster(
    cores=1,
    memory="40GB",
    resource_spec="pmem=40GB",
    # account='open',
    worker_extra_args=["#PBS -l feature=rhel7"],
    walltime="06:00:00",
)

cluster.scale(jobs=40)  # ask for jobs

from dask.distributed import Client

client = Client(cluster)

client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.102.201.224:42639,Workers: 0
Dashboard: /proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# Total uncertainty: gridded metrics

### Function definition

In [None]:
#######################################################################
# Total uncertainty: variance across all models, scenarios, ensembles
#######################################################################
def uc_total(
    nex_in,
    nex_models,
    cil_in,
    cil_models,
    isi_in,
    isi_models,
    cbp_in,
    cbp_gard_models,
    cbp_deep_models,
    land_mask,
    metric,
    submetric,
    year,
):
    """
    Reads in all models, ssps, and calculates the total uncertainty (variance across
    all model, ssp, ensemble dimensions) for a given year (and possibly DataArray).
    For metrics like 'hot' where there are several sub-metrics based on different
    thresholds and/or observational data, we need to select a specific DataArray
    to keep the memory manageable.
    """

    # Subfunction for general preprocessing of each model/ensemble
    def read_and_process(ensemble, path_in, model, year, metric, submetric):
        # Read netcdf or zarr
        if ensemble in ["NEX", "ISIMIP", "GARD-SV"]:
            ds = xr.open_dataset(path_in + metric + "/" + model + ".nc")
        elif ensemble in ["CIL", "DeepSD-BC"]:
            ds = xr.open_dataset(path_in + metric + "/" + model, engine="zarr")

        # Select submetric if chosen
        if submetric:
            ds = ds[submetric]

        # Common preprocessing
        ds["time"] = ds.indexes["time"].year
        ds = ds.sel(time=year)
        ds = ds.sortby("ssp")
        ds = ds.assign_coords(ensemble=ensemble)
        ds = ds.sel(lat=slice(-60, 90))

        # Add model dimension
        if model[-6:] in ["tasmin", "tasmax"]:
            model_str = model[:-7]
        else:
            model_str = model
        ds = ds.assign_coords(model=model_str)

        # Fix lon to [-180,180]
        if ds.lon.max() > 180:
            ds["lon"] = np.where(ds["lon"] > 180, ds["lon"] - 360, ds["lon"])
            ds = ds.sortby("lon")

        # Some models/methods are missing precip so fill with NaNs
        if (metric in ["max", "avg"]) and ("pr" not in ds.data_vars):
            ds["pr"] = xr.full_like(ds[list(ds.data_vars)[0]], np.nan)
        if (metric == "max5d") and ("RX5day" not in ds.data_vars):
            ds["RX5day"] = xr.full_like(ds[list(ds.data_vars)[0]], np.nan)

        # Drop member_id
        if "member_id" in list(ds.coords):
            ds = ds.isel(member_id=0).drop("member_id")

        # Return
        return ds

    ######################
    # Read all ensembles
    ######################
    # NEX-GDDP
    ds_out = []
    for model in nex_models:
        ds_out.append(read_and_process("NEX", nex_in, model, year, metric, submetric))
    ds_nex = xr.concat(ds_out, dim="model", fill_value=np.nan)

    # CIL-GDPCIR
    ds_out = []
    for model in cil_models:
        ds_out.append(read_and_process("CIL", cil_in, model, year, metric, submetric))
    ds_cil = xr.concat(ds_out, dim="model", fill_value=np.nan)

    # ISIMIP
    ds_out = []
    for model in isi_models:
        ds_out.append(read_and_process("ISIMIP", isi_in, model, year, metric, submetric))
    ds_isi = xr.concat(ds_out, dim="model", fill_value=np.nan)

    # carbonplan: GARD-SV
    ds_out = []
    for model in cbp_gard_models:
        ds_out.append(
            read_and_process(
                "GARD-SV",
                cbp_in + "/regridded/conservative/GARD-SV/",
                model,
                year,
                metric,
                submetric,
            )
        )
    ds_cbp_gard = xr.concat(ds_out, dim="model", fill_value=np.nan)

    # carbonplan: DeepSD-BC
    ds_out = []
    for model in cbp_deep_models:
        ds_out.append(
            read_and_process(
                "DeepSD-BC",
                cbp_in + "native_grid/DeepSD-BC/",
                model,
                year,
                metric,
                submetric,
            )
        )
    ds_cbp_deep = xr.concat(ds_out, dim="model", fill_value=np.nan)

    ###########################
    # Merge all and mask ocean
    ###########################
    ds = xr.concat(
        [ds_nex, ds_cil, ds_isi, ds_cbp_gard, ds_cbp_deep],
        dim="ensemble",
        fill_value=np.nan,
    )

    # Mask out ocean points
    ds = xr.where(land_mask, np.nan, ds)

    ##########################
    # Uncertainty calculation
    ##########################
    ## Total uncertainty
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        U_total_true = ds.var(dim=["ensemble", "ssp", "model"], skipna=True)  # throws warning when all NaNs

    U_total_true = U_total_true.assign_coords(uncertainty="total_true")

    return U_total_true

### Computation

In [8]:
%%time

metric = "avg"

# Dask delayed over years
delayed_res = []
for year in range(2015, 2100):
    # Read all ensembles and compute total uncertainty
    tmp_res = dask.delayed(uc_total)(
        nex_in,
        nex_models,
        cil_in,
        cil_models,
        isi_in,
        isi_models,
        cbp_in,
        cbp_gard_models,
        cbp_deep_models,
        land_mask,
        metric,
        False,
        year,
    )

    # Append
    delayed_res.append(tmp_res)

# Compute
res = dask.compute(*delayed_res)

# Merge and store
ds_out = xr.concat(res, dim="time")
ds_out.to_netcdf(out_path + "total_uncertainty/" + metric + ".nc")

CPU times: user 7min 6s, sys: 19.6 s, total: 7min 25s
Wall time: 53min 50s


In [9]:
%%time

metric = "max"

# Dask delayed over years
delayed_res = []
for year in range(2015, 2100):
    # Read all ensembles and compute total uncertainty
    tmp_res = dask.delayed(uc_total)(
        nex_in,
        nex_models,
        cil_in,
        cil_models,
        isi_in,
        isi_models,
        cbp_in,
        cbp_gard_models,
        cbp_deep_models,
        land_mask,
        metric,
        False,
        year,
    )

    # Append
    delayed_res.append(tmp_res)

# Compute
res = dask.compute(*delayed_res)

# Merge and store
ds_out = xr.concat(res, dim="time")
ds_out.to_netcdf(out_path + "total_uncertainty/" + metric + ".nc")

CPU times: user 10min 15s, sys: 26.9 s, total: 10min 42s
Wall time: 1h 19min 52s


In [None]:
%%time

metric = "max5d"

# Dask delayed over years
delayed_res = []
for year in range(2015, 2100):
    # Read all ensembles and compute total uncertainty
    tmp_res = dask.delayed(uc_total)(
        nex_in,
        nex_models,
        cil_in,
        cil_models,
        isi_in,
        isi_models,
        cbp_in,
        cbp_gard_models,
        cbp_deep_models,
        land_mask,
        metric,
        False,
        year,
    )

    # Append
    delayed_res.append(tmp_res)

# Compute
res = dask.compute(*delayed_res)

# Merge and store
ds_out = xr.concat(res, dim="time")
ds_out.to_netcdf(out_path + "total_uncertainty/" + metric + ".nc")

In [11]:
%%time

metric = "dry"

# Dask delayed over years
delayed_res = []
for year in range(2015, 2100):
    # Read all ensembles and compute total uncertainty
    tmp_res = dask.delayed(uc_total)(
        nex_in,
        nex_models,
        cil_in,
        cil_models,
        isi_in,
        isi_models,
        cbp_in,
        cbp_gard_precip_models,
        cbp_deep_models,
        land_mask,
        metric,
        False,
        year,
    )

    # Append
    delayed_res.append(tmp_res)

# Compute
res = dask.compute(*delayed_res)

# Merge and store
ds_out = xr.concat(res, dim="time")
ds_out.to_netcdf(out_path + "total_uncertainty/" + metric + ".nc")

CPU times: user 3min 21s, sys: 10.3 s, total: 3min 32s
Wall time: 26min 46s


In [16]:
%%time
# Hot + dry days
metric = "hotdry"
for thresh in ["q99", "rp10"]:
    for obs in ["gmfd", "era5"]:
        submetric_str = thresh + obs
        submetric = [
            "hotdry_" + submetric_str + "_count",
            "hotdry_" + submetric_str + "_streak",
        ]

        # Dask delayed over years
        delayed_res = []
        for year in range(2015, 2100):
            # Read all ensembles and compute total uncertainty
            tmp_res = dask.delayed(uc_total)(
                nex_in,
                nex_models,
                cil_in,
                cil_models,
                isi_in,
                isi_models,
                cbp_in,
                cbp_gard_precip_models,
                cbp_deep_models,
                land_mask,
                metric,
                submetric,
                year,
            )

            # Append
            delayed_res.append(tmp_res)

        # Compute
        res = dask.compute(*delayed_res)

        # Merge and store
        ds_out = xr.concat(res, dim="time")
        ds_out.to_netcdf(out_path + "total_uncertainty/" + metric + "_" + submetric_str + ".nc")

CPU times: user 53.5 s, sys: 9.57 s, total: 1min 3s
Wall time: 10min 15s


In [17]:
%%time
# Wet days
metric = "wet"
for thresh in ["q99", "rp10"]:
    for obs in ["gmfd", "era5"]:
        submetric_str = thresh + obs
        submetric = [
            "pr_" + submetric_str + "_count",
            "pr_" + submetric_str + "_streak",
        ]

        # Dask delayed over years
        delayed_res = []
        for year in range(2015, 2100):
            # Read all ensembles and compute total uncertainty
            tmp_res = dask.delayed(uc_total)(
                nex_in,
                nex_models,
                cil_in,
                cil_models,
                isi_in,
                isi_models,
                cbp_in,
                cbp_gard_precip_models,
                cbp_deep_models,
                land_mask,
                metric,
                submetric,
                year,
            )

            # Append
            delayed_res.append(tmp_res)

        # Compute
        res = dask.compute(*delayed_res)

        # Merge and store
        ds_out = xr.concat(res, dim="time")
        ds_out.to_netcdf(out_path + "total_uncertainty/" + metric + "_" + submetric_str + ".nc")

CPU times: user 58 s, sys: 9.06 s, total: 1min 7s
Wall time: 11min 12s


In [7]:
%%time
# Hot days
metric = "hot"
for thresh in ["q99", "rp10"]:
    for obs in ["gmfd", "era5"]:
        for submetric_var in ["tas", "tasmin", "tasmax"]:
            submetric_str = submetric_var + "_" + thresh + obs
            submetric = [submetric_str + "_count", submetric_str + "_streak"]

            # Dask delayed over years
            delayed_res = []
            for year in range(2015, 2100):
                # Read all ensembles and compute total uncertainty
                tmp_res = dask.delayed(uc_total)(
                    nex_in,
                    [model + "_" + submetric_var for model in nex_models],
                    cil_in,
                    cil_models,
                    isi_in,
                    [model + "_" + submetric_var for model in isi_models],
                    cbp_in,
                    [model + "_" + submetric_var for model in cbp_gard_models],
                    cbp_deep_models,
                    land_mask,
                    metric,
                    submetric,
                    year,
                )

                # Append
                delayed_res.append(tmp_res)

            # Compute
            res = dask.compute(*delayed_res)

            # Merge and store
            ds_out = xr.concat(res, dim="time")
            ds_out.to_netcdf(out_path + "total_uncertainty/" + metric + "_" + submetric_str + ".nc")

CPU times: user 7min 42s, sys: 33.1 s, total: 8min 15s
Wall time: 58min 34s


# Uncertainty partitioning: gridded metrics

### Extracting the forced response

In [16]:
def get_forced_poly(
    nex_in,
    nex_models,
    cil_in,
    cil_models,
    isi_in,
    isi_models,
    cbp_in,
    cbp_gard_models,
    cbp_deep_models,
    land_mask,
    metric,
    submetric,
    submetric_var,
    poly_path,
    deg,
):
    """
    Reads in all models, ssps, and calculates the 'forced response' as a deg-th order
    polynomial.
    For metrics like 'hot' where there are several sub-metrics based on different
    thresholds and/or observational data, we need to select a specific DataArray
    to keep the memory manageable.
    """

    # Subfunction to get polynomial for each model/ensemble
    def get_poly_coeffs(ensemble, path_in, model, metric, submetric, submetric_var, deg):
        # Read netcdf or zarr
        if ensemble in ["NEX", "ISIMIP", "GARD-SV"]:
            if submetric_var:
                ds = xr.open_dataset(path_in + metric + "/" + model + "_" + submetric_var + ".nc")
            else:
                ds = xr.open_dataset(path_in + metric + "/" + model + ".nc")
        elif ensemble in ["CIL", "DeepSD-BC"]:
            ds = xr.open_dataset(path_in + metric + "/" + model, engine="zarr")

        # Select submetric if chosen
        if submetric:
            ds = ds[submetric]

        # Common preprocessing
        ds = ds.sel(lat=slice(-60, 90))
        ds = ds.sortby("ssp")
        if ds.lon.max() > 180:
            ds["lon"] = np.where(ds["lon"] > 180, ds["lon"] - 360, ds["lon"])
            ds = ds.sortby("lon")

        # Forced response via polynomial
        ds = xr.polyval(coord=ds["time"], coeffs=ds.polyfit(dim="time", deg=deg))
        ds = ds.rename({name: name.replace("_polyfit_coefficients", "") for name in list(ds.data_vars)})

        # Construct output name: assumes submetrics are of the form ['X_count', 'X_streak']
        out_str = poly_path + metric + "/"

        if submetric:
            submetric_str = submetric[0].replace("_count", "").replace("_streak", "")
            out_str = out_str + submetric_str + "_"

        # Drop member_id
        if "member_id" in list(ds.coords):
            ds = ds.isel(member_id=0).drop("member_id")

        ds.to_netcdf(out_str + ensemble + "_" + model + "_deg" + str(deg) + ".nc")

    # For checking if file exists
    def construct_out_str(poly_path, ensemble, model, metric, submetric, submetric_var, deg):
        # Construct output name: assumes submetrics are of the form ['X_count', 'X_streak']
        out_str = poly_path + metric + "/"

        if submetric:
            submetric_str = submetric[0].replace("_count", "").replace("_streak", "")
            out_str = out_str + submetric_str + "_"

        return out_str + ensemble + "_" + model + "_deg" + str(deg) + ".nc"

    #######################
    # Apply to all ensembles
    #######################
    # Dask delayed
    res = []

    # NEX-GDDP
    for model in nex_models:
        if not os.path.isfile(construct_out_str(poly_path, "NEX", model, metric, submetric, submetric_var, deg)):
            res.append(dask.delayed(get_poly_coeffs)("NEX", nex_in, model, metric, submetric, submetric_var, deg))

    # CIL-GDPCIR
    for model in cil_models:
        if not os.path.isfile(construct_out_str(poly_path, "CIL", model, metric, submetric, submetric_var, deg)):
            res.append(dask.delayed(get_poly_coeffs)("CIL", cil_in, model, metric, submetric, submetric_var, deg))

    # ISIMIP
    for model in isi_models:
        if not os.path.isfile(construct_out_str(poly_path, "ISIMIP", model, metric, submetric, submetric_var, deg)):
            res.append(dask.delayed(get_poly_coeffs)("ISIMIP", isi_in, model, metric, submetric, submetric_var, deg))

    # carbonplan: GARD-SV
    for model in cbp_gard_models:
        if not os.path.isfile(construct_out_str(poly_path, "GARD-SV", model, metric, submetric, submetric_var, deg)):
            res.append(
                dask.delayed(get_poly_coeffs)(
                    "GARD-SV",
                    cbp_in + "/regridded/conservative/GARD-SV/",
                    model,
                    metric,
                    submetric,
                    submetric_var,
                    deg,
                )
            )

    # carbonplan: DeepSD-BC
    for model in cbp_deep_models:
        if not os.path.isfile(construct_out_str(poly_path, "DeepSD-BC", model, metric, submetric, submetric_var, deg)):
            res.append(
                dask.delayed(get_poly_coeffs)(
                    "DeepSD-BC",
                    cbp_in + "native_grid/DeepSD-BC/",
                    model,
                    metric,
                    submetric,
                    submetric_var,
                    deg,
                )
            )

    # Compute
    if len(res) > 0:
        dask.compute(*res)

### Uncertainty characterization of forced response

In [7]:
def uc_forced(
    nex_in,
    nex_models,
    cil_in,
    cil_models,
    isi_in,
    isi_models,
    cbp_in,
    cbp_gard_models,
    cbp_deep_models,
    land_mask,
    metric,
    submetric,
    submetric_var,
    year,
    poly_path,
    deg,
    weighted,
):
    """
    Reads in all models, ssps, and calculates the uncertainty in the 'forced response'
    along each dimension (ssp, model, ens) for a given year (and possibly DataArray).
    For metrics like 'hot' where there are several sub-metrics based on different
    thresholds and/or observational data, we need to select a specific DataArray
    to keep the memory manageable.
    """

    # Subfunction for general preprocessing of each model/ensemble
    def read_and_process(ensemble, path_in, model, year, metric, submetric, submetric_var, deg):
        # Polynomial responses should have already been calculated
        poly_str = poly_path + metric + "/"
        if submetric:
            submetric_str = submetric[0].replace("_count", "").replace("_streak", "")
            poly_str = poly_str + submetric_str + "_"
        ds = xr.open_dataset(poly_str + ensemble + "_" + model + "_deg" + str(deg) + ".nc")
        ds["time"] = ds.indexes["time"].year
        ds = ds.sel(time=year)

        # Select submetric if chosen
        if submetric:
            ds = ds[submetric]

        # Common preprocessing
        ds = ds.sel(lat=slice(-60, 90))
        ds = ds.sortby("ssp")
        ds = ds.assign_coords(ensemble=ensemble)
        ds = ds.assign_coords(model=model)

        # Fix lon to [-180,180]
        if ds.lon.max() > 180:
            ds["lon"] = np.where(ds["lon"] > 180, ds["lon"] - 360, ds["lon"])
            ds = ds.sortby("lon")

        # Some models/methods are missing precip so fill with NaNs
        if (metric in ["max", "avg"]) and ("pr" not in ds.data_vars):
            ds["pr"] = xr.full_like(ds[list(ds.data_vars)[0]], np.nan)
        if (metric == "max5d") and ("RX5day" not in ds.data_vars):
            ds["RX5day"] = xr.full_like(ds[list(ds.data_vars)[0]], np.nan)

        # Drop member_id
        if "member_id" in list(ds.coords):
            ds = ds.isel(member_id=0).drop("member_id")

        # Forced response should always be >= 0
        if metric in ["hot", "wet", "dry", "hotdry"]:
            ds = xr.where(ds >= 0, ds, 0)

        # Return
        return ds

    ######################
    # Read all ensembles
    ######################
    # NEX-GDDP
    ds_out = []
    for model in nex_models:
        ds_out.append(read_and_process("NEX", nex_in, model, year, metric, submetric, submetric_var, deg))
    ds_nex = xr.concat(ds_out, dim="model", fill_value=np.nan)

    # CIL-GDPCIR
    ds_out = []
    for model in cil_models:
        ds_out.append(read_and_process("CIL", cil_in, model, year, metric, submetric, submetric_var, deg))
    ds_cil = xr.concat(ds_out, dim="model", fill_value=np.nan)

    # ISIMIP
    ds_out = []
    for model in isi_models:
        ds_out.append(read_and_process("ISIMIP", isi_in, model, year, metric, submetric, submetric_var, deg))
    ds_isi = xr.concat(ds_out, dim="model", fill_value=np.nan)

    # carbonplan: GARD-SV
    ds_out = []
    for model in cbp_gard_models:
        ds_out.append(
            read_and_process(
                "GARD-SV",
                cbp_in + "/regridded/conservative/GARD-SV/",
                model,
                year,
                metric,
                submetric,
                submetric_var,
                deg,
            )
        )
    ds_cbp_gard = xr.concat(ds_out, dim="model", fill_value=np.nan)

    # carbonplan: DeepSD-BC
    ds_out = []
    for model in cbp_deep_models:
        ds_out.append(
            read_and_process(
                "DeepSD-BC",
                cbp_in + "native_grid/DeepSD-BC/",
                model,
                year,
                metric,
                submetric,
                submetric_var,
                deg,
            )
        )
    ds_cbp_deep = xr.concat(ds_out, dim="model", fill_value=np.nan)

    ###########################
    # Merge all and mask ocean
    ###########################
    ds = xr.concat(
        [ds_nex, ds_cil, ds_isi, ds_cbp_gard, ds_cbp_deep],
        dim="ensemble",
        fill_value=np.nan,
    )
    ds = xr.where(land_mask, np.nan, ds)

    ##########################
    # Uncertainty calculation
    ##########################

    ## Scenario uncertainty
    # HS09 approach: variance across multi-model means
    U_scen_hs09 = ds.mean(dim=["model", "ensemble"]).var(dim="ssp")
    # BB13 approach: variance across scenarios, averaged over models and ensembles (no weighting)
    U_scen_bb13 = ds.var(dim="ssp").mean(dim=["model", "ensemble"])

    ##  Model uncertainty
    # Variance across models, averaged over scenarios and ensembles
    U_model = ds.var(dim="model")

    if weighted:
        weights = (
            ds.isel(lat=300, lon=800)[list(ds.data_vars)[0]].count(dim="model").rename("weights")
        )  # weights (choose point over land)
        weights = xr.where(weights == 1, 0, weights)  # remove combinations where variance was calculated over 1 entry
        U_model = U_model.weighted(weights).mean(dim=["ssp", "ensemble"])  # weighted average
    else:
        U_model = U_model.mean(dim=["ssp", "ensemble"])  # simple average

    ## Downscaling uncertainy
    # Variance across ensembles, averaged over models and scenarios
    U_ens = ds.var(dim="ensemble")

    if weighted:
        weights = (
            ds.isel(lat=300, lon=800)[list(ds.data_vars)[0]].count(dim="ensemble").rename("weights")
        )  # weights (choose point over land)
        weights = xr.where(weights == 1, 0, weights)  # remove combinations where variance was calculated over 1 entry
        U_ens = U_ens.weighted(weights).mean(dim=["ssp", "model"])  # weighted average
    else:
        U_ens = U_ens.mean(dim=["ssp", "model"])  # simple average

    ## Merge and return
    U_model = U_model.assign_coords(uncertainty="model")
    U_ens = U_ens.assign_coords(uncertainty="ensemble")
    U_scen_hs09 = U_scen_hs09.assign_coords(uncertainty="scenario_hs09")
    U_scen_bb13 = U_scen_bb13.assign_coords(uncertainty="scenario_bb13")

    U_out = xr.concat([U_scen_hs09, U_scen_bb13, U_model, U_ens], dim="uncertainty")

    U_out = U_out.assign_coords(time=year)

    return U_out

In [8]:
def calculate_forced_uc(metric, submetric, submetric_var, poly_path, deg, weighted, save_str):
    """
    Calculate the uncertainty paritition of the forced response
    for a given selection of parameters/settings
    """
    if os.path.isfile(out_path + "uncertainty_partitioning/" + metric + "_" + save_str + ".nc"):
        return None

    delayed_res = []

    # carbonplan GARD-SV precip models
    if metric in ["wet", "dry", "hotdry"]:
        cbp_gard_models_in = cbp_gard_precip_models
    else:
        cbp_gard_models_in = cbp_gard_models

    for year in range(2015, 2100):
        # Read all ensembles and compute UC
        tmp_res = dask.delayed(uc_forced)(
            nex_in,
            nex_models,
            cil_in,
            cil_models,
            isi_in,
            isi_models,
            cbp_in,
            cbp_gard_models_in,
            cbp_deep_models,
            land_mask,
            metric,
            submetric,
            submetric_var,
            year,
            poly_path,
            deg,
            weighted,
        )

        # Append
        delayed_res.append(tmp_res)

    # Compute
    res = dask.compute(*delayed_res)

    # Merge and store
    ds_out = xr.concat(res, dim="time")
    ds_out.to_netcdf(out_path + "uncertainty_partitioning/" + metric + "_" + save_str + ".nc")

### Measure of interannual variability

In [17]:
def calculate_iav(
    path_in,
    ensemble,
    model,
    land_mask,
    metric,
    submetric,
    submetric_var,
    poly_path,
    deg,
    const_iav,
    iav_path,
):
    """
    Calculates the internal variability (variance over all years
    of residuals from forced response) for a given model-ssp-ensemble
    """

    # Subfunction for general preprocessing of each model/ensemble
    def read_and_process(ensemble, path_in, model, metric, submetric, submetric_var):
        # Read netcdf or zarr
        if ensemble in ["NEX", "ISIMIP", "GARD-SV"]:
            if submetric_var:
                ds = xr.open_dataset(path_in + metric + "/" + model + "_" + submetric_var + ".nc")
            else:
                ds = xr.open_dataset(path_in + metric + "/" + model + ".nc")
        elif ensemble in ["CIL", "DeepSD-BC"]:
            ds = xr.open_dataset(path_in + metric + "/" + model, engine="zarr")

        # Some models/methods are missing precip so fill with NaNs
        if (metric in ["max", "avg"]) and ("pr" not in ds.data_vars):
            ds["pr"] = xr.full_like(ds[list(ds.data_vars)[0]], np.nan)
        if (metric == "max5d") and ("RX5day" not in ds.data_vars):
            ds["RX5day"] = xr.full_like(ds[list(ds.data_vars)[0]], np.nan)

        # Select submetric if chosen
        if submetric:
            ds = ds[submetric]

        # Common preprocessing
        ds["time"] = ds.indexes["time"].year
        ds = ds.sortby("ssp")
        ds = ds.assign_coords(ensmod=ensemble + "__" + model)
        ds = ds.sel(lat=slice(-60, 90))

        # Fix lon to [-180,180]
        if ds.lon.max() > 180:
            ds["lon"] = np.where(ds["lon"] > 180, ds["lon"] - 360, ds["lon"])
            ds = ds.sortby("lon")

        # Drop member_id
        if "member_id" in list(ds.coords):
            ds = ds.isel(member_id=0).drop("member_id")

        # Return
        return ds

    #########################
    # Check if already done
    # for non-const case
    #########################
    if not const_iav:
        out_str = iav_path + metric + "/"

        if submetric:
            submetric_str = submetric[0].replace("_count", "").replace("_streak", "")
            out_str = out_str + submetric_str + "_"

        if os.path.isfile(out_str + ensemble + "_" + model + "_deg" + str(deg) + ".nc"):
            return None

    ###################
    # Read raw outputs
    ###################
    ds = read_and_process(ensemble, path_in, model, metric, submetric, submetric_var)
    # Mask out ocean points
    ds = xr.where(land_mask, np.nan, ds)

    ########################
    # Read forced response
    ########################
    poly_str = poly_path + metric + "/"
    if submetric:
        submetric_str = submetric[0].replace("_count", "").replace("_streak", "")
        poly_str = poly_str + submetric_str + "_"

    ds_forced = xr.open_dataset(poly_str + ensemble + "_" + model + "_deg" + str(deg) + ".nc")
    ds_forced["time"] = ds_forced.indexes["time"].year

    # Forced response should always be >= 0
    if metric in ["hot", "wet", "dry", "hotdry"]:
        ds_forced = xr.where(ds_forced >= 0, ds_forced, 0)

    # Some models/methods are missing precip so fill with NaNs
    if (metric in ["max", "avg"]) and ("pr" not in ds_forced.data_vars):
        ds_forced["pr"] = xr.full_like(ds_forced[list(ds_forced.data_vars)[0]], np.nan)
    if (metric == "max5d") and ("RX5day" not in ds_forced.data_vars):
        ds_forced["RX5day"] = xr.full_like(ds_forced[list(ds_forced.data_vars)[0]], np.nan)

    ################################
    # Get IAV estimate
    # Variance of residuals
    ################################
    # IAV can be constant value or rolling
    if const_iav:
        iav = (ds - ds_forced).var(dim="time")
        return iav
    else:
        iav = (ds - ds_forced).rolling(time=11, center=True).var()

        iav.to_netcdf(out_str + ensemble + "_" + model + "_deg" + str(deg) + ".nc")

        # #### NOTE: these time slices need to be the same as for the UC map plot!
        # iav_early = (ds - ds_forced).sel(time=slice(2020,2039)).var(dim='time').assign_coords(time = 'early')
        # iav_mid = (ds - ds_forced).sel(time=slice(2050,2069)).var(dim='time').assign_coords(time = 'mid')
        # iav_late = (ds - ds_forced).sel(time=slice(2080,2099)).var(dim='time').assign_coords(time = 'late')
        # iav = xr.concat([iav_early, iav_mid, iav_late], dim='time')

In [18]:
def make_delayed_list_iav(metric, submetric, submetric_var, poly_path, deg, const_iav, iav_path):
    """
    Make a delayed list with IAV of all models-ssps-ensembles which
    can then be combined into one dataset and averaged for best estimate.
    """
    # Parallelize with dask over models
    delayed_res = []

    # NEX
    for model in nex_models:
        tmp_res = dask.delayed(calculate_iav)(
            nex_in,
            "NEX",
            model,
            land_mask,
            metric,
            submetric,
            submetric_var,
            poly_path,
            deg,
            const_iav,
            iav_path,
        )
        delayed_res.append(tmp_res)

    # CIL
    for model in cil_models:
        tmp_res = dask.delayed(calculate_iav)(
            cil_in,
            "CIL",
            model,
            land_mask,
            metric,
            submetric,
            submetric_var,
            poly_path,
            deg,
            const_iav,
            iav_path,
        )
        delayed_res.append(tmp_res)

    # ISIMIP
    for model in isi_models:
        tmp_res = dask.delayed(calculate_iav)(
            isi_in,
            "ISIMIP",
            model,
            land_mask,
            metric,
            submetric,
            submetric_var,
            poly_path,
            deg,
            const_iav,
            iav_path,
        )
        delayed_res.append(tmp_res)

    # carbonplan GARD-SV
    if metric in ["wet", "dry", "hotdry"]:
        models = cbp_gard_precip_models
    else:
        models = cbp_gard_models
    for model in models:
        tmp_res = dask.delayed(calculate_iav)(
            cbp_in + "/regridded/conservative/GARD-SV/",
            "GARD-SV",
            model,
            land_mask,
            metric,
            submetric,
            submetric_var,
            poly_path,
            deg,
            const_iav,
            iav_path,
        )
        delayed_res.append(tmp_res)

    # carbonplan DeepSD-BC
    for model in cbp_deep_models:
        tmp_res = dask.delayed(calculate_iav)(
            cbp_in + "native_grid/DeepSD-BC/",
            "DeepSD-BC",
            model,
            land_mask,
            metric,
            submetric,
            submetric_var,
            poly_path,
            deg,
            const_iav,
            iav_path,
        )
        delayed_res.append(tmp_res)

    # return
    return delayed_res

In [19]:
def calculate_all_iav(metric, submetric, submetric_var, poly_path, deg, const_iav, iav_path, save_str):
    """
    Calculate the internal variability uncertainty
    for a given selection of parameters/settings
    """
    # Check if already done
    if os.path.isfile(out_path + "uncertainty_partitioning/" + metric + "_" + save_str + ".nc"):
        return None

    # Make delayed list
    delayed_res = make_delayed_list_iav(
        metric=metric,
        submetric=submetric,
        submetric_var=submetric_var,
        poly_path=poly_path,
        deg=deg,
        const_iav=const_iav,
        iav_path=iav_path,
    )
    # Compute
    res = dask.compute(*delayed_res)

    if const_iav:
        # Merge and average over ensemble + model (ensmod) and ssp
        ds_out = xr.concat(res, dim="ensmod").mean(dim=["ensmod", "ssp"])
        ds_out.to_netcdf(out_path + "uncertainty_partitioning/" + metric + "_" + save_str + ".nc")
    else:
        # If rolling IAV, each ensmod was saved individually so now read with dask
        # and calculate average (otherwise would run out of memory)
        out_str = iav_path + metric + "/"

        if submetric:
            submetric_str = submetric[0].replace("_count", "").replace("_streak", "")
            out_str = out_str + submetric_str + "_"

        with dask.config.set(**{"array.slicing.split_large_chunks": False}):
            ds_out = xr.open_mfdataset(
                out_str + "*_deg" + str(deg) + ".nc",
                combine="nested",
                concat_dim="ensmod",
                parallel=True,
            )

        ds_out = ds_out.mean(dim=["ensmod", "ssp"])
        ds_out.to_netcdf(out_path + "uncertainty_partitioning/" + metric + "_" + save_str + ".nc")

## Computations

In [None]:
%%time

metric = "avg"

##############################
# Get the forced response
##############################
for deg in [2, 4]:
    get_forced_poly(
        nex_in,
        nex_models,
        cil_in,
        cil_models,
        isi_in,
        isi_models,
        cbp_in,
        cbp_gard_models,
        cbp_deep_models,
        land_mask,
        metric,
        False,
        False,
        poly_path,
        deg,
    )

##############################
# Interannual variability
##############################
for deg in [2, 4]:
    for const_iav in [True, False]:
        calculate_all_iav(
            metric=metric,
            submetric=False,
            submetric_var=False,
            poly_path=poly_path,
            deg=deg,
            const_iav=const_iav,
            iav_path=iav_path,
            save_str="deg" + str(deg) + "_" + ("non" * (not const_iav)) + "const_iav",
        )

#############################
# UC on forced response
#############################
for deg in [2, 4]:
    for weighted in [True, False]:
        calculate_forced_uc(
            metric=metric,
            submetric=False,
            submetric_var=False,
            poly_path=poly_path,
            deg=deg,
            weighted=weighted,
            save_str="deg" + str(2) + "_nonWeighted" * (not weighted),
        )

In [None]:
%%time

metric = "max"

#######################################
# Get the forced response (polynomial)
#######################################
for deg in [2, 4]:
    get_forced_poly(
        nex_in,
        nex_models,
        cil_in,
        cil_models,
        isi_in,
        isi_models,
        cbp_in,
        cbp_gard_models,
        cbp_deep_models,
        land_mask,
        metric,
        False,
        False,
        poly_path,
        deg,
    )

####################################
# Interannual variability
####################################
for deg in [2, 4]:
    for const_iav in [True, False]:
        calculate_all_iav(
            metric=metric,
            submetric=False,
            submetric_var=False,
            poly_path=poly_path,
            deg=deg,
            const_iav=const_iav,
            iav_path=iav_path,
            save_str="deg" + str(deg) + "_" + ("non" * (not const_iav)) + "const_iav",
        )

#############################
# UC on forced response
#############################
for deg in [2, 4]:
    for weighted in [True, False]:
        calculate_forced_uc(
            metric=metric,
            submetric=False,
            submetric_var=False,
            poly_path=poly_path,
            deg=deg,
            weighted=weighted,
            save_str="deg" + str(2) + "_nonWeighted" * (not weighted),
        )

In [None]:
%%time

metric = "max5d"

#######################################
# Get the forced response (polynomial)
#######################################
for deg in [2, 4]:
    get_forced_poly(
        nex_in,
        nex_models,
        cil_in,
        cil_models,
        isi_in,
        isi_models,
        cbp_in,
        cbp_gard_models,
        cbp_deep_models,
        land_mask,
        metric,
        False,
        False,
        poly_path,
        deg,
    )

####################################
# Interannual variability
####################################
for deg in [2, 4]:
    for const_iav in [True, False]:
        calculate_all_iav(
            metric=metric,
            submetric=False,
            submetric_var=False,
            poly_path=poly_path,
            deg=deg,
            const_iav=const_iav,
            iav_path=iav_path,
            save_str="deg" + str(deg) + "_" + ("non" * (not const_iav)) + "const_iav",
        )

#############################
# UC on forced response
#############################
for deg in [2, 4]:
    for weighted in [True, False]:
        calculate_forced_uc(
            metric=metric,
            submetric=False,
            submetric_var=False,
            poly_path=poly_path,
            deg=deg,
            weighted=weighted,
            save_str="deg" + str(2) + "_nonWeighted" * (not weighted),
        )

In [None]:
%%time

metric = "dry"

#######################################
# Get the forced response (polynomial)
#######################################
for deg in [2, 4]:
    get_forced_poly(
        nex_in,
        nex_models,
        cil_in,
        cil_models,
        isi_in,
        isi_models,
        cbp_in,
        cbp_gard_precip_models,
        cbp_deep_models,
        land_mask,
        metric,
        False,
        False,
        poly_path,
        deg,
    )

####################################
# Interannual variability
####################################
for deg in [2, 4]:
    for const_iav in [True, False]:
        calculate_all_iav(
            metric=metric,
            submetric=False,
            submetric_var=False,
            poly_path=poly_path,
            deg=deg,
            const_iav=const_iav,
            iav_path=iav_path,
            save_str="deg" + str(deg) + "_" + ("non" * (not const_iav)) + "const_iav",
        )

#############################
# UC on forced response
#############################
for deg in [2, 4]:
    for weighted in [True, False]:
        calculate_forced_uc(
            metric=metric,
            submetric=False,
            submetric_var=False,
            poly_path=poly_path,
            deg=deg,
            weighted=weighted,
            save_str="deg" + str(2) + "_nonWeighted" * (not weighted),
        )

In [None]:
%%time
# Wet days
metric = "wet"

for thresh in ["q99", "rp10"]:
    for obs in ["gmfd", "era5"]:
        submetric_str = thresh + obs
        submetric = [
            "pr_" + submetric_str + "_count",
            "pr_" + submetric_str + "_streak",
        ]

        #######################################
        # Get the forced response (polynomial)
        #######################################
        for deg in [2, 4]:
            get_forced_poly(
                nex_in,
                nex_models,
                cil_in,
                cil_models,
                isi_in,
                isi_models,
                cbp_in,
                cbp_gard_precip_models,
                cbp_deep_models,
                land_mask,
                metric,
                submetric,
                False,
                poly_path,
                deg,
            )

        ####################################
        # Interannual variability
        ####################################
        for deg in [2, 4]:
            for const_iav in [True, False]:
                calculate_all_iav(
                    metric=metric,
                    submetric=submetric,
                    submetric_var=False,
                    poly_path=poly_path,
                    deg=deg,
                    const_iav=const_iav,
                    iav_path=iav_path,
                    save_str=submetric_str + "_deg" + str(deg) + "_" + ("non" * (not const_iav)) + "const_iav",
                )

        ################################
        # UC on forced response
        ################################
        for deg in [2, 4]:
            for weighted in [True, False]:
                calculate_forced_uc(
                    metric=metric,
                    submetric=submetric,
                    submetric_var=False,
                    poly_path=poly_path,
                    deg=deg,
                    weighted=weighted,
                    save_str="deg" + str(2) + "_nonWeighted" * (not weighted),
                )

In [None]:
%%time
# Hot days
metric = "hot"

# Subselection of submetrics to analyze
submetric_strs = [
    "tas_q99gmfd",
    "tasmax_q99gmfd",
    "tasmin_q99gmfd",
    "tas_q99era5",
    "tasmax_q99era5",
    "tasmin_q99era5",
    "tas_rp10gmfd",
    "tasmax_rp10gmfd",
    "tasmin_rp10gmfd",
]

# for thresh in ['q99', 'rp10']:
#     for obs in ['gmfd', 'era5']:
#         for submetric_var in ['tas', 'tasmin', 'tasmax']:

for submetric_str in submetric_strs:
    if True:
        if True:
            # submetric_str = submetric_var + '_' + thresh + obs
            submetric_var = submetric_str.split("_")[0]
            submetric = [submetric_str + "_count", submetric_str + "_streak"]

            #######################################
            # Get the forced response (polynomial)
            #######################################
            for deg in [2, 4]:
                get_forced_poly(
                    nex_in,
                    nex_models,
                    cil_in,
                    cil_models,
                    isi_in,
                    isi_models,
                    cbp_in,
                    cbp_gard_models,
                    cbp_deep_models,
                    land_mask,
                    metric,
                    submetric,
                    submetric_var,
                    poly_path,
                    deg,
                )

            ####################################
            # Interannual variability
            ####################################
            for deg in [2, 4]:
                for const_iav in [True, False]:
                    calculate_all_iav(
                        metric=metric,
                        submetric=submetric,
                        submetric_var=submetric_var,
                        poly_path=poly_path,
                        deg=deg,
                        const_iav=const_iav,
                        iav_path=iav_path,
                        save_str=submetric_str + "_deg" + str(deg) + "_" + ("non" * (not const_iav)) + "const_iav",
                    )

            ################################
            # UC on forced response
            ################################
            for deg in [2, 4]:
                for weighted in [True, False]:
                    calculate_forced_uc(
                        metric=metric,
                        submetric=submetric,
                        submetric_var=submetric_var,
                        poly_path=poly_path,
                        deg=deg,
                        weighted=weighted,
                        save_str="deg" + str(2) + "_nonWeighted" * (not weighted),
                    )

In [None]:
%%time
# Hot + dry days
metric = "hotdry"

for thresh in ["q99", "rp10"]:
    for obs in ["gmfd", "era5"]:
        submetric_str = thresh + obs
        submetric = [
            "hotdry_" + submetric_str + "_count",
            "hotdry_" + submetric_str + "_streak",
        ]

        #######################################
        # Get the forced response (polynomial)
        #######################################
        for deg in [2, 4]:
            get_forced_poly(
                nex_in,
                nex_models,
                cil_in,
                cil_models,
                isi_in,
                isi_models,
                cbp_in,
                cbp_gard_precip_models,
                cbp_deep_models,
                land_mask,
                metric,
                submetric,
                False,
                poly_path,
                deg,
            )

        ####################################
        # Interannual variability
        ####################################
        for deg in [2, 4]:
            for const_iav in [True, False]:
                calculate_all_iav(
                    metric=metric,
                    submetric=submetric,
                    submetric_var=False,
                    poly_path=poly_path,
                    deg=deg,
                    const_iav=const_iav,
                    iav_path=iav_path,
                    save_str=submetric_str + "_deg" + str(deg) + "_" + ("non" * (not const_iav)) + "const_iav",
                )

        ################################
        # UC on forced response
        ################################
        for deg in [2, 4]:
            for weighted in [True, False]:
                calculate_forced_uc(
                    metric=metric,
                    submetric=submetric,
                    submetric_var=False,
                    poly_path=poly_path,
                    deg=deg,
                    weighted=weighted,
                    save_str="deg" + str(2) + "_nonWeighted" * (not weighted),
                )