In [None]:
from qdrant_client import QdrantClient, models
from app.core.config import settings as config
from app.utils.utils import split_docs
from app.ingestion.web_loader.bs_loader import load_web_docs
from app.ingestion.web_loader.bs_utils import urls
from app.db.vector_db import VectorDB
import os
import pickle
from tenacity import (
    retry,
    wait_exponential,
    stop_after_attempt,
    retry_if_exception_type,
)
import requests
from fastembed import SparseTextEmbedding

In [None]:
client = QdrantClient(config.qdrant_url, api_key=config.qdrant_api_key)

In [None]:
client.create_collection(
    collection_name=config.qdrant_collection_name,
    vectors_config={
        "dense": models.VectorParams(
            size=config.embeddings_dim, distance=models.Distance.COSINE
        )
    },
    sparse_vectors_config={
        "bm25": models.SparseVectorParams(modifier=models.Modifier.IDF)
    },
)

In [None]:
# client.delete_collection(collection_name=config.qdrant_collection_name)

In [None]:
client.get_collections()

In [None]:
CACHE_FILE = "all_docs.pkl"

if os.path.exists(CACHE_FILE):
    print("Loading cached documents...")
    with open(CACHE_FILE, "rb") as f:
        all_docs = pickle.load(f)
else:
    all_docs = load_web_docs(urls)
    print("Saving documents to cache...")
    with open(CACHE_FILE, "wb") as f:
        pickle.dump(all_docs, f)

In [None]:
chunks = split_docs(all_docs)

In [None]:
print(len(chunks))
print(chunks[20])

In [None]:
vector_db = VectorDB(config)

sparse_model = SparseTextEmbedding("Qdrant/bm25")

dense_embeddings = []
sparse_embeddings = []

# Embeddings Generation

In [None]:
@retry(
    retry=retry_if_exception_type((requests.exceptions.RequestException, Exception)),
    wait=wait_exponential(multiplier=1, min=4, max=60),  # Exponential backoff
    stop=stop_after_attempt(5),
)
def get_embedding_with_retry(text):

    dense_embedding = vector_db.get_embeddings(text)
    sparse_embedding = list(sparse_model.embed([text]))[0]

    return dense_embedding, sparse_embedding


for i, chunk in enumerate(chunks):
    try:
        dense_embedding, sparse_embedding = get_embedding_with_retry(chunk.text)
        dense_embeddings.append(dense_embedding)
        sparse_embeddings.append(sparse_embedding)

        if i % 100 == 0:
            print(f"Processed {i}/{len(chunks)} chunks")

    except Exception as e:
        print(f"Failed after retries on chunk {i}: {str(e)}")

In [None]:
print(dense_embeddings[4])

In [None]:
print(sparse_embeddings)

# Cache and Load Embeddings

In [None]:
DENSE_EMBEDDINGS_CACHE_FILE = "dense_embeddings.pkl"

with open(DENSE_EMBEDDINGS_CACHE_FILE, "wb") as f:
    pickle.dump(dense_embeddings, f)

In [None]:
if os.path.exists(DENSE_EMBEDDINGS_CACHE_FILE):
    print("Loading cached dense embeddings...")
    with open(DENSE_EMBEDDINGS_CACHE_FILE, "rb") as f:
        dense_embeddings = pickle.load(f)

    # Regenerate sparse embeddings quickly
    print("Regenerating sparse embeddings...")
    sparse_embeddings = []
    for chunk in chunks:
        sparse_embedding = list(sparse_model.embed([chunk.text]))[0]
        sparse_embeddings.append(sparse_embedding)

In [None]:
print(len(dense_embeddings))
print(len(sparse_embeddings))

# Add Embeddings to VectorDB

In [None]:
document_ids = vector_db.add_documents(
    docs=chunks,
    dense_embeddings=dense_embeddings,
    sparse_embeddings=sparse_embeddings,
)

In [None]:
collection_info = vector_db.client.get_collection(vector_db.collection_name)
print(f"Collection info: {collection_info}")

# Count the number of points in the collection
point_count = vector_db.client.count(vector_db.collection_name)
print(f"Number of documents in vector store: {point_count}")

# Test Query against VectorDB

In [None]:
query = "what is sutd?"
query_embedding = next(sparse_model.query_embed(query))
results = vector_db.client.query_points(
    collection_name=vector_db.collection_name,
    query=models.SparseVector(**query_embedding.as_object()),
    limit=3,
    using="bm25",
)
print(results)

In [None]:
query = "what is sutd?"
query_embedding = vector_db.get_embeddings(query)
results = vector_db.client.query_points(
    collection_name=vector_db.collection_name,
    query=query_embedding,
    limit=3,
    using="dense",
)
print(results)