diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index eab0490cb..efcaaf998 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -99,6 +99,37 @@ TAG_TORCHX_USER = "torchx.pytorch.org/user" +def parse_ulimits(ulimits_list: list[str]) -> List[Dict[str, Any]]: + """ + Parse ulimit string in format: name:softLimit:hardLimit + Multiple ulimits separated by commas. + """ + if not ulimits_list: + return [] + + ulimits = [] + for ulimit_str in ulimits_list: + if not ulimit_str.strip(): + continue + + parts = ulimit_str.strip().split(":") + if len(parts) != 3: + raise ValueError( + f"ulimit must be in format name:softLimit:hardLimit, got: {ulimit_str}" + ) + + name, soft_limit, hard_limit = parts + ulimits.append( + { + "name": name, + "softLimit": int(soft_limit) if soft_limit != "-1" else -1, + "hardLimit": int(hard_limit) if hard_limit != "-1" else -1, + } + ) + + return ulimits + + if TYPE_CHECKING: from docker import DockerClient @@ -177,7 +208,8 @@ def _role_to_node_properties( privileged: bool = False, job_role_arn: Optional[str] = None, execution_role_arn: Optional[str] = None, -) -> Dict[str, object]: + ulimits: Optional[List[Dict[str, Any]]] = None, +) -> Dict[str, Any]: role.mounts += get_device_mounts(role.resource.devices) mount_points = [] @@ -239,6 +271,7 @@ def _role_to_node_properties( "environment": [{"name": k, "value": v} for k, v in role.env.items()], "privileged": privileged, "resourceRequirements": resource_requirements_from_resource(role.resource), + **({"ulimits": ulimits} if ulimits else {}), "linuxParameters": { # To support PyTorch dataloaders we need to set /dev/shm to larger # than the 64M default. @@ -361,6 +394,7 @@ class AWSBatchOpts(TypedDict, total=False): priority: int job_role_arn: Optional[str] execution_role_arn: Optional[str] + ulimits: Optional[list[str]] class AWSBatchScheduler( @@ -514,6 +548,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ privileged=cfg["privileged"], job_role_arn=cfg.get("job_role_arn"), execution_role_arn=cfg.get("execution_role_arn"), + ulimits=parse_ulimits(cfg.get("ulimits") or []), ) ) node_idx += role.num_replicas @@ -599,6 +634,11 @@ def _run_opts(self) -> runopts: type_=str, help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.", ) + opts.add( + "ulimits", + type_=List[str], + help="Ulimit settings in format: name:softLimit:hardLimit (multiple separated by commas)", + ) return opts def _get_job_id(self, app_id: str) -> Optional[str]: diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py index c2a5f65f6..7379f08d5 100644 --- a/torchx/schedulers/test/aws_batch_scheduler_test.py +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -23,6 +23,7 @@ AWSBatchScheduler, create_scheduler, ENV_TORCHX_ROLE_NAME, + parse_ulimits, resource_from_resource_requirements, resource_requirements_from_resource, to_millis_since_epoch, @@ -311,7 +312,6 @@ def test_volume_mounts(self) -> None: ) props = _role_to_node_properties(role, 0) self.assertEqual( - # pyre-fixme[16]: `object` has no attribute `__getitem__`. props["container"]["volumes"], [ { @@ -350,7 +350,6 @@ def test_device_mounts(self) -> None: ) props = _role_to_node_properties(role, 0) self.assertEqual( - # pyre-fixme[16]: `object` has no attribute `__getitem__`. props["container"]["linuxParameters"]["devices"], [ { @@ -375,7 +374,6 @@ def test_resource_devices(self) -> None: ) props = _role_to_node_properties(role, 0) self.assertEqual( - # pyre-fixme[16]: `object` has no attribute `__getitem__`. props["container"]["linuxParameters"]["devices"], [ { @@ -396,6 +394,46 @@ def test_resource_devices(self) -> None: ], ) + def test_role_to_node_properties_ulimits(self) -> None: + role = specs.Role( + name="test", + image="test:latest", + entrypoint="test", + args=["test"], + resource=specs.Resource(cpu=1, memMB=1000, gpu=0), + ) + ulimits = [ + {"name": "nofile", "softLimit": 65536, "hardLimit": 65536}, + {"name": "memlock", "softLimit": -1, "hardLimit": -1}, + ] + props = _role_to_node_properties(role, 0, ulimits=ulimits) + self.assertEqual( + props["container"]["ulimits"], + ulimits, + ) + + def test_parse_ulimits(self) -> None: + # Test single ulimit + result = parse_ulimits(["nofile:65536:65536"]) + expected = [{"name": "nofile", "softLimit": 65536, "hardLimit": 65536}] + self.assertEqual(result, expected) + + # Test multiple ulimits + result = parse_ulimits(["nofile:65536:65536", "memlock:-1:-1"]) + expected = [ + {"name": "nofile", "softLimit": 65536, "hardLimit": 65536}, + {"name": "memlock", "softLimit": -1, "hardLimit": -1}, + ] + self.assertEqual(result, expected) + + # Test empty list + result = parse_ulimits([]) + self.assertEqual(result, []) + + # Test invalid format + with self.assertRaises(ValueError): + parse_ulimits(["invalid"]) + def _mock_scheduler_running_job(self) -> AWSBatchScheduler: scheduler = AWSBatchScheduler( "test",