Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions python/monarch/tools/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import argparse
import functools
import inspect
import logging
import os
import time
from datetime import timedelta
from typing import Any, Callable, Mapping, Optional, Union

from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults
Expand All @@ -18,12 +21,13 @@
)

from monarch.tools.mesh_spec import mesh_spec_from_metadata, ServerSpec

from torchx.runner import Runner
from torchx.specs import AppDef, AppDryRunInfo, CfgVal
from torchx.specs import AppDef, AppDryRunInfo, AppState, CfgVal
from torchx.specs.builders import parse_args
from torchx.util.types import decode, decode_optional

logger: logging.Logger = logging.getLogger(__name__)


def torchx_runner() -> Runner:
# namespace is currently unused so make it empty str
Expand Down Expand Up @@ -165,15 +169,73 @@ def info(server_handle: str) -> Optional[ServerSpec]:
if appdef is None:
return None

# host status grouped by mesh (role) names
replica_status = {r.role: r.replicas for r in status.roles}

mesh_specs = []
for role in appdef.roles:
spec = mesh_spec_from_metadata(appdef, role.name)
assert spec is not None, "cannot be 'None' since we iterate over appdef's roles"

# null-guard since some schedulers do not fill replica_status
if host_status := replica_status.get(role.name):
spec.hostnames = [h.hostname for h in host_status]

mesh_specs.append(spec)

return ServerSpec(name=appdef.name, state=status.state, meshes=mesh_specs)


_5_SECONDS = timedelta(seconds=5)


async def server_ready(
server_handle: str, check_interval: timedelta = _5_SECONDS
) -> Optional[ServerSpec]:
"""Waits until the server's job is in RUNNING state to returns the server spec.
Returns `None` if the server does not exist.

NOTE: Certain fields such as `hostnames` is only filled (and valid) when the server is RUNNING.

Usage:

.. code-block:: python

server_info = await server_ready("slurm:///123")
if not server_info:
print(f"Job does not exist")
else:
if server_info.is_running:
for mesh in server_info.meshes:
connect_to(mesh.hostnames)
else:
print(f"Job in {server_info.state} state. Hostnames are not available")

"""

while True:
server_spec = info(server_handle)

if not server_spec: # server not found
return None

if server_spec.state <= AppState.PENDING: # UNSUBMITTED or SUBMITTED or PENDING
# NOTE: TorchX currently does not have async APIs so need to loop-on-interval
# TODO maybe inverse exponential backoff instead of constant interval?
check_interval_seconds = check_interval.total_seconds()
logger.info(
"waiting for %s to be %s (current: %s), will check again in %g seconds...",
server_handle,
AppState.RUNNING,
server_spec.state,
check_interval_seconds,
)
time.sleep(check_interval_seconds)
continue
else:
return server_spec


def kill(server_handle: str) -> None:
with torchx_runner() as runner:
runner.cancel(server_handle)
Expand Down
8 changes: 7 additions & 1 deletion python/monarch/tools/mesh_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict
import string
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Optional

from torchx import specs
Expand All @@ -29,6 +29,7 @@ class MeshSpec:
host_type: str
gpus: int
port: int = DEFAULT_REMOTE_ALLOCATOR_PORT
hostnames: list[str] = field(default_factory=list)


def _tag(mesh_name: str, tag_template: str) -> str:
Expand Down Expand Up @@ -84,6 +85,10 @@ class ServerSpec:
state: specs.AppState
meshes: list[MeshSpec]

@property
def is_running(self) -> bool:
return self.state == specs.AppState.RUNNING

def get_mesh_spec(self, mesh_name: str) -> MeshSpec:
for mesh_spec in self.meshes:
if mesh_spec.name == mesh_name:
Expand Down Expand Up @@ -115,6 +120,7 @@ def to_json(self) -> dict[str, Any]:
"host_type": mesh.host_type,
"hosts": mesh.num_hosts,
"gpus": mesh.gpus,
"hostnames": mesh.hostnames,
}
for mesh in self.meshes
},
Expand Down
6 changes: 4 additions & 2 deletions python/tests/tools/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ def test_info(self, mock_cmd_info: mock.MagicMock) -> None:
"trainer": {
"host_type": "gpu.medium",
"hosts": 4,
"gpus": 2
"gpus": 2,
"hostnames": []
},
"generator": {
"host_type": "gpu.small",
"hosts": 16,
"gpus": 1
"gpus": 1,
"hostnames": []
}
}
}
Expand Down
78 changes: 77 additions & 1 deletion python/tests/tools/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
# pyre-strict

