In [8]:
import pandas as pd
import jsonlines
import json
import faiss
import bisect

from code.fasttext.embedding_utils import TableEncoder
from code.utils.settings import DefaultPath
from code.utils.utils import rebuild_table

### Create index and LUTs

In [77]:
tabenc = TableEncoder(
    model_path=DefaultPath.model_path.fasttext + 'ft_cc.en.300_freqprune_400K_100K_pq_300.bin'
    )

d = 300 # fastText vectors are 300D
row_index = faiss.IndexFlatL2(d)
column_index = faiss.IndexFlatL2(d)

In [78]:
row_LUT = [
    [], # rows steps
    [], # corresponding table IDs
]

column_LUT = [
    [], # columns steps
    [], # corresponding table IDs
]

In [79]:
def get_table_from_jsonl(table_id):
    with jsonlines.open(DefaultPath.data_path.wikitables + 'sloth_tables.jsonl') as reader:
        while (m := reader.read())['_id'] != table_id:
            continue
        return rebuild_table(m)

In [80]:
n_tables_to_read = 100
tables = []

with jsonlines.open(DefaultPath.data_path.wikitables + 'sloth_tables.jsonl') as reader:
    for i, json_table in enumerate(reader):
        if i >= n_tables_to_read: 
            break
        table = rebuild_table(json_table)
        tables.append(table)
        row_embeddings, column_embeddings = tabenc.full_embedding(table, add_label=True)
        
        row_LUT[0].append(
            row_embeddings.shape[0] - 1 if row_LUT[0] == [] else row_LUT[0][-1] + row_embeddings.shape[0]
        )
        row_LUT[1].append(json_table['_id'])
        row_index.add(row_embeddings)
        
        column_LUT[0].append(
            column_embeddings.shape[0] - 1 if column_LUT[0] == [] else column_LUT[0][-1] + column_embeddings.shape[0]
        )
        column_LUT[1].append(json_table['_id'])
        column_index.add(column_embeddings)

In [81]:
row_index.ntotal, column_index.ntotal

(1039, 709)

### Saving index and LUTs

In [7]:
ROW_INDEX_FILEPATH = DefaultPath.db_path.faiss + 'row_index.index'
COLUMN_INDEX_FILEPATH = DefaultPath.db_path.faiss + 'column_index.index'

LUT_FILEPATH = DefaultPath.db_path.faiss + 'lut.json'

In [9]:
faiss.write_index(row_index,    ROW_INDEX_FILEPATH)
faiss.write_index(column_index, COLUMN_INDEX_FILEPATH)

with open(LUT_FILEPATH, 'w') as lut_writer:
    json.dump({'row_LUT': row_LUT, 'column_LUT': column_LUT}, lut_writer)


### Reload index and LUTs

In [10]:
row_index_2 = faiss.read_index(ROW_INDEX_FILEPATH)
column_index_2 = faiss.read_index(COLUMN_INDEX_FILEPATH)

with open(LUT_FILEPATH, 'r') as lut_reader:
    luts = json.load(lut_reader)

In [12]:
row_index.ntotal, row_index_2.ntotal

(75, 75)

In [13]:
column_index.ntotal, column_index_2.ntotal

(38, 38)

In [14]:
luts

{'row_LUT': [[18, 28, 39, 64, 74],
  ['27283419-8', '27284790-1', '27285187-1', '27285645-2', '27285838-1']],
 'column_LUT': [[6, 16, 26, 32, 37],
  ['27283419-8', '27284790-1', '27285187-1', '27285645-2', '27285838-1']]}

### Retrieving vectors by id

In [83]:
vid = 0 # --> first row of the first table in 'tables'
vec = row_index.reconstruct(0)
vec.shape

(300,)

In [85]:
from code.fasttext.embedding_utils import np_cosine_similarity

row_embeddings, column_embeddings = tabenc.full_embedding(tables[0], add_label=True)
np_cosine_similarity(vec, row_embeddings[0])   # with 1 they are actually the same vector

1.0

### Sanity check

In [23]:
row_LUT

[[18, 28, 39, 64, 74],
 ['27283419-8', '27284790-1', '27285187-1', '27285645-2', '27285838-1']]

In [30]:
tab0 = tables[0]
tab0.head()

Unnamed: 0,Pos,Driver,MNZ,NUR,ADR,VAL,Pts
0,1,Giuliano Alessi,3,2.0,1,1,75
1,2,Alessandro Balzan,Ret,1.0,3,3,49
2,3,Alessandro Battaglin,2,,2,4,41
3,4,Kristian Ghedina,5,8.0,6,2,39
4,5,Max Pigoli,1,3.0,Ret,10,35


In [44]:
t0_row_embeddings, t0_col_embeddings = tabenc.full_embedding(tab0, add_label=True)

In [45]:
D, I = row_index.search(t0_row_embeddings, 4)

In [46]:
I

