# cf_xarray : Scale your analysis across datasets with less data wrangling and more metadata handling

_Deepak Cherian, Mattia Almansi, Pascal Bourgault_

There has been an explosion in the availability of terabyte to petabyte-scale
geoscience datasets, particularly on the cloud, prompting the development of
scalable tools and workflows to handle such big datasets by Earthcube projects
such as Pangeo. There is a parallel need for tools that enable the analysis of
datasets from a wide variety of sources that each have their own nomenclature.

Xarray is a python package that enables easy and convenient labelled data
analytics by allowing users to leverage metadata such as dimension names and
coordinate labels. cf_xarray is an open-source Apache licensed Xarray extension
that decodes Climate and Forecast (CF) Metadata conventions adopted by the
geoscience community, allowing users to extensively use standardized metadata
such as “standard names” in their analysis pipelines. For example, the zonal
average of an Xarray dataset `ds` is seamlessly calculated as
`ds.cf.mean("longitude")` on a wide variety of CF-compliant datasets, regardless
of the actual name of the “longitude” variable (e.g. “lon”, “lon_rho”, “long”).
cf_xarray also provides tools and heuristics to optionally guess absent
attributes, allowing usage on incompletely tagged datasets. cf_xarray is now
seeing adoption in other packages such as xESMF, a package for regridding of
Xarray datasets; and NOAA’s Model Diagnostic Task Force (MDTF) diagnostic
workflow for validating model simulations.

Our notebook will demonstrate the use of cf_xarray to build an analysis pipeline
that works on a wide variety of cloud-available datasets such as the CMIP6
archive, the CESM Large Ensemble, various satellite datasets, and that uses
xESMF to regrid this wide variety of datasets to a common grid to facilitate
analysis of anomalies.


## Imports


In [None]:
import cf_xarray

import xarray as xr
import dask

import matplotlib.pyplot as plt

dask.config.set(**{"array.slicing.split_large_chunks": False})

## Utility functions

The following functions are used in this notebook to create an example dataset.


In [None]:
def open_example_dataset():

    # Open grid and variables, then merge
    grid = xr.open_dataset("data/ocean_grid_sym_OM4_05.nc")
    ds = xr.open_dataset(
        "http://35.188.34.63:8080/thredds/dodsC/OM4p5/ocean_monthly_z.200301-200712.nc4",
        chunks={"time": 1, "z_l": 5},
    )
    ds = xr.merge([grid, ds], compat="override")

    # Illustrate the equivalent of a curvilinear grid case,
    # where axes and coordinates are different
    axes = ["xh", "xq", "yh", "yq"]
    ds = ds.drop_vars(axes)
    ds = ds.assign_coords({axis: ds[axis] for axis in axes})
    ds = ds.set_coords(
        [
            var
            for var in ds.variables
            for prefix in ["geo"]
            if var.startswith(prefix)
        ]
    )

    return ds


def assign_coordinates_and_cell_measures(ds):

    # Some CF metadata is missing in the example dataset.
    # Datasets fully compliant with CF conventions do not need this step.
    # Furthermore, functions to automatically assign missing coordinates
    # and measures metadata will be implemented in cf_xarray:
    # https://github.com/xarray-contrib/cf-xarray/issues/201

    for varname, variable in ds.data_vars.items():

        # Add coordinates attribute
        coordinates = []
        for coord in sum(ds.cf.coordinates.values(), []):
            if set(ds[coord].dims) <= set(variable.dims):
                coordinates.append(coord)
        if coordinates:
            variable.attrs["coordinates"] = " ".join(coordinates)
        else:
            variable.attrs.pop("coordinates", None)

        # Add cell_measures attribute
        cell_measures = {}
        for stdname in ("cell_thickness", "cell_area", "ocean_volume"):
            key = stdname.split("_")[-1]
            value = ds.cf.standard_names[stdname]
            for measure in value:
                if (
                    set(ds[measure].dims) <= set(variable.dims)
                    and measure != varname
                ):
                    cell_measures[key] = measure
        if cell_measures:
            variable.attrs["cell_measures"] = " ".join(
                [f"{k}: {v}" for k, v in cell_measures.items()]
            )
        else:
            variable.attrs.pop("cell_measures", None)

# `cf_xarray` features:


## 1. Easily access wrapped functions using xarray's API.

When `cf_xarray` is imported, the `cf` accessor is automatically added to a
dataset.


In [None]:
ds = open_example_dataset()

# cf_xarray accessor is present
assert hasattr(ds, "cf")

# cf_xarray preserves xarray's API
for obj in [ds, ds.cf]:
    assert hasattr(obj, "squeeze")

## 2. Standardize a dataset guessing missing metadata

The `axis` attribute is missing in the example dataset.  
Here we use `cf_xarray` to identify all spatial and time axes and to generate
missing metadata.


