# GraphFrames Basics

Examples of how to use GraphFrames for basic queries, motif finding, and general graph algorithms.

## Basic setup

GraphFrames is a package for Apache Spark which provides DataFrame-based Graphs. It provides high-level APIs in Scala, Java, and Python. It aims to provide both the functionality of GraphX and extended functionality taking advantage of Spark DataFrames. This extended functionality includes motif finding, DataFrame-based serialization, and highly expressive graph queries.

In order to be able to use this API, we need to first import the python modules. This assumes the GraphFrames Spark package has been already configured correctly to run with this notebook.

In [None]:
import pyspark
from pyspark.context import SparkContext
from pyspark.sql import SparkSession, SQLContext
from pyspark.sql.functions import regexp_replace
from pyspark.sql.types import *
from graphframes import *

## Preparing the data

Graph consist of set of vertices set of edges. These has to be provided as data frames.

In [None]:
n_schema = StructType([StructField("id", LongType()), StructField("sex", StringType())])
n_nodes = [[0, "m"], [1, "f"], [2, "m"], [3, "m"], [4, "m"], [5, "f"], [6, "m"], [7, "f"], [8, "m"], [9, "f"]]

e_schema = StructType([StructField("src", LongType()), StructField("dst", LongType())])
n_edges = [[1,2], [2,3], [3,1], [3,4], [4,2], [2,4], [9, 1], [8, 1], [7, 1], [6, 5], [9, 0], [0,1]]

nodes_df = spark.createDataFrame(n_nodes, schema=n_schema)
print("nodes:")
nodes_df.show()

edges_df = spark.createDataFrame(n_edges, schema=e_schema)
print("edges:")
edges_df.show()

## Constructing the graph representation

Now, we can create the graph frame from the nodes and edges data frames. It's quite handy to cache the result to avoid the unwanted computations.

In [None]:
g1 = GraphFrame(nodes_df, edges_df).cache()

## Creating other graph

If we don't care about the schema, there is also easier way to create the simple graph frame:

In [None]:
# Vertex DataFrame
v = sqlContext.createDataFrame([
  (1, "Alice", 62),
  (2, "Bob", 12),
  (3, "Charlie", 55),
  (4, "David", 29),
  (5, "Esther", 32),
  (6, "Fanny", 14),
  (7, "Gabby", 60),
  (0, "Henry", 51)
], ["id", "name", "age"])

# Edge DataFrame
e = sqlContext.createDataFrame([
  (1, 2, "friend"),
  (2, 3, "follow"),
  (3, 2, "follow"),
  (6, 3, "follow"),
  (5, 6, "follow"),
  (0, 2, "follow"),
  (0, 6, "follow"),
  (5, 4, "friend"),
  (4, 1, "friend"),
  (2, 0, "friend"),
  (1, 5, "friend")
], ["src", "dst", "relationship"])
# Create a GraphFrame
g2 = GraphFrame(v, e).cache()

We have two different graphs but we can merge them together, providing the ids match. First we need to merge the nodes:

In [None]:
print("g1 nodes:")
g1.vertices.show()

print("g2 nodes:")
g2.vertices.show()

print("merged nodes:")
merged_nodes = g1.vertices.join(g2.vertices, 'id')
merged_nodes.show()

Do the similar for edges. Here we need to join on both 'src' and 'dst'.

In [None]:
print("g1 edges:")
g1.edges.show()

print("g2 edges:")
g2.edges.show()

print("merged edges:")
merged_edges_raw = g1.edges.join(g2.edges, ['src', 'dst'], 'fullouter')
merged_edges_raw.show()

Replace `null` values with word 'other'.

In [None]:
merged_edges = merged_edges_raw.na.fill('other')
merged_edges.show()

In [None]:
g = GraphFrame(merged_nodes, merged_edges)
g.cache()

## Simple algorithms

