From d17cd98d3adda0bb30edfe67432f758c02663e77 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Thu, 5 Dec 2024 20:58:55 -0800 Subject: [PATCH] lintrunner: enable pyre --- .github/workflows/docs.yaml | 2 +- .github/workflows/lint.yaml | 9 +- .lintrunner.toml | 23 +++++ .watchmanconfig | 8 ++ tools/linter/adapters/pyre_linter.py | 124 +++++++++++++++++++++++++++ torchft/checkpointing.py | 17 ++-- torchft/data.py | 6 +- torchft/ddp_test.py | 2 +- torchft/http.py | 7 ++ torchft/manager.py | 39 +++++---- torchft/manager_test.py | 20 ++++- torchft/parameter_server.py | 5 +- torchft/process_group.py | 64 ++++++++------ torchft/process_group_test.py | 10 +-- torchft/torchft.pyi | 10 +++ 15 files changed, 278 insertions(+), 68 deletions(-) create mode 100644 .watchmanconfig create mode 100644 tools/linter/adapters/pyre_linter.py create mode 100644 torchft/http.py create mode 100644 torchft/torchft.pyi diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 707489cd..755ed058 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -8,7 +8,7 @@ on: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - name: Setup Python uses: actions/setup-python@v3 diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 899d6530..2b82d471 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -8,7 +8,7 @@ on: jobs: lint: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - name: Setup Python uses: actions/setup-python@v3 @@ -31,7 +31,12 @@ jobs: run: | set -eux - lintrunner --force-color --all-files + lintrunner --skip PYRE --force-color --all-files + - name: Run pyre + run: | + set -eux + + pyre check - name: Run Rust Lint run: | set -eux diff --git a/.lintrunner.toml b/.lintrunner.toml index 22f84ebc..db84901f 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -42,3 +42,26 @@ command = [ '--', '@{{PATHSFILE}}', ] + +[[linter]] +code = 'PYRE' +include_patterns = [ + '**/*.py', + '**/*.pyi', +] +command = [ + 'python3', + 'tools/linter/adapters/pyre_linter.py', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'pip_init', + '--dry-run={{DRYRUN}}', + 'pyre-check==0.9.23', +] +is_formatter = false diff --git a/.watchmanconfig b/.watchmanconfig new file mode 100644 index 00000000..8d922fc1 --- /dev/null +++ b/.watchmanconfig @@ -0,0 +1,8 @@ +{ + "root_files": [ + "torchft", + "*.py", + ".pyre_configuration", + ".watchmanconfig" + ] +} diff --git a/tools/linter/adapters/pyre_linter.py b/tools/linter/adapters/pyre_linter.py new file mode 100644 index 00000000..d5c9ad84 --- /dev/null +++ b/tools/linter/adapters/pyre_linter.py @@ -0,0 +1,124 @@ +import argparse +import concurrent.futures +import json +import logging +import os +import subprocess +import sys +from enum import Enum +from pathlib import Path +from typing import Any, List, NamedTuple, Optional, Set, TypedDict + +logger: logging.Logger = logging.getLogger(__name__) + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + ADVICE = "advice" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: Optional[str] + line: Optional[int] + char: Optional[int] + code: str + severity: LintSeverity + name: str + original: Optional[str] + replacement: Optional[str] + description: Optional[str] + + +class PyreResult(TypedDict): + line: int + column: int + stop_line: int + stop_column: int + path: str + code: int + name: str + description: str + concise_description: str + + +def run_pyre() -> List[PyreResult]: + proc = subprocess.run( + ["pyre", "--output=json", "incremental"], + capture_output=True, + ) + return json.loads(proc.stdout) + + +def check_pyre( + filenames: Set[str], +) -> List[LintMessage]: + try: + results = run_pyre() + + return [ + LintMessage( + path=result["path"], + line=result["line"], + char=result["column"], + code="pyre", + severity=LintSeverity.WARNING, + name=result["name"], + description=result["description"], + original=None, + replacement=None, + ) + for result in results + ] + except Exception as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code="pyre", + severity=LintSeverity.ADVICE, + name="command-failed", + original=None, + replacement=None, + description=(f"Failed due to {err.__class__.__name__}:\n{err}"), + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Checks files with pyre", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(processName)s:%(levelname)s> %(message)s", + level=( + logging.NOTSET + if args.verbose + else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO + ), + stream=sys.stderr, + ) + + lint_messages = check_pyre(set(args.filenames)) + + for lint_message in lint_messages: + print(json.dumps(lint_message._asdict()), flush=True) + + +if __name__ == "__main__": + main() diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index becd57c4..c358d517 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -16,20 +16,19 @@ import socket import threading import urllib.request -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from typing import Callable +from http.server import BaseHTTPRequestHandler +from typing import Callable, Generic, TypeVar import torch -logger: logging.Logger = logging.getLogger(__name__) +from torchft.http import _IPv6HTTPServer +logger: logging.Logger = logging.getLogger(__name__) -class _IPv6HTTPServer(ThreadingHTTPServer): - address_family = socket.AF_INET6 - request_queue_size = 1024 +T = TypeVar("T") -class CheckpointServer: +class CheckpointServer(Generic[T]): """ This is an HTTP server that can be used to transfer checkpoints between workers. @@ -41,7 +40,7 @@ class CheckpointServer: state_dict: a callable that returns the state dict to be transferred """ - def __init__(self, state_dict: Callable[[], object]) -> None: + def __init__(self, state_dict: Callable[[], T]) -> None: self._checkpoint_lock = threading.Lock() self._disallowed = False self._step = -1 @@ -88,7 +87,7 @@ def err(self, msg: str) -> None: self._thread.start() @classmethod - def load_from_address(cls, address: str) -> object: + def load_from_address(cls, address: str) -> T: """ Loads a checkpoint from the given address. diff --git a/torchft/data.py b/torchft/data.py index cc2e227b..5e3a0f1f 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -47,7 +47,6 @@ def __init__( dataset: data.Dataset, replica_group: int, num_replica_groups: int, - *args, rank: Optional[int] = None, num_replicas: Optional[int] = None, **kwargs, @@ -69,5 +68,8 @@ def __init__( self.global_world_size = num_replicas * num_replica_groups super().__init__( - dataset, *args, rank=self.global_rank, num_replicas=self.global_world_size + dataset, + rank=self.global_rank, + num_replicas=self.global_world_size, + **kwargs, ) diff --git a/torchft/ddp_test.py b/torchft/ddp_test.py index 21993f1e..0145a9fa 100644 --- a/torchft/ddp_test.py +++ b/torchft/ddp_test.py @@ -44,7 +44,7 @@ def allreduce_grad(tensor: torch.Tensor) -> Future[torch.Tensor]: call_count += 1 - fut = Future() + fut = Future() # pyre-fixme[29]: not a function fut.set_result(tensor) return fut diff --git a/torchft/http.py b/torchft/http.py new file mode 100644 index 00000000..246f1d4e --- /dev/null +++ b/torchft/http.py @@ -0,0 +1,7 @@ +import socket +from http.server import ThreadingHTTPServer + + +class _IPv6HTTPServer(ThreadingHTTPServer): + address_family = socket.AF_INET6 + request_queue_size = 1024 diff --git a/torchft/manager.py b/torchft/manager.py index 0d1a64da..5e3b1a20 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -32,7 +32,7 @@ import uuid from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast import torch from torch.distributed import PrefixStore, ReduceOp, TCPStore, Work @@ -51,6 +51,8 @@ MANAGER_ADDR_KEY: str = "manager_addr" MANAGER_DEFAULT_PORT: int = int(os.environ.get("TORCHFT_MANAGER_PORT", 29511)) +T = TypeVar("T") + class Manager: """ @@ -93,14 +95,15 @@ def __init__( """ self._load_state_dict = load_state_dict self._state_dict = state_dict + self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum self._timeout = timeout store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) - rank = rank or int(os.environ["RANK"]) + self._rank: int = rank or int(os.environ["RANK"]) + rank = self._rank world_size = world_size or int(os.environ["WORLD_SIZE"]) - self._rank = rank self._min_replica_size = min_replica_size self._ckpt_server = CheckpointServer( @@ -141,7 +144,6 @@ def __init__( self._store.set(MANAGER_ADDR_KEY, addr) addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8") - # pyre-fixme[16]: can't find rust module self._client = ManagerClient(addr, timeout=timeout) self._step = 0 @@ -149,7 +151,7 @@ def __init__( self._errored = False self._healing = False self._participating_replicas = 0 - self._pending_work: List[torch.futures.Future[torch.Tensor]] = [] + self._pending_work: List[torch.futures.Future[object]] = [] self._batches_committed = 0 # first step is 1 @@ -180,7 +182,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso a Future that will be completed with the allreduced gradient """ if self.errored(): - fut = torch.futures.Future() + fut = torch.futures.Future() # pyre-fixme[29]: not a function fut.set_result(grad) return fut @@ -200,7 +202,7 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso # on the Future def callback( fut: torch.futures.Future[List[torch.Tensor]], - ) -> torch.futures.Future[torch.Tensor]: + ) -> torch.Tensor: nonlocal grad fut.value() @@ -217,7 +219,7 @@ def callback( logger.exception(f"got exception in all reduce -- skipping remaining: {e}") self.report_error() - fut = torch.futures.Future() + fut = torch.futures.Future() # pyre-fixme[29]: not a function fut.set_result(grad) return fut @@ -242,7 +244,9 @@ def errored(self) -> bool: """ return self._errored - def wrap_future(self, fut: torch.futures.Future[object], default: object) -> None: + def wrap_future( + self, fut: torch.futures.Future[T], default: T + ) -> torch.futures.Future[T]: """ Wrap a Future and swallow any errors that occur and report them to the manager. @@ -255,8 +259,8 @@ def wrap_future(self, fut: torch.futures.Future[object], default: object) -> Non # schedule error handling as a continuation on the Future def callback( - fut: torch.futures.Future[List[torch.Tensor]], - ) -> torch.futures.Future[torch.Tensor]: + fut: torch.futures.Future[T], + ) -> T: nonlocal default try: @@ -267,7 +271,7 @@ def callback( return default fut = fut.then(callback) - self._pending_work.append(fut) + self._pending_work.append(cast(torch.futures.Future[object], fut)) return fut def step(self) -> None: @@ -334,14 +338,13 @@ def _async_quorum(self) -> None: logger.info("healing required") logger.info(f"fetching checkpoint server address from {address}") - # pyre-fixme[16]: can't find rust module primary_client = ManagerClient(address, timeout=self._timeout) checkpoint_server_address = primary_client.checkpoint_address(self._rank) - self._state_dict = CheckpointServer.load_from_address( + self._pending_state_dict = CheckpointServer.load_from_address( checkpoint_server_address ) - self.load_state_dict(self._state_dict["torchft"]) + self.load_state_dict(self._pending_state_dict["torchft"]) # we apply the user state dict only when safe from the main thread # This isn't strictly needed as loading the state_dict above should @@ -354,10 +357,10 @@ def _apply_pending_state_dict(self) -> None: # synchronize on future self._quorum_future.result() - assert self._state_dict is not None, "checkpoint was not staged" + assert self._pending_state_dict is not None, "checkpoint was not staged" - self._load_state_dict(self._state_dict["user"]) - self._state_dict = None + self._load_state_dict(self._pending_state_dict["user"]) + self._pending_state_dict = None def should_commit(self) -> bool: """ diff --git a/torchft/manager_test.py b/torchft/manager_test.py index d2ad0b5f..df19401d 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -16,6 +16,9 @@ class TestManager(TestCase): + store: TCPStore # pyre-fixme[13]: never initialized + load_state_dict: MagicMock # pyre-fixme[13]: never initialized + def _create_manager( self, use_async_quorum: bool = True, min_replica_size: int = 2 ) -> Manager: @@ -98,6 +101,7 @@ def test_quorum_happy(self, client_mock) -> None: self.assertEqual(manager._quorum_id, 123) self.assertEqual(manager._step, 1) + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) manager.step() @@ -135,7 +139,9 @@ def test_quorum_heal_sync(self, client_mock) -> None: self.assertEqual(manager._quorum_id, 123) self.assertEqual(manager._step, 20) + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1) self.assertEqual(self.load_state_dict.call_count, 1) @@ -178,7 +184,9 @@ def test_quorum_heal_async_not_enough_participants(self, client_mock) -> None: self.assertEqual(manager._quorum_id, 123) self.assertEqual(manager._step, 20) + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1) self.assertEqual(self.load_state_dict.call_count, 1) @@ -224,7 +232,9 @@ def test_quorum_heal_async_zero_grad(self, client_mock) -> None: self.assertEqual(manager._quorum_id, 123) self.assertEqual(manager._step, 20) + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1) self.assertEqual(self.load_state_dict.call_count, 1) @@ -254,15 +264,18 @@ def test_allreduce_error(self, client_mock) -> None: manager.step() manager.allreduce_grad(torch.tensor([1.0])).wait() + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) # inject failure when work queued + # pyre-ignore[16]: _pg is mocked manager._pg.allreduce.side_effect = RuntimeError("injected failure") manager.allreduce_grad(torch.tensor([1.0])).wait() self.assertTrue(manager._errored) # this should be skipped due to error manager.allreduce_grad(torch.tensor([1.0])).wait() self.assertEqual(manager._pg.allreduce.call_count, 2) + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1) self.assertFalse(manager.should_commit()) @@ -284,7 +297,7 @@ def test_allreduce_error(self, client_mock) -> None: ) manager.step() - bad_fut = torch.futures.Future() + bad_fut = torch.futures.Future() # pyre-fixme[29]: not a function bad_fut.set_exception(RuntimeError("injected failure")) manager._pg.allreduce.return_value.get_future.return_value = bad_fut manager.allreduce_grad(torch.tensor([1.0])).wait() @@ -326,7 +339,7 @@ def test_manager_wrap_future(self, client_mock) -> None: self.assertFalse(manager.errored()) - fut = torch.futures.Future() + fut = torch.futures.Future() # pyre-fixme[29]: not a function wrapped_fut = manager.wrap_future(fut, 2) fut.set_exception(RuntimeError("injected failure")) @@ -342,9 +355,10 @@ def test_manager_numerics(self, client_mock) -> None: manager._quorum_future = MagicMock() manager._participating_replicas = 5 self.assertEqual(manager.num_participants(), 5) + # pyre-ignore[16]: _pg is mocked manager._pg.allreduce.return_value = _DummyWork(None) - fut = torch.futures.Future() + fut = torch.futures.Future() # pyre-fixme[29]: not a function fut = manager.allreduce_grad(torch.tensor([1.0])) result = fut.value() torch.testing.assert_close(result, torch.tensor([1.0 / 5])) diff --git a/torchft/parameter_server.py b/torchft/parameter_server.py index 3823640d..3c844595 100644 --- a/torchft/parameter_server.py +++ b/torchft/parameter_server.py @@ -18,10 +18,11 @@ import urllib.request import uuid from abc import ABC, abstractmethod -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from http.server import BaseHTTPRequestHandler from torch.distributed import TCPStore +from torchft.http import _IPv6HTTPServer from torchft.process_group import ProcessGroup logger: logging.Logger = logging.getLogger(__name__) @@ -101,7 +102,7 @@ def do_GET(self): raise server_address = ("", port) - self._server = ThreadingHTTPServer(server_address, RequestHandler) + self._server = _IPv6HTTPServer(server_address, RequestHandler) self._server.daemon_threads = True logger.info(f"Started ParameterServer on {self.address()}...") diff --git a/torchft/process_group.py b/torchft/process_group.py index 83f248e5..eb8b6d19 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -29,6 +29,9 @@ _register_process_group, _unregister_process_group, ) + +# pyre-fixme[21]: no attribute ProcessGroupNCCL +# pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( BroadcastOptions, DeviceMesh, @@ -102,9 +105,11 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: """ raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def allgather( self, output_tensors: List[List[torch.Tensor]], @@ -113,6 +118,7 @@ def allgather( ) -> Work: raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: raise NotImplementedError("not implemented") @@ -127,7 +133,7 @@ def size(self) -> int: def getBackendName(self) -> str: raise NotImplementedError("not implemented") - def register(self, name: str) -> BaseProcessGroup: + def register(self, name: str) -> "ProcessGroup": """ Registers the process group with the global registry. This enables usage with things like functional_collectives which are compilable. @@ -184,32 +190,34 @@ def __repr__(self) -> str: class ProcessGroupWrapper(ProcessGroup): - PG_CLASS: Type[BaseProcessGroup] + PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized """ This is a wrapper around any ProcessGroup with a reconfiguration method. """ def __init__(self, pg: Optional[ProcessGroup] = None) -> None: super().__init__(0, 1) - self._pg = pg + self._pg: Optional[BaseProcessGroup] = pg def configure(self, store_addr: str, rank: int, world_size: int) -> None: - if isinstance(self._pg, ProcessGroup): - self._pg.configure(store_addr, rank, world_size) + pg = self._pg + if isinstance(pg, ProcessGroup): + pg.configure(store_addr, rank, world_size) return - if self._pg is not None: - if hasattr(self._pg, "abort"): - self._pg.abort() + if pg is not None: + if hasattr(pg, "abort"): + pg.abort() # pyre-fixme[16]: no attribute abort self._pg = None store = create_store_client(store_addr) # TODO: set global timeout + # pyre-fixme[20]: expects argument options self._pg = self.PG_CLASS(store, rank, world_size) def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: - return self._pg.allreduce(tensors, opts) + return self.parent.allreduce(tensors, opts) def allgather( self, @@ -217,15 +225,17 @@ def allgather( input_tensor: List[torch.Tensor], opts: object, ) -> Work: - return self._pg.allgather(output_tensors, input_tensor, opts) + return self.parent.allgather(output_tensors, input_tensor, opts) def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: - return self._pg.broadcast(tensor_list, opts) + return self.parent.broadcast(tensor_list, opts) def size(self) -> int: - return self._pg.size() + return self.parent.size() - def parent(self) -> ProcessGroup: + @property + def parent(self) -> BaseProcessGroup: + assert self._pg is not None, "process group not initialized" return self._pg def __repr__(self) -> str: @@ -237,7 +247,7 @@ class ProcessGroupGloo(ProcessGroupWrapper): This is a reconfigurable version of ProcessGroupGloo. """ - PG_CLASS = BaseProcessGroupGloo + PG_CLASS = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo def getBackendName(self) -> str: return "torchft-gloo" @@ -254,7 +264,7 @@ class ProcessGroupNCCL(ProcessGroupWrapper): abort when reconfiguring, we need to ensure this is safe. """ - PG_CLASS = BaseProcessGroupNCCL + PG_CLASS = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL def getBackendName(self) -> str: return "torchft-nccl" @@ -326,7 +336,7 @@ def getBackendName(self): class _ErrorSwallowingWork(Work): def __init__( self, - pg: "ErrorSwallowingProcessGroup", + pg: "ErrorSwallowingProcessGroupWrapper", work: Work, default_result: object, ): @@ -350,7 +360,7 @@ def get_future(self) -> Future: # schedule error handling as a continuation on the Future def callback( fut: torch.futures.Future[List[torch.Tensor]], - ) -> torch.futures.Future[torch.Tensor]: + ) -> object: try: return fut.value() except Exception as e: @@ -464,7 +474,7 @@ def __init__( self._op_id = op_id self._timeout = timeout - def wait(self) -> bool: + def wait(self, timeout: Optional[timedelta] = None) -> bool: self._tx.put(("wait", self._op_id), timeout=self._timeout) assert _get(self._rx, self._timeout) == self._op_id return True @@ -474,8 +484,9 @@ def get_future(self) -> Future: class _BabyWorkNCCL(_BabyWork): - def wait(self) -> bool: + def wait(self, timeout: Optional[timedelta] = None) -> bool: self._tx.put(("synchronize", self._op_id), timeout=self._timeout) + # pyre-fixme[23]: unable to unpack into 2 values op_id, event = _get(self._rx, self._timeout) assert op_id == self._op_id assert isinstance(event, torch.cuda.Event) @@ -495,7 +506,7 @@ class ProcessGroupBaby(ProcessGroup): """ - PG_CLASS: Type[BaseProcessGroup] + PG_CLASS: Type[BaseProcessGroup] # pyre-fixme[13]: never initialized WORK_CLASS: Type[_BabyWork] = _BabyWork def __init__(self, timeout: float = 60.0) -> None: @@ -508,6 +519,8 @@ def __init__(self, timeout: float = 60.0) -> None: self._rx = None self._future_queue = None self._future_thread = None + self._futures = {} + self._futures_lock = threading.Lock() self._timeout = timeout @@ -561,6 +574,7 @@ def _worker( try: store = create_store_client(store_addr) + # pyre-fixme[20]: expects argument options pg = cls.PG_CLASS(store, rank, world_size) work = {} @@ -635,7 +649,7 @@ def _future_handler(self, future_queue: mp.Queue) -> None: def _get_future(self, op_id: int) -> Future: with self._futures_lock: - fut = Future() + fut = Future() # pyre-fixme[29]: is not a function self._futures[op_id] = fut self._tx.put(("future", op_id), timeout=self._timeout) @@ -672,7 +686,7 @@ class ProcessGroupBabyGloo(ProcessGroupBaby): ProcessGroupBabyNCCL. """ - PG_CLASS = BaseProcessGroupGloo + PG_CLASS = BaseProcessGroupGloo # pyre-fixme[16]: no attribute ProcessGroupGloo def getBackendName(self): return "torchft-baby-gloo" @@ -694,7 +708,7 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): tensors may leak in the current PyTorch implementation. TODO fix """ - PG_CLASS = BaseProcessGroupNCCL + PG_CLASS = BaseProcessGroupNCCL # pyre-fixme[16]: no attribute ProcessGroupNCCL WORK_CLASS = _BabyWorkNCCL def getBackendName(self): @@ -719,12 +733,12 @@ def extend_device_mesh( """ groups = mesh.get_all_groups() groups.insert(dim, pg) - mesh_dim_names = list(mesh.mesh_dim_names) + mesh_dim_names = list(mesh.mesh_dim_names or []) mesh_dim_names.insert(dim, name) return DeviceMesh.from_group( group=groups, device_type=mesh.device_type, mesh=mesh.mesh.unsqueeze(dim), - mesh_dim_names=mesh_dim_names, + mesh_dim_names=tuple(mesh_dim_names), ) diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 0d8c0cca..d7ccecf6 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -6,6 +6,7 @@ import os from concurrent.futures import ThreadPoolExecutor +from typing import Tuple from unittest import TestCase, skipUnless from unittest.mock import Mock @@ -13,14 +14,13 @@ import torch.distributed as dist from torch import nn from torch._C._distributed_c10d import _resolve_process_group -from torch.distributed import ReduceOp, TCPStore, _functional_collectives +from torch.distributed import ReduceOp, TCPStore, Work, _functional_collectives from torch.distributed.device_mesh import init_device_mesh from torchft.manager import Manager from torchft.process_group import ( ErrorSwallowingProcessGroupWrapper, ManagedProcessGroup, - ProcessGroup, ProcessGroupBabyGloo, ProcessGroupBabyNCCL, ProcessGroupDummy, @@ -136,7 +136,7 @@ def test_baby_nccl_2gpu(self) -> None: store_addr = f"localhost:{store.port}/prefix" - def run(rank: int) -> None: + def run(rank: int) -> Tuple[torch.Tensor, Work]: a = ProcessGroupBabyNCCL() a.configure(store_addr, rank, 2) @@ -205,7 +205,7 @@ def test_functional_collectives(self) -> None: def test_process_group_wrapper(self) -> None: pg = ProcessGroupDummy(0, 1) wrapper = ProcessGroupWrapper(pg) - self.assertIs(wrapper.parent(), pg) + self.assertIs(wrapper.parent, pg) wrapper.configure("addr", 0, 1) self.assertEqual(pg.configure_count, 1) @@ -215,7 +215,7 @@ def test_process_group_wrapper(self) -> None: def test_error_swallowing_process_group_wrapper(self) -> None: pg = ProcessGroupDummy(0, 1) wrapper = ErrorSwallowingProcessGroupWrapper(pg) - self.assertIs(wrapper.parent(), pg) + self.assertIs(wrapper.parent, pg) t = torch.zeros(10) work = wrapper.allreduce([t], ReduceOp.SUM) diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi new file mode 100644 index 00000000..4076ba5d --- /dev/null +++ b/torchft/torchft.pyi @@ -0,0 +1,10 @@ +from datetime import timedelta +from typing import Tuple + +class ManagerClient: + def __init__(self, addr: str, timeout: timedelta) -> None: ... + def quorum( + self, rank: int, step: int, checkpoint_server_addr: str + ) -> Tuple[int, int, int, str, str, int, int, bool]: ... + def checkpoint_address(self, rank: int) -> str: ... + def should_commit(self, rank: int, step: int, should_commit: bool) -> bool: ...