In [0]:
from graphframes import GraphFrame # Use ML Runtime with CPUs for better performance
from graphframes.lib import Pregel
from graphframes.lib import AggregateMessages as AM
from pyspark.sql.functions import *
from pyspark.sql.window import Window

from typing import Dict, Callable, Union
from pyspark.sql.types import StructField
from pyspark.sql import Column

In [0]:
data = [
    (1, "Pepperoni Pizza", "Product", None, 0),
    (2, "Dough", "Complex Ingredient", 1, 10),
    (3, "Tomato Sauce", "Complex Ingredient", 1, 10),
    (4, "Cheese", "Complex Ingredient", 1, 10),
    (5, "Pepperoni", "Complex Ingredient", 1, 10),
    (6, "Tomato, Chopped", "Processed Ingredient" ,3, 10),
    (7, "Tomato, Washed", "Processed Ingredient", 6, 10),
    (8, "Tomato", "Produce", 7, 10),
]

columns = ["ingredientId", "ingredientName", "ingredientClass", "parentIngredientID", "carbon"]
ingredients_df = spark.createDataFrame(data, columns)
ingredients_df.display()

ingredientId,ingredientName,ingredientClass,parentIngredientID,carbon
1,Pepperoni Pizza,Product,,0
2,Dough,Complex Ingredient,1.0,10
3,Tomato Sauce,Complex Ingredient,1.0,10
4,Cheese,Complex Ingredient,1.0,10
5,Pepperoni,Complex Ingredient,1.0,10
6,"Tomato, Chopped",Processed Ingredient,3.0,10
7,"Tomato, Washed",Processed Ingredient,6.0,10
8,Tomato,Produce,7.0,10


In [0]:
# NOTE: DO NOT rename id, src, and dst - GraphFrames leverages these names to set up the correct graph path

# Create vertices and edges dataframes
vertices = ingredients_df.withColumnRenamed("ingredientID", "id") # Mandatory --> vertex id
print("Vertices")
vertices.display()

edges = (ingredients_df.filter("parentIngredientID IS NOT NULL") 
                       .withColumnRenamed("ingredientID", "src") # Mandatory --> source vertex
                       .withColumnRenamed("parentIngredientID", "dst") # Mandatory --> destination vertex
                       .select("src", "dst")
        )
print("Edges")
edges.display()

graph = GraphFrame(vertices, edges)

Vertices


id,ingredientName,ingredientClass,parentIngredientID,carbon
1,Pepperoni Pizza,Product,,0
2,Dough,Complex Ingredient,1.0,10
3,Tomato Sauce,Complex Ingredient,1.0,10
4,Cheese,Complex Ingredient,1.0,10
5,Pepperoni,Complex Ingredient,1.0,10
6,"Tomato, Chopped",Processed Ingredient,3.0,10
7,"Tomato, Washed",Processed Ingredient,6.0,10
8,Tomato,Produce,7.0,10


Edges


src,dst
2,1
3,1
4,1
5,1
6,3
7,6
8,7


In [0]:
"""
Initialize pregel algorithm to run aggregations over the tree

.setMaxIter --> Maximum iteration of messages propagated in the graph. It depends of the deepness of your graph structure

.withVertexColumn --> 

.sendMsgToDst --> 

.aggMsgs --> Defines how messages are aggregated after grouped by target vertex IDs.
"""
result = (graph.pregel
               .setMaxIter(5) 
               .withVertexColumn("totalCarbon", lit(0), coalesce(Pregel.msg(), col("carbon"))) 
               .sendMsgToDst(Pregel.src("totalCarbon") + Pregel.dst("carbon"))
               .aggMsgs(sum(Pregel.msg()).alias("msg"))
               .run()
        )

result.display()

