## Resampling with XESMF (S3 storage, NetCDF file, H5NetCDF driver, earthaccess auth, and pre-generated weights)

In [None]:
import argparse
import itertools
import sys

import earthaccess
import numpy as np
import pyproj
import rasterio.transform
import xarray as xr
import xesmf as xe

In [None]:
sys.path.append("..")

In [None]:
def make_grid_ds(*, te, tilesize, dstSRS) -> xr.Dataset:
    """
    Make a dataset representing a target grid

    Returns
    -------
    xr.Dataset
        Target grid dataset with the following variables:
        - "x": X coordinate in Web Mercator projection (grid cell center)
        - "y": Y coordinate in Web Mercator projection (grid cell center)
        - "lat": latitude coordinate (grid cell center)
        - "lon": longitude coordinate (grid cell center)
        - "lat_b": latitude bounds for grid cell
        - "lon_b": longitude bounds for grid cell

    Notes
    -----
    Modified from ndpyramid - https://github.com/carbonplan/ndpyramid
    """

    transform = rasterio.transform.Affine.translation(
        te[0], te[3]
    ) * rasterio.transform.Affine.scale((te[2] * 2) / tilesize, (te[1] * 2) / tilesize)

    p = pyproj.Proj(dstSRS)

    grid_shape = (tilesize, tilesize)
    bounds_shape = (tilesize + 1, tilesize + 1)

    xs = np.empty(grid_shape)
    ys = np.empty(grid_shape)
    lat = np.empty(grid_shape)
    lon = np.empty(grid_shape)
    lat_b = np.zeros(bounds_shape)
    lon_b = np.zeros(bounds_shape)

    # calc grid cell center coordinates
    ii, jj = np.meshgrid(np.arange(tilesize) + 0.5, np.arange(tilesize) + 0.5)
    for i, j in itertools.product(range(grid_shape[0]), range(grid_shape[1])):
        locs = [ii[i, j], jj[i, j]]
        xs[i, j], ys[i, j] = transform * locs
        lon[i, j], lat[i, j] = p(xs[i, j], ys[i, j], inverse=True)

    # calc grid cell bounds
    iib, jjb = np.meshgrid(np.arange(tilesize + 1), np.arange(tilesize + 1))
    for i, j in itertools.product(range(bounds_shape[0]), range(bounds_shape[1])):
        locs = [iib[i, j], jjb[i, j]]
        x, y = transform * locs
        lon_b[i, j], lat_b[i, j] = p(x, y, inverse=True)

    return xr.Dataset(
        {
            "x": xr.DataArray(xs[0, :], dims=["x"]),
            "y": xr.DataArray(ys[:, 0], dims=["y"]),
            "lat": xr.DataArray(lat, dims=["y", "x"]),
            "lon": xr.DataArray(lon, dims=["y", "x"]),
            "lat_b": xr.DataArray(lat_b, dims=["y_b", "x_b"]),
            "lon_b": xr.DataArray(lon_b, dims=["y_b", "x_b"]),
        },
    )


def xesmf_weights_to_xarray(regridder) -> xr.Dataset:
    """
    Construct an xarray dataset from XESMF weights

    Notes
    -----
    From ndpyramid - https://github.com/carbonplan/ndpyramid
    """
    w = regridder.weights.data
    dim = "n_s"
    ds = xr.Dataset(
        {
            "S": (dim, w.data),
            "col": (dim, w.coords[1, :] + 1),
            "row": (dim, w.coords[0, :] + 1),
        }
    )
    ds.attrs = {"n_in": regridder.n_in, "n_out": regridder.n_out}
    return ds


def _reconstruct_xesmf_weights(ds_w):
    """
    Reconstruct weights into format that xESMF understands

    Notes
    -----
    From ndpyramid - https://github.com/carbonplan/ndpyramid
    """
    import sparse
    import xarray as xr

    col = ds_w["col"].values - 1
    row = ds_w["row"].values - 1
    s = ds_w["S"].values
    n_out, n_in = ds_w.attrs["n_out"], ds_w.attrs["n_in"]
    crds = np.stack([row, col])
    return xr.DataArray(
        sparse.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights"
    )


