In [1]:
# run up the shell with  

# pyspark --packages graphframes:graphframes:0.5.0-spark2.1-s_2.11

# to allow for the graphframes import

from collections import defaultdict

from graphframes import GraphFrame
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StructField, ArrayType, StringType, StructType

In [2]:
# Type added when creating the vertex set to begin with

v = sqlContext.createDataFrame([
  ("CH1", "CH"),
  ("CH2", "CH"),
  ("VAT1", "VAT"),
  ("VAT2", "VAT"),
  ("PAYE1", "PAYE"),
  ("PAYE2", "PAYE"),
], ["id", "type"])

In [3]:
e = sqlContext.createDataFrame([
  ("CH1", "VAT1", "0.8"),
  ("CH1", "VAT2", "0.9"),
  ("CH2", "VAT2", "0.91"),
  ("VAT1", "PAYE1", "0.8"),
  ("VAT1", "PAYE2", "0.87"),    
], ["src", "dst", 'weight'])

In [4]:
def get_closest_ch(distances):
    '''
    Function to iterate over output of shortestPaths graphframes algorithm
    '''
    closest_match_distance = 10000 # np.inf?
    closest_match = None # Will always be a CH node as we are partitioning a complex unit which has been tested for multiple CH nodes
    
    for node, distance in distances.items():
        if node.startswith("CH"):
            if distance < closest_match_distance:
                closest_match_distance = distance
                closest_match = node
                                
    return closest_match

In [5]:
getClosestCHUDF = udf(get_closest_ch, StringType())

In [6]:
def get_subgraphs_dict(ch_list, vertex_list):
    """
    Get dict of subraphs keyed by their CH node
    {
        "CH1": ["VAT1", "CH!", "VAT2"]
        "CH2": ["VAT12", "CH2"]
    }
    """
    subgraphs_dict = defaultdict(list)
    
    for node in vertex_list.collect():
        subgraphs_dict[node['closest_ch']].append(node['id'])
        
    return subgraphs_dict


In [7]:
def get_ch_nodes_list(vertex_list):
    ch_nodes = vertex_list.where(vertex_list['type'] == 'CH')
    return ch_nodes.rdd.map(lambda r: r[0]).collect()

In [8]:
def add_opposite_direction_edges(edge_list):
    """
    The shortest path algorithm uses the direction of the edges are directional, 
    adding an edge in other direction of equal weight gives the effect of undirected edges.
    """
    inverse_edge_list = edge_list
    inverse_edge_list = inverse_edge_list.withColumnRenamed('src', 'dst_copy')
    inverse_edge_list = inverse_edge_list.withColumnRenamed('dst', 'src')
    inverse_edge_list = inverse_edge_list.withColumnRenamed('dst_copy', 'dst')
    return edge_list.union(inverse_edge_list.select('src', 'dst', 'weight'))

In [85]:
def find_closest_ch(edge_list, vertex, path=[], path_weight=0):
    path = path + [vertex[0]]

    print('path: ', path)
    print(vertex[0])
    
    if vertex[0].startswith('CH'):
        return vertex
    
    #print(edge_list.collect())
    
    possible_dsts = edge_list.filter(edge_list['src'] == vertex[0]).rdd.map(lambda r: (r[1], r[2])).collect() # possible dst nodes from src node
    
    print(possible_dsts)
    if possible_dsts == []:
        return None
    
    shortest = None

    for node in possible_dsts:
        if node[0] not in path:
            newpath = find_closest_ch(edge_list, node, path, path_weight)
            if newpath:
                print('newpath: ', newpath)
                if not shortest or len(newpath) < len(shortest): # Need to take into account the weight of the path?
                    shortest = newpath
    return shortest

In [86]:
# test the find_closes_ch algorithm
#e = add_opposite_direction_edges(e)
#e.collect()
#find_closest_ch(e, 'PAYE1')

In [87]:
def split_complex_sub_graphs(vertex_list, edge_list):
    """
    Split sub graph, return a list of the split subgraphs vertex lists
    
    """
    edge_list = add_opposite_direction_edges(edge_list)
    
    #G = GraphFrame(vertex_list, edge_list)
    
    #ch_node_list = get_ch_nodes_list(vertex_list)
    
    # Doesn't take into account weights
    #shortest_paths_to_ch_nodes = G.shortestPaths(landmarks=ch_node_list)
    
    # TODO: weighted shortest paths  or closest CH algorithm (djikstras with shortest path to node of type "CH")
    
    subgraphs_dict = defaultdict(list)

    all_nodes = vertex_list.rdd.map(lambda r: (r[0],r[1])).collect()
    for node in all_nodes:
        #print(node)
        closest_ch_node = find_closest_ch(edge_list, node)
        #print('closest ch node', closest_ch_node)
        subgraphs_dict[closest_ch_node[0]].append(node[0])
    
    print(subgraphs_dict)
    #nodes_with_closest_ch_distance = shortest_paths_to_ch_nodes.withColumn("closest_ch", getClosestCHUDF(shortest_paths_to_ch_nodes.distances))
    
    # Dictionary of subgraphs, keyed by CH vertex
    #subgraphs_dict = get_subgraphs_dict(ch_node_list, nodes_with_closest_ch_distance)
    
    return [subgraph for subgraph in subgraphs_dict.values()]

In [88]:
split_complex_sub_graphs(v, e)

path:  ['CH1']
CH1
path:  ['CH2']
CH2
path:  ['VAT1']
VAT1
[('PAYE1', '0.8'), ('PAYE2', '0.87'), ('CH1', '0.8')]
path:  ['VAT1', 'PAYE1']
PAYE1
[('VAT1', '0.8')]
path:  ['VAT1', 'PAYE2']
PAYE2
[('VAT1', '0.87')]
path:  ['VAT1', 'CH1']
CH1
newpath:  ('CH1', '0.8')
path:  ['VAT2']
VAT2
[('CH1', '0.9'), ('CH2', '0.91')]
path:  ['VAT2', 'CH1']
CH1
newpath:  ('CH1', '0.9')
path:  ['VAT2', 'CH2']
CH2
newpath:  ('CH2', '0.91')
path:  ['PAYE1']
PAYE1
[('VAT1', '0.8')]
path:  ['PAYE1', 'VAT1']
VAT1
[('PAYE1', '0.8'), ('PAYE2', '0.87'), ('CH1', '0.8')]
path:  ['PAYE1', 'VAT1', 'PAYE2']
PAYE2
[('VAT1', '0.87')]
path:  ['PAYE1', 'VAT1', 'CH1']
CH1
newpath:  ('CH1', '0.8')
newpath:  ('CH1', '0.8')
path:  ['PAYE2']
PAYE2
[('VAT1', '0.87')]
path:  ['PAYE2', 'VAT1']
VAT1
[('PAYE1', '0.8'), ('PAYE2', '0.87'), ('CH1', '0.8')]
path:  ['PAYE2', 'VAT1', 'PAYE1']
PAYE1
[('VAT1', '0.8')]
path:  ['PAYE2', 'VAT1', 'CH1']
CH1
newpath:  ('CH1', '0.8')
newpath:  ('CH1', '0.8')
defaultdict(<class 'list'>, {'CH1': 

[['CH1', 'VAT1', 'VAT2', 'PAYE1', 'PAYE2'], ['CH2']]