In [2]:
#!/usr/bin/env python3
"""
Download and prepare a drive-only OSM network around a chosen center.

Outputs in the 'data' directory:
    - drive_nodes.parquet
    - drive_edges.parquet

These files are in the format expected by the Streamlit routing app.
"""

from pathlib import Path

import geopandas as gpd
import networkx as nx
import osmnx as ox


DATA_DIR: Path = Path('data')

CENTERS: dict[str, tuple[float, float]] = {
    'barreiro': (38.657111, -9.059832), # Hospital Nossa Senhora do Rosario, Barreiro
    'capelle': (51.954087208237304, 4.578024796888894), # Spoorlaan 6 2908 BG Capelle aan den IJssel, Netherlands
}

DEFAULT_CENTER_KEY: str = 'capelle'

RADIUS_BY_CENTER_METERS: dict[str, int] = {
    'barreiro': 30_000,
    'capelle': 50_000,
}


def configure_osmnx() -> None:
    """
    Configure global OSMnx settings for caching and logging.
    """
    ox.settings.use_cache = True
    ox.settings.log_console = True


def get_center_point(center_key: str) -> tuple[float, float]:
    """
    Get a center point (lat, lon) by key.

    Args:
        center_key: Key identifying the center (e.g., 'hospital_barreiro').

    Returns:
        (latitude, longitude) tuple.

    Raises:
        KeyError: If the center key is unknown.
    """
    return CENTERS[center_key]


def get_radius_meters(center_key: str, default_radius_meters: int) -> int:
    """
    Get a radius (meters) for a center key, falling back to a default.

    Args:
        center_key: Key identifying the center.
        default_radius_meters: Fallback radius if the key has no specific radius.

    Returns:
        Radius in meters.
    """
    return RADIUS_BY_CENTER_METERS.get(center_key, default_radius_meters)


def download_drive_graph(
    center_point: tuple[float, float],
    dist_meters: int,
) -> nx.MultiDiGraph:
    """
    Download a drive network around a center point from OSM.

    Args:
        center_point: (latitude, longitude) of the center.
        dist_meters: Radius in meters for the network extraction.

    Returns:
        OSMnx MultiDiGraph for the drive network.
    """
    graph: nx.MultiDiGraph = ox.graph_from_point(
        center_point,
        dist=dist_meters,
        network_type='drive',
        simplify=True,
    )
    return graph


def graph_to_parquet(
    graph: nx.MultiDiGraph,
    output_dir: Path,
    prefix: str = 'drive',
) -> None:
    """
    Project graph to a metric CRS, convert to GeoDataFrames, and write Parquet.

    This produces:
        output_dir / f"{prefix}_nodes.parquet"
        output_dir / f"{prefix}_edges.parquet"

    The nodes file has a "node_id" column with OSM node ids.
    The edges file has at least "u", "v", "length", and "geometry".
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    graph_proj: nx.MultiDiGraph = ox.project_graph(graph)

    nodes, edges = ox.graph_to_gdfs(graph_proj)

    nodes = nodes.copy()
    nodes['node_id'] = nodes.index

    edges = edges.reset_index(drop=False)

    node_cols = ['node_id', 'x', 'y', 'geometry']
    edge_cols = ['u', 'v', 'length', 'geometry']

    node_cols_existing = [col for col in node_cols if col in nodes.columns]
    edge_cols_existing = [col for col in edge_cols if col in edges.columns]

    nodes_out: gpd.GeoDataFrame = nodes[node_cols_existing].copy()
    edges_out: gpd.GeoDataFrame = edges[edge_cols_existing].copy()

    if 'node_id' not in nodes_out.columns:
        raise RuntimeError("nodes Parquet must contain a 'node_id' column.")

    for col in ['u', 'v', 'length']:
        if col not in edges_out.columns:
            raise RuntimeError(f"edges Parquet must contain column '{col}'.")

    nodes_path: Path = output_dir / f'{prefix}_nodes.parquet'
    edges_path: Path = output_dir / f'{prefix}_edges.parquet'

    nodes_out.to_parquet(nodes_path)
    edges_out.to_parquet(edges_path)

    print(f'Saved drive nodes to {nodes_path}')
    print(f'Saved drive edges to {edges_path}')
    print(f'Number of nodes: {len(nodes_out)}')
    print(f'Number of edges: {len(edges_out)}')


def main() -> None:
    """
    Configure OSMnx, download the drive graph for the chosen center, and export.
    """
    configure_osmnx()

    center_key = DEFAULT_CENTER_KEY
    center_point = get_center_point(center_key)
    radius_meters = get_radius_meters(center_key, default_radius_meters=30_000)

    print('Downloading drive network...')
    print(f'Center key: {center_key}')
    print(f'Center: {center_point}, radius: {radius_meters} m')

    graph = download_drive_graph(center_point, radius_meters)

    print(
        'Downloaded graph with '
        f'{graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges'
    )

    print('Exporting drive network to Parquet...')
    graph_to_parquet(graph, DATA_DIR, prefix=f'{DEFAULT_CENTER_KEY}_drive')
    print('All done.')


if __name__ == '__main__':
    main()

Downloading drive network...
Center key: capelle
Center: (51.954087208237304, 4.578024796888894), radius: 50000 m
Downloaded graph with 253608 nodes and 598612 edges
Exporting drive network to Parquet...
Saved drive nodes to data\capelle_drive_nodes.parquet
Saved drive edges to data\capelle_drive_edges.parquet
Number of nodes: 253608
Number of edges: 598612
All done.
