Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions torchx/schedulers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Generic, Iterable, List, Optional, TypeVar, Union
from typing import Generic, Iterable, List, Optional, TypeVar

from torchx.specs import (
AppDef,
AppDryRunInfo,
AppState,
NONE,
NULL_RESOURCE,
Expand Down Expand Up @@ -95,11 +96,9 @@ def __hash__(self) -> int:


T = TypeVar("T")
A = TypeVar("A")
D = TypeVar("D")


class Scheduler(abc.ABC, Generic[T, A, D]):
class Scheduler(abc.ABC, Generic[T]):
"""
An interface abstracting functionalities of a scheduler.
Implementers need only implement those methods annotated with
Expand Down Expand Up @@ -129,7 +128,7 @@ def close(self) -> None:

def submit(
self,
app: A,
app: AppDef,
cfg: T,
workspace: str | Workspace | None = None,
) -> str:
Expand Down Expand Up @@ -157,7 +156,7 @@ def submit(
return self.schedule(dryrun_info)

@abc.abstractmethod
def schedule(self, dryrun_info: D) -> str:
def schedule(self, dryrun_info: AppDryRunInfo) -> str:
"""
Same as ``submit`` except that it takes an ``AppDryRunInfo``.
Implementers are encouraged to implement this method rather than
Expand All @@ -173,7 +172,7 @@ def schedule(self, dryrun_info: D) -> str:

raise NotImplementedError()

def submit_dryrun(self, app: A, cfg: T) -> D:
def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
"""
Rather than submitting the request to run the app, returns the
request object that would have been submitted to the underlying
Expand All @@ -187,15 +186,15 @@ def submit_dryrun(self, app: A, cfg: T) -> D:
# pyre-fixme: _submit_dryrun takes Generic type for resolved_cfg
dryrun_info = self._submit_dryrun(app, resolved_cfg)

if isinstance(app, AppDef):
for role in app.roles:
dryrun_info = role.pre_proc(self.backend, dryrun_info)
for role in app.roles:
dryrun_info = role.pre_proc(self.backend, dryrun_info)

dryrun_info._app = app
dryrun_info._cfg = resolved_cfg
return dryrun_info

@abc.abstractmethod
def _submit_dryrun(self, app: A, cfg: T) -> D:
def _submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
raise NotImplementedError()

def run_opts(self) -> runopts:
Expand Down Expand Up @@ -394,15 +393,12 @@ def _pre_build_validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
"""
pass

def _validate(self, app: A, scheduler: str, cfg: T) -> None:
def _validate(self, app: AppDef, scheduler: str, cfg: T) -> None:
"""
Validates after workspace build whether application is consistent with the scheduler.

Raises error if application is not compatible with scheduler
"""
if not isinstance(app, AppDef):
return

for role in app.roles:
if role.resource == NULL_RESOURCE:
raise ValueError(
Expand Down
6 changes: 2 additions & 4 deletions torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def wrapper() -> T:


@_thread_local_cache
def _local_session() -> "boto3.session.Session":
def _local_session() -> "boto3.session.Session": # noqa: F821
import boto3.session

return boto3.session.Session()
Expand All @@ -399,9 +399,7 @@ class AWSBatchOpts(TypedDict, total=False):
ulimits: Optional[list[str]]


class AWSBatchScheduler(
DockerWorkspaceMixin, Scheduler[AWSBatchOpts, AppDef, AppDryRunInfo[BatchJob]]
):
class AWSBatchScheduler(DockerWorkspaceMixin, Scheduler[AWSBatchOpts]):
"""
AWSBatchScheduler is a TorchX scheduling interface to AWS Batch.

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/aws_sagemaker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _merge_ordered(

class AWSSageMakerScheduler(
DockerWorkspaceMixin,
Scheduler[AWSSageMakerOpts, AppDef, AppDryRunInfo[AWSSageMakerJob]],
Scheduler[AWSSageMakerOpts],
):
"""
AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker.
Expand Down
4 changes: 1 addition & 3 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ class DockerOpts(TypedDict, total=False):
privileged: bool


class DockerScheduler(
DockerWorkspaceMixin, Scheduler[DockerOpts, AppDef, AppDryRunInfo[DockerJob]]
):
class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
"""
DockerScheduler is a TorchX scheduling interface to Docker.

Expand Down
5 changes: 1 addition & 4 deletions torchx/schedulers/kubernetes_mcad_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,10 +796,7 @@ class KubernetesMCADOpts(TypedDict, total=False):
network: Optional[str]


class KubernetesMCADScheduler(
DockerWorkspaceMixin,
Scheduler[KubernetesMCADOpts, AppDef, AppDryRunInfo[KubernetesMCADJob]],
):
class KubernetesMCADScheduler(DockerWorkspaceMixin, Scheduler[KubernetesMCADOpts]):
"""
KubernetesMCADScheduler is a TorchX scheduling interface to Kubernetes.

Expand Down
5 changes: 1 addition & 4 deletions torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,10 +591,7 @@ class KubernetesOpts(TypedDict, total=False):
validate_spec: Optional[bool]


class KubernetesScheduler(
DockerWorkspaceMixin,
Scheduler[KubernetesOpts, AppDef, AppDryRunInfo[KubernetesJob]],
):
class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]):
"""
KubernetesScheduler is a TorchX scheduling interface to Kubernetes.

Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _register_termination_signals() -> None:
signal.signal(signal.SIGINT, _terminate_process_handler)


class LocalScheduler(Scheduler[LocalOpts, AppDef, AppDryRunInfo[PopenRequest]]):
class LocalScheduler(Scheduler[LocalOpts]):
"""
Schedules on localhost. Containers are modeled as processes and
certain properties of the container that are either not relevant
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/lsf_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def __repr__(self) -> str:
{self.materialize()}"""


class LsfScheduler(Scheduler[LsfOpts, AppDef, AppDryRunInfo]):
class LsfScheduler(Scheduler[LsfOpts]):
"""
**Example: hello_world**

Expand Down
4 changes: 1 addition & 3 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,7 @@ def __repr__(self) -> str:
{self.materialize()}"""


class SlurmScheduler(
DirWorkspaceMixin, Scheduler[SlurmOpts, AppDef, AppDryRunInfo[SlurmBatchRequest]]
):
class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
"""
SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects
that slurm CLI tools are locally installed and job accounting is enabled.
Expand Down
4 changes: 1 addition & 3 deletions torchx/schedulers/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@
from torchx.workspace.api import WorkspaceMixin

T = TypeVar("T")
A = TypeVar("A")
D = TypeVar("D")


class SchedulerTest(unittest.TestCase):
class MockScheduler(Scheduler[T, A, D], WorkspaceMixin[None]):
class MockScheduler(Scheduler[T], WorkspaceMixin[None]):
def __init__(self, session_name: str) -> None:
super().__init__("mock", session_name)

Expand Down