- Necessary Imports
- Declaring configuration type
- Declaring constants for configuration keys (used in setting non-method accessible attributes)

In [15]:
import json
from dataclasses import dataclass
from typing import Optional, Dict, Any, TypeVar, List, Callable, Self, Set
from pyspark import SparkContext, SparkConf
from pandas import read_csv, to_datetime, DataFrame
import re
from enum import Enum

T = TypeVar("T")

class HashTreeNodeType(Enum):
    INTERNAL = 0
    LEAF = 1

# Fields for https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.SparkConf.html#pyspark.SparkConf
@dataclass
class Configuration:
    appName: str
    bindAddress: str
    bindPort: str
    masterUrl: str


CONFIG_BIND_ADDRESS_KEY = "spark.driver.bindAddress"
CONFIG_BIND_PORT_KEY = "spark.ui.port"

Parsing the spark configuration values from the json file

In [16]:
configuration_values: Optional[Dict[str, Any]] = None

with open("configuration.json", "r") as configuration_file:
    configuration_values = json.loads(configuration_file.read())

print("FAILED TO LOAD" if configuration_values == None else configuration_values)

{'appName': 'APriori Example', 'masterUrl': 'local', 'bindAddress': 'localhost', 'bindPort': '4050'}


Creating the spark configuration from the parsed json configuration

In [17]:
configuration = SparkConf()

spark_config: Optional[Configuration] = Configuration(**configuration_values)

if spark_config == None:
    raise ValueError("Must supply configuration, or keep defaults")

configuration.setAppName(spark_config.appName)
configuration.setMaster(spark_config.masterUrl)
configuration.set(CONFIG_BIND_ADDRESS_KEY, spark_config.bindAddress)
configuration.set(CONFIG_BIND_PORT_KEY, spark_config.bindPort)

<pyspark.conf.SparkConf at 0x1eecdda9010>

Instantiating & Creating the SparkContext

In [18]:
spark_context: Optional[SparkContext] = None

if spark_context is not None:
    spark_context.stop()

spark_context = SparkContext.getOrCreate(conf=configuration)

Create "baskets" for APriori algorithm

In [19]:
supermarket_df = read_csv("./data/supermarket_sales.csv")

## Some Data Augmenting
# Adding Month, Year, and Day columns which correspond to the respective Month/Year/Day values of the dates
supermarket_df["Date"] = to_datetime(supermarket_df["Date"])
supermarket_df["Hour"] = supermarket_df["Time"].map(lambda x: int(x.split(":")[0]))
supermarket_df["Minute"] = supermarket_df["Time"].map(lambda x: int(x.split(":")[1]))

# The goal is to form "baskets" from the transactions of each month, which we will use for our APriori analysis
supermarket_df = supermarket_df.sort_values(by="Date")

# The current basket id
basket_id = 1

# The collection of baskets
baskets = {}

# The minute threshold
minute_threshold = 10

# The unique dates, which represent the unique day/month/year values
unique_dates = supermarket_df["Date"].unique()

# Iterate over each unique date
for each_date in unique_dates:

    # Find all dates that match the current `each_date`
    x = supermarket_df[(supermarket_df["Date"].dt.year == each_date.year) & (supermarket_df["Date"].dt.month == each_date.month) & (supermarket_df["Date"].dt.day == each_date.day)]
    
    # Find all unique hours for that specific date
    unique_hours = set([int(y) for y in x["Hour"].unique()])

    # Iterate over all hours that transactions occurred in that date
    for each_hour in unique_hours:

        # Find all hour transactions that match the selected hour
        hour_transactions = x[x["Hour"] == each_hour]

        # Find all unique minutes within that hour
        unique_minutes = sorted([int(y) for y in hour_transactions["Minute"].unique()])

        # Running set of the currently "basket'd" transactions
        current_found_transaction_ids = set()
        for i in range(len(unique_minutes) - 1):

            # The current basket
            basket = []

            # The current minute
            curr_minute = unique_minutes[i]

            # All minutes that meet the threshold
            future_minutes = list(filter(lambda x: x - curr_minute <= minute_threshold, unique_minutes[i + 1:]))

            # Add current minute to the minutes to find
            total_minutes_to_find = [curr_minute] + future_minutes
            for each_minute in total_minutes_to_find:

                # Find the record matching to the minute we are looking for, grabbing it's product name
                records = x[x["Minute"] == each_minute]["Product line"].to_list()

                # Find the record matching to the minute we are looking for, extracting it's id to avoid double counting
                record_ids = x[x["Minute"] == each_minute].index.to_list()
                
                # "extend" the basket, which just appends all elements, but avoids duplicates
                basket.extend([re.match(r"\w+", records[i]).group(0).lower() for i in range(len(records)) if record_ids[i] not in current_found_transaction_ids])

                # For each record that meets the criteria of being in the threshold and already not basket'd
                for each_record_id in record_ids:
                    # Add the found record id to the currently found transaction ids
                    current_found_transaction_ids.add(each_record_id)
            
            # If the basket has stuff in it, then add it to the list of baskets
            if len(basket) > 0:
                baskets[basket_id] = basket
                basket_id += 1

