Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ Milestone: https://github.com/pytorch/torchx/milestones/3
* Slurm jobs will by default launch in the current working directory to match `local_cwd` and workspace behavior. #372
* Replicas now have their own log files and can be accessed programmatically. #373
* Support for `comment`, `mail-user` and `constraint` fields. #391
* Workspace support (prototype) - Slurm jobs can now be launched in isolated experiment directories. #416
* WorkspaceMixin support (prototype) - Slurm jobs can now be launched in isolated experiment directories. #416
* Kubernetes
* Support for running jobs under service accounts. #408
* Support for specifying instance types. #433
Expand Down
8 changes: 4 additions & 4 deletions docs/source/workspace.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ torchx.workspace

.. currentmodule:: torchx.workspace

.. autoclass:: Workspace
.. autoclass:: WorkspaceMixin
:members:

.. autofunction:: walk_workspace
Expand All @@ -18,7 +18,7 @@ torchx.workspace.docker_workspace
.. automodule:: torchx.workspace.docker_workspace
.. currentmodule:: torchx.workspace.docker_workspace

.. autoclass:: DockerWorkspace
.. autoclass:: DockerWorkspaceMixin
:members:
:private-members: _update_app_images, _push_images

Expand All @@ -29,7 +29,7 @@ torchx.workspace.dir_workspace
.. automodule:: torchx.workspace.dir_workspace
.. currentmodule:: torchx.workspace.dir_workspace

.. autoclass:: DirWorkspace
.. autoclass:: DirWorkspaceMixin
:members:

.. fbcode::
Expand All @@ -40,6 +40,6 @@ torchx.workspace.dir_workspace
.. automodule:: torchx.workspace.fb.jetter_workspace
.. currentmodule:: torchx.workspace.fb.jetter_workspace

.. autoclass:: JetterWorkspace
.. autoclass:: JetterWorkspaceMixin
:members:
:show-inheritance:
4 changes: 2 additions & 2 deletions torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torchx.tracker.api import tracker_config_env_var_name, TRACKER_ENV_VAR_NAME

from torchx.util.types import none_throws
from torchx.workspace.api import Workspace
from torchx.workspace.api import WorkspaceMixin

from .config import get_config, get_configs

Expand Down Expand Up @@ -363,7 +363,7 @@ def dryrun(
with log_event("dryrun", scheduler, runcfg=json.dumps(cfg) if cfg else None):
sched = self._scheduler(scheduler)

if workspace and isinstance(sched, Workspace):
if workspace and isinstance(sched, WorkspaceMixin):
role = app.roles[0]
old_img = role.image

Expand Down
6 changes: 3 additions & 3 deletions torchx/runner/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torchx.specs.finder import ComponentNotFoundException

from torchx.util.types import none_throws
from torchx.workspace import Workspace
from torchx.workspace import WorkspaceMixin


GET_SCHEDULER_FACTORIES = "torchx.runner.api.get_scheduler_factories"
Expand Down Expand Up @@ -293,9 +293,9 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None:
)

def test_dryrun_with_workspace(self, _) -> None:
class TestScheduler(Scheduler, Workspace):
class TestScheduler(WorkspaceMixin[None], Scheduler):
def __init__(self, build_new_img: bool):
Scheduler.__init__(self, backend="ignored", session_name="ignored")
super().__init__(backend="ignored", session_name="ignored")
self.build_new_img = build_new_img

def schedule(self, dryrun_info: AppDryRunInfo) -> str:
Expand Down
2 changes: 1 addition & 1 deletion torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def log_iter(
def list(self) -> List[ListAppResponse]:
raise NotImplementedError()

def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"i",
Expand Down
10 changes: 8 additions & 2 deletions torchx/schedulers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
RoleStatus,
runopts,
)
from torchx.workspace.api import Workspace
from torchx.workspace.api import WorkspaceMixin


