In [None]:
import pyspark
import hashlib
import re

conf = pyspark.SparkConf().setMaster('local').setAppName('lsh')
sc = pyspark.SparkContext(conf=conf)


def readData(x):
    file_id = int(x[0][-7:-4])
    content = x[1]
    res = []
    for l in content.splitlines():

        words = l.split(' ')
        new_words = []
        for i in words:
            i = re.sub('[^a-zA-Z0-9]$', '', i)
            i = re.sub('^[^a-zA-Z0-9]', '', i)
            new_words.append(i)
        for i in range(len(new_words)-2):
            shingles = new_words[i] + new_words[i+1] + new_words[i+2]
            shingles = int.from_bytes(hashlib.sha256(shingles.encode('utf8')).digest()[:4], 'little')
            res.append(shingles)

    return (file_id, res)

def min_hash(x):
    l = []
    res = []
    for i in range(100):
        hash_func = i * x[0] % 22111 % keys_conut
        l.append(hash_func)
    for i in x[1]:
        for idx, val in enumerate(l):
            res.append(((x[0], i, idx), val))
    return res

def gen_candidate_pair(x):
    res = []
    for i in x[1]:
        for j in x[1]:
            if i != j:
                if i < j:
                    res.append((i, j))
                else:
                    res.append((j, i))
    return list(set(res))

def cal_sim(x):
    union = set(x[1][0]).union(set(x[1][1]))
    inter = set(x[1][0]).intersection(set(x[1][1]))
    
    return (x[0], len(inter) / len(union))


def hash_bucket(x):
    if x[0][0] % 2 == 0:
        return ((x[0][0] // 2, x[0][1]), x[1])
    else:
        return ((x[0][0] // 2, x[0][1]), x[1] * keys_conut)

original_data = sc.wholeTextFiles('./athletics/*.txt').map(readData)
hash_matrix = original_data.flatMapValues(lambda x: x).map(lambda x: (x[1], [x[0]])).reduceByKey(lambda x, y: x+y).sortBy(lambda x: x[0], ascending=True)
hash_matrix = hash_matrix.zipWithIndex().map(lambda x: (x[1], x[0][1]))

keys_conut = hash_matrix.keys().count()
signature_matrix = hash_matrix.flatMap(min_hash).reduceByKey(lambda x, y: x if x < y else y)

lsh_matrix = signature_matrix.map(hash_bucket).reduceByKey(lambda x, y: x + y)
lsh_matrix = lsh_matrix.map(lambda x: ((x[0][0], x[1]), [x[0][1]])).reduceByKey(lambda x, y: x + y)
candidate_pairs = lsh_matrix.flatMap(gen_candidate_pair).distinct()


sim_pairs = candidate_pairs.join(original_data).map(lambda x: (x[1][0], (x[0], x[1][1]))).join(
    original_data).map(lambda x: ((x[1][0][0], x[0]), (x[1][0][1], x[1][1])))
sim_pairs = sim_pairs.map(cal_sim).sortBy(lambda x: x[1], ascending=False)

ans = sim_pairs.take(10)
for i in ans:
    print('(%03d, %03d): %.2f %s' % (i[0][0], i[0][1], i[1]*100, '%'))

sc.stop()