def reconstruct_weights(weights_fp):
    """
    Reconstruct weights into format that xESMF understands

    Notes
    -----
    From ndpyramid - https://github.com/carbonplan/ndpyramid
    """
    return _reconstruct_xesmf_weights(xr.open_zarr(weights_fp))

In [None]:
def get_weights_fp(dataset):
    return (
        "s3://nasa-veda-scratch/resampling/test-weight-caching/"
        + dataset
        + "-weights.zarr"
    )


def get_target_grid_fp(dataset):
    return (
        "s3://nasa-veda-scratch/resampling/test-weight-caching/"
        + dataset
        + "-target.zarr"
    )

In [None]:
def generate_weights(dataset):
    from common import earthaccess_args
    from common import target_extent as te

    # Define filepath, driver, and variable information
    args = earthaccess_args[dataset]
    weights_fp = get_weights_fp(dataset)
    target_grid_fp = get_target_grid_fp(dataset)
    input_uri = f'{args["folder"]}/{args["filename"]}'
    src = f's3://{args["bucket"]}/{input_uri}'
    # Create grid to hold result
    target_grid = make_grid_ds(te=te, tilesize=256, dstSRS="EPSG:3857")
    # Cache target grid
    target_grid.to_zarr(
        target_grid_fp, mode="w", storage_options={"use_listings_cache": False}
    )
    # Authenticate with earthaccess
    fs = earthaccess.get_s3fs_session(daac=args["daac"])
    # Specify fsspec caching since default options don't work well for raster data
    fsspec_caching = {
        "cache_type": "none",
    }
    with fs.open(src, **fsspec_caching) as f:
        # Open dataset
        da = xr.open_dataset(f, engine="h5netcdf", chunks={}, mask_and_scale=True)[
            args["variable"]
        ]
        # Create XESMF regridder
        regridder = xe.Regridder(
            da,
            target_grid,
            "nearest_s2d",
            periodic=True,
            extrap_method="nearest_s2d",
            ignore_degenerate=True,
        )
        # Cache weights
        weights = xesmf_weights_to_xarray(regridder)
        weights.to_zarr(
            weights_fp, mode="w", storage_options={"use_listings_cache": False}
        )

In [None]:
def regrid(dataset):
    from common import earthaccess_args

    args = earthaccess_args[dataset]
    # Load pre-generated weights and target dataset
    weights_fp = get_weights_fp(dataset)
    target_grid_fp = get_target_grid_fp(dataset)
    weights = reconstruct_weights(weights_fp)
    grid = xr.open_zarr(target_grid_fp)
    # Define filepath, driver, and variable information
    input_uri = f'{args["folder"]}/{args["filename"]}'
    src = f's3://{args["bucket"]}/{input_uri}'
    # Authenticate with earthaccess
    fs = earthaccess.get_s3_filesystem(daac=args["daac"])
    # Specify fsspec caching since default options don't work well for raster data
    fsspec_caching = {
        "cache_type": "none",
    }
    with fs.open(src, **fsspec_caching) as f:
        # Open dataset
        da = xr.open_dataset(f, engine="h5netcdf", mask_and_scale=True)[
            args["variable"]
        ]
        # Create XESMF regridder
        regridder = xe.Regridder(
            da,
            grid,
            "nearest_s2d",
            periodic=True,
            extrap_method="nearest_s2d",
            ignore_degenerate=True,
            reuse_weights=True,
            weights=weights,
        )
        # Regrid dataset
        return regridder(da).load()

In [None]:
if __name__ == "__main__":
    if "get_ipython" in dir():
        # Just call warp_resample if running as a Jupyter Notebook
        dataset = "gpm_imerg"
        generate_weights(dataset)
        da = regrid(dataset)
    else:
        # Configure dataset via argpase if running via CLI
        parser = argparse.ArgumentParser(description="Set environment for the script.")
        parser.add_argument(
            "--dataset",
            default="gpm_imerg",
            help="Dataset to resample.",
            choices=["gpm_imerg", "mursst"],
        )
        user_args = parser.parse_args()
        da = regrid(user_args.dataset)