In [None]:
!pip install pyspark
!pip install -U -q PyDrive
!apt install openjdk-8-jdk-headless -qq
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"

In [None]:
import collections
import copy
import math
import sys
import time
from functools import reduce
from itertools import combinations
from operator import add

from pyspark import SparkContext

# TODO change the number of bucket will have different result
BUCKET_NUMBER = 99

In [None]:
def hash_func(combination):
    result = sum(map(lambda x: int(x), list(combination)))
    return result % BUCKET_NUMBER


def check_bitmap(combination, bitmap):
    return bitmap[hash_func(combination)]


def wrapper(singleton_list):
    return [tuple(item.split(",")) for item in singleton_list]


def shrink_basket(basket, frequent_singleton):
    return sorted(list(set(basket).intersection(set(frequent_singleton))))


def cmp(pair1, pair2):
    return True if pair1[:-1] == pair2[:-1] else False


In [None]:
def gen_permutation(combination_list):

    if combination_list is not None and len(combination_list) > 0:
        size = len(combination_list[0])
        permutation_list = list()
        for index, front_pair in enumerate(combination_list[:-1]):
            for back_pair in combination_list[index + 1:]:
                if cmp(front_pair, back_pair):
                    combination = tuple(sorted(list(set(front_pair).union(set(back_pair)))))
                    temp_pair = list()
                    for pair in combinations(combination, size):
                        temp_pair.append(pair)
                    if set(temp_pair).issubset(set(combination_list)):
                        permutation_list.append(combination)
                else:
                    break

        return permutation_list

In [None]:
def find_candidate_itemset(data_baskets, original_support, whole_length):

    # compute support threshold in subset baskets
    support, data_baskets = gen_ps_threshold(data_baskets, original_support, whole_length)
    baskets_list = list(data_baskets)
    # print("baskets_list -> ", baskets_list)
    all_candidate_dict = collections.defaultdict(list)
    # first phrase of PCY algorithm, acquiring frequent_singleton and bitmap
    frequent_singleton, bitmap = init_singleton_and_bitmap(baskets_list, support)
    index = 1
    candidate_list = frequent_singleton
    all_candidate_dict[str(index)] = wrapper(frequent_singleton)

    # the second phrase, third phrase .... until the candidate list is empty
    while None is not candidate_list and len(candidate_list) > 0:
        index += 1
        temp_counter = collections.defaultdict(list)
        for basket in baskets_list:
            # we dont need to compute basket_item which is not frequent_single
            basket = shrink_basket(basket, frequent_singleton)
            if len(basket) >= index:
                if index == 2:
                    for pair in combinations(basket, index):
                        if check_bitmap(pair, bitmap):
                            # if check_proper_subset(pair, candidate_list):
                            # this is always true, since you have filter the basket before
                            temp_counter[pair].append(1)

                if index >= 3:
                    for candidate_item in candidate_list:
                        if set(candidate_item).issubset(set(basket)):
                            temp_counter[candidate_item].append(1)

        # filter the temp_counter
        filtered_dict = dict(filter(lambda elem: len(elem[1]) >= support, temp_counter.items()))
        # generate new candidate list
        candidate_list = gen_permutation(sorted(list(filtered_dict.keys())))
        if len(filtered_dict) == 0:
            break
        all_candidate_dict[str(index)] = list(filtered_dict.keys())

    yield reduce(lambda val1, val2: val1 + val2, all_candidate_dict.values())


In [None]:
def init_singleton_and_bitmap(baskets, support):
    
    bitmap = [0 for _ in range(BUCKET_NUMBER)]
    temp_counter = collections.defaultdict(list)
    for basket in baskets:
        # find frequent singleton
        for item in basket:
            temp_counter[item].append(1)

        # find frequent bucket
        for pair in combinations(basket, 2):
            key = hash_func(pair)
            bitmap[key] = (bitmap[key] + 1)

    filtered_dict = dict(filter(lambda elem: len(elem[1]) >= support, temp_counter.items()))
    frequent_singleton = sorted(list(filtered_dict.keys()))
    bitmap = list(map(lambda value: True if value >= support else False, bitmap))

    return frequent_singleton, bitmap


In [None]:
def count_frequent_itemset(data_baskets, candidate_pairs):
    temp_counter = collections.defaultdict(list)
    for pairs in candidate_pairs:
        if set(pairs).issubset(set(data_baskets)):
            temp_counter[pairs].append(1)

    yield [tuple((key, sum(value))) for key, value in temp_counter.items()]

