In [1]:
import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder \
                    .config('spark.jars.packages', 'graphframes:graphframes:0.8.2-spark3.0-s_2.12') \
                    .getOrCreate()
sc = spark.sparkContext

In [2]:
df_V = spark.read.csv("transport-nodes.csv", header = True, inferSchema = True)
df_E = spark.read.csv("transport-relationships.csv", header = True, inferSchema = True)
from graphframes import GraphFrame
graph = GraphFrame(df_V, df_E) 

<img src="ShortestPath.png" align="left">

In [3]:
from_expr = "id = 'A城市'"
to_expr = "id = 'E城市'"
result = graph.bfs(from_expr, to_expr)
result.show(truncate = False)

+------------------------------------+------------------------+---------------------------------+------------------------+---------------------------------+
|from                                |e0                      |v1                               |e1                      |to                               |
+------------------------------------+------------------------+---------------------------------+------------------------+---------------------------------+
|[A城市, 52.078663, 4.288788, 514861]|[A城市, C城市, ROAD, 10]|[C城市, 51.9225, 4.47917, 623652]|[C城市, E城市, ROAD, 20]|[E城市, 52.01667, 4.70833, 70990]|
|[A城市, 52.078663, 4.288788, 514861]|[A城市, B城市, ROAD, 26]|[B城市, 51.9775, 4.13333, 9382]  |[B城市, E城市, ROAD, 5] |[E城市, 52.01667, 4.70833, 70990]|
|[A城市, 52.078663, 4.288788, 514861]|[A城市, D城市, ROAD, 25]|[D城市, 52.01667, 4.70833, 70939]|[D城市, E城市, ROAD, 45]|[E城市, 52.01667, 4.70833, 70990]|
+------------------------------------+------------------------+---------------------------------+---

In [6]:
result.select('from', 'v1', 'to').show(truncate = False)

+------------------------------------+---------------------------------+---------------------------------+
|from                                |v1                               |to                               |
+------------------------------------+---------------------------------+---------------------------------+
|[A城市, 52.078663, 4.288788, 514861]|[C城市, 51.9225, 4.47917, 623652]|[E城市, 52.01667, 4.70833, 70990]|
|[A城市, 52.078663, 4.288788, 514861]|[B城市, 51.9775, 4.13333, 9382]  |[E城市, 52.01667, 4.70833, 70990]|
|[A城市, 52.078663, 4.288788, 514861]|[D城市, 52.01667, 4.70833, 70939]|[E城市, 52.01667, 4.70833, 70990]|
+------------------------------------+---------------------------------+---------------------------------+



In [None]:
from pyspark.sql.types import ArrayType, StringType
from graphframes.lib import AggregateMessages as AM
from pyspark.sql import functions as F

add_path_udf = F.udf(lambda path, id: path + [id], ArrayType(StringType()))

def shortest_path(g, origin, destination, column_name="cost"):
    if g.vertices.filter(g.vertices.id == destination).count() == 0:
        return (spark.createDataFrame(sc.emptyRDD(), g.vertices.schema)
                .withColumn("path", F.array()))

    vertices = (g.vertices.withColumn("visited", F.lit(False))
                .withColumn("distance", F.when(g.vertices["id"] == origin, 0)
                            .otherwise(float("inf")))
                .withColumn("path", F.array()))
    cached_vertices = AM.getCachedDataFrame(vertices)
    g2 = GraphFrame(cached_vertices, g.edges)

    while g2.vertices.filter('visited == False').first():
        current_node_id = g2.vertices.filter('visited == False').sort("distance").first().id

        msg_distance = AM.edge[column_name] + AM.src['distance']
        msg_path = add_path_udf(AM.src["path"], AM.src["id"])
        msg_for_dst = F.when(AM.src['id'] == current_node_id, F.struct(msg_distance, msg_path))
        new_distances = g2.aggregateMessages(F.min(AM.msg).alias("aggMess"),
                                             sendToDst=msg_for_dst)

        new_visited_col = F.when(
            g2.vertices.visited | (g2.vertices.id == current_node_id), True).otherwise(False)
        new_distance_col = F.when(new_distances["aggMess"].isNotNull() &
                                  (new_distances.aggMess["col1"] < g2.vertices.distance),
                                  new_distances.aggMess["col1"]) \
            .otherwise(g2.vertices.distance)
        new_path_col = F.when(new_distances["aggMess"].isNotNull() &
                              (new_distances.aggMess["col1"] < g2.vertices.distance),
                              new_distances.aggMess["col2"].cast("array<string>")) \
            .otherwise(g2.vertices.path)

        new_vertices = (g2.vertices.join(new_distances, on="id", how="left_outer")
                        .drop(new_distances["id"])
                        .withColumn("visited", new_visited_col)
                        .withColumn("newDistance", new_distance_col)
                        .withColumn("newPath", new_path_col)
                        .drop("aggMess", "distance", "path")
                        .withColumnRenamed('newDistance', 'distance')
                        .withColumnRenamed('newPath', 'path'))
        cached_new_vertices = AM.getCachedDataFrame(new_vertices)
        g2 = GraphFrame(cached_new_vertices, g2.edges)
        if g2.vertices.filter(g2.vertices.id == destination).first().visited:
            return (g2.vertices.filter(g2.vertices.id == destination)
                    .withColumn("newPath", add_path_udf("path", "id"))
                    .drop("visited", "path")
                    .withColumnRenamed("newPath", "path"))
    return (spark.createDataFrame(sc.emptyRDD(), g.vertices.schema)
            .withColumn("path", F.array()))

In [None]:
result = shortest_path(graph, "A城市", "E城市", "cost")
result.select("id", "distance", "path").show(truncate=False)