id,totalCarbon,ingredientName,ingredientClass,parentIngredientID,carbon
2,10,Dough,Complex Ingredient,1.0,10
3,40,Tomato Sauce,Complex Ingredient,1.0,10
7,20,"Tomato, Washed",Processed Ingredient,6.0,10
4,10,Cheese,Complex Ingredient,1.0,10
5,10,Pepperoni,Complex Ingredient,1.0,10
8,10,Tomato,Produce,7.0,10
1,70,Pepperoni Pizza,Product,,0
6,30,"Tomato, Chopped",Processed Ingredient,3.0,10


In [0]:
def set_tree_level_labels(graph: GraphFrame, max_iters= 50, min_iters= 5):

    window_id_order = Window().partitionBy().orderBy("id")
    branch_label_expr = array(row_number().over(window_id_order), lit(1))
    message_expr = struct(col("level"), col(f"branchLabel"))

    vertex_outbound_conns_df = graph.outDegrees

    vertices_leveled = (graph.vertices.join(vertex_outbound_conns_df, 
                                        on=["id"], 
                                        how="left")
                    .withColumn("level", when(col("outDegree").isNull(), lit(1)))
                    .withColumn("branchLabel", when(col("outDegree").isNull(), branch_label_expr))
                    .withColumn("message", when(col("outDegree").isNull(), message_expr))
                    .drop("outDegree")
    )   
    print("Vertices")
    vertices_leveled.display()

    window_src_dst = Window().partitionBy("dst").orderBy("src")

    edges_child_numbered = graph.edges.withColumn("childNumber", row_number().over(window_src_dst))
    print("Edges Child Numbered")
    edges_child_numbered.display()
    
    v = vertices_leveled
    e = edges_child_numbered
    gx = GraphFrame(v, e)

    for i in range(max_iters):
        print(f"Iteration {i}")

        aggregate_msg_expr = first(AM.msg).alias("message")
        msg_to_src_expr = struct((AM.dst["message"].getField("level") + 1).alias("level"),
                                  array_append(AM.dst["message"].getField("branchLabel"), AM.edge["childNumber"]).alias("branchLabel"))
        
        new_levels = gx.aggregateMessages(
            aggCol= aggregate_msg_expr,
            sendToSrc= msg_to_src_expr
        )

        print("New Levels Propagated")
        new_levels.display()
        
        if i >= min_iters and new_levels.count() == 0:
            break

        v = (
            v.alias("l")
            .join(new_levels.alias("r"), "id", "left_outer")
            .withColumn("level", coalesce(col("l.level"), col("r.message.level")))
            .withColumn(
                "branchLabel",
                coalesce(col("l.branchLabel"), col("r.message.branchLabel")),
            )
            .drop(col("r.message"))
        )
        print("Vertices to process")
        v.display()
        cachedNewVertices = AM.getCachedDataFrame(new_levels)
        gx = GraphFrame(cachedNewVertices, e)

    print("Final - All levels and tree branches set")
    v.display()
    return GraphFrame(v.drop("message"), e)

graph_leveled = set_tree_level_labels(graph, max_iters= 15) # You must set always a maximum level of iterations

max_levels = graph_leveled.vertices.groupBy().max("level").first()[0]
print("Check precisely how many levels we'll get in a graph")
max_levels

Vertices


id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,message
1,Pepperoni Pizza,Product,,0,1.0,"List(1, 1)","List(1, List(1, 1))"
2,Dough,Complex Ingredient,1.0,10,,,
3,Tomato Sauce,Complex Ingredient,1.0,10,,,
4,Cheese,Complex Ingredient,1.0,10,,,
5,Pepperoni,Complex Ingredient,1.0,10,,,
6,"Tomato, Chopped",Processed Ingredient,3.0,10,,,
7,"Tomato, Washed",Processed Ingredient,6.0,10,,,
8,Tomato,Produce,7.0,10,,,


Edges Child Numbered


src,dst,childNumber
2,1,1
3,1,2
4,1,3
5,1,4
6,3,1
7,6,1
8,7,1


Iteration 0
New Levels Propagated