In [None]:
def gen_ps_threshold(partition, support, whole_length):
    partition = copy.deepcopy(list(partition))
    return math.ceil(support * len(list(partition)) / whole_length), partition

In [None]:
def reformat(itemset_data):

    temp_index = 1
    result_str = ""
    for pair in itemset_data:
        if len(pair) == 1:
            result_str += str("(" + str(pair)[1:-2] + "),")

        elif len(pair) != temp_index:
            result_str = result_str[:-1]
            result_str += "\n\n"
            temp_index = len(pair)
            result_str += (str(pair) + ",")
        else:
            result_str += (str(pair) + ",")

    return result_str[:-1]

In [None]:
def export_2_file(candidate_data, frequent_data, file_path):
    with open(file_path, 'w+') as output_file:
        str_result = 'Candidates:\n' + reformat(candidate_data) + '\n\n' \
                     + 'Frequent Itemsets:\n' + reformat(frequent_data)
        output_file.write(str_result)
        output_file.close()

<h2>Main</h2>

In [None]:
start = time.time()
case_number = "1"  
support_threshold = "4"
input_csv_path = '/content/drive/MyDrive/yelp_sample/small2.csv'
output_file_path = '/content/drive/MyDrive/yelp_sample/case_1_out_new'

partition_number = 2

sc = SparkContext.getOrCreate()

raw_rdd = sc.textFile(input_csv_path, partition_number)
header = raw_rdd.first()
data_rdd = raw_rdd.filter(lambda line: line != header)
whole_data_size = None
basket_rdd = None


In [None]:
if 1 == int(case_number):
    basket_rdd = data_rdd.map(lambda line: (line.split(',')[0], line.split(',')[1]))
    basket_rdd = basket_rdd.groupByKey().map(lambda user_items: (user_items[0], sorted(list(set(list(user_items[1]))))))
    basket_rdd = basket_rdd.map(lambda item_users: item_users[1])

elif 2 == int(case_number):
    basket_rdd = data_rdd.map(lambda line: (line.split(',')[1], line.split(',')[0]))
    basket_rdd = basket_rdd.groupByKey().map(lambda user_items: (user_items[0], sorted(list(set(list(user_items[1]))))))
    basket_rdd = basket_rdd.map(lambda item_users: item_users[1])

basket_rdd.take(3)

[['100', '103', '105', '106', '107', '97'],
 ['100', '101', '103', '104', '106', '107', '108', '97'],
 ['97', '98']]

In [None]:
whole_data_size = basket_rdd.count()
whole_data_size

49

In [None]:
candidate_itemset = basket_rdd.mapPartitions(lambda partition: find_candidate_itemset(data_baskets=partition, original_support=int(support_threshold),whole_length=whole_data_size)) 
candidate_itemset = candidate_itemset.flatMap(lambda pairs: pairs).distinct() 
candidate_itemset = candidate_itemset.sortBy(lambda pairs: (len(pairs), pairs)).collect()

In [None]:
candidate_itemset[100:110]

[('104', '106'),
 ('104', '107'),
 ('104', '108'),
 ('104', '109'),
 ('104', '110'),
 ('104', '111'),
 ('104', '112'),
 ('104', '113'),
 ('104', '114'),
 ('104', '116')]

In [None]:
frequent_itemset = basket_rdd.flatMap(lambda partition: count_frequent_itemset(data_baskets=partition,candidate_pairs=candidate_itemset))
frequent_itemset = frequent_itemset.flatMap(lambda pairs: pairs).reduceByKey(add)
frequent_itemset = frequent_itemset.filter(lambda pair_count: pair_count[1] >= int(support_threshold))
frequent_itemset = frequent_itemset.map(lambda pair_count: pair_count[0])
frequent_itemset = frequent_itemset.sortBy(lambda pairs: (len(pairs), pairs)).collect()

In [None]:
frequent_itemset[100:110]

[('104', '107'),
 ('104', '108'),
 ('104', '109'),
 ('104', '110'),
 ('104', '111'),
 ('104', '112'),
 ('104', '113'),
 ('104', '114'),
 ('104', '116'),
 ('104', '117')]

In [None]:
export_2_file(candidate_data=candidate_itemset,frequent_data=frequent_itemset,file_path=output_file_path)