diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index f74dd2448..fde0fbf96 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -210,6 +210,7 @@ def from_role( sbatch_opts.setdefault("gpus-per-node", str(resource.gpu)) else: sbatch_opts.setdefault("gpus-per-task", str(resource.gpu)) + sbatch_opts.setdefault("ntasks", "1") srun_opts = { "output": f"slurm-{macros.app_id}-{name}.out", diff --git a/torchx/schedulers/test/slurm_scheduler_test.py b/torchx/schedulers/test/slurm_scheduler_test.py index 480f02bc8..ef7f3383e 100644 --- a/torchx/schedulers/test/slurm_scheduler_test.py +++ b/torchx/schedulers/test/slurm_scheduler_test.py @@ -128,6 +128,7 @@ def test_replica_request(self, mock_version: MagicMock) -> None: "--cpus-per-task=2", "--mem=10", "--gpus-per-task=3", + "--ntasks=1", ], ) self.assertEqual( @@ -163,6 +164,7 @@ def test_replica_request_nomem(self, mock_version: MagicMock) -> None: "--ntasks-per-node=1", "--cpus-per-task=2", "--gpus-per-task=3", + "--ntasks=1", ], )