# Generate Requested XY Integrals and Averages

In [1]:
import datetime
import glob
import pprint

import cftime
import distributed
import ncar_jobqueue
import numpy as np
import xarray as xr
import yaml

from utils import time_set_mid
from utils_units import clean_units, conv_units

  from distributed.utils import tmpfile


In [2]:
xr.set_options(keep_attrs=True);

In [3]:
with open("GCB_metadata.yaml", mode="r") as fptr:
    GCB_metadata = yaml.safe_load(fptr)
pprint.pprint(GCB_metadata)

{'A': {'cases': ['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRD.001',
                 'g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRD.002']},
 'B': {'cases': ['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BCRC.001']},
 'C': {'cases': ['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRC.001']},
 'D': {'cases': ['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BCRD.001',
                 'g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BCRD.002']}}


In [4]:
tseries_root = "/glade/campaign/cesm/development/bgcwg/projects/GCB_2022"
submission_dir = f"{tseries_root}/submission"

In [5]:
def gen_single_var_ds(CESM_cases, gcomp, freq, scomp, stream, varname):
    paths = []
    for case in CESM_cases:
        dir = f"{tseries_root}/{case}/output/{gcomp}/proc/tseries/{freq}"
        case_paths = glob.glob(f"{dir}/{case}.{scomp}.{stream}.{varname}.*.nc")
        case_paths.sort()
        paths.extend(case_paths)

    kwargs = {
        "compat": "override",
        "data_vars": "minimal",
        "coords": "minimal",
        "parallel": True,
    }

    ds = xr.open_mfdataset(paths, **kwargs)

    # copy metadata not propagated by open_mfdataset from 1st file
    ds0 = xr.open_dataset(paths[0])
    for key in ["unlimited_dims"]:
        if key in ds0.encoding:
            ds.encoding[key] = ds0.encoding[key]
    ds["time"].encoding = ds0["time"].encoding

    # remove CESM specific variable attributes
    del ds[varname].attrs["grid_loc"]

    return time_set_mid(ds, "time")

In [6]:
def reduce_2d_driver(GCB_name, CESM_cases):
    print(GCB_name)
    print(CESM_cases)

    reduce_dims = ["nlat", "nlon"]

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "FG_CO2")

        weight = ds_in["TAREA"].fillna(0).load()
        da = reduce_2d_comp(ds_in, weight, "FG_CO2", "integrate", reduce_dims)
        da.attrs["units"] += "(12 g)/(mol)"
        da = conv_units(da, "PgC yr-1")
        da.name = "fgco2_glob"
        reduce_2d_write(GCB_name, ds_in, da, 1959, 2021, write_ann=True)
        reduce_2d_write(GCB_name, ds_in, da, 1850, 2021, write_ann=True)

        weight = gen_weight_reg(ds_in)
        da = reduce_2d_comp(ds_in, weight, "FG_CO2", "integrate", reduce_dims)
        da.attrs["units"] += "(12 g)/(mol)"
        da = conv_units(da, "PgC yr-1")
        da.name = "fgco2_reg"
        reduce_2d_write(GCB_name, ds_in, da, 1959, 2021, write_ann=True)
        reduce_2d_write(GCB_name, ds_in, da, 1850, 2021, write_ann=True)

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "DIC_zint")

        weight = ds_in["TAREA"].fillna(0).load()
        da = reduce_2d_comp(ds_in, weight, "DIC_zint", "integrate", reduce_dims)
        da.attrs["units"] += "(12 g)/(mol)"
        da = conv_units(da, "PgC")
        da.name = "intDIC_1994_glob"
        reduce_2d_write(GCB_name, ds_in, da, 1994, 1994, write_mon=False, write_ann=True)
        da.name = "intDIC_2007_glob"
        reduce_2d_write(GCB_name, ds_in, da, 2007, 2007, write_mon=False, write_ann=True)

        weight = gen_weight_reg(ds_in)
        da = reduce_2d_comp(ds_in, weight, "DIC_zint", "integrate", reduce_dims)
        da.attrs["units"] += "(12 g)/(mol)"
        da = conv_units(da, "PgC")
        da.name = "intDIC_1994_reg"
        reduce_2d_write(GCB_name, ds_in, da, 1994, 1994, write_mon=False, write_ann=True)
        da.name = "intDIC_2007_reg"
        reduce_2d_write(GCB_name, ds_in, da, 2007, 2007, write_mon=False, write_ann=True)

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "ATM_CO2")
        weight = ds_in["TAREA"].fillna(0).load()
        da = reduce_2d_comp(ds_in, weight, "ATM_CO2", "average", reduce_dims)
        da.name = "Atm_CO2"
        reduce_2d_write(GCB_name, ds_in, da, 1959, 2021, write_ann=True)

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "DIC_RIV_FLUX")
        weight = ds_in["TAREA"].fillna(0).load()
        da = reduce_2d_comp(ds_in, weight, "DIC_RIV_FLUX", "integrate", reduce_dims)
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "DOC_RIV_FLUX")
        da += reduce_2d_comp(ds_in, weight, "DOC_RIV_FLUX", "integrate", reduce_dims)
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "DOCr_RIV_FLUX")
        da += reduce_2d_comp(ds_in, weight, "DOCr_RIV_FLUX", "integrate", reduce_dims)
        da.attrs["units"] += "(12 g)/(mol)"
        da = conv_units(da, "PgC yr-1")
        da.name = "RivCin"
        reduce_2d_write(GCB_name, ds_in, da, 1959, 2021, write_mon=False, write_mean=True)

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "pocToSed")
        weight = ds_in["TAREA"].fillna(0).load()
        da = reduce_2d_comp(ds_in, weight, "pocToSed", "integrate", reduce_dims)
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "calcToSed")
        da += reduce_2d_comp(ds_in, weight, "calcToSed", "integrate", reduce_dims)
        da.attrs["units"] += "(12 g)/(mol)"
        da = conv_units(da, "PgC yr-1")
        da.name = "Burial"
        reduce_2d_write(GCB_name, ds_in, da, 1959, 1960, write_mon=False, write_mean=True)
        reduce_2d_write(GCB_name, ds_in, da, 1959, 2021, write_mon=False, write_mean=True)


