In [2]:
import faiss
import pandas as pd
import polars as pl

from code.faiss.lut import LUT
from code.utils.table import Table
from code.utils.settings import DefaultPath

from sentence_transformers import SentenceTransformer

In [3]:
test_root_dir = DefaultPath.data_path.wikitables + 'threshold-r5-c2-a50'

results = pl.scan_csv(f'{test_root_dir}/sloth-results-r5-c2-a50.csv')
tables_path = f'{test_root_dir}/sloth-tables-r5-c2-a50.jsonl'

Carico ogni tabella in formato csv solo per far prima, altrimenti ogni volta c'è da andare a leggere
il jsonl ed è una noia.

In [3]:
from code.utils.utils import rebuild_table
from tqdm.notebook import tqdm

# with jsonlines.open(tables_path) as reader:
#     for jtable in tqdm(reader, total=45673):
#         rebuild_table(jtable).to_csv(f'{test_root_dir}/csv/{jtable["_id"]}', index=False)

Dai risultati di SLOTH, prendo le prime 1000 righe: con gli elementi unici a sinistra delle coppie (r_id) creo un indice, con quelli a destra (s_id) faccio l'operazione di testing per fare qualche valutazione.

In [4]:
index_ids = set(results.select(pl.col('r_id')).head(100).unique().collect()['r_id'].to_list())
len(index_ids)

98

In [5]:
model = SentenceTransformer("all-MiniLM-L6-v2")

In [7]:
def encode_table(table:Table, model, normalize_embeddings=False):
  row_embs = model.encode([
        '|'.join([f'{h},{cell}' for (h, cell) in zip(table.headers, t)]) 
        for t in table.get_tuples()
    ],
    normalize_embeddings=normalize_embeddings
  )

  col_embs = model.encode([
        f"{h},{','.join(map(str, col))}"
        for h, col in zip(table.headers, table.columns)
    ],
    normalize_embeddings=normalize_embeddings
  )

  return row_embs, col_embs

In [None]:
dim = model.get_sentence_embedding_dimension()
row_index, column_index = faiss.IndexFlatL2(dim), faiss.IndexFlatL2(dim) 
row_lut, column_lut = LUT(), LUT()


for id in tqdm(index_ids):
    df = pd.read_csv(test_root_dir + f'/csv/{id}')
    table = Table()
    table.from_pandas(df)

    row_emb, col_emb = encode_table(table, model, True)

    row_lut.insert_index(row_emb.shape[0], id)
    column_lut.insert_index(col_emb.shape[0], id)

    row_index.add(row_emb)
    column_index.add(col_emb)


In [9]:
faiss.write_index(row_index, 'basic_index/row_index.index')
faiss.write_index(column_index, 'basic_index/column_index.index')
import json
with open('basic_index/row_lut.json', 'w') as writer:
    json.dump(
        { 'idxs': row_lut.idxs, 'ids': row_lut.table_ids },
        writer
    )
with open('basic_index/column_lut.json', 'w') as writer:
    json.dump(
        { 'idxs': column_lut.idxs, 'ids':  column_lut.table_ids },
        writer
    )

In [8]:
dim = model.get_sentence_embedding_dimension()
row_index, column_index = faiss.read_index('basic_index/row_index.index'), faiss.read_index('basic_index/column_index.index')
row_lut, column_lut = LUT(), LUT()
row_lut.load('basic_index/row_lut.json')
column_lut.load('basic_index/column_lut.json')

In [9]:
test_ids = set(results.select(pl.col('s_id')).head(100).unique().collect()['s_id'].to_list())
len(test_ids)

99

In [10]:
test_ids = list(test_ids.difference(index_ids))
len(test_ids)

98

In [11]:
results.head(100).filter(pl.col('s_id') == test_ids[0]).collect()

r_id,s_id,jsim,o_a,a%
str,str,f64,i64,f64
"""4918676-1""","""5116633-1""",0.468085,44,0.174603


In [12]:
df = pl.read_csv(test_root_dir +  f'/csv/{test_ids[0]}')
df.head()

Club,Season,League,League_duplicated_0,FA Cup,FA Cup_duplicated_0,League Cup,League Cup_duplicated_0,Other,Other_duplicated_0,Total,Total_duplicated_0
str,str,i64,i64,i64,i64,i64,i64,str,str,i64,i64
"""Reading""","""1999–00""",6,0,0,0,0,0,"""0""","""0""",6,0
"""Reading""","""2000–01""",4,0,0,0,1,0,"""1""","""0""",6,0
"""Reading""","""2001–02""",38,7,2,0,3,2,"""2""","""2""",45,11
"""Reading""","""2002–03""",22,4,1,0,0,0,"""2""","""0""",25,4
"""Reading""","""2003–04""",1,0,0,0,0,0,"""—""","""—""",1,0


In [13]:
table = Table()
table.from_polars(df)

In [14]:
row_emb, col_emb = encode_table(table, model, True)

In [15]:
D, I = row_index.search(row_emb, 3)

In [17]:
import numpy as np

cnt = np.unique(np.vectorize(row_lut.lookup)(I), return_counts=True)
cnt

(array(['10777756-1', '39176684-1', '40082725-1', '4918676-1'],
       dtype='<U10'),
 array([12,  3,  1, 68]))

In [52]:
r = list(zip(cnt[0], cnt[1]))
r

[('10777756-1', 12), ('39176684-1', 3), ('40082725-1', 1), ('4918676-1', 68)]

In [53]:
r.append(('39176684-1', 44))
r = sorted(r, key=lambda x: x[0])
r

[('10777756-1', 12),
 ('39176684-1', 3),
 ('39176684-1', 44),
 ('40082725-1', 1),
 ('4918676-1', 68)]

In [59]:
from itertools import groupby
from functools import reduce

an_iterator = groupby(r, lambda x : x[0]) 
  
for key, group in an_iterator: 
    key_and_group = {key : reduce(lambda x, y: x[1] + y[1] if y != None else x[1], list(group))} 
    print(key_and_group) 

{'10777756-1': ('10777756-1', 12)}
{'39176684-1': (<class 'tuple'>, <class 'tuple'>)}
{'40082725-1': ('40082725-1', 1)}
{'4918676-1': ('4918676-1', 68)}
