In [1]:
import pandas as pd
import numpy as np
import geopandas as gpd



In [62]:
from typing import Dict, List, Optional

import osmnx as ox
import numpy as np
import networkx as nx

from pyrosm import OSM

from sklearn.neighbors import BallTree
from sklearn.metrics import pairwise

TRAVEL_TYPES = {
    "driving": "driving",
    "cycling": "cycling",
    "walking": "walking",
    "service": "driving+service",
}


class ShortestRoute:

    graphs: Dict = dict()

    def __init__(
        self, osm: OSM, travel_types: List = list(TRAVEL_TYPES.keys())
    ) -> None:
        self.validate_dependencies()
        if osm is not None:
            for travel_type in TRAVEL_TYPES.keys():
                if travel_type not in travel_types:
                    continue
                network_type = TRAVEL_TYPES[travel_type]
                nodes, edges = osm.get_network(nodes=True, network_type=network_type)
                graph = osm.to_graph(nodes, edges, graph_type="networkx")
                graph = ox.add_edge_speeds(graph)
                graph = ox.add_edge_travel_times(graph)
                graph = nx.convert_node_labels_to_integers(graph)
                self.graphs[travel_type] = graph

    def validate_dependencies(self) -> None:
        if pairwise.distance_metrics is None:
            raise Warning(
                "There is no valid distance metrics available on the current scikit-learn installation."
            )

    def get_valid_distance_metrics(self) -> List[str]:
        return pairwise.distance_metrics()

    def validate_distance_metric(self, distance_metric: str) -> None:
        if distance_metric not in pairwise.distance_metrics():
            raise Warning(
                f"{distance_metric} is not a valid distance metric. Please use one of the following: {self.get_valid_distance_metrics}"
            )

    def get_available_travel_types(self) -> List:
        return list(self.graphs.keys())

    def validate_travel_type(self, travel_type: str = "driving") -> None:
        if travel_type not in self.get_available_travel_types():
            raise Warning(
                f"Travel type is not available. Available travel types are: {self.get_available_travel_types()}"
            )

    def get_nodes_positions(self, travel_type: str = "driving") -> np.ndarray:
        self.validate_travel_type(travel_type)
        graph = self.graphs[travel_type]
        node_pos = np.zeros((graph.number_of_nodes(), 2))
        for node, data in graph.nodes(data=True):
            node_pos[node, [0, 1]] = data["x"], data["y"]
        return node_pos

    def build_nodes_tree(
        self, travel_type: str = "driving", distance_metric: str = "haversine"
    ):
        self.validate_travel_type(travel_type)
        self.validate_distance_metric(distance_metric)
        nodes = self.get_nodes_positions(travel_type)
        tree = BallTree(nodes, metric=distance_metric)
        return tree

    def get_nearest_node(
        self,
        x: float,
        y: float,
        travel_type: str = "driving",
        return_dist: bool = False,
    ):
        self.validate_travel_type(travel_type)
        return ox.distance.nearest_nodes(
            self.graphs[travel_type], x, y, return_dist=return_dist
        )

    def get_nearest_nodes(
        self,
        sources: np.ndarray,
        travel_type: str = "driving",
        return_dist: bool = False,
        tree: Optional[BallTree] = None
    ):
        if tree is None:
            tree = self.build_nodes_tree(travel_type)
        if return_dist:
            d, i = tree.query(sources, k=1, return_distance=return_dist)
            return d[:, 0], i[:, 0]
        else:
            i = tree.query(sources, k=1, return_distance=return_dist)
            return i[:, 0]

    def compute_cost_matrix(self, sources: np.ndarray, targets: np.ndarray, travel_type: str = 'driving', weight: str = 'travel_time'):
        self.validate_travel_type(travel_type)
        graph = self.graphs[travel_type]
        tree = self.build_nodes_tree(travel_type)
        source_nodes = self.get_nearest_nodes(sources, travel_type=travel_type, return_dist=False, tree=tree)
        target_nodes = self.get_nearest_nodes(targets, travel_type=travel_type, return_dist=False, tree=tree)
        cost_matrix = np.full( (len(sources), len(targets)), np.inf )
        for index_source, source in enumerate(source_nodes):
            distances = nx.shortest_path_length(graph, source=source, weight=weight)
            for index_target, target in enumerate(target_nodes):
                cost_matrix[index_source, index_target] = distances[target]
        return cost_matrix

In [63]:
osm = OSM('../data/protobuf/macul.osm.pbf')
sr = ShortestRoute(osm, travel_types=['driving'])

In [64]:
sr.compute_cost_matrix([[-70.587114, -33.472195], [-70.586961, -33.472222]], [[-70.587114, -33.472195], [-70.586961, -33.472222]])

array([[0. , 1.4],
       [1.4, 0. ]])

