Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions notebooker/serialization/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from notebooker.constants import JobStatus, NotebookResultComplete, NotebookResultError, NotebookResultPending

logger = getLogger(__name__)
REMOVE_ID_PROJECTION = {"_id": 0}
REMOVE_PAYLOAD_FIELDS_PROJECTION = {"raw_html_resources": 0, "raw_html": 0, "raw_ipynb_json": 0}
REMOVE_PAYLOAD_FIELDS_AND_ID_PROJECTION = dict(REMOVE_PAYLOAD_FIELDS_PROJECTION, **REMOVE_ID_PROJECTION)


class MongoResultSerializer:
Expand Down Expand Up @@ -232,9 +235,7 @@ def get_all_results(
base_filter.update(mongo_filter)
if since:
base_filter.update({"update_time": {"$gt": since}})
projection = (
{"_id": 0} if load_payload else {"raw_html_resources": 0, "raw_html": 0, "raw_ipynb_json": 0, "_id": 0}
)
projection = REMOVE_ID_PROJECTION if load_payload else REMOVE_PAYLOAD_FIELDS_AND_ID_PROJECTION
results = self.library.find(base_filter, projection).sort("update_time", -1).limit(limit)
for res in results:
if res:
Expand All @@ -247,8 +248,19 @@ def get_all_result_keys(self, limit: int = 0, mongo_filter: Optional[Dict] = Non
base_filter = {"status": {"$ne": JobStatus.DELETED.value}}
if mongo_filter:
base_filter.update(mongo_filter)
projection = {"report_name": 1, "job_id": 1, "_id": 0}
for result in self.library.find(base_filter, projection).sort("update_time", -1).limit(limit):
results = self.library.aggregate(
[
stage
for stage in (
{"$match": base_filter},
{"$sort": {"update_time": -1}},
{"$limit": limit} if limit else {},
{"$project": {"report_name": 1, "job_id": 1}},
)
if stage
]
)
for result in results:
keys.append((result["report_name"], result["job_id"]))
return keys

Expand Down Expand Up @@ -307,6 +319,7 @@ def get_latest_successful_job_ids_for_name_all_params(self, report_name: str) ->
results = self.library.aggregate(
[
{"$match": mongo_filter},
{"$project": REMOVE_PAYLOAD_FIELDS_PROJECTION},
{"$sort": {"update_time": -1}},
{"$group": {"_id": "$overrides", "job_id": {"$first": "$job_id"}}},
]
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/serialization/test_mongoose.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def test_get_latest_job_id_for_name_and_params(_get_all_job_ids, conn, gridfs):
def test__get_all_job_ids(conn, gridfs):
serializer = MongoResultSerializer()
serializer._get_all_job_ids("report_name", None, limit=1)
serializer.library.find.assert_called_once_with(
{"status": {"$ne": JobStatus.DELETED.value}, "report_name": "report_name"},
{"_id": 0, "job_id": 1, "report_name": 1},
serializer.library.aggregate.assert_called_once_with(
[
{"$match": {"status": {"$ne": JobStatus.DELETED.value}, "report_name": "report_name"}},
{"$sort": {"update_time": -1}},
{"$limit": 1},
{"$project": {"report_name": 1, "job_id": 1}},
]
)