array([[ 0, 47, 48, 56],
       [ 1,  2,  4,  0],
       [ 2,  1,  4, 43],
       [ 3, 43,  2,  6],
       [ 4,  1,  2,  6],
       [ 5,  7, 12, 11],
       [ 6, 10,  7,  5],
       [ 7, 10,  5,  6],
       [ 8,  9, 13, 12],
       [ 9,  8, 12, 11],
       [10,  7,  5,  6],
       [11, 12,  8, 13],
       [12, 11,  8, 13],
       [13,  8, 12, 11],
       [14, 15, 13,  8],
       [15, 14, 17, 16],
       [16, 17, 15, 18],
       [17, 16, 15, 18],
       [18, 17, 16, 73]])

In [47]:
D

array([[0.        , 0.09946109, 0.10641481, 0.13238117],
       [0.        , 0.02899076, 0.09312305, 0.14556909],
       [0.        , 0.02899076, 0.11655647, 0.14818348],
       [0.        , 0.1378604 , 0.16162254, 0.16894306],
       [0.        , 0.09312305, 0.11655647, 0.1329234 ],
       [0.        , 0.04215686, 0.09135455, 0.0938854 ],
       [0.        , 0.10489605, 0.1061948 , 0.11966924],
       [0.        , 0.02832747, 0.04215686, 0.1061948 ],
       [0.        , 0.01623926, 0.02424421, 0.02504677],
       [0.        , 0.01623926, 0.03399125, 0.03717704],
       [0.        , 0.02832747, 0.09542399, 0.10489605],
       [0.        , 0.00856783, 0.02585012, 0.03472276],
       [0.        , 0.00856783, 0.02504677, 0.03047948],
       [0.        , 0.02424421, 0.03047948, 0.03472276],
       [0.        , 0.07427008, 0.13655701, 0.15150769],
       [0.        , 0.07427008, 0.09341952, 0.09413615],
       [0.        , 0.00054896, 0.09413615, 0.09648809],
       [0.        , 0.00054896,

The sanity check is ok.

### Search example

In [48]:
new_tables = []
with jsonlines.open(DefaultPath.data_path.wikitables + 'sloth_tables.jsonl') as reader:
    for i in range(n_tables_to_read):
        reader.read()
    for i in range(5):
        new_tables.append(rebuild_table(reader.read()))

In [52]:
new_tables[2]

Unnamed: 0,Year,Round,Pld,W,D,L,GS,GA
0,2000,Champions,5,4,1,0,7,1
1,2002,Third Place,4,1,2,1,4,3
2,2004,Champions,4,4,0,0,17,3
3,2007,Champions,4,3,1,0,5,1
4,2008,Champions,4,4,0,0,13,2
5,2010,RunnerUp,4,2,1,1,8,5
6,2012,Group Stage,3,1,2,0,2,1
7,Total,7 Titles,28,19,7,2,56,16


In [53]:
query_row_emb, query_col_emb = tabenc.full_embedding(new_tables[2], add_label=True)

In [54]:
D, I = row_index.search(query_row_emb, 4)

In [55]:
I

array([[45, 42, 51, 50],
       [42, 48, 45, 44],
       [45, 42, 52, 51],
       [42, 45, 51, 50],
       [45, 42, 51, 50],
       [46, 44, 48, 47],
       [42, 51, 50, 56],
       [ 3, 41, 43,  6]])

In [56]:
D

array([[0.20610723, 0.21844062, 0.22837846, 0.23063575],
       [0.23178735, 0.25529832, 0.2558034 , 0.26046154],
       [0.27518916, 0.27648422, 0.35642472, 0.35974297],
       [0.19421408, 0.20366746, 0.23187867, 0.2337852 ],
       [0.2730953 , 0.27647525, 0.3351663 , 0.3358405 ],
       [0.22735554, 0.23806359, 0.24144514, 0.24659696],
       [0.22766678, 0.23879595, 0.23965047, 0.25475562],
       [0.3527736 , 0.48265165, 0.49419552, 0.4999259 ]], dtype=float32)

In [70]:
# retrieving the table ID of a vector
table_id = row_LUT[1][bisect.bisect_left(row_LUT[0], 3)]
table_id

'27283419-8'

In [71]:
get_table_from_jsonl(table_id)

Unnamed: 0,Pos,Driver,MNZ,NUR,ADR,VAL,Pts
0,1,Giuliano Alessi,3,2,1,1,75
1,2,Alessandro Balzan,Ret,1,3,3,49
2,3,Alessandro Battaglin,2,,2,4,41
3,4,Kristian Ghedina,5,8,6,2,39
4,5,Max Pigoli,1,3,Ret,10,35
5,6,Marco Gregori,4,,,5,18
6,7,Francesco Ascani,6,5,Ret,7,18
7,8,Mauro Simoncini,,7,4,,14
8,9,Moreno Petrini,Ret,4,Ret,,10
9,10,Corrado Canneori,,,5,,8
