In [14]:
import sys
sys.path.append('../src')

from wikidataDB import Session, WikidataID, WikidataEntity
from wikidataEmbed import WikidataTextifier, JinaAIEmbeddings
from tqdm import tqdm
import json
import os
import torch
from langchain_core.documents import Document
from ragstack_langchain.graph_store import CassandraGraphStore
import cassio

In [None]:
from datasets import load_dataset

test_dataset = load_dataset("json", data_files="test.json", split="train")

test_dataset.save_to_disk("test.hf")

In [None]:
datastax_token = json.load(open("../API_tokens/datastax_token.json"))
os.environ["ASTRA_DB_DATABASE_ID"] = datastax_token['database_id']
os.environ["ASTRA_DB_APPLICATION_TOKEN"] = datastax_token['token']

BATCH_SIZE = 1
BATCH_LENGTH = 5000

cassio.init(auto=True)
embeddings = JinaAIEmbeddings(embedding_dim=1024)
textifier = WikidataTextifier(with_claim_aliases=True, with_property_aliases=False)
graph_store = CassandraGraphStore(
    embeddings,
    node_table="Wikidata_entities_v1",
)

In [16]:
with open("../data/Wikidata/pushed_embeddings.json", "r+") as file:
    json_str = file.read().strip()
    if json_str.endswith(','):
        json_str = json_str[:-1]+"}"
    prev_data = json.loads(json_str)

In [None]:
offset = len([k for k,q in prev_data.items() if "_1" in k]) -5
offset

In [None]:
with tqdm(total=9203531) as progressbar:
    with Session() as session:
        entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).offset(offset).yield_per(1000)
        doc_batch = []
        batch_length = 0
        progressbar.update(offset)

        for entity in entities:
            progressbar.update(1)
            chunks = embeddings.chunk_text(entity, textifier)
            for chunk_i in range(len(chunks)):
                doc = Document(page_content=chunks[chunk_i], metadata={"QID": entity.id, "ChunkID": chunk_i+1})
                if f"{entity.id}_{chunk_i+1}" not in prev_data:
                    doc_batch.append(doc)
                    batch_length += len(chunks[chunk_i])

                if batch_length >= BATCH_LENGTH:
                    try:
                        graph_store.add_documents(doc_batch)
                        torch.cuda.empty_cache()
                        with open("../data/Wikidata/pushed_embeddings.json", "a+") as file:
                            file.write(", ".join([f"\"{d.metadata["QID"]}_{d.metadata["ChunkID"]}\": 1" for d in doc_batch]) +", ")

                        progressbar.set_description(f"Batch Size: {len(doc_batch)}")
                        doc_batch = []
                        batch_length = 0
                    except Exception as e:
                        torch.cuda.empty_cache()
                        progressbar.set_description(f"Batch Size: 1")
                        print(e)

                        while len(doc_batch) > 0:
                            doc = doc_batch.pop()
                            embeddings.model = embeddings.model.to('cpu')
                            graph_store.add_documents([doc])
                            torch.cuda.empty_cache()
                            with open("../data/Wikidata/pushed_embeddings.json", "a+") as file:
                                file.write(f"\"{doc.metadata["QID"]}_{doc.metadata["ChunkID"]}\": 1, ")
                        batch_length = 0
                        embeddings.model = embeddings.model.to('cuda')

In [18]:
import torch
torch.cuda.empty_cache()

In [32]:
vector_retriever = graph_store.as_retriever(search_kwargs={"k": 1000, "depth": 0})
results = vector_retriever.get_relevant_documents("Is this a question?")

In [6]:
cassio.init(auto=True)
from cassio.config import check_resolve_keyspace, check_resolve_session

session = check_resolve_session()
keyspace = check_resolve_keyspace()

In [None]:
session.execute(f"DROP TABLE IF EXISTS {keyspace}.wikidata_entities")

In [12]:
import sys
sys.path.append('../src')

