diff --git a/torchx/specs/__init__.py b/torchx/specs/__init__.py index ae1622e61..efb5df037 100644 --- a/torchx/specs/__init__.py +++ b/torchx/specs/__init__.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -52,14 +51,19 @@ from torchx.util.modules import import_attr -AWS_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr( +GiB: int = 1024 + +ResourceFactory = Callable[[], Resource] + +AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr( "torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={} ) -GENERIC_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr( +GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr( "torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={} ) - -GiB: int = 1024 +FB_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr( + "torchx.specs.fb.named_resources", "NAMED_RESOURCES", default={} +) def _load_named_resources() -> Dict[str, Callable[[], Resource]]: @@ -69,6 +73,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]: for name, resource in { **GENERIC_NAMED_RESOURCES, **AWS_NAMED_RESOURCES, + **FB_NAMED_RESOURCES, **resource_methods, }.items(): materialized_resources[name] = resource diff --git a/torchx/specs/api.py b/torchx/specs/api.py index e3e954a5b..02657aa4e 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -83,6 +82,8 @@ class Resource: memMB: MB of ram capabilities: additional hardware specs (interpreted by scheduler) devices: a list of named devices with their quantities + tags: metadata tags for the resource (not interpreted by schedulers) + used to add non-functional information about resources (e.g. whether it is an alias of another resource) Note: you should prefer to use named_resources instead of specifying the raw resource requirement directly. @@ -93,6 +94,7 @@ class Resource: memMB: int capabilities: Dict[str, Any] = field(default_factory=dict) devices: Dict[str, int] = field(default_factory=dict) + tags: Dict[str, object] = field(default_factory=dict) @staticmethod def copy(original: "Resource", **capabilities: Any) -> "Resource": @@ -101,6 +103,7 @@ def copy(original: "Resource", **capabilities: Any) -> "Resource": are present in the original resource and as parameter, the one from parameter will be used. """ + res_capabilities = dict(original.capabilities) res_capabilities.update(capabilities) return Resource(