diff --git a/CHANGELOG.md b/CHANGELOG.md index f7c18da5a..cede037ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,7 +77,7 @@ Milestone: https://github.com/pytorch/torchx/milestones/3 * Slurm jobs will by default launch in the current working directory to match `local_cwd` and workspace behavior. #372 * Replicas now have their own log files and can be accessed programmatically. #373 * Support for `comment`, `mail-user` and `constraint` fields. #391 - * Workspace support (prototype) - Slurm jobs can now be launched in isolated experiment directories. #416 + * WorkspaceMixin support (prototype) - Slurm jobs can now be launched in isolated experiment directories. #416 * Kubernetes * Support for running jobs under service accounts. #408 * Support for specifying instance types. #433 diff --git a/docs/source/workspace.rst b/docs/source/workspace.rst index feec50b1b..a3fd0fefa 100644 --- a/docs/source/workspace.rst +++ b/docs/source/workspace.rst @@ -6,7 +6,7 @@ torchx.workspace .. currentmodule:: torchx.workspace -.. autoclass:: Workspace +.. autoclass:: WorkspaceMixin :members: .. autofunction:: walk_workspace @@ -18,7 +18,7 @@ torchx.workspace.docker_workspace .. automodule:: torchx.workspace.docker_workspace .. currentmodule:: torchx.workspace.docker_workspace -.. autoclass:: DockerWorkspace +.. autoclass:: DockerWorkspaceMixin :members: :private-members: _update_app_images, _push_images @@ -29,7 +29,7 @@ torchx.workspace.dir_workspace .. automodule:: torchx.workspace.dir_workspace .. currentmodule:: torchx.workspace.dir_workspace -.. autoclass:: DirWorkspace +.. autoclass:: DirWorkspaceMixin :members: .. fbcode:: @@ -40,6 +40,6 @@ torchx.workspace.dir_workspace .. automodule:: torchx.workspace.fb.jetter_workspace .. currentmodule:: torchx.workspace.fb.jetter_workspace - .. autoclass:: JetterWorkspace + .. autoclass:: JetterWorkspaceMixin :members: :show-inheritance: diff --git a/torchx/runner/api.py b/torchx/runner/api.py index d37b60228..194530372 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -34,7 +34,7 @@ from torchx.tracker.api import tracker_config_env_var_name, TRACKER_ENV_VAR_NAME from torchx.util.types import none_throws -from torchx.workspace.api import Workspace +from torchx.workspace.api import WorkspaceMixin from .config import get_config, get_configs @@ -363,7 +363,7 @@ def dryrun( with log_event("dryrun", scheduler, runcfg=json.dumps(cfg) if cfg else None): sched = self._scheduler(scheduler) - if workspace and isinstance(sched, Workspace): + if workspace and isinstance(sched, WorkspaceMixin): role = app.roles[0] old_img = role.image diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index df5f6d1f8..95c9e6b04 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -26,7 +26,7 @@ from torchx.specs.finder import ComponentNotFoundException from torchx.util.types import none_throws -from torchx.workspace import Workspace +from torchx.workspace import WorkspaceMixin GET_SCHEDULER_FACTORIES = "torchx.runner.api.get_scheduler_factories" @@ -293,9 +293,9 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None: ) def test_dryrun_with_workspace(self, _) -> None: - class TestScheduler(Scheduler, Workspace): + class TestScheduler(WorkspaceMixin[None], Scheduler): def __init__(self, build_new_img: bool): - Scheduler.__init__(self, backend="ignored", session_name="ignored") + super().__init__(backend="ignored", session_name="ignored") self.build_new_img = build_new_img def schedule(self, dryrun_info: AppDryRunInfo) -> str: diff --git a/torchx/runner/test/config_test.py b/torchx/runner/test/config_test.py index e2e3068d1..78274c01b 100644 --- a/torchx/runner/test/config_test.py +++ b/torchx/runner/test/config_test.py @@ -62,7 +62,7 @@ def log_iter( def list(self) -> List[ListAppResponse]: raise NotImplementedError() - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add( "i", diff --git a/torchx/schedulers/api.py b/torchx/schedulers/api.py index e059f8c69..450ee7c68 100644 --- a/torchx/schedulers/api.py +++ b/torchx/schedulers/api.py @@ -22,7 +22,7 @@ RoleStatus, runopts, ) -from torchx.workspace.api import Workspace +from torchx.workspace.api import WorkspaceMixin DAYS_IN_2_WEEKS = 14 @@ -138,7 +138,7 @@ def submit( """ if workspace: sched = self - assert isinstance(sched, Workspace) + assert isinstance(sched, WorkspaceMixin) role = app.roles[0] sched.build_workspace_and_update_role(role, workspace, cfg) dryrun_info = self.submit_dryrun(app, cfg) @@ -189,6 +189,12 @@ def run_opts(self) -> runopts: Returns the run configuration options expected by the scheduler. Basically a ``--help`` for the ``run`` API. """ + opts = self._run_opts() + if isinstance(self, WorkspaceMixin): + opts.update(self.workspace_opts()) + return opts + + def _run_opts(self) -> runopts: return runopts() @abc.abstractmethod diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index d558079a6..d50ea55af 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -42,9 +42,11 @@ from typing import ( Any, Callable, + cast, Dict, Iterable, List, + Mapping, Optional, Tuple, TYPE_CHECKING, @@ -67,13 +69,14 @@ AppDef, AppState, BindMount, + CfgVal, DeviceMount, macros, Role, runopts, VolumeMount, ) -from torchx.workspace.docker_workspace import DockerWorkspace +from torchx.workspace.docker_workspace import DockerWorkspaceMixin from typing_extensions import TypedDict if TYPE_CHECKING: @@ -246,7 +249,7 @@ class AWSBatchOpts(TypedDict, total=False): priority: Optional[int] -class AWSBatchScheduler(Scheduler[AWSBatchOpts], DockerWorkspace): +class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]): """ AWSBatchScheduler is a TorchX scheduling interface to AWS Batch. @@ -308,8 +311,7 @@ def __init__( log_client: Optional[Any] = None, docker_client: Optional["DockerClient"] = None, ) -> None: - Scheduler.__init__(self, "aws_batch", session_name) - DockerWorkspace.__init__(self, docker_client) + super().__init__("aws_batch", session_name, docker_client=docker_client) # pyre-fixme[4]: Attribute annotation cannot be `Any`. self.__client = client @@ -335,7 +337,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str: assert cfg is not None, f"{dryrun_info} missing cfg" images_to_push = dryrun_info.request.images_to_push - self._push_images(images_to_push) + self.push_images(images_to_push) req = dryrun_info.request self._client.register_job_definition(**req.job_def) @@ -370,7 +372,7 @@ def _submit_dryrun(self, app: AppDef, cfg: AWSBatchOpts) -> AppDryRunInfo[BatchJ name = make_unique(f"{app.name}{name_suffix}") # map any local images to the remote image - images_to_push = self._update_app_images(app, cfg.get("image_repo")) + images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg)) nodes = [] @@ -450,14 +452,9 @@ def _cancel_existing(self, app_id: str) -> None: reason="killed via torchx CLI", ) - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add("queue", type_=str, help="queue to schedule job in", required=True) - opts.add( - "image_repo", - type_=str, - help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container", - ) opts.add( "share_id", type_=str, diff --git a/torchx/schedulers/docker_scheduler.py b/torchx/schedulers/docker_scheduler.py index 6c93a3072..a24c9c07d 100644 --- a/torchx/schedulers/docker_scheduler.py +++ b/torchx/schedulers/docker_scheduler.py @@ -38,7 +38,7 @@ runopts, VolumeMount, ) -from torchx.workspace.docker_workspace import DockerWorkspace +from torchx.workspace.docker_workspace import DockerWorkspaceMixin from typing_extensions import TypedDict @@ -76,7 +76,7 @@ def __repr__(self) -> str: return str(self) -LABEL_VERSION: str = DockerWorkspace.LABEL_VERSION +LABEL_VERSION: str = DockerWorkspaceMixin.LABEL_VERSION LABEL_APP_ID: str = "torchx.pytorch.org/app-id" LABEL_ROLE_NAME: str = "torchx.pytorch.org/role-name" LABEL_REPLICA_ID: str = "torchx.pytorch.org/replica-id" @@ -98,7 +98,7 @@ class DockerOpts(TypedDict, total=False): copy_env: Optional[List[str]] -class DockerScheduler(Scheduler[DockerOpts], DockerWorkspace): +class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]): """ DockerScheduler is a TorchX scheduling interface to Docker. @@ -143,8 +143,7 @@ class DockerScheduler(Scheduler[DockerOpts], DockerWorkspace): """ def __init__(self, session_name: str) -> None: - Scheduler.__init__(self, "docker", session_name) - DockerWorkspace.__init__(self) + super().__init__("docker", session_name) def _ensure_network(self) -> None: import filelock @@ -346,7 +345,7 @@ def _cancel_existing(self, app_id: str) -> None: for container in containers: container.stop() - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add( "copy_env", diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index 7a79155b6..752c796d7 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -33,7 +33,17 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING +from typing import ( + Any, + cast, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + TYPE_CHECKING, +) import torchx import yaml @@ -51,6 +61,7 @@ AppDef, AppState, BindMount, + CfgVal, DeviceMount, macros, ReplicaState, @@ -61,7 +72,7 @@ runopts, VolumeMount, ) -from torchx.workspace.docker_workspace import DockerWorkspace +from torchx.workspace.docker_workspace import DockerWorkspaceMixin from typing_extensions import TypedDict @@ -441,7 +452,7 @@ class KubernetesOpts(TypedDict, total=False): priority_class: Optional[str] -class KubernetesScheduler(Scheduler[KubernetesOpts], DockerWorkspace): +class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]): """ KubernetesScheduler is a TorchX scheduling interface to Kubernetes. @@ -535,8 +546,7 @@ def __init__( client: Optional["ApiClient"] = None, docker_client: Optional["DockerClient"] = None, ) -> None: - Scheduler.__init__(self, "kubernetes", session_name) - DockerWorkspace.__init__(self, docker_client) + super().__init__("kubernetes", session_name, docker_client=docker_client) self._client = client @@ -575,7 +585,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str: namespace = cfg.get("namespace") or "default" images_to_push = dryrun_info.request.images_to_push - self._push_images(images_to_push) + self.push_images(images_to_push) resource = dryrun_info.request.resource try: @@ -605,7 +615,7 @@ def _submit_dryrun( raise TypeError(f"config value 'queue' must be a string, got {queue}") # map any local images to the remote image - images_to_push = self._update_app_images(app, cfg.get("image_repo")) + images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg)) service_account = cfg.get("service_account") assert service_account is None or isinstance( @@ -642,7 +652,7 @@ def _cancel_existing(self, app_id: str) -> None: name=name, ) - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add( "namespace", @@ -656,11 +666,6 @@ def run_opts(self) -> runopts: help="Volcano queue to schedule job in", required=True, ) - opts.add( - "image_repo", - type_=str, - help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container", - ) opts.add( "service_account", type_=str, diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index d4df2fb86..336d78994 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -573,7 +573,7 @@ def __init__( self._base_log_dir: Optional[str] = None self._created_tmp_log_dir: bool = False - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add( "log_dir", diff --git a/torchx/schedulers/lsf_scheduler.py b/torchx/schedulers/lsf_scheduler.py index e2c62d921..1fc229dfc 100644 --- a/torchx/schedulers/lsf_scheduler.py +++ b/torchx/schedulers/lsf_scheduler.py @@ -440,7 +440,7 @@ class LsfScheduler(Scheduler[LsfOpts]): def __init__(self, session_name: str) -> None: super().__init__("lsf", session_name) - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add( "lsf_queue", diff --git a/torchx/schedulers/ray_scheduler.py b/torchx/schedulers/ray_scheduler.py index 3cdafacf2..c9a981a78 100644 --- a/torchx/schedulers/ray_scheduler.py +++ b/torchx/schedulers/ray_scheduler.py @@ -29,7 +29,7 @@ from torchx.schedulers.ids import make_unique from torchx.schedulers.ray.ray_common import RayActor, TORCHX_RANK0_HOST from torchx.specs import AppDef, macros, NONE, ReplicaStatus, Role, RoleStatus, runopts -from torchx.workspace.dir_workspace import TmpDirWorkspace +from torchx.workspace.dir_workspace import TmpDirWorkspaceMixin from typing_extensions import TypedDict @@ -113,7 +113,7 @@ class RayJob: requirements: Optional[str] = None actors: List[RayActor] = field(default_factory=list) - class RayScheduler(Scheduler[RayOpts], TmpDirWorkspace): + class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]): """ RayScheduler is a TorchX scheduling interface to Ray. The job def workers will be launched as Ray actors @@ -153,7 +153,7 @@ def __init__(self, session_name: str) -> None: super().__init__("ray", session_name) # TODO: Add address as a potential CLI argument after writing ray.status() or passing in config file - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add( "cluster_config_file", diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index c937a2cd4..0bbd1c5d0 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -41,7 +41,7 @@ RoleStatus, runopts, ) -from torchx.workspace.dir_workspace import DirWorkspace +from torchx.workspace.dir_workspace import DirWorkspaceMixin from typing_extensions import TypedDict SLURM_JOB_DIRS = ".torchxslurmjobdirs" @@ -257,7 +257,7 @@ def __repr__(self) -> str: {self.materialize()}""" -class SlurmScheduler(Scheduler[SlurmOpts], DirWorkspace): +class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]): """ SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects that slurm CLI tools are locally installed and job accounting is enabled. @@ -311,7 +311,7 @@ class SlurmScheduler(Scheduler[SlurmOpts], DirWorkspace): Partial support. SlurmScheduler will return job and replica status but does not provide the complete original AppSpec. workspaces: | - If ``job_dir`` is specified the DirWorkspace will create a new + If ``job_dir`` is specified the DirWorkspaceMixin will create a new isolated directory with a snapshot of the workspace. mounts: false elasticity: false @@ -323,7 +323,7 @@ class SlurmScheduler(Scheduler[SlurmOpts], DirWorkspace): def __init__(self, session_name: str) -> None: super().__init__("slurm", session_name) - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add( "partition", diff --git a/torchx/schedulers/test/api_test.py b/torchx/schedulers/test/api_test.py index ee3a980f5..9d36e5946 100644 --- a/torchx/schedulers/test/api_test.py +++ b/torchx/schedulers/test/api_test.py @@ -29,13 +29,13 @@ Role, runopts, ) -from torchx.workspace.api import Workspace +from torchx.workspace.api import WorkspaceMixin T = TypeVar("T") class SchedulerTest(unittest.TestCase): - class MockScheduler(Scheduler[T], Workspace): + class MockScheduler(Scheduler[T], WorkspaceMixin[None]): def __init__(self, session_name: str) -> None: super().__init__("mock", session_name) @@ -73,7 +73,7 @@ def log_iter( def list(self) -> List[ListAppResponse]: return [] - def run_opts(self) -> runopts: + def _run_opts(self) -> runopts: opts = runopts() opts.add("foo", type_=str, required=True, help="required option") return opts diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 34d0a6b81..0a482ca15 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -892,6 +892,9 @@ def add( self._opts[cfg_key] = runopt(default, type_, required, help) + def update(self, other: "runopts") -> None: + self._opts.update(other._opts) + def __repr__(self) -> str: required = [(key, opt) for key, opt in self._opts.items() if opt.is_required] optional = [ diff --git a/torchx/workspace/__init__.py b/torchx/workspace/__init__.py index d4d47d7ee..3f75bf1f9 100644 --- a/torchx/workspace/__init__.py +++ b/torchx/workspace/__init__.py @@ -20,4 +20,4 @@ * ``memory://foo-bar/`` an in-memory workspace for notebook/programmatic usage """ -from torchx.workspace.api import walk_workspace, Workspace # noqa: F401 +from torchx.workspace.api import walk_workspace, WorkspaceMixin # noqa: F401 diff --git a/torchx/workspace/api.py b/torchx/workspace/api.py index e965fe19a..239b341ac 100644 --- a/torchx/workspace/api.py +++ b/torchx/workspace/api.py @@ -7,17 +7,19 @@ import abc import fnmatch import posixpath -from typing import Iterable, Mapping, Tuple, TYPE_CHECKING +from typing import Generic, Iterable, Mapping, Tuple, TYPE_CHECKING, TypeVar -from torchx.specs import CfgVal, Role +from torchx.specs import AppDef, CfgVal, Role, runopts if TYPE_CHECKING: from fsspec import AbstractFileSystem TORCHX_IGNORE = ".torchxignore" +T = TypeVar("T") -class Workspace(abc.ABC): + +class WorkspaceMixin(abc.ABC, Generic[T]): """ Note: (Prototype) this interface may change without notice! @@ -32,6 +34,16 @@ class Workspace(abc.ABC): workspace build artifact is, is implementation dependent. """ + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + + def workspace_opts(self) -> runopts: + """ + Returns the run configuration options expected by the workspace. + Basically a ``--help`` for the ``run`` API. + """ + return runopts() + @abc.abstractmethod def build_workspace_and_update_role( self, role: Role, workspace: str, cfg: Mapping[str, CfgVal] @@ -47,6 +59,21 @@ def build_workspace_and_update_role( """ ... + def dryrun_push_images(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> T: + """ + dryrun_push does a dryrun of the image push and updates the app to have + the final values. Only called for remote jobs. + + ``push`` must be called before scheduling the job. + """ + raise NotImplementedError("dryrun_push is not implemented") + + def push_images(self, images_to_push: T) -> None: + """ + push pushes any images to the remote repo if required. + """ + raise NotImplementedError("push is not implemented") + def _ignore(s: str, patterns: Iterable[str]) -> Tuple[int, bool]: last_matching_pattern = -1 diff --git a/torchx/workspace/dir_workspace.py b/torchx/workspace/dir_workspace.py index 6b29d2e68..e19afc9e6 100644 --- a/torchx/workspace/dir_workspace.py +++ b/torchx/workspace/dir_workspace.py @@ -13,10 +13,10 @@ import fsspec from torchx.specs import CfgVal, Role -from torchx.workspace.api import walk_workspace, Workspace +from torchx.workspace.api import walk_workspace, WorkspaceMixin -class TmpDirWorkspace(Workspace): +class TmpDirWorkspaceMixin(WorkspaceMixin[None]): def build_workspace_and_update_role( self, role: Role, workspace: str, cfg: Mapping[str, CfgVal] ) -> None: @@ -31,7 +31,7 @@ def build_workspace_and_update_role( role.image = job_dir -class DirWorkspace(Workspace): +class DirWorkspaceMixin(WorkspaceMixin[None]): def build_workspace_and_update_role( self, role: Role, workspace: str, cfg: Mapping[str, CfgVal] ) -> None: diff --git a/torchx/workspace/docker_workspace.py b/torchx/workspace/docker_workspace.py index 0cd893b83..222b72088 100644 --- a/torchx/workspace/docker_workspace.py +++ b/torchx/workspace/docker_workspace.py @@ -14,8 +14,8 @@ import fsspec import torchx -from torchx.specs import AppDef, CfgVal, Role -from torchx.workspace.api import walk_workspace, Workspace +from torchx.specs import AppDef, CfgVal, Role, runopts +from torchx.workspace.api import walk_workspace, WorkspaceMixin if TYPE_CHECKING: from docker import DockerClient @@ -26,9 +26,9 @@ TORCHX_DOCKERFILE = "Dockerfile.torchx" -class DockerWorkspace(Workspace): +class DockerWorkspaceMixin(WorkspaceMixin[Dict[str, Tuple[str, str]]]): """ - DockerWorkspace will build patched docker images from the workspace. These + DockerWorkspaceMixin will build patched docker images from the workspace. These patched images are docker images and can be either used locally via the docker daemon or pushed using the helper methods to a remote repository for remote jobs. @@ -50,7 +50,13 @@ class DockerWorkspace(Workspace): LABEL_VERSION: str = "torchx.pytorch.org/version" - def __init__(self, docker_client: Optional["DockerClient"] = None) -> None: + def __init__( + self, + *args: object, + docker_client: Optional["DockerClient"] = None, + **kwargs: object, + ) -> None: + super().__init__(*args, **kwargs) self.__docker_client = docker_client @property @@ -63,6 +69,15 @@ def _docker_client(self) -> "DockerClient": self.__docker_client = client return client + def workspace_opts(self) -> runopts: + opts = runopts() + opts.add( + "image_repo", + type_=str, + help="(remote jobs) the image repository to use when pushing patched images, must have push access. Ex: example.com/your/container", + ) + return opts + def build_workspace_and_update_role( self, role: Role, workspace: str, cfg: Mapping[str, CfgVal] ) -> None: @@ -98,21 +113,22 @@ def build_workspace_and_update_role( finally: context.close() - def _update_app_images( - self, app: AppDef, image_repo: Optional[str] = None + def dryrun_push_images( + self, app: AppDef, cfg: Mapping[str, CfgVal] ) -> Dict[str, Tuple[str, str]]: """ _update_app_images replaces the local Docker images (identified via ``sha256:...``) in the provided ``AppDef`` with the remote path that they will be uploaded to and returns a mapping of local to remote names. - ``_push_images`` must be called with the returned mapping before + ``push`` must be called with the returned mapping before launching the job. Returns: A dict of [local image name, (remote repo, tag)]. """ HASH_PREFIX = "sha256:" + image_repo = cfg.get("image_repo") images_to_push = {} for role in app.roles: @@ -132,7 +148,7 @@ def _update_app_images( role.image = remote_image return images_to_push - def _push_images(self, images_to_push: Dict[str, Tuple[str, str]]) -> None: + def push_images(self, images_to_push: Dict[str, Tuple[str, str]]) -> None: """ _push_images pushes the specified images to the remote container repository with the specified tag. The docker daemon must be diff --git a/torchx/workspace/test/dir_workspace_test.py b/torchx/workspace/test/dir_workspace_test.py index e0b4f5b4a..b2160b77a 100644 --- a/torchx/workspace/test/dir_workspace_test.py +++ b/torchx/workspace/test/dir_workspace_test.py @@ -12,12 +12,16 @@ import fsspec from torchx.specs import Role -from torchx.workspace.dir_workspace import _copy_to_dir, DirWorkspace, TmpDirWorkspace +from torchx.workspace.dir_workspace import ( + _copy_to_dir, + DirWorkspaceMixin, + TmpDirWorkspaceMixin, +) class DirWorkspaceTest(unittest.TestCase): def test_build_workspace_no_job_dir(self) -> None: - w = DirWorkspace() + w = DirWorkspaceMixin() role = Role( name="role", image="blah", @@ -27,7 +31,7 @@ def test_build_workspace_no_job_dir(self) -> None: self.assertEqual(role.image, "blah") def test_build_workspace(self) -> None: - w = DirWorkspace() + w = DirWorkspaceMixin() role = Role( name="role", image="blah", @@ -108,7 +112,7 @@ def test_torchxignore(self) -> None: class TmpDirWorkspaceTest(unittest.TestCase): def test_build_workspace(self) -> None: - w = TmpDirWorkspace() + w = TmpDirWorkspaceMixin() role = Role( name="role", image="blah", diff --git a/torchx/workspace/test/docker_workspace_test.py b/torchx/workspace/test/docker_workspace_test.py index fc83f074a..87b7cad13 100644 --- a/torchx/workspace/test/docker_workspace_test.py +++ b/torchx/workspace/test/docker_workspace_test.py @@ -14,7 +14,7 @@ import fsspec from torchx.specs import AppDef, Role -from torchx.workspace.docker_workspace import _build_context, DockerWorkspace +from torchx.workspace.docker_workspace import _build_context, DockerWorkspaceMixin def has_docker() -> bool: @@ -43,7 +43,7 @@ def test_docker_workspace(self) -> None: args=["bar/foo.sh"], ) - workspace = DockerWorkspace() + workspace = DockerWorkspaceMixin() workspace.build_workspace_and_update_role( role, "memory://test_workspace", {} ) @@ -52,6 +52,12 @@ def test_docker_workspace(self) -> None: class DockerWorkspaceMockTest(unittest.TestCase): + def test_runopts(self) -> None: + self.assertCountEqual( + DockerWorkspaceMixin().workspace_opts()._opts.keys(), + {"image_repo"}, + ) + def test_update_app_images(self) -> None: app = AppDef( name="foo", @@ -86,9 +92,11 @@ def test_update_app_images(self) -> None: ) # no image_repo with self.assertRaisesRegex(KeyError, "image_repo"): - DockerWorkspace()._update_app_images(app) + DockerWorkspaceMixin().dryrun_push_images(app, {}) # with image_repo - images_to_push = DockerWorkspace()._update_app_images(app, "example.com/repo") + images_to_push = DockerWorkspaceMixin().dryrun_push_images( + app, {"image_repo": "example.com/repo"} + ) self.assertEqual( images_to_push, { @@ -102,8 +110,8 @@ def test_push_images(self) -> None: client = MagicMock() img = MagicMock() client.images.get.return_value = img - workspace = DockerWorkspace(docker_client=client) - workspace._push_images( + workspace = DockerWorkspaceMixin(docker_client=client) + workspace.push_images( { "sha256:hasha": ("example.com/repo", "hasha"), "sha256:hashb": ("example.com/repo", "hashb"), @@ -114,8 +122,8 @@ def test_push_images(self) -> None: self.assertEqual(client.images.push.call_count, 2) def test_push_images_empty(self) -> None: - workspace = DockerWorkspace() - workspace._push_images({}) + workspace = DockerWorkspaceMixin() + workspace.push_images({}) def test_dockerignore(self) -> None: fs = fsspec.filesystem("memory")