diff --git a/vector_search/constants.py b/vector_search/constants.py index 2b322f049c..97445d892d 100644 --- a/vector_search/constants.py +++ b/vector_search/constants.py @@ -46,7 +46,7 @@ QDRANT_LEARNING_RESOURCE_INDEXES = { "readable_id": models.PayloadSchemaType.KEYWORD, "resource_type": models.PayloadSchemaType.KEYWORD, - "certification": models.PayloadSchemaType.KEYWORD, + "certification": models.PayloadSchemaType.BOOL, "certification_type.code": models.PayloadSchemaType.KEYWORD, "professional": models.PayloadSchemaType.BOOL, "published": models.PayloadSchemaType.BOOL, diff --git a/vector_search/utils.py b/vector_search/utils.py index 4a5e2aa37f..67e485943b 100644 --- a/vector_search/utils.py +++ b/vector_search/utils.py @@ -155,24 +155,25 @@ def update_qdrant_indexes(): Create or update Qdrant indexes based on mapping in constants """ client = qdrant_client() - for index_field in QDRANT_LEARNING_RESOURCE_INDEXES: - collection_name = RESOURCES_COLLECTION_NAME - collection = client.get_collection(collection_name=collection_name) - if index_field not in collection.payload_schema: - client.create_payload_index( - collection_name=collection_name, - field_name=index_field, - field_schema=QDRANT_LEARNING_RESOURCE_INDEXES[index_field], - ) - for index_field in QDRANT_CONTENT_FILE_INDEXES: - collection_name = CONTENT_FILES_COLLECTION_NAME - collection = client.get_collection(collection_name=collection_name) - if index_field not in collection.payload_schema: - client.create_payload_index( - collection_name=collection_name, - field_name=index_field, - field_schema=QDRANT_CONTENT_FILE_INDEXES[index_field], - ) + + for index in [ + (QDRANT_LEARNING_RESOURCE_INDEXES, RESOURCES_COLLECTION_NAME), + (QDRANT_CONTENT_FILE_INDEXES, CONTENT_FILES_COLLECTION_NAME), + ]: + indexes = index[0] + collection_name = index[1] + for index_field in indexes: + collection = client.get_collection(collection_name=collection_name) + if ( + index_field not in collection.payload_schema + or indexes[index_field] + != collection.payload_schema[index_field].dict()["data_type"] + ): + client.create_payload_index( + collection_name=collection_name, + field_name=index_field, + field_schema=indexes[index_field], + ) def vector_point_id(readable_id): diff --git a/vector_search/utils_test.py b/vector_search/utils_test.py index 35f3552feb..346ccda029 100644 --- a/vector_search/utils_test.py +++ b/vector_search/utils_test.py @@ -23,7 +23,9 @@ from main.utils import checksum_for_content from vector_search.constants import ( CONTENT_FILES_COLLECTION_NAME, + QDRANT_CONTENT_FILE_INDEXES, QDRANT_CONTENT_FILE_PARAM_MAP, + QDRANT_LEARNING_RESOURCE_INDEXES, QDRANT_RESOURCE_PARAM_MAP, RESOURCES_COLLECTION_NAME, ) @@ -39,6 +41,7 @@ should_generate_resource_embeddings, update_content_file_payload, update_learning_resource_payload, + update_qdrant_indexes, vector_point_id, vector_search, ) @@ -851,3 +854,64 @@ def test_embed_learning_resources_contentfile_summarization_filter(mocker): # Assert that the summarizer was called with the correct content file IDs assert sorted(mock_content_summarizer.mock_calls[0].args[0]) == sorted(cf_ids) + + +@pytest.mark.django_db +def test_update_qdrant_indexes_adds_missing_index(mocker): + """ + Test that update_qdrant_indexes adds an index if it doesn't already exist + """ + mock_client = mocker.patch("vector_search.utils.qdrant_client").return_value + mock_client.get_collection.return_value.payload_schema = {} + + update_qdrant_indexes() + + # Ensure create_payload_index is called for missing indexes + expected_calls = [ + mocker.call( + collection_name=RESOURCES_COLLECTION_NAME, + field_name=index_field, + field_schema=QDRANT_LEARNING_RESOURCE_INDEXES[index_field], + ) + for index_field in QDRANT_LEARNING_RESOURCE_INDEXES + ] + [ + mocker.call( + collection_name=CONTENT_FILES_COLLECTION_NAME, + field_name=index_field, + field_schema=QDRANT_CONTENT_FILE_INDEXES[index_field], + ) + for index_field in QDRANT_CONTENT_FILE_INDEXES + ] + mock_client.create_payload_index.assert_has_calls(expected_calls, any_order=True) + + +@pytest.mark.django_db +def test_update_qdrant_indexes_updates_mismatched_field_type(mocker): + """ + Test that update_qdrant_indexes updates the index if the field types mismatch + """ + mock_client = mocker.patch("vector_search.utils.qdrant_client").return_value + mock_client.get_collection.return_value.payload_schema = { + index_field: mocker.MagicMock(data_type="wrong_type") + for index_field in QDRANT_LEARNING_RESOURCE_INDEXES + } + + update_qdrant_indexes() + + # Ensure create_payload_index is called for mismatched field types + expected_calls = [ + mocker.call( + collection_name=RESOURCES_COLLECTION_NAME, + field_name=index_field, + field_schema=QDRANT_LEARNING_RESOURCE_INDEXES[index_field], + ) + for index_field in QDRANT_LEARNING_RESOURCE_INDEXES + ] + [ + mocker.call( + collection_name=CONTENT_FILES_COLLECTION_NAME, + field_name=index_field, + field_schema=QDRANT_CONTENT_FILE_INDEXES[index_field], + ) + for index_field in QDRANT_CONTENT_FILE_INDEXES + ] + mock_client.create_payload_index.assert_has_calls(expected_calls, any_order=True)