In [None]:
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMOnlineLoader
from srai.neighbourhoods import H3Neighbourhood
from srai.regionalizers import H3Regionalizer
from srai.utils import geocode_to_region_gdf
from srai.plotting import plot_regions, plot_numeric_data
from pytorch_lightning import seed_everything

In [None]:
SEED = 71
seed_everything(SEED)

### Load data from OSM

First use geocoding to get the area

In [None]:
area_gdf = geocode_to_region_gdf("Wrocław, Poland")
plot_regions(area_gdf, tiles_style="CartoDB positron")

Next, download the data for the selected region and the specified tags. We're using `OSMOnlineLoader` here, as it's faster for low numbers of tags. In a real life scenario with more tags, you would likely want to use the `OSMPbfLoader`.

In [None]:
tags = {
    "leisure": "park",
    "landuse": "forest",
    "amenity": ["bar", "restaurant", "cafe"],
    "water": "river",
    "sport": "soccer",
}
loader = OSMOnlineLoader()

features_gdf = loader.load(area_gdf, tags)

folium_map = plot_regions(area_gdf, colormap=["rgba(0,0,0,0)"], tiles_style="CartoDB positron")
features_gdf.explore(m=folium_map)

## Prepare the data for embedding

After downloading the data, we need to prepare it for embedding. Namely - we need to regionalize the selected area, and join the features with regions.

In [None]:
regionalizer = H3Regionalizer(resolution=9)
regions_gdf = regionalizer.transform(area_gdf)
plot_regions(regions_gdf, tiles_style="CartoDB positron")

In [None]:
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(regions_gdf, features_gdf)
joint_gdf

## Embedding

After preparing the data we can proceed with generating embeddings for the regions.

In [None]:
import warnings

neighbourhood = H3Neighbourhood(regions_gdf)
embedder = Hex2VecEmbedder([15, 10])

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    embeddings = embedder.fit_transform(
        regions_gdf,
        features_gdf,
        joint_gdf,
        neighbourhood,
        trainer_kwargs={"max_epochs": 5, "accelerator": "cpu"},
        batch_size=100,
    )
embeddings

### Visualizing the embeddings' similarity

In [None]:
embedder.save("./modello")

In [None]:
embedder_loaded = Hex2VecEmbedder.load("./modello")
embedder_loaded

In [None]:
from srai.embedders import Highway2VecEmbedder
from srai.loaders import OSMNetworkType, OSMWayLoader

d = OSMWayLoader(OSMNetworkType.DRIVE).load(area_gdf)

In [None]:
joint = joiner.transform(regions_gdf, d[1])

In [None]:
highway2vec = Highway2VecEmbedder()
highway2vec.fit(regions_gdf, d[1], joint)

In [None]:
highway2vec.save("highway2vec")

In [None]:
vars(highway2vec)

In [None]:
Highway2VecEmbedder.load("highway2vec")

In [None]:
import geopandas as gpd
import pandas as pd
from srai.embedders import GTFS2VecEmbedder
from srai.constants import REGIONS_INDEX
from shapely.geometry import Polygon

features_gdf = gpd.GeoDataFrame(
    {
        "trip_count_at_6": [1, 0, 0],
        "trip_count_at_7": [1, 1, 0],
        "trip_count_at_8": [0, 0, 1],
        "directions_at_6": [
            {"A", "A1"},
            {"B", "B1"},
            {"C"},
        ],
    },
    geometry=gpd.points_from_xy([1, 2, 5], [1, 2, 2]),
    index=pd.Index(name="stop_id", data=[1, 2, 3]),
)
regions_gdf = gpd.GeoDataFrame(
    geometry=[
        Polygon([(0, 0), (0, 3), (3, 3), (3, 0)]),
        Polygon([(4, 0), (4, 3), (7, 3), (7, 0)]),
        Polygon([(8, 0), (8, 3), (11, 3), (11, 0)]),
    ],
    index=pd.Index(name=REGIONS_INDEX, data=["ff1", "ff2", "ff3"]),
)

joint_gdf = gpd.GeoDataFrame()
joint_gdf.index = pd.MultiIndex.from_tuples(
    [("ff1", 1), ("ff1", 2), ("ff2", 3)],
    names=[REGIONS_INDEX, "stop_id"],
)

embedder = GTFS2VecEmbedder(hidden_size=2, embedding_size=4)
embedder.fit(regions_gdf, features_gdf, joint_gdf)
res = embedder.transform(regions_gdf, features_gdf, joint_gdf)
res

In [None]:
embedder.save("gtfs2vec")

In [None]:
a = embedder.load("gtfs2vec")

In [None]:
a = embedder.transform(regions_gdf, features_gdf, joint_gdf)

In [None]:
a