id,message
2,"List(2, List(1, 1, 1))"
3,"List(2, List(1, 1, 2))"
4,"List(2, List(1, 1, 3))"
5,"List(2, List(1, 1, 4))"
6,"List(null, null)"
7,"List(null, null)"
8,"List(null, null)"


Vertices to process


id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,message
2,Dough,Complex Ingredient,1.0,10,2.0,"List(1, 1, 1)",
3,Tomato Sauce,Complex Ingredient,1.0,10,2.0,"List(1, 1, 2)",
7,"Tomato, Washed",Processed Ingredient,6.0,10,,,
4,Cheese,Complex Ingredient,1.0,10,2.0,"List(1, 1, 3)",
5,Pepperoni,Complex Ingredient,1.0,10,2.0,"List(1, 1, 4)",
8,Tomato,Produce,7.0,10,,,
1,Pepperoni Pizza,Product,,0,1.0,"List(1, 1)","List(1, List(1, 1))"
6,"Tomato, Chopped",Processed Ingredient,3.0,10,,,


Iteration 1
New Levels Propagated


id,message
6,"List(3, List(1, 1, 2, 1))"
7,"List(null, null)"
8,"List(null, null)"


Vertices to process


id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,message
1,Pepperoni Pizza,Product,,0,1.0,"List(1, 1)","List(1, List(1, 1))"
2,Dough,Complex Ingredient,1.0,10,2.0,"List(1, 1, 1)",
3,Tomato Sauce,Complex Ingredient,1.0,10,2.0,"List(1, 1, 2)",
4,Cheese,Complex Ingredient,1.0,10,2.0,"List(1, 1, 3)",
5,Pepperoni,Complex Ingredient,1.0,10,2.0,"List(1, 1, 4)",
6,"Tomato, Chopped",Processed Ingredient,3.0,10,3.0,"List(1, 1, 2, 1)",
7,"Tomato, Washed",Processed Ingredient,6.0,10,,,
8,Tomato,Produce,7.0,10,,,


Iteration 2
New Levels Propagated


id,message
7,"List(4, List(1, 1, 2, 1, 1))"
8,"List(null, null)"


Vertices to process


id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,message
1,Pepperoni Pizza,Product,,0,1.0,"List(1, 1)","List(1, List(1, 1))"
2,Dough,Complex Ingredient,1.0,10,2.0,"List(1, 1, 1)",
3,Tomato Sauce,Complex Ingredient,1.0,10,2.0,"List(1, 1, 2)",
4,Cheese,Complex Ingredient,1.0,10,2.0,"List(1, 1, 3)",
5,Pepperoni,Complex Ingredient,1.0,10,2.0,"List(1, 1, 4)",
6,"Tomato, Chopped",Processed Ingredient,3.0,10,3.0,"List(1, 1, 2, 1)",
7,"Tomato, Washed",Processed Ingredient,6.0,10,4.0,"List(1, 1, 2, 1, 1)",
8,Tomato,Produce,7.0,10,,,


Iteration 3
New Levels Propagated


id,message
8,"List(5, List(1, 1, 2, 1, 1, 1))"


Vertices to process


id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,message
2,Dough,Complex Ingredient,1.0,10,2,"List(1, 1, 1)",
3,Tomato Sauce,Complex Ingredient,1.0,10,2,"List(1, 1, 2)",
7,"Tomato, Washed",Processed Ingredient,6.0,10,4,"List(1, 1, 2, 1, 1)",
4,Cheese,Complex Ingredient,1.0,10,2,"List(1, 1, 3)",
5,Pepperoni,Complex Ingredient,1.0,10,2,"List(1, 1, 4)",
8,Tomato,Produce,7.0,10,5,"List(1, 1, 2, 1, 1, 1)",
1,Pepperoni Pizza,Product,,0,1,"List(1, 1)","List(1, List(1, 1))"
6,"Tomato, Chopped",Processed Ingredient,3.0,10,3,"List(1, 1, 2, 1)",


