Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/apps/datapreproc/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def data_preproc(
],
env=env,
resource=resource,
).replicas(1)
num_replicas=1,
)

return specs.AppDef(name).of(ddp_role)
return specs.AppDef(name, roles=[ddp_role])
2 changes: 1 addition & 1 deletion torchx/cli/test/cmd_describe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_test_app(self) -> AppDef:
num_replicas=2,
nnodes="2:3",
)
return AppDef("my_train_job").of(trainer)
return AppDef("my_train_job", roles=[trainer])

def test_run(self) -> None:
parser = argparse.ArgumentParser()
Expand Down
11 changes: 7 additions & 4 deletions torchx/cli/test/cmd_log_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from unittest.mock import MagicMock, patch

from torchx.cli.cmd_log import ENDC, GREEN, get_logs
from torchx.specs.api import AppDef, Role, parse_app_handle
from torchx.specs import AppDef, Role, parse_app_handle


class SentinelError(Exception):
Expand All @@ -31,9 +31,12 @@ def __call__(self, name: Optional[str] = None) -> "MockRunner":

def describe(self, app_handle: str) -> AppDef:
scheduler_backend, session_name, app_id = parse_app_handle(app_handle)
return AppDef(name=app_id).of(
Role(name="master", image="test_image").replicas(1),
Role(name="trainer", image="test_image").replicas(3),
return AppDef(
name=app_id,
roles=[
Role(name="master", image="test_image", num_replicas=1),
Role(name="trainer", image="test_image", num_replicas=3),
],
)

def log_lines(
Expand Down
23 changes: 12 additions & 11 deletions torchx/components/base/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,16 @@ def create_torch_dist_role(
entrypoint = os.path.join(macros.img_root, entrypoint)

args = [*torch_run_args, entrypoint, *args]
return (
Role(
name,
image=image,
base_image=base_image,
resource=resource,
port_map=port_map,
)
.runs(entrypoint_override, *args, **env)
.replicas(num_replicas)
.with_retry_policy(retry_policy, max_retries)
return Role(
name,
image=image,
base_image=base_image,
entrypoint=entrypoint_override,
args=args,
env=env,
num_replicas=num_replicas,
retry_policy=retry_policy,
max_retries=max_retries,
resource=resource,
port_map=port_map,
)
6 changes: 4 additions & 2 deletions torchx/components/base/test/roles_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def test_build_create_torch_dist_role(self) -> None:
nnodes="2:4",
max_restarts=3,
no_python=True,
).replicas(2)
num_replicas=2,
)
self.assertEqual("elastic_trainer", elastic_trainer.name)
self.assertEqual("python", elastic_trainer.entrypoint)
self.assertEqual(
Expand Down Expand Up @@ -153,7 +154,8 @@ def test_json_serialization_factory(self) -> None:
nnodes="2:4",
rdzv_backend="etcd",
rdzv_id="foobar",
).replicas(3)
num_replicas=3,
)

# this is effectively JSON
elastic_json = asdict(role)
Expand Down
5 changes: 3 additions & 2 deletions torchx/components/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ def ddp(
base_image=base_image,
entrypoint=entrypoint,
resource=resource or specs.NULL_RESOURCE,
num_replicas=nnodes,
script_args=list(script_args),
script_envs=env,
nproc_per_node=nproc_per_node,
nnodes=nnodes,
max_restarts=0,
).replicas(nnodes)
)

