---
title: "ds.map_blocks(...)"
categories: [xarray]
date: 2025-04-11
---

## Example on how to use `ds.map_blocks(...)` for pixel wise prediction

`.map_blocks(...)` applies a function to chunks a dask-backed xarray.Dataset.

In [1]:
import dask.array as da
import xarray as xr
import numpy as np
from sklearn.ensemble import RandomForestClassifier as RF

In [2]:
n_features = 50

## Generate a dummy model

In [3]:
n_samples = 1000
# random training data
X_train = da.random.random((n_samples, n_features))
y_train = da.random.randint(0, 2, n_samples)

rf = RF(random_state=42, n_estimators=50, n_jobs=-1)
rf.fit(X_train, y_train)

## Generate new data to predict on
The time dimension in the following example is only a placeholder for any kind of predictor dimension. For the example to make sense (and work!), the predictor/feature (i.e., time) dimension must not be chunked!

In [4]:
# random data to predict on
lat = np.arange(4000)
lon = np.arange(6000)
time = np.arange(n_features)

# unless we set the time dimension to -1, it wont work
data = da.random.random(
    (lat.size, lon.size, time.size),
    chunks=(200, 200, -1),
)

In [5]:
ds = xr.DataArray(
    data,
    coords=[lat, lon, time],
    dims=["lat", "lon", "time"],
    name="test",
).to_dataset()

In [6]:
def generic_func(ds: xr.Dataset):
    """
    Flatten chunk
    Apply Random Forest model
    Recover original 2D shape
    """
    ds_stacked = ds.stack(ml=("lat", "lon")).transpose("ml", "time")

    # predict on input data
    X = ds_stacked.test.data
    y_hat_1d = rf.predict(X)
    y_hat_2d = y_hat_1d.reshape((ds.lat.size, ds.lon.size))

    # copy the chunk but remove (squeeze) the time dimension
    data_out = ds.isel(time=[0]).squeeze().copy(deep=True)
    data_out.test.data = y_hat_2d

    return data_out

In [7]:
ds_pred = ds.map_blocks(generic_func, template=ds.isel(time=[0]).squeeze())

In [8]:
ds_pred

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,(),()
Dask graph,1 chunks in 4 graph layers,1 chunks in 4 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
Array Chunk Bytes 8 B 8 B Shape () () Dask graph 1 chunks in 4 graph layers Data type int64 numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,(),()
Dask graph,1 chunks in 4 graph layers,1 chunks in 4 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,183.11 MiB,312.50 kiB
Shape,"(4000, 6000)","(200, 200)"
Dask graph,600 chunks in 4 graph layers,600 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 183.11 MiB 312.50 kiB Shape (4000, 6000) (200, 200) Dask graph 600 chunks in 4 graph layers Data type float64 numpy.ndarray",6000  4000,

Unnamed: 0,Array,Chunk
Bytes,183.11 MiB,312.50 kiB
Shape,"(4000, 6000)","(200, 200)"
Dask graph,600 chunks in 4 graph layers,600 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [9]:
ds_pred = ds_pred.compute()