Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for adjust_3d #47

Closed
riley-brady opened this issue Oct 25, 2023 · 2 comments · Fixed by #48
Closed

Optimization for adjust_3d #47

riley-brady opened this issue Oct 25, 2023 · 2 comments · Fixed by #48
Assignees
Labels
Feature New feature or request
Milestone

Comments

@riley-brady
Copy link

riley-brady commented Oct 25, 2023

Hello! Thank you for a fantastic offering with this package. This has been very helpful to get some quantile mapping up and running quickly. As a thank you, I wanted to offer some optimization I did on gridded bias correction that you might find useful. I specialize in vectorizing/optimizing gridded data, but do not have the capacity right now to open a full PR.

I used cm.adjust_3d on a ~111x111 lat lon grid with 40,000 time steps. The progress bar estimated around 2.5 hours for this but I didn't run it in full. With the below implementation, it ran in 1 minute.

You need a dask cluster running and a dask dataset of course to reap those benefits, but the implementation will speed up in-memory datasets too.


import numpy as np
import xarray as xr
from cmethods import CMethods as cm


def quantile_map_3d(
    obs: xr.DataArray,
    simh: xr.DataArray,
    simp: xr.DataArray,
    n_quantiles: int,
    kind: str,
):
    """Quantile mapping vectorized for 3D operations."""

    def qmap(
        obs: xr.DataArray,
        simh: xr.DataArray,
        simp: xr.DataArray,
        n_quantiles: int,
        kind: str,
    ) -> np.array:
        """Helper for apply ufunc to vectorize/parallelize the bias correction step."""
        return cm.quantile_mapping(
            obs=obs, simh=simh, simp=simp, n_quantiles=n_quantiles, kind=kind
        )

    result = xr.apply_ufunc(
        qmap,
        obs,
        simh,
        # Need to spoof a fake time axis since 'time' coord on full dataset is different
        # than 'time' coord on training dataset.
        simp.rename({"time": "t2"}),
        dask="parallelized",
        vectorize=True,
        # This will vectorize over the time dimension, so will submit each grid cell
        # independently
        input_core_dims=[["time"], ["time"], ["t2"]],
        # Need to denote that the final output dataset will be labeled with the
        # spoofed time coordinate
        output_core_dims=[["t2"]],
        kwargs={"n_quantiles": n_quantiles, "kind": kind},
    )

    # Rename to proper coordinate name.
    result = result.rename({"t2": "time"})

    # ufunc will put the core dimension to the end (time), so want to preserve original
    # order where time is commonly first.
    result = result.transpose(*obs.dims)
    return result

The nice thing about this is that it can handle 1D datasets without any issue. The limitation is they always have to be xarray objects. But it works with dask or in-memory datasets and any arbitrary dimensions as long as a labeled time dimension exists.

The other great thing is you could just implement the apply_ufunc wrapper to every single bias correction code without the need for a separate adjust_3d function. A user can pass in 1D or 2D+ data without any change in code.

Example:

obs = xr.DataArray(
    [[1, 2, 3, 4], [2, 3, 4, 5]],
    dims=["x", "time"],
    coords={"x": [0, 1], "time": pd.date_range("2023-10-25", freq="D", periods=4)},
).transpose("time", "x")
simh = xr.DataArray(
    [[2, 1, 5, 4], [3, 9, 1, 4]],
    dims=["x", "time"],
    coords={"x": [0, 1], "time": pd.date_range("2023-10-25", freq="D", periods=4)},
).transpose("time", "x")
simp = xr.DataArray(
    [[7, 9, 10, 14], [12, 13, 14, 15]],
    dims=["x", "time"],
    coords={"x": [0, 1], "time": pd.date_range("2040-10-25", freq="D", periods=4)},
).transpose("time", "x")

# 2D dataset
>>> quantile_map_3d(obs, simh, simp, 250, "*")
<xarray.DataArray (time: 4, x: 2)>
array([[5., 9.],
       [5., 9.],
       [5., 9.],
       [5., 9.]])
Coordinates:
  * x        (x) int64 0 1
  * time     (time) datetime64[ns] 2040-10-25 2040-10-26 2040-10-27 2040-10-28

# 1D dataset
>>> quantile_map_3d(obs.isel(x=0), simh.isel(x=0), simp.isel(x=0), 250, "*")
<xarray.DataArray (time: 4)>
array([5., 5., 5., 5.])
Coordinates:
    x        int64 0
  * time     (time) datetime64[ns] 2040-10-25 2040-10-26 2040-10-27 2040-10-28
@riley-brady riley-brady added the Feature New feature or request label Oct 25, 2023
@btschwertfeger
Copy link
Owner

Hey @riley-brady, thank you for sharing your suggestions! I will try it out and if it brings such improvements, it will definitely be part of the next release. I was looking for such a solution for a long time when I was developing this package at my old work place. I created this issue #6 past then but then I got lost in other projects.

I also need some time to explore and implement this solution - hopefully during the upcoming week.

Thanks again!

@riley-brady
Copy link
Author

Glad to hear it might help! Now if you combine https://github.com/btschwertfeger/BiasAdjustCXX with cython and this... you'll be super efficient. :)

To really get the power of this on a larger dataset, you need to have your file stored as a chunked Zarr and be running a local dask cluster. Again, this will work and be much faster than the nested for-loop for in-memory simple datasets (1D, 2D, 3D, and above).

But to the point of your issue #6 and the stackoverflow, if you're going bigger than your RAM you definitely want it stored as Zarr with dask.

If you haven't used zarr or dask before and need some pointers, please let me know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature New feature or request
Projects
None yet
2 participants