Now, we can run the APriori algorithm :)

In [20]:
class APriori:
    def __init__(self):
        self.candidates: List[T] = []
        self.support_threshold = None
        self.baskets: Dict[int, List[T]]

    def prune_candidates(self):
        filtered_candidates = []
        for each_candidate_set  in self.candidates: # for each candidate that exists
            support_count = 0
            setted_candidate = set(each_candidate_set)
            for each_basket in self.baskets.values(): # test against all baskets computed
                if setted_candidate.issubset(set(each_basket)): # check if the candidate is a subset of the basket
                    support_count += 1

            if support_count >= self.support_threshold:
                filtered_candidates.append(each_candidate_set)
        if len(filtered_candidates) == 0:
            raise Exception("No candidates meet support threshold, terminating")

        self.candidates = filtered_candidates

    def extend_candidates(self):
        new_candidates = []
        candidates_len = range(len(self.candidates))
        for i in candidates_len:
            for j in candidates_len:
                if i != j:
                    candidate_a = self.candidates[i]
                    candidate_b = self.candidates[j]
                    set_a = set(candidate_a)
                    set_b = set(candidate_b)
                    union_a_b = set_a.union(set_b)
                    is_valid = len(union_a_b) == len(set_a) + 1 and len(union_a_b) == len(set_b) + 1 and not list(union_a_b) in new_candidates

                    if is_valid:
                        new_candidate = set_a.union(set_b)
                        new_candidates.append(list(new_candidate))

        if len(new_candidates) == 0:
            raise Exception("Cannot extend candidates further, frequent itemsets have been found")
        
        self.candidates = new_candidates
    
    def set_candidates(self, candidates: List[T]):
        self.candidates = candidates
        return self
    
    def set_support_threshold(self, support_threshold: int):
        self.support_threshold = support_threshold
        return self
    
    def set_baskets(self, baskets: List[List[T]]):
        self.baskets = baskets
        return self
    
    def run(self):
        while True:
            try:
                self.prune_candidates()
                self.extend_candidates()
            except:
                return self.candidates


    

In [21]:
converted_candidates = set()
for each_basket in baskets.values():
    converted_candidates.update(set(each_basket))

singleton_candidates = []
for each_item in converted_candidates:
    singleton_candidates.append([each_item])

In [29]:
%%timeit

apriori = APriori().set_baskets(baskets).set_candidates(singleton_candidates).set_support_threshold(9)
apriori.run()

2.5 ms ± 122 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


However, we can expand and improve the complexity of APriori, by using a Hash-Tree data structure, I will go over the details further down this document.