Iteration 4
New Levels Propagated


id,message


Vertices to process


id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,message
1,Pepperoni Pizza,Product,,0,1,"List(1, 1)","List(1, List(1, 1))"
2,Dough,Complex Ingredient,1.0,10,2,"List(1, 1, 1)",
3,Tomato Sauce,Complex Ingredient,1.0,10,2,"List(1, 1, 2)",
4,Cheese,Complex Ingredient,1.0,10,2,"List(1, 1, 3)",
5,Pepperoni,Complex Ingredient,1.0,10,2,"List(1, 1, 4)",
6,"Tomato, Chopped",Processed Ingredient,3.0,10,3,"List(1, 1, 2, 1)",
7,"Tomato, Washed",Processed Ingredient,6.0,10,4,"List(1, 1, 2, 1, 1)",
8,Tomato,Produce,7.0,10,5,"List(1, 1, 2, 1, 1, 1)",


Iteration 5
New Levels Propagated


id,message


Final - All levels and tree branches set


id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,message
2,Dough,Complex Ingredient,1.0,10,2,"List(1, 1, 1)",
3,Tomato Sauce,Complex Ingredient,1.0,10,2,"List(1, 1, 2)",
7,"Tomato, Washed",Processed Ingredient,6.0,10,4,"List(1, 1, 2, 1, 1)",
4,Cheese,Complex Ingredient,1.0,10,2,"List(1, 1, 3)",
5,Pepperoni,Complex Ingredient,1.0,10,2,"List(1, 1, 4)",
8,Tomato,Produce,7.0,10,5,"List(1, 1, 2, 1, 1, 1)",
1,Pepperoni Pizza,Product,,0,1,"List(1, 1)","List(1, List(1, 1))"
6,"Tomato, Chopped",Processed Ingredient,3.0,10,3,"List(1, 1, 2, 1)",


Check precisely how many levels we'll get in a graph


5

In [0]:
"""
Sometimes we might have cases where we model our data and generate orphan nodes (vertex) by dropping certain rows to meet certain criteria. In order to propage effectively, we need to remove them. Be mindful to keep your graph structure intact and robust before using any MPI algorithm/interface like Pregel, Aggregate Message, Page Rank etc.
"""
new_vertices = (graph_leveled.vertices.filter(col("level").isNotNull())
                                      .withColumn("graphProcessOrder", lit(max_levels) - col(f"level"))
               )

edges_vertices_join_expr = (col("id") == col("src")) | (col("id") == col("dst"))

new_edges = graph_leveled.edges.join(new_vertices, edges_vertices_join_expr, "leftsemi")

graph_with_levels = GraphFrame(new_vertices, new_edges).cache() # Caching is extremely important when dealing with iterative algorithms. Don't fogert to unpersist to avoid GC pressure

In [0]:
# NOTE: Struct is a helpful structure to manage MPI across the nodes. we can simply propagate more than one column and read/aggregate easily.

import pyspark.sql.functions as F

from graphframes import GraphFrame
from graphframes.lib import Pregel
from typing import Dict, Callable, Union
from pyspark.sql.types import StructField
from pyspark.sql import Column


