In [9]:
import pandas as pd
import jsonlines
from tqdm.notebook import tqdm
from time import time
import chromadb

from code.fasttext.embedding_utils import TableEncoder
from code.utils.settings import DefaultPath
from code.utils.utils import rebuild_table

In [2]:
tabenc = TableEncoder()

In [4]:
sloth_tested_pairs = pd.read_csv(
    DefaultPath.data_path.wikitables + 'train_set_turl_malaguti.csv',
    nrows=10000
    )

In [5]:
ids = list(set(sloth_tested_pairs['r_id']).union(sloth_tested_pairs['s_id']))
len(ids)

14209

In [6]:
chroma_client = chromadb.PersistentClient(DefaultPath.db_path.chroma + 'v0')

try:
    chroma_client.delete_collection("column-base")
    chroma_client.delete_collection("row-base")
    print('Collections deleted')
except: pass
finally:
    collection_column_base = chroma_client.create_collection(name='column-base')
    collection_row_base = chroma_client.create_collection(name='row-base')
    print('Collections created')

Collections created


### Storing on persistent chroma client - batch  

In [8]:
batch_size = 500
batch_column_ids, batch_row_ids = [], []
batch_column_embeddings, batch_row_embeddings = [], []

wikitables = {}
with jsonlines.open(DefaultPath.data_path.wikitables + 'small_train_tables.jsonl', 'r') as reader:
    for obj in reader:
        wikitables[obj['_id']] = obj

i = 0 

for table_id in tqdm(wikitables.keys()):
    try:
        wikitable = wikitables[table_id]
    except KeyError:
        print(f'Table ID {table_id} not found')
        continue
    
    table_id = wikitable['_id']
    i += 1

    table = rebuild_table(wikitable)
    row_embeddings, column_embeddings = tabenc.full_embedding(table, False, False)

    # batch_column_ids.extend(list(range(column_embeddings.shape[0])))
    # batch_row_ids.extend(list(range(row_embeddings.shape[0]))) 

    # ID - version 1
    batch_column_ids.extend([f'{table_id}#{idx}' for idx in range(column_embeddings.shape[0])])
    batch_row_ids.extend([f'{table_id}#{idx}' for idx in range(row_embeddings.shape[0])]) 

    batch_column_embeddings.extend(column_embeddings.tolist())
    batch_row_embeddings.extend(row_embeddings.tolist())

    if i % batch_size == 0 and i != 0:
        collection_column_base.add(
            ids=batch_column_ids,
            #metadatas=[{'table_id': table_id}] * column_embeddings.shape[0],
            embeddings=batch_column_embeddings
        )

        collection_row_base.add(
            ids=batch_row_ids,
            #metadatas=[{'table_id': table_id}] * row_embeddings.shape[0],
            embeddings=batch_row_embeddings
        )
        
        batch_column_ids, batch_row_ids = [], []
        batch_column_embeddings, batch_row_embeddings = [], []

  0%|          | 0/1904 [00:00<?, ?it/s]

### The same, but working with jsonl reader directly

In [8]:
batch_size = 500
batch_column_ids, batch_row_ids = [], []
batch_column_embeddings, batch_row_embeddings = [], []

i = 0 

with jsonlines.open(DefaultPath.data_path.wikitables + 'medium_train_tables.jsonl', 'r') as reader:
    for wikitable in reader:
        
        table_id = wikitable['_id']
        i += 1

        table = rebuild_table(wikitable)
        row_embeddings, column_embeddings = tabenc.full_embedding(table, False, False)

        batch_column_ids.extend([f"{table_id}#{i}" for i in range(column_embeddings.shape[0])])
        batch_row_ids.extend([f"{table_id}#{i}" for i in range(row_embeddings.shape[0])])
        
        batch_column_embeddings.extend(column_embeddings.tolist())
        batch_row_embeddings.extend(row_embeddings.tolist())

        if (i % batch_size == 0 and i != 0):
            collection_column_base.add(
                ids=batch_column_ids,
                embeddings=batch_column_embeddings
            )

            collection_row_base.add(
                ids=batch_row_ids,
                embeddings=batch_row_embeddings
            )
            
            batch_column_ids, batch_row_ids = [], []
            batch_column_embeddings, batch_row_embeddings = [], []
            print(f"Loaded in db {i}...", end='\r')
    
    if batch_column_ids:
        # if case the batch isn't empty but it hasn't been loaded previously
        collection_column_base.add(
            ids=batch_column_ids,
            embeddings=batch_column_embeddings
        )

        collection_row_base.add(
            ids=batch_row_ids,
            embeddings=batch_row_embeddings
        )
        
        batch_column_ids, batch_row_ids = [], []
        batch_column_embeddings, batch_row_embeddings = [], []
        print(f"Loaded in db {i}...", end='\r')

Loaded in db 14000...