# Spatial splitting with stratification

SRAI library contains a dedicated functions for splitting the points dataset into train / test (and optionally validation) splits by separating the points spatially while also keeping them stratified based on a given target.

The function only works for points dataset and uses [`H3`](https://h3geo.org/) indexing system to cluster points together and separate H3 cells into different splits.

---

When working with most machine learning datasets, splitting into training and testing sets is straightforward: pick a random subset for testing, and (optionally) use stratification to keep the distribution of a target variable balanced between the two. This works fine when the data points are independent.

Geospatial data plays by different rules. Nearby locations often share similar characteristics - a phenomenon called spatial autocorrelation. If we split data randomly, our training and test sets might end up covering the same areas, meaning the model is “tested” on locations that are practically identical to ones it has already seen. This can make performance look much better than it really is.

That’s why for geo-related tasks, we need spatial splitting: making sure the training and test sets are separated in space so that evaluation reflects real-world conditions. Sometimes we also want to stratify these spatial splits by a numerical value to ensure both sets still have similar value distributions. Standard `train_test_split` functions can’t combine these two needs, so we provide a dedicated function for spatially aware splitting with optional stratification.

---

This notebook will show how different modes of splitting work based on buildings dataset from [Overture Maps Foundation](https://overturemaps.org/).

In [None]:
import geopandas as gpd
import overturemaestro as om
import pyarrow.compute as pc
import seaborn as sns

from srai.spatial_split import train_test_spatial_split

In [None]:
buildings = om.convert_bounding_box_to_geodataframe(
    theme="buildings",
    type="building",
    bbox=(-122.531822, 37.700213, -122.353294, 37.814456),  # SF
    release="2025-07-23.0",
    pyarrow_filter=pc.field("subtype").is_valid() & pc.field("height").is_valid(),
    columns_to_download=["subtype", "height"],
)
buildings

In [None]:
SAN_FRANCISCO_PROJECTED_CRS = 7131  # NAD83(2011) / San Francisco CS13

buildings_with_centroid = gpd.GeoDataFrame(
    buildings,
    geometry=buildings.to_crs(SAN_FRANCISCO_PROJECTED_CRS).centroid,
    crs=SAN_FRANCISCO_PROJECTED_CRS,
).to_crs(4326)
buildings_with_centroid

In [None]:
buildings["subtype"].value_counts()

## Spatial splitting without target column

Target column isn't required for spatial splitting.

In that case, the total number of points per H3 cell will be counted and used for stratification, to ensure that all splits have different densities of points.

In [None]:
train_default_gdf, test_default_gdf = train_test_spatial_split(
    input_gdf=buildings_with_centroid,
    parent_h3_resolution=9,
    target_column=None,
    test_size=0.2,
    random_state=42,
)
train_default_gdf

In [None]:
# TODO: add spatial plots

## Spatial splitting with numerical target column

If a target column is provided, it will be automatically treated as a numerical column, split into buckets (default: `7`) and stratified based on those buckets. The value distibution should be the same in both splits.

In [None]:
train_height_gdf, test_height_gdf = train_test_spatial_split(
    input_gdf=buildings_with_centroid,
    parent_h3_resolution=9,
    target_column="height",
    n_bins=7,
    test_size=0.2,
    random_state=42,
)
train_height_gdf

In [None]:
ax = sns.kdeplot(
    data=train_height_gdf,
    x="height",
    fill=True,
    label="train",
    log_scale=True,
)
sns.kdeplot(
    data=test_height_gdf,
    x="height",
    fill=True,
    label="test",
    ax=ax,
)
ax.legend()
ax.set_xlim(left=1)

In [None]:
# TODO: add spatial plots

## Spatial splitting with categorical target column

Stratification can be also done based on the extisting categorical column, without using buckets.

In that case, the `categorical` parameter must be set to `True`.

In [None]:
train_categorical_gdf, test_categorical_gdf = train_test_spatial_split(
    input_gdf=buildings_with_centroid,
    parent_h3_resolution=9,
    target_column="subtype",
    categorical=True,
    test_size=0.2,
    random_state=42,
)
train_categorical_gdf

In [None]:
train_categories_stats = train_categorical_gdf["subtype"].value_counts().reset_index()
train_categories_stats["count"] /= train_categories_stats["count"].max()
train_categories_stats["split"] = "train"

test_categories_stats = test_categorical_gdf["subtype"].value_counts().reset_index()
test_categories_stats["count"] /= test_categories_stats["count"].max()
test_categories_stats["split"] = "test"

In [None]:
sns.barplot(
    data=gpd.pd.concat([train_categories_stats, test_categories_stats]),
    x="count",
    y="subtype",
    hue="split",
)

In [None]:
# TODO: add spatial plots