DAYS_IN_2_WEEKS = 14
Expand Down Expand Up @@ -138,7 +138,7 @@ def submit(
"""
if workspace:
sched = self
assert isinstance(sched, Workspace)
assert isinstance(sched, WorkspaceMixin)
role = app.roles[0]
sched.build_workspace_and_update_role(role, workspace, cfg)
dryrun_info = self.submit_dryrun(app, cfg)
Expand Down Expand Up @@ -189,6 +189,12 @@ def run_opts(self) -> runopts:
Returns the run configuration options expected by the scheduler.
Basically a ``--help`` for the ``run`` API.
"""
opts = self._run_opts()
if isinstance(self, WorkspaceMixin):
opts.update(self.workspace_opts())
return opts

def _run_opts(self) -> runopts:
return runopts()

@abc.abstractmethod
Expand Down
21 changes: 9 additions & 12 deletions torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
TYPE_CHECKING,
Expand All @@ -67,13 +69,14 @@
AppDef,
AppState,
BindMount,
CfgVal,
DeviceMount,
macros,
Role,
runopts,
VolumeMount,
)
from torchx.workspace.docker_workspace import DockerWorkspace
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
from typing_extensions import TypedDict

if TYPE_CHECKING:
Expand Down Expand Up @@ -246,7 +249,7 @@ class AWSBatchOpts(TypedDict, total=False):
priority: Optional[int]


class AWSBatchScheduler(Scheduler[AWSBatchOpts], DockerWorkspace):
class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
"""
AWSBatchScheduler is a TorchX scheduling interface to AWS Batch.

Expand Down Expand Up @@ -308,8 +311,7 @@ def __init__(
log_client: Optional[Any] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
Scheduler.__init__(self, "aws_batch", session_name)
DockerWorkspace.__init__(self, docker_client)
super().__init__("aws_batch", session_name, docker_client=docker_client)

# pyre-fixme[4]: Attribute annotation cannot be `Any`.
self.__client = client
Expand All @@ -335,7 +337,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str:
assert cfg is not None, f"{dryrun_info} missing cfg"

images_to_push = dryrun_info.request.images_to_push
self._push_images(images_to_push)
self.push_images(images_to_push)

req = dryrun_info.request
self._client.register_job_definition(**req.job_def)
Expand Down Expand Up @@ -370,7 +372,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ
name = make_unique(f"{app.name}{name_suffix}")

# map any local images to the remote image
images_to_push = self._update_app_images(app, cfg.get("image_repo"))
images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg))

nodes = []

Expand Down Expand Up @@ -450,14 +452,9 @@ def _cancel_existing(self, app_id: str) -> None:
reason="killed via torchx CLI",
)

def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add("queue", type_=str, help="queue to schedule job in", required=True)
opts.add(
"image_repo",
type_=str,
help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
)
opts.add(
"share_id",
type_=str,
Expand Down
11 changes: 5 additions & 6 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
runopts,
VolumeMount,
)
from torchx.workspace.docker_workspace import DockerWorkspace
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
from typing_extensions import TypedDict


Expand Down Expand Up @@ -76,7 +76,7 @@ def __repr__(self) -> str:
return str(self)


LABEL_VERSION: str = DockerWorkspace.LABEL_VERSION
LABEL_VERSION: str = DockerWorkspaceMixin.LABEL_VERSION
LABEL_APP_ID: str = "torchx.pytorch.org/app-id"
LABEL_ROLE_NAME: str = "torchx.pytorch.org/role-name"
LABEL_REPLICA_ID: str = "torchx.pytorch.org/replica-id"
Expand All @@ -98,7 +98,7 @@ class DockerOpts(TypedDict, total=False):
copy_env: Optional[List[str]]


class DockerScheduler(Scheduler[DockerOpts], DockerWorkspace):
class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
"""
DockerScheduler is a TorchX scheduling interface to Docker.