In [65]:
d, i = sr.get_nearest_nodes([[-70.587114, -33.472195], [-70.586961, -33.472222]], return_dist=True)

In [54]:
np.full((2, 3), np.inf)

array([[inf, inf, inf],
       [inf, inf, inf]])

In [55]:
for t, distance in nx.shortest_path_length(sr.graphs['driving'], source=0, weight='length').items():
    

{0: 0,
 1: 11.475,
 2: 14.551,
 67: 46.326,
 3: 54.666000000000004,
 66: 56.371,
 20767: 57.168,
 17412: 66.77600000000001,
 12394: 67.634,
 17413: 74.44900000000001,
 12393: 75.713,
 65: 88.905,
 8871: 96.259,
 12392: 100.41399999999999,
 68: 100.587,
 4378: 102.035,
 8872: 103.619,
 8873: 111.065,
 23805: 111.211,
 23804: 113.651,
 52: 118.506,
 4: 122.02199999999999,
 69: 124.947,
 53: 127.712,
 4381: 136.25900000000001,
 17414: 136.78000000000003,
 4379: 143.85399999999998,
 54: 146.37800000000001,
 70: 153.974,
 55: 155.30200000000002,
 56: 164.656,
 4447: 164.664,
 71: 167.653,
 12391: 171.95999999999998,
 57: 173.55,
 4334: 173.607,
 12390: 175.12099999999998,
 4353: 175.82399999999998,
 12389: 178.27999999999997,
 4461: 178.897,
 4380: 179.60500000000002,
 4338: 181.44299999999998,
 58: 181.735,
 4339: 184.611,
 4340: 187.768,
 59: 189.61700000000002,
 4341: 190.935,
 72: 198.35899999999998,
 60: 200.346,
 73: 202.87699999999998,
 74: 204.13599999999997,
 61: 206.51,
 694: 207.

In [41]:
d[]

{0: 0,
 1: 0,
 5: 0,
 2: 0.3,
 6: 0.4,
 7: 0.8,
 8: 1.3,
 9: 1.7000000000000002,
 10: 2.3000000000000003,
 11: 2.8000000000000003,
 681: 3.1,
 682: 3.5,
 67: 3.6,
 12: 3.7,
 683: 3.8,
 3: 4.1,
 684: 4.1,
 66: 4.3,
 13: 4.3,
 20767: 4.8,
 14: 4.8999999999999995,
 17412: 5.3,
 15: 5.3999999999999995,
 12394: 5.6,
 16: 5.999999999999999,
 17413: 6.1,
 17: 6.499999999999999,
 12393: 6.5,
 685: 6.8,
 65: 6.8,
 18: 6.999999999999999,
 8871: 7.3999999999999995,
 19: 7.599999999999999,
 686: 7.6,
 68: 7.800000000000001,
 8872: 7.999999999999999,
 20: 8.299999999999999,
 687: 8.299999999999999,
 8873: 8.6,
 23805: 8.8,
 21: 8.999999999999998,
 688: 8.999999999999998,
 23804: 9.0,
 52: 9.2,
 22: 9.299999999999999,
 12392: 9.3,
 4378: 9.399999999999999,
 689: 9.699999999999998,
 23: 9.7,
 23468: 9.7,
 69: 9.700000000000001,
 53: 9.899999999999999,
 23469: 10.2,
 24: 10.299999999999999,
 690: 10.399999999999997,
 4: 10.5,
 25: 10.899999999999999,
 4381: 11.000000000000002,
 54: 11.299999999999999,

In [9]:
nodes, edges = osm.get_network(nodes=True, network_type='driving')

In [10]:
nodes

Unnamed: 0,lon,lat,tags,timestamp,version,changeset,id,geometry
0,-70.587114,-33.472195,,0,0,0,282822274,POINT (-70.58711 -33.47219)
1,-70.586993,-33.472216,"{'crossing': 'uncontrolled', 'crossing_ref': '...",0,0,0,9316857831,POINT (-70.58699 -33.47222)
2,-70.586961,-33.472222,"{'direction': 'backward', 'highway': 'give_way'}",0,0,0,9316857832,POINT (-70.58696 -33.47222)
3,-70.586538,-33.472299,,0,0,0,282822382,POINT (-70.58654 -33.47230)
4,-70.585829,-33.472428,,0,0,0,4437997596,POINT (-70.58583 -33.47243)
...,...,...,...,...,...,...,...,...
25415,-70.598483,-33.507286,,0,0,0,9750122029,POINT (-70.59848 -33.50729)
25416,-70.598522,-33.507283,"{'access': 'permissive', 'barrier': 'gate'}",0,0,0,9750122030,POINT (-70.59852 -33.50728)
25417,-70.598781,-33.507245,,0,0,0,9750122032,POINT (-70.59878 -33.50724)
25418,-70.599056,-33.507207,,0,0,0,9750122031,POINT (-70.59906 -33.50721)
