Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
ANNOTATION_ISTIO_SIDECAR = "sidecar.istio.io/inject"

LABEL_INSTANCE_TYPE = "node.kubernetes.io/instance-type"
TPU_TF_VERSION = "tf-version.cloud-tpus.google.com"


def sanitize_for_serialization(obj: object) -> object:
Expand Down Expand Up @@ -314,6 +315,14 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod
security_context=security_context,
)

annotations = {
# Disable the istio sidecar as it prevents the containers from
# exiting once finished.
ANNOTATION_ISTIO_SIDECAR: "false",
}
if TPU_TF_VERSION in resource.capabilities:
annotations[TPU_TF_VERSION] = resource.capabilities[TPU_TF_VERSION]

return V1Pod(
spec=V1PodSpec(
containers=[container],
Expand All @@ -323,11 +332,7 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod
node_selector=node_selector,
),
metadata=V1ObjectMeta(
annotations={
# Disable the istio sidecar as it prevents the containers from
# exiting once finished.
ANNOTATION_ISTIO_SIDECAR: "false",
},
annotations=annotations,
labels={},
),
)
Expand Down Expand Up @@ -362,6 +367,7 @@ def app_to_resource(
job level. When using the APPLICATION retry policy, the job level retry
count is set to the minimum of the max_retries of the roles.
"""
scheduler_name: str = "volcano"
tasks = []
unique_app_id = cleanup_str(make_unique(app.name))
for role_idx, role in enumerate(app.roles):
Expand All @@ -386,6 +392,12 @@ def app_to_resource(
"name": name,
"template": pod,
}
if TPU_TF_VERSION in pod.metadata.annotations:
# Volcano can't handle TPUs so fallback to default Pod
# scheduling behavior.
task["minAvailable"] = 0
scheduler_name = "default-scheduler"

if role.max_retries > 0:
task["maxRetry"] = role.max_retries
task["policies"] = RETRY_POLICIES[role.retry_policy]
Expand All @@ -402,7 +414,7 @@ def app_to_resource(
"kind": "Job",
"metadata": {"name": f"{unique_app_id}"},
"spec": {
"schedulerName": "volcano",
"schedulerName": scheduler_name,
"queue": queue,
"tasks": tasks,
"maxRetry": job_retries,
Expand Down Expand Up @@ -680,6 +692,9 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
roles_statuses[role].replicas.append(
ReplicaStatus(id=int(idx), role=role, state=state, hostname="")
)
elif app_state == AppState.RUNNING:
# if no tasks and running -- pods haven't been created yet
app_state = AppState.PENDING
else:
app_state = AppState.UNKNOWN
return DescribeAppResponse(
Expand Down
22 changes: 22 additions & 0 deletions torchx/schedulers/test/kubernetes_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
role_to_pod,
LABEL_INSTANCE_TYPE,
)
from torchx.specs.named_resources_tpu import tpu_v3_8

SKIP_DOCKER: bool = not has_docker()

Expand Down Expand Up @@ -727,6 +728,27 @@ def test_push_patches(self) -> None:
self.assertEqual(client.images.get().tag.call_count, 1)
self.assertEqual(client.images.push.call_count, 1)

def test_tpu(self) -> None:
scheduler = create_scheduler("test")

role = specs.Role(
name="foo",
image="",
resource=tpu_v3_8(),
)
app = specs.AppDef("test", roles=[role])
info = scheduler._submit_dryrun(app, cfg={"queue": "blah"})
res = info.request.resource
# pyre-ignore
self.assertEqual(res["spec"]["schedulerName"], "default-scheduler")
self.assertEqual(
res["spec"]["tasks"][0]["template"].metadata.annotations[
"tf-version.cloud-tpus.google.com"
],
"pytorch-1.11",
)
self.assertEqual(res["spec"]["tasks"][0]["minAvailable"], 0)


class KubernetesSchedulerNoImportTest(unittest.TestCase):
"""
Expand Down
6 changes: 5 additions & 1 deletion torchx/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Dict, Optional

from torchx.specs.named_resources_aws import NAMED_RESOURCES as AWS_NAMED_RESOURCES
from torchx.specs.named_resources_tpu import NAMED_RESOURCES as TPU_NAMED_RESOURCES
from torchx.util.entrypoints import load_group

from .api import ( # noqa: F401 F403
Expand Down Expand Up @@ -58,7 +59,10 @@
def _load_named_resources() -> Dict[str, Resource]:
resource_methods = load_group("torchx.named_resources", default={})
materialized_resources = {}
default = AWS_NAMED_RESOURCES
default = {
**AWS_NAMED_RESOURCES,
**TPU_NAMED_RESOURCES,
}
for name, resource in default.items():
materialized_resources[name] = resource()
for resource_name, resource_method in resource_methods.items():
Expand Down
2 changes: 1 addition & 1 deletion torchx/specs/named_resources_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

Usage:

::
.. doctest::

from torchx.specs import named_resources
print(named_resources["aws_t3.medium"])
Expand Down
87 changes: 87 additions & 0 deletions torchx/specs/named_resources_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

r"""
`torchx.specs.named_resources_tpu` contains resource definitions that represent
corresponding Google Cloud TPU VMs.

TPUs require a matching torch version so the named resources will read the local
Torch version to set the `tf-version.cloud-tpus.google.com` annotation correctly.

.. note::
These resource definitions may change in future. It is expected for each user to
manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
to set up named resources.

Usage:

.. doctest::

from torchx.specs import named_resources
print(named_resources["tpu_v2_8"])
print(named_resources["tpu_v3_8"])
print(named_resources["tpu_preemptible_v3_8"])
print(named_resources["tpu_v3_2048"])
"""

from typing import Dict, Callable, Optional

from torchx.specs.api import Resource

NAMED_RESOURCES: Dict[str, Callable[[], Resource]] = {}


def _get_tf_version(version: Optional[str] = None) -> str:
if version is None:
try:
from torch.version import __version__

version = __version__
except ImportError:
version = "1.11"
if "dev" in version:
return "pytorch-nightly"
short_ver = ".".join(version.split(".")[:2])
return f"pytorch-{short_ver}"


def _register_type(ver: str, cores: int) -> Callable[[], Resource]:
device: str = "cloud-tpus.google.com/" + ver

def resource() -> Resource:
return Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": _get_tf_version(),
},
devices={
device: int(cores),
},
)

