# How to embed and push the entities to the Vector Database

In [None]:
import sys
sys.path.append('../src')

import os
os.environ["LANGUAGE"] = 'en' # Specify the language of the textified entities.

from wikidataDB import Session, WikidataID, WikidataEntity
from wikidataEmbed import WikidataTextifier
from wikidataRetriever import AstraDBConnect

import json
from tqdm import tqdm
import os
import pickle
from datetime import datetime
import hashlib
from astrapy import DataAPIClient

In [None]:
MODEL = os.getenv("MODEL", "jina")
EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", 100))
QUERY_BATCH_SIZE = int(os.getenv("QUERY_BATCH_SIZE", 1000))
OFFSET = int(os.getenv("OFFSET", 0))
API_KEY_FILENAME = os.getenv("API_KEY", None)
DUMPDATE = os.getenv("DUMPDATE", '09/18/2024')

COLLECTION_NAME = "wikidata"
LANGUAGE = "en" # Language of the SQLite database
TEXTIFIER_LANGUAGE = "rdf" # Language of the textifier (The name of the python script found in src/language_variables)

API_KEY_FILENAME = os.listdir("../API_tokens/datastax_wikidata.json")
datastax_token = json.load(open(f"../API_tokens/{API_KEY_FILENAME}"))

textifier = WikidataTextifier(language=TEXTIFIER_LANGUAGE)
graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model=MODEL, batch_size=EMBED_BATCH_SIZE, cache_embeddings=False)

#### Push Wikidata entities with QIDs in a sample data

In [None]:
sample_ids = pickle.load(open("../data/Evaluation Data/Sample IDs (EN).pkl", "rb"))
sample_ids = sample_ids[sample_ids['In Wikipedia']]
total_entities = len(sample_ids)

def get_entity(session):
    sample_qids = list(sample_ids['QID'].values)[OFFSET:]
    sample_qid_batches = [sample_qids[i:i + QUERY_BATCH_SIZE] for i in range(0, len(sample_qids), QUERY_BATCH_SIZE)]

    # For each batch of sample QIDs, fetch the entities from the database
    for qid_batch in sample_qid_batches:
        entities = session.query(WikidataEntity).filter(WikidataEntity.id.in_(qid_batch)).yield_per(QUERY_BATCH_SIZE)
        for entity in entities:
            yield entity

In [None]:
with tqdm(total=total_entities-OFFSET) as progressbar:
    with Session() as session:
        entity_generator = get_entity(session)
        doc_batch = []
        ids_batch = []

        for entity in entity_generator:
            progressbar.update(1)
            chunks = textifier.chunk_text(entity, graph_store.tokenizer, max_length=graph_store.max_token_size)
            for chunk_i in range(len(chunks)):
                md5_hash = hashlib.md5(chunks[chunk_i].encode('utf-8')).hexdigest()
                metadata={
                    "MD5": md5_hash,
                    "Label": entity.label,
                    "Description": entity.description,
                    "Aliases": entity.aliases,
                    "Date": datetime.now().isoformat(),
                    "QID": entity.id,
                    "ChunkID": chunk_i+1,
                    "Language": LANGUAGE,
                    "DumpDate": DUMPDATE
                }
                graph_store.add_document(id=f"{entity.id}_{LANGUAGE}_{chunk_i+1}", text=chunks[chunk_i], metadata=metadata)

            tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not wokring in docker compose. This is the alternative

        graph_store.push_batch()

#### Push all Wikidata entities found in the SQLite database

In [None]:
with Session() as session:
    total_entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).offset(OFFSET).yield_per(QUERY_BATCH_SIZE)

def get_entity(session):
    entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).offset(OFFSET).yield_per(QUERY_BATCH_SIZE)
    for entity in entities:
        yield entity

In [None]:
with tqdm(total=total_entities-OFFSET) as progressbar:
    with Session() as session:
        entity_generator = get_entity(session)
        doc_batch = []
        ids_batch = []

        for entity in entity_generator:
            progressbar.update(1)
            chunks = textifier.chunk_text(entity, graph_store.tokenizer, max_length=graph_store.max_token_size)
            for chunk_i in range(len(chunks)):
                md5_hash = hashlib.md5(chunks[chunk_i].encode('utf-8')).hexdigest()
                metadata={
                    "MD5": md5_hash,
                    "Label": entity.label,
                    "Description": entity.description,
                    "Aliases": entity.aliases,
                    "Date": datetime.now().isoformat(),
                    "QID": entity.id,
                    "ChunkID": chunk_i+1,
                    "Language": LANGUAGE,
                    "DumpDate": DUMPDATE
                }
                graph_store.add_document(id=f"{entity.id}_{LANGUAGE}_{chunk_i+1}", text=chunks[chunk_i], metadata=metadata)

            tqdm.write(progressbar.format_meter(progressbar.n, progressbar.total, progressbar.format_dict["elapsed"])) # tqdm is not wokring in docker compose. This is the alternative

        graph_store.push_batch()

#### Copy from one Astra collection to another

In [None]:
datastax_token = json.load(open("../API_tokens/datastax_wikidata_nvidia.json"))

COLLECTION_NAME = 'wikidata_1'
client = DataAPIClient(datastax_token['ASTRA_DB_APPLICATION_TOKEN'])
database0 = client.get_database(datastax_token['ASTRA_DB_API_ENDPOINT'])
wikiDataCollection = database0.get_collection(COLLECTION_NAME)

COLLECTION_NAME = 'wikidata_2'
graph_store = AstraDBConnect(datastax_token, COLLECTION_NAME, model='jina', batch_size=4, cache_embeddings=False)

with tqdm(total=1347786) as progressbar:
    for item in wikiDataCollection.find():
        progressbar.update(1)
        if item['metadata']['QID'] in sample_ids['QID'].values:
            graph_store.add_document(id=item['_id'], text=item['content'], metadata=item['metadata'])

    graph_store.push_batch()

#### Check if all sample IDs are in Astra

In [None]:
datastax_token = json.load(open("../API_tokens/datastax_wikidata_nvidia.json"))
COLLECTION_NAME = 'wikidata_1'

client = DataAPIClient(datastax_token['ASTRA_DB_APPLICATION_TOKEN'])
database0 = client.get_database(datastax_token['ASTRA_DB_API_ENDPOINT'])
wikiDataCollection = database0.get_collection(COLLECTION_NAME)

sample_ids = pickle.load(open("../data/Evaluation Data/Sample IDs (EN).pkl", "rb"))
sample_ids[f'in_{COLLECTION_NAME}'] = False

for qid in tqdm((sample_ids[~sample_ids['in_wikidata_test_v1']]['QID'].values)):
    item = wikiDataCollection.find_one({'metadata.QID': f'{qid}', 'metadata.Language': 'en'})
    if item is not None:
        sample_ids.loc[sample_ids['QID'] == qid, 'in_wikidata_test_v1'] = True