In [None]:
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
from pytorch_lightning import seed_everything
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import h3
from shapely.ops import transform
from functools import partial
import pyproj

from srai.embedders import Highway2VecEmbedder, Hex2VecEmbedder, GTFS2VecEmbedder, GeoVexEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMNetworkType, OSMWayLoader, OSMOnlineLoader, OSMPbfLoader, GTFSLoader
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS, HEX2VEC_FILTER
from srai.neighbourhoods import H3Neighbourhood
from srai.plotting import plot_regions, plot_numeric_data
from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf
from srai.h3 import ring_buffer_h3_regions_gdf

from pathlib import Path
from tqdm import tqdm
import torch
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader


In [None]:
import warnings
warnings.filterwarnings('ignore')

# Import regions

In [None]:
selected_regions_gdf = gpd.read_file("selected_regions_10.geojson")
selected_regions_buffered_gdf = gpd.read_file("selected_regions_buffered_10.geojson")
# set index to region_id
selected_regions_gdf.set_index("region_id", inplace=True)
selected_regions_buffered_gdf.set_index("region_id", inplace=True)
# remove irrelevant columns Index(['IntersectionArea', 'lbm', 'afw', 'fys', 'onv', 'soc', 'vrz', 'won'] (gives value error when doing hex2vec)
selected_regions_gdf.drop(columns=['IntersectionArea', 'lbm', 'afw', 'fys', 'onv', 'soc', 'vrz', 'won'], inplace=True)
selected_regions_buffered_gdf.drop(columns=['IntersectionArea', 'lbm', 'afw', 'fys', 'onv', 'soc', 'vrz', 'won'], inplace=True)

# Prepare embeddings

In [None]:
# tags = HEX2VEC_FILTER
# loader = OSMOnlineLoader()
# area_southholland_gdf = geocode_to_region_gdf("South Holland, Netherlands")
# features_gdf = loader.load(area_southholland_gdf, tags)

In [None]:
tags = GEOFABRIK_LAYERS
loader = OSMPbfLoader()

features_gdf = loader.load(selected_regions_buffered_gdf, tags)

In [None]:
features_gdf.head()

In [None]:
seed_everything(42)
joiner = IntersectionJoiner()
joint_gdf = joiner.transform(selected_regions_buffered_gdf, features_gdf)

In [None]:
# neighbourhood = H3Neighbourhood(selected_regions_buffered_gdf)
# 
# embedder = Hex2VecEmbedder()
# 
# embeddings_POI_hex2vec = embedder.fit_transform(
#     selected_regions_buffered_gdf,
#     features_gdf,
#     joint_gdf,
#     neighbourhood,
#     trainer_kwargs={"max_epochs": 20, "accelerator": "gpu"},
#     batch_size=128,
# )

In [None]:
neighbourhood = H3Neighbourhood(selected_regions_buffered_gdf)

embedder = GeoVexEmbedder(
    target_features=GEOFABRIK_LAYERS,
    batch_size=128,
    neighbourhood_radius=5,     # see step 0 study area preparation - buffer of 15 hexaxgonal neighbors (3 times that of res 9 - 3 hex fit in 1 res 9)- no wait 15 is too much, 5 is enough we will do aggregation in a later step too remember this is just to create single view embedding
    convolutional_layers=2,
    embedding_size=50,
)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    embeddings_POI_geovex = embedder.fit_transform(
        regions_gdf=selected_regions_buffered_gdf,
        features_gdf=features_gdf,
        joint_gdf=joint_gdf,
        neighbourhood=neighbourhood,
        trainer_kwargs={
            # "max_epochs": 20, # uncomment for a longer training
            "max_epochs": 4,
            "accelerator": (
                "cpu" if torch.backends.mps.is_available() else "auto"
            ),  # GeoVexEmbedder does not support MPS
        },
        learning_rate=0.001,
    )

In [None]:
from Plotting import pca_plot, cluster_plot
# import csv of hex2vec embeddings for plotting and comparison
embeddings_POI_hex2vec = pd.read_csv("embeddings_POI_hex2vec_10.csv")

In [None]:
# set index to region_id
embeddings_POI_hex2vec.set_index("region_id", inplace=True)

In [None]:
pca_plot(embeddings_POI_hex2vec, selected_regions_gdf)

In [None]:
pca_plot(embeddings_POI_geovex, selected_regions_gdf)

In [None]:
# Export embeddings as csv
embeddings_POI_hex2vec.to_csv("embeddings_POI_hex2vec_10.csv")
embeddings_POI_geovex.to_csv("embeddings_POI_geovex_10.csv")

In [None]:
embeddings_POI_geovex.head()