diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index 0b3232bda..396179559 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -135,6 +135,7 @@ def _should_use_gpus_per_node_from_version() -> bool: "comment", "mail-user", "mail-type", + "account", } SBATCH_GROUP_OPTIONS = { "partition", @@ -159,6 +160,7 @@ def _apply_app_id_env(s: str) -> str: SlurmOpts = TypedDict( "SlurmOpts", { + "account": Optional[str], "partition": str, "time": str, "comment": Optional[str], @@ -404,6 +406,12 @@ def __init__(self, session_name: str) -> None: def _run_opts(self) -> runopts: opts = runopts() + opts.add( + "account", + type_=str, + help="The account to use for the slurm job.", + default=None, + ) opts.add( "partition", type_=str, diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index 040a3f3c7..de63f450f 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -159,7 +159,6 @@ def test_submit_dryrun_tags(self, _) -> None: def test_submit_dryrun_job_role_arn(self) -> None: cfg = AWSBatchOpts({"queue": "ignored_in_test", "job_role_arn": "fizzbuzz"}) info = create_scheduler("test").submit_dryrun(_test_app(), cfg) - # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertEqual(cfg["job_role_arn"], node_groups[0]["container"]["jobRoleArn"]) @@ -169,7 +168,6 @@ def test_submit_dryrun_execution_role_arn(self) -> None: {"queue": "ignored_in_test", "execution_role_arn": "veryexecutive"} ) info = create_scheduler("test").submit_dryrun(_test_app(), cfg) - # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertEqual( @@ -179,7 +177,6 @@ def test_submit_dryrun_execution_role_arn(self) -> None: def test_submit_dryrun_privileged(self) -> None: cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) info = create_scheduler("test").submit_dryrun(_test_app(), cfg) - # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertTrue(node_groups[0]["container"]["privileged"]) @@ -189,7 +186,6 @@ def test_submit_dryrun_instance_type_multinode(self) -> None: resource = specs.named_resources_aws.aws_p3dn_24xlarge() app = _test_app(num_replicas=2, resource=resource) info = create_scheduler("test").submit_dryrun(app, cfg) - # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertEqual( @@ -202,7 +198,6 @@ def test_submit_dryrun_instance_type_singlenode(self) -> None: resource = specs.named_resources_aws.aws_p3dn_24xlarge() app = _test_app(num_replicas=1, resource=resource) info = create_scheduler("test").submit_dryrun(app, cfg) - # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertTrue("instanceType" in node_groups[0]["container"]) @@ -212,7 +207,6 @@ def test_submit_dryrun_no_instance_type_non_aws(self) -> None: resource = specs.named_resources_aws.aws_p3dn_24xlarge() app = _test_app(num_replicas=2) info = create_scheduler("test").submit_dryrun(app, cfg) - # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) self.assertTrue("instanceType" not in node_groups[0]["container"]) diff --git a/torchx/schedulers/test/slurm_scheduler_test.py b/torchx/schedulers/test/slurm_scheduler_test.py index 16e457b39..c931dfba3 100644 --- a/torchx/schedulers/test/slurm_scheduler_test.py +++ b/torchx/schedulers/test/slurm_scheduler_test.py @@ -696,6 +696,24 @@ def test_dryrun_comment(self, mock_version: MagicMock) -> None: info.request.cmd, ) + @patch( + "torchx.schedulers.slurm_scheduler.version", + return_value=SLURM_VERSION_24_5, + ) + def test_account(self, mock_version: MagicMock) -> None: + scheduler = create_scheduler("foo") + app = simple_app() + info = scheduler.submit_dryrun( + app, + cfg={ + "account": "foobar", + }, + ) + self.assertIn( + "--account=foobar", + info.request.cmd, + ) + @patch( "torchx.schedulers.slurm_scheduler.version", return_value=SLURM_VERSION_24_5,