diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index cb53d84fd..9663757bb 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -68,14 +68,15 @@ class SlurmReplicaRequest: @classmethod def from_role(cls, role: Role, cfg: RunConfig) -> "SlurmReplicaRequest": opts = {k: str(v) for k, v in cfg.cfgs.items()} - - if (resource := role.resource) != NONE: - if (cpu := resource.cpu) > 0: - opts["cpus-per-task"] = str(cpu) - if (memMB := resource.memMB) > 0: - opts["mem"] = str(memMB) - if (gpu := resource.gpu) > 0: - opts["gpus-per-task"] = str(gpu) + resource = role.resource + + if resource != NONE: + if resource.cpu > 0: + opts["cpus-per-task"] = str(resource.cpu) + if resource.memMB > 0: + opts["mem"] = str(resource.memMB) + if resource.gpu > 0: + opts["gpus-per-task"] = str(resource.gpu) return cls( dir=role.image, diff --git a/torchx/schedulers/test/slurm_scheduler_test.py b/torchx/schedulers/test/slurm_scheduler_test.py index 61ce27f70..3d2401c93 100644 --- a/torchx/schedulers/test/slurm_scheduler_test.py +++ b/torchx/schedulers/test/slurm_scheduler_test.py @@ -138,10 +138,9 @@ def test_run_multi_role(self, run: MagicMock) -> None: self.assertEqual(app_id, "1234") self.assertEqual(run.call_count, 1) - self.assertEqual( - run.call_args.kwargs, {"stdout": subprocess.PIPE, "check": True} - ) - (args,) = run.call_args.args + args, kwargs = run.call_args + self.assertEqual(kwargs, {"stdout": subprocess.PIPE, "check": True}) + (args,) = args self.assertEqual(len(args), 9) self.assertEqual(args[:4], ["sbatch", "--parsable", "--job-name", "foo"]) self.assertTrue(args[4].endswith("role-0-a-0.sh"))