- Necessary Imports
- Declaring configuration type
- Declaring constants for configuration keys (used in setting non-method accessible attributes)

In [27]:
import json
from dataclasses import dataclass
from typing import Optional, Dict, Any, TypeVar, List, Callable, Self, Set, Tuple
from pyspark import SparkContext, SparkConf
from pandas import read_csv, to_datetime, DataFrame
import re
from enum import Enum

T = TypeVar("T")

class HashTreeNodeType(Enum):
    INTERNAL = 0
    LEAF = 1

# Fields for https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.SparkConf.html#pyspark.SparkConf
@dataclass
class Configuration:
    appName: str
    bindAddress: str
    bindPort: str
    masterUrl: str


CONFIG_BIND_ADDRESS_KEY = "spark.driver.bindAddress"
CONFIG_BIND_PORT_KEY = "spark.ui.port"

Parsing the spark configuration values from the json file

In [2]:
configuration_values: Optional[Dict[str, Any]] = None

with open("configuration.json", "r") as configuration_file:
    configuration_values = json.loads(configuration_file.read())

print("FAILED TO LOAD" if configuration_values == None else configuration_values)

{'appName': 'APriori Example', 'masterUrl': 'local', 'bindAddress': 'localhost', 'bindPort': '4050'}


Creating the spark configuration from the parsed json configuration

In [3]:
configuration = SparkConf()

spark_config: Optional[Configuration] = Configuration(**configuration_values)

if spark_config == None:
    raise ValueError("Must supply configuration, or keep defaults")

configuration.setAppName(spark_config.appName)
configuration.setMaster(spark_config.masterUrl)
configuration.set(CONFIG_BIND_ADDRESS_KEY, spark_config.bindAddress)
configuration.set(CONFIG_BIND_PORT_KEY, spark_config.bindPort)

<pyspark.conf.SparkConf at 0x2057f1f6030>

Instantiating & Creating the SparkContext

In [4]:
spark_context: Optional[SparkContext] = None

if spark_context is not None:
    spark_context.stop()

spark_context = SparkContext.getOrCreate(conf=configuration)

Create "baskets" for APriori algorithm

In [5]:
supermarket_df = read_csv("./data/supermarket_sales.csv")

## Some Data Augmenting
# Adding Month, Year, and Day columns which correspond to the respective Month/Year/Day values of the dates
supermarket_df["Date"] = to_datetime(supermarket_df["Date"])
supermarket_df["Hour"] = supermarket_df["Time"].map(lambda x: int(x.split(":")[0]))
supermarket_df["Minute"] = supermarket_df["Time"].map(lambda x: int(x.split(":")[1]))

# The goal is to form "baskets" from the transactions of each month, which we will use for our APriori analysis
supermarket_df = supermarket_df.sort_values(by="Date")

# The current basket id
basket_id = 1

# The collection of baskets
transactions = {}

# The minute threshold
minute_threshold = 10

# The unique dates, which represent the unique day/month/year values
unique_dates = supermarket_df["Date"].unique()

# Iterate over each unique date
for each_date in unique_dates:

    # Find all dates that match the current `each_date`
    x = supermarket_df[(supermarket_df["Date"].dt.year == each_date.year) & (supermarket_df["Date"].dt.month == each_date.month) & (supermarket_df["Date"].dt.day == each_date.day)]
    
    # Find all unique hours for that specific date
    unique_hours = set([int(y) for y in x["Hour"].unique()])

    # Iterate over all hours that transactions occurred in that date
    for each_hour in unique_hours:

        # Find all hour transactions that match the selected hour
        hour_transactions = x[x["Hour"] == each_hour]

        # Find all unique minutes within that hour
        unique_minutes = sorted([int(y) for y in hour_transactions["Minute"].unique()])

        # Running set of the currently "basket'd" transactions
        current_found_transaction_ids = set()
        for i in range(len(unique_minutes) - 1):

            # The current basket
            basket = []

            # The current minute
            curr_minute = unique_minutes[i]

            # All minutes that meet the threshold
            future_minutes = list(filter(lambda x: x - curr_minute <= minute_threshold, unique_minutes[i + 1:]))

            # Add current minute to the minutes to find
            total_minutes_to_find = [curr_minute] + future_minutes
            for each_minute in total_minutes_to_find:

                # Find the record matching to the minute we are looking for, grabbing it's product name
                records = x[x["Minute"] == each_minute]["Product line"].to_list()

                # Find the record matching to the minute we are looking for, extracting it's id to avoid double counting
                record_ids = x[x["Minute"] == each_minute].index.to_list()
                
                # "extend" the basket, which just appends all elements, but avoids duplicates
                basket.extend([re.match(r"\w+", records[i]).group(0).lower() for i in range(len(records)) if record_ids[i] not in current_found_transaction_ids])

                # For each record that meets the criteria of being in the threshold and already not basket'd
                for each_record_id in record_ids:
                    # Add the found record id to the currently found transaction ids
                    current_found_transaction_ids.add(each_record_id)
            
            # If the basket has stuff in it, then add it to the list of baskets
            if len(basket) > 0:
                transactions[basket_id] = basket
                basket_id += 1

