In [None]:
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
from pytorch_lightning import seed_everything
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import h3
from shapely.ops import transform
from functools import partial
import pyproj

from srai.embedders import Highway2VecEmbedder, Hex2VecEmbedder, GTFS2VecEmbedder, GeoVexEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMNetworkType, OSMWayLoader, OSMOnlineLoader, OSMPbfLoader, GTFSLoader
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS, HEX2VEC_FILTER
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting import plot_regions, plot_numeric_data
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
from srai.h3 import ring_buffer_h3_regions_gdf

from pathlib import Path
from tqdm import tqdm
import torch
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader


In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
selected_regions_gdf = gpd.read_file("selected_regions_10.geojson")
selected_regions_buffered_gdf = gpd.read_file("selected_regions_buffered_10.geojson")
# set index to region_id
selected_regions_gdf.set_index("region_id", inplace=True)
selected_regions_buffered_gdf.set_index("region_id", inplace=True)
# remove irrelevant columns Index(['IntersectionArea', 'lbm', 'afw', 'fys', 'onv', 'soc', 'vrz', 'won'] (gives value error when doing hex2vec)
selected_regions_gdf.drop(columns=['IntersectionArea', 'lbm', 'afw', 'fys', 'onv', 'soc', 'vrz', 'won'], inplace=True)
selected_regions_buffered_gdf.drop(columns=['IntersectionArea', 'lbm', 'afw', 'fys', 'onv', 'soc', 'vrz', 'won'], inplace=True)

# Embeddings GTFS

In [None]:
gtfs_loader = GTFSLoader()
features_gdf = gtfs_loader.load("D:\\tu delft\\Afstuderen\\gtfs_nl.zip")

In [None]:
# Step 1: Replace NaN values with empty sets in columns expected to contain sets
for column in features_gdf.columns:
    if features_gdf[column].dtype == 'object':
        features_gdf[column] = features_gdf[column].apply(lambda x: set() if pd.isna(x) else x)

# Step 2: Convert floats to ints
for column in features_gdf.columns:
    if features_gdf[column].dtype == float:
        features_gdf[column] = features_gdf[column].fillna(0).astype(int)

# Step 3: Merging features messed up index names
features_gdf.index.name = "feature_id"

# Step 4: Join data now that it is in correct format
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(selected_regions_buffered_gdf, features_gdf)

# Step 5: Embed the features
seed_everything(42)
embedder = GTFS2VecEmbedder()
embeddings_GTFS = embedder.fit_transform(selected_regions_buffered_gdf, features_gdf, joint_gdf)

# Embeddings RN (road network)

In [None]:
# Seed for reproducibility
seed_everything(42)

# Load OSM data
loader = OSMWayLoader(OSMNetworkType.DRIVE)
area_southholland_gdf = geocode_to_region_gdf("South Holland, Netherlands")
nodes_gdf, edges_gdf = loader.load(area_southholland_gdf)

ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))
nodes_gdf.plot(ax=ax, markersize=3, color="red")

edges_gdf.to_crs(epsg=4326, inplace=True)
selected_regions_buffered_gdf.to_crs(epsg=4326, inplace=True)

In [None]:
# Join regions and edges
# Note that joiner is already set to IntersectionJoiner() in the GTFS embedding step
joint_gdf = joiner.transform(selected_regions_buffered_gdf, edges_gdf)

In [None]:
# Embed the road network
embedder = Highway2VecEmbedder()
embedder.fit(selected_regions_buffered_gdf, edges_gdf, joint_gdf)
embeddings_roadnetwork = embedder.transform(selected_regions_buffered_gdf, edges_gdf, joint_gdf)

In [None]:
# Visualization and Export
from Plotting import pca_plot, cluster_plot
warnings.filterwarnings('ignore')
#pca_plot(embeddings_GTFS, selected_regions_buffered_gdf)

In [None]:
#pca_plot(embeddings_roadnetwork, selected_regions_buffered_gdf)

In [None]:
#pca_plot(embeddings_GTFS, selected_regions_gdf)

In [None]:
#cluster_plot(embeddings_roadnetwork, selected_regions_gdf, 6)

In [None]:
# export embeddings
embeddings_roadnetwork.to_csv("embeddings_roadnetwork_10.csv")
embeddings_GTFS.to_csv("embeddings_GTFS_10.csv")