In [160]:
# References:
# http://hadooptutorial.wikispaces.com/Iterative+MapReduce+and+Counters
# http://www.slideshare.net/jhammerb/lec5-pagerank
# The CSV and Spark manual pages
import csv
import time

INF = 999999999
issues = {}
charLookup = {}
revLookup = []
adjList = {}
heroCount = 0
updates = sc.accumulator(0)

# Read CSV file
with open('source.csv','rb') as f:
    reader = csv.reader(f)
    for row in reader:
        name, issue = row[0].strip(), row[1].strip()      
        if issue not in issues:
            issues[issue] = set()
        if name not in charLookup:
            charLookup[name] = heroCount
            revLookup.append(name)
            heroCount += 1
        issues[issue].add(name)

# Create adjacency list
for cs in issues.values():
    chars = [charLookup[c] for c in cs]
    for i in xrange(len(chars)):
        c1 = chars[i]
        if c1 not in adjList:
            adjList[c1] = set()
        for j in xrange(i+1, len(chars)):
            c2 = chars[j]
            if c2 not in adjList:
                adjList[c2] = set()
            adjList[c1].add(c2)
            adjList[c2].add(c1)

# Map: emit neighbour nodes if GRAY, otherwise
# just emit the same thing if WHITE or BLACK
def processNode((k,v)):
    curDist, adj, col = v
    ret = []
    if col == 1: # Emit all neighbouring nodes and cur
        for nb in adj:
            ret.append( (nb, (curDist+1, [], 1)) )        
        ret.append((k, (curDist, adj, 2)))
    else: # Emit the same thing
        ret.append((k,v))
    return ret

# Combine nodes by key
def combineNode(a, b):    
    retAdj = a[1] if len(a[1]) > len(b[1]) else b[1]
    retDist = min(a[0], b[0])
    retC = max(a[2], b[2]) # Pick the higher color

    return (retDist, retAdj, retC)      

# Format of the RDD:
# node_index | current_dist | adj_list | color
# color = {WHITE=0, GRAY=1, BLACK=2}
def nTouched(adjList, startN = 0):
    rddList = []
    
    for node, adj in adjList.iteritems():    
        dist = 0 if node == startN else INF
        rddList.append( (node, (dist, list(adj), int(dist==0)) ) )

    graphRDD = sc.parallelize(rddList, 2)
    while True: # While gray nodes still exist 
        updates.value = 0
        graphRDD = graphRDD.flatMap(processNode)\
                        .reduceByKey(combineNode)\

        graphRDD.foreach(lambda (k,v): updates.add(int(v[2]==1)))      
        if updates.value == 0:
            break
    
    touched = graphRDD.map(lambda (k,v): int(v[0]!=INF))\
                    .reduce(lambda a,b: a+b)
    return touched


sT = time.time()
sList = ['CAPTAIN AMERICA', 'MISS THING/MARY', 'ORWELL']
for char in sList:
    print 'Source =', char
    print 'Touched', nTouched(adjList, charLookup[char]),'nodes'
print time.time() - sT, 'seconds'

# # print res
# with open('output.txt','w') as f:
#     for line in res:
#         f.write(str(line))



Source = CAPTAIN AMERICA
Touched 6403 nodes
Source = MISS THING/MARY
Touched 7 nodes
Source = ORWELL
Touched 9 nodes
1.39092302322 seconds