def pregel_aggregation_by_levels(
    g: GraphFrame,
    vertex_name: str, 
    vertex_initial_value: Dict[str, Column], 
    child_msg_value: Callable[[str], StructField], 
    children_aggregation: Dict[str, Column], 
    vertex_update_after_aggregation: Union[Column, Dict[str, Column]],
    order_column: str,
    max_iters: int = 10,
    debug=True,
):
        
    vertex_initial_value = F.struct(*[v.alias(k) for k, v in vertex_initial_value.items()])  # Convert dict to actual struct
    children_aggregation = F.struct(*[v.alias(k) for k, v in children_aggregation.items()])  # Convert dict to actual struct
    
    if isinstance(vertex_update_after_aggregation, dict):
        vertex_update_after_aggregation = F.struct(*[v.alias(k) for k, v in vertex_update_after_aggregation.items()])  # Convert dict to actual struct

    initial_iteration_message = F.when(F.col("inDegree").isNull(), vertex_initial_value)  # Start the loop with only the leaves

    v = (        
        g.vertices.join(g.inDegrees, ["id"],"left")
        .cache()
    )

    v.display()

    e = g.edges.cache()
    gx = GraphFrame(v, e)

    __currentIterationMessage = "__currentIterationMessage"
    __currentIterationOrder = "__currentIterationOrder"

    nV = (
        gx.pregel
        .setMaxIter(max_iters)  # This is important, pregel will do this number of iterations 
        .withVertexColumn(
            colName=vertex_name,
            initialExpr=vertex_initial_value,
            updateAfterAggMsgsExpr=vertex_update_after_aggregation,
        )
        .withVertexColumn(
            colName=__currentIterationMessage, 
            initialExpr=initial_iteration_message,
            updateAfterAggMsgsExpr=F.when(
                (
                    (F.col(__currentIterationOrder) <= 0) |  # Update all messages that are down the current level
                    ((F.col(__currentIterationOrder) <= 1) & Pregel.msg().isNotNull())  # For current level messages, update them
                ), 
                Pregel.msg())
                .otherwise(F.col(__currentIterationMessage)),
        )
        .withVertexColumn(
            colName=__currentIterationOrder,
            initialExpr=F.col(order_column),
            updateAfterAggMsgsExpr=F.col(__currentIterationOrder) - 1
        )
        .sendMsgToDst(
            msgExpr=F.when(Pregel.src(__currentIterationOrder) == 0, child_msg_value(__currentIterationMessage))
        )
        .aggMsgs(
            aggExpr=children_aggregation
        )
        .run()
    )
    
    if not debug:
        nV = nV.drop("inDegree", __currentIterationMessage, __currentIterationOrder)

    nV.display()

    return GraphFrame(nV, e)

In [0]:
graph_with_levels.vertices.display()

graph_with_levels.edges.display()

id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,graphProcessOrder
1,Pepperoni Pizza,Product,,0,1,"List(1, 1)",4
2,Dough,Complex Ingredient,1.0,10,2,"List(1, 1, 1)",3
3,Tomato Sauce,Complex Ingredient,1.0,10,2,"List(1, 1, 2)",3
4,Cheese,Complex Ingredient,1.0,10,2,"List(1, 1, 3)",3
5,Pepperoni,Complex Ingredient,1.0,10,2,"List(1, 1, 4)",3
6,"Tomato, Chopped",Processed Ingredient,3.0,10,3,"List(1, 1, 2, 1)",2
7,"Tomato, Washed",Processed Ingredient,6.0,10,4,"List(1, 1, 2, 1, 1)",1
8,Tomato,Produce,7.0,10,5,"List(1, 1, 2, 1, 1, 1)",0


src,dst,childNumber
6,3,1
8,7,1
2,1,1
3,1,2
4,1,3
5,1,4
7,6,1


In [0]:
# The best way for message passing development is setting the iteration as 1 and display the internals. You will have a clear picture of what will happen in further nodes. Additionally, motif finding is also helpful to understand source -> edge -> destination data being propagated through the messages

_msg = Pregel.msg().getField # Label to simplify the struct().getField - note that our message is a struct

initialize_carbon_value = {"totalCarbon": col("carbon")}

child_msg_value = lambda v: (Pregel.src(v).withField("contributionFactor", lit(1.0))) # Show case weighting example later

aggredate_children_expr = {"totalCarbon": sum(_msg("totalCarbon"))}

vertex_update_after_agg_expr = coalesce(_msg("totalCarbon") + lit(0)) + col("impactsAgg.totalCarbon")

