diff --git a/notebooker/serialization/mongo.py b/notebooker/serialization/mongo.py index 0a181f93..17a1e2a4 100644 --- a/notebooker/serialization/mongo.py +++ b/notebooker/serialization/mongo.py @@ -10,6 +10,9 @@ from notebooker.constants import JobStatus, NotebookResultComplete, NotebookResultError, NotebookResultPending logger = getLogger(__name__) +REMOVE_ID_PROJECTION = {"_id": 0} +REMOVE_PAYLOAD_FIELDS_PROJECTION = {"raw_html_resources": 0, "raw_html": 0, "raw_ipynb_json": 0} +REMOVE_PAYLOAD_FIELDS_AND_ID_PROJECTION = dict(REMOVE_PAYLOAD_FIELDS_PROJECTION, **REMOVE_ID_PROJECTION) class MongoResultSerializer: @@ -232,9 +235,7 @@ def get_all_results( base_filter.update(mongo_filter) if since: base_filter.update({"update_time": {"$gt": since}}) - projection = ( - {"_id": 0} if load_payload else {"raw_html_resources": 0, "raw_html": 0, "raw_ipynb_json": 0, "_id": 0} - ) + projection = REMOVE_ID_PROJECTION if load_payload else REMOVE_PAYLOAD_FIELDS_AND_ID_PROJECTION results = self.library.find(base_filter, projection).sort("update_time", -1).limit(limit) for res in results: if res: @@ -247,8 +248,19 @@ def get_all_result_keys(self, limit: int = 0, mongo_filter: Optional[Dict] = Non base_filter = {"status": {"$ne": JobStatus.DELETED.value}} if mongo_filter: base_filter.update(mongo_filter) - projection = {"report_name": 1, "job_id": 1, "_id": 0} - for result in self.library.find(base_filter, projection).sort("update_time", -1).limit(limit): + results = self.library.aggregate( + [ + stage + for stage in ( + {"$match": base_filter}, + {"$sort": {"update_time": -1}}, + {"$limit": limit} if limit else {}, + {"$project": {"report_name": 1, "job_id": 1}}, + ) + if stage + ] + ) + for result in results: keys.append((result["report_name"], result["job_id"])) return keys @@ -307,6 +319,7 @@ def get_latest_successful_job_ids_for_name_all_params(self, report_name: str) -> results = self.library.aggregate( [ {"$match": mongo_filter}, + {"$project": REMOVE_PAYLOAD_FIELDS_PROJECTION}, {"$sort": {"update_time": -1}}, {"$group": {"_id": "$overrides", "job_id": {"$first": "$job_id"}}}, ] diff --git a/tests/unit/serialization/test_mongoose.py b/tests/unit/serialization/test_mongoose.py index 74bb414b..7617ba39 100644 --- a/tests/unit/serialization/test_mongoose.py +++ b/tests/unit/serialization/test_mongoose.py @@ -32,7 +32,11 @@ def test_get_latest_job_id_for_name_and_params(_get_all_job_ids, conn, gridfs): def test__get_all_job_ids(conn, gridfs): serializer = MongoResultSerializer() serializer._get_all_job_ids("report_name", None, limit=1) - serializer.library.find.assert_called_once_with( - {"status": {"$ne": JobStatus.DELETED.value}, "report_name": "report_name"}, - {"_id": 0, "job_id": 1, "report_name": 1}, + serializer.library.aggregate.assert_called_once_with( + [ + {"$match": {"status": {"$ne": JobStatus.DELETED.value}, "report_name": "report_name"}}, + {"$sort": {"update_time": -1}}, + {"$limit": 1}, + {"$project": {"report_name": 1, "job_id": 1}}, + ] )