# Spatial splitting with stratification

SRAI library contains a dedicated functions for splitting the points dataset into train / test (and optionally validation) splits by separating the points spatially while also keeping them stratified based on a given target.

The function only works for points dataset and uses [`H3`](https://h3geo.org/) indexing system to cluster points together and separate H3 cells into different splits.

---

When working with most machine learning datasets, splitting into training and testing sets is straightforward: pick a random subset for testing, and (optionally) use stratification to keep the distribution of a target variable balanced between the two. This works fine when the data points are independent.

Geospatial data plays by different rules. Nearby locations often share similar characteristics - a phenomenon called spatial autocorrelation. If we split data randomly, our training and test sets might end up covering the same areas, meaning the model is “tested” on locations that are practically identical to ones it has already seen. This can make performance look much better than it really is and we can't test its capability to generalize the reasoning based on spatial features. 

That’s why for geo-related tasks, we need spatial splitting: making sure the training and test sets are separated in space so that evaluation reflects real-world conditions. Sometimes we also want to stratify these spatial splits by a numerical value to ensure both sets still have similar value distributions. Standard `train_test_split` functions can’t combine these two needs, so we provide a dedicated function for spatially aware splitting with optional stratification.

---

This notebook will show how different modes of splitting work based on buildings dataset from [Overture Maps Foundation](https://overturemaps.org/).

### How does it work?

To separate the input dataset into multiple outputs, H3 indexing system is used to split groups of points together.

First, the algorithm transform the points into H3 cells with a given resolution and calculates statistics per H3 cell (number of points per bucket / category).

Next, all H3 cells are shuffled (with optional `random_state` to ensure reproducibility) and iterated one by one.

For each split (test, validation, test) and each bucket per split, a current number of points is saved. While iterating each H3 cell with a group of points inside it, a potential new number of points is calculated with a difference to the expected ratio. Current H3 cell is assigned to the split where the difference to the expected ratio is the lowest.

After iterating all H3 cells, the original dataset of points is split based on the list of assigned H3 cells.

The report of splitting is printed with differences between expected and actual ratios.

In [None]:
import contextily as cx
import geopandas as gpd
import matplotlib.pyplot as plt
import overturemaestro as om
import pyarrow.compute as pc
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.patches import Patch
from sklearn.model_selection import train_test_split

from srai.h3 import h3_to_geoseries, shapely_geometry_to_h3
from srai.spatial_split import spatial_split_points, train_test_spatial_split

Let's start with downloading example data. Here we will use Vancouver buildings from Overture Maps dataset.

We only want buildings with both `height` and `subtype` columns filled.

Height will be used in the numerical split example and subtype in the categorical split example.

---

Because the splitting only works on points, we will assign a centroid to each building as an additional column. Centroids will be calculated in the corresponding projected Coordinate Reference System.

In [None]:
VANCOUVER_BOUNDING_BOX = (-123.148670, 49.255555, -123.076572, 49.296907)
VANCOUVER_PROJECTED_CRS = 26910  # NAD83 / UTM zone 10N
H3_RESOLUTION = 9

In [None]:
buildings = om.convert_bounding_box_to_geodataframe(
    theme="buildings",
    type="building",
    bbox=VANCOUVER_BOUNDING_BOX,
    release="2025-07-23.0",
    pyarrow_filter=pc.field("subtype").is_valid() & pc.field("height").is_valid(),
    columns_to_download=["subtype", "height"],
)
buildings["centroid"] = buildings.to_crs(VANCOUVER_PROJECTED_CRS).centroid.to_crs(4326)
buildings

First, let's see how the random split without spatial context looks like.

In [None]:
# train_test_split function from scikit-learn
random_train_gdf, random_test_gdf = train_test_split(buildings, test_size=0.2, random_state=42)

ax = random_train_gdf.plot(color="#1E88E5", figsize=(15, 12), label="train", legend=True)
random_test_gdf.plot(color="#FFC107", ax=ax, label="test", legend=True)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5)

ax.set_title("Vancouver buildings data - random split")
ax.legend(
    handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
    labels=["Train", "Test"],
)
cx.add_basemap(
    ax,
    source=cx.providers.CartoDB.VoyagerNoLabels,
    crs=4326,
    zoom=15,
    alpha=0.8,
)
ax.set_axis_off()
plt.show()

As shown, the buildings are split at random, resulting in both sets covering the same geographic area.

With this approach, you can’t properly evaluate the model’s ability to generalize based on spatial patterns.

## Without target column - default

Target column isn't required for spatial splitting.

By default, the algorithm calculates a density of points per H3 cell and uses it for the for stratification. This way both splits have both dense and sparse regions in them.

In [None]:
train_default_gdf, test_default_gdf = train_test_spatial_split(
    input_gdf=buildings,
    parent_h3_resolution=H3_RESOLUTION,
    geometry_column="centroid",
    target_column=None,
    test_size=0.2,
    random_state=42,
)
train_default_gdf

In [None]:
covering_h3_cells = h3_to_geoseries(shapely_geometry_to_h3(buildings["centroid"], H3_RESOLUTION))

In [None]:
ax = train_default_gdf.plot(color="#1E88E5", figsize=(15, 12), zorder=2)
test_default_gdf.plot(color="#FFC107", ax=ax, zorder=2)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)
covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, alpha=0.5, zorder=1)