g_results = pregel_aggregation_by_levels(graph_with_levels,                                                    
                                         vertex_name= "impactsAgg",
                                         vertex_initial_value = initialize_carbon_value,
                                         child_msg_value = child_msg_value,
                                         children_aggregation = aggredate_children_expr,
                                         vertex_update_after_aggregation = vertex_update_after_agg_expr,
                                         order_column="graphProcessOrder",
                                         max_iters=1,
                                         debug=True)

id,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,graphProcessOrder,inDegree
1,Pepperoni Pizza,Product,,0,1,"List(1, 1)",4,4.0
2,Dough,Complex Ingredient,1.0,10,2,"List(1, 1, 1)",3,
3,Tomato Sauce,Complex Ingredient,1.0,10,2,"List(1, 1, 2)",3,1.0
4,Cheese,Complex Ingredient,1.0,10,2,"List(1, 1, 3)",3,
5,Pepperoni,Complex Ingredient,1.0,10,2,"List(1, 1, 4)",3,
6,"Tomato, Chopped",Processed Ingredient,3.0,10,3,"List(1, 1, 2, 1)",2,1.0
7,"Tomato, Washed",Processed Ingredient,6.0,10,4,"List(1, 1, 2, 1, 1)",1,1.0
8,Tomato,Produce,7.0,10,5,"List(1, 1, 2, 1, 1, 1)",0,


id,efAgg,__currentIterationMessage,__currentIterationOrder,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,graphProcessOrder,inDegree
2,,List(10),2,Dough,Complex Ingredient,1.0,10,2,"List(1, 1, 1)",3,
3,,,2,Tomato Sauce,Complex Ingredient,1.0,10,2,"List(1, 1, 2)",3,1.0
7,20.0,List(10),0,"Tomato, Washed",Processed Ingredient,6.0,10,4,"List(1, 1, 2, 1, 1)",1,1.0
4,,List(10),2,Cheese,Complex Ingredient,1.0,10,2,"List(1, 1, 3)",3,
5,,List(10),2,Pepperoni,Complex Ingredient,1.0,10,2,"List(1, 1, 4)",3,
8,,,-1,Tomato,Produce,7.0,10,5,"List(1, 1, 2, 1, 1, 1)",0,
1,,,3,Pepperoni Pizza,Product,,0,1,"List(1, 1)",4,4.0
6,,,1,"Tomato, Chopped",Processed Ingredient,3.0,10,3,"List(1, 1, 2, 1)",2,1.0


In [0]:
g_results.find("(c)-[e]->(p)") # Expression for child -> edge -> parent. GraphFrame motif finding uses a simple Domain-Specific Language (DSL) for expressing structural queries.

# See more: https://graphframes.github.io/graphframes/docs/_site/user-guide.html#motif-finding

id,efAgg,__currentIterationMessage,__currentIterationOrder,ingredientName,ingredientClass,parentIngredientID,carbon,level,branchLabel,graphProcessOrder,inDegree
2,List(10),,-7,Dough,Complex Ingredient,1.0,10,2,"List(1, 1, 1)",3,
3,List(10),,-7,Tomato Sauce,Complex Ingredient,1.0,10,2,"List(1, 1, 2)",3,1.0
7,List(10),,-9,"Tomato, Washed",Processed Ingredient,6.0,10,4,"List(1, 1, 2, 1, 1)",1,1.0
4,List(10),,-7,Cheese,Complex Ingredient,1.0,10,2,"List(1, 1, 3)",3,
5,List(10),,-7,Pepperoni,Complex Ingredient,1.0,10,2,"List(1, 1, 4)",3,
8,List(10),,-10,Tomato,Produce,7.0,10,5,"List(1, 1, 2, 1, 1, 1)",0,
1,List(40),,-6,Pepperoni Pizza,Product,,0,1,"List(1, 1)",4,4.0
6,List(10),,-8,"Tomato, Chopped",Processed Ingredient,3.0,10,3,"List(1, 1, 2, 1)",2,1.0