def gen_weight_reg(ds):
    lateral_dims = ds["KMT"].dims
    TAREA = ds["TAREA"].fillna(0).load()
    KMT = ds["KMT"].fillna(0).load()
    TLAT = ds["TLAT"].load()

    rmask_dict = {}
    rmask_dict["South"] = xr.where((KMT > 0) & (TLAT < -30.0), TAREA, 0.0)
    rmask_dict["Tropics"] = xr.where((KMT > 0) & (TLAT >= -30.0) & (TLAT < 30.0), TAREA, 0.0)
    rmask_dict["North"] = xr.where((KMT > 0) & (TLAT >= 30.0), TAREA, 0.0)

    rmask = xr.DataArray(
        np.zeros((len(rmask_dict), ds.dims[lateral_dims[0]], ds.dims[lateral_dims[1]])),
        dims=("region", lateral_dims[0], lateral_dims[1]),
        coords={"region": list(rmask_dict.keys())},
    )
    rmask.attrs = TAREA.attrs
    rmask.encoding = TAREA.encoding
    rmask.region.encoding["dtype"] = "S1"

    for i, rmask_field in enumerate(rmask_dict.values()):
        rmask.values[i, :, :] = rmask_field

    return rmask


def reduce_2d_comp(ds, weight, varname, reduce_op, reduce_dims):
    da_in = ds[varname]

    da_in_units = clean_units(da_in.attrs["units"])

    da_out = da_in.weighted(weight).sum(dim=reduce_dims)
    da_out.encoding = ds[varname].encoding
    del da_out.encoding["coordinates"]

    if reduce_op == "integrate":
        da_out.attrs["long_name"] = "Integrated " + da_in.attrs["long_name"]
        da_out.attrs["units"] = f"({weight.attrs['units']})({da_in_units})"

    if reduce_op == "average":
        ones_masked = xr.ones_like(da_in).where(da_in.notnull())
        denom = ones_masked.weighted(weight).sum(dim=reduce_dims)
        da_out /= denom
        da_out.attrs["long_name"] = "Averaged " + da_in.attrs["long_name"]
        da_out.attrs["units"] = da_in_units

    return da_out


def reduce_2d_write(
    GCB_name, ds_in, da, yr_lo, yr_hi, write_mon=True, write_ann=False, write_mean=False,
):
    tb_name = ds_in["time"].attrs["bounds"]
    data_vars = {"time": ds_in["time"], tb_name: ds_in[tb_name], da.name: da}
    time_slice = slice(f"{yr_lo:4}-01-01", f"{(yr_hi+1):4}-01-01")
    ds_out = xr.Dataset(data_vars).sel(time=time_slice)
    ds_out.encoding = ds_in.encoding
    ds_out.attrs["source_id"] = "CESM2"
    ds_out.attrs["institution_id"] = "NCAR"
    ds_out.attrs["variable_id"] = da.name
    ds_out.attrs["contact"] = "klindsay@ucar.edu"
    ds_out.attrs["creation_date"] = datetime.datetime.now().strftime("%Y-%m-%d")

    # ensure NaN _FillValues do not get generated
    for d in [ds_out.variables, ds_out.coords]:
        for var in d:
            if "_FillValue" not in ds_out[var].encoding:
                ds_out[var].encoding["_FillValue"] = None

    datestamp = datetime.datetime.now().strftime("%Y%m%d")

    if write_mon:
        timestring = f"{yr_lo:4}01-{yr_hi:4}12"
        path = f"{submission_dir}/{GCB_name}/{da.name}_CESM2_{GCB_name}_1_gr_{timestring}_v{datestamp}.nc"
        print(f"writing reduced monthly field to {path}")
        ds_out.to_netcdf(path)

    if write_ann:
        ds_out_ann = gen_ds_ann(ds_out, da)
        timestring = f"{yr_lo:4}-{yr_hi:4}"
        path = f"{submission_dir}/{GCB_name}/{da.name}_CESM2_{GCB_name}_1_gr_{timestring}_v{datestamp}.nc"
        print(f"writing reduced annual field to {path}")
        ds_out_ann.to_netcdf(path)

    if write_mean:
        ds_out_mean = gen_ds_mean(ds_out, da)
        path = f"{submission_dir}/{GCB_name}/{da.name}_CESM2_{GCB_name}_1_gr_v{datestamp}.nc"
        print(f"writing reduced mean field to {path}")
        # without .load() in the following, KeyError is raised for da.name
        # within the netCDF4 backend without .load(), I don't understand why
        ds_out_mean.load().to_netcdf(path)


