Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,37 @@
TAG_TORCHX_USER = "torchx.pytorch.org/user"


def parse_ulimits(ulimits_list: list[str]) -> List[Dict[str, Any]]:
"""
Parse ulimit string in format: name:softLimit:hardLimit
Multiple ulimits separated by commas.
"""
if not ulimits_list:
return []

ulimits = []
for ulimit_str in ulimits_list:
if not ulimit_str.strip():
continue

parts = ulimit_str.strip().split(":")
if len(parts) != 3:
raise ValueError(
f"ulimit must be in format name:softLimit:hardLimit, got: {ulimit_str}"
)

name, soft_limit, hard_limit = parts
ulimits.append(
{
"name": name,
"softLimit": int(soft_limit) if soft_limit != "-1" else -1,
"hardLimit": int(hard_limit) if hard_limit != "-1" else -1,
}
)

return ulimits


if TYPE_CHECKING:
from docker import DockerClient

Expand Down Expand Up @@ -177,7 +208,8 @@ def _role_to_node_properties(
privileged: bool = False,
job_role_arn: Optional[str] = None,
execution_role_arn: Optional[str] = None,
) -> Dict[str, object]:
ulimits: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
role.mounts += get_device_mounts(role.resource.devices)

mount_points = []
Expand Down Expand Up @@ -239,6 +271,7 @@ def _role_to_node_properties(
"environment": [{"name": k, "value": v} for k, v in role.env.items()],
"privileged": privileged,
"resourceRequirements": resource_requirements_from_resource(role.resource),
**({"ulimits": ulimits} if ulimits else {}),
"linuxParameters": {
# To support PyTorch dataloaders we need to set /dev/shm to larger
# than the 64M default.
Expand Down Expand Up @@ -361,6 +394,7 @@ class AWSBatchOpts(TypedDict, total=False):
priority: int
job_role_arn: Optional[str]
execution_role_arn: Optional[str]
ulimits: Optional[list[str]]


class AWSBatchScheduler(
Expand Down Expand Up @@ -514,6 +548,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
privileged=cfg["privileged"],
job_role_arn=cfg.get("job_role_arn"),
execution_role_arn=cfg.get("execution_role_arn"),
ulimits=parse_ulimits(cfg.get("ulimits") or []),
)
)
node_idx += role.num_replicas
Expand Down Expand Up @@ -599,6 +634,11 @@ def _run_opts(self) -> runopts:
type_=str,
help="The Amazon Resource Name (ARN) of the IAM role that the ECS agent can assume for AWS permissions.",
)
opts.add(
"ulimits",
type_=List[str],
help="Ulimit settings in format: name:softLimit:hardLimit (multiple separated by commas)",
)
return opts

def _get_job_id(self, app_id: str) -> Optional[str]:
Expand Down
44 changes: 41 additions & 3 deletions torchx/schedulers/test/aws_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AWSBatchScheduler,
create_scheduler,
ENV_TORCHX_ROLE_NAME,
parse_ulimits,
resource_from_resource_requirements,
resource_requirements_from_resource,
to_millis_since_epoch,
Expand Down Expand Up @@ -311,7 +312,6 @@ def test_volume_mounts(self) -> None:
)
props = _role_to_node_properties(role, 0)
self.assertEqual(
# pyre-fixme[16]: `object` has no attribute `__getitem__`.
props["container"]["volumes"],
[
{
Expand Down Expand Up @@ -350,7 +350,6 @@ def test_device_mounts(self) -> None:
)
props = _role_to_node_properties(role, 0)
self.assertEqual(
# pyre-fixme[16]: `object` has no attribute `__getitem__`.
props["container"]["linuxParameters"]["devices"],
[
{
Expand All @@ -375,7 +374,6 @@ def test_resource_devices(self) -> None:
)
props = _role_to_node_properties(role, 0)
self.assertEqual(
# pyre-fixme[16]: `object` has no attribute `__getitem__`.
props["container"]["linuxParameters"]["devices"],
[
{
Expand All @@ -396,6 +394,46 @@ def test_resource_devices(self) -> None:
],
)

def test_role_to_node_properties_ulimits(self) -> None:
role = specs.Role(
name="test",
image="test:latest",
entrypoint="test",
args=["test"],
resource=specs.Resource(cpu=1, memMB=1000, gpu=0),
)
ulimits = [
{"name": "nofile", "softLimit": 65536, "hardLimit": 65536},
{"name": "memlock", "softLimit": -1, "hardLimit": -1},
]
props = _role_to_node_properties(role, 0, ulimits=ulimits)
self.assertEqual(
props["container"]["ulimits"],
ulimits,
)

def test_parse_ulimits(self) -> None:
# Test single ulimit
result = parse_ulimits(["nofile:65536:65536"])
expected = [{"name": "nofile", "softLimit": 65536, "hardLimit": 65536}]
self.assertEqual(result, expected)

# Test multiple ulimits
result = parse_ulimits(["nofile:65536:65536", "memlock:-1:-1"])
expected = [
{"name": "nofile", "softLimit": 65536, "hardLimit": 65536},
{"name": "memlock", "softLimit": -1, "hardLimit": -1},
]
self.assertEqual(result, expected)

# Test empty list
result = parse_ulimits([])
self.assertEqual(result, [])

# Test invalid format
with self.assertRaises(ValueError):
parse_ulimits(["invalid"])

def _mock_scheduler_running_job(self) -> AWSBatchScheduler:
scheduler = AWSBatchScheduler(
"test",
Expand Down
Loading