In [1]:
from csv import reader
from collections import defaultdict, OrderedDict
from itertools import chain, combinations
from datetime import datetime

In [2]:
class Node:
    # node constructor; one head node and its children represents itemsets
    def __init__(self, item_name, frequency, parent_node):
        self.item_name = item_name
        self.frequency = frequency
        self.parent = parent_node
        self.children = {}
        self.next = None

    def addFrequency(self, frequency):
        self.frequency += frequency

In [3]:
# read the input file and create a list of itemset and a frequency list
def readFile(fpath):
    itemset_list = []
    frequency = []

    with open(fpath, 'r') as file:
        csv_reader = reader(file)
        for line in csv_reader:
            if line == "": continue
            if line == " ": continue
            new_line = []
            for item in line:
                item = item.replace(" ", "") 
                new_line.append(item)
            itemset_list.append(new_line)  
            frequency.append(1)
            
    return itemset_list, frequency

In [4]:
def buildFPTree(itemset_list, frequency, min_sup):
    header_table = defaultdict(int)
    
    # create a header table
    for index in range(len(itemset_list)):
        for item in itemset_list[index]:
            header_table[item] += frequency[index]    
            
    # select items statisfies min_sup
    header_table = dict((item, sup) for item, sup in header_table.items() if sup >= min_sup)  
    if len(header_table) == 0:
        return None, None      
    
    # header table - {item: [frequency, head node]}
    for item in header_table:
        header_table[item] = [header_table[item], None]
        
    # initialize a Null head node
    fpTree = Node('None', 1, None)
    
    # update the FP Tree with item and freqency for each itemset 
    for index in range(len(itemset_list)):
        itemset = []
        for item in itemset_list[index]:
            if item in header_table:
                itemset.append(item)
        itemset.sort(key = lambda item: header_table[item][0], reverse = True)
        current_node = fpTree
        for item in itemset:
            # update the main FP tree
            current_node = updateFPTree(item, current_node, header_table, frequency[index])
            
    return fpTree, header_table

In [5]:
def updateFPTree(item, node, header_table, frequency):
    # if the item exists in the FP Tree, increase its frequency
    if item in node.children:
        node.children[item].addFrequency(frequency)
    # if not, create a new node
    else:
        new_node = Node(item, frequency, node)
        node.children[item] = new_node
        # update the header table with new node
        updateHeaderTable(item, new_node, header_table)         
    
    return node.children[item]    

In [6]:
def updateHeaderTable(item, node, header_table):
    if header_table[item][1] == None:
        header_table[item][1] = node
    else:
        current_node = header_table[item][1]
        # reach the last node and then link it to the current node
        while current_node.next != None:
            current_node = current_node.next
        current_node.next = node

In [7]:
# generate conditional FP trees
def condiFPTree(header_table, min_sup, pre, freq_item_list):
    # sort the header table by frequency and create a list
    sorted_item_list = []
    for item in sorted(list(header_table.items()), key=lambda p:p[1][0]):
        sorted_item_list.append(item[0])
        
    for item in sorted_item_list:
        new_freqset = pre.copy()
        new_freqset.add(item)
        freq_item_list.append(new_freqset)
        # get all the paths containing the item
        conditional_pattern, frequency = findPath(item, header_table)
        # build conditional FP Tree
        conditional_tree, new_header_table =  buildFPTree(conditional_pattern, frequency, min_sup) 
        if new_header_table:
            condiFPTree(new_header_table, min_sup, new_freqset, freq_item_list)

In [8]:
def findPath(item, header_table):
    node = header_table[item][1] 
    pattern = []
    frequency = []
    while node:
        path = []
        # extend the path to the root node
        pathToRoot(node,path)  
        if len(path) > 1:
            pattern.append(path[1:])
            frequency.append(node.frequency)
        node = node.next  
    return pattern, frequency

def pathToRoot(node, path):
    if node.parent:
        path.append(node.item_name)
        pathToRoot(node.parent, path)

In [9]:
def getSubsets(itemset):  
    s = list(itemset)
    return chain.from_iterable(combinations(s, r) for r in range(1,len(s)))

In [10]:
def getSupport(target_itemset, itemset_list):
    count = 0
    for itemset in itemset_list:
        if(set(target_itemset).issubset(itemset)):
            count += 1
    return count    

In [11]:
def associationRule(freq_item_list, itemset_list, min_conf):
    rules = []
    for itemset in freq_item_list:
        subsets = getSubsets(itemset) # get all the subsets of the itemset   
        itemset_sup = getSupport(itemset, itemset_list)
        for subset in subsets:
            confidence = float(itemset_sup / getSupport(subset, itemset_list))
            if confidence >= min_conf:
                rules.append([set(subset), set(itemset.difference(subset)), confidence]) # difference = itemset - subset
    return rules

In [14]:
def fpgrowth(fpath, min_sup, min_conf):    
    itemset_list, frequency = readFile(fpath)    
    fpTree, header_table = buildFPTree(itemset_list, frequency, min_sup)      
    freq_item_list = []
    condiFPTree(header_table, min_sup, set(), freq_item_list)
    rules = associationRule(freq_item_list, itemset_list, min_conf)
    
    print("freq_item_list")
    for itemset in freq_item_list:
        print('itemset: ' + str(itemset))
#     print("rules")
#     for rule in rules:
#         print(str(rule[0]) + ' ⇒ ' + str(rule[1]) + ', conf = ' + str(round(rule[2], 3)))

In [15]:
starting_time = datetime.now()
fpgrowth('./adult.csv', 0.6*30162, 0.7)
ending_time = datetime.now()
running_time = (ending_time - starting_time).total_seconds()
print("Running time is " + str(running_time) + "s")

freq_item_list
itemset: {'Male'}
itemset: {'United-States', 'Male'}
itemset: {'Private'}
itemset: {'White', 'Private'}
itemset: {'United-States', 'Private'}
itemset: {'<=50K'}
itemset: {'White', '<=50K'}
itemset: {'United-States', '<=50K'}
itemset: {'White'}
itemset: {'White', 'United-States'}
itemset: {'United-States'}
rules
{'Male'} ⇒ {'United-States'}, conf = 0.911
{'White'} ⇒ {'Private'}, conf = 0.738
{'Private'} ⇒ {'White'}, conf = 0.859
{'United-States'} ⇒ {'Private'}, conf = 0.732
{'Private'} ⇒ {'United-States'}, conf = 0.903
{'White'} ⇒ {'<=50K'}, conf = 0.736
{'<=50K'} ⇒ {'White'}, conf = 0.843
{'United-States'} ⇒ {'<=50K'}, conf = 0.746
{'<=50K'} ⇒ {'United-States'}, conf = 0.905
{'White'} ⇒ {'United-States'}, conf = 0.934
{'United-States'} ⇒ {'White'}, conf = 0.881
Running time is 1.078205s