Now, we can run the APriori algorithm :)

In [6]:
class APriori:
    def __init__(self):
        self.candidates: List[T] = []
        self.support_threshold = None
        self.baskets: Dict[int, List[T]]

    def prune_candidates(self):
        filtered_candidates = []
        for each_candidate_set  in self.candidates: # for each candidate that exists
            support_count = 0
            setted_candidate = set(each_candidate_set)
            for each_basket in self.baskets.values(): # test against all baskets computed
                if setted_candidate.issubset(set(each_basket)): # check if the candidate is a subset of the basket
                    support_count += 1

            if support_count >= self.support_threshold:
                filtered_candidates.append(each_candidate_set)
        if len(filtered_candidates) == 0:
            raise Exception("No candidates meet support threshold, terminating")

        self.candidates = filtered_candidates

    def extend_candidates(self):
        new_candidates = []
        candidates_len = range(len(self.candidates))
        for i in candidates_len:
            for j in candidates_len:
                if i != j:
                    candidate_a = self.candidates[i]
                    candidate_b = self.candidates[j]
                    set_a = set(candidate_a)
                    set_b = set(candidate_b)
                    union_a_b = set_a.union(set_b)
                    is_valid = len(union_a_b) == len(set_a) + 1 and len(union_a_b) == len(set_b) + 1 and not list(union_a_b) in new_candidates

                    if is_valid:
                        new_candidate = set_a.union(set_b)
                        new_candidates.append(list(new_candidate))

        if len(new_candidates) == 0:
            raise Exception("Cannot extend candidates further, frequent itemsets have been found")
        
        self.candidates = new_candidates
    
    def set_candidates(self, candidates: List[T]):
        self.candidates = candidates
        return self
    
    def set_support_threshold(self, support_threshold: int):
        self.support_threshold = support_threshold
        return self
    
    def set_baskets(self, baskets: List[List[T]]):
        self.baskets = baskets
        return self
    
    def run(self):
        while True:
            try:
                self.prune_candidates()
                self.extend_candidates()
            except:
                return self.candidates


    

In [7]:
converted_candidates = set()
for each_basket in transactions.values():
    converted_candidates.update(set(each_basket))

singleton_candidates = []
for each_item in converted_candidates:
    singleton_candidates.append([each_item])

In [8]:
%%timeit

apriori = APriori().set_baskets(transactions).set_candidates(singleton_candidates).set_support_threshold(9)
apriori.run()

2.52 ms ± 92.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


However, we can expand and improve the complexity of APriori, by using a Hash-Tree data structure, I will go over the details further down this document.