resource_name = f"tpu_{ver.replace('-', '_')}_{cores}"
NAMED_RESOURCES[resource_name] = resource
return resource


tpu_v2_8: Callable[[], Resource] = _register_type("v2", 8)
tpu_preemptible_v2_8: Callable[[], Resource] = _register_type("preemptible-v2", 8)
tpu_v2_32: Callable[[], Resource] = _register_type("v2", 32)
tpu_v2_128: Callable[[], Resource] = _register_type("v2", 128)
tpu_v2_256: Callable[[], Resource] = _register_type("v2", 256)
tpu_v2_512: Callable[[], Resource] = _register_type("v2", 512)

tpu_v3_8: Callable[[], Resource] = _register_type("v3", 8)
tpu_preemptible_v3_8: Callable[[], Resource] = _register_type("preemptible-v3", 8)
tpu_v3_32: Callable[[], Resource] = _register_type("v3", 32)
tpu_v3_64: Callable[[], Resource] = _register_type("v3", 64)
tpu_v3_128: Callable[[], Resource] = _register_type("v3", 128)
tpu_v3_256: Callable[[], Resource] = _register_type("v3", 256)
tpu_v3_512: Callable[[], Resource] = _register_type("v3", 512)
tpu_v3_1024: Callable[[], Resource] = _register_type("v3", 1024)
tpu_v3_2048: Callable[[], Resource] = _register_type("v3", 2048)
94 changes: 94 additions & 0 deletions torchx/specs/test/named_resource_tpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import unittest

from torchx.specs import Resource
from torchx.specs import named_resources_tpu as tpu


class NamedResourcesTest(unittest.TestCase):
def test_tf_version(self) -> None:
self.assertEqual(tpu._get_tf_version("2.123.0+cu102"), "pytorch-2.123")
self.assertEqual(
tpu._get_tf_version("1.12.0.dev20220419+cu113"), "pytorch-nightly"
)

def test_tpu_v3_8(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/v3": 8,
},
)
self.assertEqual(tpu.tpu_v3_8(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v3_8"](), want)

def test_tpu_v3_2048(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/v3": 2048,
},
)
self.assertEqual(tpu.tpu_v3_2048(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v3_2048"](), want)

def test_tpu_v2_8(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/v2": 8,
},
)
self.assertEqual(tpu.tpu_v2_8(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v2_8"](), want)

def test_tpu_preemptible_v2_8(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/preemptible-v2": 8,
},
)
self.assertEqual(tpu.tpu_preemptible_v2_8(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_preemptible_v2_8"](), want)

def test_tpu_preemptible_v3_8(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/preemptible-v3": 8,
},
)
self.assertEqual(tpu.tpu_preemptible_v3_8(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_preemptible_v3_8"](), want)