# Spatial splitting with stratification

SRAI library contains a dedicated functions for splitting the points dataset into train / test (and optionally validation) splits by separating the points spatially while also keeping them stratified based on a given target.

The function only works for points dataset and uses [`H3`](https://h3geo.org/) indexing system to cluster points together and separate H3 cells into different splits.

---

When working with most machine learning datasets, splitting into training and testing sets is straightforward: pick a random subset for testing, and (optionally) use stratification to keep the distribution of a target variable balanced between the two. This works fine when the data points are independent.

Geospatial data plays by different rules. Nearby locations often share similar characteristics - a phenomenon called spatial autocorrelation. If we split data randomly, our training and test sets might end up covering the same areas, meaning the model is “tested” on locations that are practically identical to ones it has already seen. This can make performance look much better than it really is and we can't test its capability to generalize the reasoning based on spatial features. 

That’s why for geo-related tasks, we need spatial splitting: making sure the training and test sets are separated in space so that evaluation reflects real-world conditions. Sometimes we also want to stratify these spatial splits by a numerical value to ensure both sets still have similar value distributions. Standard `train_test_split` functions can’t combine these two needs, so we provide a dedicated function for spatially aware splitting with optional stratification.

---

This notebook will show how different modes of splitting work based on buildings dataset from [Overture Maps Foundation](https://overturemaps.org/).

### How does it work?

To separate the input dataset into multiple outputs, H3 indexing system is used to split groups of points together.

First, the algorithm transform the points into H3 cells with a given resolution and calculates statistics per H3 cell (number of points per bucket / category).

Next, all H3 cells are shuffled (with optional `random_state` to ensure reproducibility) and iterated one by one.

For each split (test, validation, test) and each bucket per split, a current number of points is saved. While iterating each H3 cell with a group of points inside it, a potential new number of points is calculated with a difference to the expected ratio. Current H3 cell is assigned to the split where the difference to the expected ratio is the lowest.

After iterating all H3 cells, the original dataset of points is split based on the list of assigned H3 cells.

The report of splitting is printed with differences between expected and actual ratios.

In [None]:
import contextily as cx
import geopandas as gpd
import matplotlib.pyplot as plt
import overturemaestro as om
import pyarrow.compute as pc
import seaborn as sns
from matplotlib.patches import Patch

from srai.h3 import h3_to_geoseries, shapely_geometry_to_h3
from srai.spatial_split import spatial_split_points, train_test_spatial_split

Let's start with downloading example data. Here we will use Vancouver buildings from Overture Maps dataset.

We only want buildings with both `height` and `subtype` columns filled.

Height will be used in the numerical split example and subtype in the categorical split example.

---

Because the splitting only works on points, we will assign a centroid to each building as an additional column. Centroids will be calculated in the corresponding projected Coordinate Reference System.

In [None]:
VANCOUVER_BOUNDING_BOX = (-123.148670, 49.255555, -123.076572, 49.296907)
VANCOUVER_PROJECTED_CRS = 26910  # NAD83 / UTM zone 10N
H3_RESOLUTION = 9

In [None]:
buildings = om.convert_bounding_box_to_geodataframe(
    theme="buildings",
    type="building",
    bbox=VANCOUVER_BOUNDING_BOX,
    release="2025-07-23.0",
    pyarrow_filter=pc.field("subtype").is_valid() & pc.field("height").is_valid(),
    columns_to_download=["subtype", "height"],
)
buildings["centroid"] = buildings.to_crs(VANCOUVER_PROJECTED_CRS).centroid.to_crs(4326)
buildings

## Without target column - default

Target column isn't required for spatial splitting.

By default, the algorithm calculates a density of points per H3 cell and uses it for the for stratification. This way both splits have both dense and sparse regions in them.

In [None]:
train_default_gdf, test_default_gdf = train_test_spatial_split(
    input_gdf=buildings,
    parent_h3_resolution=H3_RESOLUTION,
    geometry_column="centroid",
    target_column=None,
    test_size=0.2,
    random_state=42,
)
train_default_gdf

In [None]:
covering_h3_cells = h3_to_geoseries(shapely_geometry_to_h3(buildings["centroid"], H3_RESOLUTION))

In [None]:
ax = train_default_gdf.plot(color="royalblue", figsize=(15, 15), label="train", legend=True)
test_default_gdf.plot(color="orange", ax=ax, label="test", legend=True)
covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3)

ax.set_title("Vancouver buildings data - count split")
ax.legend(
    handles=[Patch(facecolor="royalblue"), Patch(facecolor="orange")], labels=["Train", "Test"]
)
cx.add_basemap(ax, source=cx.providers.CartoDB.PositronNoLabels, crs=4326, zoom=15)
plt.show()

## With numerical target column

If a target column is provided, it will be automatically treated as a numerical column, split into buckets (default: `7`) and stratified based on those buckets. The value distibution will be roughly the same in both splits.

In [None]:
train_height_gdf, test_height_gdf = train_test_spatial_split(
    input_gdf=buildings,
    parent_h3_resolution=9,
    geometry_column="centroid",
    target_column="height",
    n_bins=7,
    test_size=0.2,
    random_state=42,
)

In [None]:
ax = sns.kdeplot(
    data=train_height_gdf,
    x="height",
    fill=True,
    label="train",
    log_scale=True,
)
sns.kdeplot(
    data=test_height_gdf,
    x="height",
    fill=True,
    label="test",
    ax=ax,
)
ax.legend()
ax.set_xlim(left=1)
plt.show()

In [None]:
train_covering_h3_cells = h3_to_geoseries(
    shapely_geometry_to_h3(train_height_gdf["centroid"], H3_RESOLUTION)
)
test_covering_h3_cells = h3_to_geoseries(
    shapely_geometry_to_h3(test_height_gdf["centroid"], H3_RESOLUTION)
)

In [None]:
with plt.rc_context({"hatch.linewidth": 0.3}):
    ax = buildings.plot(
        gpd.pd.qcut(buildings["height"], 7),
        figsize=(15, 15),
        cmap="Spectral_r",
        legend=True,
        legend_kwds=dict(title="Height category (m)"),
    )
    train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3)
    test_covering_h3_cells.plot(
        ax=ax, linewidth=0.3, color=(0, 0, 0, 0), edgecolor="black", hatch="//"
    )

    ax2 = ax.twinx()
    ax2.get_yaxis().set_visible(False)
    ax2.legend(
        handles=[
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
        ],
        labels=["Train", "Test"],
        loc=2,
    )

    ax.set_title("Vancouver buildings data - numerical split")
    cx.add_basemap(ax, source=cx.providers.CartoDB.PositronNoLabels, crs=4326, zoom=15, alpha=0.5)
    plt.show()

