In [1]:
import os
import sys
from pyspark import SparkContext, SparkConf
import json
import itertools
import math
import numpy as np
import time

In [2]:
appName = 'assignment3'
master = 'local[*]'
conf = SparkConf().setAppName(appName).setMaster(master)
# conf = SparkConf().setAll([('spark.executor.memory', '8g'), ('spark.executor.cores', '3'), ('spark.cores.max', '3'), ('spark.driver.memory','8g')])
sc = SparkContext(conf=conf)
sc.setLogLevel("INFO")

In [3]:
sc

In [4]:
def read_csv_line(line):
    line = line.split(',')
    return (line[0].strip(), line[1].strip(), line[2].strip())

def read_business(data_path):
    rdd = sc.textFile(data_path)
    header = rdd.first()
    rdd = rdd.filter(lambda x: x != header).map(read_csv_line)
    unique_users = rdd.map(lambda x: x[0]).distinct().collect()    
    userid_index = {}
    for i, uid in enumerate(unique_users):
        userid_index[uid] = i
    
    business_users = rdd.map(lambda x: (x[1], userid_index[x[0]])).groupByKey().map(lambda x: (x[0], list(x[1])))
    
    return business_users

In [36]:
out = read_business('./data/yelp_train.csv').take(5)

In [37]:
out[:3]

[('3MntE_HWbNNoyiLGxywjYA', [5628, 791, 722, 7074, 2988]),
 ('YXohNvMTCmGhFMSQsDZq1g',
  [2,
   25,
   6283,
   1514,
   7139,
   2350,
   8823,
   3339,
   5723,
   282,
   9706,
   7950,
   10051,
   3200,
   975,
   2834,
   2841,
   1069,
   6774,
   5104,
   490,
   1391,
   4588,
   2435,
   7577,
   2187,
   1075,
   508,
   4442,
   9550,
   9,
   3472,
   587,
   3495,
   3153,
   5645,
   8596,
   7783,
   3848,
   9375,
   51,
   7047,
   1843,
   3734,
   601,
   153,
   1279,
   303,
   265,
   910,
   5647,
   1794,
   11212,
   8899,
   3820,
   424,
   10752,
   6276,
   643,
   1232,
   437,
   3127,
   1587,
   10390,
   6224,
   756,
   7092,
   3888,
   1669,
   6425,
   9438,
   7342,
   4747,
   1660,
   3390,
   7061,
   6072,
   8736,
   1431,
   1735,
   6523,
   2536,
   5740,
   2332,
   6615,
   1390,
   1050,
   6477,
   7153,
   1137,
   6501,
   308,
   10472,
   10882,
   4,
   8827,
   2824,
   3666,
   10730,
   8564,
   8029,
   2936,
   5948,
   3171

In [14]:
# Total number of users 11270
p = 13591
m = 64
number_of_hashes = 64
a = np.random.randint(1, high=p, size = number_of_hashes)
b = np.random.randint(0, high=p, size = number_of_hashes)

def minhash(vector, a, b, p, m):
    vector = np.array(vector)
    return list(np.min((a.reshape(1, a.shape[0]) * vector.reshape(vector.shape[0], 1) + b) %p %m, axis=0))

In [21]:
rdd = read_business('./data/yelp_train.csv')
st = time.time()
business_sig = rdd.map(lambda x: (x[0], minhash(x[1], a, b, p, m)))
print(time.time() - st)

0.00021839141845703125


In [34]:
from itertools import combinations

def hash_bands(x):
    doc_id, sig = x
    b = 16
    r = 4
    output = []
    for i in range(0, b):
        output.append(((i, hash(frozenset(sig[r*i:(i+1)*r]))), doc_id))
    return output


def jaccard_sim(x, y):
    x = set(x)
    y = set(y)
    return len(x.intersection(y)) / len(x.union(y))


def index_signatures(signatures):
    index = {}
    for x, y in signatures:
        index[x] = y
        
    return index

def find_similarity(candidate, index):
    bucket = candidate[1]
    output = []
    for candidate1, candidate2 in combinations(bucket, 2):
        sim = jaccard_sim(index[candidate1], index[candidate2])
        if(sim >= 0.5):
            output.append((candidate1, candidate2, sim))
            
    return output


rdd = read_business('./data/yelp_train.csv')
signatures = rdd.collect()
index = index_signatures(signatures)
business_sig = rdd.map(lambda x: (x[0], minhash(x[1], a, b, p, m)))
candidates = business_sig.flatMap(hash_bands).groupByKey().map(lambda x: (x[0], list(x[1]))).filter(lambda x: len(x[1]) > 1)
similar_pairs = candidates.flatMap(lambda x: find_similarity(x, index)).collect()

In [35]:
similar_pairs[:2]

[('H8mq-5oLkF9jlfyYi3vvOw', '4XGjbI2Ggi-kdgt9eZR83w', 0.6),
 ('6nMYROXu0VX4Ytpdsfi3XA', '-i3dOjumvOw-52aGXU1xDg', 0.5)]

In [16]:
hash(frozenset([1,2,3]))

-272375401224217160

In [18]:
hash(frozenset([4,2,3]))

7752673356882005635

In [13]:
# b*r = n
# (1/b)^(1/r) = threshold


24732