In [166]:
class HashTreeNode:
    def __init__(self, branching_factor: int, hashing_function: Callable[[List[T]], int], storage_threshold: int, support_threshold: int):    
        self.storage_threshold = storage_threshold
        self.hashing_function = hashing_function
        self.branching_factor = branching_factor
        self.stored_itemsets: Set[str] = set()
        self.support_count = 0
        self.type = HashTreeNodeType.LEAF
        self.children: List[Optional[Self]] = [None] * self.branching_factor
        self.support_threshold = support_threshold
        self.candidates_met_threshold = False

    def add_itemset(self, itemset: List[T]):

        if self.storage_threshold == len(self.stored_itemsets):
            self.type = HashTreeNodeType.INTERNAL
            cloned_itemsets = self.stored_itemsets.copy()
            self.stored_itemsets = set()
            self.support_count = 0

            for each_itemset in cloned_itemsets:
                self.add_itemset(each_itemset)

        if self.type == HashTreeNodeType.INTERNAL:
            found_branch = self.hashing_function(itemset) % self.branching_factor

            if self.children[found_branch] is not None:
                if self.children[found_branch].type == HashTreeNodeType.INTERNAL:
                    self.children[found_branch].add_itemset(itemset)
                else:
                    self.children[found_branch].stored_itemsets.add(str(set(itemset)))
            else:
                self.children[found_branch] = HashTreeNode(self.branching_factor, self.hashing_function, self.storage_threshold, self.support_threshold)
                self.children[found_branch].stored_itemsets.add(str(set(itemset)))
        
        else:
            self.stored_itemsets.add(str(set(itemset)))
    
    def process_itemset(self, itemset: List[T]) -> Tuple[bool, Set[str]]:
        if self.type == HashTreeNodeType.INTERNAL:
            found_branch = self.hashing_function(itemset) % self.branching_factor
            
            if self.children[found_branch] is not None:
                if self.children[found_branch].type == HashTreeNodeType.LEAF:
                    leaf_node = self.children[found_branch]
                    setted_itemset = set(itemset)
                    is_in_children = any([eval(each_child) in setted_itemset for each_child in leaf_node.stored_itemsets])
                    
                    if is_in_children:
                        self.children[found_branch].support_count += 1
                        
                        if leaf_node.support_count >= leaf_node.support_threshold and not leaf_node.candidates_met_threshold:
                            leaf_node.candidates_met_threshold = True
                            return (True, leaf_node.stored_itemsets)
                        
                    return (False, [])
                return self.children[found_branch].process_itemset(itemset)
        else:
            setted_itemset = set(itemset)
            is_in_children = any([eval(each_child).issubset(setted_itemset) for each_child in self.stored_itemsets])

            if is_in_children:
                return (True, self.stored_itemsets)
            
        return (False, [])
                    

class HashTree:
    def __init__(self, branching_factor: int, hashing_function: Callable[[List[T]], int], node_storage_threshold: int, support_threshold: int):

        if branching_factor is None or hashing_function is None:
            raise Exception("Must define branching factor and hashing function in order to properly use the Hash-Tree data structure")
        
        if branching_factor == 0:
            raise Exception("Cannot have branching factor of 0")

        if node_storage_threshold <= 0:
            raise ValueError("Invalid `node_storage_threshold` value")
        
        if support_threshold <= 0:
            raise ValueError("Invalid `support_threshold` value")

        self.branching_factor = branching_factor
        self.hashing_function = hashing_function
        self.root: Optional[HashTreeNode] = None
        self.node_storage_threshold = node_storage_threshold
        self.support_threshold = support_threshold

    def add_itemset(self, itemset: List[T]):
        if self.root is None:
            self.root = HashTreeNode(self.branching_factor, self.hashing_function, self.node_storage_threshold, self.support_threshold)

        self.root.add_itemset(itemset)

    def process_itemset(self, itemset: List[T]):
        if self.root is None:
            return (False, [])
        
        return self.root.process_itemset(itemset)
        
        




