Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def __init__(
log_client: Optional[Any] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("aws_batch", session_name, docker_client=docker_client)

# pyre-fixme[4]: Attribute annotation cannot be `Any`.
Expand Down Expand Up @@ -796,7 +797,18 @@ def _stream_events(
yield event["message"] + "\n"


def create_scheduler(session_name: str, **kwargs: object) -> AWSBatchScheduler:
def create_scheduler(
session_name: str,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
client: Optional[Any] = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
log_client: Optional[Any] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: object,
) -> AWSBatchScheduler:
return AWSBatchScheduler(
session_name=session_name,
client=client,
log_client=log_client,
docker_client=docker_client,
)
1 change: 1 addition & 0 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("docker", session_name)

def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str:
Expand Down
9 changes: 8 additions & 1 deletion torchx/schedulers/gcp_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
client: Optional[Any] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
Scheduler.__init__(self, "gcp_batch", session_name)
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
self.__client = client
Expand Down Expand Up @@ -474,7 +475,13 @@ def _cancel_existing(self, app_id: str) -> None:
self._client.delete_job(request=request)


def create_scheduler(session_name: str, **kwargs: object) -> GCPBatchScheduler:
def create_scheduler(
session_name: str,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
client: Optional[Any] = None,
**kwargs: object,
) -> GCPBatchScheduler:
return GCPBatchScheduler(
session_name=session_name,
client=client,
)
10 changes: 9 additions & 1 deletion torchx/schedulers/kubernetes_mcad_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,7 @@ def __init__(
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("kubernetes_mcad", session_name, docker_client=docker_client)

self._client = client
Expand Down Expand Up @@ -1230,9 +1231,16 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> KubernetesMCADScheduler:
def create_scheduler(
session_name: str,
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: Any,
) -> KubernetesMCADScheduler:
return KubernetesMCADScheduler(
session_name=session_name,
client=client,
docker_client=docker_client,
)


Expand Down
10 changes: 9 additions & 1 deletion torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def __init__(
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("kubernetes", session_name, docker_client=docker_client)

self._client = client
Expand Down Expand Up @@ -777,9 +778,16 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> KubernetesScheduler:
def create_scheduler(
session_name: str,
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: Any,
) -> KubernetesScheduler:
return KubernetesScheduler(
session_name=session_name,
client=client,
docker_client=docker_client,
)


Expand Down
11 changes: 9 additions & 2 deletions torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def __init__(
cache_size: int = 100,
extra_paths: Optional[List[str]] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("local", session_name)

# TODO T72035686 replace dict with a proper LRUCache data structure
Expand Down Expand Up @@ -1124,9 +1125,15 @@ def __next__(self) -> str:
return line


def create_scheduler(session_name: str, **kwargs: Any) -> LocalScheduler:
def create_scheduler(
session_name: str,
cache_size: int = 100,
extra_paths: Optional[List[str]] = None,
**kwargs: Any,
) -> LocalScheduler:
return LocalScheduler(
session_name=session_name,
cache_size=kwargs.get("cache_size", 100),
image_provider_class=CWDImageProvider,
cache_size=cache_size,
extra_paths=extra_paths,
)
1 change: 1 addition & 0 deletions torchx/schedulers/lsf_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ class LsfScheduler(Scheduler[LsfOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("lsf", session_name)

def _run_opts(self) -> runopts:
Expand Down
7 changes: 5 additions & 2 deletions torchx/schedulers/ray_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]):
def __init__(
self, session_name: str, ray_client: Optional[JobSubmissionClient] = None
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("ray", session_name)

# w/o Final None check in _get_ray_client does not work as it pyre assumes mutability
Expand Down Expand Up @@ -441,10 +442,12 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> "RayScheduler":
def create_scheduler(
session_name: str, ray_client: Optional[JobSubmissionClient] = None, **kwargs: Any
) -> "RayScheduler":
if not has_ray(): # pragma: no cover
raise ModuleNotFoundError(
"Ray is not installed in the current Python environment."
)

return RayScheduler(session_name=session_name)
return RayScheduler(session_name=session_name, ray_client=ray_client)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to add some unit tests for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add some unit tests this week 👍🏻

1 change: 1 addition & 0 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("slurm", session_name)

def _run_opts(self) -> runopts:
Expand Down
10 changes: 9 additions & 1 deletion torchx/schedulers/test/aws_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,16 @@ def paginate(self, *_1: Any, **_2: Any) -> Iterable[Dict[str, Any]]:

class AWSBatchSchedulerTest(unittest.TestCase):
def test_create_scheduler(self) -> None:
scheduler = create_scheduler("foo")
client = MagicMock()
log_client = MagicMock()
docker_client = MagicMock()
scheduler = create_scheduler(
"foo", client=client, log_client=log_client, docker_client=docker_client
)
self.assertIsInstance(scheduler, AWSBatchScheduler)
self.assertEqual(scheduler._client, client)
self.assertEqual(scheduler._log_client, log_client)
self.assertEqual(scheduler._docker_client, docker_client)

def test_submit_dryrun_with_share_id(self) -> None:
app = _test_app()
Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/test/gcp_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def mock_rand() -> Generator[None, None, None]:

class GCPBatchSchedulerTest(unittest.TestCase):
def test_create_scheduler(self) -> None:
scheduler = create_scheduler("foo")
client = MagicMock()
scheduler = create_scheduler("foo", client=client)
self.assertIsInstance(scheduler, GCPBatchScheduler)
self.assertEqual(scheduler._client, client)

@mock_rand()
def test_submit_dryrun(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion torchx/schedulers/test/kubernetes_mcad_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,14 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef:

class KubernetesMCADSchedulerTest(unittest.TestCase):
def test_create_scheduler(self) -> None:
scheduler = create_scheduler("foo")
client = MagicMock()
docker_client = MagicMock()
scheduler = create_scheduler("foo", client=client, docker_client=docker_client)
self.assertIsInstance(
scheduler, kubernetes_mcad_scheduler.KubernetesMCADScheduler
)
self.assertEquals(client, scheduler._client)
self.assertEquals(docker_client, scheduler._docker_client)

def test_app_to_resource_resolved_macros(self) -> None:
app = _test_app()
Expand Down
6 changes: 5 additions & 1 deletion torchx/schedulers/test/kubernetes_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef:

class KubernetesSchedulerTest(unittest.TestCase):
def test_create_scheduler(self) -> None:
scheduler = create_scheduler("foo")
client = MagicMock()
docker_client = MagicMock
scheduler = create_scheduler("foo", client=client, docker_client=docker_client)
self.assertIsInstance(scheduler, kubernetes_scheduler.KubernetesScheduler)
self.assertEquals(scheduler._docker_client, docker_client)
self.assertEquals(scheduler._client, client)

def test_app_to_resource_resolved_macros(self) -> None:
app = _test_app()
Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/test/local_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ def test_get_entrypoint(self) -> None:
self.assertEqual(self.provider.get_entrypoint("asdf", role), "entrypoint.sh")

def test_create_scheduler(self) -> None:
sched = create_scheduler("foo")
sched = create_scheduler("foo", cache_size=20, extra_paths=["foo"])
self.assertEqual(sched.session_name, "foo")
self.assertEqual(sched._image_provider_class, CWDImageProvider)
self.assertEqual(sched._cache_size, 20)
self.assertEqual(len(sched._extra_paths), 1)


LOCAL_SCHEDULER_MAKE_UNIQUE = "torchx.schedulers.local_scheduler.make_unique"
Expand Down