In [23]:
class HashTreeNode:
    def __init__(self, branching_factor: int, hashing_function: Callable[[List[T]], int], storage_threshold: int):
        self.storage_threshold = storage_threshold
        self.stored_itemsets: List[Set[T]] = []
        self.type = HashTreeNodeType.LEAF
        self.children: List[Optional[Self]] = [None] * self.branching_factor
        self.hashing_function = hashing_function
        self.branching_factor = branching_factor

    def add_itemset(self, itemset: List[T]):

        if self.storage_threshold == len(self.stored_itemsets):
            self.type = HashTreeNodeType.INTERNAL
            cloned_itemsets = self.stored_itemsets[:]
            self.stored_itemsets = []

            for each_itemset in cloned_itemsets:
                self.add_itemset(each_itemset)

        if self.type == HashTreeNodeType.INTERNAL:
            found_branch = self.hashing_function(itemset) % self.branching_factor

            if self.children[found_branch] is not None:
                if self.children[found_branch].type == HashTreeNodeType.INTERNAL:
                    self.children[found_branch].add_itemset(itemset)
                else:
                    self.children[found_branch].stored_itemsets.append(set(itemset))
            else:
                self.children[found_branch] = HashTreeNode(self.branching_factor, self.hashing_function, self.storage_threshold)
                self.children[found_branch].stored_itemsets.append(set(itemset))

    def check_itemset(self, itemset: List[T]):
        found_branch = self.hashing_function(itemset) % self.branching_factor

        if self.children[found_branch] is not None:
            if self.children[found_branch].type == HashTreeNodeType.LEAF:
                return set(itemset) in self.children[found_branch].stored_itemsets
            
            return self.children[found_branch].check_itemset(itemset)
                    

class HashTree:
    def __init__(self, branching_factor: int, hashing_function: Callable[[List[T]], int], node_storage_threshold: int):

        if branching_factor is None or hashing_function is None:
            raise Exception("Must define branching factor and hashing function in order to properly use the Hash-Tree data structure")
        
        if branching_factor == 0:
            raise Exception("Cannot have branching factor of 0")

        self.branching_factor = branching_factor
        self.hashing_function = hashing_function
        self.root: Optional[HashTreeNode] = None
        self.node_storage_threshold = node_storage_threshold

    def add_itemset(self, itemset: List[T]):
        if self.root is None:
            self.root = HashTreeNode(self.branching_factor, self.hashing_function, self.node_storage_threshold)

        self.root.add_itemset(itemset)

    def check_itemset(self, itemset: List[T]):
        if self.root is None:
            return False
        
        return self.root.check_itemset(itemset)
        




In [24]:
class EnhancedAPriori:
    def __init__(self):
        self.candidates: List[List[T]] = []
        self.support_threshold = None
        self.baskets: Dict[int, List[T]]
        self.hash_tree : Optional[HashTree] = None

    def prune_candidates(self):
        filtered_candidates = []
        for each_candidate_set  in self.candidates: # for each candidate that exists
            support_count = 0
            setted_candidate = set(each_candidate_set)
            for each_basket in self.baskets.values(): # test against all baskets computed
                if setted_candidate.issubset(set(each_basket)): # check if the candidate is a subset of the basket
                    support_count += 1

            if support_count >= self.support_threshold:
                filtered_candidates.append(each_candidate_set)
        if len(filtered_candidates) == 0:
            raise Exception("No candidates meet support threshold, terminating")

        self.candidates = filtered_candidates

    def extend_candidates(self):
        new_candidates = []
        candidates_len = range(len(self.candidates))
        for i in candidates_len:
            for j in candidates_len:
                if i != j:
                    candidate_a = self.candidates[i]
                    candidate_b = self.candidates[j]
                    set_a = set(candidate_a)
                    set_b = set(candidate_b)
                    union_a_b = set_a.union(set_b)
                    is_valid = len(union_a_b) == len(set_a) + 1 and len(union_a_b) == len(set_b) + 1 and not list(union_a_b) in new_candidates

                    if is_valid:
                        new_candidate = set_a.union(set_b)
                        new_candidates.append(list(new_candidate))

        if len(new_candidates) == 0:
            raise Exception("Cannot extend candidates further, frequent itemsets have been found")
        
        self.candidates = new_candidates
    
    def set_candidates(self, candidates: List[T]):
        self.candidates = candidates
        return self
    
    def set_support_threshold(self, support_threshold: int):
        self.support_threshold = support_threshold
        return self
    
    def set_baskets(self, baskets: List[List[T]]):
        self.baskets = baskets
        return self
    
    def run(self):
        while True:
            try:
                self.prune_candidates()
                self.extend_candidates()
            except:
                return self.candidates

In [28]:
%%timeit

new_apriori = EnhancedAPriori().set_baskets(baskets).set_candidates(singleton_candidates).set_support_threshold(9)
new_apriori.run()

2.36 ms ± 143 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