ax.set_title("Vancouver buildings data - count split")
ax.legend(
    handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
    labels=["Train", "Test"],
)
cx.add_basemap(
    ax,
    source=cx.providers.CartoDB.VoyagerNoLabels,
    crs=4326,
    zoom=15,
    alpha=0.8,
)
ax.set_axis_off()
plt.show()

## With numerical target column

If a target column is provided, it will be automatically treated as a numerical column, split into buckets (default: `7`) and stratified based on those buckets. The value distibution will be roughly the same in both splits.

In [None]:
train_height_gdf, test_height_gdf = train_test_spatial_split(
    input_gdf=buildings,
    parent_h3_resolution=9,
    geometry_column="centroid",
    target_column="height",
    n_bins=7,
    test_size=0.2,
    random_state=42,
)

In [None]:
ax = sns.kdeplot(
    data=train_height_gdf,
    x="height",
    fill=True,
    label="train",
    log_scale=True,
)
sns.kdeplot(
    data=test_height_gdf,
    x="height",
    fill=True,
    label="test",
    ax=ax,
)
ax.legend()
ax.set_xlim(left=1)
plt.show()

In [None]:
train_covering_h3_cells = h3_to_geoseries(
    shapely_geometry_to_h3(train_height_gdf["centroid"], H3_RESOLUTION)
)
test_covering_h3_cells = h3_to_geoseries(
    shapely_geometry_to_h3(test_height_gdf["centroid"], H3_RESOLUTION)
)

In [None]:
with plt.rc_context({"hatch.linewidth": 0.3}):
    ax = buildings.plot(
        gpd.pd.qcut(buildings["height"], 7),
        figsize=(15, 12),
        cmap="Spectral_r",
        legend=True,
        legend_kwds=dict(title="Height category (m)"),
        zorder=2,
    )
    buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2, alpha=0.5)
    train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
    test_covering_h3_cells.plot(
        ax=ax,
        linewidth=0.3,
        color=(0, 0, 0, 0),
        edgecolor="black",
        hatch="//",
        zorder=1,
    )

    ax2 = ax.twinx()
    ax2.get_yaxis().set_visible(False)
    ax2.legend(
        handles=[
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
        ],
        labels=["Train", "Test"],
        loc=2,
    )

    ax.set_title("Vancouver buildings data - numerical split")
    cx.add_basemap(
        ax,
        source=cx.providers.CartoDB.VoyagerNoLabels,
        crs=4326,
        zoom=15,
        alpha=0.5,
    )
    ax.set_axis_off()
    plt.show()

## With categorical target column

Stratification can be also done based on the extisting categorical column, without using buckets.

In that case, the `categorical` parameter must be set to `True`.

