---
title: "xr.apply_ufunc(...)"
categories: [xarray]
date: 2025-04-11
---


## Example on how `xr.apply_ufunc(...)` can be used for pixel wise prediction

Note: `ds.map_blocks()` likely is faster compared to this approach. This is really only for demonstrative purpose, as a template to be adapted for other computations.

In [1]:
import sys

import xarray as xr
import numpy as np
import dask
import sklearn
from sklearn.ensemble import RandomForestClassifier as RF

from util import generate_X_y, generate_3d_dataset

In [2]:
print(sys.version)
print(dask.__version__)
print(xr.__version__)
print(np.__version__)
print(sklearn.__version__)

3.13.1 | packaged by conda-forge | (main, Dec  5 2024, 21:23:54) [GCC 13.3.0]
2025.4.0
2025.3.1
2.2.0
1.6.0


## Generate some random data for training and inference

In [3]:
n_classes = 2
n_features = 12
n_samples = 1000

lat = 40
lon = 60
time = n_features

In [4]:
# random training data
X_train, y_train = generate_X_y(n_samples,n_features, n_classes)

In [5]:
# random "real" data to predict on
ds = generate_3d_dataset(lat, lon, time)
ds

Unnamed: 0,Array,Chunk
Bytes,225.00 kiB,2.25 kiB
Shape,"(40, 60, 12)","(4, 6, 12)"
Dask graph,100 chunks in 1 graph layer,100 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 225.00 kiB 2.25 kiB Shape (40, 60, 12) (4, 6, 12) Dask graph 100 chunks in 1 graph layer Data type float64 numpy.ndarray",12  60  40,

Unnamed: 0,Array,Chunk
Bytes,225.00 kiB,2.25 kiB
Shape,"(40, 60, 12)","(4, 6, 12)"
Dask graph,100 chunks in 1 graph layer,100 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Train a dummy model

In [6]:
rf = RF(random_state=42, n_estimators=50, n_jobs=-1)
rf.fit(X_train, y_train)

## Define function to be applied via `.apply.ufunc(...)`

In [7]:
def generic_func(arr):
    return rf.predict(arr.reshape(1, -1))

In [8]:
ds_ag = xr.apply_ufunc(
    generic_func,
    ds,
    input_core_dims=[["time"]],
    dask="parallelized",
    output_dtypes=np.float32,
    vectorize=True,
    dask_gufunc_kwargs={"allow_rechunk": True},
)

In [9]:
ds_ag

Unnamed: 0,Array,Chunk
Bytes,9.38 kiB,96 B
Shape,"(40, 60)","(4, 6)"
Dask graph,100 chunks in 4 graph layers,100 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 9.38 kiB 96 B Shape (40, 60) (4, 6) Dask graph 100 chunks in 4 graph layers Data type float32 numpy.ndarray",60  40,

Unnamed: 0,Array,Chunk
Bytes,9.38 kiB,96 B
Shape,"(40, 60)","(4, 6)"
Dask graph,100 chunks in 4 graph layers,100 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
ds_ag.compute()