---
title: "ds.map_blocks(...)"
categories: [xarray]
date: 2025-04-11
---

## Example on how to use `ds.map_blocks(...)` for pixel wise prediction

This is the preferred version for chunk-wise processing of an xarray.Dataset

In [None]:
import dask.array as da
import xarray as xr
import numpy as np
from sklearn.ensemble import RandomForestClassifier as RF

In [None]:
n_features = 50

In [None]:
n_samples = 1000
# random training data
X_train = da.random.random((n_samples, n_features))
y_train = da.random.randint(0, 2, n_samples)

rf = RF(random_state=42, n_estimators=50, n_jobs=-1)
rf.fit(X_train, y_train)

In [None]:
# random training data
lat = np.arange(4000)
lon = np.arange(6000)
time = np.arange(n_features)
data = da.random.random((lat.size, lon.size, time.size))

In [None]:
ds = xr.DataArray(
    data, coords=[lat, lon, time], dims=["lat", "lon", "time"], name="test"
).to_dataset()

In [None]:
def generic_func(ds: xr.Dataset):
    ds_stacked = ds.stack(ml=("lat", "lon")).transpose("ml", "time")

    # predict on input data
    X = ds_stacked.test.data
    y_hat_1d = rf.predict(X)
    y_hat_2d = y_hat_1d.reshape((ds.lat.size, ds.lon.size))

    data_out = ds.isel(time=[0]).squeeze().copy(deep=True)
    data_out.test.data = y_hat_2d

    return data_out

In [None]:
ds_pred = ds.map_blocks(generic_func, template=ds.isel(time=[0]).squeeze())

In [None]:
ds_pred

In [None]:
ds_pred = ds_pred.compute()