In [None]:
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
import torch


## Load data

In [7]:
df = pd.read_csv("..//data//interim//articles_with_score_df.csv", usecols=['title'])

### Prepare embedding

In [None]:
if torch.cuda.is_available():
    print("CUDA is available: ", torch.cuda.is_available())   
    print("Number of CUDA devices: ", torch.cuda.device_count())
    print("CUDA current device: ", torch.cuda.current_device())
    print("CUDA device name: ", torch.cuda.get_device_name(0))

    model = SentenceTransformer("all-MiniLM-L6-v2", device='cuda')
else:
    model = SentenceTransformer("all-MiniLM-L6-v2", device='cpu')

In [None]:
titles = df['title'].tolist()

titles_embeddings = []

for title in tqdm(titles):
    embedding = model.encode([title])
    titles_embeddings.append(embedding[0])

df['embedding'] = titles_embeddings

In [10]:
df.to_pickle('..//data//interim//titles_with_embedings.pkl')

## Load data to Vector DB (chroma)

In [None]:
import chromadb
from chromadb.config import Settings

In [11]:
df = pd.read_pickle('..//data//interim//titles_with_embedings.pkl')

In [12]:
chroma_client = chromadb.PersistentClient(path="../data/chroma")
#chroma_client = chromadb.HttpClient(host="localhost", port=8000, settings=Settings(allow_reset=True, anonymized_telemetry=False))

collection_status = False
while collection_status != True:
    try:
        document_collection = chroma_client.get_or_create_collection(name="articles_with_score")
        collection_status = True
    except Exception as e:
        print(e)
        pass

In [None]:
batch_size = 5000
last_confirmed_id = 0

for batch_start in tqdm(range(0, df.shape[0], batch_size), desc='Batches', unit='batch'):
    batch_df = df.iloc[batch_start:batch_start + batch_size]

    batch_embeddings = batch_df['embedding'].apply(lambda x: x.tolist()).tolist()
    batch_documents = batch_df['title'].tolist()
    batch_metadatas = [{'year': row['year'], 'n_citation': row['n_citation'], 'gov_score': row['gov_score']} for index, row in batch_df.iterrows()]
    batch_ids = [str(index + 1) for index in batch_df.index]
    
    document_collection.add(
        embeddings=batch_embeddings,
        documents=batch_documents,
        metadatas=batch_metadatas,
        ids=batch_ids
    )

    last_confirmed_id = batch_df.index[-1] + 1


Batches: 100%|████████████████████████████████████████████████████████████████████| 171/171 [20:01<00:00,  7.03s/batch]


Size of the collection: 850406


### Health check

In [None]:
try:
    collection_size = document_collection.count()
    print("Size of the collection:", collection_size)
except Exception as e:
    print("Failed to get collection size:", e)

if document_collection.count() == df.shape[0]:
    print("Correct size of the articles collection:", document_collection.count())
else:
    print("Data inconsistency detected!!!")

Correct size of the articles collection: 850406
