In [1]:
from collections import defaultdict

from graphframes import GraphFrame
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StructField, ArrayType, StringType, StructType

In [2]:
v = sqlContext.createDataFrame([
  ("CH1", "CH"),
  ("CH2", "CH"),
  ("VAT1", "VAT"),
  ("VAT2", "VAT"),
  ("PAYE1", "PAYE"),
  ("PAYE2", "PAYE"),
], ["id", "type"])

In [3]:
e = sqlContext.createDataFrame([
  ("CH1", "VAT1", "1"),
  ("CH1", "VAT2", "1"),
  ("CH2", "VAT2", "1"),
  ("VAT1", "PAYE1", "1"),
  ("VAT1", "PAYE2", "1"),    
], ["src", "dst", 'weight'])

In [4]:
def get_closest_ch(distances):
    '''
    Function to iterate over output of shortestPaths graphframes algorithm
    '''
    closest_match_distance = 10000 # np.inf?
    closest_match = None # should always be a CH node, 
    
    for node, distance in distances.items():
        if node.startswith("CH"):
            if distance < closest_match_distance:
                closest_match_distance = distance
                closest_match = node
                                
    return closest_match

In [5]:
getClosestCHUDF = udf(get_closest_ch, StringType())

In [6]:
def get_subgraphs_dict(ch_list, vertex_list):
    """
    Get dict of subraphs keyed by their CH node
    {
        "CH1": ["VAT1", "CH!", "VAT2"]
        "CH2": ["VAT12", "CH2"]
    }
    """
    subgraphs_dict = defaultdict(list)
    
    for node in vertex_list.collect():
        subgraphs_dict[node['closest_ch']].append(node['id'])
        
    return subgraphs_dict


In [7]:
def get_ch_nodes_list(vertex_list):
    ch_nodes = vertex_list.where(vertex_list['type'] == 'CH')
    return ch_nodes.rdd.map(lambda r: r[0]).collect()

In [8]:
def split_complex_sub_graphs(vertex_list, edge_list):
    """
    Split sub graph, return a list of the split subgraphs vertex lists
    """
    G = GraphFrame(vertex_list, edge_list)
    
    ch_node_list = get_ch_nodes_list(vertex_list)
    
    shortest_paths_to_ch_nodes = G.shortestPaths(ch_node_list)
    
    # add a closest ch column
    closest_ch_nodes = shortest_paths_to_ch_nodes.withColumn("closest_ch", getClosestCHUDF(shortest_paths_to_ch_nodes.distances))
    
    # Dictionary of subgraphs, keyed by CH vertex
    subgraphs_dict = get_subgraphs_dict(ch_node_list, closest_ch_nodes)
    
    return [value for value in subgraphs_dict.values()]

In [10]:
split_complex_sub_graphs(v, e)

[['PAYE2', 'PAYE1', 'VAT2', 'VAT1'], ['CH1'], ['CH2']]