diff --git a/vector_search/management/commands/create_qdrant_collections.py b/vector_search/management/commands/create_qdrant_collections.py index 762abc7d26..17b3a419ed 100644 --- a/vector_search/management/commands/create_qdrant_collections.py +++ b/vector_search/management/commands/create_qdrant_collections.py @@ -3,7 +3,7 @@ from django.core.management.base import BaseCommand from vector_search.utils import ( - create_qdrand_collections, + create_qdrant_collections, ) @@ -26,8 +26,8 @@ def handle(self, *args, **options): # noqa: ARG002 """Create Qdrant collections""" if options["force"]: - create_qdrand_collections(force_recreate=True) + create_qdrant_collections(force_recreate=True) else: - create_qdrand_collections(force_recreate=False) + create_qdrant_collections(force_recreate=False) self.stdout.write("Created Qdrant collections") diff --git a/vector_search/management/commands/generate_embeddings.py b/vector_search/management/commands/generate_embeddings.py index 324654da71..3e75f6fe4f 100644 --- a/vector_search/management/commands/generate_embeddings.py +++ b/vector_search/management/commands/generate_embeddings.py @@ -6,7 +6,7 @@ from main.utils import clear_search_cache, now_in_utc from vector_search.tasks import embed_learning_resources_by_id, start_embed_resources from vector_search.utils import ( - create_qdrand_collections, + create_qdrant_collections, ) @@ -42,6 +42,12 @@ def add_arguments(self, parser): action="store_true", help="Skip embedding content files", ) + parser.add_argument( + "--overwrite", + dest="overwrite", + action="store_true", + help="Force overwrite existing embeddings", + ) for object_type in sorted(LEARNING_RESOURCE_TYPES): parser.add_argument( @@ -71,7 +77,7 @@ def handle(self, *args, **options): # noqa: ARG002 self.stdout.write(f" --{object_type}s") return if options["recreate_collections"]: - create_qdrand_collections(force_recreate=True) + create_qdrant_collections(force_recreate=True) if options["resource-ids"]: task = embed_learning_resources_by_id.delay( [ @@ -79,10 +85,13 @@ def handle(self, *args, **options): # noqa: ARG002 for resource_id in options["resource-ids"].split(",") ], skip_content_files=options["skip_content_files"], + overwrite=options["overwrite"], ) else: task = start_embed_resources.delay( - indexes_to_update, skip_content_files=options["skip_content_files"] + indexes_to_update, + skip_content_files=options["skip_content_files"], + overwrite=options["overwrite"], ) self.stdout.write( f"Started celery task {task} to index content for the following" diff --git a/vector_search/tasks.py b/vector_search/tasks.py index dcba089cb6..1a5c5948e3 100644 --- a/vector_search/tasks.py +++ b/vector_search/tasks.py @@ -31,10 +31,7 @@ chunks, now_in_utc, ) -from vector_search.constants import ( - RESOURCES_COLLECTION_NAME, -) -from vector_search.utils import embed_learning_resources, filter_existing_qdrant_points +from vector_search.utils import embed_learning_resources log = logging.getLogger(__name__) @@ -46,7 +43,7 @@ retry_backoff=True, rate_limit="600/m", ) -def generate_embeddings(ids, resource_type): +def generate_embeddings(ids, resource_type, overwrite): """ Generate learning resource embeddings and index in Qdrant @@ -57,7 +54,7 @@ def generate_embeddings(ids, resource_type): """ try: with wrap_retry_exception(*SEARCH_CONN_EXCEPTIONS): - embed_learning_resources(ids, resource_type) + embed_learning_resources(ids, resource_type, overwrite) except (RetryError, Ignore): raise except SystemExit as err: @@ -69,7 +66,7 @@ def generate_embeddings(ids, resource_type): @app.task(bind=True) -def start_embed_resources(self, indexes, skip_content_files): +def start_embed_resources(self, indexes, skip_content_files, overwrite): """ Celery task to embed all learning resources for given indexes @@ -89,7 +86,7 @@ def start_embed_resources(self, indexes, skip_content_files): blocklisted_ids = load_course_blocklist() index_tasks = [ - generate_embeddings.si(ids, COURSE_TYPE) + generate_embeddings.si(ids, COURSE_TYPE, overwrite) for ids in chunks( Course.objects.filter(learning_resource__published=True) .exclude(learning_resource__readable_id=blocklisted_ids) @@ -123,10 +120,7 @@ def start_embed_resources(self, indexes, skip_content_files): ) index_tasks = index_tasks + [ - generate_embeddings.si( - ids, - CONTENT_FILE_TYPE, - ) + generate_embeddings.si(ids, CONTENT_FILE_TYPE, overwrite) for ids in chunks( run_contentfiles, chunk_size=settings.QDRANT_CHUNK_SIZE, @@ -150,10 +144,7 @@ def start_embed_resources(self, indexes, skip_content_files): chunk_size=settings.QDRANT_CHUNK_SIZE, ): index_tasks.append( - generate_embeddings.si( - ids, - resource_type, - ) + generate_embeddings.si(ids, resource_type, overwrite) ) except: # noqa: E722 error = "start_embed_resources threw an error" @@ -166,7 +157,7 @@ def start_embed_resources(self, indexes, skip_content_files): @app.task(bind=True) -def embed_learning_resources_by_id(self, ids, skip_content_files): +def embed_learning_resources_by_id(self, ids, skip_content_files, overwrite): """ Celery task to embed specific resources @@ -190,10 +181,7 @@ def embed_learning_resources_by_id(self, ids, skip_content_files): embed_resources = resources.filter(resource_type=resource_type) [ index_tasks.append( - generate_embeddings.si( - chunk_ids, - resource_type, - ) + generate_embeddings.si(chunk_ids, resource_type, overwrite) ) for chunk_ids in chunks( embed_resources.order_by("id").values_list("id", flat=True), @@ -216,10 +204,7 @@ def embed_learning_resources_by_id(self, ids, skip_content_files): ).order_by("id") content_ids = run_contentfiles.values_list("id", flat=True) index_tasks = index_tasks + [ - generate_embeddings.si( - ids, - CONTENT_FILE_TYPE, - ) + generate_embeddings.si(ids, CONTENT_FILE_TYPE, overwrite) for ids in chunks( content_ids, chunk_size=settings.QDRANT_CHUNK_SIZE, @@ -249,27 +234,19 @@ def embed_new_learning_resources(self): published=True, created_on__gt=since, ).exclude(resource_type=CONTENT_FILE_TYPE) - existing_readable_ids = [ - learning_resource.readable_id for learning_resource in new_learning_resources - ] - filtered_readable_ids = filter_existing_qdrant_points( - values=existing_readable_ids, - lookup_field="readable_id", - collection_name=RESOURCES_COLLECTION_NAME, - ) - filtered_resources = LearningResource.objects.filter( - readable_id__in=filtered_readable_ids + + resource_types = list( + new_learning_resources.values_list("resource_type", flat=True) ) - resource_types = list(filtered_resources.values_list("resource_type", flat=True)) tasks = [] for resource_type in resource_types: tasks.extend( [ - generate_embeddings.si(ids, resource_type) + generate_embeddings.si(ids, resource_type, overwrite=False) for ids in chunks( - filtered_resources.filter(resource_type=resource_type).values_list( - "id", flat=True - ), + new_learning_resources.filter( + resource_type=resource_type + ).values_list("id", flat=True), chunk_size=settings.QDRANT_CHUNK_SIZE, ) ] diff --git a/vector_search/tasks_test.py b/vector_search/tasks_test.py index 9c09acc615..7e87d18368 100644 --- a/vector_search/tasks_test.py +++ b/vector_search/tasks_test.py @@ -66,9 +66,13 @@ def test_start_embed_resources(mocker, mocked_celery, index): ) with pytest.raises(mocked_celery.replace_exception_class): - start_embed_resources.delay([index], skip_content_files=True) + start_embed_resources.delay([index], skip_content_files=True, overwrite=True) - generate_embeddings_mock.si.assert_called_once_with(resource_ids, index) + generate_embeddings_mock.si.assert_called_once_with( + resource_ids, + index, + True, # noqa: FBT003 + ) assert mocked_celery.replace.call_count == 1 assert mocked_celery.replace.call_args[0][1] == mocked_celery.chain.return_value @@ -101,7 +105,7 @@ def test_start_embed_resources_without_settings(mocker, mocked_celery, index): generate_embeddings_mock = mocker.patch( "vector_search.tasks.generate_embeddings", autospec=True ) - start_embed_resources.delay([index], skip_content_files=True) + start_embed_resources.delay([index], skip_content_files=True, overwrite=True) generate_embeddings_mock.si.assert_not_called() @@ -172,7 +176,9 @@ def test_embed_learning_resources_by_id(mocker, mocked_celery): content_ids.append(cf.id) with pytest.raises(mocked_celery.replace_exception_class): - embed_learning_resources_by_id.delay(resource_ids, skip_content_files=False) + embed_learning_resources_by_id.delay( + resource_ids, skip_content_files=False, overwrite=True + ) for mock_call in generate_embeddings_mock.si.mock_calls[1:]: assert mock_call.args[0][0] in content_ids assert mock_call.args[1] == "content_file" diff --git a/vector_search/utils.py b/vector_search/utils.py index 054ae4e1d4..43f43eccca 100644 --- a/vector_search/utils.py +++ b/vector_search/utils.py @@ -63,7 +63,7 @@ def points_generator( yield models.PointStruct(id=idx, payload=payload, vector=point_vector) -def create_qdrand_collections(force_recreate): +def create_qdrant_collections(force_recreate): """ Create or recreate QDrant collections @@ -174,8 +174,10 @@ def _process_resource_embeddings(serialized_resources): docs.append( f"{doc.get('title')} {doc.get('description')} {doc.get('full_description')}" ) - embeddings = encoder.embed_documents(docs) - return points_generator(ids, metadata, embeddings, vector_name) + if len(docs) > 0: + embeddings = encoder.embed_documents(docs) + return points_generator(ids, metadata, embeddings, vector_name) + return None def _chunk_documents(encoder, texts, metadatas): @@ -282,10 +284,12 @@ def _process_content_embeddings(serialized_content): except Exception as e: # noqa: BLE001 msg = f"Exceeded multi-vector max size: {e}" logger.warning(msg) - return points_generator(ids, metadata, embeddings, vector_name) + if ids: + return points_generator(ids, metadata, embeddings, vector_name) + return None -def embed_learning_resources(ids, resource_type): +def embed_learning_resources(ids, resource_type, overwrite): """ Embed learning resources @@ -296,20 +300,47 @@ def embed_learning_resources(ids, resource_type): client = qdrant_client() - resources_collection_name = RESOURCES_COLLECTION_NAME - content_files_collection_name = CONTENT_FILES_COLLECTION_NAME - - create_qdrand_collections(force_recreate=False) + create_qdrant_collections(force_recreate=False) if resource_type != CONTENT_FILE_TYPE: - serialized_resources = serialize_bulk_learning_resources(ids) - collection_name = resources_collection_name + serialized_resources = list(serialize_bulk_learning_resources(ids)) + points = [ + (vector_point_id(serialized["readable_id"]), serialized) + for serialized in serialized_resources + ] + if not overwrite: + filtered_point_ids = filter_existing_qdrant_points_by_ids( + [point[0] for point in points], + collection_name=RESOURCES_COLLECTION_NAME, + ) + serialized_resources = [ + point[1] for point in points if point[0] in filtered_point_ids + ] + + collection_name = RESOURCES_COLLECTION_NAME points = _process_resource_embeddings(serialized_resources) else: - serialized_resources = serialize_bulk_content_files(ids) - collection_name = content_files_collection_name + serialized_resources = list(serialize_bulk_content_files(ids)) + collection_name = CONTENT_FILES_COLLECTION_NAME + points = [ + ( + vector_point_id( + f"{doc['resource_readable_id']}.{doc['run_readable_id']}.{doc['key']}.0" + ), + doc, + ) + for doc in serialized_resources + ] + if not overwrite: + filtered_point_ids = filter_existing_qdrant_points_by_ids( + [point[0] for point in points], + collection_name=CONTENT_FILES_COLLECTION_NAME, + ) + serialized_resources = [ + point[1] for point in points if point[0] in filtered_point_ids + ] points = _process_content_embeddings(serialized_resources) - - client.upload_points(collection_name, points=points, wait=False) + if points: + client.upload_points(collection_name, points=points, wait=False) def _resource_vector_hits(search_result): @@ -395,6 +426,17 @@ def vector_search( } +def document_exists(document, collection_name=RESOURCES_COLLECTION_NAME): + client = qdrant_client() + count_result = client.count( + collection_name=collection_name, + count_filter=models.Filter( + must=qdrant_query_conditions(document, collection_name=collection_name) + ), + ) + return count_result.count > 0 + + def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME): """ Generate Qdrant query conditions from query params @@ -432,6 +474,21 @@ def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME): return conditions +def filter_existing_qdrant_points_by_ids( + point_ids, collection_name=RESOURCES_COLLECTION_NAME +): + """ + Return only points that dont exist in qdrant + """ + client = qdrant_client() + response = client.retrieve( + collection_name=collection_name, + ids=point_ids, + ) + existing = [record.id for record in response] + return [point_id for point_id in point_ids if point_id not in existing] + + def filter_existing_qdrant_points( values, lookup_field="readable_id", diff --git a/vector_search/utils_test.py b/vector_search/utils_test.py index fb7e83a3ec..f2d9d6273f 100644 --- a/vector_search/utils_test.py +++ b/vector_search/utils_test.py @@ -13,7 +13,7 @@ from vector_search.encoders.utils import dense_encoder from vector_search.utils import ( _chunk_documents, - create_qdrand_collections, + create_qdrant_collections, embed_learning_resources, filter_existing_qdrant_points, qdrant_query_conditions, @@ -35,8 +35,14 @@ def test_vector_point_id_used_for_embed(mocker, content_type): "vector_search.utils.qdrant_client", return_value=mock_qdrant, ) - - embed_learning_resources([resource.id for resource in resources], content_type) + if content_type == "learning_resource": + mocker.patch( + "vector_search.utils.filter_existing_qdrant_points", + return_value=[r.readable_id for r in resources], + ) + embed_learning_resources( + [resource.id for resource in resources], content_type, overwrite=True + ) if content_type == "learning_resource": point_ids = [vector_point_id(resource.readable_id) for resource in resources] @@ -47,12 +53,49 @@ def test_vector_point_id_used_for_embed(mocker, content_type): ) for resource in serialize_bulk_content_files([r.id for r in resources]) ] - assert sorted( [p.id for p in mock_qdrant.upload_points.mock_calls[0].kwargs["points"]] ) == sorted(point_ids) +@pytest.mark.parametrize("content_type", ["learning_resource", "content_file"]) +def test_embed_learning_resources_no_overwrite(mocker, content_type): + # test when overwrite flag is false we dont re-embed existing resources + if content_type == "learning_resource": + resources = LearningResourceFactory.create_batch(5) + else: + resources = ContentFileFactory.create_batch(5, content="test content") + mock_qdrant = mocker.patch("qdrant_client.QdrantClient") + mocker.patch( + "vector_search.utils.qdrant_client", + return_value=mock_qdrant, + ) + if content_type == "learning_resource": + # filter out 3 resources that are already embedded + mocker.patch( + "vector_search.utils.filter_existing_qdrant_points_by_ids", + return_value=[vector_point_id(r.readable_id) for r in resources[0:2]], + ) + else: + # all contentfiles exist in qdrant + mocker.patch( + "vector_search.utils.filter_existing_qdrant_points_by_ids", + return_value=[ + vector_point_id( + f"{doc['resource_readable_id']}.{doc['run_readable_id']}.{doc['key']}.0" + ) + for doc in serialize_bulk_content_files([r.id for r in resources[0:3]]) + ], + ) + embed_learning_resources( + [resource.id for resource in resources], content_type, overwrite=False + ) + if content_type == "learning_resource": + assert len(list(mock_qdrant.upload_points.mock_calls[0].kwargs["points"])) == 2 + else: + assert len(list(mock_qdrant.upload_points.mock_calls[0].kwargs["points"])) == 3 + + def test_filter_existing_qdrant_points(mocker): """ Test that filter_existing_qdrant_points filters out @@ -96,7 +139,7 @@ def test_filter_existing_qdrant_points(mocker): assert filtered_resources.count() == 7 -def test_force_create_qdrand_collections(mocker): +def test_force_create_qdrant_collections(mocker): """ Test that the force flag will recreate collections even if they exist @@ -107,7 +150,7 @@ def test_force_create_qdrand_collections(mocker): return_value=mock_qdrant, ) mock_qdrant.collection_exists.return_value = True - create_qdrand_collections(force_recreate=True) + create_qdrant_collections(force_recreate=True) assert ( mock_qdrant.recreate_collection.mock_calls[0].kwargs["collection_name"] == RESOURCES_COLLECTION_NAME @@ -126,7 +169,7 @@ def test_force_create_qdrand_collections(mocker): ) -def test_auto_create_qdrand_collections(mocker): +def test_auto_create_qdrant_collections(mocker): """ Test that collections will get autocreated if they don't exist @@ -137,7 +180,7 @@ def test_auto_create_qdrand_collections(mocker): return_value=mock_qdrant, ) mock_qdrant.collection_exists.return_value = False - create_qdrand_collections(force_recreate=False) + create_qdrant_collections(force_recreate=False) assert ( mock_qdrant.recreate_collection.mock_calls[0].kwargs["collection_name"] == RESOURCES_COLLECTION_NAME @@ -167,7 +210,7 @@ def test_skip_creating_qdrand_collections(mocker): return_value=mock_qdrant, ) mock_qdrant.collection_exists.return_value = False - create_qdrand_collections(force_recreate=False) + create_qdrant_collections(force_recreate=False) assert ( mock_qdrant.recreate_collection.mock_calls[0].kwargs["collection_name"] == RESOURCES_COLLECTION_NAME