In [1]:
from pyspark import SparkContext
from pyspark import AccumulatorParam
sc = SparkContext("local", "P4")

In [2]:
source = sc.textFile("source.csv")
source.take(1)

[u'"FROST, CARMILLA","AA2 35"']

In [3]:
def comic_KV(x):
    K_V = x.split('","')
    return (K_V[1][0:-1], [K_V[0][1:]])
def node_KV(x):
    K_V = x.split('","')
    return (K_V[1][0:-1], K_V[0][1:])
def get_neighbors(val):
    c_n = val[1]
    neighbors = c_n[1][:]
    neighbors.remove(c_n[0])
    return (c_n[0], neighbors)
def group_neighbors(x, y):
    return set(x) | set(y)

nodes = source.map(node_KV)
comics = source.map(comic_KV)
comics = comics.reduceByKey(lambda x,y: x + y)

In [4]:
nodes_neighbors = nodes.join(comics).map(get_neighbors).reduceByKey(group_neighbors)
sorted_neighbors = nodes_neighbors.sortByKey()

In [9]:
class DistanceAccumulatorParam(AccumulatorParam):
    def zero(self, initialValue):
        return initialValue

    def addInPlace(self, v1, v2):
        for key in v2.keys():
            if key not in v1:
                v1[key] = v2[key]
        return v1
    
def ss_bfs_accum(rdd, root, diameter = -1):
    distance_hash = rdd.context.accumulator({root: 0}, DistanceAccumulatorParam())
    next_hop = rdd.lookup(root)[0]
    hops = 1
    while (hops <= diameter or diameter < 0) and len(next_hop) > 0:
        next_rdd = rdd.filter(lambda x: x[0] in next_hop)
        next_rdd.foreach(lambda x: distance_hash.add({x[0] : hops}))
        next_hop = set(next_rdd.flatMap(lambda x: x[1]).collect()) - set(distance_hash.value.keys())
        hops += 1
    return distance_hash.value

def ss_bfs(rdd, root, diameter = -1):
    next_hop = rdd.lookup(root)[0]
    distances = {}
    hops = 1
    while (hops <= diameter or diameter < 0) and len(next_hop) > 0:
        for node in next_hop:
            if node not in distances:
                distances[node] = hops
        next_hop = set(rdd.filter(lambda x: x[0] in next_hop).flatMap(lambda x: x[1]).collect()) - set(distances.keys())
        hops += 1
    return distances

In [6]:
%%timeit
roots = [u'CAPTAIN AMERICA', u'MISS THING/MARY', u'ORWELL']
for r in roots:
    print len(ss_bfs_accum(sorted_neighbors, r))

6408
7
9
6408
7
9
6408
7
9
6408
7
9
1 loops, best of 3: 2.81 s per loop


In [10]:
%%timeit
roots = [u'CAPTAIN AMERICA', u'MISS THING/MARY', u'ORWELL']
for r in roots:
    print len(ss_bfs(sorted_neighbors, r))

6408
7
9
6408
7
9
6408
7
9
6408
7
9
1 loops, best of 3: 1.94 s per loop
