diff --git a/dpdispatcher/machine.py b/dpdispatcher/machine.py index 01a51557..d6dbf785 100644 --- a/dpdispatcher/machine.py +++ b/dpdispatcher/machine.py @@ -82,6 +82,7 @@ def __init__( local_root=None, remote_root=None, remote_profile={}, + retry_count=3, *, context=None, ): @@ -96,6 +97,7 @@ def __init__( else: pass self.bind_context(context=context) + self.retry_count = retry_count def bind_context(self, context): self.context = context @@ -148,7 +150,8 @@ def load_from_dict(cls, machine_dict): base.check_value(machine_dict, strict=False) context = BaseContext.load_from_dict(machine_dict) - machine = machine_class(context=context) + retry_count = machine_dict.get("retry_count", 3) + machine = machine_class(context=context, retry_count=retry_count) return machine def serialize(self, if_empty_remote_profile=False): @@ -161,6 +164,7 @@ def serialize(self, if_empty_remote_profile=False): machine_dict["remote_profile"] = self.context.remote_profile else: machine_dict["remote_profile"] = {} + machine_dict["retry_count"] = self.retry_count # normalize the dict base = self.arginfo() machine_dict = base.normalize_value(machine_dict, trim_pattern="_*") @@ -396,6 +400,7 @@ def arginfo(cls): doc_clean_asynchronously = ( "Clean the remote directory asynchronously after the job finishes." ) + doc_retry_count = "Number of retries to resubmit failed jobs." machine_args = [ Argument("batch_type", str, optional=False, doc=doc_batch_type), @@ -413,6 +418,7 @@ def arginfo(cls): default=False, doc=doc_clean_asynchronously, ), + Argument("retry_count", int, optional=True, default=3, doc=doc_retry_count), ] context_variant = Variant( diff --git a/dpdispatcher/machines/dp_cloud_server.py b/dpdispatcher/machines/dp_cloud_server.py index b4719bfe..001a17fe 100644 --- a/dpdispatcher/machines/dp_cloud_server.py +++ b/dpdispatcher/machines/dp_cloud_server.py @@ -19,7 +19,8 @@ class Bohrium(Machine): alias = ("Lebesgue", "DpCloudServer") - def __init__(self, context): + def __init__(self, context, **kwargs): + super().__init__(context=context, **kwargs) self.context = context self.input_data = context.remote_profile["input_data"].copy() self.api_version = 2 @@ -32,7 +33,6 @@ def __init__(self, context): phone = context.remote_profile.get("phone", None) username = context.remote_profile.get("username", None) password = context.remote_profile.get("password", None) - self.retry_count = context.remote_profile.get("retry_count", 3) self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True) ticket = os.environ.get("BOHR_TICKET", None) diff --git a/dpdispatcher/machines/openapi.py b/dpdispatcher/machines/openapi.py index c6926b6c..9a2941a1 100644 --- a/dpdispatcher/machines/openapi.py +++ b/dpdispatcher/machines/openapi.py @@ -29,7 +29,8 @@ def unzip_file(zip_file, out_dir="./"): class OpenAPI(Machine): - def __init__(self, context): + def __init__(self, context, **kwargs): + super().__init__(context=context, **kwargs) if not found_bohriumsdk: raise ModuleNotFoundError( "bohriumsdk not installed. Install dpdispatcher with `pip install dpdispatcher[bohrium]`" @@ -38,7 +39,6 @@ def __init__(self, context): self.remote_profile = context.remote_profile.copy() self.grouped = self.remote_profile.get("grouped", True) - self.retry_count = self.remote_profile.get("retry_count", 3) self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True) access_key = ( diff --git a/dpdispatcher/machines/pbs.py b/dpdispatcher/machines/pbs.py index 35ef4c44..7b81a656 100644 --- a/dpdispatcher/machines/pbs.py +++ b/dpdispatcher/machines/pbs.py @@ -17,6 +17,9 @@ class PBS(Machine): + # def __init__(self, **kwargs): + # super().__init__(**kwargs) + def gen_script(self, job): pbs_script = super().gen_script(job) return pbs_script @@ -188,24 +191,8 @@ def gen_script_header(self, job): class SGE(PBS): - def __init__( - self, - batch_type=None, - context_type=None, - local_root=None, - remote_root=None, - remote_profile={}, - *, - context=None, - ): - super(PBS, self).__init__( - batch_type, - context_type, - local_root, - remote_root, - remote_profile, - context=context, - ) + def __init__(self, **kwargs): + super().__init__(**kwargs) def gen_script_header(self, job): ### Ref:https://softpanorama.org/HPC/PBS_and_derivatives/Reference/pbs_command_vs_sge_commands.shtml diff --git a/pyproject.toml b/pyproject.toml index dbf1814c..b167c046 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,7 @@ build-backend = "setuptools.build_meta" name = "dpdispatcher" dynamic = ["version"] description = "Generate HPC scheduler systems jobs input scripts, submit these scripts to HPC systems, and poke until they finish" -authors = [ - { name = "DeepModeling" }, -] +authors = [{ name = "DeepModeling" }] license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3.7", @@ -32,7 +30,15 @@ dependencies = [ ] requires-python = ">=3.7" readme = "README.md" -keywords = ["dispatcher", "hpc", "slurm", "lsf", "pbs", "ssh", "jh_unischeduler"] +keywords = [ + "dispatcher", + "hpc", + "slurm", + "lsf", + "pbs", + "ssh", + "jh_unischeduler", +] [project.urls] Homepage = "https://github.com/deepmodeling/dpdispatcher" @@ -59,12 +65,8 @@ docs = [ ] cloudserver = ["oss2", "tqdm", "bohrium-sdk"] bohrium = ["oss2", "tqdm", "bohrium-sdk"] -gui = [ - "dpgui", -] -test = [ - "dpgui", -] +gui = ["dpgui"] +test = ["dpgui"] [tool.setuptools.packages.find] include = ["dpdispatcher*"] @@ -84,11 +86,11 @@ profile = "black" [tool.ruff.lint] select = [ - "E", # errors - "F", # pyflakes - "D", # pydocstyle + "E", # errors + "F", # pyflakes + "D", # pydocstyle "UP", # pyupgrade - "I", # isort + "I", # isort ] ignore = [ "E501", # line too long @@ -113,3 +115,6 @@ ignore = [ [tool.ruff.lint.pydocstyle] convention = "numpy" + +[tool.ruff] +line-length = 88 diff --git a/tests/test_argcheck.py b/tests/test_argcheck.py index b87f39fc..637c5254 100644 --- a/tests/test_argcheck.py +++ b/tests/test_argcheck.py @@ -27,6 +27,7 @@ def test_machine_argcheck(self): "symlink": True, }, "clean_asynchronously": False, + "retry_count": 3, } self.assertDictEqual(norm_dict, expected_dict)