---
title: "ds.map_blocks(...)"
# categories: [xarray]
# date: 2025-04-11
---

`.map_blocks(...)` applies a function to chunks of a dask-backed `xarray.Dataset`. The following example demonstrates how `ds.map_blocks(...)` can be used for pixel-wise application of a machine learning model.

In [1]:
import sys

import dask
import xarray as xr
import numpy as np
import sklearn
from sklearn.ensemble import RandomForestClassifier as RF

from util import generate_X_y, generate_3d_dataset

In [2]:
print(sys.version)
print(dask.__version__)
print(xr.__version__)
print(np.__version__)
print(sklearn.__version__)

3.13.1 | packaged by conda-forge | (main, Dec  5 2024, 21:23:54) [GCC 13.3.0]
2025.4.0
2025.3.1
2.2.0
1.6.0


## Generate new data to predict on
The time dimension in the following example is only a placeholder for any kind of predictor dimension. For the example to make sense (and work!), the predictor/feature (i.e., time) dimension must not be chunked internally, i.e., form a single chunk!

In [None]:
#| code-fold: true

n_classes = 2
n_features = 12
n_samples = 1000

lat = 4000
lon = 6000 
time = n_features

In [4]:
# random training data
X_train, y_train = generate_X_y(n_samples,n_features, n_classes)

In [5]:
# random features to predict on, in a "real" shape (x, y, time)
ds = generate_3d_dataset(lat, lon, time)
ds

Unnamed: 0,Array,Chunk
Bytes,2.15 GiB,21.97 MiB
Shape,"(4000, 6000, 12)","(400, 600, 12)"
Dask graph,100 chunks in 1 graph layer,100 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 2.15 GiB 21.97 MiB Shape (4000, 6000, 12) (400, 600, 12) Dask graph 100 chunks in 1 graph layer Data type float64 numpy.ndarray",12  6000  4000,

Unnamed: 0,Array,Chunk
Bytes,2.15 GiB,21.97 MiB
Shape,"(4000, 6000, 12)","(400, 600, 12)"
Dask graph,100 chunks in 1 graph layer,100 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Train a dummy model

In [6]:
rf = RF(random_state=42, n_estimators=50, n_jobs=-1)
rf.fit(X_train, y_train)

## Function for chunk-wise application

In [7]:
def generic_func(ds: xr.Dataset):
    """
    Flatten chunk
    Apply Random Forest model
    Recover original 2D shape
    """
    ds_stacked = ds.stack(ml=("lat", "lon")).transpose("ml", "time")

    # predict on input data
    X = ds_stacked.test.data
    y_hat_1d = rf.predict(X)
    y_hat_2d = y_hat_1d.reshape((ds.lat.size, ds.lon.size))

    # copy the chunk but remove (squeeze) the time dimension
    data_out = ds.isel(time=[0]).squeeze().copy(deep=True)
    data_out.test.data = y_hat_2d

    return data_out

In [8]:
ds_pred = ds.map_blocks(generic_func, template=ds.isel(time=[0]).squeeze())

In [9]:
ds_pred

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,(),()
Dask graph,1 chunks in 4 graph layers,1 chunks in 4 graph layers
Data type,datetime64[ns] numpy.ndarray,datetime64[ns] numpy.ndarray
Array Chunk Bytes 8 B 8 B Shape () () Dask graph 1 chunks in 4 graph layers Data type datetime64[ns] numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,(),()
Dask graph,1 chunks in 4 graph layers,1 chunks in 4 graph layers
Data type,datetime64[ns] numpy.ndarray,datetime64[ns] numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,183.11 MiB,1.83 MiB
Shape,"(4000, 6000)","(400, 600)"
Dask graph,100 chunks in 4 graph layers,100 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 183.11 MiB 1.83 MiB Shape (4000, 6000) (400, 600) Dask graph 100 chunks in 4 graph layers Data type float64 numpy.ndarray",6000  4000,

Unnamed: 0,Array,Chunk
Bytes,183.11 MiB,1.83 MiB
Shape,"(4000, 6000)","(400, 600)"
Dask graph,100 chunks in 4 graph layers,100 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [10]:
ds_pred = ds_pred.compute()