In [12]:
###############################################################################
# Author: Carlos Bobed
# Date: Sept 2020
# Comments: Code to obtain a word embedding model out from a transactional database
# in Vreeken et al. database format
# Modifications:
###############################################################################

import gensim, logging, os, sys, gzip
import time

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',filename='word2vec.out', level=logging.INFO)

## Adapted

class MySentences(object): 
    def __init__(self, filename):
        self.filename = filename

    def __iter__(self):
        try: 
            for line in open(self.filename, mode='rt', encoding='UTF-8'): 
                if (line.split(':')[0].isnumeric()): 
                    aux = line.split(':')[1].rstrip('\n')
                    words = filter(None,aux.split(' '))
                    yield list(words)
        except Exception: 
            print ('Failed reading file: ')
            print (self.filename)

def custom_order (item_list): 
    atEnd = False; 
    result = []
    for it in sorted(item_list): 
        if not atEnd: 
            result.insert(0,it)
        else:
            result.append(it)
        atEnd = not atEnd
    return result

# THis method assumes that the database items are ordered 
# according to their support, this is, is a conversed database (vreeken's format)
# 0 => highest suppport
class MyOrderedSentences(object): 
    def __init__(self, filename):
        self.filename = filename

    def __iter__(self):
        try: 
            for line in open(self.filename, mode='rt', encoding='UTF-8'): 
                if (line.split(':')[0].isnumeric()): 
                    aux = line.split(':')[1].rstrip('\n')
                    words = filter(None,aux.split(' '))
                    ints = [int(it) for it in words]
                    words = [str(it) for it in custom_order(ints)]
                    yield list(words)
        except Exception: 
            print ('Failed reading file: ')
            print (self.filename)

            
class MySentencesDat(object): 
    def __init__(self, filename):
        self.filename = filename

    def __iter__(self):
        try: 
            for line in open(self.filename, mode='rt', encoding='UTF-8'): 
                aux = line.rstrip('\n')
                words = filter(None,aux.split(' '))
                yield list(words)
        except Exception: 
            print ('Failed reading file: ')
            print (self.filename)
            
            
## method to read for the Vreeken's codetable format 
## we don't need it to be a generator 
## we do label each code and honor the order in the code table (length, support, lexicographical)
## following Pierre's suggestion, we keep track of the codes and the transaction IDs 
def read_codetable(filename, load_all): 
    codes = {}
    label = 0 
    with open(filename, mode='rt', encoding='UTF-8') as file: 
        for line in file: 
            item_line = list(filter(None, line.rstrip('\n').split(' ')))
            ## only_used => those codes whose usage is > 0
            ## we get the last token, check whether it ends with )
            ## then, we get exactly the contents and check whether the first 
            ## component is different from 0
            if (item_line[-1].endswith(')')): 
                usage,support = item_line[-1][1:-1].split(',')
                if (load_all or int(usage) != 0):
                    codes[label]={'code': item_line[:-1], 'usage':int(usage), 'support':int(support)}
                    label+=1
    return codes        

# def read_codetable_supports(filename): 
#     codes = {}
#     label = 0 
#     with open(filename, mode='rt', encoding='UTF-8') as file: 
#         for line in file: 
#             item_line = list(filter(None, line.rstrip('\n').split(' ')))
#             ## only_used => those codes whose usage is > 0
#             ## we get the last token, check whether it ends with )
#             ## then, we get exactly the contents and check whether the first 
#             ## component is different from 0
#             if (item_line[-1].endswith(')')): 
#                 usage,support = item_line[-1][1:-1].split(',')
#                 codes[label] = {'usage':usage, 'support':support}
#                 label+=1
#     return codes     


## to keep track of the transactions id in the database, we read them in a different method
def read_database_db (filename): 
    transactions = {}
    label = 0
    with open(filename, mode='rt', encoding='UTF-8') as file: 
        for line in open(filename, mode='rt', encoding='UTF-8'): 
            if (line.split(':')[0].isnumeric()): 
                aux = line.split(':')[1].rstrip('\n')
                words = filter(None,aux.split(' '))
                transactions[label] = list(words)
                label+=1
    return transactions

def read_database_dat(filename): 
    transactions = {}
    label = 0
    with open(filename, mode='rt', encoding='UTF-8') as file: 
        for line in open(filename, mode='rt', encoding='UTF-8'): 
            aux = line.rstrip('\n')
            words = filter(None,aux.split(' '))
            transactions[label] = list(words)
            label+=1
    return transactions

