In [1]:
import sys
sys.path.append('../src')

from wikidata_dumpreader import WikidataDumpReader
from wikidataDB import WikidataID, WikidataEntity, Session
from sqlalchemy import desc

with Session() as session:
    count_entity = session.query(WikidataEntity).count()

with Session() as session:
    count_id = session.query(WikidataID).count()

In [1]:
import sys
sys.path.append('../src')

from wikidataDB import Session, WikidataID, WikidataEntity
from wikidataEmbed import WikidataEmbed
from tqdm import tqdm
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True)  

with Session() as session:
    # Use yield_per to process large chunks of data at a time.
    entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).yield_per(1000)
    with_desc_tokens = []
    with_desc_embeds = []
    progressbar = tqdm(total=9203531)

    for entity in entities:
        progressbar.update(1)
        text = WikidataEmbed.entity_to_text(entity, with_desc=True)
        tokens_ids = tokenizer.encode(text)
        if len(tokens_ids) < 100:
            print(text)
            print(entity.id)
        with_desc_tokens.append(len(tokens_ids))
        with_desc_embeds.append((len(tokens_ids) // 7500) +1)
        
        # Update the progress description every 1000 iterations
        if progressbar.n % 100 == 0:
            progressbar.set_description(f"Avg Token Size: {sum(with_desc_tokens)/len(with_desc_embeds)}, Avg Embed Size: {sum(with_desc_embeds)/len(with_desc_embeds)}")

    progressbar.close()

  0%|          | 81/9203531 [00:00<17:14:52, 148.22it/s]

NMP22, Nuclear Matrix Protein.
Q3869592
NWA World Tag Team Championship, Professional wrestling championship. Attributes include: 
- instance of, that class of which this subject is a particular example and member; different from P279 (subclass of); for example: K2 is an instance of mountain; volcano is a subclass of mountain (and an instance of volcanic landform): professional wrestling championship, professional wrestling competition
Q3869701


Avg Token Size: 408.24, Avg Embed Size: 1.0:   0%|          | 110/9203531 [00:00<20:30:57, 124.61it/s]

KeyboardInterrupt: 

In [3]:
progressbar.close()

Avg Token Size: 408.24, Avg Embed Size: 1.0:   0%|          | 112/9203531 [05:36<7681:11:10,  3.00s/it]


In [4]:
import sys
sys.path.append('../src')

import json
import astrapy

datastax_token = json.load(open("../API tokens/datastax_token.json"))
ASTRA_DB_APPLICATION_TOKEN = datastax_token['token']
ASTRA_DB_API_ENDPOINT = datastax_token['endpoint']
EMBED_DIM = 1024
SIMILARITY_METRIC = astrapy.constants.VectorMetric.COSINE
COLLECTION_NAME = "wikidata_en_v1"

client = astrapy.DataAPIClient(ASTRA_DB_APPLICATION_TOKEN)
database = client.get_database_by_api_endpoint(ASTRA_DB_API_ENDPOINT)

if COLLECTION_NAME not in database.list_collection_names():
    datastax_db = database.create_collection(
        COLLECTION_NAME,
        dimension=EMBED_DIM,
        metric=SIMILARITY_METRIC
    )
    
datastax_db = database.get_collection(COLLECTION_NAME)

In [6]:
database.list_collection_names()

['wikidata_en_v1']

In [None]:
from transformers import AutoModel

model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to('cuda')

In [2]:
import sys
sys.path.append('../src')

from wikidataDB import Session, WikidataID, WikidataEntity
from tqdm import tqdm

batch_size = 100

with Session() as session:
    entities = session.query(WikidataEntity).join(WikidataID, WikidataEntity.id == WikidataID.id).filter(WikidataID.in_wikipedia == True).yield_per(batch_size)

    progressbar = tqdm(total=9203531)
    text_batch = []
    id_batch = []
    for entity in entities:
        progressbar.update(1)
        text = WikidataEmbed.entity_to_text(entity, with_desc=True)
        text_batch.append(text)
        id_batch.append(entity.id)

        if len(id_batch) >= batch_size:
            embeddings = model.encode(text_batch, task="text-matching")
            datastax_db.insert_many(
                [
                    {
                        "_id": astrapy.ids.UUID("018e65c9-e33d-749b-9386-e848739582f0"),
                        "summary": "A dinner on the Moon",
                        "$vector": [0.2, -0.3, -0.5],
                    },
                    {
                        "summary": "Riding the waves",
                        "tags": ["sport"],
                        "$vector": [0, 0.2, 1],
                    },
                    {
                        "summary": "Friendly aliens in town",
                        "tags": ["scifi"],
                        "$vector": [-0.3, 0, 0.8],
                    },
                    {
                        "summary": "Meeting Beethoven at the dentist",
                        "$vector": [0.2, 0.6, 0],
                    },
                ],
            )
            text_batch = []
            id_batch = []
    progressbar.close()

Entity Types: {'Q': 9203000}: 100%|██████████| 9203531/9203531 [15:47<00:00, 9710.35it/s]  