## With categorical target column

Stratification can be also done based on the extisting categorical column, without using buckets.

In that case, the `categorical` parameter must be set to `True`.

In [None]:
buildings["subtype"].value_counts()

In [None]:
train_categorical_gdf, test_categorical_gdf = train_test_spatial_split(
    input_gdf=buildings,
    parent_h3_resolution=9,
    geometry_column="centroid",
    target_column="subtype",
    categorical=True,
    test_size=0.2,
    random_state=42,
)

In [None]:
train_categories_stats = train_categorical_gdf["subtype"].value_counts().reset_index()
train_categories_stats["count"] /= train_categories_stats["count"].sum()
train_categories_stats["split"] = "train"

test_categories_stats = test_categorical_gdf["subtype"].value_counts().reset_index()
test_categories_stats["count"] /= test_categories_stats["count"].sum()
test_categories_stats["split"] = "test"

In [None]:
sns.barplot(
    data=gpd.pd.concat([train_categories_stats, test_categories_stats]),
    x="count",
    y="subtype",
    hue="split",
)
plt.show()

In [None]:
train_covering_h3_cells = h3_to_geoseries(
    shapely_geometry_to_h3(train_categorical_gdf["centroid"], H3_RESOLUTION)
)
test_covering_h3_cells = h3_to_geoseries(
    shapely_geometry_to_h3(test_categorical_gdf["centroid"], H3_RESOLUTION)
)

In [None]:
with plt.rc_context({"hatch.linewidth": 0.3}):
    ax = buildings.plot(
        "subtype",
        figsize=(15, 15),
        cmap="Paired",
        legend=True,
        legend_kwds=dict(title="Building subtype"),
    )
    train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3)
    test_covering_h3_cells.plot(
        ax=ax, linewidth=0.3, color=(0, 0, 0, 0), edgecolor="black", hatch="//"
    )

    ax2 = ax.twinx()
    ax2.get_yaxis().set_visible(False)
    ax2.legend(
        handles=[
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
        ],
        labels=["Train", "Test"],
        loc=2,
    )

    ax.set_title("Vancouver buildings data - categorical split")
    cx.add_basemap(ax, source=cx.providers.CartoDB.PositronNoLabels, crs=4326, zoom=15, alpha=0.5)
    plt.show()

