From 7dfac17a5dff2852d6fcaa21426aff90d6cf8e81 Mon Sep 17 00:00:00 2001 From: Kunal Goswami Date: Wed, 27 Aug 2025 11:29:05 -0700 Subject: [PATCH] fix: Pass instance type into aws batch job definition even when num_replicas = 1 --- torchx/schedulers/aws_batch_scheduler.py | 2 +- torchx/schedulers/test/aws_batch_scheduler_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index 76e285539..eab0490cb 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -255,7 +255,7 @@ def _role_to_node_properties( container["jobRoleArn"] = job_role_arn if execution_role_arn: container["executionRoleArn"] = execution_role_arn - if role.num_replicas > 1: + if role.num_replicas > 0: instance_type = instance_type_from_resource(role.resource) if instance_type is not None: container["instanceType"] = instance_type diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index f8773d081..c2a5f65f6 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -195,7 +195,7 @@ def test_submit_dryrun_instance_type_multinode(self) -> None: node_groups[0]["container"]["instanceType"], ) - def test_submit_dryrun_no_instance_type_singlenode(self) -> None: + def test_submit_dryrun_instance_type_singlenode(self) -> None: cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True}) resource = specs.named_resources_aws.aws_p3dn_24xlarge() app = _test_app(num_replicas=1, resource=resource) @@ -203,7 +203,7 @@ def test_submit_dryrun_no_instance_type_singlenode(self) -> None: # pyre-ignore[16] node_groups = info.request.job_def["nodeProperties"]["nodeRangeProperties"] self.assertEqual(1, len(node_groups)) - self.assertTrue("instanceType" not in node_groups[0]["container"]) + self.assertTrue("instanceType" in node_groups[0]["container"]) def test_submit_dryrun_no_instance_type_non_aws(self) -> None: cfg = AWSBatchOpts({"queue": "ignored_in_test", "privileged": True})