---
title: "xr.apply_ufunc(...)"
categories: [xarray]
date: 2025-04-11
---


## Example on how `xr.apply_ufunc(...)` can be used for pixel wise prediction

*Note: `ds.map_blocks()` likely is **a lot** faster compared to this version!*  This is really only for demonstrative purpose.

In [1]:
import dask.array as da
import xarray as xr
import numpy as np
from sklearn.ensemble import RandomForestClassifier as RF

In [2]:
n_features = 12

In [3]:
n_samples = 1000
# random training data
X_train = da.random.random((n_samples,n_features))
y_train = da.random.randint(0, 2, n_samples)

rf = RF(random_state = 42, n_estimators=50, n_jobs=-1)
rf.fit(X_train, y_train)

In [None]:
# random training data
# keep size reasonably small
lat = np.arange(40)
lon = np.arange(60)
time = np.arange(n_features)
data = da.random.random((lat.size,lon.size, time.size))

In [None]:
ds = xr.DataArray(
    data,
    coords=[lat, lon, time],
    dims=["lat", "lon", "time"],
    name="test",
).to_dataset()

In [6]:
ds

Unnamed: 0,Array,Chunk
Bytes,225.00 kiB,225.00 kiB
Shape,"(40, 60, 12)","(40, 60, 12)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 225.00 kiB 225.00 kiB Shape (40, 60, 12) (40, 60, 12) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",12  60  40,

Unnamed: 0,Array,Chunk
Bytes,225.00 kiB,225.00 kiB
Shape,"(40, 60, 12)","(40, 60, 12)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [7]:
def generic_func(arr):
    return rf.predict(arr.reshape(1, -1))

In [8]:
ds_ag = xr.apply_ufunc(
    generic_func,
    ds,
    input_core_dims=[["time"]],
    dask="parallelized",
    output_dtypes=np.float32,
    vectorize=True,
    dask_gufunc_kwargs={"allow_rechunk": True},
)

In [9]:
ds_ag

Unnamed: 0,Array,Chunk
Bytes,9.38 kiB,9.38 kiB
Shape,"(40, 60)","(40, 60)"
Dask graph,1 chunks in 4 graph layers,1 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.38 kiB 9.38 kiB Shape (40, 60) (40, 60) Dask graph 1 chunks in 4 graph layers Data type float32 numpy.ndarray",60  40,

Unnamed: 0,Array,Chunk
Bytes,9.38 kiB,9.38 kiB
Shape,"(40, 60)","(40, 60)"
Dask graph,1 chunks in 4 graph layers,1 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
ds_ag.compute()