In [None]:
buildings["subtype"].value_counts()

In [None]:
train_categorical_gdf, test_categorical_gdf = train_test_spatial_split(
    input_gdf=buildings,
    parent_h3_resolution=9,
    geometry_column="centroid",
    target_column="subtype",
    categorical=True,
    test_size=0.2,
    random_state=42,
)

In [None]:
train_categories_stats = train_categorical_gdf["subtype"].value_counts().reset_index()
train_categories_stats["count"] /= train_categories_stats["count"].sum()
train_categories_stats["split"] = "train"

test_categories_stats = test_categorical_gdf["subtype"].value_counts().reset_index()
test_categories_stats["count"] /= test_categories_stats["count"].sum()
test_categories_stats["split"] = "test"

In [None]:
sns.barplot(
    data=gpd.pd.concat([train_categories_stats, test_categories_stats]),
    x="count",
    y="subtype",
    hue="split",
)
plt.show()

In [None]:
train_covering_h3_cells = h3_to_geoseries(
    shapely_geometry_to_h3(train_categorical_gdf["centroid"], H3_RESOLUTION)
)
test_covering_h3_cells = h3_to_geoseries(
    shapely_geometry_to_h3(test_categorical_gdf["centroid"], H3_RESOLUTION)
)

In [None]:
with plt.rc_context({"hatch.linewidth": 0.3}):
    ax = buildings.plot(
        "subtype",
        categories=buildings["subtype"].value_counts().index,
        figsize=(15, 12),
        cmap="Set3",
        legend=True,
        legend_kwds=dict(title="Building subtype"),
        zorder=2,
    )
    buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2, alpha=0.5)
    train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
    test_covering_h3_cells.plot(
        ax=ax,
        linewidth=0.3,
        color=(0, 0, 0, 0),
        edgecolor="black",
        hatch="//",
        zorder=1,
    )

    ax2 = ax.twinx()
    ax2.get_yaxis().set_visible(False)
    ax2.legend(
        handles=[
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
        ],
        labels=["Train", "Test"],
        loc=2,
    )

    ax.set_title("Vancouver buildings data - categorical split")
    cx.add_basemap(
        ax,
        source=cx.providers.CartoDB.VoyagerNoLabels,
        crs=4326,
        zoom=15,
        alpha=0.5,
    )
    ax.set_axis_off()
    plt.show()

## Splitting into three datasets at once

By using another function, `spatial_split_points`, user can split the dataset into three groups at once (train, validation, test).

Usually users want to split data into train and test sets, and run the splitting again to get the validation set, but `SRAI` exposes a function to split directly into 3 splits. This function returns a dictionary with splitted data.

In [None]:
splits = spatial_split_points(
    input_gdf=buildings,
    parent_h3_resolution=H3_RESOLUTION,
    geometry_column="centroid",
    target_column=None,
    # Size can also be passed as an expected number of points, not only a fraction
    test_size=1000,
    validation_size=500,
    random_state=42,
)

In [None]:
print(splits.keys())

In [None]:
ax = splits["train"].plot(color="#1E88E5", figsize=(15, 12), zorder=2)
splits["test"].plot(color="#FFC107", ax=ax, zorder=2)
splits["validation"].plot(color="#D81B60", ax=ax, zorder=2)
buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)
covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, alpha=0.5, zorder=1)

ax.set_title("Vancouver buildings data - count split into three sets")
ax.legend(
    handles=[
        Patch(facecolor="#1E88E5"),
        Patch(facecolor="#FFC107"),
        Patch(facecolor="#D81B60"),
    ],
    labels=["Train", "Test", "Validation"],
)
cx.add_basemap(
    ax,
    source=cx.providers.CartoDB.VoyagerNoLabels,
    crs=4326,
    zoom=15,
    alpha=0.8,
)
ax.set_axis_off()
plt.show()

### Parse split report manually

You can use the `return_split_stats` to get the splitting report as a pandas DataFrame and manually validate splitting ratios.

You can also use the `verbose` parameter to disable the output.