def gen_ds_ann(ds, da):
    # assumes time values are mid-interval, so groupby works

    # compute da_ann
    time = ds["time"]
    tb_name = time.attrs["bounds"]
    time_bounds = cftime.date2num(
        ds[tb_name], time.encoding["units"], time.encoding["calendar"],
    )
    dt = xr.DataArray(
        time_bounds[:, -1] - time_bounds[:, 0],
        coords={"time": ds["time"]},
    )
    da_ann = (da * dt).groupby("time.year").sum() / dt.groupby("time.year").sum()
    da_ann.encoding = da.encoding

    # construct tb_ann
    tb = ds[tb_name]
    tb_ann = xr.concat(
        [tb[:,0].groupby("time.year").min(), tb[:,1].groupby("time.year").max()],
        tb.dims[-1],
    ).transpose()
    tb_ann.attrs = ds[tb_name].attrs
    tb_ann.encoding = ds[tb_name].encoding

    # construct time_ann from tb_ann
    # load tb_ann to compute time_ann because otherwise you get the error
    # NotImplementedError: Computing the mean of an array containing cftime.datetime
    # objects is not yet implemented on dask arrays.
    time_ann = tb_ann.load().mean(axis=-1)
    time_ann.attrs = ds["time"].attrs
    time_ann.encoding = ds["time"].encoding

    # generate Dataset
    data_vars = {"time": time_ann, tb_name: tb_ann, da.name: da_ann}
    ds_ann = xr.Dataset(data_vars).swap_dims({"year": "time"}).drop("year")
    ds_ann.encoding = ds.encoding
    ds_ann.attrs = ds.attrs

    return ds_ann


def gen_ds_mean(ds, da):
    # compute da_mean
    time = ds["time"]
    tb_name = time.attrs["bounds"]
    time_bounds = cftime.date2num(
        ds[tb_name], time.encoding["units"], time.encoding["calendar"],
    )
    dt = xr.DataArray(
        time_bounds[:, -1] - time_bounds[:, 0],
        coords={"time": ds["time"]},
    )
    da_mean = (da * dt).sum() / dt.sum()
    da_mean.encoding = da.encoding

    # construct tb_mean
    tb = ds[tb_name]
    tb_mean = xr.concat([tb[:,0].min(), tb[:,1].max()], tb.dims[-1]).transpose()
    tb_mean.attrs = ds[tb_name].attrs
    tb_mean.encoding = ds[tb_name].encoding

    # construct time_mean from tb_mean
    # load tb_mean to compute time_mean because otherwise you get the error
    # NotImplementedError: Computing the mean of an array containing cftime.datetime
    # objects is not yet implemented on dask arrays.
    time_mean = tb_mean.load().mean(axis=-1)
    time_mean.attrs = ds["time"].attrs
    time_mean.encoding = ds["time"].encoding

    # generate Dataset
    data_vars = {"time": time_mean, tb_name: tb_mean, da.name: da_mean}
    ds_mean = xr.Dataset(data_vars)
    ds_mean.encoding = ds.encoding
    ds_mean.attrs = ds.attrs

    return ds_mean


# Obtain Computational Resources

In [7]:
cluster = ncar_jobqueue.NCARCluster(
    cores=1,  # The number of cores you want
    memory='4GB',  # Amount of memory
    processes=1,  # How many processes
    walltime='01:00:00',  # Amount of wall time
)

cluster.scale(8)

client = distributed.Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/klindsay/GCB_2022/proxy/8787/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/klindsay/GCB_2022/proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.12.206.54:42712,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/klindsay/GCB_2022/proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [8]:
for GCB_name in GCB_metadata:
    reduce_2d_driver(GCB_name, GCB_metadata[GCB_name]["cases"])

A
['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRD.001', 'g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRD.002']
writing reduced monthly field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/fgco2_glob_CESM2_A_1_gr_195901-202112_v20220712.nc
writing reduced annual field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/fgco2_glob_CESM2_A_1_gr_1959-2021_v20220712.nc
writing reduced monthly field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/fgco2_glob_CESM2_A_1_gr_185001-202112_v20220712.nc
writing reduced annual field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/fgco2_glob_CESM2_A_1_gr_1850-2021_v20220712.nc
writing reduced monthly field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/fgco2_reg_CESM2_A_1_gr_195901-202112_v20220712.nc
writing reduced annual field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/fgco2_reg_CE

# Release Computational Resources

In [9]:
client.close()
cluster.close()
