In [1]:
import os
import jsonlines
import random
import numpy as np
import pandas as pd
from time import time
from tqdm.notebook import tqdm
from itertools import product

from code.fasttext.embedding_utils import TableEncoder
from code.utils.settings import DefaultPath
from code.utils.utils import rebuild_table

import chromadb

In [2]:
tabenc = TableEncoder()

In [3]:
sloth_tested_pairs = pd.read_csv(
    DefaultPath.data_path.wikitables + 'train_set_turl_malaguti.csv',
    nrows=1000
)

wikitables = {}

with jsonlines.open(DefaultPath.data_path.wikitables + 'small_train_tables.jsonl', 'r') as reader:
    for obj in reader:
        wikitables[obj['_id']] = obj

ids = list(set(sloth_tested_pairs['r_id']).union(sloth_tested_pairs['s_id']))
len(ids)        

1905

In [4]:
table_random_id = ids[random.randint(0, len(ids))]
table_random_id

'30454545-1'

In [5]:
table = rebuild_table(wikitables[table_random_id])
table

Unnamed: 0,Rank,Nation,Gold,Silver,Bronze,Total
0,1,Russia,80,14,10,104
1,2,Turkey,15,29,22,66
2,3,Ukraine,8,29,12,49
3,4,Italy,7,6,4,17
4,5,Netherlands,5,-,4,9
5,6,Bulgaria,3,2,2,7
6,7,Romania,2,8,9,19
7,8,United Kingdom,2,1,2,5
8,9,Armenia,2,-,10,12
9,10,Israel,1,4,5,10


In [6]:
row_embeddings, columns_embeddings = tabenc.full_embedding(table)

In [7]:
table.shape

(25, 6)

In [8]:
row_embeddings.shape, columns_embeddings.shape

((25, 300), (6, 300))

In [9]:
row_embeddings.dtype, columns_embeddings.dtype

(dtype('float32'), dtype('float32'))

In [10]:
row_embeddings.nbytes, columns_embeddings.nbytes

(30000, 7200)

In [11]:
# memory usage per embedding
row_embeddings.nbytes / row_embeddings.shape[0], columns_embeddings.nbytes / columns_embeddings.shape[0]

(1200.0, 1200.0)

In [12]:
row_embeddings.shape[0] * row_embeddings.shape[1] * 4, \
    columns_embeddings.shape[0] * columns_embeddings.shape[1] * 4

(30000, 7200)

In [13]:
row_embeddings.shape[0] * row_embeddings.shape[1] * 4 + \
    columns_embeddings.shape[0] * columns_embeddings.shape[1] * 4

37200

In [14]:
chroma_client = chromadb.PersistentClient(
    path=DefaultPath.db_path.chroma + f'tester0',
    settings=chromadb.config.Settings(anonymized_telemetry=False)
    )

In [15]:
chroma_client.list_collections()

[]

In [16]:
try:
    chroma_client.delete_collection("rows")
except: pass
finally:
    collection_row = chroma_client.create_collection(name='rows')
    print('Collections "rows" created')

Collections "rows" created


In [17]:
try:
    chroma_client.delete_collection("columns")
    print('Collections "columns" deleted')
except: pass
finally:
    collection_column = chroma_client.create_collection(name='columns')
    print('Collections "columns" created')

Collections "columns" created


In [18]:
collection_row.add(
    ids=list(map(lambda id: f'{table_random_id}#{id}', range(row_embeddings.shape[0]))),
    embeddings=row_embeddings
)

#01 :  ['chroma.sqlite3']
#02 :  ['chroma.sqlite3']
#03 :  ['chroma.sqlite3']
#04 :  ['chroma.sqlite3']
#05 :  ['chroma.sqlite3']
#06 :  ['chroma.sqlite3']
#07 :  ['chroma.sqlite3']
#08 :  ['chroma.sqlite3']


In [32]:
type(collection_row._client).__mro__

(chromadb.api.segment.SegmentAPI,
 chromadb.api.ServerAPI,
 chromadb.api.BaseAPI,
 chromadb.api.AdminAPI,
 chromadb.config.Component,
 abc.ABC,
 overrides.enforce.EnforceOverrides,
 object)

In [22]:
type(collection_row._client._manager).__mro__

(chromadb.segment.impl.manager.local.LocalSegmentManager,
 chromadb.segment.SegmentManager,
 chromadb.config.Component,
 abc.ABC,
 overrides.enforce.EnforceOverrides,
 object)