In [None]:
splits, split_report = spatial_split_points(
    input_gdf=buildings,
    parent_h3_resolution=H3_RESOLUTION,
    geometry_column="centroid",
    target_column=None,
    # Can also be passed as an expected number of points, not only a fraction
    test_size=1000,
    validation_size=500,
    random_state=42,
    return_split_stats=True,
    verbose=False,
)
split_report

In [None]:
split_report[["train_ratio", "validation_ratio", "test_ratio"]].mean()

In [None]:
split_report[
    [
        "train_ratio_difference",
        "validation_ratio_difference",
        "test_ratio_difference",
    ]
].mean()

In [None]:
split_report[["train_points", "validation_points", "test_points"]].sum()

## Different H3 resolutions

You can perform splitting at different H3 resolutions, and the choice of resolution will affect the results.

- Higher resolutions (smaller hexagons) produce a split ratio closer to your target, but the regions are physically closer together, which reduces true spatial separation.
- Lower resolutions (larger hexagons) improve spatial separation but may cause the actual split ratio to deviate more from the target.

<div class="admonition tip">
    <p class="admonition-title">Selecting proper H3 resolution</p>
    <p>
    As a rule of thumb, choose the lowest resolution that still keeps the split ratio difference within an acceptable range for your use case.
    </p>
</div>

In [None]:
def split_per_resolution(resolution: int, ax: Axes, h3_edge_alpha: float) -> None:
    """Split the data using given resolution."""
    test_ratio = 0.4
    _train_gdf, _test_gdf = train_test_spatial_split(
        input_gdf=buildings,
        parent_h3_resolution=resolution,
        geometry_column="centroid",
        target_column="height",
        test_size=test_ratio,
        random_state=42,
        verbose=False,
    )
    buildings.exterior.plot(ax=ax, color="black", linewidth=0.3, alpha=0.5, zorder=2)

    actual_test_ratio = len(_test_gdf) / len(buildings)
    test_ratio_diff = test_ratio - actual_test_ratio

    _train_covering_h3_cells = h3_to_geoseries(
        shapely_geometry_to_h3(_train_gdf["centroid"], resolution)
    )
    _test_covering_h3_cells = h3_to_geoseries(
        shapely_geometry_to_h3(_test_gdf["centroid"], resolution)
    )

    _train_gdf.plot(color="#1E88E5", ax=ax, zorder=2)
    _test_gdf.plot(color="#FFC107", ax=ax, zorder=2)
    _train_covering_h3_cells.boundary.plot(
        color="black", ax=ax, linewidth=0.3, alpha=h3_edge_alpha, zorder=1
    )
    _test_covering_h3_cells.boundary.plot(
        color="black", ax=ax, linewidth=0.3, alpha=h3_edge_alpha, zorder=1
    )

    ax.legend(
        handles=[Patch(facecolor="#1E88E5"), Patch(facecolor="#FFC107")],
        labels=["Train", "Test"],
    )

    ax.set_title(
        f"Vancouver buildings data - numerical split (H3 resolution: {resolution})\n"
        f"Expected test ratio: {test_ratio:.2f}, "
        f"Actual test ratio: {actual_test_ratio:.2f}, "
        f"Diff: {test_ratio_diff:.3f}"
    )
    cx.add_basemap(
        ax,
        source=cx.providers.CartoDB.VoyagerNoLabels,
        crs=4326,
        zoom=15,
        alpha=0.8,
    )
    ax.set_axis_off()


with plt.rc_context({"hatch.linewidth": 0.3}):
    fig, axes = plt.subplots(2, 2, figsize=(20, 18), sharex=True, sharey=True)
    pairs = [
        (7, axes[0][0], 1.0),
        (8, axes[0][1], 0.9),
        (9, axes[1][0], 0.8),
        (10, axes[1][1], 0.7),
    ]
    for h3_res, ax, h3_edge_alpha in pairs:
        split_per_resolution(h3_res, ax, h3_edge_alpha)

