In [24]:
# run up the shell with  

# pyspark --packages graphframes:graphframes:0.5.0-spark2.1-s_2.11

# to allow for the graphframes import

from collections import defaultdict
from graphframes import GraphFrame

In [25]:
# Type added when creating the vertex set to begin with

v = sqlContext.createDataFrame([
  ("CH1", "CH"),
  ("CH2", "CH"),
  ("VAT1", "VAT"),
  ("VAT2", "VAT"),
  ("PAYE1", "PAYE"),
  ("PAYE2", "PAYE"),
], ["id", "type"])

In [26]:
e = sqlContext.createDataFrame([
  ("CH1", "VAT1", "0.8"),
  ("CH1", "VAT2", "0.9"),
  ("CH2", "VAT2", "0.92"),
  ("VAT1", "PAYE1", "0.8"),
  ("VAT1", "PAYE2", "0.87"),    
], ["src", "dst", 'weight'])

In [27]:
def add_opposite_direction_edges(edge_list):
    """
    The shortest path algorithm uses the direction of the edges are directional, 
    adding an edge in other direction of equal weight gives the effect of undirected edges.
    """
    inverse_edge_list = edge_list
    inverse_edge_list = inverse_edge_list.withColumnRenamed('src', 'dst_copy')
    inverse_edge_list = inverse_edge_list.withColumnRenamed('dst', 'src')
    inverse_edge_list = inverse_edge_list.withColumnRenamed('dst_copy', 'dst')
    return edge_list.union(inverse_edge_list.select('src', 'dst', 'weight'))

In [28]:
def is_better_path(newpath, oldpath):
    """
    is new path more strongly weighted than old path
    
    newpath, oldpath:  lists of tuples with node and weight
    """
    new_path_length = [float(node[1]) for node in newpath]
    old_path_length = [float(node[1]) for node in oldpath]

    # This returns the most strongly weighted path rather than shortest, is this a problem for long paths of weak edges?
    return sum(new_path_length) > sum(old_path_length)

In [29]:
def find_closest_ch_path(edge_list, vertex, path=[]):
    # list of tuples, (node_id, path_weight)
    path = path + [vertex] 
        
    if vertex[0].startswith('CH'):
        return path 
    
    #TODO: max path length constraint?
    
    possible_dsts = edge_list.filter(edge_list['src'] == vertex[0]).rdd.map(lambda r: (r[1], r[2])).collect() 
    shortest = None

    for node in possible_dsts:
        if node not in path: 
            newpath = find_closest_ch_path(edge_list, node, path)
            if newpath:
                if not shortest or is_better_path(newpath, shortest):
                    shortest = newpath
    return shortest

In [30]:
def split_complex_sub_graphs(vertex_list, edge_list):
    """
    Split sub graph, return a list of the split subgraphs vertex lists
    
    """
    # graph is undirected so need edges in both directions
    edge_list = add_opposite_direction_edges(edge_list)
        
    subgraphs_dict = defaultdict(list)

    all_nodes = vertex_list.rdd.map(lambda r: (r[0], 0)).collect()
    
    for node in all_nodes:
        closest_ch_path = find_closest_ch_path(edge_list, node)
        closest_ch_node = closest_ch_path[-1][0]
        subgraphs_dict[closest_ch_node].append(node[0])
        
    return [subgraph for subgraph in subgraphs_dict.values()]

In [31]:
split_complex_sub_graphs(v, e)

[['CH1', 'VAT1', 'PAYE1', 'PAYE2'], ['CH2', 'VAT2']]

In [32]:
############### Testing ####################

In [33]:
v1 = sqlContext.createDataFrame([
  ("CH1", "CH"),
  ("CH2", "CH"),
  ("VAT1", "VAT"),
  ("VAT2", "VAT"),
], ["id", "type"])

e1 = sqlContext.createDataFrame([
  ("CH1", "VAT1", "0.9"),
  ("VAT1", "VAT2", "0.8"),
  ("CH2", "VAT2", "0.92"),   
], ["src", "dst", 'weight'])

In [34]:
assert split_complex_sub_graphs(v1, e1) == [['CH1', 'VAT1'], ['CH2', 'VAT2']]

In [35]:
v2 = sqlContext.createDataFrame([
  ("CH1", "CH"),
  ("CH2", "CH"),
  ("VAT1", "VAT"),
  ("VAT2", "VAT"),
  ("PAYE1", "PAYE"),
], ["id", "type"])

e2 = sqlContext.createDataFrame([
  ("CH1", "VAT1", "0.9"),
  ("CH2", "VAT1", "0.92"),   
  ("CH2", "PAYE1", "0.92"),   
  ("CH2", "VAT2", "0.92"),   
], ["src", "dst", 'weight'])

In [36]:
assert split_complex_sub_graphs(v2, e2) == [['CH1'], ['CH2', 'VAT1', 'VAT2', 'PAYE1']]

In [37]:
v3 = sqlContext.createDataFrame([
  ("CH1", "CH"),
  ("CH2", "CH"),
  ("VAT1", "VAT"),
  ("VAT2", "VAT"),
], ["id", "type"])

e3 = sqlContext.createDataFrame([
  ("CH1", "VAT1", "0.9"),
  ("CH2", "VAT1", "0.92"),   
  ("CH2", "VAT2", "0.92"),   
  ("CH1", "VAT2", "0.93"),   
], ["src", "dst", 'weight'])

In [38]:
assert split_complex_sub_graphs(v3, e3) == [['CH1', 'VAT2'], ['CH2', 'VAT1']]

In [39]:
v4 = sqlContext.createDataFrame([
  ("CH1", "CH"),
  ("CH2", "CH"),
  ("VAT1", "VAT"),
  ("VAT2", "VAT"),
  ('PAYE3', 'PAYE')
], ["id", "type"])

e4 = sqlContext.createDataFrame([
  ("CH1", "VAT1", "0.9"),
  ("VAT1", "VAT2", "0.7"),
  ("VAT2", "PAYE3", "0.8"),
  ("CH2", "PAYE3", "0.95"),   
], ["src", "dst", 'weight'])

In [40]:
split_complex_sub_graphs(v4, e4)

[['CH1', 'VAT1', 'VAT2'], ['CH2', 'PAYE3']]

In [41]:
assert split_complex_sub_graphs(v4, e4) == [['CH1', 'VAT1', 'VAT2'], ['CH2', 'PAYE3']]

In [42]:
v5 = sqlContext.createDataFrame([
  ("CH1", "CH"),
  ("CH2", "CH"),
  ("VAT1", "VAT"),
  ("VAT2", "VAT"),
], ["id", "type"])

e5 = sqlContext.createDataFrame([
  ("CH1", "VAT1", "0.9"),
  ("VAT1", "VAT2", "0.8"),
  ("CH2", "VAT2", "0.8"),   
], ["src", "dst", 'weight'])

In [43]:
assert split_complex_sub_graphs(v5, e5) == [['CH1', 'VAT1'], ['CH2', 'VAT2']]