In [None]:
ds = ds.cf.guess_coord_axis(verbose=True)
assign_coordinates_and_cell_measures(ds)

## 3. Concisely represent a dataset showing metadata that have been interpreted

The following cell prints the representation of the `cf_xarray` accessor.  
Note that the variables lie on staggered grids, and therefore there are multiple
variables associated with the same Axis/Coordinate/Measure.


In [None]:
ds.cf

## 4. Extract variables using standard names

`standard_name` is a CF metadata that precisely describes the physical
quantities being represented by all variables. `cf_xarray` allows to extract
variables using their standard names. The advantages of using standard names
rather than variable names are:

1. The code generated for a specific dataset can be applied to a wide variety of
   datasets that each has their own nomenclature.
2. As opposed to standard names, arbitrary variable names can be misleading.

Here we use `cf_xarray` to extract the oceanic bathymetry using the appropriate
standard name.


In [None]:
ds.cf["sea_floor_depth_below_geoid"]

## 5. Identify links between variables

CF conventions enable to link variables with each other using metadata (e.g.,
`coordinates`, `cell_measures`, `ancillary_variables`).  
For example, when we extract the bathymetry using `cf_xarray`, an additional
variable corresponding to the surface covered by each grid cell is appended to
the extracted `DataArray`.


In [None]:
cf_coords = set(ds.cf["sea_floor_depth_below_geoid"].coords)
xr_coords = set(ds["deptho"].coords)
additional_coord = list(cf_coords - xr_coords)[0]
ds.cf[additional_coord]

## 6. Perform operation on multiple dimensions associated with the same CF key

As mentioned above, the example dataset has multiple dimensions associated with
the same spatial axes.  
Such information is decoded by `cf_xarray` and can be used by many wrapped
functions.  
For example, here we use the CF axes to slice multiple dimensions at once:


In [None]:
ds_sliced = ds.cf.isel(X=slice(10), Y=slice(10), Z=slice(10), T=slice(10))
print("Original dataset sizes:", ds.sizes)
print("  Sliced dataset sizes:", ds_sliced.sizes)

Similarly, we can apply a spatial average to all variables in the dataset
without having to pass the arbitrary name of all staggered dimensions.


In [None]:
ds.cf.mean(["X", "Y", "Z"])

## 7. Automatically set xarray keyword arguments

`cf_xarray` automatically sets some of the keyword arguments of wrapped
functions.  
As opposed to `xarray`, in the example below `cf_xarray` assigns the appropriate
coordinates to the plot axes.


In [None]:
da = ds.cf["sea_floor_depth_below_geoid"]
fig, (xr_ax, cf_ax) = plt.subplots(1, 2, figsize=(10, 5))
da.plot(ax=xr_ax)
da.cf.plot(ax=cf_ax)
_ = xr_ax.set_title("xarray")
_ = cf_ax.set_title("cf_xarray")
plt.tight_layout()

# 8. Example workflow

The function below performs several operations only using CF attributes.
Therefore, it can be applied to any CF-compliant dataset.

1. Fill NaN values of all cell measures with 0s
2. Extract sea water potential temperature
3. Select all levels shallower than 100m depth
4. Compute and plot thickness-weighted yearly mean maps
5. Compute and plot difference from climatology maps
6. Compute and plot area-weighted yearly mean trend


In [None]:
def plot_top_100m_temp(ds):

    # Fill cell_measures missing values with zeros
    for var in sum(ds.cf.cell_measures.values(), []):
        ds[var] = ds[var].fillna(0)

    # Compute and plot maps
    with xr.set_options(keep_attrs=True):
        da_maps = (
            ds.cf["sea_water_potential_temperature"]
            .cf.sel(Z=slice(0, 100))
            .cf.weighted("thickness")
            .mean(["Z"])
            .cf.groupby("T.year")
            .mean()
        )
    da_maps = da_maps.cf.guess_coord_axis()
    da_maps = da_maps.load()

    # Plot yearly mean
    da_maps.cf.plot(col="T", center=False, robust=True)
    plt.show()

    # Plot difference from climatology
    with xr.set_options(keep_attrs=True):
        da_anomaly = da_maps - da_maps.cf.mean("T")
    da_anomaly.attrs["standard_name"] += "_anomaly"
    da_anomaly.attrs["long_name"] += " Anomaly"
    da_anomaly.cf.plot(col="T", robust=True)
    plt.show()

    # Compute and plot yearly mean
    da_line = da_maps.cf.weighted("area").mean(["X", "Y"])
    da_line.cf.plot()

In [None]:
plot_top_100m_temp(ds)

# 9. TODO: Apply to one of the CMIP datasets


In [None]:
# run plot_top_100m_temp to one of CMIP6 datasets on Pangeo.