Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.core.management.base import BaseCommand

from vector_search.utils import (
create_qdrand_collections,
create_qdrant_collections,
)


Expand All @@ -26,8 +26,8 @@ def handle(self, *args, **options): # noqa: ARG002
"""Create Qdrant collections"""

if options["force"]:
create_qdrand_collections(force_recreate=True)
create_qdrant_collections(force_recreate=True)
else:
create_qdrand_collections(force_recreate=False)
create_qdrant_collections(force_recreate=False)

self.stdout.write("Created Qdrant collections")
15 changes: 12 additions & 3 deletions vector_search/management/commands/generate_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from main.utils import clear_search_cache, now_in_utc
from vector_search.tasks import embed_learning_resources_by_id, start_embed_resources
from vector_search.utils import (
create_qdrand_collections,
create_qdrant_collections,
)


Expand Down Expand Up @@ -42,6 +42,12 @@ def add_arguments(self, parser):
action="store_true",
help="Skip embedding content files",
)
parser.add_argument(
"--overwrite",
dest="overwrite",
action="store_true",
help="Force overwrite existing embeddings",
)

for object_type in sorted(LEARNING_RESOURCE_TYPES):
parser.add_argument(
Expand Down Expand Up @@ -71,18 +77,21 @@ def handle(self, *args, **options): # noqa: ARG002
self.stdout.write(f" --{object_type}s")
return
if options["recreate_collections"]:
create_qdrand_collections(force_recreate=True)
create_qdrant_collections(force_recreate=True)
if options["resource-ids"]:
task = embed_learning_resources_by_id.delay(
[
int(resource_id)
for resource_id in options["resource-ids"].split(",")
],
skip_content_files=options["skip_content_files"],
overwrite=options["overwrite"],
)
else:
task = start_embed_resources.delay(
indexes_to_update, skip_content_files=options["skip_content_files"]
indexes_to_update,
skip_content_files=options["skip_content_files"],
overwrite=options["overwrite"],
)
self.stdout.write(
f"Started celery task {task} to index content for the following"
Expand Down
57 changes: 17 additions & 40 deletions vector_search/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
chunks,
now_in_utc,
)
from vector_search.constants import (
RESOURCES_COLLECTION_NAME,
)
from vector_search.utils import embed_learning_resources, filter_existing_qdrant_points
from vector_search.utils import embed_learning_resources

log = logging.getLogger(__name__)

Expand All @@ -46,7 +43,7 @@
retry_backoff=True,
rate_limit="600/m",
)
def generate_embeddings(ids, resource_type):
def generate_embeddings(ids, resource_type, overwrite):
"""
Generate learning resource embeddings and index in Qdrant

Expand All @@ -57,7 +54,7 @@ def generate_embeddings(ids, resource_type):
"""
try:
with wrap_retry_exception(*SEARCH_CONN_EXCEPTIONS):
embed_learning_resources(ids, resource_type)
embed_learning_resources(ids, resource_type, overwrite)
except (RetryError, Ignore):
raise
except SystemExit as err:
Expand All @@ -69,7 +66,7 @@ def generate_embeddings(ids, resource_type):


@app.task(bind=True)
def start_embed_resources(self, indexes, skip_content_files):
def start_embed_resources(self, indexes, skip_content_files, overwrite):
"""
Celery task to embed all learning resources for given indexes

Expand All @@ -89,7 +86,7 @@ def start_embed_resources(self, indexes, skip_content_files):
blocklisted_ids = load_course_blocklist()

index_tasks = [
generate_embeddings.si(ids, COURSE_TYPE)
generate_embeddings.si(ids, COURSE_TYPE, overwrite)
for ids in chunks(
Course.objects.filter(learning_resource__published=True)
.exclude(learning_resource__readable_id=blocklisted_ids)
Expand Down Expand Up @@ -123,10 +120,7 @@ def start_embed_resources(self, indexes, skip_content_files):
)

index_tasks = index_tasks + [
generate_embeddings.si(
ids,
CONTENT_FILE_TYPE,
)
generate_embeddings.si(ids, CONTENT_FILE_TYPE, overwrite)
for ids in chunks(
run_contentfiles,
chunk_size=settings.QDRANT_CHUNK_SIZE,
Expand All @@ -150,10 +144,7 @@ def start_embed_resources(self, indexes, skip_content_files):
chunk_size=settings.QDRANT_CHUNK_SIZE,
):
index_tasks.append(
generate_embeddings.si(
ids,
resource_type,
)
generate_embeddings.si(ids, resource_type, overwrite)
)
except: # noqa: E722
error = "start_embed_resources threw an error"
Expand All @@ -166,7 +157,7 @@ def start_embed_resources(self, indexes, skip_content_files):


@app.task(bind=True)
def embed_learning_resources_by_id(self, ids, skip_content_files):
def embed_learning_resources_by_id(self, ids, skip_content_files, overwrite):
"""
Celery task to embed specific resources

