diff --git a/RELEASE.rst b/RELEASE.rst index eb09661e7b..dbb96eb892 100644 --- a/RELEASE.rst +++ b/RELEASE.rst @@ -1,6 +1,14 @@ Release Notes ============= +Version 0.47.8 +-------------- + +- upgrade course-search-utils (#2656) +- Vector based topics tagging for videos (#2649) +- Update dependency ruff to v0.14.2 (#2632) +- Update Node.js to v22.21.0 (#2634) + Version 0.47.7 (Released October 28, 2025) -------------- diff --git a/docker-compose.apps.yml b/docker-compose.apps.yml index 64597a3ea1..5a57cc6449 100644 --- a/docker-compose.apps.yml +++ b/docker-compose.apps.yml @@ -31,7 +31,7 @@ services: profiles: - frontend working_dir: /src - image: node:22.20 + image: node:22.21 entrypoint: ["/bin/sh", "-c"] command: - | diff --git a/frontends/main/package.json b/frontends/main/package.json index 92dd739496..b53d266257 100644 --- a/frontends/main/package.json +++ b/frontends/main/package.json @@ -13,7 +13,7 @@ "@ebay/nice-modal-react": "^1.2.13", "@emotion/cache": "^11.13.1", "@emotion/styled": "^11.11.0", - "@mitodl/course-search-utils": "^3.4.1", + "@mitodl/course-search-utils": "^3.5.0", "@mitodl/mitxonline-api-axios": "^2025.10.21", "@mitodl/smoot-design": "^6.17.1", "@next/bundle-analyzer": "^14.2.15", diff --git a/frontends/main/src/app/c/[channelType]/[name]/page.tsx b/frontends/main/src/app/c/[channelType]/[name]/page.tsx index 560a4df50a..1ac3631877 100644 --- a/frontends/main/src/app/c/[channelType]/[name]/page.tsx +++ b/frontends/main/src/app/c/[channelType]/[name]/page.tsx @@ -93,8 +93,6 @@ const Page: React.FC> = async ({ ) const searchRequest = getSearchParams({ - // @ts-expect-error Local openapi client https://www.npmjs.com/package/@mitodl/open-api-axios - // out of sync while we adding an enum value. requestParams: validateRequestParams(search), constantSearchParams, facetNames, diff --git a/frontends/main/src/app/search/page.tsx b/frontends/main/src/app/search/page.tsx index ea975717d6..3061a0794c 100644 --- a/frontends/main/src/app/search/page.tsx +++ b/frontends/main/src/app/search/page.tsx @@ -38,8 +38,6 @@ const Page: React.FC> = async ({ searchParams }) => { } const params = getSearchParams({ - // @ts-expect-error Local openapi client https://www.npmjs.com/package/@mitodl/open-api-axios - // out of sync while we adding an enum value. requestParams: validateRequestParams(search), constantSearchParams: {}, facetNames, diff --git a/learning_resources/etl/loaders_test.py b/learning_resources/etl/loaders_test.py index 91c16bdb8c..72b074754a 100644 --- a/learning_resources/etl/loaders_test.py +++ b/learning_resources/etl/loaders_test.py @@ -131,6 +131,14 @@ def mock_duplicates(mocker): ) +@pytest.fixture +def mock_get_similar_topics_qdrant(mocker): + mocker.patch( + "learning_resources_search.plugins.get_similar_topics_qdrant", + return_value=["topic1", "topic2"], + ) + + @pytest.fixture(autouse=True) def mock_upsert_tasks(mocker): """Mock out the upsert task helpers""" @@ -1465,9 +1473,10 @@ def test_load_video(mocker, video_exists, is_published, pass_topics): assert getattr(result, key) == value, f"Property {key} should equal {value}" -def test_load_videos(): +def test_load_videos(mocker, mock_get_similar_topics_qdrant): """Verify that load_videos loads a list of videos""" assert Video.objects.count() == 0 + video_resources = [video.learning_resource for video in VideoFactory.build_batch(5)] videos_data = [ { @@ -1486,13 +1495,14 @@ def test_load_videos(): @pytest.mark.parametrize("playlist_exists", [True, False]) -def test_load_playlist(mocker, playlist_exists): +def test_load_playlist(mocker, playlist_exists, mock_get_similar_topics_qdrant): """Test load_playlist""" expected_topics = [{"name": "Biology"}, {"name": "Physics"}] [ LearningResourceTopicFactory.create(name=topic["name"]) for topic in expected_topics ] + mock_most_common_topics = mocker.patch( "learning_resources.etl.loaders.most_common_topics", return_value=expected_topics, @@ -1904,7 +1914,7 @@ def test_course_with_unpublished_force_ingest_is_test_mode(): @pytest.mark.django_db -def test_load_articles(mocker, climate_platform): +def test_load_articles(mocker, climate_platform, mock_get_similar_topics_qdrant): articles_data = [ { "title": "test", diff --git a/learning_resources_search/api.py b/learning_resources_search/api.py index 6ca4c3422b..b6108f04fa 100644 --- a/learning_resources_search/api.py +++ b/learning_resources_search/api.py @@ -35,7 +35,11 @@ adjust_search_for_percolator, document_percolated_actions, ) -from vector_search.constants import RESOURCES_COLLECTION_NAME +from vector_search.constants import ( + RESOURCES_COLLECTION_NAME, + TOPICS_COLLECTION_NAME, +) +from vector_search.encoders.utils import dense_encoder log = logging.getLogger(__name__) @@ -830,6 +834,51 @@ def user_subscribed_to_query( ) +def get_similar_topics_qdrant( + resource: LearningResource, value_doc: dict, num_topics: int +) -> list[str]: + from vector_search.utils import qdrant_client, vector_point_id + + """ + Get a list of similar topics based on vector similarity + + Args: + value_doc (dict): + a document representing the data fields we want to search with + num_topics (int): + number of topics to return + Returns: + list of str: + list of topic values + """ + encoder = dense_encoder() + client = qdrant_client() + + response = client.retrieve( + collection_name=RESOURCES_COLLECTION_NAME, + ids=[vector_point_id(resource.readable_id)], + with_vectors=True, + ) + + embedding_context = "\n".join( + [value_doc[key] for key in value_doc if value_doc[key] is not None] + ) + if response and len(response) > 0: + embeddings = response[0].vector.get(encoder.model_short_name()) + else: + embeddings = encoder.embed(embedding_context) + + return [ + hit["name"] + for hit in _qdrant_similar_results( + input_query=embeddings, + num_resources=num_topics, + collection_name=TOPICS_COLLECTION_NAME, + score_threshold=0.2, + ) + ] + + def get_similar_topics( value_doc: dict, num_topics: int, min_term_freq: int, min_doc_freq: int ) -> list[str]: @@ -909,7 +958,12 @@ def get_similar_resources( ) -def _qdrant_similar_results(doc, num_resources): +def _qdrant_similar_results( + input_query, + num_resources=6, + collection_name=RESOURCES_COLLECTION_NAME, + score_threshold=0, +): """ Get similar resources from qdrant @@ -924,9 +978,7 @@ def _qdrant_similar_results(doc, num_resources): list of serialized resources """ from vector_search.utils import ( - dense_encoder, qdrant_client, - vector_point_id, ) encoder = dense_encoder() @@ -934,10 +986,11 @@ def _qdrant_similar_results(doc, num_resources): return [ hit.payload for hit in client.query_points( - collection_name=RESOURCES_COLLECTION_NAME, - query=vector_point_id(doc["readable_id"]), + collection_name=collection_name, + query=input_query, limit=num_resources, using=encoder.model_short_name(), + score_threshold=score_threshold, ).points ] @@ -956,7 +1009,12 @@ def get_similar_resources_qdrant(value_doc: dict, num_resources: int): list of str: list of learning resources """ - hits = _qdrant_similar_results(value_doc, num_resources) + from vector_search.utils import vector_point_id + + hits = _qdrant_similar_results( + input_query=vector_point_id(value_doc["readable_id"]), + num_resources=num_resources, + ) return ( LearningResource.objects.for_search_serialization() .filter( diff --git a/learning_resources_search/api_test.py b/learning_resources_search/api_test.py index c60ac3e0ee..25f5117ef2 100644 --- a/learning_resources_search/api_test.py +++ b/learning_resources_search/api_test.py @@ -1,6 +1,6 @@ """Search API function tests""" -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import pytest from freezegun import freeze_time @@ -21,6 +21,7 @@ generate_sort_clause, generate_suggest_clause, get_similar_topics, + get_similar_topics_qdrant, percolate_matches_for_document, relevant_indexes, ) @@ -3266,3 +3267,39 @@ def test_dev_mode(dev_mode): assert construct_search(search_params).to_dict().get("explain") else: assert construct_search(search_params).to_dict().get("explain") is None + + +@pytest.mark.django_db +def test_get_similar_topics_qdrant_uses_cached_embedding(mocker): + """ + Test that get_similar_topics_qdrant uses a cached embedding when available + """ + resource = MagicMock() + resource.readable_id = "test-resource" + value_doc = {"title": "Test Title", "description": "Test Description"} + num_topics = 3 + + mock_encoder = mocker.patch("learning_resources_search.api.dense_encoder") + encoder_instance = mock_encoder.return_value + encoder_instance.model_short_name.return_value = "test-model" + encoder_instance.embed.return_value = [0.1, 0.2, 0.3] + + mock_client = mocker.patch("vector_search.utils.qdrant_client") + client_instance = mock_client.return_value + + # Simulate a cached embedding in the response + client_instance.retrieve.return_value = [ + MagicMock(vector={"test-model": [0.9, 0.8, 0.7]}) + ] + + mocker.patch( + "learning_resources_search.api._qdrant_similar_results", + return_value=[{"name": "topic1"}, {"name": "topic2"}], + ) + + result = get_similar_topics_qdrant(resource, value_doc, num_topics) + + # Assert that embed was NOT called (cached embedding used) + encoder_instance.embed.assert_not_called() + # Assert that the result is as expected + assert result == ["topic1", "topic2"] diff --git a/learning_resources_search/plugins.py b/learning_resources_search/plugins.py index 8eca2f981a..f938d316e1 100644 --- a/learning_resources_search/plugins.py +++ b/learning_resources_search/plugins.py @@ -7,7 +7,7 @@ from django.conf import settings as django_settings from learning_resources_search import tasks -from learning_resources_search.api import get_similar_topics +from learning_resources_search.api import get_similar_topics_qdrant from learning_resources_search.constants import ( COURSE_TYPE, PERCOLATE_INDEX_TYPE, @@ -125,11 +125,10 @@ def resource_similar_topics(self, resource) -> list[dict]: "full_description": resource.full_description, } - topic_names = get_similar_topics( + topic_names = get_similar_topics_qdrant( + resource, text_doc, settings.OPEN_VIDEO_MAX_TOPICS, - settings.OPEN_VIDEO_MIN_TERM_FREQ, - settings.OPEN_VIDEO_MIN_DOC_FREQ, ) return [{"name": topic_name} for topic_name in topic_names] diff --git a/learning_resources_search/plugins_test.py b/learning_resources_search/plugins_test.py index 4bb24b8ca9..8eacaa65d0 100644 --- a/learning_resources_search/plugins_test.py +++ b/learning_resources_search/plugins_test.py @@ -128,19 +128,18 @@ def test_resource_similar_topics(mocker, settings): """The plugin function should return expected topics for a resource""" expected_topics = ["topic1", "topic2"] mock_similar_topics = mocker.patch( - "learning_resources_search.plugins.get_similar_topics", + "learning_resources_search.plugins.get_similar_topics_qdrant", return_value=expected_topics, ) resource = LearningResourceFactory.create() topics = SearchIndexPlugin().resource_similar_topics(resource) assert topics == [{"name": topic} for topic in expected_topics] mock_similar_topics.assert_called_once_with( + resource, { "title": resource.title, "description": resource.description, "full_description": resource.full_description, }, settings.OPEN_VIDEO_MAX_TOPICS, - settings.OPEN_VIDEO_MIN_TERM_FREQ, - settings.OPEN_VIDEO_MIN_DOC_FREQ, ) diff --git a/main/settings.py b/main/settings.py index 8cd321eeb6..c5d47ab0cd 100644 --- a/main/settings.py +++ b/main/settings.py @@ -34,7 +34,7 @@ from main.settings_pluggy import * # noqa: F403 from openapi.settings_spectacular import open_spectacular_settings -VERSION = "0.47.7" +VERSION = "0.47.8" log = logging.getLogger() diff --git a/main/settings_course_etl.py b/main/settings_course_etl.py index a5bf02954e..6f70c4d6c4 100644 --- a/main/settings_course_etl.py +++ b/main/settings_course_etl.py @@ -96,7 +96,7 @@ # course catalog video etl settings OPEN_VIDEO_DATA_BRANCH = get_string("OPEN_VIDEO_DATA_BRANCH", "master") OPEN_VIDEO_USER_LIST_OWNER = get_string("OPEN_VIDEO_USER_LIST_OWNER", None) -OPEN_VIDEO_MAX_TOPICS = get_int("OPEN_VIDEO_MAX_TOPICS", 3) +OPEN_VIDEO_MAX_TOPICS = get_int("OPEN_VIDEO_MAX_TOPICS", 2) OPEN_VIDEO_MIN_TERM_FREQ = get_int("OPEN_VIDEO_MIN_TERM_FREQ", 1) OPEN_VIDEO_MIN_DOC_FREQ = get_int("OPEN_VIDEO_MIN_DOC_FREQ", 15) diff --git a/poetry.lock b/poetry.lock index 43837616dc..ebca7b93cd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7664,31 +7664,31 @@ files = [ [[package]] name = "ruff" -version = "0.14.1" +version = "0.14.2" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "ruff-0.14.1-py3-none-linux_armv6l.whl", hash = "sha256:083bfc1f30f4a391ae09c6f4f99d83074416b471775b59288956f5bc18e82f8b"}, - {file = "ruff-0.14.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f6fa757cd717f791009f7669fefb09121cc5f7d9bd0ef211371fad68c2b8b224"}, - {file = "ruff-0.14.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d6191903d39ac156921398e9c86b7354d15e3c93772e7dbf26c9fcae59ceccd5"}, - {file = "ruff-0.14.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed04f0e04f7a4587244e5c9d7df50e6b5bf2705d75059f409a6421c593a35896"}, - {file = "ruff-0.14.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c9e6cf6cd4acae0febbce29497accd3632fe2025c0c583c8b87e8dbdeae5f61"}, - {file = "ruff-0.14.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6fa2458527794ecdfbe45f654e42c61f2503a230545a91af839653a0a93dbc6"}, - {file = "ruff-0.14.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:39f1c392244e338b21d42ab29b8a6392a722c5090032eb49bb4d6defcdb34345"}, - {file = "ruff-0.14.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7382fa12a26cce1f95070ce450946bec357727aaa428983036362579eadcc5cf"}, - {file = "ruff-0.14.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd0bf2be3ae8521e1093a487c4aa3b455882f139787770698530d28ed3fbb37c"}, - {file = "ruff-0.14.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabcaa9ccf8089fb4fdb78d17cc0e28241520f50f4c2e88cb6261ed083d85151"}, - {file = "ruff-0.14.1-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:747d583400f6125ec11a4c14d1c8474bf75d8b419ad22a111a537ec1a952d192"}, - {file = "ruff-0.14.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5a6e74c0efd78515a1d13acbfe6c90f0f5bd822aa56b4a6d43a9ffb2ae6e56cd"}, - {file = "ruff-0.14.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0ea6a864d2fb41a4b6d5b456ed164302a0d96f4daac630aeba829abfb059d020"}, - {file = "ruff-0.14.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0826b8764f94229604fa255918d1cc45e583e38c21c203248b0bfc9a0e930be5"}, - {file = "ruff-0.14.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cbc52160465913a1a3f424c81c62ac8096b6a491468e7d872cb9444a860bc33d"}, - {file = "ruff-0.14.1-py3-none-win32.whl", hash = "sha256:e037ea374aaaff4103240ae79168c0945ae3d5ae8db190603de3b4012bd1def6"}, - {file = "ruff-0.14.1-py3-none-win_amd64.whl", hash = "sha256:59d599cdff9c7f925a017f6f2c256c908b094e55967f93f2821b1439928746a1"}, - {file = "ruff-0.14.1-py3-none-win_arm64.whl", hash = "sha256:e3b443c4c9f16ae850906b8d0a707b2a4c16f8d2f0a7fe65c475c5886665ce44"}, - {file = "ruff-0.14.1.tar.gz", hash = "sha256:1dd86253060c4772867c61791588627320abcb6ed1577a90ef432ee319729b69"}, + {file = "ruff-0.14.2-py3-none-linux_armv6l.whl", hash = "sha256:7cbe4e593505bdec5884c2d0a4d791a90301bc23e49a6b1eb642dd85ef9c64f1"}, + {file = "ruff-0.14.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8d54b561729cee92f8d89c316ad7a3f9705533f5903b042399b6ae0ddfc62e11"}, + {file = "ruff-0.14.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5c8753dfa44ebb2cde10ce5b4d2ef55a41fb9d9b16732a2c5df64620dbda44a3"}, + {file = "ruff-0.14.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d0bbeffb8d9f4fccf7b5198d566d0bad99a9cb622f1fc3467af96cb8773c9e3"}, + {file = "ruff-0.14.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7047f0c5a713a401e43a88d36843d9c83a19c584e63d664474675620aaa634a8"}, + {file = "ruff-0.14.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bf8d2f9aa1602599217d82e8e0af7fd33e5878c4d98f37906b7c93f46f9a839"}, + {file = "ruff-0.14.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1c505b389e19c57a317cf4b42db824e2fca96ffb3d86766c1c9f8b96d32048a7"}, + {file = "ruff-0.14.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a307fc45ebd887b3f26b36d9326bb70bf69b01561950cdcc6c0bdf7bb8e0f7cc"}, + {file = "ruff-0.14.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:61ae91a32c853172f832c2f40bd05fd69f491db7289fb85a9b941ebdd549781a"}, + {file = "ruff-0.14.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1967e40286f63ee23c615e8e7e98098dedc7301568bd88991f6e544d8ae096"}, + {file = "ruff-0.14.2-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:2877f02119cdebf52a632d743a2e302dea422bfae152ebe2f193d3285a3a65df"}, + {file = "ruff-0.14.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e681c5bc777de5af898decdcb6ba3321d0d466f4cb43c3e7cc2c3b4e7b843a05"}, + {file = "ruff-0.14.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e21be42d72e224736f0c992cdb9959a2fa53c7e943b97ef5d081e13170e3ffc5"}, + {file = "ruff-0.14.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b8264016f6f209fac16262882dbebf3f8be1629777cf0f37e7aff071b3e9b92e"}, + {file = "ruff-0.14.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5ca36b4cb4db3067a3b24444463ceea5565ea78b95fe9a07ca7cb7fd16948770"}, + {file = "ruff-0.14.2-py3-none-win32.whl", hash = "sha256:41775927d287685e08f48d8eb3f765625ab0b7042cc9377e20e64f4eb0056ee9"}, + {file = "ruff-0.14.2-py3-none-win_amd64.whl", hash = "sha256:0df3424aa5c3c08b34ed8ce099df1021e3adaca6e90229273496b839e5a7e1af"}, + {file = "ruff-0.14.2-py3-none-win_arm64.whl", hash = "sha256:ea9d635e83ba21569fbacda7e78afbfeb94911c9434aff06192d9bc23fd5495a"}, + {file = "ruff-0.14.2.tar.gz", hash = "sha256:98da787668f239313d9c902ca7c523fe11b8ec3f39345553a51b25abc4629c96"}, ] [[package]] @@ -9255,4 +9255,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = "~3.12" -content-hash = "3375bf2da141502912a53f6ded4b5673b3de08e29ad16534e795f3b9dbb8e20d" +content-hash = "ceab94db56105439c94cf4b8895b8c50d21deb17b0d59445e7cc813782f6a129" diff --git a/pyproject.toml b/pyproject.toml index 362fba8ab7..80cb94e57c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,7 +134,7 @@ pytest-env = "^1.0.0" pytest-freezegun = "^0.4.2" pytest-mock = "^3.10.0" responses = "^0.25.0" -ruff = "0.14.1" +ruff = "0.14.2" safety = "^3.0.0" semantic-version = "^2.10.0" freezegun = "^1.4.0" diff --git a/vector_search/constants.py b/vector_search/constants.py index 97445d892d..d057c2ea98 100644 --- a/vector_search/constants.py +++ b/vector_search/constants.py @@ -3,6 +3,7 @@ RESOURCES_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.resources" CONTENT_FILES_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.content_files" +TOPICS_COLLECTION_NAME = f"{settings.QDRANT_BASE_COLLECTION_NAME}.topics" QDRANT_CONTENT_FILE_PARAM_MAP = { "key": "key", @@ -43,6 +44,10 @@ } +QDRANT_TOPICS_PARAM_MAP = { + "name": "name", +} + QDRANT_LEARNING_RESOURCE_INDEXES = { "readable_id": models.PayloadSchemaType.KEYWORD, "resource_type": models.PayloadSchemaType.KEYWORD, @@ -82,3 +87,7 @@ "edx_block_id": models.PayloadSchemaType.KEYWORD, "url": models.PayloadSchemaType.KEYWORD, } + +QDRANT_TOPIC_INDEXES = { + "name": models.PayloadSchemaType.KEYWORD, +} diff --git a/vector_search/management/commands/sync_topic_embeddings.py b/vector_search/management/commands/sync_topic_embeddings.py new file mode 100644 index 0000000000..8a5d1a8487 --- /dev/null +++ b/vector_search/management/commands/sync_topic_embeddings.py @@ -0,0 +1,27 @@ +"""Management command to update or create the topics collection in Qdrant""" + +from django.core.management.base import BaseCommand, CommandError + +from main.utils import clear_search_cache, now_in_utc +from vector_search.tasks import sync_topics + + +class Command(BaseCommand): + """Syncs embeddings for topics in Qdrant""" + + help = "update or create the topics collection in Qdrant" + + def handle(self, *args, **options): # noqa: ARG002 + """Sync the topics collection""" + task = sync_topics.apply() + self.stdout.write("Waiting on task...") + start = now_in_utc() + error = task.get() + if error: + msg = f"Geenerate embeddings errored: {error}" + raise CommandError(msg) + clear_search_cache() + total_seconds = (now_in_utc() - start).total_seconds() + self.stdout.write( + f"Embeddings generated and stored, took {total_seconds} seconds" + ) diff --git a/vector_search/tasks.py b/vector_search/tasks.py index ce7ab862ca..3e497a3d9e 100644 --- a/vector_search/tasks.py +++ b/vector_search/tasks.py @@ -32,7 +32,11 @@ chunks, now_in_utc, ) -from vector_search.utils import embed_learning_resources, remove_qdrant_records +from vector_search.utils import ( + embed_learning_resources, + embed_topics, + remove_qdrant_records, +) log = logging.getLogger(__name__) @@ -362,3 +366,11 @@ def remove_run_content_files(run_id): for ids in chunks(content_file_ids, chunk_size=settings.QDRANT_CHUNK_SIZE) ] ) + + +@app.task +def sync_topics(): + """ + Sync topics to the Qdrant collection + """ + embed_topics() diff --git a/vector_search/utils.py b/vector_search/utils.py index ab9160b21b..1f9d55a735 100644 --- a/vector_search/utils.py +++ b/vector_search/utils.py @@ -2,12 +2,17 @@ import uuid from django.conf import settings +from django.db.models import Q from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_experimental.text_splitter import SemanticChunker from qdrant_client import QdrantClient, models from learning_resources.content_summarizer import ContentSummarizer -from learning_resources.models import ContentFile, LearningResource +from learning_resources.models import ( + ContentFile, + LearningResource, + LearningResourceTopic, +) from learning_resources.serializers import ( ContentFileSerializer, LearningResourceMetadataDisplaySerializer, @@ -28,7 +33,10 @@ QDRANT_CONTENT_FILE_PARAM_MAP, QDRANT_LEARNING_RESOURCE_INDEXES, QDRANT_RESOURCE_PARAM_MAP, + QDRANT_TOPIC_INDEXES, + QDRANT_TOPICS_PARAM_MAP, RESOURCES_COLLECTION_NAME, + TOPICS_COLLECTION_NAME, ) from vector_search.encoders.utils import dense_encoder @@ -83,48 +91,29 @@ def create_qdrant_collections(force_recreate): force_recreate (bool): Whether to recreate the collections even if they already exist """ + + collections = [ + RESOURCES_COLLECTION_NAME, + CONTENT_FILES_COLLECTION_NAME, + TOPICS_COLLECTION_NAME, + ] + for collection_name in collections: + create_qdrant_collection(collection_name, force_recreate) + + update_qdrant_indexes() + + +def create_qdrant_collection(collection_name, force_recreate): + """ + Create or recreate a QDrant collection + """ client = qdrant_client() - resources_collection_name = RESOURCES_COLLECTION_NAME - content_files_collection_name = CONTENT_FILES_COLLECTION_NAME encoder = dense_encoder() # True if either of the collections were recreated - - if ( - not client.collection_exists(collection_name=resources_collection_name) - or force_recreate - ): - client.delete_collection(resources_collection_name) + if not client.collection_exists(collection_name=collection_name) or force_recreate: + client.delete_collection(collection_name) client.recreate_collection( - collection_name=resources_collection_name, - on_disk_payload=True, - vectors_config={ - encoder.model_short_name(): models.VectorParams( - size=encoder.dim(), distance=models.Distance.COSINE - ), - }, - replication_factor=2, - shard_number=6, - strict_mode_config=models.StrictModeConfig( - enabled=True, - unindexed_filtering_retrieve=False, - unindexed_filtering_update=False, - ), - sparse_vectors_config=client.get_fastembed_sparse_vector_params(), - optimizers_config=models.OptimizersConfigDiff(default_segment_number=2), - quantization_config=models.BinaryQuantization( - binary=models.BinaryQuantizationConfig( - always_ram=True, - ), - ), - ) - - if ( - not client.collection_exists(collection_name=content_files_collection_name) - or force_recreate - ): - client.delete_collection(content_files_collection_name) - client.recreate_collection( - collection_name=content_files_collection_name, + collection_name=collection_name, on_disk_payload=True, vectors_config={ encoder.model_short_name(): models.VectorParams( @@ -146,7 +135,6 @@ def create_qdrant_collections(force_recreate): ), ), ) - update_qdrant_indexes() def update_qdrant_indexes(): @@ -158,6 +146,7 @@ def update_qdrant_indexes(): for index in [ (QDRANT_LEARNING_RESOURCE_INDEXES, RESOURCES_COLLECTION_NAME), (QDRANT_CONTENT_FILE_INDEXES, CONTENT_FILES_COLLECTION_NAME), + (QDRANT_TOPIC_INDEXES, TOPICS_COLLECTION_NAME), ]: indexes = index[0] collection_name = index[1] @@ -188,6 +177,60 @@ def vector_point_id(readable_id): return str(uuid.uuid5(uuid.NAMESPACE_DNS, readable_id)) +def embed_topics(): + """ + Embed and store new (sub)topics and remove non-existent ones from Qdrant + """ + client = qdrant_client() + create_qdrant_collections(force_recreate=False) + indexed_count = client.count(collection_name=TOPICS_COLLECTION_NAME).count + + topic_names = set( + LearningResourceTopic.objects.filter( + Q(parent=None) | Q(parent__isnull=False) + ).values_list("name", flat=True) + ) + + if indexed_count > 0: + existing = vector_search( + query_string="", + params={}, + search_collection=TOPICS_COLLECTION_NAME, + limit=indexed_count, + ) + indexed_topic_names = {hit["name"] for hit in existing["hits"]} + else: + indexed_topic_names = set() + + new_topics = topic_names - indexed_topic_names + remove_topics = indexed_topic_names - topic_names + for remove_topic in remove_topics: + remove_points_matching_params( + {"name": remove_topic}, collection_name=TOPICS_COLLECTION_NAME + ) + + docs = [] + metadata = [] + ids = [] + + filtered_topics = LearningResourceTopic.objects.filter(name__in=new_topics) + + for topic in filtered_topics: + docs.append(topic.name) + metadata.append( + { + "name": topic.name, + } + ) + ids.append(str(topic.topic_uuid)) + if len(docs) > 0: + encoder = dense_encoder() + embeddings = encoder.embed_documents(docs) + vector_name = encoder.model_short_name() + points = points_generator(ids, metadata, embeddings, vector_name) + client.upload_points(TOPICS_COLLECTION_NAME, points=points, wait=False) + + def _chunk_documents(encoder, texts, metadatas): # chunk the documents. use semantic chunking if enabled chunk_params = { @@ -757,6 +800,8 @@ def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME): conditions = [] if collection_name == RESOURCES_COLLECTION_NAME: QDRANT_PARAM_MAP = QDRANT_RESOURCE_PARAM_MAP + elif collection_name == TOPICS_COLLECTION_NAME: + QDRANT_PARAM_MAP = QDRANT_TOPICS_PARAM_MAP else: QDRANT_PARAM_MAP = QDRANT_CONTENT_FILE_PARAM_MAP if not params: diff --git a/vector_search/utils_test.py b/vector_search/utils_test.py index 346ccda029..c4e48cc17b 100644 --- a/vector_search/utils_test.py +++ b/vector_search/utils_test.py @@ -1,4 +1,5 @@ from decimal import Decimal +from unittest.mock import MagicMock import pytest from django.conf import settings @@ -10,6 +11,7 @@ LearningResourceFactory, LearningResourcePriceFactory, LearningResourceRunFactory, + LearningResourceTopicFactory, ) from learning_resources.models import LearningResource from learning_resources.serializers import LearningResourceMetadataDisplaySerializer @@ -35,6 +37,7 @@ _embed_course_metadata_as_contentfile, create_qdrant_collections, embed_learning_resources, + embed_topics, filter_existing_qdrant_points, qdrant_query_conditions, should_generate_content_embeddings, @@ -915,3 +918,64 @@ def test_update_qdrant_indexes_updates_mismatched_field_type(mocker): for index_field in QDRANT_CONTENT_FILE_INDEXES ] mock_client.create_payload_index.assert_has_calls(expected_calls, any_order=True) + + +def test_embed_topics_no_new_topics(mocker): + """ + Test embed_topics when there are no new topics to embed + """ + mock_client = MagicMock() + mock_qdrant_client = mocker.patch("vector_search.utils.qdrant_client") + mock_qdrant_client.return_value = mock_client + mock_client.count.return_value.count = 1 + mock_vector_search = mocker.patch("vector_search.utils.vector_search") + mock_vector_search.return_value = {"hits": [{"name": "topic1"}]} + LearningResourceTopicFactory.create(name="topic1", parent=None) + mock_remove_points_matching_params = mocker.patch( + "vector_search.utils.remove_points_matching_params" + ) + embed_topics() + mock_remove_points_matching_params.assert_not_called() + mock_client.upload_points.assert_not_called() + + +def test_embed_topics_new_topics(mocker): + """ + Test embed_topics when there are new topics + """ + mock_client = MagicMock() + mock_qdrant_client = mocker.patch("vector_search.utils.qdrant_client") + mock_qdrant_client.return_value = mock_client + mock_client.count.return_value.count = 1 + mock_vector_search = mocker.patch("vector_search.utils.vector_search") + mock_vector_search.return_value = {"hits": [{"name": "topic1"}]} + LearningResourceTopicFactory.create(name="topic1", parent=None) + LearningResourceTopicFactory.create(name="topic2", parent=None) + LearningResourceTopicFactory.create(name="topic3", parent=None) + mocker.patch("vector_search.utils.remove_points_matching_params") + embed_topics() + mock_client.upload_points.assert_called_once() + assert len(list(mock_client.upload_points.mock_calls[0][2]["points"])) == 2 + + +def test_embed_topics_remove_topics(mocker): + """ + Test embed_topics when there are topics to remove + """ + mock_client = MagicMock() + mock_qdrant_client = mocker.patch("vector_search.utils.qdrant_client") + mock_qdrant_client.return_value = mock_client + mock_client.count.return_value.count = 1 + mock_vector_search = mocker.patch("vector_search.utils.vector_search") + mock_vector_search.return_value = {"hits": [{"name": "remove-topic"}]} + + LearningResourceTopicFactory.create(name="topic2", parent=None) + LearningResourceTopicFactory.create(name="topic3", parent=None) + mock_remove_points_matching_params = mocker.patch( + "vector_search.utils.remove_points_matching_params" + ) + embed_topics() + mock_remove_points_matching_params.assert_called_once() + assert ( + mock_remove_points_matching_params.mock_calls[0][1][0]["name"] == "remove-topic" + ) diff --git a/yarn.lock b/yarn.lock index e04f130b51..482e84ed98 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3169,11 +3169,11 @@ __metadata: languageName: node linkType: hard -"@mitodl/course-search-utils@npm:^3.4.1": - version: 3.4.1 - resolution: "@mitodl/course-search-utils@npm:3.4.1" +"@mitodl/course-search-utils@npm:^3.5.0": + version: 3.5.0 + resolution: "@mitodl/course-search-utils@npm:3.5.0" dependencies: - "@mitodl/open-api-axios": "npm:2024.9.16" + "@mitodl/mit-learn-api-axios": "npm:2025.10.28" "@remixicon/react": "npm:^4.2.0" axios: "npm:^1.6.7" fuse.js: "npm:^7.0.0" @@ -3191,27 +3191,27 @@ __metadata: optional: true react-router: optional: true - checksum: 10/cccb99931ef96b25d788f26d9bb6b12f61e735a550127cb2eb0dd14fb8210f303c36c50d7a8af2ccb546aa82bb3051c5242fa55fb494fc17c98e1147b6c99063 + checksum: 10/f7b31dc385bc9220b54ec580f21c9a7c90c5a4e69ee86617642f8934bdd79bfd1f2bac4dfdbeadd7edf71c768e67fd5b3af0cd526fdaaacbcfb13eec9a4f7afb languageName: node linkType: hard -"@mitodl/mitxonline-api-axios@npm:^2025.10.21": - version: 2025.10.21 - resolution: "@mitodl/mitxonline-api-axios@npm:2025.10.21" +"@mitodl/mit-learn-api-axios@npm:2025.10.28": + version: 2025.10.28 + resolution: "@mitodl/mit-learn-api-axios@npm:2025.10.28" dependencies: - "@types/node": "npm:^20.11.19" + "@types/node": "npm:^22.0.0" axios: "npm:^1.6.5" - checksum: 10/2285ecfb20946dabd8efcce5a0614c5370c2d7759e01417d799acab25d0e5964bdc7a5789989e075e23e180cc039ccfca4a0c03ee97904b7e2e1fafed193f448 + checksum: 10/0832d24ab89078c315dc01ce40a3350752dee76a1a26649807c846858365e13834dc2ba76c8f4324667f5982690675445ad7c1b44e9d847bccfd7290b29df9b6 languageName: node linkType: hard -"@mitodl/open-api-axios@npm:2024.9.16": - version: 2024.9.16 - resolution: "@mitodl/open-api-axios@npm:2024.9.16" +"@mitodl/mitxonline-api-axios@npm:^2025.10.21": + version: 2025.10.21 + resolution: "@mitodl/mitxonline-api-axios@npm:2025.10.21" dependencies: "@types/node": "npm:^20.11.19" axios: "npm:^1.6.5" - checksum: 10/ff40f62d087e3168e1d9ba0887cc41d3db4d2eb868e43317feb18510cf1f7d9b535d32acd478b9dd1bacfee8ea7e4967d2ed84499f5bf431317c614555f19416 + checksum: 10/2285ecfb20946dabd8efcce5a0614c5370c2d7759e01417d799acab25d0e5964bdc7a5789989e075e23e180cc039ccfca4a0c03ee97904b7e2e1fafed193f448 languageName: node linkType: hard @@ -14058,7 +14058,7 @@ __metadata: "@emotion/cache": "npm:^11.13.1" "@emotion/styled": "npm:^11.11.0" "@faker-js/faker": "npm:^10.0.0" - "@mitodl/course-search-utils": "npm:^3.4.1" + "@mitodl/course-search-utils": "npm:^3.5.0" "@mitodl/mitxonline-api-axios": "npm:^2025.10.21" "@mitodl/smoot-design": "npm:^6.17.1" "@next/bundle-analyzer": "npm:^14.2.15"