diff --git a/torchx/schedulers/api.py b/torchx/schedulers/api.py index 0a3763496..c2bf2f54a 100644 --- a/torchx/schedulers/api.py +++ b/torchx/schedulers/api.py @@ -11,10 +11,11 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Generic, Iterable, List, Optional, TypeVar, Union +from typing import Generic, Iterable, List, Optional, TypeVar from torchx.specs import ( AppDef, + AppDryRunInfo, AppState, NONE, NULL_RESOURCE, @@ -95,11 +96,9 @@ def __hash__(self) -> int: T = TypeVar("T") -A = TypeVar("A") -D = TypeVar("D") -class Scheduler(abc.ABC, Generic[T, A, D]): +class Scheduler(abc.ABC, Generic[T]): """ An interface abstracting functionalities of a scheduler. Implementers need only implement those methods annotated with @@ -129,7 +128,7 @@ def close(self) -> None: def submit( self, - app: A, + app: AppDef, cfg: T, workspace: str | Workspace | None = None, ) -> str: @@ -157,7 +156,7 @@ def submit( return self.schedule(dryrun_info) @abc.abstractmethod - def schedule(self, dryrun_info: D) -> str: + def schedule(self, dryrun_info: AppDryRunInfo) -> str: """ Same as ``submit`` except that it takes an ``AppDryRunInfo``. Implementers are encouraged to implement this method rather than @@ -173,7 +172,7 @@ def schedule(self, dryrun_info: D) -> str: raise NotImplementedError() - def submit_dryrun(self, app: A, cfg: T) -> D: + def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo: """ Rather than submitting the request to run the app, returns the request object that would have been submitted to the underlying @@ -187,15 +186,15 @@ def submit_dryrun(self, app: A, cfg: T) -> D: # pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg dryrun_info = self._submit_dryrun(app, resolved_cfg) - if isinstance(app, AppDef): - for role in app.roles: - dryrun_info = role.pre_proc(self.backend, dryrun_info) + for role in app.roles: + dryrun_info = role.pre_proc(self.backend, dryrun_info) + dryrun_info._app = app dryrun_info._cfg = resolved_cfg return dryrun_info @abc.abstractmethod - def _submit_dryrun(self, app: A, cfg: T) -> D: + def _submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo: raise NotImplementedError() def run_opts(self) -> runopts: @@ -394,15 +393,12 @@ def _pre_build_validate(self, app: AppDef, scheduler: str, cfg: T) -> None: """ pass - def _validate(self, app: A, scheduler: str, cfg: T) -> None: + def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None: """ Validates after workspace build whether application is consistent with the scheduler. Raises error if application is not compatible with scheduler """ - if not isinstance(app, AppDef): - return - for role in app.roles: if role.resource == NULL_RESOURCE: raise ValueError( diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py index 420d4b981..baaad159a 100644 --- a/torchx/schedulers/aws_batch_scheduler.py +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -381,7 +381,7 @@ def wrapper() -> T: @_thread_local_cache -def _local_session() -> "boto3.session.Session": +def _local_session() -> "boto3.session.Session": # noqa: F821 import boto3.session return boto3.session.Session() @@ -399,9 +399,7 @@ class AWSBatchOpts(TypedDict, total=False): ulimits: Optional[list[str]] -class AWSBatchScheduler( - DockerWorkspaceMixin, Scheduler[AWSBatchOpts, AppDef, AppDryRunInfo[BatchJob]] -): +class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]): """ AWSBatchScheduler is a TorchX scheduling interface to AWS Batch. diff --git a/torchx/schedulers/aws_sagemaker_scheduler.py b/torchx/schedulers/aws_sagemaker_scheduler.py index 083ea0f7c..5511bc152 100644 --- a/torchx/schedulers/aws_sagemaker_scheduler.py +++ b/torchx/schedulers/aws_sagemaker_scheduler.py @@ -157,7 +157,7 @@ def _merge_ordered( class AWSSageMakerScheduler( DockerWorkspaceMixin, - Scheduler[AWSSageMakerOpts, AppDef, AppDryRunInfo[AWSSageMakerJob]], + Scheduler[AWSSageMakerOpts], ): """ AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker. diff --git a/torchx/schedulers/docker_scheduler.py b/torchx/schedulers/docker_scheduler.py index c35a6a034..fa3bfccbb 100644 --- a/torchx/schedulers/docker_scheduler.py +++ b/torchx/schedulers/docker_scheduler.py @@ -129,9 +129,7 @@ class DockerOpts(TypedDict, total=False): privileged: bool -class DockerScheduler( - DockerWorkspaceMixin, Scheduler[DockerOpts, AppDef, AppDryRunInfo[DockerJob]] -): +class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]): """ DockerScheduler is a TorchX scheduling interface to Docker. diff --git a/torchx/schedulers/kubernetes_mcad_scheduler.py b/torchx/schedulers/kubernetes_mcad_scheduler.py index e0ee17eb1..36922e7e1 100644 --- a/torchx/schedulers/kubernetes_mcad_scheduler.py +++ b/torchx/schedulers/kubernetes_mcad_scheduler.py @@ -796,10 +796,7 @@ class KubernetesMCADOpts(TypedDict, total=False): network: Optional[str] -class KubernetesMCADScheduler( - DockerWorkspaceMixin, - Scheduler[KubernetesMCADOpts, AppDef, AppDryRunInfo[KubernetesMCADJob]], -): +class KubernetesMCADScheduler(DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts]): """ KubernetesMCADScheduler is a TorchX scheduling interface to Kubernetes. diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index d2775923c..2fb77f3bf 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -591,10 +591,7 @@ class KubernetesOpts(TypedDict, total=False): validate_spec: Optional[bool] -class KubernetesScheduler( - DockerWorkspaceMixin, - Scheduler[KubernetesOpts, AppDef, AppDryRunInfo[KubernetesJob]], -): +class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]): """ KubernetesScheduler is a TorchX scheduling interface to Kubernetes. diff --git a/torchx/schedulers/local_scheduler.py b/torchx/schedulers/local_scheduler.py index c039ebf54..eed800736 100644 --- a/torchx/schedulers/local_scheduler.py +++ b/torchx/schedulers/local_scheduler.py @@ -529,7 +529,7 @@ def _register_termination_signals() -> None: signal.signal(signal.SIGINT, _terminate_process_handler) -class LocalScheduler(Scheduler[LocalOpts, AppDef, AppDryRunInfo[PopenRequest]]): +class LocalScheduler(Scheduler[LocalOpts]): """ Schedules on localhost. Containers are modeled as processes and certain properties of the container that are either not relevant diff --git a/torchx/schedulers/lsf_scheduler.py b/torchx/schedulers/lsf_scheduler.py index b2700a316..b260745b7 100644 --- a/torchx/schedulers/lsf_scheduler.py +++ b/torchx/schedulers/lsf_scheduler.py @@ -394,7 +394,7 @@ def __repr__(self) -> str: {self.materialize()}""" -class LsfScheduler(Scheduler[LsfOpts, AppDef, AppDryRunInfo]): +class LsfScheduler(Scheduler[LsfOpts]): """ **Example: hello_world** diff --git a/torchx/schedulers/slurm_scheduler.py b/torchx/schedulers/slurm_scheduler.py index 7fd926d99..0b3232bda 100644 --- a/torchx/schedulers/slurm_scheduler.py +++ b/torchx/schedulers/slurm_scheduler.py @@ -335,9 +335,7 @@ def __repr__(self) -> str: {self.materialize()}""" -class SlurmScheduler( - DirWorkspaceMixin, Scheduler[SlurmOpts, AppDef, AppDryRunInfo[SlurmBatchRequest]] -): +class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]): """ SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects that slurm CLI tools are locally installed and job accounting is enabled. diff --git a/torchx/schedulers/test/api_test.py b/torchx/schedulers/test/api_test.py index 1f65dd6b5..f2f3c9cbb 100644 --- a/torchx/schedulers/test/api_test.py +++ b/torchx/schedulers/test/api_test.py @@ -35,12 +35,10 @@ from torchx.workspace.api import WorkspaceMixin T = TypeVar("T") -A = TypeVar("A") -D = TypeVar("D") class SchedulerTest(unittest.TestCase): - class MockScheduler(Scheduler[T, A, D], WorkspaceMixin[None]): + class MockScheduler(Scheduler[T], WorkspaceMixin[None]): def __init__(self, session_name: str) -> None: super().__init__("mock", session_name)