Expand All @@ -190,10 +181,7 @@ def embed_learning_resources_by_id(self, ids, skip_content_files):
embed_resources = resources.filter(resource_type=resource_type)
[
index_tasks.append(
generate_embeddings.si(
chunk_ids,
resource_type,
)
generate_embeddings.si(chunk_ids, resource_type, overwrite)
)
for chunk_ids in chunks(
embed_resources.order_by("id").values_list("id", flat=True),
Expand All @@ -216,10 +204,7 @@ def embed_learning_resources_by_id(self, ids, skip_content_files):
).order_by("id")
content_ids = run_contentfiles.values_list("id", flat=True)
index_tasks = index_tasks + [
generate_embeddings.si(
ids,
CONTENT_FILE_TYPE,
)
generate_embeddings.si(ids, CONTENT_FILE_TYPE, overwrite)
for ids in chunks(
content_ids,
chunk_size=settings.QDRANT_CHUNK_SIZE,
Expand Down Expand Up @@ -249,27 +234,19 @@ def embed_new_learning_resources(self):
published=True,
created_on__gt=since,
).exclude(resource_type=CONTENT_FILE_TYPE)
existing_readable_ids = [
learning_resource.readable_id for learning_resource in new_learning_resources
]
filtered_readable_ids = filter_existing_qdrant_points(
values=existing_readable_ids,
lookup_field="readable_id",
collection_name=RESOURCES_COLLECTION_NAME,
)
filtered_resources = LearningResource.objects.filter(
readable_id__in=filtered_readable_ids

resource_types = list(
new_learning_resources.values_list("resource_type", flat=True)
)
resource_types = list(filtered_resources.values_list("resource_type", flat=True))
tasks = []
for resource_type in resource_types:
tasks.extend(
[
generate_embeddings.si(ids, resource_type)
generate_embeddings.si(ids, resource_type, overwrite=False)
for ids in chunks(
filtered_resources.filter(resource_type=resource_type).values_list(
"id", flat=True
),
new_learning_resources.filter(
resource_type=resource_type
).values_list("id", flat=True),
chunk_size=settings.QDRANT_CHUNK_SIZE,
)
]
Expand Down
14 changes: 10 additions & 4 deletions vector_search/tasks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ def test_start_embed_resources(mocker, mocked_celery, index):
)

with pytest.raises(mocked_celery.replace_exception_class):
start_embed_resources.delay([index], skip_content_files=True)
start_embed_resources.delay([index], skip_content_files=True, overwrite=True)

generate_embeddings_mock.si.assert_called_once_with(resource_ids, index)
generate_embeddings_mock.si.assert_called_once_with(
resource_ids,
index,
True, # noqa: FBT003
)
assert mocked_celery.replace.call_count == 1
assert mocked_celery.replace.call_args[0][1] == mocked_celery.chain.return_value

Expand Down Expand Up @@ -101,7 +105,7 @@ def test_start_embed_resources_without_settings(mocker, mocked_celery, index):
generate_embeddings_mock = mocker.patch(
"vector_search.tasks.generate_embeddings", autospec=True
)
start_embed_resources.delay([index], skip_content_files=True)
start_embed_resources.delay([index], skip_content_files=True, overwrite=True)

generate_embeddings_mock.si.assert_not_called()

Expand Down Expand Up @@ -172,7 +176,9 @@ def test_embed_learning_resources_by_id(mocker, mocked_celery):
content_ids.append(cf.id)

with pytest.raises(mocked_celery.replace_exception_class):
embed_learning_resources_by_id.delay(resource_ids, skip_content_files=False)
embed_learning_resources_by_id.delay(
resource_ids, skip_content_files=False, overwrite=True
)
for mock_call in generate_embeddings_mock.si.mock_calls[1:]:
assert mock_call.args[0][0] in content_ids
assert mock_call.args[1] == "content_file"
Expand Down
87 changes: 72 additions & 15 deletions vector_search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def points_generator(
yield models.PointStruct(id=idx, payload=payload, vector=point_vector)


def create_qdrand_collections(force_recreate):
def create_qdrant_collections(force_recreate):
"""
Create or recreate QDrant collections

Expand Down Expand Up @@ -174,8 +174,10 @@ def _process_resource_embeddings(serialized_resources):
docs.append(
f"{doc.get('title')} {doc.get('description')} {doc.get('full_description')}"
)
embeddings = encoder.embed_documents(docs)
return points_generator(ids, metadata, embeddings, vector_name)
if len(docs) > 0:
embeddings = encoder.embed_documents(docs)
return points_generator(ids, metadata, embeddings, vector_name)
return None


def _chunk_documents(encoder, texts, metadatas):
Expand Down Expand Up @@ -282,10 +284,12 @@ def _process_content_embeddings(serialized_content):
except Exception as e: # noqa: BLE001
msg = f"Exceeded multi-vector max size: {e}"
logger.warning(msg)
return points_generator(ids, metadata, embeddings, vector_name)
if ids:
return points_generator(ids, metadata, embeddings, vector_name)
return None


def embed_learning_resources(ids, resource_type):
def embed_learning_resources(ids, resource_type, overwrite):
"""
Embed learning resources

Expand All @@ -296,20 +300,47 @@ def embed_learning_resources(ids, resource_type):

client = qdrant_client()

resources_collection_name = RESOURCES_COLLECTION_NAME
content_files_collection_name = CONTENT_FILES_COLLECTION_NAME

create_qdrand_collections(force_recreate=False)
create_qdrant_collections(force_recreate=False)
if resource_type != CONTENT_FILE_TYPE:
serialized_resources = serialize_bulk_learning_resources(ids)
collection_name = resources_collection_name
serialized_resources = list(serialize_bulk_learning_resources(ids))
points = [
(vector_point_id(serialized["readable_id"]), serialized)
for serialized in serialized_resources
]
if not overwrite:
filtered_point_ids = filter_existing_qdrant_points_by_ids(
[point[0] for point in points],
collection_name=RESOURCES_COLLECTION_NAME,
)
serialized_resources = [
point[1] for point in points if point[0] in filtered_point_ids
]

collection_name = RESOURCES_COLLECTION_NAME
points = _process_resource_embeddings(serialized_resources)
else:
serialized_resources = serialize_bulk_content_files(ids)
collection_name = content_files_collection_name
serialized_resources = list(serialize_bulk_content_files(ids))
collection_name = CONTENT_FILES_COLLECTION_NAME
points = [
(
vector_point_id(
f"{doc['resource_readable_id']}.{doc['run_readable_id']}.{doc['key']}.0"
),
doc,
)
for doc in serialized_resources
]
if not overwrite:
filtered_point_ids = filter_existing_qdrant_points_by_ids(
[point[0] for point in points],
collection_name=CONTENT_FILES_COLLECTION_NAME,
)
serialized_resources = [
point[1] for point in points if point[0] in filtered_point_ids
]
points = _process_content_embeddings(serialized_resources)

client.upload_points(collection_name, points=points, wait=False)
if points:
client.upload_points(collection_name, points=points, wait=False)


def _resource_vector_hits(search_result):
Expand Down Expand Up @@ -395,6 +426,17 @@ def vector_search(
}


def document_exists(document, collection_name=RESOURCES_COLLECTION_NAME):
client = qdrant_client()
count_result = client.count(
collection_name=collection_name,
count_filter=models.Filter(
must=qdrant_query_conditions(document, collection_name=collection_name)
),
)
return count_result.count > 0


def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME):
"""
Generate Qdrant query conditions from query params
Expand Down Expand Up @@ -432,6 +474,21 @@ def qdrant_query_conditions(params, collection_name=RESOURCES_COLLECTION_NAME):
return conditions


def filter_existing_qdrant_points_by_ids(
point_ids, collection_name=RESOURCES_COLLECTION_NAME
):
"""
Return only points that dont exist in qdrant
"""
client = qdrant_client()
response = client.retrieve(
collection_name=collection_name,
ids=point_ids,
)
existing = [record.id for record in response]
return [point_id for point_id in point_ids if point_id not in existing]


def filter_existing_qdrant_points(
values,
lookup_field="readable_id",
Expand Down
Loading
Loading