In [13]:
a = [23,49,0,9,3,12, 74]
b = []
atEnd = False; 
for it in sorted(a): 
    if not atEnd: 
        b.insert(0,it)
    else:
        b.append(it)
    atEnd = not atEnd
print(f'{a} {b} {custom_order(a)}')

[23, 49, 0, 9, 3, 12, 74] [74, 23, 9, 0, 3, 12, 49] [74, 23, 9, 0, 3, 12, 49]


In [15]:
database_name = 'connect'
database_filename = database_name + '.db'
database_model_filename = database_filename + '.vect'
codetable_filename = database_name+'-latest-SLIM.ct'

In [16]:
# for i in read_codetable(os.path.join('.', 'databases', 'chess-latest-SLIM.ct')): 
#     print (' '.join(i))
print (len(read_codetable(os.path.join('.', 'databases', codetable_filename), True)))
print (len(read_codetable(os.path.join('.', 'databases', codetable_filename), False)))

code_table = read_codetable(os.path.join('.', 'databases', codetable_filename), False)
database_transactions = read_database_dat(os.path.join('.', 'databases', database_filename))

1799
1672


In [17]:
[x for x in sorted(code_table.values(), key=lambda x: x['support'], reverse=True) if x['usage'] != 0]

[{'code': ['0', '1', '2', '3', '7', '9', '10', '16', '23', '30', '39'],
  'usage': 210,
  'support': 28135},
 {'code': ['11', '18', '25', '32', '35'], 'usage': 83, 'support': 26337},
 {'code': ['41'], 'usage': 14, 'support': 25889},
 {'code': ['4', '5', '6', '8', '12', '19', '26', '33', '42'],
  'usage': 108,
  'support': 23869},
 {'code': ['46'], 'usage': 12, 'support': 19896},
 {'code': ['38', '50'], 'usage': 13, 'support': 12257},
 {'code': ['47', '48'], 'usage': 13, 'support': 10808},
 {'code': ['11', '18', '25', '35', '54'], 'usage': 269, 'support': 10532},
 {'code': ['0',
   '1',
   '2',
   '3',
   '7',
   '9',
   '10',
   '11',
   '16',
   '18',
   '23',
   '25',
   '30',
   '32',
   '35',
   '39'],
  'usage': 359,
  'support': 9808},
 {'code': ['13', '14', '15', '20', '22', '27', '29', '34', '37', '44'],
  'usage': 6534,
  'support': 9312},
 {'code': ['30', '52'], 'usage': 4, 'support': 9171},
 {'code': ['4', '5', '6', '8', '12', '19', '26', '33', '45'],
  'usage': 158,
  'supp

In [18]:
def calculate_cover(transaction, code_table): 
    item_set = set(transaction)
    codes = []
    current_code = 0
    while (len(item_set) != 0 and current_code < len(code_table) ):
        aux_code_set = set(code_table[current_code]['code'])
        if (aux_code_set.issubset(item_set)): 
            codes.append(current_code)
            item_set.difference_update(aux_code_set)
        current_code+=1
    return codes

In [19]:
trans = database_transactions[0]

c = calculate_cover (trans, code_table)
print (trans)
print (f'with length: {len(trans)}')
print ('codified as: ')
for aux in c: 
    print (f"{aux} --> {code_table[aux]} --> len == {len(code_table[aux]['code'])}")

['fic-1.5']
with length: 1
codified as: 


In [20]:
# testing the behaviour
sentences = MyOrderedSentences(os.path.join('.', 'databases', database_filename))
for sent in sentences: 
    print (sent)

['118', '93', '64', '47', '42', '38', '35', '33', '30', '28', '26', '23', '21', '19', '16', '14', '12', '9', '7', '5', '2', '0', '1', '4', '6', '8', '11', '13', '15', '18', '20', '22', '25', '27', '29', '32', '34', '37', '39', '44', '56', '80', '108']
['118', '93', '64', '47', '42', '38', '36', '33', '30', '28', '26', '23', '21', '19', '16', '14', '12', '9', '7', '5', '2', '0', '1', '4', '6', '8', '11', '13', '15', '18', '20', '22', '25', '27', '29', '32', '34', '37', '39', '44', '49', '80', '108']
['118', '93', '64', '46', '39', '37', '35', '33', '30', '28', '26', '23', '21', '19', '16', '14', '12', '9', '7', '5', '2', '0', '1', '4', '6', '8', '11', '13', '15', '18', '20', '22', '25', '27', '29', '32', '34', '36', '38', '42', '47', '80', '108']
['118', '93', '64', '47', '42', '38', '35', '33', '30', '28', '26', '23', '21', '19', '16', '14', '12', '9', '7', '5', '2', '0', '1', '4', '6', '8', '11', '13', '15', '18', '20', '22', '25', '27', '29', '32', '34', '36', '39', '44', '53', '80',

['65', '61', '47', '45', '37', '35', '32', '28', '26', '24', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '25', '27', '29', '34', '36', '38', '46', '52', '64']
['66', '62', '52', '46', '38', '36', '34', '28', '26', '24', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '25', '27', '29', '35', '37', '45', '47', '54', '64']
['81', '64', '52', '46', '38', '36', '34', '29', '27', '25', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '24', '26', '28', '33', '35', '37', '45', '47', '54', '66']
['81', '64', '52', '46', '38', '35', '33', '29', '27', '25', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '24', '26', '28', '32', '34', '37', '42', '47', '58', '65']
['81

['79', '71', '63', '53', '47', '42', '34', '30', '27', '25', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '26', '29', '33', '39', '44', '48', '56', '70', '74']
['79', '70', '56', '47', '44', '37', '33', '30', '27', '25', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '26', '29', '32', '34', '39', '45', '48', '63', '71']
['79', '71', '63', '51', '47', '42', '34', '30', '27', '25', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '26', '29', '33', '37', '44', '48', '56', '70', '74']
['92', '73', '63', '54', '47', '42', '37', '33', '29', '27', '23', '21', '19', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '20', '22', '26', '28', '30', '34', '39', '44', '48', '56', '70', '79']
['79

['85', '69', '58', '48', '44', '37', '34', '31', '27', '25', '23', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '22', '24', '26', '29', '33', '35', '43', '47', '52', '65', '74']
['85', '69', '49', '47', '43', '37', '33', '31', '29', '26', '23', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '22', '24', '27', '30', '32', '34', '39', '44', '48', '58', '73']
['92', '73', '58', '48', '44', '39', '34', '31', '29', '26', '23', '21', '19', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '20', '22', '24', '27', '30', '33', '37', '43', '47', '49', '69', '74']
['73', '58', '48', '46', '39', '34', '32', '30', '27', '24', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '26', '29', '31', '33', '37', '43', '47', '49', '69']
['73

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



['109', '82', '66', '51', '44', '40', '37', '34', '31', '28', '26', '22', '20', '18', '15', '13', '11', '8', '6', '4', '2', '0', '1', '3', '5', '7', '10', '12', '14', '17', '19', '21', '24', '27', '29', '33', '36', '38', '42', '49', '54', '78', '97']
['82', '66', '56', '46', '40', '37', '33', '31', '28', '25', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '24', '26', '29', '32', '35', '38', '42', '51', '59', '75']
['115', '82', '66', '54', '44', '40', '37', '34', '31', '27', '25', '22', '20', '18', '15', '13', '11', '8', '6', '4', '2', '0', '1', '3', '5', '7', '10', '12', '14', '17', '19', '21', '24', '26', '29', '33', '35', '38', '42', '51', '56', '69', '95']
['97', '74', '59', '46', '40', '37', '35', '31', '28', '26', '24', '21', '19', '17', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '18', '20', '22', '25', '27', '29', '33', '36', '38', '43', '51', '65', '82']


['81', '60', '52', '49', '41', '36', '32', '29', '27', '25', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '24', '26', '28', '31', '33', '37', '43', '50', '55', '65']
['84', '65', '55', '52', '42', '36', '33', '31', '27', '25', '23', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '24', '26', '28', '32', '35', '41', '50', '53', '60', '67']
['72', '60', '53', '50', '41', '35', '32', '30', '26', '24', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '25', '28', '31', '33', '36', '45', '52', '55', '67']
['83', '64', '55', '53', '45', '39', '35', '30', '27', '25', '23', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '24', '26', '28', '33', '36', '41', '50', '54', '60', '67']
['72

['78', '57', '49', '43', '40', '33', '31', '29', '27', '24', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '26', '28', '30', '32', '34', '41', '48', '51', '58']
['72', '58', '52', '45', '40', '33', '31', '29', '26', '24', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '25', '28', '30', '32', '35', '41', '48', '57', '59']
['75', '59', '53', '48', '41', '35', '32', '30', '26', '24', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '23', '25', '28', '31', '33', '40', '42', '51', '58', '67']
['81', '59', '57', '48', '41', '35', '32', '29', '27', '25', '22', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '24', '26', '28', '31', '33', '40', '42', '51', '58', '66']
['69

['95', '66', '54', '48', '41', '37', '34', '31', '28', '26', '24', '21', '19', '17', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '18', '20', '22', '25', '27', '29', '33', '35', '40', '45', '52', '56', '82']
['83', '67', '53', '48', '41', '36', '34', '32', '28', '26', '24', '20', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '19', '21', '25', '27', '31', '33', '35', '40', '42', '52', '65', '81']
['88', '74', '57', '48', '41', '36', '34', '30', '28', '25', '23', '21', '18', '16', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '17', '20', '22', '24', '27', '29', '31', '35', '40', '43', '51', '61', '76']
['95', '66', '53', '48', '41', '35', '33', '31', '28', '26', '24', '21', '19', '17', '14', '12', '10', '8', '6', '4', '2', '0', '1', '3', '5', '7', '9', '11', '13', '15', '18', '20', '22', '25', '27', '29', '32', '34', '40', '42', '52', '58', '81']
['75

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [21]:
import time

start = time.time()

dir_name = 'databases'
dim = 200
win = 10
epochs = 10

# for fname in os.listdir(os.path.join('.', dir_name)): 
#     if (fname.endswith('.db')): 
#         sentences = MySentences(os.path.join(dir_name, fname))
#         out_name = fname+'.vect'
#         #we force min_count to 1 in order not to miss any item
#         model = gensim.models.Word2Vec(size=dim, workers=12, window=win, sg=1, negative=15, iter=epochs, min_count=1)
#         model.build_vocab(sentences, progress_per=10000)
#         start_time = time.time()
#         model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs, report_delay=1)
#         end_time = time.time()

#         model.save(os.path.join(dir_name,out_name))
#         with open(os.path.join(dir_name,out_name+'-times'), 'w+') as out:
#             print (str(end_time-start_time))
#             out.write(' time to train: '+str(end_time-start_time))
#     else: 
#         print('skipping '+fname)

## create all the vects for a particular format
# for fname in [f for f in os.listdir('databases') if f.endswith('.dat')]: 
for fname in [database_filename]: 
    if (fname.endswith('.db') or fname.endswith('.dat')): 
        
        if (fname.endswith('.db')): 
            sentences = MyOrderedSentences(os.path.join(dir_name, fname))
        else: 
            sentences = MySentencesDat(os.path.join(dir_name,fname))
            
        out_name = fname+'_200d_5_10.vect'
        #we force min_count to 1 in order not to miss any item
        model = gensim.models.Word2Vec(size=dim, workers=12, window=win, sg=1, negative=15, iter=epochs, min_count=1)
        model.build_vocab(sentences, progress_per=10000)
        start_time = time.time()
        model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs, report_delay=1)
        end_time = time.time()

        model.save(os.path.join(dir_name,out_name))
        with open(os.path.join(dir_name,out_name+'-times'), 'w+') as out:
            print (str(end_time-start_time))
            out.write(' time to train: '+str(end_time-start_time))
    else: 
        print('skipping '+fname)
        
end = time.time()
print(end - start)


75.22903656959534
76.6860580444336


FROM HERE ONWARDS ONLY FOR DATABASES WITH BOTH WORD VECTORS AND A CODE TABLE ALREADY CALCULATED

In [28]:
from gensim.models import Word2Vec
# from sklearn.manifold import TSNE
from tsnecuda import TSNE


model = Word2Vec.load(os.path.join(dir_name, database_model_filename))
labelled_vects = {int(x): model.wv[x] for x in model.wv.vocab}

In [29]:
def calculate_centroids (model, labelled_transactions): 

#     initial code
#     centroids = []
#     for transaction in transaction_list: 
#         words = [model.wv[it] for it in transaction]
#         centroids.append(np.mean(words, axis=0))
#     return centroids

    # more pythonic way
    return { label: np.mean([model.wv[it] for it in  labelled_transactions[label]], axis=0) for label in labelled_transactions}

In [30]:
a ={0:'a', 1:'b'}
for i in a.values(): 
    print(i)
    
for i in enumerate([1,2,3]): 
    print(i)

a
b
(0, 1)
(1, 2)
(2, 3)


In [31]:
import numpy as np
database_transactions = read_database(os.path.join('.', 'databases', database_filename))
# aux_centroids = {}
# label = 0
# for current_trans in database_transactions: 
#     words = [model.wv[it] for it in database_transactions[current_trans]]
#     aux_centroids[label] = np.mean(words,axis=0)
#     label+=1
    
centroids = calculate_centroids(model, database_transactions)

NameError: name 'read_database' is not defined

In [39]:
# print (all([np.array_equal(aux_centroids[i], centroids[i]) for i in database_transactions]))    
# print (aux_centroids[1][:10]) 
# print (centroids[1][:10])

In [40]:
codes = read_codetable(os.path.join('.', 'databases', codetable_filename), False)
centroids_codes = calculate_centroids(model, {label:codes[label]['code'] for label in codes})

KeyError: "word '0' not in vocabulary"

In [41]:
# there are a number of blogs and people that suggest reducing the dimensions of the space 
# before using t-SNE due to its cost
# one way is to use PCA 

# taken from https://towardsdatascience.com/visualising-high-dimensional-datasets-using-pca-and-t-sne-in-python-8ef87e7915b

# pca = PCA(n_components=3)
# pca_result = pca.fit_transform(df[feat_cols].values)
# df['pca-one'] = pca_result[:,0]
# df['pca-two'] = pca_result[:,1] 
# df['pca-three'] = pca_result[:,2]
# print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_))
# Explained variation per principal component: [0.09746116 0.07155445 0.06149531]

# besides
# “Since t-SNE scales quadratically in the number of objects N, its applicability is limited to data 
# sets with only a few thousand input objects; beyond that, learning becomes too slow to be practical 
# (and the memory requirements become too large)”.

# However, I have to say that I've tried it with some tens of thousands 
# (visualizing the item embeddings of DBpedia 2016 conversion) and it was somehow bearable (can't recall the exact time it took, 
# but it ended ... which is nice :P) 

# for 'connect database' 65K transactions (and a non-very optimized version of t-SNE - there are multithreaded ones) 
# it has taken about 5 minutes (MSI-laptop) 

In [42]:
print(all([(x in labelled_vects) for x in range(len(labelled_vects))]))
print(all([(x in database_transactions) for x in range(len(database_transactions))]))
print(all([(x in code_table) for x in range(len(code_table))]))
print(all([(x in centroids_codes) for x in range(len(centroids_codes))]))
print(all([(x in centroids) for x in range(len(centroids))]))


False
True
True


NameError: name 'centroids_codes' is not defined

In [None]:
test = TSNE(n_components=2).fit_transform(np.array([labelled_vects[x] for x in range(len(labelled_vects))]))

In [None]:
top_position = len(centroids)
# 25000 if len(centroids)>5000 else len(centroids)
## 0 .. len(labelled_vects)-1 => projected vectors of items
## len(labelled_vects) .. len(labelled_vects) + len(centroids_codes) -1 => projected vectors of centroids of codes
## rest => transactions 
start_centroids_codes = len(labelled_vects)
start_centroids = len(labelled_vects) + len(centroids_codes)
positions_to_draw = np.array([labelled_vects[x] for x in range(len(labelled_vects))] + [centroids_codes[i] for i in range(len(centroids_codes))] + [centroids[i] for i in range(top_position)])

## We have to transform all of the ones that are going to be displayed at once
Z = TSNE(n_components=2).fit_transform(positions_to_draw)

In [None]:
## sanity check 
print(len(labelled_vects))
print(len(centroids_codes))
print(len(centroids))
print(np.array_equal(labelled_vects[0], positions_to_draw[0]))
print(np.array_equal(centroids_codes[0], positions_to_draw[start_centroids_codes]))
print(np.array_equal(centroids[0], positions_to_draw[start_centroids]))

In [None]:
codification = {}
for trans in database_transactions: 
    codification[trans] = calculate_cover(database_transactions[trans], code_table)


In [None]:
#Plot helpers
import matplotlib
import matplotlib.pyplot as plt
#Enable matplotlib to be interactive (zoom etc)
# %matplotlib notebook

In [None]:
a=[1,2,3,4,5]
print(a[:2])

In [None]:
fig_size = 5
fig, (ax1,ax2) = plt.subplots(1,2)
fig.set_size_inches(fig_size*2,fig_size, forward=True)
ax1.plot(Z[:len(labelled_vects), 0], Z[:len(labelled_vects), 1], 'o')
ax1.set_title(f'Items - d{dim} w{win} e{epochs} - {database_name}')
ax1.set_yticklabels([]) #Hide ticks
ax1.set_xticklabels([]) #Hide ticks
#Transactions
ax2.plot(Z[start_centroids:, 0], Z[start_centroids:, 1], 'o')
print (len(Z[start_centroids:,0]))
print (len(Z[start_centroids_codes:start_centroids_codes+len(centroids_codes), 1]))
#codes
ax2.plot(Z[start_centroids_codes:start_centroids_codes+len(centroids_codes), 0], \
         Z[start_centroids_codes:start_centroids_codes+len(centroids_codes), 1], 'x')
ax2.set_title(f'Transactions - d{dim} w{win} e{epochs} - {database_name}')
ax2.set_yticklabels([]) #Hide ticks
ax2.set_xticklabels([]) #Hide ticks
plt.show()

In [None]:
highest_usage_code = [x for x in sorted(code_table, key=lambda y: code_table[y]['usage'], reverse=True)][:5]

print(highest_usage_code)
[print(f"U:{code_table[x]['usage']} S:{code_table[x]['support']} C:{code_table[x]['code']}") for x in highest_usage_code]

In [None]:
num_codes_to_draw = 10
num_highest_usage = 5

fig_size = 5
highest_usage_code = [x for x in sorted(code_table, key=lambda y: code_table[y]['usage'], reverse=True)][:num_highest_usage]
fig, ax = plt.subplots(num_codes_to_draw,1)
fig.set_size_inches(fig_size,num_codes_to_draw*fig_size, forward=True)
for i in range(num_codes_to_draw-num_highest_usage):     
    ## all the transactions 
    ax[i].plot(Z[start_centroids:, 0], Z[start_centroids:, 1], '3')
    ## the centroid of the current code
    ax[i].plot(Z[start_centroids_codes+i,0], Z[start_centroids_codes+i,1], 's')
    ## the transactions covered by the current code
    affected_x =[Z[start_centroids+cod, 0] for cod in codification if i in set(codification[cod]) and cod <top_position] 
    affected_y =[Z[start_centroids+cod, 1] for cod in codification if i in set(codification[cod]) and cod <top_position]
    ax[i].plot(affected_x, affected_y, 'o')
    ## the items of the current code
    code = code_table[i]['code']
    items_x = [Z[int(x),0] for x in code]
    items_y = [Z[int(x),1] for x in code]
    ax[i].plot(items_x, items_y, '*')
    ax[i].set_title(f"{database_name} - Influence code {i} - usage:{code_table[i]['usage']} - support:{code_table[i]['support']}")
    
for idx in enumerate(highest_usage_code): 
    ## all the transactions 
    pos = num_codes_to_draw-num_highest_usage+idx[0]
    ax[pos].plot(Z[start_centroids:, 0], Z[start_centroids:, 1], '3')
    ## the centroid of the current code
    ax[pos].plot(Z[start_centroids_codes+idx[1],0], Z[start_centroids_codes+idx[1],1], 's')
    ## the transactions covered by the current code
    affected_x =[Z[start_centroids+cod, 0] for cod in codification if idx[1] in set(codification[cod]) and cod <top_position] 
    affected_y =[Z[start_centroids+cod, 1] for cod in codification if idx[1] in set(codification[cod]) and cod <top_position]
    ax[pos].plot(affected_x, affected_y, 'o')
    ## the items of the current code
    code = code_table[idx[1]]['code']
    items_x = [Z[int(x),0] for x in code]
    items_y = [Z[int(x),1] for x in code]
    ax[pos].plot(items_x, items_y, '*')
    ax[pos].set_title(f"{database_name} - Influence code {idx[1]} - usage:{code_table[idx[1]]['usage']} - support:{code_table[idx[1]]['support']}")
    
    
plt.show()