In [5]:
import json
print("READING DATASET")
dataset = None
with open('modified_coco.json') as f:
    dataset = json.load(f)
print("DONE")
# dataset row:
# {'file_name': '000000095096.png', 'image_id': 95096, 'annotations': ['car', 'car', 'train', 'stop sign']}

READING DATASET
DONE


In [6]:
from tqdm import tqdm

def find_items(dataset):
    """Given a dataset of transactions return all possible values.
    
    Input: dataset, a list of tuples representing transactions
    Output: a tuple of all possible items ordered lexicographically
    """
    items = set()
    for t in dataset:
        for item in t:
            items.add(item)
    items = list(items)
    items.sort()
    return tuple(items)

def contain(subset, superset):
    """Return True iff subset is contained in superset
    
    Input: subset, a tuple of values
           superset, a tuple of values
    Output: True iff superset contains subset, false otherwise
    """
    result = True
    for item in subset:
        if item not in superset:
            result = False
            break
    return result


def support(itemset, dataset):
    """Return the support value of a given itemset among a given dataset

    Input: itemset, a tuple of string values;
           dataset, a list of tuples of string values
    Output: a float number representing the support value
    """
    num = 0.0
    for item in dataset:
        num += 1 if contain(itemset, item) else 0
    den = float(len(dataset))
    return num / den

def check_prefix(t1, t2, length):
    """Check if 2 structures t1 and t2 has the same prefix of a specified length.

    Input: - t1 & t2: two ordered structures of the same length
           - length: the prefix length to check
    Output: true iff t1 and t2 have the same prefix, false otherwise.
    """
    if len(t1) != len(t2) or len(t1) < length:
        print("ERROR")
        return False
    for i in range(length):
        if t1[i] != t2[i]:
            return False
    return True

def apriori(dataset, min_sup=0.5):
    """Apriori algorithm implementation for finding association rules.

    Input: dataset, the entire dataset as list of lists representing transactions;
           min_sup, float value representing the minimum support for result filtering.
    Output: a matrix with two columns:
    the first contain an itemset tuple, the second its support value.
    A row is present iff the value in the second column is >= min_sup.
    """
    result = [[], []]
    items = find_items(dataset) # tuple
    itemsets = [(i,) for i in items]
    for level in tqdm(range(len(items))):
        if len(itemsets) == 0:
            break
        for i in itemsets:
            sup = support(i, dataset)
            if sup >= min_sup:
                result[0].append(i)
                result[1].append(sup)
            else:
                itemsets.remove(i)

        new_itemsets = []
        for i in range(len(itemsets)-1):
            for j in range(i+1, len(itemsets)):
                if check_prefix(itemsets[i], itemsets[j], level) == True:
                    tmp = list(set(itemsets[i] + itemsets[j]))
                    tmp.sort()
                    new_itemsets.append(tuple(tmp))
                else:
                    break
                    
        itemsets = new_itemsets
    return result

In [7]:
print("ELABORATING DATASET")
ds_elaborated = []
for row in dataset:
    tmp = set(row["annotations"])
    ds_elaborated.append(tmp)
print("DONE")

ELABORATING DATASET
DONE


In [8]:
res = apriori(ds_elaborated, 0.02)
for i in range(len(res[0])):
    print(res[0][i], " -> ", res[1][i])

  8%|███▍                                        | 6/78 [00:05<01:11,  1.00it/s]

('backpack',)  ->  0.0852
('baseball glove',)  ->  0.03
('bench',)  ->  0.4338
('bicycle',)  ->  0.0762
('bird',)  ->  0.02
('bottle',)  ->  0.028
('bus',)  ->  0.0912
('cell phone',)  ->  0.0354
('chair',)  ->  0.0602
('clock',)  ->  0.0332
('cup',)  ->  0.0222
('dining table',)  ->  0.0386
('dog',)  ->  0.0276
('fire hydrant',)  ->  0.1346
('person',)  ->  0.5886
('skateboard',)  ->  0.0344
('stop sign',)  ->  0.1332
('suitcase',)  ->  0.0202
('tennis racket',)  ->  0.0214
('train',)  ->  0.0492
('truck',)  ->  0.1286
('backpack', 'car')  ->  0.0356
('backpack', 'person')  ->  0.0826
('backpack', 'traffic light')  ->  0.0352
('baseball bat', 'baseball glove')  ->  0.0242
('baseball glove', 'bench')  ->  0.0292
('bench', 'bottle')  ->  0.0244
('bench', 'car')  ->  0.0712
('bench', 'chair')  ->  0.0506
('bench', 'dog')  ->  0.0206
('bench', 'person')  ->  0.3208
('bench', 'potted plant')  ->  0.0286
('bench', 'train')  ->  0.0206
('bicycle', 'car')  ->  0.042
('bicycle', 'person')  -> 


