In [1]:
import faiss
import pandas as pd

import jsonlines

from code.fasttext.embedding_utils import TableEncoder
from code.utils.settings import DefaultPath
from code.utils.utils import rebuild_table

from code.faiss.lut import LUT

## Sanity checks on v1_2 indexes

There are many tables with identical rows/columns, this results in L2-distances=0!

In [2]:
tabenc = TableEncoder(
    model_path=DefaultPath.model_path.fasttext + 'ft_cc.en.300_freqprune_400K_100K_pq_300.bin'
)

In [36]:
db_path = DefaultPath.db_path.faiss + 'v1_2/48_50000_add_label_True_L2/'

In [37]:
row_index = faiss.read_index(db_path + 'row_index.index') 
column_index = faiss.read_index(db_path + 'column_index.index')

In [38]:
row_index.ntotal, column_index.ntotal

(508855, 333343)

In [39]:
row_lut = LUT()
column_lut = LUT()

row_lut.load(db_path + 'row_lut.json')
column_lut.load(db_path + 'column_lut.json')
row_lut.ntotal, column_lut.ntotal

(50000, 50000)

In [16]:
bad_tables = []
with jsonlines.open(DefaultPath.data_path.wikitables + 'sloth_tables.jsonl') as reader:
    for i, json_table in enumerate(reader):
        if i >= 160: 
            break
        id_table = json_table['_id']
        table = rebuild_table(json_table)
        
        row_emb, col_emb = tabenc.full_embedding(table, add_label=True)
        D, I = row_index.search(row_emb, 3)

        if not all([round(xe[0], 5) == 0 for xe in D]):
            print(f'Errore nelle distanze per i=={i}')
            
        if not all([row_lut.lookup(xi[0]) == id_table for xi in I]):
            print(f'Errore negli ID per i == {i}')
            bad_tables.append((i, table, id_table))


Errore negli ID per i == 31
Errore negli ID per i == 55
Errore negli ID per i == 62
Errore negli ID per i == 71
Errore negli ID per i == 77
Errore negli ID per i == 85
Errore negli ID per i == 92
Errore negli ID per i == 93
Errore negli ID per i == 94
Errore negli ID per i == 97
Errore negli ID per i == 108
Errore negli ID per i == 149
Errore negli ID per i == 153


In [18]:
[bt[0] for bt in bad_tables]

[31, 55, 62, 71, 77, 85, 92, 93, 94, 97, 108, 149, 153]

In [19]:
idx = 149
i_table_id = [bt for bt in bad_tables if bt[0] == idx][0]
row_emb, col_emb = tabenc.full_embedding(i_table_id[1], add_label=True)
i_table_id[1]

Unnamed: 0,Program,Bachelor's (5Yrs),Master's (7Yrs),Diploma (2Yrs)
0,Civil Engineering,B.E,M.E,A.G.T.I
1,Electronic and Communication,B.E.,,A.G.T.I
2,Electrical Power,B.E.,,A.G.T.I
3,Mechanical Engineering,B.E.,,A.G.T.I
4,Computer Numerical Control Engineering,,,A.G.T.I


In [20]:
D, I = row_index.search(row_emb, 3)
D

array([[0.        , 0.18108512, 0.18108512],
       [0.        , 0.        , 0.00206247],
       [0.        , 0.        , 0.00542197],
       [0.        , 0.        , 0.00291783],
       [0.        , 0.1171796 , 0.1171796 ]], dtype=float32)

In [21]:
I

array([[  1596,   1590,   1597],
       [  1590,   1597,   1595],
       [  1591,   1598,   1592],
       [  1592,   1599,   1590],
       [  1600, 137881, 137887]])

In [250]:
i_table_id[2]

'37803389-1'

In [22]:
for res in I:
    for r in res:
        print(row_lut.lookup(r), end='\t')
    print()

37803389-1	37803315-1	37803389-1	
37803315-1	37803389-1	37803315-1	
37803315-1	37803389-1	37803315-1	
37803315-1	37803389-1	37803315-1	
37803389-1	26816358-12	26816358-16	


In [23]:
id_s = row_lut.lookup(1596)
id_t = row_lut.lookup(1590)
t, s = None, None

with jsonlines.open(DefaultPath.data_path.wikitables + 'sloth_tables.jsonl') as reader:
    for i, json_table in enumerate(reader):
        if json_table['_id'] == id_t:
            t = rebuild_table(json_table)
        if json_table['_id'] == id_s:
            s = rebuild_table(json_table)
        if t is not None and s is not None:
            break
t

Unnamed: 0,Program,Bachelor's (5Yrs),Master's (7Yrs),Diploma (2Yrs)
0,Civil Engineering,B.E,,A.G.T.I
1,Electronic and Communication,B.E.,,A.G.T.I
2,Electrical Power,B.E.,,A.G.T.I
3,Mechanical Engineering,B.E.,,A.G.T.I
4,Mining Engineering,B.E.,,A.G.T.I
5,Biotechnology,B.E,,A.G.T.I
6,Information Technology,B.E,,A.G.T.I


In [253]:
s

Unnamed: 0,Program,Bachelor's (5Yrs),Master's (7Yrs),Diploma (2Yrs)
0,Civil Engineering,B.E,M.E,A.G.T.I
1,Electronic and Communication,B.E.,,A.G.T.I
2,Electrical Power,B.E.,,A.G.T.I
3,Mechanical Engineering,B.E.,,A.G.T.I
4,Computer Numerical Control Engineering,,,A.G.T.I


In [216]:
row_t, col_t = tabenc.full_embedding(t, add_label=True)
row_s, col_s = tabenc.full_embedding(s, add_label=True)
row_t.shape, row_s.shape

In [218]:
from code.fasttext.embedding_utils import compare_embeddings

c = compare_embeddings(t, s, tabenc, on='rows', add_label=True)
c.head()

100%|██████████| 3/3 [00:00<00:00, 29.52it/s]


In [220]:
import numpy as np
dist = pd.DataFrame(columns=['R1', 'R2', 'L2'])
for i, rs in enumerate(row_s):
    for j, rt in enumerate(row_t):
        dist.loc[len(dist)] = [i, j, np.linalg.norm(rs - rt)]
dist.sort_values(by=['L2']).head()

Unnamed: 0,R1,R2,L2
0,0.0,0.0,0.0
9,3.0,0.0,0.123911
2,0.0,2.0,0.168474
11,3.0,2.0,0.190462
8,2.0,2.0,0.190531
