# Generate Requested Remapped 2D Fields

In [1]:
import datetime
import glob
import pprint

from dask import compute, delayed
import distributed
import ncar_jobqueue
import numpy as np
import numpy.ma as ma
import xarray as xr
import yaml

from ocean_remap import ocean_remap
from utils import time_set_mid
from utils_units import clean_units, conv_units

  from distributed.utils import tmpfile


In [2]:
xr.set_options(keep_attrs=True);

In [3]:
with open("GCB_metadata.yaml", mode="r") as fptr:
    GCB_metadata = yaml.safe_load(fptr)
pprint.pprint(GCB_metadata)

{'A': {'cases': ['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRD.001',
                 'g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRD.002']},
 'B': {'cases': ['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BCRC.001']},
 'C': {'cases': ['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRC.001']},
 'D': {'cases': ['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BCRD.001',
                 'g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BCRD.002']}}


In [4]:
tseries_root = "/glade/campaign/cesm/development/bgcwg/projects/GCB_2022"
submission_dir = f"{tseries_root}/submission"

In [5]:
def gen_single_var_ds(CESM_cases, gcomp, freq, scomp, stream, varname, isel_dict=None):

    print("entering gen_single_var_ds")

    paths = []
    for case in CESM_cases:
        tseries_dir = f"{tseries_root}/{case}/output/{gcomp}/proc/tseries/{freq}"
        case_paths = glob.glob(f"{tseries_dir}/{case}.{scomp}.{stream}.{varname}.*.nc")
        case_paths.sort()
        paths.extend(case_paths)

    kwargs = {
        "compat": "override",
        "data_vars": "minimal",
        "coords": "minimal",
        "parallel": True,
    }

    ds = xr.open_mfdataset(paths, chunks={"time": 12}, **kwargs)

    if isel_dict is not None:
        ds = ds.isel(isel_dict)
        for key, value in isel_dict.items():
            if isinstance(value, int):
                ds = ds.drop_vars(key)

    # copy metadata not propagated by open_mfdataset from 1st file
    ds0 = xr.open_dataset(paths[0])
    for key in ["unlimited_dims"]:
        if key in ds0.encoding:
            ds.encoding[key] = ds0.encoding[key]
    ds["time"].encoding = ds0["time"].encoding

    # remove CESM specific variable attributes
    del ds[varname].attrs["grid_loc"]

    ds = time_set_mid(ds, "time")

    print("returning from gen_single_var_ds")

    return ds

In [6]:
def remap_2d_driver(GCB_name, CESM_cases, matrix):
    print(GCB_name)
    print(CESM_cases)

    (yr_lo, yr_hi) = (1959, 2021)
    time_slice = slice(f"{yr_lo:4}-01-01", f"{(yr_hi+1):4}-01-01")

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "FG_CO2").sel(time=time_slice)
        da = remap_2d(ds_in, "FG_CO2", matrix, apply_area_corr=True)
        da.name = "fgco2"
        da = conv_units(da, "mol m-2 s-1")
        remap_2d_write(GCB_name, ds_in, da, matrix.dst_grid, yr_lo, yr_hi)

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "fCO2").sel(time=time_slice)
        da = remap_2d(ds_in, "fCO2", matrix)
        da.name = "sfco2"
        remap_2d_write(GCB_name, ds_in, da, matrix.dst_grid, yr_lo, yr_hi)

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "IFRAC").sel(time=time_slice)
        da = remap_2d(ds_in, "IFRAC", matrix)
        da.name = "fice"
        da.attrs["units"] = clean_units(da.attrs["units"])
        remap_2d_write(GCB_name, ds_in, da, matrix.dst_grid, yr_lo, yr_hi)

    if True:
        isel_dict = {"z_t": 0}
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "DIC", isel_dict).sel(time=time_slice)
        da = remap_2d(ds_in, "DIC", matrix)
        da.name = "dissicos"
        da.attrs["long_name"] = "Surface " + da.attrs["long_name"]
        da = conv_units(da, "mol m-3")
        remap_2d_write(GCB_name, ds_in, da, matrix.dst_grid, yr_lo, yr_hi)

    if True:
        isel_dict = {"z_t": 0}
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "ALK", isel_dict).sel(time=time_slice)
        da = remap_2d(ds_in, "ALK", matrix)
        da.name = "talkos"
        da.attrs["long_name"] = "Surface " + da.attrs["long_name"]
        da.attrs["units"] = clean_units(da.attrs["units"])
        da = conv_units(da, "mol m-3")
        remap_2d_write(GCB_name, ds_in, da, matrix.dst_grid, yr_lo, yr_hi)

    if True:
        isel_dict = {"z_t": 0}
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "TEMP", isel_dict).sel(time=time_slice)
        da = remap_2d(ds_in, "TEMP", matrix)
        da.name = "tos"
        da.attrs["long_name"] = "Surface " + da.attrs["long_name"]
        remap_2d_write(GCB_name, ds_in, da, matrix.dst_grid, yr_lo, yr_hi)

    if True:
        isel_dict = {"z_t": 0}
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "SALT", isel_dict).sel(time=time_slice)
        da = remap_2d(ds_in, "SALT", matrix)
        da.name = "sos"
        da.attrs["long_name"] = "Surface " + da.attrs["long_name"]
        remap_2d_write(GCB_name, ds_in, da, matrix.dst_grid, yr_lo, yr_hi)

    if True:
        ds_in = gen_single_var_ds(CESM_cases, "ocn", "month_1", "pop", "h", "DIC_zint").sel(time=time_slice)
        da = remap_2d(ds_in, "DIC_zint", matrix, apply_area_corr=True)
        da.name = "intdic"
        da = conv_units(da, "mol m-2")
        remap_2d_write(GCB_name, ds_in, da, matrix.dst_grid, yr_lo, yr_hi)


