diff --git a/CHANGELOG.md b/CHANGELOG.md index e108f1d8..ce98805d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +0.6.1 (2023-09-13) +------------------ +* Feature: GridFS document storage in Mongo-backed instances is now sharded if the mongo server supports it. + 0.6.0 (2023-09-01) ------------------ * Feature: Reports are now grouped by their containing folder on the main UI. diff --git a/notebooker/serialization/mongo.py b/notebooker/serialization/mongo.py index 13804e65..f81f90e3 100644 --- a/notebooker/serialization/mongo.py +++ b/notebooker/serialization/mongo.py @@ -86,14 +86,22 @@ def load_files_from_gridfs(result_data_store: gridfs.GridFS, result: Dict, do_re class MongoResultSerializer(ABC): + instance = None + + def __new__(cls, *args, **kwargs): + if not isinstance(cls.instance, cls): + cls.instance = object.__new__(cls) + return cls.instance + # This class is the interface between Mongo and the rest of the application def __init__(self, database_name="notebooker", mongo_host="localhost", result_collection_name="NOTEBOOK_OUTPUT"): self.database_name = database_name self.mongo_host = mongo_host self.result_collection_name = result_collection_name - mongo_connection = self.get_mongo_database() - self.library = mongo_connection[result_collection_name] - self.result_data_store = gridfs.GridFS(mongo_connection, "notebook_data") + + mongo_database = self.get_mongo_database() + self.library = mongo_database[result_collection_name] + self.result_data_store = gridfs.GridFS(mongo_database, "notebook_data") def __init_subclass__(cls, cli_options: click.Command = None, **kwargs): if cli_options is None: @@ -104,6 +112,16 @@ def __init_subclass__(cls, cli_options: click.Command = None, **kwargs): cls.cli_options = cli_options super().__init_subclass__(**kwargs) + def enable_sharding(self): + conn = self.get_mongo_connection() + try: + conn.admin.command("enableSharding", self.database_name) + conn.admin.command({"shardCollection": f"{self.database_name}.notebook_data.chunks", + "key": {"files_id": 1, "n": 1}}) + logger.info(f"Successfully sharded GridFS collection for {self.database_name}") + except pymongo.errors.OperationFailure: + logger.error(f"Could not shard {self.database_name}. Continuing.") + def serializer_args_to_cmdline_args(self) -> List[str]: args = [] for cli_arg in self.cli_options.params: diff --git a/notebooker/web/app.py b/notebooker/web/app.py index 6d042c47..c7aaeb7d 100644 --- a/notebooker/web/app.py +++ b/notebooker/web/app.py @@ -138,6 +138,11 @@ def main(web_config: WebappConfig): GLOBAL_CONFIG = web_config flask_app = create_app(web_config) flask_app = setup_app(flask_app, web_config) + serializer = get_serializer_from_cls(web_config.SERIALIZER_CLS, **web_config.SERIALIZER_CONFIG) + try: + serializer.enable_sharding() + except AttributeError: + pass start_app(web_config) logger.info("Notebooker is now running at http://0.0.0.0:%d", web_config.PORT) http_server = WSGIServer(("0.0.0.0", web_config.PORT), flask_app) diff --git a/tests/unit/serialization/test_mongoose.py b/tests/unit/serialization/test_mongoose.py index 7617ba39..201be808 100644 --- a/tests/unit/serialization/test_mongoose.py +++ b/tests/unit/serialization/test_mongoose.py @@ -21,7 +21,8 @@ def test_mongo_filter_status(): @patch("notebooker.serialization.mongo.gridfs") @patch("notebooker.serialization.mongo.MongoResultSerializer.get_mongo_database") @patch("notebooker.serialization.mongo.MongoResultSerializer._get_all_job_ids") -def test_get_latest_job_id_for_name_and_params(_get_all_job_ids, conn, gridfs): +@patch("notebooker.serialization.mongo.MongoResultSerializer.get_mongo_connection") +def test_get_latest_job_id_for_name_and_params(conn, _get_all_job_ids, db, gridfs): serializer = MongoResultSerializer() serializer.get_latest_job_id_for_name_and_params("report_name", None) _get_all_job_ids.assert_called_once_with("report_name", None, as_of=None, limit=1) @@ -29,7 +30,8 @@ def test_get_latest_job_id_for_name_and_params(_get_all_job_ids, conn, gridfs): @patch("notebooker.serialization.mongo.gridfs") @patch("notebooker.serialization.mongo.MongoResultSerializer.get_mongo_database") -def test__get_all_job_ids(conn, gridfs): +@patch("notebooker.serialization.mongo.MongoResultSerializer.get_mongo_connection") +def test__get_all_job_ids(conn, db, gridfs): serializer = MongoResultSerializer() serializer._get_all_job_ids("report_name", None, limit=1) serializer.library.aggregate.assert_called_once_with(