In [168]:
class EnhancedAPriori:
    def __init__(self):
        self.candidates: List[List[T]] = []
        self.support_threshold = None
        self.baskets: Dict[int, List[T]]
        self.branching_factor = None
        self.hashing_function = None
        self.node_storage_threshold = None
        self.hash_tree : Optional[HashTree] = None
        self.candidate_supports = {}
        
        
    def create_hash_tree(self) -> HashTree:
        if self.support_threshold is None or self.branching_factor is None or self.hashing_function is None or self.node_storage_threshold is None:
            raise ValueError("Invalid configuration for HashTree instantiation")
        
        self.hash_tree = HashTree(self.branching_factor, self.hashing_function, self.node_storage_threshold, self.support_threshold)

    def prune_candidates(self) -> Self:
        """
        "Prunes" the existing candidate set, using the support threshold to determine whether any candidates in the current set
        are not considered "frequent" or not

        Raises:
            Exception: If the pruning step removes all existing candidates, then return the candidates before pruning, as they are the most frequent

        Returns:
            Self: The mutated APriori instance
        """
        filtered_candidates = []

        for each_basket in self.baskets.values(): # test against all baskets computed
            (found, candidates) = self.hash_tree.process_itemset(each_basket)
            
            if not found:
                continue
            
            for each_candidate in candidates:
                self.candidate_supports[str(each_candidate)] += 1
                
                
        for each_candidate, each_candidate_count in self.candidate_supports.items():
            if each_candidate_count >= self.support_threshold:
                filtered_candidates.append(eval(each_candidate))

        if len(filtered_candidates) == 0:
            raise Exception("No candidates meet support threshold, terminating")

        self.candidates = filtered_candidates
        return self

    def extend_candidates(self) -> Self:
        """
        Extends the candidates, extension means forming a new candidate set which we will further prune as well.

        Raises:
            Exception: If it's impossible to further extend the candidate set

        Returns:
            Self: The mutated APriori instance
        """
        new_candidates = []
        candidates_len = range(len(self.candidates))

        for i in candidates_len:
            for j in candidates_len:
                if i != j:
                    candidate_i = self.candidates[i]
                    candidate_j = self.candidates[j]
                    set_i = set(candidate_i)
                    set_j = set(candidate_j)
                    union_i_j = set_i.union(set_j)
                    is_extended_itemset_valid = len(union_i_j) == len(set_i) + 1 and len(union_i_j) == len(set_j) + 1 and not list(union_i_j) in new_candidates

                    if is_extended_itemset_valid:
                        new_candidate = set_i.union(set_j)
                        new_candidates.append(list(new_candidate))

        if len(new_candidates) == 0:
            raise Exception("Cannot extend candidates further, frequent itemsets have been found")
        
        return self.set_candidates(new_candidates)
    
    def set_candidates(self, candidates: List[List[T]]) -> Self:
        """
        Sets the "candidates", or the itemsets which could potentially be frequent itemsets

        Args:
            candidates (List[T]): The currently computed frequent itemsets

        Returns:
            Self: The mutated APriori instance
        """
        self.candidates = candidates
        self.hash_tree_candidates()
        return self
    
    def hash_tree_candidates(self):
        
        if self.hash_tree is None:
            raise ValueError("HashTree is not instantiated")
        
        for each_candidate in self.candidates:
            self.candidate_supports[str(set(each_candidate))] = 0
            self.hash_tree.add_itemset(each_candidate)
    
    def set_support_threshold(self, support_threshold: int) -> Self:
        """
        Sets the "support threshold" for the APriori class, this is the threshold by which we consider a itemset "frequent" or not

        Args:
            support_threshold (int): The support threshold we apply to the candidate sets which is involved in the "pruning" step

        Returns:
            Self: The mutated APriori instance
        """
        self.support_threshold = support_threshold
        return self
    
    def set_baskets(self, baskets: List[List[T]]) -> Self:
        """
        Sets the "baskets", or the parsed transaction database, into a list of transactions that we use to measure the
        support value of the candidate set

        Args:
            baskets (List[List[T]]): The parsed transactions

        Returns:
            Self: The mutated APriori instance
        """
        self.baskets = baskets
        return self
    
    def set_branching_factor(self, branching_factor: int = 5) -> Self:
        self.branching_factor = branching_factor
        return self
    
    def set_hashing_function(self, hashing_function: Callable[[List[T]], int]) -> Self:
        self.hashing_function = hashing_function
        return self
    
    def set_node_storage_threshold(self, node_storage_threshold: int = 10) -> Self:
        self.node_storage_threshold = node_storage_threshold
        return self
    
    def run(self) -> List[T]:
        """
        This is the core function that runs the APriori algorithm on the dataset fields set in the instantiation of the APriori class

        Returns:
            List[T]: The frequent itemsets calculated using the APriori algorithm
        """
        while True:
            try:
                self.prune_candidates()
                self.create_hash_tree()
                self.extend_candidates()
            except Exception as error:
                return self.candidates

For the APriori with HashTree to work, we first must define a hashing function, branching factor, and node_storage_threshold.
Let's do just that

In [16]:
branching_factor = 10
hashing_function = lambda x: sum([sum([ord(z) for z in y]) for y in x])
node_storage_threshold = 7

In [169]:
new_apriori = EnhancedAPriori().set_support_threshold(9).set_branching_factor(10).set_hashing_function(hashing_function).set_node_storage_threshold(7)
new_apriori.create_hash_tree()
new_apriori.set_baskets(transactions).set_candidates(singleton_candidates)
new_apriori.run()

