## Baseline GCN testing
Notebook to create and evaluate GCN against EBC on predicting number of passing bicyclists in *copenhagen?*
- Preprocess EBC for graph DONE
- Assign Metrics from data
- Create Torch Graph
- Evaluate against SOTA

In [None]:
import torch
from torch_geometric.data import Data
import torch_geometric as tg
import osmnx as ox
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import geopandas as gpd
import folium
from folium import plugins
from folium.plugins import HeatMap
from shapely.geometry import Point, LineString, Polygon
import shapely
import momepy as mp 
import esda
import seaborn as sns
from shapely.strtree import STRtree
import pickle
from tqdm import tqdm

import os, glob, sys

print(os.path.exists('../scripts/bike_functions.py'))
import sys
sys.path.append('../scripts')
from bike_functions import *

sns.set_theme()

In [None]:
lat, lon = 55.6867243, 12.5700724 # Copenhagen coordinates
dist = 10000 # Distance in meters to fetch data

features = [
    'aerialway',
    'aeroway',
    'amenity',
    'barrier',
    'boundary',
    'building',
    'craft',
    'emergency',
    'geological',
    'healthcare',
    'highway',
    'historic',
    'landuse',
    'leisure',
    'man_made',
    'military',
    'natural',
    'office',
    'place',
    'power',
    'public_transport',
    'railway',
    'route',
    'service',
    'shop',
    'telecom',
    'tourism',
    'water',
    'waterway',
] # List of all features in OSMnx

# features = 'amenity shop building aerialway aeroway barrier boundary craft emergency highway historic landuse leisure healthcare military natural office power public_transport railway place service tourism waterway route water'.split()

# expand_features = features # If all features should be expanded
expand_features = [] # If none of the features should be expanded


In [None]:
g, gdf, amenities = get_city_graph(lat,
                                    lon,
                                    dist,
                                    features = features, 
                                    expand_features = expand_features)

In [None]:
H = create_linegraph(g)


### EBC Calculation

In [None]:
bc = calc_bc(H)
nx.set_node_attributes(H, bc, 'bc')


In [None]:
gdf_new = load_aadt('../data/raw/trafiktaelling.json', g, gdf)


In [None]:
H = assign_aadt_to_graph_edges(g, gdf_new, H, aadt_col='aadt_cykler')


In [None]:
assign_features_to_nodes(H, amenities, geometry_col='geometry', amenity_col='amenity')

In [None]:
for node, value in H.nodes(data=True):
    if 'aadt' not in value.keys():
        value['aadt'] = 0


In [None]:
all_feats = clean_and_standardize_node_features(H, remove_fields=None)


In [None]:
def graph_to_linegraph_data(H, all_feats, target_feat='aadt', osmid_feat='osmid'):
    """
    Converts a networkx graph H with node and edge attributes into a PyTorch Geometric Data object.
    
    Parameters:
    - H: networkx graph with node features.
    - all_feats: list of feature names to extract from nodes.
    - target_feat: feature to use as the target variable (default 'aadt').
    - osmid_feat: feature to use as osmid identifier (default 'osmid').
    
    Returns:
    - PyTorch Geometric Data object with node features, targets, osmid, and edge index.
    """
    node_list, x, y, osmid_list = [], [], [], []

    node_feat_names = [i for i in all_feats if i not in [target_feat, osmid_feat]]
    for node, feats in H.nodes(data=True):
        node_list.append(node)
        x.append([feats.get(feat, 0.0) for feat in all_feats if feat not in [target_feat, osmid_feat]])
        y.append(feats[target_feat])
        osmid_list.append(feats[osmid_feat])

    node_idx = {node: idx for idx, node in enumerate(node_list)}

    edge_index = [[node_idx[s], node_idx[t]] for s, t in H.edges()]

    data = Data()
    data.num_nodes = len(node_list)
    data.x = torch.tensor(x, dtype=torch.float)
    data.y = torch.tensor(y, dtype=torch.float)
    data.osmid = torch.tensor(osmid_list, dtype=torch.long)
    data.edge_index = torch.tensor(edge_index, dtype=torch.long).t()
    # data.H = H  # Optional: Attach original H graph if needed

    return data, node_feat_names

In [None]:
linegraph, node_feat_names = graph_to_linegraph_data(H, all_feats, target_feat='aadt', osmid_feat='osmid')
linegraph.feat_names = node_feat_names

In [None]:
assert linegraph.edge_index.shape[0] == 2
assert linegraph.edge_index.shape[1] == linegraph.edge_attr.shape[0] if 'edge_attr' in linegraph else True
assert linegraph.x.shape[0] == linegraph.num_nodes


In [None]:
# import contextily as cx
# import matplotlib.pyplot as plt
# from shapely.geometry import Point
# import geopandas as gpd

# # Reproject point to match gdf CRS
# point = Point(lon, lat)
# gdf_point = gpd.GeoDataFrame(geometry=[point], crs='epsg:4326').to_crs(gdf.crs)

# # Plot base gdf
# fig, ax = plt.subplots(figsize=(15, 10))
# gdf.plot(ax=ax, legend=True, markersize=1, alpha=0.5)

# # Add basemap
# cx.add_basemap(ax, crs=gdf.crs, source=cx.providers.CartoDB.Positron)

# # Now plot the point on top
# gdf_point.plot(ax=ax, color='red', markersize=100, label='Center Point')

# # Optional: remove axis
# plt.axis('off')
# plt.show()


In [None]:
# amenities = amenities.to_crs(gdf.crs)

In [None]:
# ### remove all linestring geometries in amenities
# # convert polygon geometries to centroids
# amenities['geometry'] = amenities.geometry.centroid
# # remove points outside the bounding box of the graph
# amenities = amenities[amenities.geometry.within(gdf.unary_union.envelope)]
# # convert amenities to points
# amenities = amenities.set_geometry('geometry')

In [None]:
# fig, ax = plt.subplots(figsize=(15, 10))
# amenities.plot(ax=ax, legend=True, markersize=10, alpha=0.5)
# gdf_point = gpd.GeoDataFrame(geometry=[point], crs='epsg:4326').to_crs(gdf.crs)

# # Add basemap
# cx.add_basemap(ax, crs=amenities.crs, source=cx.providers.CartoDB.Positron)
# gdf_point.plot(ax=ax, color='red', markersize=100, label='Center Point')

# # Optional: remove axis
# plt.axis('off')
# plt.show()

In [None]:
save_graph_with_config(
    linegraph, 
    H, 
    g,
    features, 
    expand_features, 
    dist
)

In [None]:
# your_graph_number = 17
# ### save g and h
# with open(f'../data/graphs/{your_graph_number}/graph.pkl', 'wb') as f:
#     pickle.dump(g, f)
# with open(f'../data/graphs/{your_graph_number}/linegraph.pkl', 'wb') as f:
#     pickle.dump(H, f)