## Splitting into three datasets at once

By using another function, `spatial_split_points`, user can split the dataset into three groups at once (train, validation, test).

Usually users want to split data into train and test sets, and run the splitting again to get the validation set, but `SRAI` exposes a function to split directly into 3 splits. This function returns a dictionary with splitted data.

In [None]:
splits = spatial_split_points(
    input_gdf=buildings,
    parent_h3_resolution=H3_RESOLUTION,
    geometry_column="centroid",
    target_column=None,
    # Can also be passed as an expected number of points, not only a fraction
    test_size=1000,
    validation_size=500,
    random_state=42,
)

In [None]:
print(splits.keys())

In [None]:
ax = splits["train"].plot(color="royalblue", figsize=(15, 15), label="train", legend=True)
splits["test"].plot(color="orange", ax=ax, label="test", legend=True)
splits["validation"].plot(color="limegreen", ax=ax, label="validation", legend=True)
covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3)

ax.set_title("Vancouver buildings data - count split into three sets")
ax.legend(
    handles=[Patch(facecolor="royalblue"), Patch(facecolor="orange"), Patch(facecolor="limegreen")],
    labels=["Train", "Test", "Validation"],
)
cx.add_basemap(ax, source=cx.providers.CartoDB.PositronNoLabels, crs=4326, zoom=15)
plt.show()

### Parse split report manually

You can use the `return_split_stats` to get the splitting report as a pandas DataFrame and manually validate splitting ratios.

You can also use the `verbose` parameter to disable the output.

In [None]:
splits, split_report = spatial_split_points(
    input_gdf=buildings,
    parent_h3_resolution=H3_RESOLUTION,
    geometry_column="centroid",
    target_column=None,
    # Can also be passed as an expected number of points, not only a fraction
    test_size=1000,
    validation_size=500,
    random_state=42,
    return_split_stats=True,
    verbose=False,
)
split_report

In [None]:
split_report[["train_ratio", "validation_ratio", "test_ratio"]].mean()

In [None]:
split_report[
    ["train_ratio_difference", "validation_ratio_difference", "test_ratio_difference"]
].mean()

In [None]:
split_report[["train_points", "validation_points", "test_points"]].sum()

## Different H3 resolutions

You can perform splitting at different H3 resolutions, and the choice of resolution will affect the results.

- Higher resolutions (smaller hexagons) produce a split ratio closer to your target, but the regions are physically closer together, which reduces true spatial separation.
- Lower resolutions (larger hexagons) improve spatial separation but may cause the actual split ratio to deviate more from the target.

As a rule of thumb, choose the lowest resolution that still keeps the split ratio difference within an acceptable range for your use case.

In [None]:
def split_per_resolution(resolution: int) -> None:
    """Split the data using given resolution."""
    test_ratio = 0.4
    _train_gdf, _test_gdf = train_test_spatial_split(
        input_gdf=buildings,
        parent_h3_resolution=resolution,
        geometry_column="centroid",
        target_column="height",
        test_size=test_ratio,
        random_state=42,
        verbose=False,
    )

    actual_test_ratio = len(_test_gdf) / len(buildings)
    test_ratio_diff = test_ratio - actual_test_ratio

    _train_covering_h3_cells = h3_to_geoseries(
        shapely_geometry_to_h3(_train_gdf["centroid"], resolution)
    )
    _test_covering_h3_cells = h3_to_geoseries(
        shapely_geometry_to_h3(_test_gdf["centroid"], resolution)
    )

    with plt.rc_context({"hatch.linewidth": 0.3}):
        ax = _train_gdf.plot(color="royalblue", figsize=(15, 15))
        _test_gdf.plot(color="orange", ax=ax)
        _train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3)
        _test_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3)

        ax.legend(
            handles=[Patch(facecolor="royalblue"), Patch(facecolor="orange")],
            labels=["Train", "Test"],
        )

        ax.set_title(
            f"Vancouver buildings data - numerical split (H3 resolution: {resolution})\n"
            f"Expected test ratio: {test_ratio:.2f}, Actual test ratio: {actual_test_ratio:.2f}, "
            f"Diff: {test_ratio_diff:.3f}"
        )
        cx.add_basemap(
            ax, source=cx.providers.CartoDB.PositronNoLabels, crs=4326, zoom=15, alpha=0.5
        )
        plt.show()

In [None]:
for h3_res in range(7, 11):
    split_per_resolution(h3_res)

## What to do with timeseries data?

TODO: add info about recommendation for working with timeseries