In [None]:
import warnings

import contextily as cx
import matplotlib.pyplot as plt
import pandas as pd
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CSVLogger
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from srai.embedders import S2VecEmbedder
from srai.embedders.s2vec.s2_utils import get_patches_from_img_gdf
from srai.loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS
from srai.plotting import plot_regions
from srai.regionalizers import S2Regionalizer, geocode_to_region_gdf

In [None]:
SEED = 71
seed_everything(SEED)

### Load data from OSM

First use geocoding to get the area

In [None]:
area_gdf = geocode_to_region_gdf("Wrocław, Poland")
plot_regions(area_gdf, tiles_style="CartoDB positron")

In [None]:
img_resolution = 12
patch_resolution = 16

img_regionalizer = S2Regionalizer(resolution=img_resolution, buffer=True)
img_s2_regions = img_regionalizer.transform(area_gdf.reset_index(drop=True))

img_s2_geometry = img_s2_regions.union_all()

print("Image regions:", len(img_s2_regions))

### Download the Data


Next, download the data for the selected region and the specified tags.


In [None]:
tags = GEOFABRIK_LAYERS
loader = OSMPbfLoader()

features_gdf = loader.load(img_s2_regions, tags)

## Prepare the data for embedding


After downloading the data, we need to prepare it for embedding. In the previous step we have regionalized the selected area and buffered it, now we have to join the features with prepared regions.


In [None]:
plot_regions(img_s2_regions, tiles_style="CartoDB positron")

## S2Vec Embedding


After preparing the data we can proceed with generating embeddings for the regions.


In [None]:
embedder = S2VecEmbedder(
    target_features=GEOFABRIK_LAYERS,
    batch_size=8,
    img_res=img_resolution,
    patch_res=patch_resolution,
    embedding_dim=64,
    decoder_dim=32,
)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    csv_logger = CSVLogger(save_dir="s2vec_logs")

    embeddings = embedder.fit_transform(
        regions_gdf=img_s2_regions,
        features_gdf=features_gdf,
        trainer_kwargs={
            # "max_epochs": 20, # uncomment for a longer training
            "max_epochs": 5,
            "accelerator": ("cpu" if torch.backends.mps.is_available() else "auto"),
            "logger": csv_logger,
        },
        learning_rate=0.001,
    )

embeddings.head()

In [None]:
metrics_df = pd.read_csv(csv_logger.log_dir + "/metrics.csv").dropna(
    subset="train_loss_epoch"
)

fig, ax = plt.subplots(1, 1, figsize=(10, 5))

line1 = ax.plot(metrics_df["epoch"], metrics_df["train_loss_epoch"])
ax.set_title("Training metrics")
ax.set_ylabel("Loss")
ax.set_xlabel("Epoch")
plt.show()

In [None]:
patch_s2_regions, _ = get_patches_from_img_gdf(
    img_s2_regions, target_level=patch_resolution
)

# do pca with three components and then cast to RGB
pca = PCA(n_components=3)

pca_embeddings = pca.fit_transform(embeddings)
# make the embeddings into a dataframe
pca_embeddings = pd.DataFrame(pca_embeddings, index=embeddings.index)

# convert to RGB
pca_embeddings = (
    (pca_embeddings - pca_embeddings.min())
    / (pca_embeddings.max() - pca_embeddings.min())
).astype(float)
pca_embeddings["rgb"] = pca_embeddings.apply(
    lambda row: (row[0], row[1], row[2]), axis=1
)
color_values = patch_s2_regions.index.map(pca_embeddings["rgb"].to_dict()).to_list()

ax = (
    patch_s2_regions.reset_index()
    .reset_index()
    .plot(color=color_values, antialiased=True, figsize=(20, 20), alpha=0.6)
)
cx.add_basemap(ax, source=cx.providers.CartoDB.PositronNoLabels, crs=4326, zoom=12)
ax.set_axis_off()
ax.set_title("PCA representaion of embeddings")
plt.show()

### Clustering


In [None]:
clusterizer = KMeans(n_clusters=5, random_state=SEED)
clusterizer.fit(embeddings)
embeddings.index.name = "region_id"
embeddings["cluster"] = clusterizer.labels_
embeddings["cluster"]

In [None]:
ax = patch_s2_regions.plot(
    embeddings["cluster"],
    antialiased=True,
    figsize=(20, 20),
    alpha=0.6,
    legend=True,
    cmap="Spectral",
    categorical=True,
)
cx.add_basemap(ax, source=cx.providers.CartoDB.PositronNoLabels, crs=4326, zoom=12)
ax.set_axis_off()
ax.set_title("Clustering result")
plt.show()