Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
0.6.1 (2023-09-13)
------------------
* Feature: GridFS document storage in Mongo-backed instances is now sharded if the mongo server supports it.

0.6.0 (2023-09-01)
------------------
* Feature: Reports are now grouped by their containing folder on the main UI.
Expand Down
24 changes: 21 additions & 3 deletions notebooker/serialization/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,22 @@ def load_files_from_gridfs(result_data_store: gridfs.GridFS, result: Dict, do_re


class MongoResultSerializer(ABC):
instance = None

def __new__(cls, *args, **kwargs):
if not isinstance(cls.instance, cls):
cls.instance = object.__new__(cls)
return cls.instance

# This class is the interface between Mongo and the rest of the application
def __init__(self, database_name="notebooker", mongo_host="localhost", result_collection_name="NOTEBOOK_OUTPUT"):
self.database_name = database_name
self.mongo_host = mongo_host
self.result_collection_name = result_collection_name
mongo_connection = self.get_mongo_database()
self.library = mongo_connection[result_collection_name]
self.result_data_store = gridfs.GridFS(mongo_connection, "notebook_data")

mongo_database = self.get_mongo_database()
self.library = mongo_database[result_collection_name]
self.result_data_store = gridfs.GridFS(mongo_database, "notebook_data")

def __init_subclass__(cls, cli_options: click.Command = None, **kwargs):
if cli_options is None:
Expand All @@ -104,6 +112,16 @@ def __init_subclass__(cls, cli_options: click.Command = None, **kwargs):
cls.cli_options = cli_options
super().__init_subclass__(**kwargs)

def enable_sharding(self):
conn = self.get_mongo_connection()
try:
conn.admin.command("enableSharding", self.database_name)
conn.admin.command({"shardCollection": f"{self.database_name}.notebook_data.chunks",
"key": {"files_id": 1, "n": 1}})
logger.info(f"Successfully sharded GridFS collection for {self.database_name}")
except pymongo.errors.OperationFailure:
logger.error(f"Could not shard {self.database_name}. Continuing.")

def serializer_args_to_cmdline_args(self) -> List[str]:
args = []
for cli_arg in self.cli_options.params:
Expand Down
5 changes: 5 additions & 0 deletions notebooker/web/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def main(web_config: WebappConfig):
GLOBAL_CONFIG = web_config
flask_app = create_app(web_config)
flask_app = setup_app(flask_app, web_config)
serializer = get_serializer_from_cls(web_config.SERIALIZER_CLS, **web_config.SERIALIZER_CONFIG)
try:
serializer.enable_sharding()
except AttributeError:
pass
start_app(web_config)
logger.info("Notebooker is now running at http://0.0.0.0:%d", web_config.PORT)
http_server = WSGIServer(("0.0.0.0", web_config.PORT), flask_app)
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/serialization/test_mongoose.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ def test_mongo_filter_status():
@patch("notebooker.serialization.mongo.gridfs")
@patch("notebooker.serialization.mongo.MongoResultSerializer.get_mongo_database")
@patch("notebooker.serialization.mongo.MongoResultSerializer._get_all_job_ids")
def test_get_latest_job_id_for_name_and_params(_get_all_job_ids, conn, gridfs):
@patch("notebooker.serialization.mongo.MongoResultSerializer.get_mongo_connection")
def test_get_latest_job_id_for_name_and_params(conn, _get_all_job_ids, db, gridfs):
serializer = MongoResultSerializer()
serializer.get_latest_job_id_for_name_and_params("report_name", None)
_get_all_job_ids.assert_called_once_with("report_name", None, as_of=None, limit=1)


@patch("notebooker.serialization.mongo.gridfs")
@patch("notebooker.serialization.mongo.MongoResultSerializer.get_mongo_database")
def test__get_all_job_ids(conn, gridfs):
@patch("notebooker.serialization.mongo.MongoResultSerializer.get_mongo_connection")
def test__get_all_job_ids(conn, db, gridfs):
serializer = MongoResultSerializer()
serializer._get_all_job_ids("report_name", None, limit=1)
serializer.library.aggregate.assert_called_once_with(
Expand Down