In [9]:
import os
import jsonlines
import random
import pandas as pd
from time import time
from tqdm.notebook import tqdm
from itertools import product

from code.fasttext.embedding_utils import TableEncoder
from code.utils.settings import DefaultPath
from code.utils.utils import rebuild_table

import chromadb

In [2]:
# medium size db, with metadata
chroma_client = chromadb.PersistentClient(DefaultPath.db_path.chroma + f'v_m_0')

tabenc = TableEncoder()

In [3]:
sloth_tested_pairs = pd.read_csv(
    DefaultPath.data_path.wikitables + 'train_set_turl_malaguti.csv',
    nrows=10000
)

In [7]:
wikitables = {}

with jsonlines.open(DefaultPath.data_path.wikitables + 'medium_train_tables.jsonl', 'r') as reader:
    for obj in reader:
        wikitables[obj['_id']] = obj

In [5]:
ids = list(set(sloth_tested_pairs['r_id']).union(sloth_tested_pairs['s_id']))
len(ids)

14209

In [None]:
# random_id = '34163802-12' esempio carino

In [52]:
random_id = ids[random.randint(0, len(ids))]
random_id

'34163802-12'

In [53]:
table = rebuild_table(wikitables[random_id])
table

Unnamed: 0,Blue Group,Skip,W,L
0,United States,Korey Dropkin,7,0
1,Switzerland,Michael Brunner,6,1
2,Czech Republic,Marek Černovský,4,3
3,China,Bai Yang,3,4
4,Norway,Markus Skogvold,3,4
5,South Korea,Kang Sue-yeon,2,5
6,New Zealand,Luke Steele,2,5
7,Estonia,Robert-Kent Päll,1,6


In [54]:
row_embeddings, columns_embeddings = tabenc.full_embedding(table)

In [55]:
chroma_client.list_collections()

[Collection(name=column-base), Collection(name=row-base)]

In [56]:
res = chroma_client \
    .get_collection('row-base') \
        .query(
            query_embeddings=[row_embeddings[0].tolist()],
            n_results=3,
            where={'table_id': {"$ne": random_id}},
            include=['metadatas', 'distances']
        )
res

{'ids': [['34163295-11#0', '34116170-5#0', '31320420-16#6']],
 'distances': [[0.0, 0.0, 0.04128605127334595]],
 'metadatas': [[{'table_id': '34163295-11'},
   {'table_id': '34116170-5'},
   {'table_id': '31320420-16'}]],
 'embeddings': None,
 'documents': [[None, None, None]],
 'uris': None,
 'data': None}

In [57]:
table.loc[0].tolist()

['United States', 'Korey Dropkin', '7', '0']

In [58]:
table_res = []
for r in res['ids'][0]:
    table_id, row_id = r.split('#')
    table_res.append(rebuild_table(wikitables[table_id]))
    print(table_res[-1].loc[int(row_id)].tolist())

['United States', 'Korey Dropkin', '7', '0']
['United States', 'Korey Dropkin', '7', '0']
['7', 'Switzerland', '0.960', '+0.001']


In [59]:
table_res[0]

Unnamed: 0,Blue Group,Skip,W,L
0,United States,Korey Dropkin,7,0
1,Switzerland,Michael Brunner,6,1
2,Czech Republic,Marek Černovský,4,3
3,China,Bai Yang,3,4
4,Norway,Markus Skogvold,3,4
5,South Korea,Kang Sue-yeon,2,5
6,New Zealand,Luke Steele,2,5
7,Estonia,Robert-Kent Päll,1,6


In [60]:
table_res[1]

Unnamed: 0,Blue Group,Skip,W,L
0,United States,Korey Dropkin,7,0
1,Switzerland,Michael Brunner,6,1
2,Czech Republic,Marek Černovský,4,3
3,China,Bai Yang,3,4
4,Norway,Markus Skogvold,3,4
5,South Korea,Kang Sue-yeon,2,5
6,New Zealand,Luke Steele,2,5
7,Estonia,Robert-Kent Päll,1,6


In [61]:
table_res[2]

Unnamed: 0,Rank,Country,HDI,HDI.1
0,1,Norway,0.971,0.001
1,2,Iceland,0.969,0.002
2,3,Ireland,0.965,0.001
3,4,Netherlands,0.964,0.003
4,5,Sweden,0.963,0.002
5,6,France,0.961,0.003
6,7,Switzerland,0.96,0.001
7,8,Luxembourg,0.96,0.001
8,9,Finland,0.959,0.004
9,10,Austria,0.955,0.003
