In [1]:
import os
from pyspark import SparkContext, SparkConf

In [27]:
from pyspark import intersection
from itertools import product

In [2]:
# sc = SparkContext(appName="Spark Apriori")
conf = SparkConf().setAppName("Spark Apriori").setMaster(os.getenv('SPARK_URL'))
sc = SparkContext(conf=conf)

23/05/07 11:05:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [3]:
sc

## Frequent K-item Sets

In [4]:
def generate_next_c(f_k, k):
    next_c = [var1 | var2 for index, var1 in enumerate(f_k) for var2 in f_k[index + 1:] if
              list(var1)[:k - 2] == list(var2)[:k - 2]]
    return next_c

def generate_f_k(sc, c_k, shared_itemset, sup):
    def get_sup(x):
        x_sup = len([1 for t in shared_itemset.value if x.issubset(t)])
        if x_sup >= sup:
            return x, x_sup
        else:
            return ()

    f_k = sc.parallelize(c_k).map(get_sup).filter(lambda x: x).collect()
    return f_k

In [79]:
def apriori(sc, f_input, f_output, min_support):
    # read the raw data
    data = sc.textFile(f_input)
    # count the total number of samples
    n_samples = data.count()
    # min_support to frequency
    min_support_count = n_samples * min_support
    # split sort
    itemset = data.map(lambda line: sorted([int(item) for item in line.strip().split(' ')]))
    # share the whole itemset with all workers
    shared_itemset = sc.broadcast(itemset.map(lambda x: set(x)).collect())
    # store for all freq_k
    frequent_itemset = []

    # prepare candidate_1
    k = 1
    c_k = itemset.flatMap(lambda x: set(x)).distinct().collect()
    c_k = [{x} for x in c_k]

    # when candidate_k is not empty
    while len(c_k) > 0:
        # generate freq_k
#         print("C{}: {}".format(k, c_k))
        f_k = generate_f_k(sc, c_k, shared_itemset, min_support_count)
#         print("F{}: {}".format(k, f_k))

        frequent_itemset.append(f_k)
        # generate candidate_k+1
        k += 1
        c_k = generate_next_c([set(item) for item in map(lambda x: x[0], f_k)], k)

    if len(frequent_itemset[-1]) == 0: 
        frequent_itemset = frequent_itemset[:-1]
    # output the result to file system
    sc.parallelize(frequent_itemset, numSlices=1).saveAsTextFile(f_output)
#     return sc.parallelize(frequent_itemset).map(lambda x: list(x[0])).collect()
    return frequent_itemset

In [80]:
!rm -r ./output/frequent_set/

In [81]:
input_path = './dataset/apriori_ratings.txt'
output_path = './output/frequent_set'
min_support = 0.02
freqset = apriori(sc, input_path, output_path, min_support)
# apriori(sc, input_path, output_path, min_support)

                                                                                

In [84]:
for k, k_set in enumerate(freqset): 
    print(k+1, len(k_set))

1 127
2 467
3 477
4 176
5 30
6 1


## Association Rule

In [85]:
def confidence(sc, freq_parent, freq_child, conf): 
    def get_confidence(x): 
        result = [(x[0], child[0], round(child[1]/x[1], 2)) for child in freq_child.value if x[0].issubset(child[0]) and child[1]/x[1]>conf]
        return result
        
    return sc.parallelize(freq_parent).flatMap(get_confidence).filter(lambda x:x).collect()

In [86]:
len(freqset)
association_rules = []
for k in range(len(freqset) - 1): 
    print(freqset[k][0])
    freq_parent = freqset[k]
    freq_child = sc.broadcast(freqset[k+1])
    association_rules.append(confidence(sc, freq_parent, freq_child, conf = 0.6))

({41566}, 4604)
({858, 1262}, 5903)
({318, 858, 1262}, 4274)
({4226, 2324, 858, 318}, 3763)
({4226, 293, 858, 1213, 318}, 3639)


# Serialize (store) the frequent set result 

In [88]:
import pickle

In [89]:
with open('./output/association_rules', 'wb') as f:
    pickle.dump(association_rules, f)

In [90]:
with open('./output/association_rules', 'rb') as f:
    rules = pickle.load(f)