Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions torchx/specs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -52,14 +51,19 @@

from torchx.util.modules import import_attr

AWS_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr(
GiB: int = 1024

ResourceFactory = Callable[[], Resource]

AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
"torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={}
)
GENERIC_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr(
GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
"torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={}
)

GiB: int = 1024
FB_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
"torchx.specs.fb.named_resources", "NAMED_RESOURCES", default={}
)


def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
Expand All @@ -69,6 +73,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
for name, resource in {
**GENERIC_NAMED_RESOURCES,
**AWS_NAMED_RESOURCES,
**FB_NAMED_RESOURCES,
**resource_methods,
}.items():
materialized_resources[name] = resource
Expand Down
5 changes: 4 additions & 1 deletion torchx/specs/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -83,6 +82,8 @@ class Resource:
memMB: MB of ram
capabilities: additional hardware specs (interpreted by scheduler)
devices: a list of named devices with their quantities
tags: metadata tags for the resource (not interpreted by schedulers)
used to add non-functional information about resources (e.g. whether it is an alias of another resource)

Note: you should prefer to use named_resources instead of specifying the raw
resource requirement directly.
Expand All @@ -93,6 +94,7 @@ class Resource:
memMB: int
capabilities: Dict[str, Any] = field(default_factory=dict)
devices: Dict[str, int] = field(default_factory=dict)
tags: Dict[str, object] = field(default_factory=dict)

@staticmethod
def copy(original: "Resource", **capabilities: Any) -> "Resource":
Expand All @@ -101,6 +103,7 @@ def copy(original: "Resource", **capabilities: Any) -> "Resource":
are present in the original resource and as parameter, the one from parameter
will be used.
"""

res_capabilities = dict(original.capabilities)
res_capabilities.update(capabilities)
return Resource(
Expand Down
Loading