import unittest
from datetime import timedelta
from unittest import mock

from monarch.tools import commands
from monarch.tools.commands import component_args_from_cli
from monarch.tools.commands import component_args_from_cli, server_ready

from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults
defaults,
Expand Down Expand Up @@ -101,3 +102,78 @@ def test_info(
),
commands.info("slurm:///job-id"),
)


UNUSED = "__UNUSED__"
_5_MS = timedelta(milliseconds=5)


def server(state: AppState) -> ServerSpec:
mesh_x = MeshSpec(name="x", num_hosts=2, host_type=UNUSED, gpus=-1)
mesh_y = MeshSpec(name="y", num_hosts=4, host_type=UNUSED, gpus=-1)
meshes = [mesh_x, mesh_y]

if state == AppState.RUNNING:
for mesh in meshes:
mesh.hostnames = [f"node{i}" for i in range(mesh.num_hosts)]

return ServerSpec(name=UNUSED, state=state, meshes=meshes)


class TestCommandsAsync(unittest.IsolatedAsyncioTestCase):
async def test_server_ready_server_does_not_exist(self) -> None:
with mock.patch(
"monarch.tools.commands.info",
return_value=None,
):
server_info = await server_ready("slurm:///123", check_interval=_5_MS)
self.assertIsNone(server_info)

async def test_server_ready_pending_to_running(self) -> None:
with mock.patch(
"monarch.tools.commands.info",
side_effect=[
server(AppState.UNSUBMITTED),
server(AppState.SUBMITTED),
server(AppState.PENDING),
server(AppState.PENDING),
server(AppState.RUNNING),
server(AppState.CANCELLED),
],
) as mock_info:
server_info = await server_ready("slurm:///123", check_interval=_5_MS)

self.assertIsNotNone(server_info)
self.assertTrue(server_info.is_running)
self.assertEqual(server_info.state, AppState.RUNNING)

mesh_x = server_info.get_mesh_spec("x")
mesh_y = server_info.get_mesh_spec("y")
self.assertListEqual(mesh_x.hostnames, ["node0", "node1"])
self.assertListEqual(mesh_y.hostnames, ["node0", "node1", "node2", "node3"])

mock_info.assert_called()
# called 5 times, once for UNSUBMITTED, SUBMITTED, PENDING, PENDING, and RUNNING
self.assertEqual(mock_info.call_count, 5)

async def test_server_ready_pending_to_terminal(self) -> None:
for terminal_state in [AppState.SUCCEEDED, AppState.FAILED, AppState.CANCELLED]:
with self.subTest(terminal_state=terminal_state):
with mock.patch(
"monarch.tools.commands.info",
side_effect=[
server(AppState.SUBMITTED),
server(AppState.PENDING),
server(AppState.PENDING),
server(terminal_state),
],
) as mock_info:
server_info = await server_ready(
"slurm:///123",
check_interval=_5_MS,
)

self.assertIsNotNone(server_info)
self.assertEqual(server_info.state, terminal_state)
mock_info.assert_called()
self.assertEqual(mock_info.call_count, 4)
14 changes: 12 additions & 2 deletions python/tests/tools/test_mesh_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,25 @@ def test_mesh_spec_from_metadata(self) -> None:

def test_mesh_spec_can_dump_as_json(self) -> None:
mesh_spec = MeshSpec(
name="trainer", num_hosts=4, host_type="gpu.medium", gpus=2
name="trainer",
num_hosts=4,
host_type="gpu.medium",
gpus=2,
hostnames=["n0", "n1", "n2", "n3"],
)
expected = """
{
"name": "trainer",
"num_hosts": 4,
"host_type": "gpu.medium",
"gpus": 2,
"port": 26600
"port": 26600,
"hostnames": [
"n0",
"n1",
"n2",
"n3"
]
}
"""
self.assertEqual(expected.strip("\n"), json.dumps(asdict(mesh_spec), indent=2))
Expand Down