diff --git a/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/apis/vector_api.py b/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/apis/vector_api.py index bd4e755b9d..4c876bc89a 100644 --- a/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/apis/vector_api.py +++ b/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/apis/vector_api.py @@ -4,7 +4,7 @@ import time from crowd.eagle_eye.apis import EmbedAPI import itertools -from crowd.eagle_eye.config import QDRANT_HOST, QDRANT_PORT +from crowd.eagle_eye.config import QDRANT_HOST, QDRANT_PORT, QDRANT_API_KEY, IS_DEV_ENV from crowd.eagle_eye.infrastructure.logging import get_logger logger = get_logger(__name__) @@ -15,7 +15,7 @@ class VectorAPI: Class to interact with the vector database. """ - def __init__(self, do_init=False): + def __init__(self, do_init=False, cloud=True): """ Initialize the VectorAPI. @@ -24,17 +24,24 @@ def __init__(self, do_init=False): """ self.collection_name = "crowddev" - if not QDRANT_HOST: - host = "localhost" - else: - host = QDRANT_HOST + if cloud: - if not QDRANT_PORT: - port = 6333 - else: - port = QDRANT_PORT + if IS_DEV_ENV: + self.client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) - self.client = QdrantClient(host=host, port=port) + else: + self.client = QdrantClient( + host=QDRANT_HOST, + port=QDRANT_PORT, + prefer_grpc=True, + api_key=QDRANT_API_KEY, + ) + + else: + if IS_DEV_ENV: + self.client = QdrantClient(host='localhost', port=6333) + else: + self.client = QdrantClient(host='crowd-qdrant', port=6333) if do_init: self.client.recreate_collection( @@ -69,24 +76,28 @@ def _chunks(iterable, batch_size=80): yield chunk chunk = list(itertools.islice(it, batch_size)) - def upsert(self, points): + def upsert(self, points, processed=False): """ Upsert a list of points into the vector database. Args: points ([Point]): points to upsert. + processed (Bool): whether the points have already been turned into Qdrant vectors """ if (len(points) == 0): return - vectors = [ - models.PointStruct( - id=point.id, - payload=point.payload_as_dict(), - vector=point.embed, - ) for point in points - ] + if not processed: + vectors = [ + models.PointStruct( + id=point.id, + payload=point.payload_as_dict(), + vector=point.embed, + ) for point in points + ] + else: + vectors = points for vectors_chunk in VectorAPI._chunks(vectors, batch_size=100): try: @@ -101,6 +112,9 @@ def upsert(self, points): return "OK" def count(self): + """ + Count the number of vectors in a collection. + """ return self.client.count( collection_name=self.collection_name, exact=True, @@ -262,3 +276,21 @@ def keyword_match(self, ndays, exclude, exact_keywords, platform=None): 'exact_keywords': exact_keywords, }) raise e + + def scroll(self, page): + """ + Iterate through points with pagination. + + Args: + next_page (int): the page to fetch + + Returns: + tuple(list, int): (vectors, next page) + """ + return self.client.scroll( + collection_name=self.collection_name, + offset=page, + limit=100, + with_payload=True, + with_vectors=True, + ) diff --git a/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/config.py b/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/config.py index 1eb37276fb..cb4b0614f0 100644 --- a/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/config.py +++ b/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/config.py @@ -21,6 +21,7 @@ QDRANT_HOST = os.environ.get("CROWD_QDRANT_HOST") QDRANT_PORT = os.environ.get("CROWD_QDRANT_PORT") +QDRANT_API_KEY = os.environ.get("CROWD_QDRANT_API_KEY") SQS_HOST = os.environ.get("CROWD_SQS_HOST") SQS_PORT = os.environ.get("CROWD_SQS_PORT") diff --git a/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/sync.py b/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/sync.py index a850494a1e..2b33b81294 100644 --- a/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/sync.py +++ b/premium/eagle-eye/crowd-eagle-eye/crowd/eagle_eye/sync.py @@ -1,64 +1,34 @@ +from pprint import pprint as pp from crowd.eagle_eye.apis.vector_api import VectorAPI -from crowd.eagle_eye.apis.embed_api import EmbedAPI -from crowd.eagle_eye.models import Vector, Payload -import json -import pinecone -import os +from qdrant_client.http import models + import dotenv found = dotenv.find_dotenv(".env.sync") dotenv.load_dotenv(found) +vector_out = VectorAPI(do_init=False, cloud=False) +print("Vector out initialised. It has {} vectors".format(vector_out.count())) +vector_in = VectorAPI(do_init=True, cloud=True) +print("Vector in initialised") -pinecone.init(api_key=os.environ.get("PINECONE_API_KEY"), environment="us-east-1-aws") -index = pinecone.Index("crowddev-prod") -filters = [ - {"platform": {"$in": ["hacker_news"]}, "timestamp": {"$gt": 1666681782}}, - {"platform": {"$in": ["devto"]}, "timestamp": {"$gt": 1666681782}} -] - -for filter in filters: - query_response = index.query( - top_k=10000, - include_values=False, - include_metadata=True, - vector=[0.0] * 2048, - filter=filter - ) - - print('Number of results from Pinecone:', len(query_response['matches'])) - - vectors = [] - vectorAPI = VectorAPI(do_init=True) - embedAPI = EmbedAPI() +number = vector_out.count() +offset = None - for i, match in enumerate(query_response['matches']): - if i and i % 100 == 0: - vectorAPI.upsert(vectors) - vectors = [] - print('Processing match', i) - print('Number of vectors in Qdrant:', vectorAPI.count()) +while True: + vectors = vector_out.scroll(offset) - text = match['metadata']['text'] - if match['metadata']['platform'] == 'hacker_news': - if len(match['metadata']['text']) > 200: - text = match['metadata']['url'] + vectors_to_add = [ + models.PointStruct( + id=vector.id, + payload=vector.payload, + vector=vector.vector, + ) for vector in vectors[0] + ] + vector_in.upsert(vectors_to_add, processed=True) - sourceId_with_platform = match['metadata']['sourceId'] - sourceId = sourceId_with_platform[sourceId_with_platform.find(':') + 1:] - payload = Payload( - id=sourceId, - platform=match['metadata']['platform'], - title=match['metadata']['title'], - username=match['metadata']['username'], - timestamp=match['metadata']['timestamp'], - destination_url=match['metadata']['destination_url'], - url=match['metadata']['url'], - text=text, - postAttributes=json.loads(match['metadata'].get('postAttributes', {})), - userAttributes=json.loads(match['metadata'].get('userAttributes', {})) - ) - combined = f'{match["metadata"]["title"]} {text}' - vector = Vector(sourceId, payload, combined, embedAPI.embed_one(combined)) - vectors.append(vector) + offset = vectors[-1] + if not offset: + break + print(f"Synced {vector_in.count()} of {number} vectors")