In [1]:
from __future__ import print_function
import copy

Import spark and start a local context

In [2]:
from pyspark import SparkConf, SparkContext
sc = SparkContext(conf=SparkConf().setAppName("MyApp").setMaster("local"))

Define mapping functions

In [3]:
def parse_edge(s):
    user, follower = s.split("\t")
    return (int(user), int(follower))

def step(item):
    prev_v, prev_d, next_v = item[0], item[1][0], item[1][1]
    next_d = copy.copy(prev_d)
    next_d.append(prev_v)
    return (next_v, next_d)

def complete(item):
  v, old_d, new_d = item[0], item[1][0], item[1][1]
  return (v, old_d if old_d is not None else new_d)



Global settings. In the starter example code, the number of partitions is too high.
So, we reduce it to 4.

In [4]:
#n = 400  # number of partitions
n = 4

Loading data


In [5]:
edges = sc.textFile("/data/twitter/twitter_sample_small.txt").map(parse_edge).cache()
#edges = sc.textFile("/data/twitter/twitter_sample.txt").map(parse_edge).cache()
forward_edges = edges.map(lambda e: (e[1], e[0])).partitionBy(n).persist()

Define start node. We shall track the path to the node, not only the distance.

In [6]:
x = 12 # start node
t = 34 # end node
d = 0
distances = sc.parallelize([(x, [])]).partitionBy(n)
shortest_path = []

Compute all distances

In [7]:
while True:
    # we keep candidates aside, so we can reduce the forward edges set later
    candidates_join = distances.join(forward_edges, n).persist()
    candidates = candidates_join.map(step)
    new_distances = distances.fullOuterJoin(candidates, n).map(complete, True).persist()
    # did we hit target ?
    targets = new_distances.filter(lambda i: i[0] == t)
    if targets.count() > 0:
        # if yes, return the first match
        collected_targets = targets.collect()
        shortest_path = collected_targets[0][1]
        break;
    # Else counting new nodes discovered... This is useless here, as we shall
    # find the target. It helps stopping the computation if we have something wrong,
    # anyway
    new_distances = new_distances.filter(lambda i: len(i[1]) == d + 1).persist()
    count = new_distances.count()
    if count > 0:
        d += 1
        distances = new_distances
        # reducing the forward edges set
        forward_edges = forward_edges.subtractByKey(candidates_join).persist()
    else:
        break


In [8]:
# outputing result
for node in shortest_path:
        print(node,end='')
        print(',',end='')
print(t)

12,422,53,52,107,20,23,274,34
