In [1]:
# run up the shell with  

# pyspark --packages graphframes:graphframes:0.5.0-spark2.1-s_2.11

# to allow for the graphframes import

In [2]:
from graphframes import GraphFrame
from pyspark.sql.functions import col
from pyspark.sql.types import StructField, ArrayType, StringType, StructType
from complex_graph_splitter import split_complex_sub_graphs

In [3]:
# Data comes in as an edge set already, 
# ID1, ID2, Arcweight

# preprocessing is required to generate the vertex set, need to get all unique values in ID1 and ID2 and strip them,
# this will mean we won't have any singles, which will be added in the add_single_legal_units 

In [4]:
def generate_vertex_set(dataframe):
    '''
    Generate a vertex set from the edge set
    param: dataframe with ID1, ID2
    '''
    uniqueid1 = dataframe.select('src').distinct().rdd.map(lambda r: r[0]).collect()
    uniqueid2 = dataframe.select('dst').distinct().rdd.map(lambda r: r[0]).collect()
    unique_ids = list(set(uniqueid1).union(uniqueid2))
    types = [vertex.split('_', 1)[0] for vertex in unique_ids]
    
    return sqlContext.createDataFrame(zip(unique_ids, types), ['id', "type"])

In [5]:
sc.setCheckpointDir(dirName='/Users/waltoj/development/datascience/graph-frames-prototype/checkpoints')

In [6]:
def get_subgraph_vertex_list(connected_components, component_id):
    return connected_components.where(col('component') == component_id)

In [7]:
def get_subgraph_edge_list(subgraph_vertex_list, edge_list):
    """
    Edge list for subgraph vertex list
    """
    vertex_ids = subgraph_vertex_list.distinct().rdd.map(lambda r: r[0]).collect()

    return edge_list.where(col('src').isin(vertex_ids) | col('dst').isin(vertex_ids))

In [8]:
def get_subgraph_components(component_id, connected_components, edge_list):
    """
    Return edge list and vertex lists of the connected component generated from the connected_components algorithm
    """
    subgraph_vertex_list = get_subgraph_vertex_list(connected_components, component_id) 
    subgraph_edge_list = get_subgraph_edge_list(subgraph_vertex_list , edge_list)
    
    return subgraph_vertex_list, subgraph_edge_list

In [9]:
def multiple_ch_nodes(vertex_list):  
    number_of_ch_nodes = vertex_list.where(vertex_list['type'] == 'CH').count()
    return number_of_ch_nodes >= 2

In [10]:
def process_subgraph(vertex_list, edge_list):
    """
    Check generated legal units are valid and decompose any which aren't into smaller legal units
    param: dataframe of the edge list of the sub graph
    
    Returns a list of the subgraphs (list of lists)
    [[CH!, VAT2], [CH2, VAT3]]
    
    """        
    if multiple_ch_nodes(vertex_list):
        return split_complex_sub_graphs(vertex_list, edge_list)
            
    return [[vertex['id'] for vertex in vertex_list.collect()]]     

In [14]:
def collapse_step(edge_list):
    '''
    create legal units from the edge list produced by the random forest step
    '''
    vertex_list = generate_vertex_set(edge_list)
    
    graph = GraphFrame(vertex_list, edge_list)
    
    connected_components = graph.connectedComponents()
    component_ids =  connected_components.select('component').distinct().rdd.map(lambda r: r[0]).collect()
    
    # Iterate over the subgraphs generated by the connected components algorithm 
    # applying business logic to split invalid ones
    sub_graphs = []

    for component_id in component_ids:
        subgraph_vertex_list, subgraph_edge_list = get_subgraph_components(component_id, connected_components, edge_list)
        sub_graphs += process_subgraph(subgraph_vertex_list, subgraph_edge_list)
    
    return sub_graphs

In [15]:
# Test data
edge_list = sqlContext.createDataFrame([
  ("CH_1", "VAT_1", "0.8"),
  ("CH_1", "VAT_2", "0.9"),
  ("CH_2", "VAT_2", "0.91"),
  ("VAT_1", "PAYE_1", "0.86"),
], ["src", "dst", 'weight'])

In [16]:
sub_graphs = collapse_step(edge_list)

In [17]:
sub_graphs

[['CH_2', 'VAT_2'], ['CH_1', 'PAYE_1', 'VAT_1']]

In [18]:
# TODO: output a dataframe of the expected format, 

In [19]:
fields = [StructField('value', ArrayType(elementType=StringType(), containsNull=False), nullable=False)]

In [20]:
output_dataframe.collect()

NameError: name 'output_dataframe' is not defined