buildings_bounds = buildings.total_bounds
for ax in axes.flatten():
    ax.set_xlim(buildings_bounds[0] - 0.001, buildings_bounds[2] + 0.001)
    ax.set_ylim(buildings_bounds[1] - 0.001, buildings_bounds[3] + 0.001)

plt.tight_layout()
plt.show()

## What to do with timeseries data?

When working with geospatial datasets that include a time component — for example, store locations with monthly performance data over the past year — it’s important to consider how the split is performed.

If you split purely at the row level, the same store might appear in both training and test sets for different months. This creates data leakage: the model could learn store-specific patterns from the training set and then see almost the same data in the test set, inflating performance metrics.

A better approach is to **split at the entity level**. For stores, that means assigning each store to a single split (train or test) and including all its historical monthly records in that split. This ensures that the model is evaluated on entirely unseen stores, which is especially important when the goal is to build a whitespot model for identifying promising new locations.

<div class="admonition tip">
    <p class="admonition-title">Utilizing temporal component</p>
    <p>
    If your dataset is big enough (data from multiple years), you can combine spatial splitting with temporal splitting to test how the model generalizes to both unseen stores and future time periods.
    </p>
</div>

---

Example below will show you how to utilize monthly transaction data to split the locations for the whitespot analysis.

First, let's select only the commercial buildings from the dataset.

In [None]:
stores = buildings[buildings["subtype"] == "commercial"].copy()
stores

In [None]:
ax = stores.plot(figsize=(15, 15))
cx.add_basemap(
    ax,
    source=cx.providers.CartoDB.VoyagerNoLabels,
    crs=4326,
    zoom=15,
    alpha=0.8,
)
ax.set_axis_off()
ax.set_title("Vancouver - commercial buildings")
plt.show()

Now, we can generate the dummy monthly sales data for the last year.

In [None]:
import numpy as np
import pandas as pd


def generate_monthly_sales(store_ids: pd.Index, seed=None):
    """
    Generate dummy monthly sales data for the past 12 months.

    Args:
        store_ids (pd.Index): IDs of locations.
        seed (int, optional): Random seed for reproducibility.

    Returns:
        pd.DataFrame: Columns = ['location_id', 'month', 'sales']
    """
    rng = np.random.default_rng(seed=seed)

    # Generate month labels (last 12 months, newest last)
    months = pd.date_range(end=pd.Timestamp.today(), periods=12, freq="M")
    month_labels = months.strftime("%Y-%m").tolist()

    # Seasonal multiplier with peak at December (sinusoidal pattern)
    phases = 2 * np.pi * (months.month - 12) / 12.0
    seasonal_factor = 1.0 + 0.2 * np.cos(phases)

    data = []

    for loc_id in store_ids:
        # Start with a base sales value for this location
        base_sales = rng.integers(8000, 20000)

        # Create gradual monthly changes using a small random walk
        gradual_changes = np.cumsum(rng.normal(loc=0, scale=300, size=12))

        # Combine base + changes + seasonality
        sales = (base_sales + gradual_changes) * seasonal_factor

        # Ensure sales are positive
        sales = np.clip(sales, 0, None)

        # Append to dataset
        for month_str, value in zip(month_labels, sales):
            data.append((loc_id, month_str, round(value, 2)))

    df = pd.DataFrame(data, columns=["id", "month", "sales"]).set_index("id")
    return df


df_sales = generate_monthly_sales(store_ids=stores.index, seed=42)
df_sales

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))
sns.lineplot(df_sales, x="month", y="sales", hue="id", legend=False, alpha=0.4, ax=ax)
ax.set_title("Monthly sales data per store")
ax.set_ylabel("Sales")
ax.set_xlabel("Month")
plt.tight_layout()
plt.show()

Now that we have stores locations and a dataframe with monthly sales data per location, we will calculate the average number of sales per month and use this information to stratify the spatial split.

In [None]:
mean_monthly_sales = df_sales.groupby("id")["sales"].mean()  # You can also use median
sns.histplot(mean_monthly_sales, kde=True)
plt.show()

