In [1]:
import os
import itertools
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

In [2]:
RANDOM_STATE = 24
np.random.seed(RANDOM_STATE)
NOISE = 0.1
TEST_SIZE = 0.2
ALPHA = 0.001
N_SAMPLES = 1000

In [3]:
params = {'legend.fontsize': 'large',
          'figure.figsize': (15, 7),
          'axes.labelsize': 'large',
          'axes.titlesize':'x-large',
          'xtick.labelsize':'x-large',
          'ytick.labelsize':'large',
          'savefig.dpi': 75,
          'image.interpolation': 'none',
          'savefig.bbox' : 'tight',
          'lines.linewidth' : 1,
          'legend.numpoints' : 1,
          'scatter.edgecolors': 'b' 
         }

CMAP = plt.cm.brg
plt.rcParams.update(params);
plt.set_cmap(CMAP);

plt.style.use('seaborn-v0_8-darkgrid')

<Figure size 1500x700 with 0 Axes>

In [4]:
fileName = 'groceries_mba.csv'

In [5]:
with open(fileName) as file:
    lines = file.readlines()
    lines = [sorted(list(set(line.rstrip().split(',')))) for line in lines]

In [6]:
lines[:10]

[['citrus_fruit', 'margarine', 'ready_soups', 'semi-finished_bread'],
 ['coffee', 'tropical_fruit', 'yogurt'],
 ['whole_milk'],
 ['cream_cheese_', 'meat_spreads', 'pip_fruit', 'yogurt'],
 ['condensed_milk',
  'long_life_bakery_product',
  'other_vegetables',
  'whole_milk'],
 ['abrasive_cleaner', 'butter', 'rice', 'whole_milk', 'yogurt'],
 ['rolls/buns'],
 ['UHT-milk',
  'bottled_beer',
  'liquor_(appetizer)',
  'other_vegetables',
  'rolls/buns'],
 ['pot_plants'],
 ['cereals', 'whole_milk']]

In [7]:
supps = []

# counting single item
supp = {}

for line in lines:
    for key in line:
        if key in supp:
            supp[key] += 1
        else:
            supp[key] = 1

In [8]:
len(supp)

169

In [9]:
SUPP_THRESHOLD = 100
CONF_THRESHOLD = 0.4
LIFT_THRESHOLD = 20.0
CONV_THRESHOLD = 5.0

In [10]:
f_supp = {k: v for k, v in supp.items() if v >= SUPP_THRESHOLD}
len(f_supp)

88

In [11]:
supps.append(f_supp)

In [12]:
supps

[{'citrus_fruit': 814,
  'margarine': 576,
  'semi-finished_bread': 174,
  'coffee': 571,
  'tropical_fruit': 1032,
  'yogurt': 1372,
  'whole_milk': 2513,
  'cream_cheese_': 390,
  'pip_fruit': 744,
  'condensed_milk': 101,
  'long_life_bakery_product': 368,
  'other_vegetables': 1903,
  'butter': 545,
  'rolls/buns': 1809,
  'UHT-milk': 329,
  'bottled_beer': 792,
  'pot_plants': 170,
  'bottled_water': 1087,
  'chocolate': 488,
  'white_bread': 414,
  'curd': 524,
  'dishes': 173,
  'flour': 171,
  'beef': 516,
  'frankfurter': 580,
  'soda': 1715,
  'chicken': 422,
  'fruit/vegetable_juice': 711,
  'newspapers': 785,
  'sugar': 333,
  'packaged_fruit/vegetables': 128,
  'specialty_bar': 269,
  'butter_milk': 275,
  'pastry': 875,
  'detergent': 189,
  'processed_cheese': 163,
  'candy': 294,
  'frozen_dessert': 106,
  'root_vegetables': 1072,
  'salty_snack': 372,
  'waffles': 378,
  'canned_beer': 764,
  'sausage': 924,
  'brown_bread': 638,
  'shopping_bags': 969,
  'beverages': 

In [13]:
# count two items purchased together

supp = {}

for line in lines:
    # print(line)
    for combination in itertools.combinations(line, 2):
        if combination[0] in supps[0] and combination[1] in supps[0]:
            key = ','.join(combination)
            if key in supp:
                supp[key] += 1
            else:
                supp[key] = 1

f_supp = {k: v for k, v in supp.items() if v >= SUPP_THRESHOLD}

supps.append(f_supp)

In [14]:
# count three items purchased together

supp = {}

for line in lines:
    # print(line)
    for combination in itertools.combinations(line, 3):
        if(combination[0]+','+combination[1] in supps[1] and combination[1]+','+combination[2] in supps[1]
          and combination[0]+','+combination[2] in supps[1]):
            key = ','.join(combination)
            if key in supp:
                supp[key] += 1
            else:
                supp[key] = 1

f_supp = {k: v for k, v in supp.items() if v >= SUPP_THRESHOLD}

supps.append(f_supp)

In [15]:
def measures(supp_ab, supp_a, supp_b, transaction_count):
    # Assuming a -> b; a and b are sets

    conf = np.float32(supp_ab)/supp_a

    sup = np.float32(supp_b)/transaction_count

    lift = conf/sup

    if conf == 1.0:
        conv = np.float32('inf')
    else:
        conv = (1-sup)/(1-conf)

    return [conf, sup, lift, conv]

In [16]:
transaction_count = len(lines)

rules = []

for i in range(2, len(supps)+1):
    for key, value in supps[i-1].items():
        key = key.split(',')
        for j in range(1, len(key)):
            for a in itertools.combinations(key,j):
                b = tuple([w for w in key if w not in a])
                [conf, sup, lift, conv] = measures(value, 
                                                   supps[len(a)-1][','.join(a)],
                                                   supps[len(b)-1][','.join(b)],
                                                   transaction_count)
                if conf >= CONF_THRESHOLD:
                    rules.append((a, b, conf, lift, conv))
        rules = sorted(rules, key = lambda x: [x[0], x[1]])
        rules = sorted(rules, key = lambda x: (x[2]), reverse =True)

In [35]:
rules

[(('citrus_fruit', 'root_vegetables'),
  ('other_vegetables',),
  0.5862068965517241,
  3.0296084222733612,
  1.9490594814438227),
 (('root_vegetables', 'tropical_fruit'),
  ('other_vegetables',),
  0.5845410628019324,
  3.020999134344196,
  1.941244487532661),
 (('butter', 'other_vegetables'),
  ('whole_milk',),
  0.5736040609137056,
  2.244884973770909,
  1.745992204711066),
 (('root_vegetables', 'tropical_fruit'),
  ('whole_milk',),
  0.5700483091787439,
  2.2309690094599866,
  1.7315526410492221),
 (('root_vegetables', 'yogurt'),
  ('whole_milk',),
  0.562992125984252,
  2.2033535849801504,
  1.7035939854445192),
 (('domestic_eggs', 'other_vegetables'),
  ('whole_milk',),
  0.5525114155251142,
  2.1623357627097084,
  1.663693804924105),
 (('whipped/sour_cream', 'yogurt'),
  ('whole_milk',),
  0.5245098039215687,
  2.052747282757114,
  1.5657188978977878),
 (('rolls/buns', 'root_vegetables'),
  ('whole_milk',),
  0.5230125523012552,
  2.04688756541299,
  1.5608041455953048),
 (('oth