return specs.AppDef(name).of(ddp_role)
return specs.AppDef(name, roles=[ddp_role])
33 changes: 14 additions & 19 deletions torchx/pipelines/kfp/test/adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,22 @@ class KFPSpecsTest(unittest.TestCase):
"""

def _test_app(self) -> api.AppDef:
trainer_role = (
api.Role(
name="trainer",
image="pytorch/torchx:latest",
resource=api.Resource(
cpu=2,
memMB=3000,
gpu=4,
),
port_map={"foo": 1234},
)
.runs(
"main",
"--output-path",
"blah",
FOO="bar",
)
.replicas(1)
trainer_role = api.Role(
name="trainer",
image="pytorch/torchx:latest",
entrypoint="main",
args=["--output-path", "blah"],
env={"FOO": "bar"},
resource=api.Resource(
cpu=2,
memMB=3000,
gpu=4,
),
port_map={"foo": 1234},
num_replicas=1,
)

return api.AppDef("test").of(trainer_role)
return api.AppDef("test", roles=[trainer_role])

def _compile_pipeline(self, pipeline: Callable[[], None]) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down
6 changes: 3 additions & 3 deletions torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def dryrun(
# input validation
if not app.roles:
raise ValueError(
f"No roles for app: {app.name}. Did you forget to call app.of(roles..)?"
f"No roles for app: {app.name}. Did you forget to add roles to AppDef?"
)

for role in app.roles:
Expand All @@ -248,7 +248,7 @@ def dryrun(
if role.num_replicas <= 0:
raise ValueError(
f"Non-positive replicas for role: {role.name}."
f" Did you forget to call role.replicas(positive_number)?"
f" Did you forget to set role.num_replicas?"
)
sched = self._scheduler(scheduler)
sched._validate(app, scheduler)
Expand Down Expand Up @@ -394,7 +394,7 @@ def describe(self, app_handle: AppHandle) -> Optional[AppDef]:
if not app:
desc = scheduler.describe(app_id)
if desc:
app = AppDef(name=app_id).of(*desc.roles)
app = AppDef(name=app_id, roles=desc.roles)
return app

def log_lines(
Expand Down
118 changes: 78 additions & 40 deletions torchx/runner/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,24 @@ def test_validate_no_roles(self, _) -> None:
def test_validate_no_resource(self, _) -> None:
runner = Runner("test", schedulers={"default": self.scheduler})
with self.assertRaises(ValueError):
role = Role("no resource", image="no_image").runs("echo", "hello_world")
app = AppDef("no resource").of(role)
role = Role(
"no resource", image="no_image", entrypoint="echo", args=["hello_world"]
)
app = AppDef("no resource", roles=[role])
runner.run(app)

def test_validate_invalid_replicas(self, _) -> None:
runner = Runner("test", schedulers={"default": self.scheduler})
with self.assertRaises(ValueError):
role = (
Role(
"invalid replicas",
image="torch",
resource=Resource(cpu=1, gpu=0, memMB=500),
)
.runs("echo", "hello_world")
.replicas(0)
role = Role(
"invalid replicas",
image="torch",
entrypoint="echo",
args=["hello_world"],
num_replicas=0,
resource=Resource(cpu=1, gpu=0, memMB=500),
)
app = AppDef("invalid replicas").of(role)
app = AppDef("invalid replicas", roles=[role])
runner.run(app)

def test_run(self, _) -> None:
Expand All @@ -95,10 +96,14 @@ def test_run(self, _) -> None:
wait_interval=1,
)
self.assertEqual(1, len(session.scheduler_backends()))
role = Role(name="touch", image=self.test_dir, resource=resource.SMALL).runs(
"touch.sh", test_file
role = Role(
name="touch",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="touch.sh",
args=[test_file],
)
app = AppDef("name").of(role)
app = AppDef("name", roles=[role])

app_handle = session.run(app, cfg=self.cfg)
app_status = none_throws(session.wait(app_handle))
Expand All @@ -109,20 +114,28 @@ def test_dryrun(self, _) -> None:
session = Runner(
name=SESSION_NAME, schedulers={"default": scheduler_mock}, wait_interval=1
)
role = Role(name="touch", image=self.test_dir, resource=resource.SMALL).runs(
"echo", "hello world"
role = Role(
name="touch",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="echo",
args=["hello world"],
)
app = AppDef("name").of(role)
app = AppDef("name", roles=[role])
session.dryrun(app, "default", cfg=self.cfg)
scheduler_mock.submit_dryrun.assert_called_once_with(app, self.cfg)
scheduler_mock._validate.assert_called_once()

def test_describe(self, _) -> None:
session = Runner(name=SESSION_NAME, schedulers={"default": self.scheduler})
role = Role(name="sleep", image=self.test_dir, resource=resource.SMALL).runs(
"sleep.sh", "60"
role = Role(
name="sleep",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="sleep.sh",
args=["60"],
)
app = AppDef("sleeper").of(role)
app = AppDef("sleeper", roles=[role])

app_handle = session.run(app, cfg=self.cfg)
self.assertEqual(app, session.describe(app_handle))
Expand All @@ -133,10 +146,14 @@ def test_list(self, _) -> None:
session = Runner(
name=SESSION_NAME, schedulers={"default": self.scheduler}, wait_interval=1
)
role = Role(name="touch", image=self.test_dir, resource=resource.SMALL).runs(
"sleep.sh", "1"
role = Role(
name="touch",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="sleep.sh",
args=["1"],
)
app = AppDef("sleeper").of(role)
app = AppDef("sleeper", roles=[role])

num_apps = 4

Expand All @@ -159,10 +176,14 @@ def test_evict_non_existent_app(self, _) -> None:
name=SESSION_NAME, schedulers={"default": scheduler}, wait_interval=1
)
test_file = os.path.join(self.test_dir, "test_file")
role = Role(name="touch", image=self.test_dir, resource=resource.SMALL).runs(
"touch.sh", test_file
role = Role(
name="touch",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="touch.sh",
args=[test_file],
)
app = AppDef("touch_test_file").of(role)
app = AppDef("touch_test_file", roles=[role])

# local scheduler was setup with a cache size of 1
# run the same app twice (the first will be removed from the scheduler's cache)
Expand All @@ -183,10 +204,14 @@ def test_status(self, _) -> None:
session = Runner(
name=SESSION_NAME, schedulers={"default": self.scheduler}, wait_interval=1
)
role = Role(name="sleep", image=self.test_dir, resource=resource.SMALL).runs(
"sleep.sh", "60"
role = Role(
name="sleep",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="sleep.sh",
args=["60"],
)
app = AppDef("sleeper").of(role)
app = AppDef("sleeper", roles=[role])
app_handle = session.run(app, cfg=self.cfg)
app_status = none_throws(session.status(app_handle))
self.assertEqual(AppState.RUNNING, app_status.state)
Expand All @@ -213,10 +238,13 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None:
session = Runner(
name="test_ui_url_session", schedulers={"default": mock_scheduler}
)
role = Role("ignored", image=self.test_dir, resource=resource.SMALL).runs(
"/bin/echo"
role = Role(
"ignored",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="/bin/echo",
)
app_handle = session.run(AppDef(app_id).of(role))
app_handle = session.run(AppDef(app_id, roles=[role]))
status = none_throws(session.status(app_handle))
self.assertEquals(resp.ui_url, status.ui_url)

Expand All @@ -233,10 +261,13 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:
session = Runner(
name="test_structured_msg", schedulers={"default": mock_scheduler}
)
role = Role("ignored", image=self.test_dir, resource=resource.SMALL).runs(
"/bin/echo"
role = Role(
"ignored",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="/bin/echo",
)
app_handle = session.run(AppDef(app_id).of(role))
app_handle = session.run(AppDef(app_id, roles=[role]))
status = none_throws(session.status(app_handle))
self.assertEquals(resp.structured_error_msg, status.structured_error_msg)

Expand Down Expand Up @@ -305,10 +336,14 @@ def test_get_schedulers(self, json_dumps_mock: MagicMock, _) -> None:
schedulers = {"default": default_sched_mock, "local": local_sched_mock}
session = Runner(name="test_session", schedulers=schedulers)

role = Role(name="sleep", image=self.test_dir, resource=resource.SMALL).runs(
"sleep.sh", "60"
role = Role(
name="sleep",
image=self.test_dir,
resource=resource.SMALL,
entrypoint="sleep.sh",
args=["60"],
)
app = AppDef("sleeper").of(role)
app = AppDef("sleeper", roles=[role])
cfg = RunConfig()
session.run(app, scheduler="local", cfg=cfg)
local_sched_mock.submit.called_once_with(app, cfg)
Expand Down Expand Up @@ -353,8 +388,11 @@ def test_run_from_file(self, _) -> None:
expected_app = AppDef(
"ddp_app",
roles=[
Role("worker", image="dummy_image", resource=Resource(1, 0, 1)).runs(
entrypoint
Role(
"worker",
image="dummy_image",
resource=Resource(1, 0, 1),
entrypoint=entrypoint,
)
],
)
Expand Down
Loading