# Quick setup for pySpark and GraphFrame

In [None]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://archive.apache.org/dist/spark/spark-3.0.3/spark-3.0.3-bin-hadoop2.7.tgz

In [None]:
!tar xf /content/spark-3.0.3-bin-hadoop2.7.tgz

In [None]:
!wget -q https://repos.spark-packages.org/graphframes/graphframes/0.8.2-spark3.0-s_2.12/graphframes-0.8.2-spark3.0-s_2.12.jar

In [None]:
## install graphframe library on Colab
!mv /content/graphframes-0.8.2-spark3.0-s_2.12.jar /content/spark-3.0.3-bin-hadoop2.7/jars/graphframes-0.8.2-spark3.0-s_2.12.jar

In [None]:
!pip -q install findspark pyspark graphframes

In [None]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.0.3-bin-hadoop2.7"

In [None]:
import findspark
findspark.init()

In [None]:
from pyspark.sql import SparkSession

spark = SparkSession.builder\
        .master("local")\
        .appName("Colab")\
        .config('spark.ui.port', '4050')\
        .getOrCreate()

In [None]:
# confirm the spark installation
spark

In [None]:
os.environ["HADOOP_HOME"] = os.environ["SPARK_HOME"]

os.environ["PYSPARK_DRIVER_PYTHON"] = "jupyter"
os.environ["PYSPARK_DRIVER_PYTHON_OPTS"] = "notebook"
os.environ["PYSPARK_SUBMIT_ARGS"] = "--packages graphframes:graphframes:0.8.2-spark3.0-s_2.12 pyspark-shell"

In [None]:
from graphframes import *

In [None]:
def init_spark(app_name="HelloWorldApp", execution_mode="local[*]"):
  spark = SparkSession.builder.master(execution_mode).appName(app_name).getOrCreate()
  sc = spark.sparkContext
  return spark, sc

## Example: GraphFrame

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext

_, sc = init_spark()
sqlContext = SQLContext(sc)

## the rest of this code (down below) comes from: https://graphframes.github.io/graphframes/docs/_site/quick-start.html#getting-started-with-apache-spark-and-spark-packages

# Create a Vertex DataFrame with unique ID column "id"
v = sqlContext.createDataFrame([
  ("a", "Alice", 34),
  ("b", "Bob", 36),
  ("c", "Charlie", 30),
], ["id", "name", "age"])

# Create an Edge DataFrame with "src" and "dst" columns
e = sqlContext.createDataFrame([
  ("a", "b", "friend"),
  ("b", "c", "follow"),
  ("c", "b", "follow"),
], ["src", "dst", "relationship"])

# Create a GraphFrame
from graphframes import *
g = GraphFrame(v, e)

# Query: Get in-degree of each vertex.
g.inDegrees.show()

# Query: Count the number of "follow" connections in the graph.
g.edges.filter("relationship = 'follow'").count()

# Importing the Data into Apache Spark

In [None]:
!mkdir data
!wget -q https://github.com/chunsejin/KGProjects_DAU/raw/master/data/social-nodes.csv
!wget -q https://github.com/chunsejin/KGProjects_DAU/raw/master/data/social-relationships.csv
!mv social-nodes.csv data
!mv social-relationships.csv data

In [None]:
from graphframes import *
from pyspark import SparkContext

In [None]:
v = spark.read.csv("data/social-nodes.csv", header=True)
e = spark.read.csv("data/social-relationships.csv", header=True)
g = GraphFrame(v, e)

# Degree Centrality with Apache Spark

In [None]:
total_degree = g.degrees
in_degree = g.inDegrees
out_degree = g.outDegrees
(total_degree.join(in_degree, "id", how="left")
 .join(out_degree, "id", how="left")
 .fillna(0)
 .sort("inDegree", ascending=False)
 .show())

# Closeness Centrality with Apache Spark

In [None]:
from graphframes.lib import AggregateMessages as AM
from pyspark.sql import functions as F
from pyspark.sql.types import *
from operator import itemgette

In [None]:
def collect_paths(paths):
    return F.collect_set(paths)

collect_paths_udf = F.udf(collect_paths, ArrayType(StringType()))
paths_type = ArrayType(
    StructType([StructField("id", StringType()), StructField("distance",

def flatten(ids):
    flat_list = [item for sublist in ids for item in sublist]
    return list(dict(sorted(flat_list, key=itemgetter(0))).items())

flatten_udf = F.udf(flatten, paths_type)

def new_paths(paths, id):
    paths = [{"id": col1, "distance": col2 + 1} for col1,
                                                    col2 in paths if col1 != id]
    paths.append({"id": id, "distance": 1})
    return paths

new_paths_udf = F.udf(new_paths, paths_type)

def merge_paths(ids, new_ids, id):
    joined_ids = ids + (new_ids if new_ids else [])
    merged_ids = [(col1, col2) for col1, col2 in joined_ids if col1 != id]
    best_ids = dict(sorted(merged_ids, key=itemgetter(1), reverse=True))
    return [{"id": col1, "distance": col2} for col1, col2 in best_ids.items()]

merge_paths_udf = F.udf(merge_paths, paths_type)

def calculate_closeness(ids):
    nodes = len(ids)
    total_distance = sum([col2 for col1, col2 in ids])
    return 0 if total_distance == 0 else nodes * 1.0 / total_distance

closeness_udf = F.udf(calculate_closeness, DoubleType())

In [None]:
vertices = g.vertices.withColumn("ids", F.array())
cached_vertices = AM.getCachedDataFrame(vertices)
g2 = GraphFrame(cached_vertices, g.edges)

In [None]:
for i in range(0, g2.vertices.count()):
    msg_dst = new_paths_udf(AM.src["ids"], AM.src["id"])
    msg_src = new_paths_udf(AM.dst["ids"], AM.dst["id"])
    agg = g2.aggregateMessages(F.collect_set(AM.msg).alias("agg"),
                               sendToSrc=msg_src, sendToDst=msg_dst)
    res = agg.withColumn("newIds", flatten_udf("agg")).drop("agg")
    new_vertices = (g2.vertices.join(res, on="id", how="left_outer")
                    .withColumn("mergedIds", merge_paths_udf("ids", "newIds",
                                                             "id")).drop("ids", "newIds")
                    .withColumnRenamed("mergedIds", "ids"))
    cached_new_vertices = AM.getCachedDataFrame(new_vertices)
    g2 = GraphFrame(cached_new_vertices, g2.edges)

In [None]:
(g2.vertices
 .withColumn("closeness", closeness_udf("ids"))
 .sort("closeness", ascending=False)
 .show(truncate=False))

# PageRank with Apache Spark

In [None]:
results = g.pageRank(resetProbability=0.15, maxIter=20)
results.vertices.sort("pagerank", ascending=False).show()

In [None]:
results = g.pageRank(resetProbability=0.15, tol=0.01)
results.vertices.sort("pagerank", ascending=False).show()

# Appendix: Timing and Profling
%time: Time the execution of a single statement

%timeit: Time repeated execution of a single statement for more accuracy

%prun: Run code with the profiler

%lprun: Run code with the line-by-line profiler

%memit: Measure the memory use of a single statement

%mprun: Run code with the line-by-line memory profile