Now we have to assign the mean values to the original dataframe.

In [None]:
stores["mean_monthly_sales"] = mean_monthly_sales
stores

Let's do the spit based on the mean monthly sales. We will reduce the number of bins to decrease the actual ratio difference.

In [None]:
train_sales_gdf, test_sales_gdf = train_test_spatial_split(
    input_gdf=stores,
    parent_h3_resolution=8,
    geometry_column="centroid",
    target_column="mean_monthly_sales",
    n_bins=5,
    test_size=0.2,
    random_state=42,
)

Here is the distribution between two sets.

In [None]:
ax = sns.kdeplot(
    data=train_sales_gdf,
    x="mean_monthly_sales",
    fill=True,
    label="train",
)
sns.kdeplot(
    data=test_sales_gdf,
    x="mean_monthly_sales",
    fill=True,
    label="test",
    ax=ax,
)
ax.legend()
ax.set_title("Mean monthly sales distribution per split")
plt.show()

In [None]:
train_covering_h3_cells = h3_to_geoseries(shapely_geometry_to_h3(train_sales_gdf["centroid"], 8))
test_covering_h3_cells = h3_to_geoseries(shapely_geometry_to_h3(test_sales_gdf["centroid"], 8))

In [None]:
with plt.rc_context({"hatch.linewidth": 0.3}):
    ax = stores.plot(
        # gpd.pd.qcut(buildings["height"], 7),
        "mean_monthly_sales",
        figsize=(15, 10),
        # cmap="Spectral_r",
        cmap="RdYlBu_r",
        legend=True,
        legend_kwds=dict(
            shrink=0.9,
            orientation="horizontal",
            pad=0.01,
            label="Mean monthly sales",
            aspect=60,
            fraction=0.03,
        ),
        zorder=2,
    )
    stores.exterior.plot(ax=ax, color="black", linewidth=0.3, zorder=2)
    train_covering_h3_cells.boundary.plot(color="black", ax=ax, linewidth=0.3, zorder=1)
    test_covering_h3_cells.plot(
        ax=ax,
        linewidth=0.3,
        color=(0, 0, 0, 0),
        edgecolor="black",
        hatch="/",
        zorder=1,
    )

    ax2 = ax.twinx()
    ax2.get_yaxis().set_visible(False)
    ax2.legend(
        handles=[
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0)),
            Patch(edgecolor="black", facecolor=(0, 0, 0, 0), hatch="///"),
        ],
        labels=["Train", "Test"],
        loc=2,
    )

    ax.set_title("Vancouver - commercial buildings - sales numerical split")
    cx.add_basemap(
        ax,
        source=cx.providers.CartoDB.VoyagerNoLabels,
        crs=4326,
        zoom=15,
        alpha=0.8,
    )
    ax.set_axis_off()
    stores_bounds = stores.total_bounds
    ax.set_xlim(stores_bounds[0] - 0.01, stores_bounds[2] + 0.01)
    ax.set_ylim(stores_bounds[1] - 0.01, stores_bounds[3] + 0.01)
    plt.show()

Now we can select transaction data based on location IDs.

In [None]:
train_store_sales = df_sales.loc[train_sales_gdf.index]
test_store_sales = df_sales.loc[test_sales_gdf.index]

print(len(train_store_sales), len(test_store_sales))
train_store_sales

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

sns.lineplot(train_store_sales, x="month", y="sales", legend=True, ax=ax, label="train")
sns.lineplot(test_store_sales, x="month", y="sales", legend=True, ax=ax, label="test")

ax.set_title("Monthly sales data per store")
ax.set_ylabel("Sales")
ax.set_xlabel("Month")

plt.tight_layout()
plt.show()

In [None]:
ax = sns.kdeplot(
    data=train_store_sales,
    x="sales",
    fill=True,
    label="train",
)
sns.kdeplot(
    data=test_store_sales,
    x="sales",
    fill=True,
    label="test",
    ax=ax,
)
ax.legend()
ax.set_title("Sales distribution per split")
plt.show()