KeyboardInterrupt: 

In [25]:
class APrioriTid:
    def __init__(self):
        self.candidates: List[List[T]] = []
        self.support_threshold = None
        self.transactions: Dict[int, List[T]] = {}
        self.hash_tree : Optional[HashTree] = None
        self.transaction_db: Dict[int, List[T]] = {}

    def prune_candidates(self) -> Self:
        """
        "Prunes" the existing candidate set, using the support threshold to determine whether any candidates in the current set
        are not considered "frequent" or not

        Raises:
            Exception: If the pruning step removes all existing candidates, then return the candidates before pruning, as they are the most frequent

        Returns:
            Self: The mutated APriori instance
        """
        filtered_candidates = []
        for each_candidate_set  in self.candidates: # for each candidate that exists
            support_count = 0
            setted_candidate = set(each_candidate_set)

            for each_transaction_id in self.transactions: # test against all baskets computed
                each_transaction = self.transactions[each_transaction_id]
                if setted_candidate.issubset(set(each_transaction)): # check if the candidate is a subset of the basket
                    support_count += 1
                    self.transaction_db[each_transaction_id].append(setted_candidate) # add the candidate to the list

            if support_count >= self.support_threshold:
                filtered_candidates.append(each_candidate_set)

        # APrioriTid implementation
        delete_keys = []
        for each_key in self.transaction_db:
            if len(self.transaction_db[each_key]) == 0:
                delete_keys.append(each_key)

        for each_delete_key in delete_keys:
            del self.transactions[each_delete_key]

            self.transaction_db[each_key] = [] # reset list

        if len(filtered_candidates) == 0:
            raise Exception("No candidates meet support threshold, terminating")

        self.candidates = filtered_candidates
        return self

    def extend_candidates(self) -> Self:
        """
        Extends the candidates, extension means forming a new candidate set which we will further prune as well.

        Raises:
            Exception: If it's impossible to further extend the candidate set

        Returns:
            Self: The mutated APriori instance
        """
        new_candidates = []
        candidates_len = range(len(self.candidates))

        for i in candidates_len:
            for j in candidates_len:
                if i != j:
                    candidate_i = self.candidates[i]
                    candidate_j = self.candidates[j]
                    set_i = set(candidate_i)
                    set_j = set(candidate_j)
                    union_i_j = set_i.union(set_j)
                    is_extended_itemset_valid = len(union_i_j) == len(set_i) + 1 and len(union_i_j) == len(set_j) + 1 and not list(union_i_j) in new_candidates

                    if is_extended_itemset_valid:
                        new_candidate = set_i.union(set_j)
                        new_candidates.append(list(new_candidate))

        if len(new_candidates) == 0:
            raise Exception("Cannot extend candidates further, frequent itemsets have been found")
        
        self.candidates = new_candidates

        return self
    
    def set_candidates(self, candidates: List[List[T]]) -> Self:
        """
        Sets the "candidates", or the itemsets which could potentially be frequent itemsets

        Args:
            candidates (List[T]): The currently computed frequent itemsets

        Returns:
            Self: The mutated APriori instance
        """
        self.candidates = candidates
        return self
    
    def set_support_threshold(self, support_threshold: int) -> Self:
        """
        Sets the "support threshold" for the APriori class, this is the threshold by which we consider a itemset "frequent" or not

        Args:
            support_threshold (int): The support threshold we apply to the candidate sets which is involved in the "pruning" step

        Returns:
            Self: The mutated APriori instance
        """
        self.support_threshold = support_threshold
        return self
    
    def set_baskets(self, baskets: Dict[int, List[T]]) -> Self:
        """
        Sets the "baskets", or the parsed transaction database, into a list of transactions that we use to measure the
        support value of the candidate set

        Args:
            baskets (List[List[T]]): The parsed transactions

        Returns:
            Self: The mutated APriori instance
        """
        self.transactions = baskets

        for each_transaction_id in baskets:
            self.transaction_db[each_transaction_id] = set()

        return self
    
    def run(self) -> List[T]:
        """
        This is the core function that runs the APriori algorithm on the dataset fields set in the instantiation of the APriori class

        Returns:
            List[T]: The frequent itemsets calculated using the APriori algorithm
        """
        while True:
            try:
                self.prune_candidates()
                self.extend_candidates()
            except:
                return self.candidates