Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vector_search/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
QDRANT_LEARNING_RESOURCE_INDEXES = {
"readable_id": models.PayloadSchemaType.KEYWORD,
"resource_type": models.PayloadSchemaType.KEYWORD,
"certification": models.PayloadSchemaType.KEYWORD,
"certification": models.PayloadSchemaType.BOOL,
"certification_type.code": models.PayloadSchemaType.KEYWORD,
"professional": models.PayloadSchemaType.BOOL,
"published": models.PayloadSchemaType.BOOL,
Expand Down
37 changes: 19 additions & 18 deletions vector_search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,24 +155,25 @@ def update_qdrant_indexes():
Create or update Qdrant indexes based on mapping in constants
"""
client = qdrant_client()
for index_field in QDRANT_LEARNING_RESOURCE_INDEXES:
collection_name = RESOURCES_COLLECTION_NAME
collection = client.get_collection(collection_name=collection_name)
if index_field not in collection.payload_schema:
client.create_payload_index(
collection_name=collection_name,
field_name=index_field,
field_schema=QDRANT_LEARNING_RESOURCE_INDEXES[index_field],
)
for index_field in QDRANT_CONTENT_FILE_INDEXES:
collection_name = CONTENT_FILES_COLLECTION_NAME
collection = client.get_collection(collection_name=collection_name)
if index_field not in collection.payload_schema:
client.create_payload_index(
collection_name=collection_name,
field_name=index_field,
field_schema=QDRANT_CONTENT_FILE_INDEXES[index_field],
)

for index in [
(QDRANT_LEARNING_RESOURCE_INDEXES, RESOURCES_COLLECTION_NAME),
(QDRANT_CONTENT_FILE_INDEXES, CONTENT_FILES_COLLECTION_NAME),
]:
indexes = index[0]
collection_name = index[1]
for index_field in indexes:
collection = client.get_collection(collection_name=collection_name)
if (
index_field not in collection.payload_schema
or indexes[index_field]
!= collection.payload_schema[index_field].dict()["data_type"]
):
client.create_payload_index(
collection_name=collection_name,
field_name=index_field,
field_schema=indexes[index_field],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed there's no unit tests for this function, would be good to add some. Could be this PR if there's no rush to get this out, otherwise in a subsequent PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call. added some tests



def vector_point_id(readable_id):
Expand Down
64 changes: 64 additions & 0 deletions vector_search/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from main.utils import checksum_for_content
from vector_search.constants import (
CONTENT_FILES_COLLECTION_NAME,
QDRANT_CONTENT_FILE_INDEXES,
QDRANT_CONTENT_FILE_PARAM_MAP,
QDRANT_LEARNING_RESOURCE_INDEXES,
QDRANT_RESOURCE_PARAM_MAP,
RESOURCES_COLLECTION_NAME,
)
Expand All @@ -39,6 +41,7 @@
should_generate_resource_embeddings,
update_content_file_payload,
update_learning_resource_payload,
update_qdrant_indexes,
vector_point_id,
vector_search,
)
Expand Down Expand Up @@ -851,3 +854,64 @@ def test_embed_learning_resources_contentfile_summarization_filter(mocker):

# Assert that the summarizer was called with the correct content file IDs
assert sorted(mock_content_summarizer.mock_calls[0].args[0]) == sorted(cf_ids)


@pytest.mark.django_db
def test_update_qdrant_indexes_adds_missing_index(mocker):
"""
Test that update_qdrant_indexes adds an index if it doesn't already exist
"""
mock_client = mocker.patch("vector_search.utils.qdrant_client").return_value
mock_client.get_collection.return_value.payload_schema = {}

update_qdrant_indexes()

# Ensure create_payload_index is called for missing indexes
expected_calls = [
mocker.call(
collection_name=RESOURCES_COLLECTION_NAME,
field_name=index_field,
field_schema=QDRANT_LEARNING_RESOURCE_INDEXES[index_field],
)
for index_field in QDRANT_LEARNING_RESOURCE_INDEXES
] + [
mocker.call(
collection_name=CONTENT_FILES_COLLECTION_NAME,
field_name=index_field,
field_schema=QDRANT_CONTENT_FILE_INDEXES[index_field],
)
for index_field in QDRANT_CONTENT_FILE_INDEXES
]
mock_client.create_payload_index.assert_has_calls(expected_calls, any_order=True)


@pytest.mark.django_db
def test_update_qdrant_indexes_updates_mismatched_field_type(mocker):
"""
Test that update_qdrant_indexes updates the index if the field types mismatch
"""
mock_client = mocker.patch("vector_search.utils.qdrant_client").return_value
mock_client.get_collection.return_value.payload_schema = {
index_field: mocker.MagicMock(data_type="wrong_type")
for index_field in QDRANT_LEARNING_RESOURCE_INDEXES
}

update_qdrant_indexes()

# Ensure create_payload_index is called for mismatched field types
expected_calls = [
mocker.call(
collection_name=RESOURCES_COLLECTION_NAME,
field_name=index_field,
field_schema=QDRANT_LEARNING_RESOURCE_INDEXES[index_field],
)
for index_field in QDRANT_LEARNING_RESOURCE_INDEXES
] + [
mocker.call(
collection_name=CONTENT_FILES_COLLECTION_NAME,
field_name=index_field,
field_schema=QDRANT_CONTENT_FILE_INDEXES[index_field],
)
for index_field in QDRANT_CONTENT_FILE_INDEXES
]
mock_client.create_payload_index.assert_has_calls(expected_calls, any_order=True)
Loading