In [2]:
import pandas as pd
import jsonlines
from tqdm.notebook import tqdm
from time import time
import chromadb

from code.fasttext.embedding_utils import TableEncoder

In [3]:
chroma_client = chromadb.PersistentClient()

tabenc = TableEncoder()

In [4]:
def rebuild_table(table):
    return pd.DataFrame(
        data=[
            [entry_data['text'] 
             for entry_data in entry]
            for entry in table['tableData']
        ],
        columns=table['tableHeaders'][0]
        )

In [5]:
sloth_tested_pairs = pd.read_csv(
    '/home/giovanni/unimore/TESI/src/data/train_set_turl_malaguti.csv',
    nrows=1000
    )

In [6]:
ids = list(set(sloth_tested_pairs['r_id']).union(sloth_tested_pairs['s_id']))
len(ids)

1905

In [7]:
wikitables = {}
with jsonlines.open('/home/giovanni/unimore/TESI/src/data/small_train_tables.jsonl', 'r') as reader:
    for obj in reader:
        wikitables[obj['_id']] = obj

In [8]:
try:
    chroma_client.delete_collection("column-base")
    chroma_client.delete_collection("row-base")
    print('Collections deleted')
except: pass
finally:
    collection_column_base = chroma_client.create_collection(name='column-base')
    collection_row_base = chroma_client.create_collection(name='row-base')
    print('Collections created')

Collections created


### Storing on persistent chroma client - batch  

In [10]:
batch_size = 100
batch_column_ids, batch_row_ids = [], []
batch_column_embeddings, batch_row_embeddings = [], []

i = 0 

#with jsonlines.open('/home/giovanni/unimore/TESI/src/data/small_train_tables.jsonl', 'r') as reader:
#    for wikitable in reader:

if 1:        
    for table_id in tqdm(ids):
        try:
            wikitable = wikitables[table_id]
        except KeyError:
            print(f'Table ID {table_id} not found')
            continue
        
        table_id = wikitable['_id']
        i += 1

        table = rebuild_table(wikitable)
        row_embeddings, column_embeddings = tabenc.full_embedding(table, False, False)


        batch_column_ids.extend([f"{table_id}#{i}" for i in range(column_embeddings.shape[0])])
        batch_row_ids.extend([f"{table_id}#{i}" for i in range(row_embeddings.shape[0])])
        
        batch_column_embeddings.extend(column_embeddings.tolist())
        batch_row_embeddings.extend(row_embeddings.tolist())

        if i % batch_size == 0 and i != 0:
            print(i)
            collection_column_base.add(
                ids=batch_column_ids,
                embeddings=batch_column_embeddings
            )

            collection_row_base.add(
                ids=batch_row_ids,
                embeddings=batch_row_embeddings
            )
            
            batch_column_ids, batch_row_ids = [], []
            batch_column_embeddings, batch_row_embeddings = [], []

100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