def remap_2d(ds, varname, matrix, apply_area_corr=False):

    print("entering remap_2d")

    objs = []
    tlen = ds[varname].shape[0]
    seg_cnt = 7
    for seg_ind in range(seg_cnt):
        ind_lo = (seg_ind * tlen) // seg_cnt
        ind_lo = 12 * round(ind_lo / 12)
        ind_hi = ((seg_ind + 1) * tlen) // seg_cnt
        ind_hi = 12 * round(ind_hi / 12)
        ds_seg = ds.isel(time=slice(ind_lo, ind_hi))
        obj = delayed(remap_2d_core)(ds_seg, varname, matrix, apply_area_corr)
        objs.append(obj)

    print("calling compute")

    da_list = compute(objs)[0]

    print("concatenating")

    da = xr.concat(da_list, dim="time", coords="minimal", compat="override")

    print("returning from remap_2d")

    return da


def remap_2d_core(ds, varname, matrix, apply_area_corr=False):
    da_src = ds[varname].load()
    fill_value = da_src.encoding["_FillValue"]
    vals_src = da_src.to_masked_array()

    if apply_area_corr:
        src_grid = matrix.src_grid
        src_area_rad2 = (src_grid.area * src_grid.frac).reshape(src_grid.dims)
        POP_area_cm2 = xr.where(ds["KMT"] > 0, ds["TAREA"], 0.0).to_masked_array()
        rearth = ds["radius"].values
        mdl2src = POP_area_cm2 / rearth**2 / src_area_rad2
        vals_src *= mdl2src

    vals_dst = matrix.remap_var(vals_src, fill_value=fill_value)
    vals_dst = np.where(vals_dst != fill_value, vals_dst, np.nan)

    dst_grid = matrix.dst_grid
    da_dst = xr.DataArray(
        vals_dst,
        coords={"time": ds["time"], "lat": dst_grid.lat, "lon": dst_grid.lon},
        attrs=da_src.attrs,
    )
    da_dst.attrs["cell_measures"] = "area: area"
    da_dst.attrs["comment"] = "Remapped from native grid to regular 1x1 grid. " \
        "Cells on 1x1 grid fractionally covered by native grid set to all " \
        "ocean or land if fractional coverage is greater than or less than " \
        "frac_thres respectively. Note that this leads to loss of exact " \
        "conservation."
    da_dst.encoding = da_src.encoding
    del da_dst.encoding["coordinates"]

    return da_dst


def add_grid_vars(dst_ds, dst_grid, ds_in):
    dst_ds["area"] = xr.DataArray(
        (conv_units(ds_in["radius"], "m")**2).values * dst_grid.area.reshape(dst_grid.dims),
        coords={"lat": dst_grid.lat, "lon": dst_grid.lon},
        attrs={"long_name": "grid cell area", "units": "m2"},
    )
    dst_ds["area"].encoding["zlib"] = True
    dst_ds["area"].encoding["complevel"] = 1

    dst_ds["mask_sfc"] = xr.DataArray(
        dst_grid.mask.reshape(dst_grid.dims),
        coords={"lat": dst_grid.lat, "lon": dst_grid.lon},
        attrs={"long_name": "surface grid cell mask", "units": "1"},
    )
    dst_ds["mask_sfc"].encoding["zlib"] = True
    dst_ds["mask_sfc"].encoding["complevel"] = 1


