In [None]:
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMOnlineLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.regionizers import H3Regionizer
from srai.utils import geocode_to_region_gdf
from srai.plotting.folium_wrapper import plot_regions
from pytorch_lightning import seed_everything

In [None]:
SEED = 71
seed_everything(SEED)

### Load data from OSM

First use geocoding to get the area

In [None]:
area_gdf = geocode_to_region_gdf("Wrocław, Poland")
plot_regions(area_gdf)

Next, download the data for the selected region and the specified tags. We're using `OSMOnlineLoader` here, as it's faster for low numbers of tags. In a real life scenario with more tags, you would likely want to use the `OSMPbfLoader`.

In [None]:
tags = {
    "leisure": "park",
    "landuse": "forest",
    "amenity": ["bar", "restaurant", "cafe"],
    "water": "river",
    "sport": "soccer",
}
loader = OSMOnlineLoader()

features_gdf = loader.load(area_gdf, tags)
features_gdf.explore()

## Prepare the data for embedding

After downloading the data, we need to prepare it for embedding. Namely - we need to regionize the selected area, and join the features with regions.

In [None]:
regionizer = H3Regionizer(resolution=9)
regions_gdf = regionizer.transform(area_gdf)
plot_regions(regions_gdf)

In [None]:
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf

## Embedding

After preparing the data we can proceed with generating embeddings for the regions.

In [None]:
import warnings

neighbourhood = H3Neighbourhood(regions_gdf)
embedder = Hex2VecEmbedder([15, 10])

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    embeddings = embedder.fit_transform(
        regions_gdf,
        features_gdf,
        joint_gdf,
        neighbourhood,
        trainer_kwargs={"max_epochs": 5, "accelerator": "cpu", "deterministic": True},
        batch_size=100,
    )
embeddings

### Visualizing the embeddings' similarity

In [None]:
from sklearn.cluster import KMeans

clusterizer = KMeans(n_clusters=5)
clusterizer.fit(embeddings)
clusterizer.labels_

In [None]:
regions_gdf["cluster"] = clusterizer.labels_
regions_gdf.explore("cluster")