diff --git a/.lintrunner.toml b/.lintrunner.toml index db84901f..27fb98d0 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1,6 +1,7 @@ [[linter]] code = 'BLACK-ISORT' include_patterns = [ + '*.py', '**/*.py', ] exclude_patterns = [] @@ -46,6 +47,7 @@ command = [ [[linter]] code = 'PYRE' include_patterns = [ + '*.py', '**/*.py', '**/*.pyi', ] diff --git a/src/lib.rs b/src/lib.rs index 186f0108..f1b5a6b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod manager; use core::time::Duration; use std::env; +use std::sync::Arc; use anyhow::Result; use pyo3::exceptions::PyRuntimeError; @@ -28,6 +29,8 @@ use pyo3::prelude::*; #[pyclass] struct Manager { handle: JoinHandle>, + manager: Arc, + _runtime: Runtime, } #[pymethods] @@ -55,10 +58,18 @@ impl Manager { )) .unwrap(); let handle = runtime.spawn(manager.clone().run()); - Self { handle: handle } + Self { + handle: handle, + manager: manager, + _runtime: runtime, + } }) } + fn address(&self) -> PyResult { + Ok(self.manager.address().to_string()) + } + fn shutdown(&self, py: Python<'_>) { py.allow_threads(move || { self.handle.abort(); @@ -200,6 +211,48 @@ async fn lighthouse_main_async(opt: lighthouse::LighthouseOpt) -> Result<()> { Ok(()) } +#[pyclass] +struct Lighthouse { + lighthouse: Arc, + handle: JoinHandle>, + _runtime: Runtime, +} + +#[pymethods] +impl Lighthouse { + #[new] + fn new(py: Python<'_>, bind: String, min_replicas: u64) -> PyResult { + py.allow_threads(move || { + let rt = Runtime::new().unwrap(); + + let lighthouse = rt + .block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt { + bind: bind, + min_replicas: min_replicas, + join_timeout_ms: 100, + quorum_tick_ms: 100, + })) + .unwrap(); + + Ok(Self { + handle: rt.spawn(lighthouse.clone().run()), + lighthouse: lighthouse, + _runtime: rt, + }) + }) + } + + fn address(&self) -> PyResult { + Ok(self.lighthouse.address().to_string()) + } + + fn shutdown(&self, py: Python<'_>) { + py.allow_threads(move || { + self.handle.abort(); + }) + } +} + #[pymodule] fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { // setup logging on import @@ -212,6 +265,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?; Ok(()) diff --git a/torchft/manager.py b/torchft/manager.py index 207994b2..e8b8d3ba 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -39,8 +39,6 @@ from torch.optim import Optimizer from torchft.checkpointing import CheckpointServer - -# pyre-fixme[21]: can't find rust module from torchft.torchft import Manager as _Manager, ManagerClient if TYPE_CHECKING: @@ -121,7 +119,7 @@ def __init__( store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) - self._rank: int = rank or int(os.environ["RANK"]) + self._rank: int = rank if rank is not None else int(os.environ["RANK"]) rank = self._rank world_size = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size @@ -151,7 +149,6 @@ def __init__( if replica_id is None: replica_id = str(uuid.uuid4()) - # pyre-fixme[16]: can't find rust module self._manager = _Manager( replica_id=replica_id, lighthouse_addr=lighthouse_addr, diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py new file mode 100644 index 00000000..e9c9261d --- /dev/null +++ b/torchft/manager_integ_test.py @@ -0,0 +1,100 @@ +from concurrent.futures import ThreadPoolExecutor, as_completed +from unittest import TestCase + +import torch +import torch.distributed as dist +from torch import nn, optim + +from torchft.ddp import DistributedDataParallel +from torchft.manager import Manager +from torchft.optim import OptimizerWrapper +from torchft.process_group import ProcessGroupGloo +from torchft.torchft import Lighthouse + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + nn.Linear(3, 4), + nn.Sigmoid(), + ) + + def forward(self, x): + return self.model(x) + + +def train_loop(replica_id: int, lighthouse_address: str) -> None: + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=2, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=str(replica_id), + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + lighthouse_addr=lighthouse_address, + port=19530 + replica_id, + ) + m = DistributedDataParallel(manager, MyModel()) + optimizer = OptimizerWrapper(manager, optim.Adam(m.parameters())) + criterion = nn.CrossEntropyLoss() + + while True: + inputs = torch.rand(2, 3) + labels = torch.randint(4, (2,)) + + optimizer.zero_grad() + out = m(inputs) + loss = criterion(out, labels) + + loss.backward() + optimizer.step() + + # TODO: assert weights are equal across replicas + + if manager.current_step() >= 5: + break + + manager.shutdown() + + +class ManagerIntegTest(TestCase): + def test_ddp(self): + lighthouse = Lighthouse( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id in range(num_replicas): + futures.append( + executor.submit(train_loop, replica_id, lighthouse.address()) + ) + + for fut in as_completed(futures): + fut.result() + + lighthouse.shutdown() diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index c86c4e5d..c3fc2b37 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -8,3 +8,21 @@ class ManagerClient: ) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ... def checkpoint_address(self, rank: int) -> str: ... def should_commit(self, rank: int, step: int, should_commit: bool) -> bool: ... + +class Manager: + def __init__( + self, + replica_id: str, + lighthouse_addr: str, + address: str, + bind: str, + store_addr: str, + world_size: int, + ) -> None: ... + def address(self) -> str: ... + def shutdown(self) -> None: ... + +class Lighthouse: + def __init__(self, bind: str, min_replicas: int) -> None: ... + def address(self) -> str: ... + def shutdown(self) -> None: ...