GraphFrames provides the same suite of standard graph algorithms as GraphX, plus some new ones. See the [API g.find("(a)-[]->(b)") docs](https://graphframes.github.io/api/python/index.html) for more details.

### Vertex degrees

In [None]:
vertexInDegrees = g.inDegrees
vertexInDegrees.show()

vertexOutDegrees = g.outDegrees
vertexOutDegrees.show()

# node with the highest out degree
foo = vertexInDegrees.join(g.vertices, 'id') \
                     .orderBy("inDegree", ascending=False) \
                     .head()
print("highest in degree:" + str(foo))

# node with the highest in degree
bar = vertexOutDegrees.join(g.vertices, 'id') \
                      .orderBy("outDegree", ascending=False) \
                      .head()
print("highest out degree:" + str(bar))

### Motif queries
It's possible to use ASCII-like queries to find patterns in the graph structure, the general form looks like:
```
g.find("(a)-[e]->(b)") 
 .filter(...)
 .groupBy(...)
 .
```

Find all people that follow someone, but are not followed back.

In [None]:
motifs = g.find("(a)-[e]->(b); !(b)-[]->(a)") \
          .filter("e.relationship = 'follow'")
motifs.show()

Find all people older than 40 that follow at least two people of age under 15.

In [176]:
candidates = g.find("(a)-[]->(b); (a)-[]->(c)") \
              .filter("b != c") \
              .filter("a.age > 40") \
              .filter("b.age < 15") \
              .filter("c.age < 15")
candidates.show()

+--------------+--------------+--------------+
|             a|             b|             c|
+--------------+--------------+--------------+
|[0,m,Henry,51]|[6,m,Fanny,14]|  [2,m,Bob,12]|
|[0,m,Henry,51]|  [2,m,Bob,12]|[6,m,Fanny,14]|
+--------------+--------------+--------------+



### Label Propagation
Within complex networks, real networks tend to have community structure. [Label propagation](https://en.wikipedia.org/wiki/Label_Propagation_Algorithm) is an algorithm for finding communities.

In [None]:
# this fails with OOM error
#results = g.pageRank(resetProbability=0.15, maxIter=1)
#results.vertices.select("id", "pagerank").show()
#results.edges.select("src", "dst", "weight").show()

#labels = g.labelPropagation(maxIter=1)
#labels.show()

# this would be nice display(ranks.vertices.orderBy(ranks.vertices.pagerank.desc()).limit(20))

In [None]:
triangles = g.triangleCount()
triangles.show()

## Visualization of a sub-graph

Our data contain a lot of transactions (2 087 249 transactions among 546 651 wallets) so let's show only a small fraction of the transaction graph. We will show all the outgoing transaction of particular bitcoin address.

In [None]:
import random
    
def node_to_dict(r):
    return {
        'id': r[0],
        'label': r[1],
        'x': random.uniform(0,1),
        'y': random.uniform(0,1),
        'size': random.uniform(0.2,1)
    }

nodes_dict = map(node_to_dict, g.vertices.collect())

def edge_to_dict(i, r):
    return {
        'id': i,
        'source': r[0],
        'target': r[1]
    }

edges_dict = [edge_to_dict(i, r) for i, r in enumerate(g.edges.collect())]

Py4JJavaError: An error occurred while calling o1653.collectToPython.
: org.apache.spark.SparkException: Job 1057 cancelled 
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1499)
	at org.apache.spark.scheduler.DAGScheduler.handleJobCancellation(DAGScheduler.scala:1439)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1686)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1669)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1658)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:630)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2022)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2043)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2062)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2087)
	at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:936)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:935)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:278)
	at org.apache.spark.sql.Dataset$$anonfun$collectToPython$1.apply$mcI$sp(Dataset.scala:2803)
	at org.apache.spark.sql.Dataset$$anonfun$collectToPython$1.apply(Dataset.scala:2800)
	at org.apache.spark.sql.Dataset$$anonfun$collectToPython$1.apply(Dataset.scala:2800)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
	at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2823)
	at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:2800)
	at sun.reflect.GeneratedMethodAccessor236.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:748)


Now we are ready to show the data using the [sigmajs](sigmajs.org) library.

In [22]:
%%javascript
require.config({
    paths: {
        sigmajs: 'https://cdnjs.cloudflare.com/ajax/libs/sigma.js/1.2.0/sigma.min'
    }
});

require(['sigmajs']);

<IPython.core.display.Javascript object>

In [31]:
from IPython.core.display import display, HTML
from string import Template
import json

js_text_template = Template(open('js/sigma-graph.js','r').read())

graph_data = { 'nodes': nodes_dict, 'edges': edges_dict }

js_text = js_text_template.substitute({'graph_data': json.dumps(graph_data),
                                       'container': 'graph-div'})

html_template = Template('''
<div id="graph-div" style="height:400px"></div>
<script> $js_text </script>
''')

HTML(html_template.substitute({'js_text': js_text}))