def remap_2d_write(GCB_name, ds_in, da, dst_grid, yr_lo, yr_hi):
    tb_name = ds_in["time"].attrs["bounds"]
    dst_ds = xr.Dataset({"time": ds_in["time"], tb_name: ds_in[tb_name], da.name: da})
    dst_ds.encoding = ds_in.encoding
    dst_ds.attrs["source_id"] = "CESM2"
    dst_ds.attrs["institution_id"] = "NCAR"
    dst_ds.attrs["variable_id"] = da.name
    dst_ds.attrs["frac_thres"] = 0.525
    dst_ds.attrs["contact"] = "klindsay@ucar.edu"
    dst_ds.attrs["creation_date"] = datetime.datetime.now().strftime("%Y-%m-%d")

    add_grid_vars(dst_ds, dst_grid, ds_in)

    dst_ds["lat"].attrs = {"long_name": "latitude", "units": "degrees_north"}
    dst_ds["lon"].attrs = {"long_name": "longitude", "units": "degrees_east"}

    # ensure NaN _FillValues do not get generated
    for d in [dst_ds.variables, dst_ds.coords]:
        for var in d:
            if "_FillValue" not in dst_ds[var].encoding:
                dst_ds[var].encoding["_FillValue"] = None

    datestamp = datetime.datetime.now().strftime("%Y%m%d")

    timestring = f"{yr_lo:4}01-{yr_hi:4}12"
    path = f"{submission_dir}/{GCB_name}/{da.name}_CESM2_{GCB_name}_1_gr_{timestring}_v{datestamp}.nc"
    print(f"writing remapped monthly field to {path}")
    dst_ds.to_netcdf(path)

# Obtain Computational Resources

In [7]:
cluster = ncar_jobqueue.NCARCluster(
    cores=1,  # The number of cores you want
    memory='4GB',  # Amount of memory
    processes=1,  # How many processes
    walltime='01:00:00',  # Amount of wall time
)

cluster.scale(7)

client = distributed.Client(cluster)
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 43554 instead


0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/klindsay/GCB_2022/proxy/43554/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/klindsay/GCB_2022/proxy/43554/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.12.206.48:44602,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/klindsay/GCB_2022/proxy/43554/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [8]:
matrix_2d_fname = 'POP_gx1v7_to_latlon_1x1_0E_mask_conserve_20220714.nc'
matrix_2d = ocean_remap(matrix_2d_fname)

for GCB_name in GCB_metadata:
    remap_2d_driver(GCB_name, GCB_metadata[GCB_name]["cases"], matrix_2d)

A
['g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRD.001', 'g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.GCB_2022.BDRD.002']
entering gen_single_var_ds
returning from gen_single_var_ds
entering remap_2d
calling compute


  ('finalize-bf5ba389-84a9-4bee-bb41-830557a7b0d5',  ... bf6e2e0>, True)
Consider scattering large objects ahead of time
with client.scatter to reduce scheduler burden and 
keep data on workers

    future = client.submit(func, big_data)    # bad

    big_future = client.scatter(big_data)     # good
    future = client.submit(func, big_future)  # good


concatenating
returning from remap_2d
writing remapped monthly field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/fgco2_CESM2_A_1_gr_195901-202112_v20220719.nc
entering gen_single_var_ds
returning from gen_single_var_ds
entering remap_2d
calling compute
concatenating
returning from remap_2d
writing remapped monthly field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/sfco2_CESM2_A_1_gr_195901-202112_v20220719.nc
entering gen_single_var_ds
returning from gen_single_var_ds
entering remap_2d
calling compute
concatenating
returning from remap_2d
writing remapped monthly field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/fice_CESM2_A_1_gr_195901-202112_v20220719.nc
entering gen_single_var_ds
returning from gen_single_var_ds
entering remap_2d
calling compute
concatenating
returning from remap_2d
writing remapped monthly field to /glade/campaign/cesm/development/bgcwg/projects/GCB_2022/submission/A/dissicos_CESM

# Release Computational Resources

In [9]:
client.close()
cluster.close()