Expand Down Expand Up @@ -143,8 +143,7 @@ class DockerScheduler(Scheduler[DockerOpts], DockerWorkspace):
"""

def __init__(self, session_name: str) -> None:
Scheduler.__init__(self, "docker", session_name)
DockerWorkspace.__init__(self)
super().__init__("docker", session_name)

def _ensure_network(self) -> None:
import filelock
Expand Down Expand Up @@ -346,7 +345,7 @@ def _cancel_existing(self, app_id: str) -> None:
for container in containers:
container.stop()

def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"copy_env",
Expand Down
31 changes: 18 additions & 13 deletions torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,17 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING
from typing import (
Any,
cast,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
TYPE_CHECKING,
)

import torchx
import yaml
Expand All @@ -51,6 +61,7 @@
AppDef,
AppState,
BindMount,
CfgVal,
DeviceMount,
macros,
ReplicaState,
Expand All @@ -61,7 +72,7 @@
runopts,
VolumeMount,
)
from torchx.workspace.docker_workspace import DockerWorkspace
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
from typing_extensions import TypedDict


Expand Down Expand Up @@ -441,7 +452,7 @@ class KubernetesOpts(TypedDict, total=False):
priority_class: Optional[str]


class KubernetesScheduler(Scheduler[KubernetesOpts], DockerWorkspace):
class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]):
"""
KubernetesScheduler is a TorchX scheduling interface to Kubernetes.

Expand Down Expand Up @@ -535,8 +546,7 @@ def __init__(
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
Scheduler.__init__(self, "kubernetes", session_name)
DockerWorkspace.__init__(self, docker_client)
super().__init__("kubernetes", session_name, docker_client=docker_client)

self._client = client

Expand Down Expand Up @@ -575,7 +585,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str:
namespace = cfg.get("namespace") or "default"

images_to_push = dryrun_info.request.images_to_push
self._push_images(images_to_push)
self.push_images(images_to_push)

resource = dryrun_info.request.resource
try:
Expand Down Expand Up @@ -605,7 +615,7 @@ def _submit_dryrun(
raise TypeError(f"config value 'queue' must be a string, got {queue}")

# map any local images to the remote image
images_to_push = self._update_app_images(app, cfg.get("image_repo"))
images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg))

service_account = cfg.get("service_account")
assert service_account is None or isinstance(
Expand Down Expand Up @@ -642,7 +652,7 @@ def _cancel_existing(self, app_id: str) -> None:
name=name,
)

def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"namespace",
Expand All @@ -656,11 +666,6 @@ def run_opts(self) -> runopts:
help="Volcano queue to schedule job in",
required=True,
)
opts.add(
"image_repo",
type_=str,
help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
)
opts.add(
"service_account",
type_=str,
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def __init__(
self._base_log_dir: Optional[str] = None
self._created_tmp_log_dir: bool = False

def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"log_dir",
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/lsf_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class LsfScheduler(Scheduler[LsfOpts]):
def __init__(self, session_name: str) -> None:
super().__init__("lsf", session_name)

def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"lsf_queue",
Expand Down
6 changes: 3 additions & 3 deletions torchx/schedulers/ray_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torchx.schedulers.ids import make_unique
from torchx.schedulers.ray.ray_common import RayActor, TORCHX_RANK0_HOST
from torchx.specs import AppDef, macros, NONE, ReplicaStatus, Role, RoleStatus, runopts
from torchx.workspace.dir_workspace import TmpDirWorkspace
from torchx.workspace.dir_workspace import TmpDirWorkspaceMixin
from typing_extensions import TypedDict


Expand Down Expand Up @@ -113,7 +113,7 @@ class RayJob:
requirements: Optional[str] = None
actors: List[RayActor] = field(default_factory=list)

class RayScheduler(Scheduler[RayOpts], TmpDirWorkspace):
class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]):
"""
RayScheduler is a TorchX scheduling interface to Ray. The job def
workers will be launched as Ray actors
Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(self, session_name: str) -> None:
super().__init__("ray", session_name)

# TODO: Add address as a potential CLI argument after writing ray.status() or passing in config file
def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"cluster_config_file",
Expand Down
Loading