from wikidataDB import Session, WikidataID, WikidataEntity
from wikidataEmbed import WikidataTextifier

import json
from langchain_astradb import AstraDBVectorStore
from langchain_core.documents import Document
from astrapy.info import CollectionVectorServiceOptions
from transformers import AutoTokenizer
from tqdm import tqdm
from langchain_core.documents import Document
import requests
import time

datastax_token = json.load(open("../API_tokens/datastax_wikidata_nvidia.json"))
ASTRA_DB_DATABASE_ID = datastax_token['ASTRA_DB_DATABASE_ID']
ASTRA_DB_APPLICATION_TOKEN = datastax_token['ASTRA_DB_APPLICATION_TOKEN']
ASTRA_DB_API_ENDPOINT = datastax_token["ASTRA_DB_API_ENDPOINT"]
ASTRA_DB_KEYSPACE = datastax_token["ASTRA_DB_KEYSPACE"]

BATCH_SIZE = 100

textifier = WikidataTextifier(with_claim_aliases=False, with_property_aliases=False)
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-unsupervised', trust_remote_code=True, clean_up_tokenization_spaces=False)

collection_vector_service_options = CollectionVectorServiceOptions(
    provider="nvidia",
    model_name="NV-Embed-QA"
)

graph_store = AstraDBVectorStore(
    collection_name="wikidata",
    collection_vector_service_options=collection_vector_service_options,
    token=ASTRA_DB_APPLICATION_TOKEN,
    api_endpoint=ASTRA_DB_API_ENDPOINT,
    namespace=ASTRA_DB_KEYSPACE,
)

vector_retriever = graph_store.as_retriever(search_kwargs={"k": 1000, "depth": 0})
results = vector_retriever.get_relevant_documents("Is this a question?")

In [5]:
import cassio
from cassio.config import check_resolve_keyspace, check_resolve_session

cassio.init(database_id=ASTRA_DB_DATABASE_ID, token=ASTRA_DB_APPLICATION_TOKEN)
session_cassio = check_resolve_session()
keyspace = check_resolve_keyspace()

In [None]:
offset = 8342800
with tqdm(total=9203531) as progressbar:
    with Session() as session:
        entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).offset(offset).yield_per(BATCH_SIZE)
        progressbar.update(offset)
        n = 0
        for entity in entities:
            progressbar.update(1)
            results = session_cassio.execute(f"SELECT * FROM {ASTRA_DB_KEYSPACE}.wikidata WHERE key = (1, '{str(entity.id)}_1') LIMIT 1;")
            n += 1
            if not results.one():
                break

In [None]:
offset = 8342800
with tqdm(total=9203531) as progressbar:
    with Session() as session:
        entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).offset(offset).yield_per(BATCH_SIZE)
        progressbar.update(offset)
        doc_batch = []
        ids_batch = []

        for entity in entities:
            progressbar.update(1)
            chunks = textifier.chunk_text(entity, tokenizer)
            for chunk_i in range(len(chunks)):
                doc = Document(page_content=chunks[chunk_i], metadata={"QID": entity.id, "ChunkID": chunk_i+1})
                doc_batch.append(doc)
                ids_batch.append(f"{entity.id}_{chunk_i+1}")

                if len(doc_batch) >= BATCH_SIZE:
                    try:
                        graph_store.add_documents(doc_batch, ids=ids_batch)
                        doc_batch = []
                        ids_batch = []
                    except Exception as e:
                        print(e)
                        while True:
                            try:
                                response = requests.get("https://www.google.com", timeout=5)
                                if response.status_code == 200:
                                    break
                            except Exception as e:
                                print("Waiting for internet connection...")
                                time.sleep(5)

        if len(doc_batch) > 0:
            graph_store.add_documents(doc_batch, ids=ids_batch)

In [None]:
graph_store.add_documents(doc_batch, ids=ids_batch)