diff --git a/.gitignore b/.gitignore index cde8b6f6b..dce165df2 100644 --- a/.gitignore +++ b/.gitignore @@ -154,7 +154,7 @@ RuntimeProfiler* checkpoint-*/ data*/ !mindnlp/data/ -!mindnlp/core/utils/data/ +!mindtorch/utils/data/ !mindnlp/dataset/ !docs/api/data/ !data2vec/ diff --git a/mindtorch/_C/__init__.py b/mindtorch/_C/__init__.py index 6d4ae8b7f..2531c05ab 100644 --- a/mindtorch/_C/__init__.py +++ b/mindtorch/_C/__init__.py @@ -211,4 +211,10 @@ def _log_api_usage_once(*args): pass ScriptDict = dict -ScriptList = list \ No newline at end of file +ScriptList = list + +class _DistStoreError(RuntimeError): pass + +def _get_accelerator(): + device_target = mindspore.get_context("device_target") + return device_(DEVICE_MAP[device_target]) \ No newline at end of file diff --git a/mindtorch/_C/_distributed_c10d.py b/mindtorch/_C/_distributed_c10d.py new file mode 100644 index 000000000..cce5a1789 --- /dev/null +++ b/mindtorch/_C/_distributed_c10d.py @@ -0,0 +1,117 @@ +import pickle +from typing import List, Any +from datetime import timedelta + +import mindtorch +from mindtorch import Tensor +from mindtorch.distributed import Store, TCPStore +from mindtorch.distributed.c10d import Backend, ReduceOp + + +class ProcessGroup: + pass + +class ProcessGroupGloo(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + timeout: timedelta + ) -> None: + super().__init__(rank, size) + self.store = store + self.ranks = [] + self.pg = None + + def name(self) -> str: + return 'gloo' + + def allreduce(self, tensors: List[Tensor], opts: Any) -> Any: + if mindtorch.distributed.is_initialized(): + self._allreduce_new_pg(tensors[0], opts) + else: + self._allreduce_use_store(tensors, opts) + + def _allreduce_new_pg(self, tensor, opts): + # Get all global ranks + if len(self.ranks) == 0: + rank_bytes = pickle.dumps(mindtorch.distributed.get_rank()) + self.store.set(f'__ar_rank_local_to_global_{self.rank_}', rank_bytes) + for local_rank in range(self.size_): + global_rank = pickle.loads(self.store.get(f'__ar_rank_local_to_global_{local_rank}')) + self.ranks.append(global_rank) + + if self.pg is None: + self.pg = mindtorch.distributed.new_group(self.ranks, backend='gloo') + + mindtorch.distributed.all_reduce(tensor, op=opts.reduceOp, group=self.pg, async_op=False) + + def _allreduce_use_store(self, tensors: List[Tensor], opts: Any) -> Any: + tensor = tensors[0] + tensor_bytes = pickle.dumps(tensor) + self.store.set(f'__ar_data_{self.rank_}', tensor_bytes) + + # Gather all tensors + gathered = [] + for i in range(self.size_): + data = self.store.get(f'__ar_data_{i}') + gathered.append(pickle.loads(data)) + stacked = mindtorch.stack(gathered) + + reduce_op = opts.reduceOp + if reduce_op == ReduceOp.SUM: + result = stacked.sum(dim=0) + elif reduce_op == ReduceOp.MAX: + if stacked.dtype == mindtorch.int32: + result = stacked.to(mindtorch.int64).max(dim=0).values.to(mindtorch.int32) + else: + result = stacked.max(dim=0).values + elif reduce_op == ReduceOp.MIN: + if stacked.dtype == mindtorch.int32: + result = stacked.to(mindtorch.int64).min(dim=0)[0].to(mindtorch.int32) + else: + result = stacked.min(dim=0)[0] + elif reduce_op == ReduceOp.PRODUCT: + result = stacked.prod(dim=0) + else: + raise ValueError(f'Unsupported reduce operation: {reduce_op}') + + tensors[0].copy_(result) + self._synchronize_and_cleanup() + + def _synchronize_and_cleanup(self): + if self.rank_ == 0: + # Wait for the completion of allreduce() execution for other ranks and remove the tensor_i key + # to prevent subsequent allreduce() exceptions. + for i in range(1, self.size_): + self.store.get(f'__ar_finish_1_{i}') + for i in range(self.size_): + self.store.delete_key(f'__ar_data_{i}') + self.store.delete_key(f'__ar_finish_1_{i}') + + # Ensure that other ranks wait for the deletion of tensor_i key to complete. + self.store.set('__ar_finish_all', '') + + # Ensure that rank 0 exits last to prevent errors in other ranks. + for i in range(1, self.size_): + self.store.get(f'__ar_finish_2_{i}') + self.store.delete_key(f'__ar_finish_2_{i}') + self.store.delete_key('__ar_finish_all') + else: + self.store.set(f'__ar_finish_1_{self.rank_}', '') + self.store.get('__ar_finish_all') + self.store.set(f'__ar_finish_2_{self.rank_}', '') + + def _set_sequence_number_for_group(self): + pass + + +class ProcessGroupHCCL: + def __init__(self, group_name): + self.group_name = group_name + + def get_hccl_comm_name(self, global_rank): + return self.group_name + + class Options: ... diff --git a/mindtorch/distributed/__init__.py b/mindtorch/distributed/__init__.py index c7e7006eb..7c95bfc59 100644 --- a/mindtorch/distributed/__init__.py +++ b/mindtorch/distributed/__init__.py @@ -55,7 +55,7 @@ def is_available() -> bool: # set_debug_level, # set_debug_level_from_env, Store, - # TCPStore, + TCPStore, Work as _Work, ) diff --git a/mindtorch/distributed/c10d/__init__.py b/mindtorch/distributed/c10d/__init__.py index 4549853f2..c1d872364 100644 --- a/mindtorch/distributed/c10d/__init__.py +++ b/mindtorch/distributed/c10d/__init__.py @@ -1,4 +1,4 @@ -from .store import Store +from .store import Store, TCPStore, FileStore from .prefix_store import PrefixStore from .types import * from .process_group import ProcessGroup diff --git a/mindtorch/distributed/c10d/store.py b/mindtorch/distributed/c10d/store.py index 7f5a20a77..bff340e68 100644 --- a/mindtorch/distributed/c10d/store.py +++ b/mindtorch/distributed/c10d/store.py @@ -1,6 +1,10 @@ -import time from typing import List, Optional, Callable from abc import ABC, abstractmethod +from datetime import timedelta +try: + from mindspore.mint.distributed.distributed import TCPStore as MsTCPStore +except: + MsTCPStore = None class Store: kDefaultTimeout = 300 # in seconds @@ -98,3 +102,42 @@ def __copy__(self): def __move__(self): raise NotImplementedError("Moving not allowed") + +class TCPStore(Store): + def __init__( + self, + host_name: str, + port: int, + world_size: Optional[int] = None, + is_master: bool = False, + timeout: timedelta = timedelta(seconds=300), + wait_for_workers: bool = True, + multi_tenant: bool = False, + master_listen_fd: Optional[int] = None, + use_libuv: bool = True + ) -> None: + super().__init__(timeout) + self.ms_store = MsTCPStore(host_name, port, world_size, is_master, timeout, wait_for_workers, multi_tenant, master_listen_fd, use_libuv) + + @property + def host(self) -> str: + return self.ms_store.host + + @property + def port(self) -> int: + return self.ms_store.port + + def set(self, key: str, value: str) -> None: + self.ms_store.set(key, value) + + def add(self, key: str, value: int) -> int: + return self.ms_store.add(key, value) + + def get(self, key: str) -> bytes: + return self.ms_store.get(key) + + def delete_key(self, key: str) -> bool: + return self.ms_store.delete_key(key) + +class FileStore(Store): + def __init__(self, path: str, numWorkers: int = ...): ... diff --git a/mindtorch/distributed/distributed_c10d.py b/mindtorch/distributed/distributed_c10d.py index 8148a11d2..54c1f2c19 100644 --- a/mindtorch/distributed/distributed_c10d.py +++ b/mindtorch/distributed/distributed_c10d.py @@ -1639,7 +1639,7 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + if device_id is not None and (device_id.index is None): raise ValueError( "init_process_group device_id parameter must be a cuda device with an " "id, e.g. cuda:0, not just cuda or cpu" diff --git a/mindtorch/distributed/elastic/__init__.py b/mindtorch/distributed/elastic/__init__.py deleted file mode 100644 index 427e1745c..000000000 --- a/mindtorch/distributed/elastic/__init__.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env/python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" - -Torchelastic agent and user worker failover contract: - -**TL;DR;**: - -* TE(torchelastic) expects user workers to finish with the 5 minutes drift -* It is better to design DDP app to fail for all workers, rather than a single one. -* TE does not synchronize number of restarts between agents -* TE re-rendezvous does not trigger restart decrease -* When a single agent finishes its job(successfully or not), it will close rendezvous. - If other agents still have workers in progress, they will be terminated. -* Based on above, scale down does not work if at least single agent finishes the job. -* When Scale up is detected by agents, it will not decrease ``max_restarts`` - - -In general TE(torchelastic) can launch arbitrary user code, but there is some -clarifications need to be done around what failover mechanism torchelastic -provides and what failover mechanism it expects from user workers. - -Torchelastic currently supports DDP style applications. That means that -TE expects *ALL* workers finish approximately at the same time. In practice, -it is nearly to impossible to guarantee that all workers in arbitrary -DDP application finish at the time, so TE provides a finalization barrier -that waits for TIMEOUT(5 minutes) for worker finalization. - -**Worker Failure** - -When worker fails, TE will check the number of restarts -available, if there is more than 0 restarts, TE will start a new rendezvous -round and restart the worker process. New rendezvous round will other -TE agents to terminate their workers. - -.. note:: The TE agent does not synchronize restarts between themselves. - When a single agent performs restart, it will trigger a local ``max_restarts`` - decrease, other agent will not decrease their ``max_restarts``. - the user to run the distributed application locally on a dev host. - -A single worker failure can cause the whole cluster to fail: -If a single worker is constantly failing, it will cause the TE agent -``max_restarts`` to go to zero. This will cause an agent to finish its -work and close rendezvous. If there are any other workers on different -agents, they will be terminated. - - -**Re-Rendezvous** - -Re-rendezvous occurs when TE agents detect a new node -trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents -will terminate its workers and start a new rendezvous round. - -Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous -has already max_nodes, the new node won't be added to the wait list right -away since there is no need to tear down a rendezvous that is already fully -utilized. The new node will wait until its timeout (600 secs by default) -and periodically check the number of participants. If the number becomes -less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs. - -*Scale up event*. When scale up event happens, torchelastic rendezvous -will detect that there are new nodes trying to join. Torchelastic agent -will stop all workers and perform re-rendezvous. Note: when scale up event -happens, *``max_restarts``* will *not* decrease. - -*Scale down event*. When scale down event happens, rendezvous will not -notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` , -it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` , -TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*. - -""" diff --git a/mindtorch/distributed/elastic/agent/server/__init__.py b/mindtorch/distributed/elastic/agent/server/__init__.py deleted file mode 100644 index 93f4f128e..000000000 --- a/mindtorch/distributed/elastic/agent/server/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -The elastic agent is the control plane of torchelastic. - -It is a process that launches and manages underlying worker processes. -The agent is responsible for: - -1. Working with distributed torch: the workers are started with all the - necessary information to successfully and trivially call - ``mindtorch.distributed.init_process_group()``. - -2. Fault tolerance: monitors workers and upon detecting worker failures - or unhealthiness, tears down all workers and restarts everyone. - -3. Elasticity: Reacts to membership changes and restarts workers with the new - members. - -The simplest agents are deployed per node and works with local processes. -A more advanced agent can launch and manage workers remotely. Agents can -be completely decentralized, making decisions based on the workers it manages. -Or can be coordinated, communicating to other agents (that manage workers -in the same job) to make a collective decision. -""" - -from .api import ( # noqa: F401 - ElasticAgent, - RunResult, - SimpleElasticAgent, - Worker, - WorkerGroup, - WorkerSpec, - WorkerState, -) -from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE diff --git a/mindtorch/distributed/elastic/agent/server/api.py b/mindtorch/distributed/elastic/agent/server/api.py deleted file mode 100644 index b9be1021e..000000000 --- a/mindtorch/distributed/elastic/agent/server/api.py +++ /dev/null @@ -1,957 +0,0 @@ -# mypy: ignore-errors - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import abc -import json -import os -import signal -import socket -import time -import traceback -import warnings -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import mindtorch.distributed.elastic.rendezvous as rdzv -import mindtorch.distributed.elastic.utils.store as store_util -from mindtorch.distributed.elastic.events import Event, EventSource, record -from mindtorch.distributed.elastic.metrics import prof, put_metric -from mindtorch.distributed.elastic.multiprocessing import ProcessFailure, SignalException -from mindtorch.distributed.elastic.rendezvous import RendezvousGracefulExitError -from mindtorch.distributed.elastic.utils.logging import get_logger - - -__all__ = [ - "WorkerSpec", - "Worker", - "WorkerState", - "WorkerGroup", - "RunResult", - "ElasticAgent", - "SimpleElasticAgent", -] -_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state" - -DEFAULT_ROLE = "default" -logger = get_logger(__name__) - - -@dataclass -class WorkerSpec: - """Blueprint information about a particular type of worker. - - For a given role, there must only exist a single worker spec. - Worker spec is expected to be homogeneous across all nodes (machine), - that is each node runs the same number of workers for a particular spec. - - Args: - role: user-defined role for the workers with this spec - local_world_size: number local workers to run - fn: (deprecated use entrypoint instead) - entrypoint: worker function or command - args: arguments to pass to ``entrypoint`` - rdzv_handler: handles rdzv for this set of workers - max_restarts: number of max retries for the workers - monitor_interval: monitor status of workers every ``n`` seconds - master_port: fixed port to run the c10d store on rank 0 - if not specified then will chose a random free port - master_addr: fixed master_addr to run the c10d store on rank 0 - if not specified then will chose hostname on agent rank 0 - redirects: redirect std streams to a file, - selectively redirect for a particular - local rank by passing a map - tee: tees the specified std stream(s) to console + file, - selectively tee for a particular local rank by passing a map, - takes precedence over ``redirects`` settings. - - """ - - role: str - local_world_size: int - rdzv_handler: rdzv.RendezvousHandler - fn: Optional[Callable] = None - # TODO @kiuk - make entrypoint a required field - entrypoint: Union[Callable, str, None] = None - args: Tuple = () - max_restarts: int = 3 - monitor_interval: float = 0.1 - master_port: Optional[int] = None - master_addr: Optional[str] = None - local_addr: Optional[str] = None - - def __post_init__(self): - assert self.local_world_size > 0 - assert self.monitor_interval > 0 - - if self.fn: - warnings.warn( - "WorkerSpec.fn will be deprecated," - " please use WorkerSpec.entrypoint instead", - category=DeprecationWarning, - ) - self.entrypoint = self.fn - assert self.entrypoint - - def get_entrypoint_name(self): - """Get the entry point name. - - If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__`` - else if the entrypoint is a binary (e.g. ``str``), returns the binary name. - """ - if isinstance(self.entrypoint, str): - return os.path.basename(self.entrypoint) - else: - assert self.entrypoint is not None - return self.entrypoint.__qualname__ - - -class Worker: - """A worker instance. - - Contrast this with ``WorkerSpec`` that represents the specifications of a - worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to - a ``WorkerSpec`` as an object is to a class. - - The ``id`` of the worker is interpreted - by the specific implementation of ``ElasticAgent``. For a local - agent, it could be the ``pid (int)`` of the worker, for a remote - agent it could be encoded as ``host:port (string)``. - - Args: - id (Any): uniquely identifies a worker (interpreted by the agent) - local_rank (int): local rank of the worker - global_rank (int): global rank of the worker - role_rank (int): rank of the worker across all workers that have the same role - world_size (int): number of workers (globally) - role_world_size (int): number of workers that have the same role - """ - - __slots__ = [ - "id", - "local_rank", - "global_rank", - "role_rank", - "world_size", - "role_world_size", - ] - - def __init__( - self, - local_rank: int, - global_rank: int = -1, - role_rank: int = -1, - world_size: int = -1, - role_world_size: int = -1, - ): - # unique identifier for this worker - self.id: Any = None - - # rank of the worker among workers with the same role being monitored - # by the same ``agent`` instance. - self.local_rank: int = local_rank - - # rank of the worker among all the workers across all roles - # across all ``agent`` instances. - # Global rank is not stable between re-rendezvous. - self.global_rank: int = global_rank - - # rank of the worker among all the workers with the same role - # across all ``agent`` instances. - # Role rank is not stable between re-rendezvous. - self.role_rank: int = role_rank - - # total number of workers (globally). Due to elasticity - # the world size may change between re-rendezvous. - self.world_size: int = world_size - - # total number of workers that share the same role. Due to elasticity - # the role world size may change between re-rendezvous. - self.role_world_size: int = role_world_size - - def __str__(self): - return ( - f"local_rank={self.local_rank},global_rank={self.global_rank}" - f",role_rank={self.role_rank},world_size={self.world_size}" - f",role_world_size={self.role_world_size}" - ) - - def __repr__(self): - return str(self) - - -class WorkerState(str, Enum): - """A state of the ``WorkerGroup``. - - Workers in a worker group change state as a unit. If a single worker - in a worker group fails the entire set is considered failed:: - - UNKNOWN - agent lost track of worker group state, unrecoverable - INIT - worker group object created not yet started - HEALTHY - workers running and healthy - UNHEALTHY - workers running and unhealthy - STOPPED - workers stopped (interrupted) by the agent - SUCCEEDED - workers finished running (exit 0) - FAILED - workers failed to successfully finish (exit !0) - - - A worker group starts from an initial ``INIT`` state, - then progresses to ``HEALTHY`` or ``UNHEALTHY`` states, - and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state. - - Worker groups can be interrupted and temporarily put into ``STOPPED`` state - by the agent. Workers in ``STOPPED`` state are scheduled to be restarted - in the near future by the agent. Some examples of workers being put into - ``STOPPED`` state are: - - 1. Worker group failure|unhealthy observed - 2. Membership change detected - - When actions (start, stop, rdzv, retry, etc) on worker group fails - and results in the action being partially applied to the worker group - the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled - exceptions during state change events on the agent. The agent is not - expected to recover worker groups in ``UNKNOWN`` state and is better off - self terminating and allowing the job manager to retry the node. - """ - - UNKNOWN = "UNKNOWN" - INIT = "INIT" - HEALTHY = "HEALTHY" - UNHEALTHY = "UNHEALTHY" - STOPPED = "STOPPED" - SUCCEEDED = "SUCCEEDED" - FAILED = "FAILED" - - @staticmethod - def is_running(state: "WorkerState") -> bool: - """Return the state of the Worker. - - Returns: - True if the worker state represents workers still running - (e.g. that the process exists but not necessarily healthy). - """ - return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY} - - -class WorkerGroup: - """A set of ``Worker`` instances. - - The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker - group contains cross instance workers or not depends on the implementation of the agent. - """ - - __slots__ = [ - "spec", - "workers", - "store", - "group_rank", - "group_world_size", - "state", - "master_addr", - "master_port", - ] - - def __init__(self, spec: WorkerSpec): - self.spec = spec - self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)] - - # assigned after rdzv - self.store = None - self.group_rank = None - self.group_world_size = None - self.master_addr = None - self.master_port = None - - self.state = WorkerState.INIT - - -class _RoleInstanceInfo: - """The class is used by the agent to exchange the information with other agents. - - The information is used to determine the rank of the workers that agent - manages in heterogeneous environments, where different agents can have - different number of workers. - """ - - __slots__ = ["role", "rank", "local_world_size"] - - def __init__(self, role: str, rank: int, local_world_size: int): - r"""Initialize the agent class instance. - - Args: - role (str): user-defined role for the workers with this spec - rank (int): the rank of the agent - local_world_size (int): number of local workers to run - """ - self.role = role - self.rank = rank - self.local_world_size = local_world_size - - def serialize(self) -> bytes: - dict_data = { - "role": self.role, - "rank": self.rank, - "local_world_size": self.local_world_size, - } - return json.dumps(dict_data).encode(encoding="UTF-8") - - @staticmethod - def deserialize(data: bytes): - dict_data = json.loads(data.decode(encoding="UTF-8")) - return _RoleInstanceInfo( - dict_data["role"], dict_data["rank"], dict_data["local_world_size"] - ) - - @staticmethod - def compare(obj1, obj2) -> int: - if obj1.role == obj2.role: - return obj1.rank - obj2.rank - elif obj1.role > obj2.role: - return 1 - else: - return -1 - - @staticmethod - def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]: - start_idx, end_idx = -1, -1 - for idx, role_info in enumerate(roles_infos): - if role_info.role == role: - if start_idx == -1: - start_idx = idx - end_idx = idx - return (start_idx, end_idx) - - -@dataclass -class RunResult: - """Return results of the worker executions. - - Run results follow an "all-or-nothing" policy where the run is successful if and - only if ALL local workers managed by this agent complete successfully. - - If the result is successful (e.g. ``is_failed() = False``) then the ``return_values`` - field contains the outputs (return values) of the workers managed by THIS agent mapped - by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of - global rank 0. - - .. note:: ``return_values`` are only meaningful for when the worker entrypoint - is a function. Workers specified as a binary entrypoint do not canonically - have a return value and the ``return_values`` field is meaningless and - may be empty. - - If ``is_failed()`` returns ``True`` then the ``failures`` field contains the - failure information, again, mapped by the GLOBAL rank of the worker that failed. - - The keys in ``return_values`` and ``failures`` are mutually exclusive, that is, - a worker's final state can only be one of: succeeded, failed. Workers intentionally - terminated by the agent according to the agent's restart policy, are not represented - in either ``return_values`` nor ``failures``. - """ - - state: WorkerState - return_values: Dict[int, Any] = field(default_factory=dict) - failures: Dict[int, ProcessFailure] = field(default_factory=dict) - - def is_failed(self) -> bool: - return self.state == WorkerState.FAILED - - -def _get_fq_hostname() -> str: - return socket.getfqdn(socket.gethostname()) - - -class ElasticAgent(abc.ABC): - """An agent process responsible for managing one or more worker processes. - - The worker processes are assumed to be regular distributed PyTorch scripts. - When the worker process is created by the agent, the agent provides the - necessary information for the worker processes to properly initialize - a torch process group. - - The exact deployment topology and ratio of agent-to-worker is dependent - on the specific implementation of the agent and the user's job placement - preferences. For instance, to run a distributed training job on GPU with - 8 trainers (one per GPU) one can: - - 1. Use 8 x single GPU instances, place an agent per instance, managing - 1 worker per agent. - 2. Use 4 x double GPU instances, place an agent per instance, managing - 2 workers per agent. - 3. Use 2 x quad GPU instances, place an agent per instance, managing - 4 workers per agent. - 4. Use 1 x 8 GPU instance, place an agent per instance, managing - 8 workers per agent. - - Usage - :: - - group_result = agent.run() - if group_result.is_failed(): - # workers failed - failure = group_result.failures[0] - logger.exception("worker 0 failed with exit code : %s", failure.exit_code) - else: - return group_result.return_values[0] # return rank 0's results - - """ - - @abc.abstractmethod - def run(self, role: str = DEFAULT_ROLE) -> RunResult: - """Run the agent. - - Supports retrying the worker group on failures up to ``max_restarts``. - - Returns: - The result of the execution, containing the return values or - failure details for each worker mapped by the worker's global rank. - - Raises: - Exception - any other failures NOT related to worker process - """ - raise NotImplementedError - - @abc.abstractmethod - def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: - """Return the ``WorkerGroup`` for the given ``role``. - - Note that the worker group is a mutable object and hence in a - multi-threaded/process environment it may change state. - Implementors are encouraged (but not required) to return - a defensive read-only copy. - """ - raise NotImplementedError - - -class SimpleElasticAgent(ElasticAgent): - """An ``ElasticAgent`` that manages one particular type of worker role. - - An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` - such as one particular type of worker role. - """ - - def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300): - self._worker_group = WorkerGroup(spec) - self._remaining_restarts = self._worker_group.spec.max_restarts - self._store = None - self._exit_barrier_timeout = exit_barrier_timeout - self._total_execution_time = 0 - - def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: - return self._worker_group - - @abc.abstractmethod - def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: - r"""Start ``worker_group.spec.local_world_size`` number of workers. - - This is according to worker spec for the worker group . - Returns a map of ``local_rank`` to worker ``id``. - """ - raise NotImplementedError - - @abc.abstractmethod - def _stop_workers( - self, worker_group: WorkerGroup, is_restart: bool = False - ) -> None: - r"""Stop all workers in the given worker group. - - Implementors must deal with workers in all states defined by - ``WorkerState``. That is, it must gracefully handle stopping - non-existent workers, unhealthy (stuck) workers, etc. - """ - raise NotImplementedError - - @abc.abstractmethod - def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: - r"""Check on the workers for the ``worker_group``. - - This function also returns the new state of the worker group. - """ - raise NotImplementedError - - @abc.abstractmethod - def _shutdown( - self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False - ) -> None: - """Clean up any resources that were allocated during the agent's work. - - Args: - death_sig: Signal to send to the child process, SIGTERM is default - """ - raise NotImplementedError - - @prof - def _rendezvous(self, worker_group: WorkerGroup) -> None: - r"""Run rendezvous for the workers specified by the worker spec. - - Assigns workers a new global rank and world size. - Updates the rendezvous store for the worker group. - """ - spec = worker_group.spec - - with self.record_duration("RENDEZVOUS"): - rdzv_info = spec.rdzv_handler.next_rendezvous() - store = rdzv_info.store - group_rank = rdzv_info.rank - group_world_size = rdzv_info.world_size - - # master_addr/master_port could be explicitly overriden - # TODO: BC - specific to static rdzv and can be simplifed further - master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr - master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port - - self._store = store - - with self.record_duration("ASSIGN_WORKER_RANKS"): - workers = self._assign_worker_ranks( - store, group_rank, group_world_size, spec - ) - worker_group.workers = workers - worker_group.store = store - worker_group.group_rank = group_rank - worker_group.group_world_size = group_world_size - worker_group.master_addr = master_addr - worker_group.master_port = master_port - - restart_count = spec.max_restarts - self._remaining_restarts - - logger.info( - "[%(role)s] Rendezvous complete for workers. Result:\n" - " restart_count=%(restart_count)s\n" - " master_addr=%(master_addr)s\n" - " master_port=%(master_port)s\n" - " group_rank=%(group_rank)s\n" - " group_world_size=%(group_world_size)s\n" - " local_ranks=%(local_ranks)s\n" - " role_ranks=%(role_ranks)s\n" - " global_ranks=%(global_ranks)s\n" - " role_world_sizes=%(role_world_sizes)s\n" - " global_world_sizes=%(global_world_sizes)s\n", - { - "role": spec.role, - "restart_count": restart_count, - "master_addr": master_addr, - "master_port": master_port, - "group_rank": group_rank, - "group_world_size": group_world_size, - "local_ranks": [worker.local_rank for worker in workers], - "role_ranks": [worker.role_rank for worker in workers], - "global_ranks": [worker.global_rank for worker in workers], - "role_world_sizes": [worker.role_world_size for worker in workers], - "global_world_sizes": [worker.world_size for worker in workers], - }, - ) - - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - # `mindtorch.distributed.elastic.metrics.prof`. - @prof - def _assign_worker_ranks( - self, store, group_rank: int, group_world_size: int, spec: WorkerSpec - ) -> List[Worker]: - """Determine proper ranks for worker processes. - - Fast Path: when all workers have the same role and world size. We calculate - the global rank to be group_rank * group_world_size + local_rank. And the - `role_world_size` is the same as `global_world_size`. No TCP store is used in - this case. This is only enabled when users set the environment variable - `TORCH_ELASTIC_WORKER_IDENTICAL` to 1. - - Time complexity: each worker O(1), overall O(1) - - Slow Path: when workers have different roles and world sizes. We use the - the following algorithm: - - 1. Each agent writes its configuration(group_rank, group_world_size - , num_workers) to the common store. - 2. The rank 0 agent reads all the role_info from the store and - determines each agents worker ranks. - 3. Determine the global rank: the global rank of the workers is computed - by cumulative sum of the local_world_size for all workers in front of it. - For efficiency reasons each worker is assigned a base global rank - such that it's workers are in the range [base_global_rank, - base_global_rank + local_world_size). - 4. Determine the role rank: The role rank is determined using the algorithms - in the point 3 with the exception that the ranks are calculated with - respect to the role name. - 5. The rank 0 agent writes the assigned ranks to the store. - 6. Each agent reads the assigned ranks from the store. - - Time complexity: each worker O(1), rank0 O(n), overall O(n) - """ - - if os.environ.get("TORCH_ELASTIC_WORKER_IDENTICAL", "0") == "1": - global_world_size = group_world_size * spec.local_world_size - base_global_rank = group_rank * spec.local_world_size - base_role_rank = base_global_rank - role_world_size = global_world_size - else: - ROLE_INFO_PREFIX = "torchelastic/role_info/" - ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/" - - agent_role_info = _RoleInstanceInfo( - spec.role, group_rank, spec.local_world_size - ) - store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize()) - - # tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations. - if group_rank == 0: - role_infos_bytes = store.multi_get( - [f"torchelastic/role_info/{i}" for i in range(group_world_size)] - ) - role_infos = [ - _RoleInstanceInfo.deserialize(info_bytes) - for info_bytes in role_infos_bytes - ] - - role_sizes = defaultdict(lambda: 0) - global_size = 0 - for role_info in role_infos: - role_sizes[role_info.role] += role_info.local_world_size - global_size += role_info.local_world_size - - base_global_rank = 0 - role_ranks = defaultdict(lambda: 0) - - keys = [] - values = [] - for i, role_info in enumerate(role_infos): - keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}") - values.append( - json.dumps( - [ - base_global_rank, - global_size, - role_ranks[role_info.role], - role_sizes[role_info.role], - ] - ) - ) - - base_global_rank += role_info.local_world_size - role_ranks[role_info.role] += role_info.local_world_size - - store.multi_set(keys, values) - - # get will block until the data is available in the store. - ( - base_global_rank, - global_world_size, - base_role_rank, - role_world_size, - ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")) - - workers = [] - for local_rank in range(spec.local_world_size): - worker = Worker( - local_rank=local_rank, - global_rank=base_global_rank + local_rank, - role_rank=base_role_rank + local_rank, - world_size=global_world_size, - role_world_size=role_world_size, - ) - workers.append(worker) - return workers - - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - # `mindtorch.distributed.elastic.metrics.prof`. - @prof - def _initialize_workers(self, worker_group: WorkerGroup) -> None: - r"""Start a fresh set of workers for the worker_group. - - Essentially, a rendezvous followed by a ``start_workers``. - The caller should first call ``_stop_workers()`` to stop running workers - prior to calling this method. - - Optimistically sets the state of the worker group that - just started as ``HEALTHY`` and delegates the actual monitoring - of state to ``_monitor_workers()`` method - """ - role = worker_group.spec.role - logger.info("[%s] Rendezvous'ing worker group", role) - - # TODO after stopping workers, wait at least monitor_interval*2 for - # workers on different nodes to fail on a collective op before waiting - # on the rdzv barrier, this way we ensure that nodes enter rdzv - # at around the same time and reduce false positive rdzv timeout errors - self._rendezvous(worker_group) - - logger.info("[%s] Starting worker group", role) - worker_ids = self._start_workers(worker_group) - for local_rank, w_id in worker_ids.items(): - worker = worker_group.workers[local_rank] - worker.id = w_id - - worker_group.state = WorkerState.HEALTHY - - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - # `mindtorch.distributed.elastic.metrics.prof`. - @prof - def _restart_workers(self, worker_group: WorkerGroup) -> None: - """Restart (stops, rendezvous, starts) all local workers in the group.""" - role = worker_group.spec.role - logger.info("[%s] Stopping worker group", role) - self._stop_workers(worker_group, is_restart=True) - worker_group.state = WorkerState.STOPPED - self._initialize_workers(worker_group) - - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - # `mindtorch.distributed.elastic.metrics.prof`. - @prof - def run(self, role: str = DEFAULT_ROLE) -> RunResult: - start_time = time.monotonic() - shutdown_called: bool = False - try: - result = self._invoke_run(role) - self._total_execution_time = int(time.monotonic() - start_time) - self._record_metrics(result) - self._record_worker_events(result) - return result - except RendezvousGracefulExitError as e: - logger.info("Rendezvous gracefully exited: %s", e) - except SignalException as e: - logger.warning("Received %s death signal, shutting down workers", e.sigval) - self._shutdown(e.sigval) - shutdown_called = True - raise - finally: - if not shutdown_called: - self._shutdown() - # record the execution time in case there were any exceptions during run. - self._total_execution_time = int(time.monotonic() - start_time) - - def get_event_failed(self) -> Event: - return self._construct_event( - state="FAILED", - source=EventSource.AGENT, - raw_error=traceback.format_exc(), - ) - - def get_event_succeeded(self) -> Event: - return self._construct_event( - state="SUCCEEDED", - source=EventSource.AGENT, - ) - - def _record_worker_events(self, result: RunResult) -> None: - for worker in self._worker_group.workers: - failure = result.failures.get(worker.global_rank) - state: str = self._get_worker_state(worker, result) - raw_error = json.dumps(failure.error_file_data) if failure else None - record(self._construct_event(state, EventSource.WORKER, worker, raw_error)) - - def _get_worker_state(self, worker: Worker, result: RunResult) -> str: - failure = result.failures.get(worker.global_rank) - if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure: - # The worker got terminated by the torchelastic agent via SIGTERM signal - return "TERMINATED" - elif failure or worker.global_rank in result.return_values: - return result.state.value - else: - raise ValueError(f"Unknown worker: {worker.global_rank}") - - @contextmanager - def record_duration(self, state: str): - start_time = time.perf_counter() - try: - yield - finally: - end_time = time.perf_counter() - duration_ms = (end_time - start_time) * 1000 - record( - self._construct_event( - state=state, source=EventSource.AGENT, duration_ms=duration_ms - ) - ) - - def _construct_event( - self, - state: str, - source: EventSource, - worker: Optional[Worker] = None, - raw_error: Optional[str] = None, - duration_ms: Optional[float] = None, - ) -> Event: - wg = self._worker_group - spec = wg.spec - md = { - "group_world_size": wg.group_world_size, - "entry_point": spec.get_entrypoint_name(), - } - if worker: - md["local_rank"] = (worker.local_rank,) - md["role_rank"] = (worker.role_rank,) - md["role_world_size"] = (worker.role_world_size,) - global_rank = worker.global_rank - worker_id = str(worker.id) - else: - global_rank = None - worker_id = None - md_str = json.dumps(md) - metadata = { - "run_id": spec.rdzv_handler.get_run_id(), - "global_rank": global_rank, - "group_rank": wg.group_rank, - "worker_id": worker_id, - "role": spec.role, - "hostname": _get_fq_hostname(), - "state": state, - "total_run_time": self._total_execution_time, - "rdzv_backend": spec.rdzv_handler.get_backend(), - "raw_error": raw_error, - "metadata": md_str, - "agent_restarts": spec.max_restarts - self._remaining_restarts, - "duration_ms": duration_ms, - } - return Event( - f"torchelastic.worker.status.{state}", source=source, metadata=metadata - ) - - def _record_metrics(self, group_results: RunResult): - is_failed = group_results.is_failed() - self._record_flakiness_metric(is_failed) - spec = self._worker_group.spec - restarts_happened = self._remaining_restarts != spec.max_restarts - put_metric(f"workers.{spec.role}.run_total", 1) - self._record_metric_with_condition( - "run_success_with_retries", not is_failed and restarts_happened - ) - self._record_metric_with_condition( - "run_success_no_retries", not is_failed and not restarts_happened - ) - self._record_metric_with_condition( - "run_failed_with_retries", is_failed and restarts_happened - ) - self._record_metric_with_condition( - "run_failed_no_retries", is_failed and not restarts_happened - ) - - def _record_metric_with_condition(self, metric_name, condition): - spec = self._worker_group.spec - if condition: - put_metric(f"workers.{spec.role}.{metric_name}", 1) - else: - put_metric(f"workers.{spec.role}.{metric_name}", 0) - - def _record_flakiness_metric(self, is_failed: bool = False): - if is_failed: - flakiness = 100.0 - else: - spec = self._worker_group.spec - flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / ( - spec.max_restarts + 1 - ) - spec = self._worker_group.spec - - put_metric(f"workers.{spec.role}.flakiness", int(flakiness)) - - def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: - # NOTE: currently only works for a single role - - spec = self._worker_group.spec - role = spec.role - - logger.info( - "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name() - ) - - self._initialize_workers(self._worker_group) - monitor_interval = spec.monitor_interval - rdzv_handler = spec.rdzv_handler - - while True: - assert self._worker_group.state != WorkerState.INIT - time.sleep(monitor_interval) - run_result = self._monitor_workers(self._worker_group) - state = run_result.state - self._worker_group.state = state - - put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) - put_metric(f"workers.{role}.{state.name.lower()}", 1) - - if state == WorkerState.SUCCEEDED: - logger.info( - "[%s] worker group successfully finished." - " Waiting %s seconds for other agents to finish.", - role, - self._exit_barrier_timeout, - ) - self._exit_barrier() - return run_result - elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: - if self._remaining_restarts > 0: - logger.info( - "[%s] Worker group %s. " - "%s/%s attempts left;" - " will restart worker group", - role, - state.name, - self._remaining_restarts, - spec.max_restarts, - ) - self._remaining_restarts -= 1 - self._restart_workers(self._worker_group) - else: - self._stop_workers(self._worker_group) - self._worker_group.state = WorkerState.FAILED - return run_result - elif state == WorkerState.HEALTHY: - # membership changes do not count as retries - num_nodes_waiting = rdzv_handler.num_nodes_waiting() - group_rank = self._worker_group.group_rank - if num_nodes_waiting > 0: - logger.info( - "[%s] Detected %s " - "new nodes from group_rank=%s; " - "will restart worker group", - role, - num_nodes_waiting, - group_rank, - ) - self._restart_workers(self._worker_group) - else: - raise Exception( # noqa: TRY002 - f"[{role}] Worker group in {state.name} state" - ) - - def _exit_barrier(self): - """ - Define a barrier that keeps the agent process alive until all workers finish. - - Wait for ``exit_barrier_timeout`` seconds for all agents to finish - executing their local workers (either successfully or not). This - acts as a safety guard against user scripts that terminate at different - times. - """ - logger.info( - "Local worker group finished (%s). " - "Waiting %s seconds for other agents to finish", - self._worker_group.state, - self._exit_barrier_timeout, - ) - start = time.time() - try: - store_util.barrier( - store=self._store, - world_size=self._worker_group.group_world_size, - key_prefix=_TERMINAL_STATE_SYNC_ID, - barrier_timeout=self._exit_barrier_timeout, - ) - logger.info( - "Done waiting for other agents. Elapsed: %s seconds", - time.time() - start, - ) - except SignalException as e: - logger.warning("Got termination signal: %s", e.sigval) - raise - except Exception: - logger.exception( - "Error waiting on exit barrier. Elapsed: %s seconds", - time.time() - start, - ) diff --git a/mindtorch/distributed/elastic/agent/server/health_check_server.py b/mindtorch/distributed/elastic/agent/server/health_check_server.py deleted file mode 100644 index 0587758e5..000000000 --- a/mindtorch/distributed/elastic/agent/server/health_check_server.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Callable - -from mindtorch.distributed.elastic.utils.logging import get_logger - - -log = get_logger(__name__) - -__all__ = ["HealthCheckServer", "create_healthcheck_server"] - - -class HealthCheckServer: - """ - Interface for health check monitoring server, which can be extended - by starting tcp/http server on the specified port. - - Args: - - alive_callback: Callable[[], int], callback to last progress time of agent - - port: int, port number to start tcp/http server - - timeout: int, timeout seconds to decide agent is alive/dead - """ - - _alive_callback: Callable[[], int] - _port: int - _timeout: int - - def __init__( - self, alive_callback: Callable[[], int], port: int, timeout: int - ) -> None: - self._alive_callback = alive_callback - self._port = port - self._timeout = timeout - - def start(self) -> None: - """ - Unsupported functionality for Pytorch, doesn't start any health check server - """ - log.warning("No health check server started") - - def stop(self) -> None: - """ - Function to stop health check server - """ - log.info("Stopping noop health check server.") - - -def create_healthcheck_server( - alive_callback: Callable[[], int], - port: int, - timeout: int, -) -> HealthCheckServer: - """ - creates health check server object - """ - return HealthCheckServer(alive_callback, port, timeout) diff --git a/mindtorch/distributed/elastic/agent/server/local_elastic_agent.py b/mindtorch/distributed/elastic/agent/server/local_elastic_agent.py deleted file mode 100644 index 9c84e97f3..000000000 --- a/mindtorch/distributed/elastic/agent/server/local_elastic_agent.py +++ /dev/null @@ -1,417 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import json -import os -import signal -import socket -import time -import uuid -from string import Template -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING - -import mindtorch.distributed.elastic.timer as timer -from mindtorch.distributed.elastic import events -from mindtorch.distributed.elastic.agent.server.api import ( - RunResult, - SimpleElasticAgent, - WorkerGroup, - WorkerSpec, - WorkerState, -) -from mindtorch.distributed.elastic.agent.server.health_check_server import ( - create_healthcheck_server, - HealthCheckServer, -) -from mindtorch.distributed.elastic.metrics.api import prof -from mindtorch.distributed.elastic.multiprocessing import ( - LogsSpecs, - PContext, - start_processes, -) -from mindtorch.distributed.elastic.utils import macros -from mindtorch.distributed.elastic.utils.logging import get_logger - - -if TYPE_CHECKING: - from mindtorch.distributed.elastic.events.api import EventMetadataValue - -logger = get_logger(__name__) - -__all__ = [ - "LocalElasticAgent", - "TORCHELASTIC_ENABLE_FILE_TIMER", - "TORCHELASTIC_TIMER_FILE", - "TORCHELASTIC_HEALTH_CHECK_PORT", -] - -TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER" -TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT" -TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" - - -class LocalElasticAgent(SimpleElasticAgent): - """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. - - This agent is deployed per host and is configured to spawn ``n`` workers. - When using GPUs, ``n`` maps to the number of GPUs available on the host. - - The local agent does not communicate to other local agents deployed on - other hosts, even if the workers may communicate inter-host. The worker id - is interpreted to be a local process. The agent starts and stops all worker - processes as a single unit. - - - The worker function and argument passed to the worker function must be - python multiprocessing compatible. To pass multiprocessing data structures - to the workers you may create the data structure in the same multiprocessing - context as the specified ``start_method`` and pass it as a function argument. - - The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait - for other agents to finish. This acts as a safety net to handle cases where - workers finish at different times, to prevent agents from viewing workers - that finished early as a scale-down event. It is strongly advised that the - user code deal with ensuring that workers are terminated in a synchronous - manner rather than relying on the exit_barrier_timeout. - - A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an - environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has - been defined in the ```LocalElasticAgent``` process. - Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` - can be set with a unique file name for the named pipe. If the environment - variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` - will internally create a unique file name and set it to the environment - variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will - be propagated to the worker processes to allow them to connect to the same - named pipe that ```LocalElasticAgent``` uses. - - Logs are written to the specified log directory. Each log line will be by default - prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``). - Log prefixes can be customized by passing a `template string - `_ as the - ``log_line_prefix_template`` argument. - The following macros (identifiers) are substituted at runtime: - ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with - global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``. - - - Example launching function - - :: - - def trainer(args) -> str: - return "do train" - - def main(): - start_method="spawn" - shared_queue= multiprocessing.get_context(start_method).Queue() - spec = WorkerSpec( - role="trainer", - local_world_size=nproc_per_process, - entrypoint=trainer, - args=("foobar",), - ...) - agent = LocalElasticAgent(spec, start_method) - results = agent.run() - - if results.is_failed(): - print("trainer failed") - else: - print(f"rank 0 return value: {results.return_values[0]}") - # prints -> rank 0 return value: do train - - Example launching binary - - :: - - def main(): - spec = WorkerSpec( - role="trainer", - local_world_size=nproc_per_process, - entrypoint="/usr/local/bin/trainer", - args=("--trainer-args", "foobar"), - ...) - agent = LocalElasticAgent(spec) - results = agent.run() - - if not results.is_failed(): - print("binary launches do not have return values") - - """ - - def __init__( - self, - spec: WorkerSpec, - logs_specs: LogsSpecs, - start_method="spawn", - exit_barrier_timeout: float = 300, - log_line_prefix_template: Optional[str] = None, - ): - super().__init__(spec, exit_barrier_timeout) - self._start_method = start_method - self._pcontext: Optional[PContext] = None - self._rdzv_handler = spec.rdzv_handler - self._log_line_prefix_template = log_line_prefix_template - self._worker_watchdog: Optional[timer.FileTimerServer] = None - self._logs_specs = logs_specs - self._health_check_server: Optional[HealthCheckServer] = None - - def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: - enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER - watchdog_enabled = os.getenv(enable_watchdog_env_name) - watchdog_file_env_name = TORCHELASTIC_TIMER_FILE - watchdog_file_path = os.getenv(watchdog_file_env_name) - if watchdog_enabled is not None and str(watchdog_enabled) == "1": - if watchdog_file_path is None: - watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) - logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) - if not envs: - logger.warning( - "Empty envs variables, using empty run_id for FileTimerServer" - ) - run_id = "" - else: - run_id = envs[0]["TORCHELASTIC_RUN_ID"] - self._worker_watchdog = timer.FileTimerServer( - file_path=watchdog_file_path, - run_id=run_id, - max_interval=0.1, - daemon=True, - log_event=self._log_watchdog_event, - ) - self._worker_watchdog.start() - logger.info("FileTimerServer started") - else: - logger.info( - "Environment variable '%s' not found. Do not start FileTimerServer.", - enable_watchdog_env_name, - ) - # Propagate the watchdog file env to worker processes - if watchdog_file_path is not None: - for worker_env in envs.values(): - worker_env[watchdog_file_env_name] = watchdog_file_path - - @staticmethod - def _get_current_time_secs() -> int: - return int(time.time()) - - def _setup_healthcheck(self) -> None: - healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT - healthcheck_port = os.getenv(healthcheck_port_env_name) - if healthcheck_port is not None: - logger.info( - "Found healthcheck port %s: %s", - healthcheck_port_env_name, - healthcheck_port, - ) - if self._worker_watchdog is None: - logger.info( - "FileTimerServer doesn't exist, using current time as dummy callback" - ) - alive_callback = LocalElasticAgent._get_current_time_secs - else: - alive_callback = self._worker_watchdog.get_last_progress_time - - try: - healthcheck_port_as_int = int(healthcheck_port) - self._health_check_server = create_healthcheck_server( - alive_callback=alive_callback, - port=healthcheck_port_as_int, - timeout=60, - ) - self._health_check_server.start() - except ValueError: - logger.info( - "Invalid healthcheck port value: '%s', expecting integer. Not starting healthcheck server.", - healthcheck_port, - ) - else: - logger.info( - "Environment variable '%s' not found. Do not start health check.", - healthcheck_port_env_name, - ) - - def _get_fq_hostname(self) -> str: - return socket.getfqdn(socket.gethostname()) - - def _log_watchdog_event( - self, - name: str, - request: Optional[timer.FileTimerRequest], - ) -> None: - wg = self._worker_group - spec = wg.spec - md = {"watchdog_event": name} - if request is not None: - md["worker_pid"] = str(request.worker_pid) - md["scope_id"] = request.scope_id - md["expiration_time"] = str(request.expiration_time) - md["signal"] = str(request.signal) - md_str = json.dumps(md) - state = "RUNNING" - metadata: Dict[str, EventMetadataValue] = { - "run_id": spec.rdzv_handler.get_run_id(), - "global_rank": None, - "group_rank": wg.group_rank, - "worker_id": None, - "role": spec.role, - "hostname": self._get_fq_hostname(), - "state": state, - "total_run_time": self._total_execution_time, - "rdzv_backend": spec.rdzv_handler.get_backend(), - "raw_error": None, - "metadata": md_str, - "agent_restarts": spec.max_restarts - self._remaining_restarts, - } - # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later. - # The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry. - event = events.Event( - name=name, source=events.EventSource.AGENT, metadata=metadata - ) - events.record(event) - - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - # `mindtorch.distributed.elastic.metrics.prof`. - @prof - def _stop_workers( - self, worker_group: WorkerGroup, is_restart: bool = False - ) -> None: - self._shutdown(is_restart=is_restart) - - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - # `mindtorch.distributed.elastic.metrics.prof`. - @prof - def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: - spec = worker_group.spec - store = worker_group.store - assert store is not None - restart_count = spec.max_restarts - self._remaining_restarts - - use_agent_store: bool = spec.rdzv_handler.use_agent_store - logger.info("use_agent_store: %s", use_agent_store) - - args: Dict[int, Tuple] = {} - envs: Dict[int, Dict[str, str]] = {} - log_line_prefixes: Optional[Dict[int, str]] = ( - {} if self._log_line_prefix_template else None - ) - for worker in worker_group.workers: - local_rank = worker.local_rank - worker_env = { - "LOCAL_RANK": str(local_rank), - "RANK": str(worker.global_rank), - "GROUP_RANK": str(worker_group.group_rank), - "ROLE_RANK": str(worker.role_rank), - "ROLE_NAME": spec.role, - "LOCAL_WORLD_SIZE": str(spec.local_world_size), - "WORLD_SIZE": str(worker.world_size), - "GROUP_WORLD_SIZE": str(worker_group.group_world_size), - "ROLE_WORLD_SIZE": str(worker.role_world_size), - "MASTER_ADDR": worker_group.master_addr, - "MASTER_PORT": str(worker_group.master_port), - "TORCHELASTIC_RESTART_COUNT": str(restart_count), - "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), - "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), - "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), - "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv( - "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1) - ), - } - if "OMP_NUM_THREADS" in os.environ: - worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] - - if self._log_line_prefix_template: - log_line_prefix = Template( - self._log_line_prefix_template - ).safe_substitute( - role_name=spec.role, - rank=worker.global_rank, - local_rank=local_rank, - ) - log_line_prefixes[local_rank] = log_line_prefix - - envs[local_rank] = worker_env - worker_args = list(spec.args) - worker_args = macros.substitute(worker_args, str(local_rank)) - args[local_rank] = tuple(worker_args) - - self._setup_local_watchdog(envs=envs) - self._setup_healthcheck() - - assert spec.entrypoint is not None - assert self._logs_specs is not None - self._pcontext = start_processes( - name=spec.role, - entrypoint=spec.entrypoint, - args=args, - envs=envs, - logs_specs=self._logs_specs, - log_line_prefixes=log_line_prefixes, - start_method=self._start_method, - ) - - return self._pcontext.pids() - - def _shutdown( - self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False - ) -> None: - if self._worker_watchdog is not None: - self._worker_watchdog.stop() - self._worker_watchdog = None - if self._health_check_server is not None: - self._health_check_server.stop() - self._health_check_server = None - if self._pcontext: - self._pcontext.close(death_sig) - if not is_restart and self._rdzv_handler: - self._rdzv_handler.shutdown() - - # pyre-fixme[56]: Pyre was not able to infer the type of the decorator - # `mindtorch.distributed.elastic.metrics.prof`. - @prof - def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: - role = worker_group.spec.role - worker_pids = {w.id for w in worker_group.workers} - assert self._pcontext is not None - pc_pids = set(self._pcontext.pids().values()) - if worker_pids != pc_pids: - logger.error( - "[%s] worker pids do not match process_context pids." - " Expected: %s, actual: %s", - role, - worker_pids, - pc_pids, - ) - return RunResult(state=WorkerState.UNKNOWN) - - result = self._pcontext.wait(0) - if result: - if result.is_failed(): - # map local rank failure to global rank - worker_failures = {} - for local_rank, failure in result.failures.items(): - worker = worker_group.workers[local_rank] - worker_failures[worker.global_rank] = failure - return RunResult( - state=WorkerState.FAILED, - failures=worker_failures, - ) - else: - # copy ret_val_queue into a map with a global ranks - workers_ret_vals = {} - for local_rank, ret_val in result.return_values.items(): - worker = worker_group.workers[local_rank] - workers_ret_vals[worker.global_rank] = ret_val - return RunResult( - state=WorkerState.SUCCEEDED, - return_values=workers_ret_vals, - ) - else: - return RunResult(state=WorkerState.HEALTHY) diff --git a/mindtorch/distributed/elastic/control_plane.py b/mindtorch/distributed/elastic/control_plane.py deleted file mode 100644 index c7e97f099..000000000 --- a/mindtorch/distributed/elastic/control_plane.py +++ /dev/null @@ -1,52 +0,0 @@ -import os -from contextlib import contextmanager, ExitStack -from typing import Generator - -from mindtorch.distributed.elastic.multiprocessing.errors import record - - -__all__ = [ - "worker_main", -] - -TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" - - -@contextmanager -def _worker_server(socket_path: str) -> Generator[None, None, None]: - from mindtorch._C._distributed_c10d import _WorkerServer - - server = _WorkerServer(socket_path) - try: - yield - finally: - server.shutdown() - - -@contextmanager -@record -def worker_main() -> Generator[None, None, None]: - """ - This is a context manager that wraps your main entry function. This combines - the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that - exposes handlers via a unix socket specified by - ``Torch_WORKER_SERVER_SOCKET``. - - Example - - :: - - @worker_main() - def main(): - pass - - if __name__=="__main__": - main() - - """ - with ExitStack() as stack: - socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) - if socket_path is not None: - stack.enter_context(_worker_server(socket_path)) - - yield diff --git a/mindtorch/distributed/elastic/events/__init__.py b/mindtorch/distributed/elastic/events/__init__.py deleted file mode 100644 index 10fd83198..000000000 --- a/mindtorch/distributed/elastic/events/__init__.py +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env/python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Module contains events processing mechanisms that are integrated with the standard python logging. - -Example of usage: - -:: - - from mindtorch.distributed.elastic import events - event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...}) - events.get_logging_handler(destination="console").info(event) - -""" - -import inspect -import logging -import os -import socket -import traceback -from typing import Dict, Optional - -from mindtorch.distributed.elastic.events.handlers import get_logging_handler - -from .api import ( # noqa: F401 - Event, - EventMetadataValue, - EventSource, - NodeState, - RdzvEvent, -) - - -_events_loggers: Dict[str, logging.Logger] = {} - - -def _get_or_create_logger(destination: str = "null") -> logging.Logger: - """ - Construct python logger based on the destination type or extends if provided. - - Available destination could be found in ``handlers.py`` file. - The constructed logger does not propagate messages to the upper level loggers, - e.g. root logger. This makes sure that a single event can be processed once. - - Args: - destination: The string representation of the event handler. - Available handlers found in ``handlers`` module - """ - global _events_loggers - - if destination not in _events_loggers: - _events_logger = logging.getLogger(f"torchelastic-events-{destination}") - _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) - # Do not propagate message to the root logger - _events_logger.propagate = False - - logging_handler = get_logging_handler(destination) - _events_logger.addHandler(logging_handler) - - # Add the logger to the global dictionary - _events_loggers[destination] = _events_logger - - return _events_loggers[destination] - - -def record(event: Event, destination: str = "null") -> None: - _get_or_create_logger(destination).info(event.serialize()) - - -def record_rdzv_event(event: RdzvEvent) -> None: - _get_or_create_logger("dynamic_rendezvous").info(event.serialize()) - - -def construct_and_record_rdzv_event( - run_id: str, - message: str, - node_state: NodeState, - name: str = "", - hostname: str = "", - pid: Optional[int] = None, - master_endpoint: str = "", - local_id: Optional[int] = None, - rank: Optional[int] = None, -) -> None: - """ - Initialize rendezvous event object and record its operations. - - Args: - run_id (str): The run id of the rendezvous. - message (str): The message describing the event. - node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED). - name (str): Event name. (E.g. Current action being performed). - hostname (str): Hostname of the node. - pid (Optional[int]): The process id of the node. - master_endpoint (str): The master endpoint for the rendezvous store, if known. - local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py - rank (Optional[int]): The rank of the node, if known. - Returns: - None - Example: - >>> # See DynamicRendezvousHandler class - >>> def _record( - ... self, - ... message: str, - ... node_state: NodeState = NodeState.RUNNING, - ... rank: Optional[int] = None, - ... ) -> None: - ... construct_and_record_rdzv_event( - ... name=f"{self.__class__.__name__}.{get_method_name()}", - ... run_id=self._settings.run_id, - ... message=message, - ... node_state=node_state, - ... hostname=self._this_node.addr, - ... pid=self._this_node.pid, - ... local_id=self._this_node.local_id, - ... rank=rank, - ... ) - """ - # We don't want to perform an extra computation if not needed. - if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler): - return - - # Set up parameters. - if not hostname: - hostname = socket.getfqdn() - if not pid: - pid = os.getpid() - - # Determines which file called this function. - callstack = inspect.stack() - filename = "no_file" - if len(callstack) > 1: - stack_depth_1 = callstack[1] - filename = os.path.basename(stack_depth_1.filename) - if not name: - name = stack_depth_1.function - - # Delete the callstack variable. If kept, this can mess with python's - # garbage collector as we are holding on to stack frame information in - # the inspect module. - del callstack - - # Set up error trace if this is an exception - if node_state == NodeState.FAILED: - error_trace = traceback.format_exc() - else: - error_trace = "" - - # Initialize event object - event = RdzvEvent( - name=f"{filename}:{name}", - run_id=run_id, - message=message, - hostname=hostname, - pid=pid, - node_state=node_state, - master_endpoint=master_endpoint, - rank=rank, - local_id=local_id, - error_trace=error_trace, - ) - - # Finally, record the event. - record_rdzv_event(event) diff --git a/mindtorch/distributed/elastic/events/api.py b/mindtorch/distributed/elastic/events/api.py deleted file mode 100644 index f85fdd835..000000000 --- a/mindtorch/distributed/elastic/events/api.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import json -from dataclasses import asdict, dataclass, field -from enum import Enum -from typing import Dict, Optional, Union - - -__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"] - -EventMetadataValue = Union[str, int, float, bool, None] - - -class EventSource(str, Enum): - """Known identifiers of the event producers.""" - - AGENT = "AGENT" - WORKER = "WORKER" - - -@dataclass -class Event: - """ - The class represents the generic event that occurs during the torchelastic job execution. - - The event can be any kind of meaningful action. - - Args: - name: event name. - source: the event producer, e.g. agent or worker - timestamp: timestamp in milliseconds when event occurred. - metadata: additional data that is associated with the event. - """ - - name: str - source: EventSource - timestamp: int = 0 - metadata: Dict[str, EventMetadataValue] = field(default_factory=dict) - - def __str__(self): - return self.serialize() - - @staticmethod - def deserialize(data: Union[str, "Event"]) -> "Event": - if isinstance(data, Event): - return data - if isinstance(data, str): - data_dict = json.loads(data) - data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] - return Event(**data_dict) - - def serialize(self) -> str: - return json.dumps(asdict(self)) - - -class NodeState(str, Enum): - """The states that a node can be in rendezvous.""" - - INIT = "INIT" - RUNNING = "RUNNING" - SUCCEEDED = "SUCCEEDED" - FAILED = "FAILED" - - -@dataclass -class RdzvEvent: - """ - Dataclass to represent any rendezvous event. - - Args: - name: Event name. (E.g. Current action being performed) - run_id: The run id of the rendezvous - message: The message describing the event - hostname: Hostname of the node - pid: The process id of the node - node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED) - master_endpoint: The master endpoint for the rendezvous store, if known - rank: The rank of the node, if known - local_id: The local_id of the node, if defined in dynamic_rendezvous.py - error_trace: Error stack trace, if this is an error event. - """ - - name: str - run_id: str - message: str - hostname: str - pid: int - node_state: NodeState - master_endpoint: str = "" - rank: Optional[int] = None - local_id: Optional[int] = None - error_trace: str = "" - - def __str__(self): - return self.serialize() - - @staticmethod - def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent": - if isinstance(data, RdzvEvent): - return data - if isinstance(data, str): - data_dict = json.loads(data) - data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] - return RdzvEvent(**data_dict) - - def serialize(self) -> str: - return json.dumps(asdict(self)) diff --git a/mindtorch/distributed/elastic/events/handlers.py b/mindtorch/distributed/elastic/events/handlers.py deleted file mode 100644 index 51dd14280..000000000 --- a/mindtorch/distributed/elastic/events/handlers.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from typing import Dict - - -_log_handlers: Dict[str, logging.Handler] = { - "console": logging.StreamHandler(), - "dynamic_rendezvous": logging.NullHandler(), - "null": logging.NullHandler(), -} - - -def get_logging_handler(destination: str = "null") -> logging.Handler: - global _log_handlers - return _log_handlers[destination] diff --git a/mindtorch/distributed/elastic/metrics/__init__.py b/mindtorch/distributed/elastic/metrics/__init__.py deleted file mode 100644 index d02fe49e9..000000000 --- a/mindtorch/distributed/elastic/metrics/__init__.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env/python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Metrics API. - -**Overview**: - -The metrics API in torchelastic is used to publish telemetry metrics. -It is designed to be used by torchelastic's internal modules to -publish metrics for the end user with the goal of increasing visibility -and helping with debugging. However you may use the same API in your -jobs to publish metrics to the same metrics ``sink``. - -A ``metric`` can be thought of as timeseries data -and is uniquely identified by the string-valued tuple -``(metric_group, metric_name)``. - -torchelastic makes no assumptions about what a ``metric_group`` is -and what relationship it has with ``metric_name``. It is totally up -to the user to use these two fields to uniquely identify a metric. - -.. note:: The metric group ``torchelastic`` is reserved by torchelastic for - platform level metrics that it produces. - For instance torchelastic may output the latency (in milliseconds) - of a re-rendezvous operation from the agent as - ``(torchelastic, agent.rendezvous.duration.ms)`` - -A sensible way to use metric groups is to map them to a stage or module -in your job. You may also encode certain high level properties -the job such as the region or stage (dev vs prod). - -**Publish Metrics**: - -Using torchelastic's metrics API is similar to using python's logging -framework. You first have to configure a metrics handler before -trying to add metric data. - -The example below measures the latency for the ``calculate()`` function. - -:: - - import time - import mindtorch.distributed.elastic.metrics as metrics - - # makes all metrics other than the one from "my_module" to go /dev/null - metrics.configure(metrics.NullMetricsHandler()) - metrics.configure(metrics.ConsoleMetricsHandler(), "my_module") - - def my_method(): - start = time.time() - calculate() - end = time.time() - metrics.put_metric("calculate_latency", int(end-start), "my_module") - -You may also use the mindtorch.distributed.elastic.metrics.prof` decorator -to conveniently and succinctly profile functions - -:: - - # -- in module examples.foobar -- - - import mindtorch.distributed.elastic.metrics as metrics - - metrics.configure(metrics.ConsoleMetricsHandler(), "foobar") - metrics.configure(metrics.ConsoleMetricsHandler(), "Bar") - - @metrics.prof - def foo(): - pass - - class Bar(): - - @metrics.prof - def baz(): - pass - -``@metrics.prof`` will publish the following metrics -:: - - .success - 1 if the function finished successfully - .failure - 1 if the function threw an exception - .duration.ms - function duration in milliseconds - -**Configuring Metrics Handler**: - -`mindtorch.distributed.elastic.metrics.MetricHandler` is responsible for emitting -the added metric values to a particular destination. Metric groups can be -configured with different metric handlers. - -By default torchelastic emits all metrics to ``/dev/null``. -By adding the following configuration metrics, -``torchelastic`` and ``my_app`` metric groups will be printed out to -console. - -:: - - import mindtorch.distributed.elastic.metrics as metrics - - metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic") - metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app") - -**Writing a Custom Metric Handler**: - -If you want your metrics to be emitted to a custom location, implement -the `mindtorch.distributed.elastic.metrics.MetricHandler` interface -and configure your job to use your custom metric handler. - -Below is a toy example that prints the metrics to ``stdout`` - -:: - - import mindtorch.distributed.elastic.metrics as metrics - - class StdoutMetricHandler(metrics.MetricHandler): - def emit(self, metric_data): - ts = metric_data.timestamp - group = metric_data.group_name - name = metric_data.name - value = metric_data.value - print(f"[{ts}][{group}]: {name}={value}") - - metrics.configure(StdoutMetricHandler(), group="my_app") - -Now all metrics in the group ``my_app`` will be printed to stdout as: - -:: - - [1574213883.4182858][my_app]: my_metric= - [1574213940.5237644][my_app]: my_metric= - -""" - -from typing import Optional - -from .api import ( # noqa: F401 - configure, - ConsoleMetricHandler, - get_elapsed_time_ms, - getStream, - MetricData, - MetricHandler, - MetricsConfig, - NullMetricHandler, - prof, - profile, - publish_metric, - put_metric, -) - - -def initialize_metrics(cfg: Optional[MetricsConfig] = None): - pass - - -try: - from mindtorch.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F401 F403 -except ModuleNotFoundError: - pass diff --git a/mindtorch/distributed/elastic/metrics/api.py b/mindtorch/distributed/elastic/metrics/api.py deleted file mode 100644 index 81a62f66f..000000000 --- a/mindtorch/distributed/elastic/metrics/api.py +++ /dev/null @@ -1,216 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import abc -import time -from collections import namedtuple -from functools import wraps -from typing import Dict, Optional -from typing_extensions import deprecated - - -__all__ = [ - "MetricsConfig", - "MetricHandler", - "ConsoleMetricHandler", - "NullMetricHandler", - "MetricStream", - "configure", - "getStream", - "prof", - "profile", - "put_metric", - "publish_metric", - "get_elapsed_time_ms", - "MetricData", -] - -MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"]) - - -class MetricsConfig: - __slots__ = ["params"] - - def __init__(self, params: Optional[Dict[str, str]] = None): - self.params = params - if self.params is None: - self.params = {} - - -class MetricHandler(abc.ABC): - @abc.abstractmethod - def emit(self, metric_data: MetricData): - pass - - -class ConsoleMetricHandler(MetricHandler): - def emit(self, metric_data: MetricData): - print( - f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}" - ) - - -class NullMetricHandler(MetricHandler): - def emit(self, metric_data: MetricData): - pass - - -class MetricStream: - def __init__(self, group_name: str, handler: MetricHandler): - self.group_name = group_name - self.handler = handler - - def add_value(self, metric_name: str, metric_value: int): - self.handler.emit( - MetricData(time.time(), self.group_name, metric_name, metric_value) - ) - - -_metrics_map: Dict[str, MetricHandler] = {} -_default_metrics_handler: MetricHandler = NullMetricHandler() - - -# pyre-fixme[9]: group has type `str`; used as `None`. -def configure(handler: MetricHandler, group: Optional[str] = None): - if group is None: - global _default_metrics_handler - # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used - # as `MetricHandler`. - _default_metrics_handler = handler - else: - _metrics_map[group] = handler - - -def getStream(group: str): - if group in _metrics_map: - handler = _metrics_map[group] - else: - handler = _default_metrics_handler - return MetricStream(group, handler) - - -def _get_metric_name(fn): - qualname = fn.__qualname__ - split = qualname.split(".") - if len(split) == 1: - module = fn.__module__ - if module: - return module.split(".")[-1] + "." + split[0] - else: - return split[0] - else: - return qualname - - -def prof(fn=None, group: str = "torchelastic"): - r""" - @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates. - - The metric name defaults to the qualified name (``class_name.def_name``) of the function. - If the function does not belong to a class, it uses the leaf module name instead. - - Usage - - :: - - @metrics.prof - def x(): - pass - - @metrics.prof(group="agent") - def y(): - pass - """ - - def wrap(f): - @wraps(f) - def wrapper(*args, **kwargs): - key = _get_metric_name(f) - try: - start = time.time() - result = f(*args, **kwargs) - put_metric(f"{key}.success", 1, group) - except Exception: - put_metric(f"{key}.failure", 1, group) - raise - finally: - put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined] - return result - - return wrapper - - if fn: - return wrap(fn) - else: - return wrap - - -@deprecated("Deprecated, use `@prof` instead", category=FutureWarning) -def profile(group=None): - """ - @profile decorator adds latency and success/failure metrics to any given function. - - Usage - - :: - - @metrics.profile("my_metric_group") - def some_function(): - """ - - def wrap(func): - @wraps(func) - def wrapper(*args, **kwargs): - try: - start_time = time.time() - result = func(*args, **kwargs) - publish_metric(group, f"{func.__name__}.success", 1) - except Exception: - publish_metric(group, f"{func.__name__}.failure", 1) - raise - finally: - publish_metric( - group, - f"{func.__name__}.duration.ms", - get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined] - ) - return result - - return wrapper - - return wrap - - -def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"): - """ - Publish a metric data point. - - Usage - - :: - - put_metric("metric_name", 1) - put_metric("metric_name", 1, "metric_group_name") - """ - getStream(metric_group).add_value(metric_name, metric_value) - - -@deprecated( - "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead", - category=FutureWarning, -) -def publish_metric(metric_group: str, metric_name: str, metric_value: int): - metric_stream = getStream(metric_group) - metric_stream.add_value(metric_name, metric_value) - - -def get_elapsed_time_ms(start_time_in_seconds: float): - """Return the elapsed time in millis from the given start time.""" - end_time = time.time() - return int((end_time - start_time_in_seconds) * 1000) diff --git a/mindtorch/distributed/elastic/multiprocessing/__init__.py b/mindtorch/distributed/elastic/multiprocessing/__init__.py deleted file mode 100644 index 4aaa6fb49..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/__init__.py +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary. - -For functions, it uses ``mindtorch.multiprocessing`` (and therefore python -``multiprocessing``) to spawn/fork worker processes. For binaries it uses python -``subprocessing.Popen`` to create worker processes. - - -Usage 1: Launching two trainers as a function - -:: - - from mindtorch.distributed.elastic.multiprocessing import Std, start_processes - - def trainer(a, b, c): - pass # train - - - # runs two trainers - # LOCAL_RANK=0 trainer(1,2,3) - # LOCAL_RANK=1 trainer(4,5,6) - ctx = start_processes( - name="trainer", - entrypoint=trainer, - args={0: (1,2,3), 1: (4,5,6)}, - envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}}, - log_dir="/tmp/foobar", - redirects=Std.ALL, # write all worker stdout/stderr to a log file - tee={0: Std.ERR}, # tee only local rank 0's stderr to console - ) - - # waits for all copies of trainer to finish - ctx.wait() - -Usage 2: Launching 2 echo workers as a binary - -:: - - # same as invoking - # echo hello - # echo world > stdout.log - ctx = start_processes( - name="echo" - entrypoint="echo", - log_dir="/tmp/foobar", - args={0: "hello", 1: "world"}, - redirects={1: Std.OUT}, - ) - -Just like ``mindtorch.multiprocessing``, the return value of the function -:func:`start_processes` is a process context (:class:`api.PContext`). If a function -was launched, a :class:`api.MultiprocessContext` is returned and if a binary -was launched a :class:`api.SubprocessContext` is returned. Both are specific -implementations of the parent :class:`api.PContext` class. -""" - -from typing import Callable, Dict, Optional, Tuple, Union - -from mindtorch.distributed.elastic.multiprocessing.api import ( # noqa: F401 - _validate_full_rank, - DefaultLogsSpecs, - LogsDest, - LogsSpecs, - MultiprocessContext, - PContext, - ProcessFailure, - RunProcsResult, - SignalException, - Std, - SubprocessContext, - to_map, -) -from mindtorch.distributed.elastic.utils.logging import get_logger - - -__all__ = [ - "start_processes", - "MultiprocessContext", - "PContext", - "ProcessFailure", - "RunProcsResult", - "SignalException", - "Std", - "LogsDest", - "LogsSpecs", - "DefaultLogsSpecs", - "SubprocessContext", - "to_map", -] - - -def start_processes( - name: str, - entrypoint: Union[Callable, str], - args: Dict[int, Tuple], - envs: Dict[int, Dict[str, str]], - logs_specs: LogsSpecs, - log_line_prefixes: Optional[Dict[int, str]] = None, - start_method: str = "spawn", -) -> PContext: - """ - Start ``n`` copies of ``entrypoint`` processes with the provided options. - - ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary). - The number of copies is determined by the number of entries for ``args`` and - ``envs`` arguments, which need to have the same key set. - - ``args`` and ``env`` parameters are the arguments and environment variables - to pass down to the entrypoint mapped by the replica index (local rank). - All local ranks must be accounted for. - That is, the keyset should be ``{0,1,...,(nprocs-1)}``. - - .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings. - If any other type is given, then it is casted to a string representation - (e.g. ``str(arg1)``). Furthermore, a binary failure will only write - an ``error.json`` error file if the main function is annotated with - ``mindtorch.distributed.elastic.multiprocessing.errors.record``. For function launches, - this is done by default and there is no need to manually annotate - with the ``@record`` annotation. - - ``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect - to a log file in the ``log_dir``. Valid mask values are defined in ``Std``. - To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as - the local rank to specify the redirect behavior for. - Any missing local ranks will default to ``Std.NONE``. - - ``tee`` acts like the unix "tee" command in that it redirects + prints to console. - To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter. - - For each process, the ``log_dir`` will contain: - - #. ``{local_rank}/error.json``: if the process failed, a file with the error info - #. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT`` - #. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR`` - - .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory. - - Example: - :: - - log_dir = "/tmp/test" - - # ok; two copies of foo: foo("bar0"), foo("bar1") - start_processes( - name="trainer", - entrypoint=foo, - args:{0:("bar0",), 1:("bar1",), - envs:{0:{}, 1:{}}, - log_dir=log_dir - ) - - # invalid; envs missing for local rank 1 - start_processes( - name="trainer", - entrypoint=foo, - args:{0:("bar0",), 1:("bar1",), - envs:{0:{}}, - log_dir=log_dir - ) - - # ok; two copies of /usr/bin/touch: touch file1, touch file2 - start_processes( - name="trainer", - entrypoint="/usr/bin/touch", - args:{0:("file1",), 1:("file2",), - envs:{0:{}, 1:{}}, - log_dir=log_dir - ) - - # caution; arguments casted to string, runs: - # echo "1" "2" "3" and echo "[1, 2, 3]" - start_processes( - name="trainer", - entrypoint="/usr/bin/echo", - args:{0:(1,2,3), 1:([1,2,3],), - envs:{0:{}, 1:{}}, - log_dir=log_dir - ) - - Args: - name: a human readable short name that describes what the processes are - (used as header when tee'ing stdout/stderr outputs) - entrypoint: either a ``Callable`` (function) or ``cmd`` (binary) - args: arguments to each replica - envs: env vars to each replica - log_dir: directory used to write log files - start_method: multiprocessing start method (spawn, fork, forkserver) - ignored for binaries - redirects: which std streams to redirect to a log file - tee: which std streams to redirect + print to console - local_ranks_filter: which ranks' logs to print to console - - """ - - nprocs = len(args) - _validate_full_rank(args, nprocs, "args") - _validate_full_rank(envs, nprocs, "envs") - - context: PContext - if isinstance(entrypoint, str): - context = SubprocessContext( - name=name, - entrypoint=entrypoint, - args=args, - envs=envs, - logs_specs=logs_specs, - log_line_prefixes=log_line_prefixes, - ) - else: - context = MultiprocessContext( - name=name, - entrypoint=entrypoint, - args=args, - envs=envs, - log_line_prefixes=log_line_prefixes, - start_method=start_method, - logs_specs=logs_specs, - ) - - try: - context.start() - return context - except Exception: - context.close() - raise diff --git a/mindtorch/distributed/elastic/multiprocessing/api.py b/mindtorch/distributed/elastic/multiprocessing/api.py deleted file mode 100644 index b5559190b..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/api.py +++ /dev/null @@ -1,923 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import abc -import logging -import os -import re -import shutil -import signal -import subprocess -import sys -import tempfile -import threading -import time -from abc import ABC, abstractmethod -from contextlib import nullcontext -from dataclasses import dataclass, field -from enum import IntFlag -from multiprocessing import synchronize -from types import FrameType -from typing import Any, Callable, Dict, Optional, Set, Tuple, Union - -import mindtorch.multiprocessing as mp -from mindtorch.distributed.elastic.multiprocessing.errors import ProcessFailure, record -from mindtorch.distributed.elastic.multiprocessing.redirects import ( - redirect_stderr, - redirect_stdout, -) -from mindtorch.distributed.elastic.multiprocessing.subprocess_handler import ( - get_subprocess_handler, - SubprocessHandler, -) -from mindtorch.distributed.elastic.multiprocessing.tail_log import TailLog - - -IS_WINDOWS = sys.platform == "win32" -IS_MACOS = sys.platform == "darwin" - - -logger = logging.getLogger(__name__) - -__all__ = [ - "DefaultLogsSpecs", - "SignalException", - "Std", - "to_map", - "RunProcsResult", - "PContext", - "get_std_cm", - "MultiprocessContext", - "SubprocessContext", - "LogsDest", - "LogsSpecs", -] - - -class SignalException(Exception): - """ - Exception is raised inside the torchelastic agent process by the termination handler - if the death signal got received by the process. - """ - - def __init__(self, msg: str, sigval: signal.Signals) -> None: - super().__init__(msg) - self.sigval = sigval - - -def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None: - """Termination handler that raises exceptions on the main process. - - When the process receives death signal(SIGTERM, SIGINT), this termination handler will - be invoked. It raises the ``SignalException`` exception that should be processed by the - user code. Python does not terminate process after the termination handler is finished, - so the exception should not be silently ignored, otherwise the process will never - be terminated. - """ - sigval = signal.Signals(signum) - raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) - - -def _get_kill_signal() -> signal.Signals: - """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.""" - if IS_WINDOWS: - return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 - else: - return signal.SIGKILL - - -def _get_default_signal() -> signal.Signals: - """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" - if IS_WINDOWS: - return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 - else: - return signal.SIGTERM - - -def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str): - actual_keys = set(d.keys()) - expected_keys = set(range(nprocs)) - - if actual_keys != expected_keys: - raise RuntimeError( - f"{what}, local rank mapping mismatch," - f" expected: {expected_keys}, actual: {actual_keys}" - ) - - -_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$" -_VALUE_REGEX = r"^[0123]$" - - -class Std(IntFlag): - NONE = 0 - OUT = 1 - ERR = 2 - ALL = OUT | ERR - - @classmethod - def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]: - """ - Example: - :: - - from_str("0") -> Std.NONE - from_str("1") -> Std.OUT - from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR} - - Any other input raises an exception - """ - - def to_std(v: str) -> Std: # type: ignore[return] - s = Std(int(v)) - if s in Std: - return s - # return None -> should NEVER reach here since we regex check input - - if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0) - return to_std(vm) - elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2) - d: Dict[int, Std] = {} - for m in vm.split(","): - i, v = m.split(":") - d[int(i)] = to_std(v) - return d - else: - raise ValueError( - f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>" - ) - - -def to_map( - val_or_map: Union[Std, Dict[int, Std]], local_world_size: int -) -> Dict[int, Std]: - """ - Certain APIs take redirect settings either as a single value (e.g. apply to all - local ranks) or as an explicit user-provided mapping. This method is a convenience - method that converts a value or mapping into a mapping. - - Example: - :: - - to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} - to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} - to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} - """ - if isinstance(val_or_map, Std): - return dict.fromkeys(range(local_world_size), val_or_map) - else: - map = {} - for i in range(local_world_size): - map[i] = val_or_map.get(i, Std.NONE) - return map - - -@dataclass -class LogsDest: - """ - For each log type, holds mapping of local rank ids to file paths. - """ - - stdouts: Dict[int, str] = field(default_factory=dict) - stderrs: Dict[int, str] = field(default_factory=dict) - tee_stdouts: Dict[int, str] = field(default_factory=dict) - tee_stderrs: Dict[int, str] = field(default_factory=dict) - error_files: Dict[int, str] = field(default_factory=dict) - - -class LogsSpecs(ABC): - """ - Defines logs processing and redirection for each worker process. - - Args: - log_dir: - Base directory where logs will be written. - redirects: - Streams to redirect to files. Pass a single ``Std`` - enum to redirect for all workers, or a mapping keyed - by local_rank to selectively redirect. - tee: - Streams to duplicate to stdout/stderr. - Pass a single ``Std`` enum to duplicate streams for all workers, - or a mapping keyed by local_rank to selectively duplicate. - """ - - def __init__( - self, - log_dir: Optional[str] = None, - redirects: Union[Std, Dict[int, Std]] = Std.NONE, - tee: Union[Std, Dict[int, Std]] = Std.NONE, - local_ranks_filter: Optional[Set[int]] = None, - ) -> None: - self._root_log_dir = log_dir - self._redirects = redirects - self._tee = tee - self._local_ranks_filter = local_ranks_filter - - @abstractmethod - def reify( - self, - envs: Dict[int, Dict[str, str]], - ) -> LogsDest: - """ - Given the environment variables, builds destination of log files for each of the local ranks. - - Envs parameter contains env variables dict for each of the local ranks, where entries are defined in: - :func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`. - """ - - @property - @abstractmethod - def root_log_dir(self) -> str: - pass - - -class DefaultLogsSpecs(LogsSpecs): - """ - Default LogsSpecs implementation: - - - `log_dir` will be created if it doesn't exist - - Generates nested folders for each attempt and rank. - """ - - def __init__( - self, - log_dir: Optional[str] = None, - redirects: Union[Std, Dict[int, Std]] = Std.NONE, - tee: Union[Std, Dict[int, Std]] = Std.NONE, - local_ranks_filter: Optional[Set[int]] = None, - ) -> None: - if log_dir != os.devnull: - if not log_dir: - log_dir = tempfile.mkdtemp(prefix="torchelastic_") - elif not os.path.exists(log_dir): - os.makedirs(log_dir, exist_ok=True) - else: - if os.path.isfile(log_dir): - raise NotADirectoryError(f"log_dir: {log_dir} is a file") - super().__init__(log_dir, redirects, tee, local_ranks_filter) - # initialized only once - self._run_log_dir = None - - @property - def root_log_dir(self) -> str: - return str(self._root_log_dir) - - def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): - base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") - os.makedirs(base_log_dir, exist_ok=True) - dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) - logger.info("log directory set to: %s", dir) - return dir - - def reify( - self, - envs: Dict[int, Dict[str, str]], - ) -> LogsDest: - """ - Uses following scheme to build log destination paths: - - - `//attempt_//stdout.log` - - `//attempt_//stderr.log` - - `//attempt_//error.json` - """ - nprocs = len(envs) - global_env = {} # use only to query properies that are not dependent on a rank - if nprocs > 0: - global_env = envs[0] - else: - logger.warning( - "Empty envs map provided when defining logging destinations." - ) - # Keys are always defined, but values can be missing in unit tests - run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") - restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") - - attempt_log_dir: str = "" - if self._root_log_dir != os.devnull: - if not self._run_log_dir: - self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) - - attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload] - shutil.rmtree(attempt_log_dir, ignore_errors=True) - os.makedirs(attempt_log_dir) - - if self._root_log_dir == os.devnull: - attempt_log_dir = os.devnull - - # create subdirs for each local rank in the logs_dir - # logs_dir - # |- 0 - # |- error.json - # |- stdout.log - # |- stderr.log - # |- ... - # |- (nprocs-1) - redirs = to_map(self._redirects, nprocs) - ts = to_map(self._tee, nprocs) - - # to tee stdout/stderr we first redirect into a file - # then tail -f stdout.log/stderr.log so add tee settings to redirects - for local_rank, tee_std in ts.items(): - redirect_std = redirs[local_rank] - redirs[local_rank] = redirect_std | tee_std - - SYS_STREAM = "" # special case to indicate to output to console - stdouts = dict.fromkeys(range(nprocs), SYS_STREAM) - stderrs = dict.fromkeys(range(nprocs), SYS_STREAM) - tee_stdouts: Dict[int, str] = {} - tee_stderrs: Dict[int, str] = {} - error_files = {} - - for local_rank in range(nprocs): - if attempt_log_dir == os.devnull: - tee_stdouts[local_rank] = os.devnull - tee_stderrs[local_rank] = os.devnull - error_files[local_rank] = os.devnull - envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = "" - else: - clogdir = os.path.join(attempt_log_dir, str(local_rank)) - os.mkdir(clogdir) - - rd = redirs[local_rank] - if (rd & Std.OUT) == Std.OUT: - stdouts[local_rank] = os.path.join(clogdir, "stdout.log") - if (rd & Std.ERR) == Std.ERR: - stderrs[local_rank] = os.path.join(clogdir, "stderr.log") - - t = ts[local_rank] - if t & Std.OUT == Std.OUT: - tee_stdouts[local_rank] = stdouts[local_rank] - if t & Std.ERR == Std.ERR: - tee_stderrs[local_rank] = stderrs[local_rank] - - if ( - self._local_ranks_filter - and local_rank not in self._local_ranks_filter - ): - # If stream is tee'd, only write to file, but don't tail - if local_rank in tee_stdouts: - tee_stdouts.pop(local_rank, None) - if local_rank in tee_stderrs: - tee_stderrs.pop(local_rank, None) - - # If stream is not redirected, don't print - if stdouts[local_rank] == SYS_STREAM: - stdouts[local_rank] = os.devnull - if stderrs[local_rank] == SYS_STREAM: - stderrs[local_rank] = os.devnull - - error_file = os.path.join(clogdir, "error.json") - error_files[local_rank] = error_file - logger.info( - "Setting worker%s reply file to: %s", local_rank, error_file - ) - envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file - - return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files) - - def __repr__(self) -> str: - return ( - f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, " - f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})" - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DefaultLogsSpecs): - return False - - return ( - self._root_log_dir == other._root_log_dir - and self._redirects == other._redirects - and self._tee == other._tee - and self._local_ranks_filter == other._local_ranks_filter - ) - - -@dataclass -class RunProcsResult: - """ - Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``. - - Note the following: - - 1. All fields are mapped by local rank - 2. ``return_values`` - only populated for functions (not the binaries). - 3. ``stdouts`` - path to stdout.log (empty string if no redirect) - 4. ``stderrs`` - path to stderr.log (empty string if no redirect) - - """ - - return_values: Dict[int, Any] = field(default_factory=dict) - failures: Dict[int, ProcessFailure] = field(default_factory=dict) - stdouts: Dict[int, str] = field(default_factory=dict) - stderrs: Dict[int, str] = field(default_factory=dict) - - def is_failed(self) -> bool: - return len(self.failures) > 0 - - -class PContext(abc.ABC): - """ - The base class that standardizes operations over a set of processes that are launched via different mechanisms. - - The name ``PContext`` is intentional to disambiguate with ``mindtorch.multiprocessing.ProcessContext``. - - .. warning:: stdouts and stderrs should ALWAYS be a superset of - tee_stdouts and tee_stderrs (respectively) this is b/c - tee is implemented as a redirect + tail -f - """ - - def __init__( - self, - name: str, - entrypoint: Union[Callable, str], - args: Dict[int, Tuple], - envs: Dict[int, Dict[str, str]], - logs_specs: LogsSpecs, - log_line_prefixes: Optional[Dict[int, str]] = None, - ): - self.name = name - # validate that all mappings have the same number of keys and - # all local ranks are accounted for - nprocs = len(args) - - # TODO log_line_prefixes can be exanded too - logs_dest = logs_specs.reify(envs) - - _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts") - _validate_full_rank(logs_dest.stderrs, nprocs, "stderrs") - - self.entrypoint = entrypoint - self.args = args - self.envs = envs - self.stdouts = logs_dest.stdouts - self.stderrs = logs_dest.stderrs - self.error_files = logs_dest.error_files - self.nprocs = nprocs - - self._stdout_tail = TailLog( - name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes - ) - self._stderr_tail = TailLog( - name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes - ) - - def start(self) -> None: - """Start processes using parameters defined in the constructor.""" - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGTERM, _terminate_process_handler) - signal.signal(signal.SIGINT, _terminate_process_handler) - if not IS_WINDOWS: - signal.signal(signal.SIGHUP, _terminate_process_handler) - signal.signal(signal.SIGQUIT, _terminate_process_handler) - else: - logger.warning( - "Failed to register signal handlers since torchelastic is running on a child thread. " - "This could lead to orphaned worker processes if the torchrun is terminated." - ) - self._start() - self._stdout_tail.start() - self._stderr_tail.start() - - @abc.abstractmethod - def _start(self) -> None: - """Start processes using strategy defined in a particular context.""" - raise NotImplementedError - - @abc.abstractmethod - def _poll(self) -> Optional[RunProcsResult]: - """ - Poll the run status of the processes running under this context. - This method follows an "all-or-nothing" policy and returns - a ``RunProcessResults`` object if either all processes complete - successfully or any process fails. Returns ``None`` if - all processes are still running. - """ - raise NotImplementedError - - def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]: - """ - Wait for the specified ``timeout`` seconds, polling every ``period`` seconds - for the processes to be done. Returns ``None`` if the processes are still running - on timeout expiry. Negative timeout values are interpreted as "wait-forever". - A timeout value of zero simply queries the status of the processes (e.g. equivalent - to a poll). - - ..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise - ``SignalException`` when the signals received. It is up to the consumer of the code - to properly handle the exception. It is important not to swallow the exception otherwise - the process would not terminate. Example of the typical workflow can be: - - .. code-block:: python - pc = start_processes(...) - try: - pc.wait(1) - .. do some other work - except SignalException as e: - pc.shutdown(e.sigval, timeout=30) - - If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating - received signal. If child processes will not terminate in the timeout time, the process will send - the SIGKILL. - """ - if timeout == 0: - return self._poll() - - if timeout < 0: - timeout = sys.maxsize - - expiry = time.time() + timeout - while time.time() < expiry: - pr = self._poll() - if pr: - return pr - time.sleep(period) - - return None - - @abc.abstractmethod - def pids(self) -> Dict[int, int]: - """Return pids of processes mapped by their respective local_ranks.""" - raise NotImplementedError - - @abc.abstractmethod - def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: - r""" - Terminates all processes managed by this context and cleans up any - meta resources (e.g. redirect, error_file files). - """ - raise NotImplementedError - - def close( - self, death_sig: Optional[signal.Signals] = None, timeout: int = 30 - ) -> None: - r""" - Terminates all processes managed by this context and cleans up any - meta resources (e.g. redirect, error_file files). - - Args: - death_sig: Death signal to terminate processes. - timeout: Time to wait for processes to finish, if process is - still alive after this time, it will be terminated via SIGKILL. - """ - if not death_sig: - death_sig = _get_default_signal() - self._close(death_sig=death_sig, timeout=timeout) - if self._stdout_tail: - self._stdout_tail.stop() - if self._stderr_tail: - self._stderr_tail.stop() - - -def get_std_cm(std_rd: str, redirect_fn): - if IS_WINDOWS or IS_MACOS or not std_rd: - return nullcontext() - else: - return redirect_fn(std_rd) - - -def _wrap( - local_rank: int, - fn: Callable, - args: Dict[int, Tuple], - envs: Dict[int, Dict[str, str]], - stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None) - stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None) - ret_vals: Dict[int, mp.SimpleQueue], - queue_finished_reading_event: synchronize.Event, -) -> None: - # get the per-rank params up front so we fail fast if no mapping is found - args_ = args[local_rank] - env_ = envs[local_rank] - ret_val_ = ret_vals[local_rank] - - stdout_rd = stdout_redirects[local_rank] - stderr_rd = stderr_redirects[local_rank] - - stdout_cm = get_std_cm(stdout_rd, redirect_stdout) - stderr_cm = get_std_cm(stderr_rd, redirect_stderr) - - for k, v in env_.items(): - os.environ[k] = v - - with stdout_cm, stderr_cm: - ret = record(fn)(*args_) - ret_val_.put(ret) - queue_finished_reading_event.wait() - - -class MultiprocessContext(PContext): - """``PContext`` holding worker processes invoked as a function.""" - - def __init__( - self, - name: str, - entrypoint: Callable, - args: Dict[int, Tuple], - envs: Dict[int, Dict[str, str]], - start_method: str, - logs_specs: LogsSpecs, - log_line_prefixes: Optional[Dict[int, str]] = None, - ): - super().__init__( - name, - entrypoint, - args, - envs, - logs_specs, - log_line_prefixes, - ) - - self.start_method = start_method - # each ret_val queue will always contain a single element. - self._ret_vals = { - local_rank: mp.get_context(self.start_method).SimpleQueue() - for local_rank in range(self.nprocs) - } - - # see comments in ``join()`` for what this is - self._return_values: Dict[int, Any] = {} - self._pc: Optional[mp.ProcessContext] = None - # Note: set method should ONLY be invoked for the use case when all processes finished - # successfully. If any process died on event.wait() calling set() method will deadlock. - self._worker_finished_event = mp.get_context(self.start_method).Event() - - def _start(self): - if self._pc: - raise ValueError( - "The process context already initialized." - " Most likely the start method got called twice." - ) - self._pc = mp.start_processes( - fn=_wrap, - args=( - self.entrypoint, - self.args, - self.envs, - self.stdouts, - self.stderrs, - self._ret_vals, - self._worker_finished_event, - ), - nprocs=self.nprocs, - join=False, - daemon=False, - start_method=self.start_method, - ) - - def _is_done(self) -> bool: - return len(self._return_values) == self.nprocs - - def _poll(self) -> Optional[RunProcsResult]: - assert self._pc is not None # assertion for mypy type checker - - try: - # mindtorch.mp.ProcessContext Throws an Exception if some/all of - # worker processes failed - # timeout < 0 checks worker status and return immediately - # Join will never return success since we use synchronize.Event to wait - # for all processes to finish. - self._pc.join(-1) - - # IMPORTANT: we use multiprocessing.Queue to carry worker return values - # back to the parent, the worker process will wait before terminating - # until all the buffered items are fed by the feeder thread to the underlying - # pipe. Hence to prevent deadlocks on large return values, - # we opportunistically try queue.get on each join call - # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms - for local_rank in range(0, self.nprocs): - return_queue = self._ret_vals[local_rank] - if not return_queue.empty(): - # save the return values temporarily into a member var - self._return_values[local_rank] = return_queue.get() - - if self._is_done(): - # we should ALWAYS have ALL the return values when all the processes are done - self._worker_finished_event.set() - - # At this point workers finished running the user function - # But the child process might still have not exited. Wait for them. - # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. - while not self._pc.join(): - logger.debug( - "entrypoint fn finished, waiting for all child procs to exit..." - ) - - _validate_full_rank( - self._return_values, self.nprocs, "return_value queue" - ) - self.close() - return RunProcsResult( - return_values=self._return_values, - stdouts=self.stdouts, - stderrs=self.stderrs, - ) - else: - return None - except (mp.ProcessRaisedException, mp.ProcessExitedException) as e: - failed_local_rank = e.error_index - - # entrypoint for MultiprocessContext will always be a Callable - fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr] - failed_proc = self._pc.processes[failed_local_rank] - error_filepath = self.error_files[failed_local_rank] - - logger.exception( - "failed (exitcode: %s)" - " local_rank: %s (pid: %s)" - " of fn: %s (start_method: %s)", - failed_proc.exitcode, - failed_local_rank, - e.pid, - fn_name, - self.start_method, - ) - - self.close() - return RunProcsResult( - failures={ - failed_local_rank: ProcessFailure( - local_rank=failed_local_rank, - pid=e.pid, - exitcode=failed_proc.exitcode, - error_file=error_filepath, - ) - }, - stdouts=self.stdouts, - stderrs=self.stderrs, - ) - - def pids(self) -> Dict[int, int]: - assert self._pc is not None # assertion for mypy type checking - return dict(enumerate(self._pc.pids())) - - def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: - if not self._pc: - return - for proc in self._pc.processes: - if proc.is_alive(): - logger.warning( - "Closing process %s via signal %s", proc.pid, death_sig.name - ) - try: - os.kill(proc.pid, death_sig) - except ProcessLookupError: - # If the process exited because of some reason, - # `ProcessLookupError` will be raised, it is safe to ignore it. - pass - end = time.monotonic() + timeout - for proc in self._pc.processes: - time_to_wait = end - time.monotonic() - if time_to_wait <= 0: - break - proc.join(time_to_wait) - for proc in self._pc.processes: - if proc.is_alive(): - logger.warning( - "Unable to shutdown process %s via %s, forcefully exiting via %s", - proc.pid, - death_sig, - _get_kill_signal(), - ) - try: - os.kill(proc.pid, _get_kill_signal()) - except ProcessLookupError: - # If the process exited because of some reason, - # `ProcessLookupError` will be raised, it is safe to ignore it. - pass - proc.join() - - -class SubprocessContext(PContext): - """``PContext`` holding worker processes invoked as a binary.""" - - def __init__( - self, - name: str, - entrypoint: str, - args: Dict[int, Tuple], - envs: Dict[int, Dict[str, str]], - logs_specs: LogsSpecs, - log_line_prefixes: Optional[Dict[int, str]] = None, - ): - super().__init__( - name, - entrypoint, - args, - envs, - logs_specs, - log_line_prefixes, - ) - - # state vector; _vdone[local_rank] -> is local_rank finished or not - self._running_local_ranks: Set[int] = set(range(self.nprocs)) - self._failures: Dict[int, ProcessFailure] = {} - self.subprocess_handlers: Dict[int, SubprocessHandler] = {} - - def _start(self): - if self.subprocess_handlers: - raise ValueError( - "The subprocess handlers already initialized. Most likely the start method got called twice." - ) - self.subprocess_handlers = { - local_rank: get_subprocess_handler( - entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str - args=self.args[local_rank], - env=self.envs[local_rank], - stdout=self.stdouts[local_rank], - stderr=self.stderrs[local_rank], - local_rank_id=local_rank, - ) - for local_rank in range(self.nprocs) - } - - def _poll(self) -> Optional[RunProcsResult]: - done_local_ranks = set() - for local_rank in self._running_local_ranks: - handler = self.subprocess_handlers[local_rank] - exitcode = handler.proc.poll() - if exitcode is not None: - done_local_ranks.add(local_rank) - if exitcode != 0: # failed or signaled - self._failures[local_rank] = ProcessFailure( - local_rank=local_rank, - pid=handler.proc.pid, - exitcode=exitcode, - error_file=self.error_files[local_rank], - ) - # else: --> succeeded; nothing to do - - self._running_local_ranks.difference_update(done_local_ranks) - - # if ALL procs are finished or ANY have failed - if not self._running_local_ranks or self._failures: - self.close() # terminate all running procs - result = RunProcsResult( - failures=self._failures, - stdouts=self.stdouts, - stderrs=self.stderrs, - ) - if result.is_failed(): - first_failure = min(result.failures.values(), key=lambda f: f.timestamp) - logger.error( - "failed (exitcode: %s)" - " local_rank: %s (pid: %s)" - " of binary: %s", - first_failure.exitcode, - first_failure.local_rank, - first_failure.pid, - self.entrypoint, - ) - else: - # Populate return with dummy values. This provides consistency with MultiprocessingHandler - result.return_values = dict.fromkeys(range(self.nprocs)) - - return result - else: # there are no failures and procs still running - return None - - def pids(self) -> Dict[int, int]: - return { - local_rank: sh.proc.pid - for local_rank, sh in self.subprocess_handlers.items() - } - - def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: - if not self.subprocess_handlers: - return - for handler in self.subprocess_handlers.values(): - if handler.proc.poll() is None: - logger.warning( - "Sending process %s closing signal %s", - handler.proc.pid, - death_sig.name, - ) - handler.close(death_sig=death_sig) - end = time.monotonic() + timeout - for handler in self.subprocess_handlers.values(): - time_to_wait = end - time.monotonic() - if time_to_wait <= 0: - break - try: - handler.proc.wait(time_to_wait) - except subprocess.TimeoutExpired: - # Ignore the timeout expired exception, since - # the child process will be forcefully terminated via SIGKILL - pass - for handler in self.subprocess_handlers.values(): - if handler.proc.poll() is None: - logger.warning( - "Unable to shutdown process %s via %s, forcefully exiting via %s", - handler.proc.pid, - death_sig, - _get_kill_signal(), - ) - handler.close(death_sig=_get_kill_signal()) - handler.proc.wait() diff --git a/mindtorch/distributed/elastic/multiprocessing/errors/__init__.py b/mindtorch/distributed/elastic/multiprocessing/errors/__init__.py deleted file mode 100644 index 73aeb9225..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/errors/__init__.py +++ /dev/null @@ -1,383 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Each host in a distributed PyTorch job runs with a single TorchElastic agent, -and multiple workers (as children processes of the TorchElastic agent). -Since the workers are user-provided (your PyTorch script/job), TorchElastic -has a way to propagate errors on the trainers through the agent and up to the -scheduler, which ultimately informs the end-user about the state of the job -and applies any retry policies. - -TorchElastic categorizes errors into 3 categories: - -+----------------+----------------+--------------------------------------------------------------+ -| Category | Sub-Category | Description | -+================+================+==============================================================+ -| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) | -| +----------------+--------------------------------------------------------------+ -| | Worker Failure | any failures on the worker child process | -+----------------+----------------+--------------------------------------------------------------+ -| Platform Error | n/a | failures caused by the agent | -+----------------+----------------+--------------------------------------------------------------+ -| Infra Error | n/a | failures outside the domain of the agent and workers | -| | | (e.g. host failures) | -+----------------+----------------+--------------------------------------------------------------+ - -All errors other than "Worker Failure" are either raised canonically from the -agent process or implicitly or explicitly crash the agent process. So the -standard language (python) provided exception handling strategies apply. - -Worker Failures are special because the exception/failure originates on a different -process from the agent so the error needs to be propagated inter-process -(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process). - -TorchElastic agents use :func:`mindtorch.distributed.elastic.multiprocessing.start_processes` -to launch the workers which has a simple file based inter-process error propagation -built-in. - -Any function or binary entrypoint decorated with :func:`record` -will write uncaught exceptions (with the trace information) to a file specified by the -environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent) -sets this env var on each child it launches, then aggregates the error files for all -children, and propagates the one with the **smallest** timestamp (e.g. the **first** error). -""" - -import json -import os -import signal -import socket -import time -import warnings -from dataclasses import dataclass, field -from datetime import datetime -from functools import wraps -from string import Template -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar - -from mindtorch.distributed.elastic.utils.logging import get_logger - -from .error_handler import ErrorHandler # noqa: F401 -from .handlers import get_error_handler # noqa: F401 - - -__all__ = [ - "ProcessFailure", - "ChildFailedError", - "record", - "ErrorHandler", - "get_error_handler", -] - -logger = get_logger(__name__) - - -JSON = Dict - -_EMPTY_ERROR_DATA = {"message": ""} -_NOT_AVAILABLE = "" - -T = TypeVar("T") - - -@dataclass -class ProcessFailure: - """ - Represent the failed process result. When the worker process fails, it may record failure root cause into the file. - - Tries to read the failure timestamp from the provided ``error_file``, - if the ``error_file`` does not exist, the timestamp is the current - timestamp (seconds since epoch). - - The ``message`` field is a concise explanation of the failure. If - the error file exists then the message is obtained from the error file. - Otherwise one is generated based on the failure signature. - - .. note:: It is assumed that the ``error_file`` is written by - ``mindtorch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``. - Otherwise the behavior is undefined. - - """ - - local_rank: int - pid: int - exitcode: int - error_file: str - error_file_data: JSON = field(init=False) - message: str = field(init=False) - timestamp: int = field(init=False) - - def __post_init__(self): - self.error_file_data = _EMPTY_ERROR_DATA - if os.path.isfile(self.error_file): - try: - with open(self.error_file) as fp: - self.error_file_data = json.load(fp) - logger.debug( - "User process failed with error data: %s", - json.dumps(self.error_file_data, indent=2), - ) - self.message, self.timestamp = self._get_error_data( - self.error_file_data - ) - except Exception: - logger.exception("Failed to parse reply file: %s", self.error_file) - raise - else: - self._set_no_reply_file() - - # make up an informative message if not already present - if not self.message: - # signals typically do not generate an error file message - if self.exitcode < 0: - self.message = ( - f"Signal {-self.exitcode} ({self.signal_name()})" - f" received by PID {self.pid}" - ) - else: - self.message = "To enable traceback see: https://pymindtorch.org/docs/stable/elastic/errors.html" - - def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]: - message = error_file_data["message"] - if isinstance(message, str): - timestamp = int(error_file_data.get("timestamp", 0)) - else: - timestamp = int(message["extraInfo"]["timestamp"]) - return (message, timestamp) - - def _set_no_reply_file(self): - self.error_file = _NOT_AVAILABLE - self.error_file_data = _EMPTY_ERROR_DATA - self.message = "" - self.timestamp = int(time.time()) - - def signal_name(self) -> str: - if self.exitcode < 0: - # We don't want to kill the parent process trying to find the signal name. - # if the signal doesn't map to a known name, use not available. - try: - return signal.Signals(-self.exitcode).name - except Exception: - return _NOT_AVAILABLE - else: - return _NOT_AVAILABLE - - def timestamp_isoformat(self): - """Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS).""" - return datetime.fromtimestamp(self.timestamp).isoformat(sep="_") - - -GlobalRank = int - -_FAILURE_FORMAT_TEMPLATE = """[${idx}]: - time : ${time} - host : ${hostname} - rank : ${rank} (local_rank: ${local_rank}) - exitcode : ${exitcode} (pid: ${pid}) - error_file: ${error_file} - traceback : ${message}""" - -# extra new lines before and after are intentional -_MSG_FORMAT_TEMPLATE = """ -${boarder} -${title} -${section} -Failures: -${other_failures} -${section} -Root Cause (first observed failure): -${root_failure} -${boarder}""" - - -class ChildFailedError(Exception): - """ - Special exception type that can be raised from a function annotated with the - ``@record`` decorator to have the child process' (root exception) propagate - up the stack as-is (e.g. without being wrapped in the parent's traceback). - - Useful in cases where the parent is a simple nanny process - and the child (worker) processes are actually doing meaningful compute. - In this case, errors typically occur on the child process as the parent - is not doing anything non-trivial, and child errors should be propagated - to the scheduler for accurate root cause diagnostics. - - .. note:: The propagation relies on error files rather than exception handling to - support both function and binary launches. - - Example: - :: - - # process tree on a host (container) - 0: scheduler-init-process: - |- 1: torchelastic_agent: - |- 2: trainer_0 (ok) - |- 3: trainer_1 (fail) -> error.json - |- ... - |- n+2: trainer_n (ok) - |- n+3: other processes - |- ... - - In the example above, trainer 1's failure (written into error.json) is - the root cause and should be reported to the scheduler's init process. - The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})`` - upon detecting trainer 1's failure which would propagate the contents - of trainer 1's error file to the scheduler's init process. - """ - - def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]): - self.name = name - self.failures = failures - assert ( - self.failures - ) # does not make sense to create a ChildFaileError with no failures - super().__init__(self.format_msg()) - - def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]: - rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp) - return rank, self.failures[rank] - - def format_msg(self, boarder_delim="=", section_delim="-"): - title = f"{self.name} FAILED" - root_rank, _root_failure = self.get_first_failure() - - root_failure_fmt: str = "" - other_failures_fmt: List[str] = [] - width = len(title) - for idx, (rank, failure) in enumerate(self.failures.items()): - fmt, w = self._format_failure(idx, rank, failure) - width = max(width, w) - if rank == root_rank: - root_failure_fmt = fmt - else: - other_failures_fmt.append(fmt) - - # upper boundary on width - width = min(width, 60) - - return Template(_MSG_FORMAT_TEMPLATE).substitute( - boarder=boarder_delim * width, - title=title, - section=section_delim * width, - root_failure=root_failure_fmt, - other_failures="\n".join(other_failures_fmt or [" "]), - ) - - def _format_failure( - self, idx: int, rank: int, failure: ProcessFailure - ) -> Tuple[str, int]: - # failure.message is either a str (when the failure does not generate a traceback - e.g. signals) - # or a dict (json) of the form - # {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}} - # so the display logic is: - # 1. if failure.message is not a dict (it is a str) just show it as is - # 2. else try to get the traceback (py_callstack) - # 3. if the traceback is not there, use the message - # 4. if the message is not there show - msg = failure.message - if isinstance(failure.message, dict): - msg = ( - failure.message.get("extraInfo", {}) - .get("py_callstack", failure.message.get("message", "")) - .replace("\n", "\n ") # to properly indent the traceback - ) - - fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( - idx=idx, - time=failure.timestamp_isoformat(), - hostname=socket.getfqdn(), - rank=rank, - local_rank=failure.local_rank, - exitcode=failure.exitcode, - pid=failure.pid, - error_file=failure.error_file, - message=msg, - ) - width = 0 - for line in fmt.split("\n"): - width = max(width, len(line)) - return fmt, width - - -def record( - fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None -) -> Callable[..., T]: - """ - Syntactic sugar to record errors/exceptions that happened in the decorated - function using the provided ``error_handler``. - - Using this decorator is equivalent to: - - :: - - error_handler = get_error_handler() - error_handler.initialize() - try: - foobar() - except ChildFailedError as e: - _, failure = e.get_first_failure() - error_handler.dump_error_file(failure.error_file, failure.exitcode) - raise - except Exception as e: - error_handler.record(e) - raise - - .. important:: use this decorator once per process at the top level method, - typically this is the main method. - - Example - - :: - - @record - def main(): - pass - - if __name__=="__main__": - main() - - """ - if not error_handler: - error_handler = get_error_handler() - - def wrap(f): - @wraps(f) - def wrapper(*args, **kwargs): - assert error_handler is not None # assertion for mypy type checker - error_handler.initialize() - try: - return f(*args, **kwargs) - except SystemExit as se: - # For run_path based entrypoints, SystemExit with code = 0 will never exit. - # Handling it here by returning a value: - if se.code == 0: - return None - else: - raise - except ChildFailedError as e: - rank, failure = e.get_first_failure() - if failure.error_file != _NOT_AVAILABLE: - error_handler.dump_error_file(failure.error_file, failure.exitcode) - else: - logger.info( - ( - "local_rank %s FAILED with no error file." - " Decorate your entrypoint fn with @record for traceback info." - " See: https://pymindtorch.org/docs/stable/elastic/errors.html", - rank, - ) - ) - raise - except Exception as e: - error_handler.record_exception(e) - raise - - return wrapper - - return wrap(fn) diff --git a/mindtorch/distributed/elastic/multiprocessing/errors/error_handler.py b/mindtorch/distributed/elastic/multiprocessing/errors/error_handler.py deleted file mode 100644 index f0fb72ac9..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/errors/error_handler.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import faulthandler -import json -import logging -import os -import time -import traceback -import warnings -from typing import Any, Dict, Optional - - -__all__ = ["ErrorHandler"] - -logger = logging.getLogger(__name__) - - -class ErrorHandler: - """ - Write the provided exception object along with some other metadata about - the error in a structured way in JSON format to an error file specified by the - environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment - variable is not set, then simply logs the contents of what would have been - written to the error file. - - This handler may be subclassed to customize the handling of the error. - Subclasses should override ``initialize()`` and ``record_exception()``. - """ - - def _get_error_file_path(self) -> Optional[str]: - """ - Return the error file path. - - May return ``None`` to have the structured error be logged only. - """ - return os.environ.get("TORCHELASTIC_ERROR_FILE", None) - - def initialize(self) -> None: - """ - Call prior to running code that we wish to capture errors/exceptions. - - Typically registers signal/fault handlers. Users can override this - function to add custom initialization/registrations that aid in - propagation/information of errors/signals/exceptions/faults. - """ - try: - faulthandler.enable(all_threads=True) - except Exception as e: - warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}") - - def _write_error_file(self, file_path: str, error_msg: str) -> None: - """Write error message to the file.""" - try: - with open(file_path, "w") as fp: - fp.write(error_msg) - except Exception as e: - warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}") - - def record_exception(self, e: BaseException) -> None: - """ - Write a structured information about the exception into an error file in JSON format. - - If the error file cannot be determined, then logs the content - that would have been written to the error file. - """ - file = self._get_error_file_path() - if file: - data = { - "message": { - "message": f"{type(e).__name__}: {e}", - "extraInfo": { - "py_callstack": traceback.format_exc(), - "timestamp": str(int(time.time())), - }, - } - } - with open(file, "w") as fp: - json.dump(data, fp) - - def override_error_code_in_rootcause_data( - self, - rootcause_error_file: str, - rootcause_error: Dict[str, Any], - error_code: int = 0, - ): - """Modify the rootcause_error read from the file, to correctly set the exit code.""" - if "message" not in rootcause_error: - logger.warning( - "child error file (%s) does not have field `message`. \n" - "cannot override error code: %s", - rootcause_error_file, - error_code, - ) - elif isinstance(rootcause_error["message"], str): - logger.warning( - "child error file (%s) has a new message format. \n" - "skipping error code override", - rootcause_error_file, - ) - else: - rootcause_error["message"]["errorCode"] = error_code - - def dump_error_file(self, rootcause_error_file: str, error_code: int = 0): - """Dump parent error file from child process's root cause error and error code.""" - with open(rootcause_error_file) as fp: - rootcause_error = json.load(fp) - # Override error code since the child process cannot capture the error code if it - # is terminated by signals like SIGSEGV. - if error_code: - self.override_error_code_in_rootcause_data( - rootcause_error_file, rootcause_error, error_code - ) - logger.debug( - "child error file (%s) contents:\n" "%s", - rootcause_error_file, - json.dumps(rootcause_error, indent=2), - ) - - my_error_file = self._get_error_file_path() - if my_error_file: - # Guard against existing error files - # This can happen when the child is created using multiprocessing - # and the same env var (TORCHELASTIC_ERROR_FILE) is used on the - # parent and child to specify the error files (respectively) - # because the env vars on the child is set in the wrapper function - # and by default the child inherits the parent's env vars, if the child - # process receives a signal before the wrapper function kicks in - # and the signal handler writes to the error file, then the child - # will write to the parent's error file. In this case just log the - # original error file contents and overwrite the error file. - self._rm(my_error_file) - self._write_error_file(my_error_file, json.dumps(rootcause_error)) - logger.info("dumped error file to parent's %s", my_error_file) - else: - logger.error( - "no error file defined for parent, to copy child error file (%s)", - rootcause_error_file, - ) - - def _rm(self, my_error_file): - if os.path.isfile(my_error_file): - # Log the contents of the original file. - with open(my_error_file) as fp: - try: - original = json.dumps(json.load(fp), indent=2) - logger.warning( - "%s already exists" - " and will be overwritten." - " Original contents:\n%s", - my_error_file, - original, - ) - except json.decoder.JSONDecodeError: - logger.warning( - "%s already exists" - " and will be overwritten." - " Unable to load original contents:\n", - my_error_file, - ) - os.remove(my_error_file) diff --git a/mindtorch/distributed/elastic/multiprocessing/errors/handlers.py b/mindtorch/distributed/elastic/multiprocessing/errors/handlers.py deleted file mode 100644 index f48c9b4a8..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/errors/handlers.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -# Multiprocessing error-reporting module - - -from mindtorch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler - - -__all__ = ["get_error_handler"] - - -def get_error_handler(): - return ErrorHandler() diff --git a/mindtorch/distributed/elastic/multiprocessing/redirects.py b/mindtorch/distributed/elastic/multiprocessing/redirects.py deleted file mode 100644 index 4553fbebd..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/redirects.py +++ /dev/null @@ -1,104 +0,0 @@ -# mypy: allow-untyped-defs -# !/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Taken and modified from original source: -# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ -import ctypes -import logging -import os -import sys -from contextlib import contextmanager -from functools import partial - - -IS_WINDOWS = sys.platform == "win32" -IS_MACOS = sys.platform == "darwin" - - -logger = logging.getLogger(__name__) - - -def get_libc(): - if IS_WINDOWS or IS_MACOS: - logger.warning( - "NOTE: Redirects are currently not supported in Windows or MacOs." - ) - return None - else: - return ctypes.CDLL("libc.so.6") - - -libc = get_libc() - - -def _c_std(stream: str): - return ctypes.c_void_p.in_dll(libc, stream) - - -def _python_std(stream: str): - return {"stdout": sys.stdout, "stderr": sys.stderr}[stream] - - -_VALID_STD = {"stdout", "stderr"} - - -@contextmanager -def redirect(std: str, to_file: str): - """ - Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``. - - This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``). - See usage for details. - - Directory of ``dst_filename`` is assumed to exist and the destination file - is overwritten if it already exists. - - .. note:: Due to buffering cross source writes are not guaranteed to - appear in wall-clock order. For instance in the example below - it is possible for the C-outputs to appear before the python - outputs in the log file. - - Usage: - - :: - - # syntactic-sugar for redirect("stdout", "tmp/stdout.log") - with redirect_stdout("/tmp/stdout.log"): - print("python stdouts are redirected") - libc = ctypes.CDLL("libc.so.6") - libc.printf(b"c stdouts are also redirected" - os.system("echo system stdouts are also redirected") - - print("stdout restored") - - """ - if std not in _VALID_STD: - raise ValueError( - f"unknown standard stream <{std}>, must be one of {_VALID_STD}" - ) - - c_std = _c_std(std) - python_std = _python_std(std) - std_fd = python_std.fileno() - - def _redirect(dst): - libc.fflush(c_std) - python_std.flush() - os.dup2(dst.fileno(), std_fd) - - with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst: - _redirect(dst) - try: - yield - finally: - _redirect(orig_std) - - -redirect_stdout = partial(redirect, "stdout") -redirect_stderr = partial(redirect, "stderr") diff --git a/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py b/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py deleted file mode 100644 index 24cb6f8e5..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from mindtorch.distributed.elastic.multiprocessing.subprocess_handler.handlers import ( - get_subprocess_handler, -) -from mindtorch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( - SubprocessHandler, -) - - -__all__ = ["SubprocessHandler", "get_subprocess_handler"] diff --git a/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py b/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py deleted file mode 100644 index 1ff071ec9..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from typing import Dict, Tuple - -from mindtorch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import ( - SubprocessHandler, -) - - -__all__ = ["get_subprocess_handler"] - - -def get_subprocess_handler( - entrypoint: str, - args: Tuple, - env: Dict[str, str], - stdout: str, - stderr: str, - local_rank_id: int, -): - return SubprocessHandler( - entrypoint=entrypoint, - args=args, - env=env, - stdout=stdout, - stderr=stderr, - local_rank_id=local_rank_id, - ) diff --git a/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py deleted file mode 100644 index a00905af4..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import os -import signal -import subprocess -import sys -from typing import Any, Dict, Optional, Tuple - - -__all__ = ["SubprocessHandler"] - -IS_WINDOWS = sys.platform == "win32" - - -def _get_default_signal() -> signal.Signals: - """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" - if IS_WINDOWS: - return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 - else: - return signal.SIGTERM - - -class SubprocessHandler: - """ - Convenience wrapper around python's ``subprocess.Popen``. Keeps track of - meta-objects associated to the process (e.g. stdout and stderr redirect fds). - """ - - def __init__( - self, - entrypoint: str, - args: Tuple, - env: Dict[str, str], - stdout: Optional[str], - stderr: Optional[str], - local_rank_id: int, - ): - self._stdout = open(stdout, "w") if stdout else None - self._stderr = open(stderr, "w") if stderr else None - # inherit parent environment vars - env_vars = os.environ.copy() - env_vars.update(env) - - args_str = (entrypoint, *[str(e) for e in args]) - self.local_rank_id = local_rank_id - self.proc: subprocess.Popen = self._popen(args_str, env_vars) - - def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen: - kwargs: Dict[str, Any] = {} - if not IS_WINDOWS: - kwargs["start_new_session"] = True - return subprocess.Popen( - # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes], - # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got - # `Tuple[str, *Tuple[Any, ...]]`. - args=args, - env=env, - stdout=self._stdout, - stderr=self._stderr, - **kwargs, - ) - - def close(self, death_sig: Optional[signal.Signals] = None) -> None: - if not death_sig: - death_sig = _get_default_signal() - if IS_WINDOWS: - self.proc.send_signal(death_sig) - else: - os.killpg(self.proc.pid, death_sig) - if self._stdout: - self._stdout.close() - if self._stderr: - self._stderr.close() diff --git a/mindtorch/distributed/elastic/multiprocessing/tail_log.py b/mindtorch/distributed/elastic/multiprocessing/tail_log.py deleted file mode 100644 index 9d4e649c3..000000000 --- a/mindtorch/distributed/elastic/multiprocessing/tail_log.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import os -import time -from concurrent.futures.thread import ThreadPoolExecutor -from threading import Event -from typing import Dict, List, Optional, TextIO, TYPE_CHECKING - - -if TYPE_CHECKING: - from concurrent.futures._base import Future - -__all__ = ["tail_logfile", "TailLog"] - -logger = logging.getLogger(__name__) - - -def tail_logfile( - header: str, file: str, dst: TextIO, finished: Event, interval_sec: float -): - while not os.path.exists(file): - if finished.is_set(): - return - time.sleep(interval_sec) - - with open(file, errors="replace") as fp: - while True: - line = fp.readline() - - if line: - dst.write(f"{header}{line}") - else: # reached EOF - if finished.is_set(): - # log line producer is finished - break - else: - # log line producer is still going - # wait for a bit before looping again - time.sleep(interval_sec) - - -class TailLog: - """ - Tail the given log files. - - The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until - the log files are created by the producer and will tail the contents of the - log files until the ``stop()`` method is called. - - .. warning:: ``TailLog`` will wait indefinitely for the log file to be created! - - Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``, - where the ``name`` is user-provided and ``idx`` is the index of the log file - in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the - header for each log file. - - Usage: - - :: - - log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"} - tailer = TailLog("trainer", log_files, sys.stdout).start() - # actually run the trainers to produce 0_stdout.log and 1_stdout.log - run_trainers() - tailer.stop() - - # once run_trainers() start writing the ##_stdout.log files - # the tailer will print to sys.stdout: - # >>> [trainer0]:log_line1 - # >>> [trainer1]:log_line1 - # >>> [trainer0]:log_line2 - # >>> [trainer0]:log_line3 - # >>> [trainer1]:log_line2 - - .. note:: Due to buffering log lines between files may not necessarily - be printed out in order. You should configure your application's - logger to suffix each log line with a proper timestamp. - - """ - - def __init__( - self, - name: str, - log_files: Dict[int, str], - dst: TextIO, - log_line_prefixes: Optional[Dict[int, str]] = None, - interval_sec: float = 0.1, - ): - n = len(log_files) - self._threadpool = None - if n > 0: - self._threadpool = ThreadPoolExecutor( - max_workers=n, - thread_name_prefix=f"{self.__class__.__qualname__}_{name}", - ) - - self._name = name - self._dst = dst - self._log_files = log_files - self._log_line_prefixes = log_line_prefixes - self._finished_events: Dict[int, Event] = { - local_rank: Event() for local_rank in log_files.keys() - } - self._futs: List[Future] = [] - self._interval_sec = interval_sec - self._stopped = False - - def start(self) -> "TailLog": - if not self._threadpool: - return self - - for local_rank, file in self._log_files.items(): - header = f"[{self._name}{local_rank}]:" - if self._log_line_prefixes and local_rank in self._log_line_prefixes: - header = self._log_line_prefixes[local_rank] - self._futs.append( - self._threadpool.submit( - tail_logfile, - header=header, - file=file, - dst=self._dst, - finished=self._finished_events[local_rank], - interval_sec=self._interval_sec, - ) - ) - return self - - def stop(self) -> None: - for finished in self._finished_events.values(): - finished.set() - - for local_rank, f in enumerate(self._futs): - try: - f.result() - except Exception as e: - logger.error( - "error in log tailor for %s%s. %s: %s", - self._name, - local_rank, - e.__class__.__qualname__, - e, - ) - - if self._threadpool: - self._threadpool.shutdown(wait=True) - - self._stopped = True - - def stopped(self) -> bool: - return self._stopped diff --git a/mindtorch/distributed/elastic/rendezvous/__init__.py b/mindtorch/distributed/elastic/rendezvous/__init__.py deleted file mode 100644 index a9329ae69..000000000 --- a/mindtorch/distributed/elastic/rendezvous/__init__.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -In the context of Torch Distributed Elastic we use the term *rendezvous* to -refer to a particular functionality that combines a **distributed -synchronization** primitive with **peer discovery**. - -It is used by Torch Distributed Elastic to gather participants of a training -job (i.e. nodes) such that they all agree on the same list of participants and -everyone's roles, as well as make a consistent collective decision on when -training can begin/resume. - -Torch Distributed Elastic rendezvous provides the following critical -functionalities: - -**Barrier**: - -Nodes performing rendezvous will all block until the rendezvous is considered -complete - this happens when at least ``min`` total number of nodes have joined -the rendezvous barrier (for the same job). This also implies the barrier is not -necessarily of fixed size. - -There's an additional small waiting time after reaching ``min`` number of -nodes - this is used to ensure the rendezvous is not completed "too quickly" -(which could potentially exclude additional nodes attempting to join at -approximately the same time). - -If ``max`` number of nodes is gathered at the barrier, the rendezvous is -completed immediately. - -There's also an overall timeout which causes the rendezvous to fail if ``min`` -number of nodes is never reached - this is meant to be a simple fail-safe to -help release partially allocated job resources, in case there's a problem with -the resource manager, and is meant to be interpreted as non-retryable. - -**Exclusivity**: - -A simple distributed barrier would not be sufficient, as we also need to ensure -that only one group of nodes exists at any given time (for a given job). In -other words, new nodes (i.e. joining late) should not be able to form a parallel -independent group of workers for the same job. - -Torch Distributed Elastic rendezvous ensures that if a group of nodes has -already completed a rendezvous (and hence might already be training), then -additional "late" nodes attempting to rendezvous will only announce themselves -as waiting, and will have to wait until the (previously completed) existing -rendezvous is destroyed first. - -**Consistency**: - -When a rendezvous is completed, all its members will agree on the job membership -and everyone's role in it. This role is represented using an integer, called -rank, that is between between 0 and world size. - -Note that ranks are *not stable*, in the sense that the same node can be -assigned a different rank in the next (re-)rendezvous. - -**Fault-tolerance**: - -Torch Distributed Elastic rendezvous is designed to tolerate node failures -during the rendezvous process. Should a process crash (or lose network -connectivity, etc), between joining the rendezvous and it being completed, then -a re-rendezvous with remaining healthy nodes will happen automatically. - -A node can also fail *after* it has completed (or *has been observered* by other -nodes to have completed) the rendezvous - this scenario will be handled by the -Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a -re-rendezvous). - -**Shared key-value store**: - -When the rendezvous is completed, a shared key-value store is created and -returned. This store implements a ``mindtorch.distributed.Store`` API (see -`distributed communication docs -`__). - -This store is only shared by the members of the completed rendezvous. It -is intended to be used by Torch Distributed Elastic to exchange information -necessary to initialize job control and data-planes. - -**Waiting workers and rendezvous closing**: - -Torch Distributed Elastic rendezvous handler object provides additional -functionalities, which are technically not part of the rendezvous process: - -1. Querying how many workers arrived late at the barrier, who can participate in - *next* rendezvous. - -2. Setting the rendezvous *closed* to signal all nodes not to participate in - next rendezvous. - -**DynamicRendezvousHandler**: - -Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler` -class that implements the rendezvous mechanism described above. It is a backend- -agnostic type that expects a particular :py:class:`.RendezvousBackend` instance -to be specified during construction. - -Torch distributed users can either implement their own backend type or use one -of the following implementations that come with PyTorch: - -- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default - ``TCPStore``) as the rendezvous backend. The main advantage of using a C10d - store is that it requires no 3rd-party dependency (such as etcd) to establish - a rendezvous. -- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy - :py:class:`.EtcdRendezvousHandler` class. Passing an - :py:class:`.EtcdRendezvousBackend` instance to - :py:class:`.DynamicRendezvousHandler` is functionally equivalent to - instantiating an :py:class:`.EtcdRendezvousHandler`. - - :: - - store = TCPStore("localhost") - - backend = C10dRendezvousBackend(store, "my_run_id") - - rdzv_handler = DynamicRendezvousHandler.from_backend( - run_id="my_run_id", - store=store, - backend=backend, - min_nodes=2, - max_nodes=4 - ) -""" - -from .api import ( - rendezvous_handler_registry, - RendezvousClosedError, - RendezvousConnectionError, - RendezvousError, - RendezvousGracefulExitError, - RendezvousHandler, - RendezvousHandlerCreator, - RendezvousHandlerRegistry, - RendezvousInfo, - RendezvousParameters, - RendezvousStateError, - RendezvousStoreInfo, - RendezvousTimeoutError, -) -from .registry import _register_default_handlers, _register_out_of_tree_handlers - - -_register_default_handlers() -_register_out_of_tree_handlers() - - -__all__ = [ - "RendezvousClosedError", - "RendezvousConnectionError", - "RendezvousError", - "RendezvousGracefulExitError", - "RendezvousHandler", - "RendezvousHandlerCreator", - "RendezvousHandlerRegistry", - "RendezvousInfo", - "RendezvousParameters", - "RendezvousStateError", - "RendezvousStoreInfo", - "RendezvousTimeoutError", - "rendezvous_handler_registry", -] diff --git a/mindtorch/distributed/elastic/rendezvous/api.py b/mindtorch/distributed/elastic/rendezvous/api.py deleted file mode 100644 index 70eefd868..000000000 --- a/mindtorch/distributed/elastic/rendezvous/api.py +++ /dev/null @@ -1,384 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import socket -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Callable, ClassVar, Dict, Optional - -from mindtorch.distributed import Store -from mindtorch.distributed.elastic.utils.distributed import get_free_port - - -__all__ = [ - "RendezvousClosedError", - "RendezvousConnectionError", - "RendezvousError", - "RendezvousGracefulExitError", - "RendezvousHandler", - "RendezvousHandlerCreator", - "RendezvousHandlerRegistry", - "RendezvousInfo", - "RendezvousParameters", - "RendezvousStateError", - "RendezvousStoreInfo", - "RendezvousTimeoutError", - "rendezvous_handler_registry", -] - - -class RendezvousError(Exception): - """Represents the base type for rendezvous errors.""" - - -class RendezvousClosedError(RendezvousError): - """Raised when a rendezvous is closed.""" - - -class RendezvousTimeoutError(RendezvousError): - """Raised when a rendezvous did not complete on time.""" - - -class RendezvousConnectionError(RendezvousError): - """Raised when the connection to a rendezvous backend has failed.""" - - -class RendezvousStateError(RendezvousError): - """Raised when the state of a rendezvous is corrupt.""" - - -class RendezvousGracefulExitError(RendezvousError): - """Raised when node wasn't not included in rendezvous and gracefully exits. - - Exception is a mechanism to exit the stack, however does not mean a failure. - """ - - -@dataclass -class RendezvousStoreInfo: - """Store address and port that can be used to bootstrap trainer distributed comms""" - - MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR" - MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT" - master_addr: str - master_port: int - - @staticmethod - def build( - rank: int, - store: Store, - local_addr: Optional[str], - server_port: Optional[int] = None, - ) -> "RendezvousStoreInfo": - """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. - - If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor. - - Args: - rank: rank of the current node - store: store to use for rendezvous - local_addr: address of the current node, if not provided will be resolved from hostname - server_port: port of the TCPStore server, when the TCPStore is shared. - """ - # TODO swap to collectives comms API - if rank == 0: - addr = local_addr or socket.getfqdn() - # When TCPStore is not shared, we fallback to get_free_port. - port = server_port or get_free_port() - store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type] - store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] - - addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8") - port = int( - store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8") - ) - return RendezvousStoreInfo(master_addr=addr, master_port=port) - - -class RendezvousInfo: - """Holds the information about the rendezvous.""" - - def __init__( - self, - store: Store, - rank: int, - world_size: int, - bootstrap_store_info: RendezvousStoreInfo, - ): - self._store = store - self._rank = rank - self._world_size = world_size - self._bootstrap_store_info = bootstrap_store_info - - @property - def store(self) -> Store: - """Store used by torchelastic control plane""" - return self._store - - @property - def rank(self) -> int: - """Rank within a group""" - return self._rank - - @property - def world_size(self) -> int: - """Global group size""" - return self._world_size - - @property - def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]: - """Store information that can used by trainer code to bootstrap distributed comms.""" - return self._bootstrap_store_info - - -class RendezvousHandler(ABC): - """Main rendezvous interface. - - Note: - Distributed Torch users normally **do not** need to implement their own - ``RendezvousHandler``. An implementation based on C10d Store is already - provided, and is recommended for most users. - """ - - @abstractmethod - def get_backend(self) -> str: - """Return the name of the rendezvous backend.""" - - @property - def use_agent_store(self) -> bool: - """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user - applications and will be available during application lifecyle. - - Rendezous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`. - Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store. - """ - return False - - @abstractmethod - def next_rendezvous(self) -> RendezvousInfo: - """Main entry-point into the rendezvous barrier. - - Blocks until the rendezvous is complete and the current process is - included in the formed worker group, or a timeout occurs, or the - rendezvous was marked closed. - - Returns: - Instance of :py:class:`RendezvousInfo`. - - Raises: - RendezvousClosedError: - The rendezvous is closed. - RendezvousConnectionError: - The connection to the rendezvous backend has failed. - RendezvousStateError: - The rendezvous state is corrupt. - RendezvousTimeoutError: - The rendezvous did not complete on time. - """ - - @abstractmethod - def is_closed(self) -> bool: - """Check whether the rendezvous has been closed. - - A closed rendezvous means all future attempts to re-rendezvous within - same job will fail. - - ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual - propagation and should not be used for synchronization. The intention is - that if at least one node decides the job is finished, it will close the - rendezvous, and other nodes will soon observe this and stop running as - well. - """ - - @abstractmethod - def set_closed(self): - """Mark the rendezvous as closed.""" - - @abstractmethod - def num_nodes_waiting(self) -> int: - """Return the number of nodes who arrived late at the rendezvous - barrier, hence were not included in the current worker group. - - Callers should periodically call this method to check whether new - nodes are waiting to join the job and if so admit them by calling - :py:meth:`next_rendezvous()` (re-rendezvous). - """ - - @abstractmethod - def get_run_id(self) -> str: - """Return the run id of the rendezvous. - - The run id is a user-defined id that uniquely identifies an instance of - a distributed application. It typically maps to a job id and is used to - allow nodes to join the correct distributed application. - """ - - @abstractmethod - def shutdown(self) -> bool: - """Close all resources that were open for the rendezvous. - - Example:: - - rdzv_handler = ... - try: - store, rank, world_size = rdzv_handler.next_rendezvous() - finally: - rdzv_handler.shutdown() - """ - - -class RendezvousParameters: - """Hold the parameters to construct a :py:class:`RendezvousHandler`. - - Args: - backend: - The name of the backend to use to handle the rendezvous. - endpoint: - The endpoint of the rendezvous, usually in form [:]. - run_id: - The id of the rendezvous. - min_nodes: - The minimum number of nodes to admit to the rendezvous. - max_nodes: - The maximum number of nodes to admit to the rendezvous. - local_addr: - The address of the local node. - **kwargs: - Additional parameters for the specified backend. - """ - - def __init__( - self, - backend: str, - endpoint: str, - run_id: str, - min_nodes: int, - max_nodes: int, - local_addr: Optional[str] = None, - **kwargs, - ): - if not backend: - raise ValueError("The rendezvous backend name must be a non-empty string.") - - if min_nodes < 1: - raise ValueError( - f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero." - ) - if max_nodes < min_nodes: - raise ValueError( - f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or " - f"equal to the minimum number of rendezvous nodes ({min_nodes})." - ) - - self.backend = backend - self.endpoint = endpoint - self.run_id = run_id - self.min_nodes = min_nodes - self.max_nodes = max_nodes - self.config = kwargs - self.local_addr = local_addr - - def get(self, key: str, default: Any = None) -> Any: - """Return the value for ``key`` if ``key`` exists, else ``default``.""" - return self.config.get(key, default) - - def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]: - """Return the value for ``key`` as a ``bool``.""" - value = self.get(key, default) - if value is None or isinstance(value, bool): - return value - if isinstance(value, int): - if value == 1: - return True - if value == 0: - return False - elif isinstance(value, str): - if value.lower() in ["1", "true", "t", "yes", "y"]: - return True - if value.lower() in ["0", "false", "f", "no", "n"]: - return False - raise ValueError( - f"The rendezvous configuration option '{key}' does not represent a valid boolean value." - ) - - def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]: - """Return the value for ``key`` as an ``int``.""" - value = self.get(key, default) - if value is None: - return value - try: - return int(value) - except ValueError as e: - raise ValueError( - f"The rendezvous configuration option '{key}' does not represent a valid integer " - "value." - ) from e - - -RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler] - - -class RendezvousHandlerRegistry: - """Represent a registry of :py:class:`RendezvousHandler` backends.""" - - _registry: Dict[str, RendezvousHandlerCreator] - - def __init__(self) -> None: - self._registry = {} - - def register(self, backend: str, creator: RendezvousHandlerCreator) -> None: - """Register a new rendezvous backend. - - Args: - backend: - The name of the backend. - creator: - The callback to invoke to construct the - :py:class:`RendezvousHandler`. - """ - if not backend: - raise ValueError("The rendezvous backend name must be a non-empty string.") - - current_creator: Optional[RendezvousHandlerCreator] - try: - current_creator = self._registry[backend] - except KeyError: - current_creator = None - - if current_creator is not None and current_creator != creator: - raise ValueError( - f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it " - f"is already registered with '{current_creator}'." - ) - - self._registry[backend] = creator - - def create_handler(self, params: RendezvousParameters) -> RendezvousHandler: - """Create a new :py:class:`RendezvousHandler`.""" - try: - creator = self._registry[params.backend] - except KeyError as e: - raise ValueError( - f"The rendezvous backend '{params.backend}' is not registered. Did you forget " - f"to call `{self.register.__name__}`?" - ) from e - - handler = creator(params) - - # Do some sanity check. - if handler.get_backend() != params.backend: - raise RuntimeError( - f"The rendezvous backend '{handler.get_backend()}' does not match the requested " - f"backend '{params.backend}'." - ) - - return handler - - -# The default global registry instance used by launcher scripts to instantiate -# rendezvous handlers. -rendezvous_handler_registry = RendezvousHandlerRegistry() diff --git a/mindtorch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/mindtorch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py deleted file mode 100644 index 51d449dcd..000000000 --- a/mindtorch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ /dev/null @@ -1,273 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import binascii -import logging -import os -import tempfile -from base64 import b64decode, b64encode -from datetime import timedelta -from typing import Any, cast, Optional, Tuple - -from mindtorch.distributed import FileStore, Store, TCPStore -from mindtorch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState - -from .api import ( - RendezvousConnectionError, - RendezvousError, - RendezvousParameters, - RendezvousStateError, -) -from .dynamic_rendezvous import RendezvousBackend, Token -from .utils import _matches_machine_hostname, parse_rendezvous_endpoint - - -logger = logging.getLogger(__name__) - -# default port for the TCP store -DEFAULT_PORT = 29400 - - -class C10dRendezvousBackend(RendezvousBackend): - """Represents a C10d-backed rendezvous backend. - - Args: - store: - The :py:class:`mindtorch.distributed.Store` instance to use to - communicate with the C10d store. - run_id: - The run id of the rendezvous. - """ - - # See the explanation in the __init__ method. - _NULL_SENTINEL = "Y2FuaW1hZGFt" - - _store: Store - _key: str - - def __init__(self, store: Store, run_id: str) -> None: - if not run_id: - raise ValueError("The run id must be a non-empty string.") - - self._store = store - - self._key = "mindtorch.rendezvous." + run_id - - # The read operation of a store blocks the caller until the specified - # key becomes available. This behavior makes it tricky to use a store - # as a regular key-value dictionary. - # - # As a workaround we initially set a sentinel value as the rendezvous - # state. Whenever this value gets returned we treat it as a None. - self._call_store("compare_set", self._key, "", self._NULL_SENTINEL) - - @property - def name(self) -> str: - """See base class.""" - return "c10d" - - def get_state(self) -> Optional[Tuple[bytes, Token]]: - """See base class.""" - base64_state: bytes = self._call_store("get", self._key) - - return self._decode_state(base64_state) - - def set_state( - self, state: bytes, token: Optional[Token] = None - ) -> Optional[Tuple[bytes, Token, bool]]: - """See base class.""" - base64_state_str: str = b64encode(state).decode() - - if token: - # Shortcut if we know for sure that the token is not valid. - if not isinstance(token, bytes): - result = self.get_state() - if result is not None: - tmp = *result, False - # Python 3.6 does not support tuple unpacking in return - # statements. - return tmp - return None - - token = token.decode() - else: - token = self._NULL_SENTINEL - - base64_state: bytes = self._call_store( - "compare_set", self._key, token, base64_state_str - ) - - state_token_pair = self._decode_state(base64_state) - if state_token_pair is None: - return None - - new_state, new_token = state_token_pair - - # C10d Store's compare_set method does not offer an easy way to find out - # whether our write attempt was successful. As a brute-force solution we - # perform a bitwise comparison of our local state and the remote state. - return new_state, new_token, new_state == state - - def _call_store(self, store_op: str, *args, **kwargs) -> Any: - try: - return getattr(self._store, store_op)(*args, **kwargs) - except (ValueError, RuntimeError, TimeoutError) as exc: - raise RendezvousConnectionError( - "The connection to the C10d store has failed. See inner exception for details." - ) from exc - - def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]: - if base64_state == self._NULL_SENTINEL.encode(): - return None - - try: - state = b64decode(base64_state) - except binascii.Error as exc: - raise RendezvousStateError( - "The state object is corrupt. See inner exception for details." - ) from exc - - return state, base64_state - - -def _create_tcp_store(params: RendezvousParameters) -> TCPStore: - host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT) - - cfg_is_host = params.get_as_bool("is_host") - # If the user has explicitly specified whether our process should host the - # the store, respect it. - if cfg_is_host is not None: - is_host = cfg_is_host - # Otherwise try to determine whether we are the host based on our hostname - # and IP address. - else: - is_host = _matches_machine_hostname(host) - - # The timeout - read_timeout = cast(int, params.get_as_int("read_timeout", 60)) - if read_timeout <= 0: - raise ValueError("The read timeout must be a positive integer.") - - # In specific cases we attempt to instantiate the store twice. For details - # see the explanation in the except clause below. - for is_server in [is_host, False]: - try: - store = TCPStore( - host, - port, - is_master=is_server, - multi_tenant=True, - timeout=timedelta(seconds=read_timeout), - ) - - if is_server: - msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend." - construct_and_record_rdzv_event( - run_id=params.run_id, message=msg, node_state=NodeState.INIT - ) - logger.info(msg) - - break - except (ValueError, RuntimeError, TimeoutError) as exc: - # If we heuristically inferred the value of is_host as True and our - # first attempt to instantiate the TCP store has failed, try it one - # more time with is_host set to False. As an edge case there can be - # more than one process that is part of the same rendezvous on this - # machine and only one of them will eventually host the store. - - if not is_server or cfg_is_host is not None: - raise RendezvousConnectionError( - "The connection to the C10d store has failed. See inner exception for details." - ) from exc - - return store # type: ignore[possibly-undefined] - - -def _create_file_store(params: RendezvousParameters) -> FileStore: - # If a user specifies an endpoint, we treat it as a path to a file. - if params.endpoint: - path = params.endpoint - else: - try: - # The temporary file is readable and writable only by the user of - # this process. - _, path = tempfile.mkstemp() - except OSError as exc: - raise RendezvousError( - "The file creation for C10d store has failed. See inner exception for details." - ) from exc - - try: - store = FileStore(path) - except (ValueError, RuntimeError) as exc: - raise RendezvousConnectionError( - "The connection to the C10d store has failed. See inner exception for details." - ) from exc - - return store - - -def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]: - """Create a new :py:class:`C10dRendezvousBackend` from the specified parameters. - - +--------------+-----------------------------------------------------------+ - | Parameter | Description | - +==============+===========================================================+ - | store_type | The type of the C10d store. The currently supported types | - | | are "tcp" and "file" which correspond to | - | | :py:class:`mindtorch.distributed.TCPStore` and | - | | :py:class:`mindtorch.distributed.FileStore`, respectively. | - | | Defaults to "tcp". | - +--------------+-----------------------------------------------------------+ - | read_timeout | The read timeout, in seconds, for store operations. | - | | Defaults to 60 seconds. | - | | | - | | Note this only applies to | - | | :py:class:`mindtorch.distributed.TCPStore`. It is not relevant| - | | to :py:class:`mindtorch.distributed.FileStore` which does not | - | | take in timeout as a parameter. | - +--------------+-----------------------------------------------------------+ - | is_host | A boolean value indicating whether this backend instance | - | | will host the C10d store. If not specified it will be | - | | inferred heuristically by matching the hostname or the IP | - | | address of this machine against the specified rendezvous | - | | endpoint. Defaults to ``None``. | - | | | - | | Note that this configuration option only applies to | - | | :py:class:`mindtorch.distributed.TCPStore`. In normal | - | | circumstances you can safely skip it; the only time when | - | | it is needed is if its value cannot be correctly | - | | determined (e.g. the rendezvous endpoint has a CNAME as | - | | the hostname or does not match the FQDN of the machine). | - +--------------+-----------------------------------------------------------+ - """ - # As of today we only support TCPStore and FileStore. Other store types do - # not have the required functionality (e.g. compare_set) yet. - store_type = params.get("store_type", "tcp").strip().lower() - store: Store - - try: - if store_type == "file": - store = _create_file_store(params) - elif store_type == "tcp": - store = _create_tcp_store(params) - else: - raise ValueError( - "Invalid store type given. Currently only supports file and tcp." - ) - - backend = C10dRendezvousBackend(store, params.run_id) - - except Exception as e: - construct_and_record_rdzv_event( - message=f"{type(e).__name__}: {str(e)}", - run_id=params.run_id, - node_state=NodeState.FAILED, - ) - raise - - return backend, store diff --git a/mindtorch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/mindtorch/distributed/elastic/rendezvous/dynamic_rendezvous.py deleted file mode 100644 index 2c71a7b03..000000000 --- a/mindtorch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ /dev/null @@ -1,1431 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import inspect -import logging -import os -import pickle -import socket -import threading -import time -import weakref -from abc import ABC, abstractmethod -from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple - -import mindtorch.distributed as dist -from mindtorch.distributed import Store -from mindtorch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState - -from .api import ( - RendezvousClosedError, - RendezvousError, - RendezvousGracefulExitError, - RendezvousHandler, - RendezvousInfo, - RendezvousParameters, - RendezvousStateError, - RendezvousStoreInfo, - RendezvousTimeoutError, -) -from .utils import _delay, _PeriodicTimer - - -__all__ = [ - "RendezvousBackend", - "RendezvousTimeout", - "RendezvousSettings", - "DynamicRendezvousHandler", - "create_handler", -] - -logger = logging.getLogger(__name__) - - -def get_method_name(depth=2): - if len(inspect.stack()) > depth: - return inspect.stack()[depth].function - return "no_method_name" - - -Token = Any -"""Represent an opaque fencing token used by the rendezvous backend.""" - - -class RendezvousBackend(ABC): - """Represent a backend that holds the rendezvous state.""" - - @property - @abstractmethod - def name(self) -> str: - """Get the name of the backend.""" - - @abstractmethod - def get_state(self) -> Optional[Tuple[bytes, Token]]: - """Get the rendezvous state. - - Returns: - A tuple of the encoded rendezvous state and its fencing token or - ``None`` if no state is found in the backend. - - Raises: - RendezvousConnectionError: - The connection to the backend has failed. - RendezvousStateError: - The rendezvous state is corrupt. - """ - - @abstractmethod - def set_state( - self, state: bytes, token: Optional[Token] = None - ) -> Optional[Tuple[bytes, Token, bool]]: - """Set the rendezvous state. - - The new rendezvous state is set conditionally: - - - If the specified ``token`` matches the fencing token stored in the - backend, the state will be updated. The new state will be returned - to the caller along with its fencing token. - - If the specified ``token`` does not match the fencing token stored - in the backend, the state won't be updated; instead the existing - state along with its fencing token will be returned to the caller. - - If the specified ``token`` is ``None``, the new state will be set - only if there is no existing state in the backend. Either the new - state or the existing state along with its fencing token will be - returned to the caller. - - Args: - state: - The encoded rendezvous state. - token: - An optional fencing token that was retrieved by a previous call - to :py:meth:`get_state` or ``set_state()``. - - Returns: - A tuple of the serialized rendezvous state, its fencing token, and - a boolean value indicating whether our set attempt succeeded. - - Raises: - RendezvousConnectionError: - The connection to the backend has failed. - RendezvousStateError: - The rendezvous state is corrupt. - """ - - -class RendezvousTimeout: - """Hold the timeout configuration of a rendezvous. - - Args: - join: - The time within which the rendezvous is expected to complete. - last_call: - An additional wait amount before completing the rendezvous once the - rendezvous has the minimum number of required participants. - close: - The time within which the rendezvous is expected to close after a - call to :py:meth:`RendezvousHandler.set_closed` or - :py:meth:`RendezvousHandler.shutdown`. - keep_alive: - The time within which a keep-alive heartbeat is expected to - complete. - """ - - _ZERO = timedelta(0) - - _DEFAULT_TIMEOUTS = { - "join": timedelta(seconds=600), - "last_call": timedelta(seconds=30), - "close": timedelta(seconds=30), - "heartbeat": timedelta(seconds=5), - } - - _join: timedelta - _last_call: timedelta - _close: timedelta - _heartbeat: timedelta - - def __init__( - self, - join: Optional[timedelta] = None, - last_call: Optional[timedelta] = None, - close: Optional[timedelta] = None, - heartbeat: Optional[timedelta] = None, - ) -> None: - self._set_timeouts( - join=join, last_call=last_call, close=close, heartbeat=heartbeat - ) - - @property - def join(self) -> timedelta: - """Get the join timeout.""" - return self._join - - @property - def last_call(self) -> timedelta: - """Get the last call timeout.""" - return self._last_call - - @property - def close(self) -> timedelta: - """Get the close timeout.""" - return self._close - - @property - def heartbeat(self) -> timedelta: - """Get the keep-alive heartbeat timeout.""" - return self._heartbeat - - def _set_timeouts(self, **timeouts: Optional[timedelta]): - for name, timeout in timeouts.items(): - if timeout is None: - timeout = self._DEFAULT_TIMEOUTS[name] - if timeout <= self._ZERO: - raise ValueError(f"The {name} timeout ({timeout}) must be positive.") - setattr(self, "_" + name, timeout) - - -@dataclass(repr=False, eq=False, frozen=True) -class RendezvousSettings: - """Hold the settings of the rendezvous. - - Attributes: - run_id: - The run id of the rendezvous. - min_nodes: - The minimum number of nodes to admit to the rendezvous. - max_nodes: - The maximum number of nodes to admit to the rendezvous. - timeout: - The timeout configuration of the rendezvous. - keep_alive_interval: - The amount of time a node waits before sending a heartbeat to keep - it alive in the rendezvous. - keep_alive_max_attempt: - The maximum number of failed heartbeat attempts after which a node - is considered dead. - """ - - run_id: str - min_nodes: int - max_nodes: int - timeout: RendezvousTimeout - keep_alive_interval: timedelta - keep_alive_max_attempt: int - - -@dataclass(eq=True, order=True, frozen=True) -class _NodeDesc: - """Describe a node in the rendezvous. - - Attributes: - addr: - The FQDN of the node or user specified local node address. - pid: - The id of the process in which the rendezvous handler runs. - local_id: - A process-wide unique id. - """ - - addr: str - pid: int - local_id: int - - def __repr__(self) -> str: - return f"{self.addr}_{self.pid}_{self.local_id}" - - -class _NodeDescGenerator: - """Generate node descriptors. - - A node descriptor is a combination of an FQDN, a process id, and an auto- - incremented integer that uniquely identifies a node in the rendezvous. - """ - - _lock: threading.Lock - _local_id: int - - def __init__(self) -> None: - self._lock = threading.Lock() - - # An integer that is incremented with each call to generate(). - self._local_id = 0 - - def generate(self, local_addr: Optional[str] = None) -> _NodeDesc: - # This method can be called by multiple threads concurrently; therefore, - # we must increment the integer atomically. - with self._lock: - local_id = self._local_id - - self._local_id += 1 - - return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id) - - -class _RendezvousState: - """Hold the state of a rendezvous. - - Attributes: - round: - The current round of the rendezvous. - complete: - A boolean value indicating whether the current round of the - rendezvous is complete. - deadline: - The time at which the current round of the rendezvous will be - considered complete if it is still waiting for nodes to join. - closed: - A boolean value indicating whether the rendezvous is closed. - participants: - A dictionary of the participants and their corresponding ranks. - wait_list: - A set of nodes that are waiting to participate in the next round of - the rendezvous. - redundancy_list: - A set of nodes that are redundant in the current round and can join - the next rendezvous without triggering re-rendezvous. - last_heartbeats: - A dictionary containing each node's last heartbeat time. - """ - - round: int - complete: bool - deadline: Optional[datetime] - closed: bool - participants: Dict[_NodeDesc, int] - wait_list: Set[_NodeDesc] - redundancy_list: Set[_NodeDesc] - last_heartbeats: Dict[_NodeDesc, datetime] - - def __init__(self) -> None: - self.round = 0 - self.complete = False - self.deadline = None - self.closed = False - self.participants = {} - self.wait_list = set() - self.redundancy_list = set() - self.last_heartbeats = {} - - -def _remove_participant_epilogue( - state: _RendezvousState, settings: RendezvousSettings -) -> None: - if state.complete: - # If we do not have any participants left, move to the next round. - if not state.participants: - msg = "No participants left in the rendezvous, marking rendezvous as incomplete" - logger.debug(msg) - state.complete = False - - state.round += 1 - else: - if len(state.participants) < settings.min_nodes: - msg = ( - f"Number of participants {len(state.participants)}) less than" - f"min_nodes {settings.min_nodes}, clearning deadline in state" - ) - logger.debug(msg) - state.deadline = None - - -class _RendezvousStateHolder(ABC): - """Hold the shared rendezvous state synced with other nodes.""" - - @property - @abstractmethod - def state(self) -> _RendezvousState: - """Get the local state.""" - - @abstractmethod - def sync(self) -> Optional[bool]: - """Read or writes the latest state. - - Returns: - A boolean value indicating whether the local state, in case marked - as dirty, was successfully synced with other nodes. - """ - - @abstractmethod - def mark_dirty(self) -> None: - """Mark the local state as dirty.""" - - -class _BackendRendezvousStateHolder(_RendezvousStateHolder): - """Hold the rendezvous state synced with other nodes via a backend. - - Args: - backend: - The rendezvous backend to use. - settings: - The rendezvous settings. - cache_duration: - The amount of time, in seconds, to cache the last rendezvous state - before requesting it from the backend again. - """ - - _backend: RendezvousBackend - _state: _RendezvousState - _settings: RendezvousSettings - _cache_duration: int - _token: Token - _dirty: bool - _last_sync_time: float - _dead_nodes: List[_NodeDesc] - - def __init__( - self, - backend: RendezvousBackend, - settings: RendezvousSettings, - cache_duration: int = 1, - ) -> None: - self._backend = backend - self._state = _RendezvousState() - self._settings = settings - self._cache_duration = cache_duration - self._token = None - self._dirty = False - self._last_sync_time = -1 - self._dead_nodes = [] - - def _record(self, message: str, node_state: NodeState = NodeState.RUNNING): - construct_and_record_rdzv_event( - name=f"{self.__class__.__name__}.{get_method_name()}", - run_id=self._settings.run_id, - message=message, - node_state=node_state, - ) - - @property - def state(self) -> _RendezvousState: - """See base class.""" - return self._state - - def sync(self) -> Optional[bool]: - """See base class.""" - state_bits: Optional[bytes] = None - - token = None - - has_set: Optional[bool] - - if self._dirty: - has_set = False - - state_bits = pickle.dumps(self._state) - - set_response = self._backend.set_state(state_bits, self._token) - if set_response is not None: - state_bits, token, has_set = set_response - else: - has_set = None - - if self._cache_duration > 0: - # Avoid overloading the backend if we are asked to retrieve the - # state repeatedly. Try to serve the cached state. - if self._last_sync_time >= max( - time.monotonic() - self._cache_duration, 0 - ): - return None - - get_response = self._backend.get_state() - if get_response is not None: - state_bits, token = get_response - - if state_bits is not None: - try: - self._state = pickle.loads(state_bits) - except pickle.PickleError as exc: - raise RendezvousStateError( - "The rendezvous state is corrupt. See inner exception for details." - ) from exc - else: - self._state = _RendezvousState() - - if has_set and self._dead_nodes and logger.isEnabledFor(logging.DEBUG): - node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes) - - msg = ( - f"As part of the sync operation the node(s) {node_list} have been removed from the " - f"rendezvous '{self._settings.run_id}' since they had no heartbeat." - ) - self._record(message=msg) - logger.debug(msg) - - self._token = token - - self._dirty = False - - self._last_sync_time = time.monotonic() - - self._sanitize() - - return has_set - - def _sanitize(self) -> None: - state = self._state - - expire_time = datetime.now(timezone.utc) - ( - self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt - ) - - # Filter out the dead nodes. - self._dead_nodes = [ - node - for node, last_heartbeat in state.last_heartbeats.items() - if last_heartbeat < expire_time - ] - - participant_removed = False - - for dead_node in self._dead_nodes: - msg = f"Detected dead node '{dead_node}', removing it from the rendezvous" - logger.debug(msg) - del state.last_heartbeats[dead_node] - - try: - del state.participants[dead_node] - - participant_removed = True - except KeyError: - pass - - try: - state.wait_list.remove(dead_node) - except KeyError: - pass - - try: - state.redundancy_list.remove(dead_node) - except KeyError: - pass - - if participant_removed: - # Common epilogue shared with the _remove_from_participants() - # function of _DistributedRendezvousOpExecutor. - _remove_participant_epilogue(state, self._settings) - - def mark_dirty(self) -> None: - """See base class. - - If the local rendezvous state is dirty, the next sync call will try to - write the changes back to the backend. However this attempt might fail - if another node, which had the same state, also made changes and wrote - them before us. - """ - self._dirty = True - - -class _Action(Enum): - """Specifies the possible actions based on the state of the rendezvous.""" - - KEEP_ALIVE = 1 - ADD_TO_PARTICIPANTS = 2 - ADD_TO_WAIT_LIST = 3 - ADD_TO_REDUNDANCY_LIST = 4 - REMOVE_FROM_PARTICIPANTS = 5 - REMOVE_FROM_WAIT_LIST = 6 - REMOVE_FROM_REDUNDANCY_LIST = 7 - MARK_RENDEZVOUS_COMPLETE = 8 - MARK_RENDEZVOUS_CLOSED = 9 - SYNC = 10 - ERROR_CLOSED = 11 - ERROR_TIMEOUT = 12 - FINISH = 13 - - -class _RendezvousContext: - """Holds the context of the rendezvous. - - Attributes: - node: - The node descriptor associated with the current rendezvous handler - instance. - state: - The current state of the rendezvous. - settings: - The rendezvous settings. - """ - - node: _NodeDesc - state: _RendezvousState - settings: RendezvousSettings - - def __init__( - self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings - ) -> None: - self.node = node - self.state = state - self.settings = settings - - -class _RendezvousOpExecutor(ABC): - """Execute rendezvous operations.""" - - @abstractmethod - def run( - self, - state_handler: Callable[[_RendezvousContext, float], _Action], - deadline: float, - update_deadline: Optional[Callable[[timedelta], float]] = None, - ) -> None: - """Execute a rendezvous operation. - - An operation is run inside a state machine and is expected to transition - the rendezvous from one state to another. - - Args: - state_handler: - A callable that is expected to return the next state transition - action based on the current state of the rendezvous. - deadline: - The time, in seconds, at which the operation will be considered - timed-out. - update_deadline: - Function to generate a new operation deadline if the current - node may participate in the next rendezvous. - """ - - -class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor): - """Execute rendezvous operations using a shared state. - - Args: - node: - The node descriptor associated with the current rendezvous handler - instance. - state_holder: - The ``RendezvousStateHolder`` to use to sync the rendezvous state - with other nodes. - settings: - The rendezvous settings. - """ - - _node: _NodeDesc - _state: _RendezvousState - _state_holder: _RendezvousStateHolder - _settings: RendezvousSettings - - def __init__( - self, - node: _NodeDesc, - state_holder: _RendezvousStateHolder, - settings: RendezvousSettings, - ) -> None: - self._node = node - self._state_holder = state_holder - self._settings = settings - - def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None: - construct_and_record_rdzv_event( - name=f"{self.__class__.__name__}.{get_method_name()}", - run_id=self._settings.run_id, - message=message, - node_state=node_state, - hostname=self._node.addr, - pid=self._node.pid, - local_id=self._node.local_id, - ) - - def run( - self, - state_handler: Callable[[_RendezvousContext, float], _Action], - deadline: float, - update_deadline: Optional[Callable[[timedelta], float]] = None, - ) -> None: - """See base class.""" - action = None - while action != _Action.FINISH: - # Reads or writes the latest rendezvous state shared by all nodes in - # the rendezvous. Note that our local changes might get overridden - # by another node if that node synced its changes before us. - has_set = self._state_holder.sync() - if has_set is not None: - if has_set: - msg = ( - f"The node '{self._node}' has successfully synced its local changes with " - f"other nodes in the rendezvous '{self._settings.run_id}'." - ) - else: - msg = ( - f"The node '{self._node}' has a stale state and failed to sync its local " - f"changes with other nodes in the rendezvous '{self._settings.run_id}'." - ) - - self._record(message=msg) - logger.debug(msg) - - self._state = self._state_holder.state - - ctx = _RendezvousContext(self._node, self._state, self._settings) - - # Determine the next action to take based on the current state of - # the rendezvous. - action = state_handler(ctx, deadline) - - if action == _Action.FINISH: - continue - - if action == _Action.ERROR_CLOSED: - raise RendezvousClosedError - - if action == _Action.ERROR_TIMEOUT: - raise RendezvousTimeoutError - - if action == _Action.SYNC: - # Delay the execution by one second to avoid overloading the - # backend if we are asked to poll for state changes. - _delay(seconds=1) - else: - if action == _Action.KEEP_ALIVE: - self._keep_alive() - elif action == _Action.ADD_TO_PARTICIPANTS: - self._add_to_participants() - elif action == _Action.ADD_TO_WAIT_LIST: - self._add_to_wait_list() - elif action == _Action.ADD_TO_REDUNDANCY_LIST: - self._add_to_redundancy_list() - elif action == _Action.REMOVE_FROM_PARTICIPANTS: - self._remove_from_participants() - elif action == _Action.REMOVE_FROM_WAIT_LIST: - self._remove_from_wait_list() - elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST: - self._remove_from_redundancy_list() - # update deadline since the node may participate in rendezvous process - if update_deadline: - deadline = update_deadline(self._settings.timeout.join) - elif action == _Action.MARK_RENDEZVOUS_COMPLETE: - self._mark_rendezvous_complete() - elif action == _Action.MARK_RENDEZVOUS_CLOSED: - self._mark_rendezvous_closed() - - # Attempt to sync our changes back to other nodes. - self._state_holder.mark_dirty() - - def _keep_alive(self) -> None: - msg = ( - f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous " - f"'{self._settings.run_id}'. Pending sync." - ) - self._record(message=msg) - logger.debug(msg) - - self._state.last_heartbeats[self._node] = datetime.now(timezone.utc) - - def _add_to_participants(self) -> None: - msg = ( - f"The node '{self._node}' added itself to the participants of round " - f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." - ) - self._record(message=msg) - logger.debug(msg) - - state = self._state - - try: - state.wait_list.remove(self._node) - except KeyError: - pass - - # The ranks of the participants will be set once the rendezvous is - # complete. - state.participants[self._node] = 0 - - self._keep_alive() - - if len(state.participants) == self._settings.min_nodes: - state.deadline = ( - datetime.now(timezone.utc) + self._settings.timeout.last_call - ) - - if len(state.participants) == self._settings.max_nodes: - self._mark_rendezvous_complete() - - def _add_to_wait_list(self) -> None: - msg = ( - f"The node '{self._node}' added itself to the wait list of round " - f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." - ) - self._record(message=msg) - logger.debug(msg) - - if self._node in self._state.redundancy_list: - self._state.redundancy_list.remove(self._node) - self._state.wait_list.add(self._node) - - self._keep_alive() - - def _add_to_redundancy_list(self) -> None: - msg = ( - f"The node '{self._node}' added itself to the redundancy list of round " - f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." - ) - self._record(message=msg) - logger.debug(msg) - - self._state.redundancy_list.add(self._node) - - self._keep_alive() - - def _remove_from_participants(self) -> None: - msg = ( - f"The node '{self._node}' removed itself from the participants of round " - f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." - ) - self._record(message=msg) - logger.debug(msg) - - state = self._state - - del state.participants[self._node] - - del state.last_heartbeats[self._node] - - # Common epilogue shared with the sanitizer() function of - # _BackendRendezvousStateHolder. - _remove_participant_epilogue(state, self._settings) - - def _remove_from_wait_list(self) -> None: - msg = ( - f"The node '{self._node}' removed itself from the wait list of round " - f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." - ) - self._record(message=msg) - logger.debug(msg) - - self._state.wait_list.remove(self._node) - - del self._state.last_heartbeats[self._node] - - def _remove_from_redundancy_list(self) -> None: - msg = ( - f"The node '{self._node}' removed itself from the redunant list of round " - f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." - ) - self._record(message=msg) - logger.debug(msg) - - self._state.redundancy_list.remove(self._node) - - del self._state.last_heartbeats[self._node] - - def _mark_rendezvous_complete(self) -> None: - msg = ( - f"The node '{self._node}' marked round {self._state.round} of the rendezvous " - f"'{self._settings.run_id}' as complete. Pending sync." - ) - self._record(message=msg, node_state=NodeState.SUCCEEDED) - logger.debug(msg) - - state = self._state - - state.complete = True - state.deadline = None - - # Assign the ranks. - for rank, node in enumerate(sorted(state.participants)): - state.participants[node] = rank - - def _mark_rendezvous_closed(self) -> None: - msg = ( - f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. " - "Pending sync." - ) - self._record(message=msg, node_state=NodeState.SUCCEEDED) - logger.debug(msg) - - self._state.closed = True - - -def _should_keep_alive(ctx: _RendezvousContext) -> bool: - """Determine whether a keep-alive heartbeat should be sent.""" - try: - last_heartbeat = ctx.state.last_heartbeats[ctx.node] - except KeyError: - return False - - return ( - last_heartbeat <= datetime.now(timezone.utc) - ctx.settings.keep_alive_interval - ) - - -class _RendezvousExitOp: - """Represent a rendezvous exit operation.""" - - def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: - if ctx.node in ctx.state.participants: - if time.monotonic() > deadline: - return _Action.ERROR_TIMEOUT - return _Action.REMOVE_FROM_PARTICIPANTS - return _Action.FINISH - - -class _RendezvousJoinOp: - """Represent a rendezvous join operation.""" - - def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: - state = ctx.state - - # A closed rendezvous means that it no longer accepts new nodes. - if state.closed: - if ctx.node in state.redundancy_list: - msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous." - raise RendezvousGracefulExitError(msg) - return _Action.ERROR_CLOSED - - if ctx.node in state.redundancy_list: - msg = f"The node {ctx.node} is in redunancy list" - logger.debug(msg) - # don't apply the timeout logic here, since we want to allow the node to rejoin - if len(state.participants) == ctx.settings.max_nodes: - if _should_keep_alive(ctx): - return _Action.KEEP_ALIVE - else: - return _Action.SYNC - else: - # transition to waiting state that will respect timeouts. - msg = f"The node {ctx.node} is removed from redunancy list" - logger.debug(msg) - return _Action.REMOVE_FROM_REDUNDANCY_LIST - - is_participant = ctx.node in state.participants - - # If we are part of the rendezvous and it is already complete there is - # no further action to take. - if state.complete and is_participant: - return _Action.FINISH - - now = time.monotonic() - if now > deadline: - rollback_period = 5 # 5 seconds - - # If we still have time to rollback (a short period on top of the - # operation deadline), try to remove ourself from the rendezvous. - # It is okay if we can't though as our keep-alive will eventually - # expire. - if now <= deadline + rollback_period: - # If we are part of the rendezvous, it means we couldn't find - # enough participants to complete it on time. - if is_participant: - return _Action.REMOVE_FROM_PARTICIPANTS - # If we are in the wait list, it means we couldn't wait till the - # next round of the rendezvous. - if ctx.node in state.wait_list: - return _Action.REMOVE_FROM_WAIT_LIST - return _Action.ERROR_TIMEOUT - - if state.complete: - # If we are here, it means we are not part of the rendezvous. In - # case the rendezvous has capacity for additional participants add - # ourself to the wait list for the next round. - if len(state.participants) < ctx.settings.max_nodes: - if ctx.node not in state.wait_list: - return _Action.ADD_TO_WAIT_LIST - elif len(state.participants) >= ctx.settings.max_nodes: - if ( - ctx.node not in state.redundancy_list - and ctx.node not in state.wait_list - ): - return _Action.ADD_TO_REDUNDANCY_LIST - elif is_participant: - # If the rendezvous has enough number of participants including us, - # check whether we have passed the rendezvous deadline. If yes, - # complete it. - if ( - len(state.participants) >= ctx.settings.min_nodes - and len(state.participants) <= ctx.settings.max_nodes - and state.deadline is not None - ): - if state.deadline < datetime.now(timezone.utc): - msg = ( - f"The node '{ctx.node}' marking the rendezvous complete, " - f"quorum established within deadline" - ) - logger.debug(msg) - return _Action.MARK_RENDEZVOUS_COMPLETE - else: - msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached" - logger.debug(msg) - else: - msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants" - logger.debug(msg) - else: - # The rendezvous is not complete yet and we are not part of it. Try - # to join. - return _Action.ADD_TO_PARTICIPANTS - - if _should_keep_alive(ctx): - return _Action.KEEP_ALIVE - - # At this point either the rendezvous is not complete, but we are part - # of it, which means we have to wait for other participants to join; or - # the rendezvous is complete, but we are not part of it, which means we - # have to wait for the next round. - return _Action.SYNC - - -class _RendezvousCloseOp: - """Represent a rendezvous close operation.""" - - def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: - if ctx.state.closed: - return _Action.FINISH - if time.monotonic() > deadline: - return _Action.ERROR_TIMEOUT - return _Action.MARK_RENDEZVOUS_CLOSED - - -class _RendezvousKeepAliveOp: - """Represent a rendezvous keep-alive update operation.""" - - def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: - if _should_keep_alive(ctx): - if time.monotonic() > deadline: - return _Action.ERROR_TIMEOUT - return _Action.KEEP_ALIVE - return _Action.FINISH - - -class DynamicRendezvousHandler(RendezvousHandler): - """Represent a handler that sets up a rendezvous among a set of nodes.""" - - # Static - _node_desc_generator = _NodeDescGenerator() - - _this_node: _NodeDesc - _settings: RendezvousSettings - _backend_name: str - _store: Store - _state_holder: _RendezvousStateHolder - _op_executor: _RendezvousOpExecutor - _heartbeat_lock: threading.Lock - _keep_alive_timer: Optional[_PeriodicTimer] - - @classmethod - def from_backend( - cls, - run_id: str, - store: Store, - backend: RendezvousBackend, - min_nodes: int, - max_nodes: int, - local_addr: Optional[str] = None, - timeout: Optional[RendezvousTimeout] = None, - ): - """Create a new :py:class:`DynamicRendezvousHandler`. - - Args: - run_id: - The run id of the rendezvous. - store: - The C10d store to return as part of the rendezvous. - backend: - The backend to use to hold the rendezvous state. - min_nodes: - The minimum number of nodes to admit to the rendezvous. - max_nodes: - The maximum number of nodes to admit to the rendezvous. - local_addr: - The local node address. - timeout: - The timeout configuration of the rendezvous. - """ - # We associate each handler instance with a unique node descriptor. - node = cls._node_desc_generator.generate(local_addr) - - settings = RendezvousSettings( - run_id, - min_nodes, - max_nodes, - timeout or RendezvousTimeout(), - keep_alive_interval=timedelta(seconds=5), - keep_alive_max_attempt=3, - ) - - state_holder = _BackendRendezvousStateHolder(backend, settings) - - return cls(node, settings, backend.name, store, state_holder) - - def __init__( - self, - node: _NodeDesc, - settings: RendezvousSettings, - backend_name: str, - store: Store, - state_holder: _RendezvousStateHolder, - ) -> None: - if not settings.run_id: - raise ValueError("The run id must be a non-empty string.") - - if settings.min_nodes < 1: - raise ValueError( - f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero." - ) - - if settings.max_nodes < settings.min_nodes: - raise ValueError( - f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal " - f"to the minimum number of nodes ({settings.min_nodes})." - ) - - self._this_node = node - - self._settings = settings - - self._backend_name = backend_name - - self._store = store - - self._state_holder = state_holder - - self._op_executor = _DistributedRendezvousOpExecutor( - self._this_node, self._state_holder, self._settings - ) - - self._heartbeat_lock = threading.Lock() - - self._keep_alive_timer = None - - # Cached shared store server reference - self._shared_tcp_store_server: Optional[dist.Store] = None - - self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None - - def _record( - self, - message: str, - node_state: NodeState = NodeState.RUNNING, - rank: Optional[int] = None, - ) -> None: - construct_and_record_rdzv_event( - name=f"{self.__class__.__name__}.{get_method_name()}", - run_id=self._settings.run_id, - message=message, - node_state=node_state, - hostname=self._this_node.addr, - pid=self._this_node.pid, - local_id=self._this_node.local_id, - rank=rank, - ) - - def _create_tcp_store_server(self, master_addr, master_port) -> dist.TCPStore: - return dist.TCPStore( - host_name=master_addr, - port=master_port, - is_master=True, - multi_tenant=True, - ) - - @property - def settings(self) -> RendezvousSettings: - """Get the settings of the rendezvous.""" - return self._settings - - def get_backend(self) -> str: - """See base class.""" - return self._backend_name - - @property - def use_agent_store(self) -> bool: - """See base class.""" - return os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") != "1" - - def next_rendezvous(self) -> RendezvousInfo: - """See base class.""" - msg = ( - f"The node '{self._this_node}' attempts to join the next round of the rendezvous " - f"'{self._settings.run_id}'." - ) - self._record(message=msg) - logger.info(msg) - - try: - self._stop_heartbeats() - - # Delay the execution for a small random amount of time if this is our - # first run. This will slightly skew the rendezvous attempts across the - # nodes and reduce the load on the backend. - if self._state_holder.state.round == 0: - _delay(seconds=(0, 0.3)) - - exit_op = _RendezvousExitOp() - join_op = _RendezvousJoinOp() - - deadline = self._get_deadline(self._settings.timeout.join) - self._op_executor.run(exit_op, deadline) - self._op_executor.run(join_op, deadline, self._get_deadline) - - self._start_heartbeats() - - rank, world_size = self._get_world() - store = self._get_store() - - except Exception as e: - self._record( - message=f"{type(e).__name__}: {str(e)}", - node_state=NodeState.FAILED, - ) - raise - - msg = ( - f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of " - f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size " - f"{world_size}." - ) - self._record(message=msg, rank=rank) - logger.info(msg) - - # opt-out option of TCPStore sharing - if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1": - bootstrap_store_info = RendezvousStoreInfo.build( - rank, store, local_addr=self._this_node.addr - ) - return RendezvousInfo( - store, - rank, - world_size, - bootstrap_store_info, - ) - - # This will only be hit when TCPStore sharing is enabled. - if self._bootstrap_store_info is None: - # To avoid race in get_free_port because we release the port after the call, - # we want to create a TCPStore server soon afterwards. - server_port = 0 - if rank == 0: - self._shared_tcp_store_server = self._create_tcp_store_server( - self._this_node.addr, server_port - ) - server_port = self._shared_tcp_store_server.port - self._bootstrap_store_info = RendezvousStoreInfo.build( - rank, - store, - local_addr=self._this_node.addr, - server_port=server_port, # For non-0 rank, this is a no-op - ) - - assert self._bootstrap_store_info is not None - if rank == 0: - assert self._shared_tcp_store_server is not None - - return RendezvousInfo( - store, - rank, - world_size, - self._bootstrap_store_info, # type: ignore[assignment] - ) - - def is_closed(self) -> bool: - """See base class.""" - try: - with self._heartbeat_lock: - self._state_holder.sync() - - return self._state_holder.state.closed - - except Exception as e: - self._record( - message=f"{type(e).__name__}: {str(e)}", - node_state=NodeState.FAILED, - ) - raise - - def set_closed(self) -> None: - """See base class.""" - try: - with self._heartbeat_lock: - self._close() - except Exception as e: - self._record( - message=f"{type(e).__name__}: {str(e)}", - node_state=NodeState.FAILED, - ) - raise - - def num_nodes_waiting(self) -> int: - """See base class.""" - try: - with self._heartbeat_lock: - self._state_holder.sync() - - return len(self._state_holder.state.wait_list) - - except Exception as e: - self._record( - message=f"{type(e).__name__}: {str(e)}", - node_state=NodeState.FAILED, - ) - raise - - def get_run_id(self) -> str: - """See base class.""" - return self._settings.run_id - - def shutdown(self) -> bool: - """See base class.""" - self._stop_heartbeats() - - try: - self._close() - - return True - except RendezvousError as ex: - msg = ( - f"The node '{self._this_node}' has failed to shutdown the rendezvous " - f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}." - ) - self._record(message=msg, node_state=NodeState.FAILED) - logger.warning(msg) - - return False - except Exception as e: - self._record( - message=f"{type(e).__name__}: {str(e)}", - node_state=NodeState.FAILED, - ) - raise - - def _close(self) -> None: - op = _RendezvousCloseOp() - - deadline = self._get_deadline(self._settings.timeout.close) - - self._op_executor.run(op, deadline) - - msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'." - self._record(message=msg, node_state=NodeState.SUCCEEDED) - logger.info(msg) - - @staticmethod - def _keep_alive_weak(weak_self) -> None: - self = weak_self() - if self is not None: - self._keep_alive() - - def _keep_alive(self) -> None: - self._heartbeat_lock.acquire() - - op = _RendezvousKeepAliveOp() - - deadline = self._get_deadline(self._settings.timeout.heartbeat) - - try: - self._op_executor.run(op, deadline) - - msg = ( - f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " - f"'{self._settings.run_id}'." - ) - self._record(message=msg) - logger.debug(msg) - except RendezvousError as ex: - msg = ( - f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " - f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." - ) - self._record(message=msg, node_state=NodeState.FAILED) - logger.warning(msg) - finally: - self._heartbeat_lock.release() - - def _start_heartbeats(self) -> None: - self._keep_alive_timer = _PeriodicTimer( - self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) - ) - - self._keep_alive_timer.set_name( - f"RendezvousKeepAliveTimer_{self._this_node.local_id}" - ) - - self._keep_alive_timer.start() - - def _stop_heartbeats(self) -> None: - if self._keep_alive_timer is None: - return - - self._keep_alive_timer.cancel() - - def _get_world(self) -> Tuple[int, int]: - state = self._state_holder.state - - return state.participants[self._this_node], len(state.participants) - - def _wrap_store(self, store: Store) -> Store: - key_prefix = ( - f"mindtorch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" - ) - - return dist.PrefixStore(key_prefix, store) - - def _get_store(self) -> Store: - return self._wrap_store(self._store) - - def _get_deadline(self, timeout: timedelta) -> float: - return time.monotonic() + timeout.total_seconds() - - -def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]: - timeout = params.get_as_int(key + "_timeout") - if timeout is None: - return None - return timedelta(seconds=timeout) - - -def create_handler( - store: Store, backend: RendezvousBackend, params: RendezvousParameters -) -> DynamicRendezvousHandler: - """Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters. - - Args: - store: - The C10d store to return as part of the rendezvous. - backend: - The backend to use to hold the rendezvous state. - - +-------------------+------------------------------------------------------+ - | Parameter | Description | - +===================+======================================================+ - | join_timeout | The total time, in seconds, within which the | - | | rendezvous is expected to complete. Defaults to 600 | - | | seconds. | - +-------------------+------------------------------------------------------+ - | last_call_timeout | An additional wait amount, in seconds, before | - | | completing the rendezvous once the minimum number of | - | | nodes has been reached. Defaults to 30 seconds. | - +-------------------+------------------------------------------------------+ - | close_timeout | The time, in seconds, within which the rendezvous is | - | | expected to close after a call to | - | | :py:meth:`RendezvousHandler.set_closed` or | - | | :py:meth:`RendezvousHandler.shutdown`. Defaults to | - | | 30 seconds. | - +-------------------+------------------------------------------------------+ - """ - try: - timeout = RendezvousTimeout( - _get_timeout(params, "join"), - _get_timeout(params, "last_call"), - _get_timeout(params, "close"), - ) - - return DynamicRendezvousHandler.from_backend( - params.run_id, - store, - backend, - params.min_nodes, - params.max_nodes, - params.local_addr, - timeout, - ) - except Exception as e: - construct_and_record_rdzv_event( - message=f"{type(e).__name__}: {str(e)}", - run_id=params.run_id, - node_state=NodeState.FAILED, - ) - raise diff --git a/mindtorch/distributed/elastic/rendezvous/etcd_rendezvous.py b/mindtorch/distributed/elastic/rendezvous/etcd_rendezvous.py deleted file mode 100644 index 131cf832a..000000000 --- a/mindtorch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ /dev/null @@ -1,1077 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import json -import logging -import sys -import threading -import time -from typing import Optional - -import etcd # type: ignore[import] - -from mindtorch.distributed.elastic.rendezvous import ( - RendezvousClosedError, - RendezvousError, - RendezvousHandler, - RendezvousInfo, - RendezvousParameters, - RendezvousStoreInfo, - RendezvousTimeoutError, -) - -from .etcd_store import cas_delay, EtcdStore -from .utils import parse_rendezvous_endpoint - - -__all__ = [ - "EtcdRendezvousRetryableFailure", - "EtcdRendezvousRetryImmediately", - "EtcdRendezvousHandler", - "EtcdRendezvous", - "create_rdzv_handler", -] - -_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s") -_log_handler = logging.StreamHandler(sys.stderr) -_log_handler.setFormatter(_log_fmt) - -logger = logging.getLogger(__name__) -logger.propagate = False -logger.setLevel(logging.INFO) -logger.addHandler(_log_handler) - - -# Retryable failure exception means the we were too late to make -# a desired state transition (e.g. because of a race condition), -# and should now restart from the beginning. -# A small delay is recommended to avoid spamming Etcd. -class EtcdRendezvousRetryableFailure(Exception): - pass - - -# Similar to retryable failure, but the new state we observed suggests we -# can re-try immediately, i.e. without a need for "safety delay". -class EtcdRendezvousRetryImmediately(Exception): - pass - - -# Default timeout for the rendezvous. -_DEFAULT_TIMEOUT: int = 600 # 10 minutes - -# Additional waiting time after reaching the minimum number of nodes -# in case the rendezvous is elastic (min != max). -_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds - -# Various constants used internally in EtcdRendezvous -CONST_ETCD_SETUP_TTL = 5 -CONST_ETCD_FROZEN_TTL = 10 -CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10 - -# Ephemeral node TTL for worker's keep-alive key: -CONST_WORKER_KEEPALIVE_TTL = 10 - -# TTL for the ephemeral run_id-specific directory. All rendezvous state data -# for a specific run_id (job instance) is contained within directory. -# Its only role is to clean-up rendezvous data from old runs (for the case when -# etcd server is persistent), and has no affect on correctness, but should be -# larger than any timeouts that a worker process is expected to survive: -CONST_RUNID_SUBROOT_TTL = 7200 # 2 hours - - -class EtcdRendezvousHandler(RendezvousHandler): - """ - Implements a - :py:class:`mindtorch.distributed.elastic.rendezvous.RendezvousHandler` interface - backed by - :py:class:`mindtorch.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`. - ``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to - use and to pass implementation specific configurations to the rendezvous - module. The basic etcd rendezvous configuration URL looks like the following - :: - - etcd://:/?min_workers=&max_workers= # noqa: W605 - - -- example -- - - etcd://localhost:2379/1234?min_workers=1&max_workers=3 - - The URL above is interpreted as follows: - - 1. Use the rendezvous handler that is registered with the ``etcd`` - scheme - 2. The ``etcd`` endpoint to use is ``localhost:2379`` - 3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to - share a common etcd server for multiple jobs so long as the - ``job_ids`` are guaranteed to be unique). Note that the job id can be - any string (e.g. does not need to be a number) as long as it is - unique. - 4. ``min_workers=1`` and ``max_workers=3`` specifies a range for - membership size - Torch Distributed Elastic starts running the job as - long as the cluster size is greater than or equal to ``min_workers`` - and admits up to ``max_workers`` into the cluster. - - Below are a full list of the parameters that can be passed to etcd - rendezvous: - - +--------------------------------------------+--------------------------+ - | Parameter | Description | - +============================================+==========================+ - | min_workers | minimum number of | - | | workers for the | - | | rendezvous to be valid | - +--------------------------------------------+--------------------------+ - | max_workers | maximum number of | - | | workers to admit | - +--------------------------------------------+--------------------------+ - | timeout | total timeout within | - | | which next_rendezvous is | - | | expected to succeed | - | | (default 600s) | - +--------------------------------------------+--------------------------+ - | last_call_timeout | additional wait amount | - | | ("last call") after min | - | | number of workers has | - | | been reached (defaults | - | | to 30s) | - +--------------------------------------------+--------------------------+ - | etcd_prefix | path prefix (from etcd | - | | root), inside which all | - | | etcd nodes will be | - | | created (defaults to | - | | ``/torchelastic/p2p``) | - +--------------------------------------------+--------------------------+ - """ - - def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]): - """ - Args: - rdzv_impl: the implementation of the rendezvous - local_addr: the local address of the current node - """ - - self._rdzv_impl = rdzv_impl - self._local_addr = local_addr - - def __del__(self): - # TODO: look into using weakref here instead. - del self._rdzv_impl - - def get_backend(self) -> str: - return "etcd" - - def next_rendezvous(self): - rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier() - - logger.info("Creating EtcdStore as the c10d::Store implementation") - store = self._rdzv_impl.setup_kv_store(rdzv_version) - - bootstrap_store_info = RendezvousStoreInfo.build( - rank, store, local_addr=self._local_addr - ) - return RendezvousInfo(store, rank, world_size, bootstrap_store_info) - - def is_closed(self): - try: - _, state = self._rdzv_impl.get_rdzv_state() - return state["status"] == "closed" - except etcd.EtcdKeyNotFound: - # No rendezvous state, so it cannot be closed. - return False - - def set_closed(self): - self._rdzv_impl.set_closed() - - def num_nodes_waiting(self): - try: - _, state = self._rdzv_impl.get_rdzv_state() - if state["status"] == "final": - return state["num_workers_waiting"] - except etcd.EtcdKeyNotFound: - pass - return 0 - - def get_run_id(self) -> str: - return self._rdzv_impl._run_id - - def shutdown(self) -> bool: - try: - self.set_closed() - return True - except BaseException as e: - logger.warning("Shutdown failed. Error occurred: %s", str(e)) - return False - - -# TODO: we should probably handle a few additional errors, -# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are -# only relevant for multi-node Etcd ensemble. A simple retry would work, -# but is verbose to add everywhere. Consider wrapping the client calls -# into auto-retry for these errors? -# -class EtcdRendezvous: - """A rendezvous implementation that uses `etcd `__ as the backend store.""" - - def __init__( - self, - client, - prefix, - run_id, - num_min_workers, - num_max_workers, - timeout, - last_call_timeout, - ): - self.client = client - logger.info("Etcd machines: %s", self.client.machines) - - self._prefix = prefix - self._run_id = run_id - self._num_min_workers = num_min_workers - self._num_max_workers = num_max_workers - self._timeout = timeout - self._last_call_timeout = last_call_timeout - - # For cleaning up TTL refresher threads (for ephemeral keys) - self._lease_run_id_stop = None - self._lease_this_rank_stop = None - - if not self._prefix.endswith("/"): - self._prefix += "/" - - # Setup a permanent prefix dir, if didn't exist - if self._prefix != "/": - self.create_path_if_not_exists(self._prefix) - - # Lease a "sub-root" node specific to this job instance (run_id) - self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL) - self._lease_run_id_stop = self.setup_lease_renewal( - self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL - ) - - # Subdir for all rendezvous work - self.create_path_if_not_exists(self.get_path("/rdzv")) - - # Create a rendezvous version counter, if doesn't exist - try: - self.client.write( - key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False - ) - except etcd.EtcdAlreadyExist: - pass - - def __del__(self): - # TODO: look into using weakref here instead. - if self._lease_run_id_stop is not None: - self._lease_run_id_stop.set() - - if self._lease_this_rank_stop is not None: - self._lease_this_rank_stop.set() - - def rendezvous_barrier(self): - """ - Main entry point for next rendezvous. - - This method is blocking until rendezvous succeeds or a timeout occurs. - - Returns: - ``(rdzv_version, rank, world_size)`` - - Raises: - RendezvousTimeoutError - timeout waiting for rendezvous - RendezvousClosedError - rendezvous is or was closed while waiting - RendezvousError - other persistent errors that - render the rendezvous non-retryable - """ - self._rendezvous_deadline = time.time() + self._timeout - while True: - if time.time() > self._rendezvous_deadline: - raise RendezvousTimeoutError - - logger.info("Attempting to join next rendezvous") - try: - # Dis-own our lease in the previous rendezvous, if exists - if self._lease_this_rank_stop is not None: - self._lease_this_rank_stop.set() - - return self.init_phase() - - except EtcdRendezvousRetryImmediately: - # The type of failure suggests we can retry without delay - pass - - except EtcdRendezvousRetryableFailure: - # In case of retryable failure, wait a small delay - # to avoid spamming etcd - time.sleep(1) - - except RendezvousTimeoutError: - logger.info("Rendezvous timeout occurred in EtcdRendezvousHandler") - raise - - except RendezvousClosedError: - logger.info( - "Rendezvous for run_id=%s was observed to be closed", self._run_id - ) - raise - - except RendezvousError: - raise - - except Exception as e: - # In case of a general exception, wait a small delay - # to avoid spamming etcd - # FIXME: there are a few things that fall under this like - # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly. - logger.info("Rendezvous attempt failed, will retry. Reason: %s", e) - time.sleep(1) - - def init_phase(self): - """ - Initially, the rendezvous state is expected to be one of: - - 1. empty (non-existent) - in this case we try to create a new one. - 2. joinable - we try to join it. - 3. final - we announce ourselves as waiting, and go into monitoring mode - - Any other state is considered transitional, and will be retried after - a short delay. - - Returns: - ``(rdzv_version, rank, world_size)`` - - Raises: - RendezvousClosedError - current rendezvous was/is closed - EtcdRendezvousRetryableFailure - observed some intermediate - state, which is best handled by retrying later - """ - try: - active_version = self.try_create_rendezvous() - state = json.loads(active_version.value) - logger.info("New rendezvous state created: %s", state) - except etcd.EtcdAlreadyExist: - active_version, state = self.get_rdzv_state() - # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound), - # but this is ok for us - just means we'll restart from beginning. - logger.info("Observed existing rendezvous state: %s", state) - - if state["status"] == "closed": - raise RendezvousClosedError - - if state["status"] == "joinable": - return self.join_phase(state["version"]) - - if state["status"] == "final": - self.handle_existing_rendezvous(state["version"]) - raise EtcdRendezvousRetryImmediately - - self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1) - raise EtcdRendezvousRetryableFailure - - def join_phase(self, expected_version): - """ - We observed a rendezvous state in 'joinable' state, and attempt to join this - particular version, and then wait for all other peers to join. - """ - # Failure to join will propagate an exception, causing a re-entry. - active_version, this_rank = self.join_rendezvous(expected_version) - state = json.loads(active_version.value) - logger.info( - "Joined rendezvous version %s as rank %s. Full state: %s", - state["version"], - this_rank, - state, - ) - - # If this worker was first to reach num_min_workers requirement, - # and rendezvous is still joinable (therefore it is elastic), - # then this worker will be responsible for waiting out the "last call" - # timeout and closing (i.e. transitioning to 'frozen') the rendezvous - # afterwards. - # As a safety against a potential failure of this worker (during the - # last call timeout), the rendezvous state is made ephemeral - # when min_num_workers is reached. - - if this_rank == self._num_min_workers - 1 and state["status"] == "joinable": - logger.info("Rank %s is responsible for join last call.", this_rank) - last_call_deadline = time.time() + self._last_call_timeout - self.handle_join_last_call(expected_version, last_call_deadline) - logger.info("Rank %s finished join last call.", this_rank) - - # Wait for rendezvous state to be frozen, which means a fixed set of peers - logger.info("Waiting for remaining peers.") - active_version = self.wait_for_peers(expected_version) - state = json.loads(active_version.value) - - assert ( - state["version"] == expected_version - ), "Logic error: failed to observe version mismatch" - - return self.confirm_phase(expected_version, this_rank) - - def confirm_phase(self, expected_version, this_rank): - """ - Once the rendezvous state transitions from 'joinable' to 'frozen', - we have every participant confirm their membership and setup per-member - keep-alive TTL keys, and then wait for all other participants to confirm, - which would then successfully conclude this rendezvous. - """ - logger.info("All peers arrived. Confirming membership.") - self.confirm_membership(expected_version, this_rank) - - logger.info("Waiting for confirmations from all peers.") - active_version = self.wait_for_final(expected_version) - state = json.loads(active_version.value) - - logger.info( - "Rendezvous version %s is complete. Final state: %s", - state["version"], - state, - ) - - # Rendezvous version number; our rank in it; world size - return state["version"], this_rank, len(state["participants"]) - - def handle_existing_rendezvous(self, expected_version): - """ - Handle the case when there's an existing (state 'final) rendezvous already - in place, and we have to announce ourselves waiting, and wait until - the next rendezvous opportunity. - """ - # If state is 'final' -> increment num_workers_waiting - # Then, observe state changes: - # 1. if it's no longer final -> bail out and re-try - # 2. if keep alives are missing, destroy it and bail out. - active_state = self.announce_self_waiting(expected_version) - logger.info( - "Added self to waiting list. Rendezvous full state: %s", active_state.value - ) - - self.wait_for_rendezvous_to_free(expected_version) - logger.info( - "Previously existing rendezvous state changed. Will re-try joining." - ) - - def try_create_rendezvous(self): - """ - Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists). - - Raises: - RendezvousError - on unexpected state - """ - # Initially active_version is ephemeral - this is to handle the - # possibility that might fail to complete the setup transaction, - # i.e. the transition "setup" -> "joinable". - active_version = self.client.write( - key=self.get_path("/rdzv/active_version"), - value=json.dumps({"status": "setup"}), - prevExist=False, - ttl=CONST_ETCD_SETUP_TTL, - ) - - try: - version_counter = self.client.get(self.get_path("/rdzv/version_counter")) - version_counter.value = str(int(version_counter.value) + 1) - self.client.update(version_counter) - except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e: - raise RendezvousError( - "Unexpected state of EtcdRendezvousHandler, worker needs to die." - ) from e - - # Any failure below results in declaring a retryable rendezvous failure. - # The ephemeral /rdzv/active_version will expire and someone can then - # re-try the setup process. - - # Create directory node for participant data - self.client.write( - key=self.get_path(f"/rdzv/v_{version_counter.value}"), - value=None, - dir=True, - prevExist=False, - ) - - # Publish rendezvous version and signal it is ready-to-be-joined. - # If rendezvous was set closed just before this, a retry will happen, - # where the closed condition will be handled. - return self.client.test_and_set( - key=self.get_path("/rdzv/active_version"), - value=json.dumps( - { - "status": "joinable", - "version": version_counter.value, - "participants": [], - } - ), - prev_value=active_version.value, - ) - - def join_rendezvous(self, expected_version): - """Helper method for the join phase.""" - # Use compare-and-swap to add self to rendezvous state: - while True: - cas_delay() - active_version, state = self.get_rdzv_state() - - if state["status"] != "joinable": - raise EtcdRendezvousRetryableFailure( - "Rendezvous state became non-joinable before we could join. " - "Must join next one." - ) - - if state["version"] != expected_version: - raise EtcdRendezvousRetryImmediately( - "Rendezvous version changed. Must try join the new one." - ) - - assert ( - len(state["participants"]) < self._num_max_workers - ), "Logic error: joinable rendezvous should always have space left" - - this_rank = len(state["participants"]) - state["participants"].append(this_rank) - - # When reaching min workers, or changing state to frozen, we'll set - # the active_version node to be ephemeral. - set_ttl: Optional[int] = None - if len(state["participants"]) == self._num_max_workers: - state["status"] = "frozen" - state["keep_alives"] = [] - set_ttl = CONST_ETCD_FROZEN_TTL - elif len(state["participants"]) >= self._num_min_workers: - set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL - - try: - # Compare-and-swap. - active_version = self.client.test_and_set( - key=self.get_path("/rdzv/active_version"), - value=json.dumps(state), - prev_value=active_version.value, - ttl=set_ttl, - ) - # We succeeded joining. - return active_version, this_rank - - except etcd.EtcdCompareFailed: - logger.info("Join rendezvous CAS unsuccessful, retrying") - - def wait_for_peers(self, expected_version): - """Helper method for the join phase.""" - active_version, state = self.get_rdzv_state() - while True: - if state["status"] == "frozen" and state["version"] == expected_version: - # Success, all peers arrived. - return active_version - - elif state["status"] == "joinable" and state["version"] == expected_version: - # Continue waiting for any interesting events. - active_version, state = self.try_wait_for_state_change( - etcd_index=active_version.etcd_index + 1 - ) - - else: - # No valid transition possible at this point - raise EtcdRendezvousRetryableFailure( - "Rendezvous state transition no longer possible. Must re-enter." - ) - - def confirm_membership(self, expected_version, this_rank): - """Helper method for the confirm phase.""" - # Compare-and-swap loop - while True: - cas_delay() - active_version, state = self.get_rdzv_state() - - if state["status"] != "frozen": - raise EtcdRendezvousRetryImmediately( - "Rendezvous no longer frozen, before we confirmed. " - "Must join next one" - ) - if state["version"] != expected_version: - raise EtcdRendezvousRetryImmediately( - "Rendezvous version changed. Must try join the new one." - ) - - this_lease_key = self.get_path( - f"/rdzv/v_{expected_version}/rank_{this_rank}" - ) - self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL) - - state["keep_alives"].append(this_lease_key) - if len(state["keep_alives"]) == len(state["participants"]): - # Everyone confirmed (this rank is last to do so) - state["status"] = "final" - state["num_workers_waiting"] = 0 - finalize = True - else: - finalize = False - - try: - # Compare-and-swap. If new state is still frozen, keep it ephemeral. - active_version = self.client.test_and_set( - key=self.get_path("/rdzv/active_version"), - value=json.dumps(state), - prev_value=active_version.value, - ttl=None if finalize else CONST_ETCD_FROZEN_TTL, - ) - - self._lease_this_rank_stop = self.setup_lease_renewal( - this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL - ) - return active_version - - except etcd.EtcdCompareFailed: - logger.info("Confirm membership CAS unsuccessful, retrying") - - def wait_for_final(self, expected_version): - """Helper method for the confirm phase.""" - active_version, state = self.get_rdzv_state() - while True: - if state["status"] == "final" and state["version"] == expected_version: - # Success. This rendezvous is final, and we accept it. - return active_version - - elif state["status"] == "frozen" and state["version"] == expected_version: - # Continue waiting for any interesting events. - active_version, state = self.try_wait_for_state_change( - etcd_index=active_version.etcd_index + 1 - ) - - else: - # No valid transition possible at this point - raise EtcdRendezvousRetryableFailure( - "Rendezvous state transition no longer possible. Must re-enter." - ) - - def announce_self_waiting(self, expected_version): - """ - Announce this worker is waiting (via num_workers_waiting counter) to join next - rendezvous, but only if state and version match. - """ - while True: - cas_delay() - active_version, state = self.get_rdzv_state() - - if state["status"] != "final" or state["version"] != expected_version: - raise EtcdRendezvousRetryImmediately - - # Increment counter to signal an additional waiting worker. - state["num_workers_waiting"] += 1 - - try: - active_version = self.client.test_and_set( - key=self.get_path("/rdzv/active_version"), - value=json.dumps(state), - prev_value=active_version.value, - ) - return active_version - - except etcd.EtcdCompareFailed: - logger.info("Announce self as waiting CAS unsuccessful, retrying") - - def wait_for_rendezvous_to_free(self, expected_version): - """ - When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join. - - Such opportunity may come from: - - 1. rendezvous state changed by someone else, in which case we unblock and retry. - 2. rendezvous becomes invalid because at least one member failed to renew their - leased keep_alive node. We detect this, and destroy the rendezvous. - """ - active_version, state = self.get_rdzv_state() - while True: - if state["status"] != "final" or state["version"] != expected_version: - return - - # Check if current rendezvous state is valid, in the sense that all - # its members are alive (renewing their lease). - # If not, try destroy this rendezvous, so a new one can be created. - alive_members = self.client.get( - self.get_path(f"/rdzv/v_{expected_version}") - ) - keep_alive_keys = [ch.key for ch in alive_members.children] - - for key in state["keep_alives"]: - if key not in keep_alive_keys: - # This participant didn't renew their lease. We'll declare this - # rendezvous version as dead (but only if it hadn't changed) - logger.info("Keep-alive key %s is not renewed.", key) - logger.info( - "Rendezvous version %s is incomplete. ", expected_version - ) - logger.info("Attempting to destroy it.") - - # Compare-and-delete operation. Throws if compare failed, - # which means rendezvous was already destroyed/re-created/closed, - # and we can try to re-enter the barrier. - self.client.delete( - key=self.get_path("/rdzv/active_version"), - prevValue=active_version.value, - ) - - logger.info( - "Destroyed rendezvous version %s successfully.", - expected_version, - ) - - # We can return (and retry) immediately - return - - # Existing rendezvous seems valid, no reason to destroy it. - # We just have to wait until something changes and re-check. - try: - overall_timeout = ( - max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 - ) - self.client.watch( - key=self.get_path("/rdzv"), - index=active_version.etcd_index + 1, - recursive=True, - timeout=overall_timeout, - ) - except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): - pass - - if time.time() > self._rendezvous_deadline: - raise RendezvousTimeoutError - active_version, state = self.get_rdzv_state() - - def handle_join_last_call(self, expected_version, deadline): - """ - After we reach min number of workers, one particular worker takes on the - responsibility of waiting an additional timeout before closing the join window. - If the worker responsible for this fails, the rendezvous will be destroyed due - to expiring TTL, and the other participants will re-rendezvous. - - Here we expect to see state - Exit gracefully if either: - - 1. state becomes - 2. timeout happens (reaching deadline), in which case - we try the transition to - - Exit with exception otherwise. - """ - active_version, state = self.get_rdzv_state() - while True: - if state["status"] == "frozen" and state["version"] == expected_version: - # Worker set became frozen before last-call timeout. This is possible - # when num_max_workers is reached before the timeout. - return - - if state["status"] != "joinable" or state["version"] != expected_version: - raise EtcdRendezvousRetryableFailure( - "Rendezvous state transition no longer possible. Must re-enter." - ) - - # If timeout occurred, attempt a state transition (joinable -> frozen) - if time.time() >= deadline: - state["status"] = "frozen" - state["keep_alives"] = [] - try: - active_version = self.client.test_and_set( - key=self.get_path("/rdzv/active_version"), - value=json.dumps(state), - prev_value=active_version.value, - ttl=CONST_ETCD_FROZEN_TTL, - ) - # We successfully made this rendezvous frozen. - return - except etcd.EtcdCompareFailed: - logger.info( - "Join last-call transition CAS unsuccessful. Will retry" - ) - cas_delay() - active_version, state = self.get_rdzv_state() - continue - - # Timeout did not occur, so we must refresh TTL, and wait for - # further changes. Note: we only want TTL to be refreshed if - # state is still joinable, hence we use CAS for that here, - # even though we don't change any of the data. - try: - active_version = self.client.test_and_set( - key=self.get_path("/rdzv/active_version"), - value=active_version.value, - prev_value=active_version.value, - ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL, - ) - - # Minimize "oversleeping": - timeout = min( - CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2, - deadline - time.time() + 1.0, # Oversleeping by 1s is ok. - ) - active_version, state = self.try_wait_for_state_change( - etcd_index=active_version.etcd_index + 1, timeout=timeout - ) - except etcd.EtcdCompareFailed: - logger.info("Join last-call TTL refresh CAS unsuccessful, will retry") - cas_delay() - active_version, state = self.get_rdzv_state() - - def set_closed(self): - """ - Mark rendezvous 'closed' for current run_id, which is used to signal other - participants to not attempt to perform (re-)rendezvous. This is useful - when one of the workers decides the job is complete. - """ - while True: - active_version, state = self.get_rdzv_state() - - if state["status"] == "closed": - # Already closed by someone else. - return - - state["status"] = "closed" - try: - self.client.test_and_set( - key=self.get_path("/rdzv/active_version"), - value=json.dumps(state), - prev_value=active_version.value, - ) - return - - except etcd.EtcdCompareFailed: - logger.info("Set closed CAS unsuccessful, retrying") - cas_delay() - - def get_rdzv_state(self): - active_version = self.client.get(key=self.get_path("/rdzv/active_version")) - return active_version, json.loads(active_version.value) - - def try_wait_for_state_change(self, etcd_index, timeout=None): - # Don't sleep past the overall deadline (at least more than by 1s) - overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 - timeout = overall_timeout if timeout is None else min(timeout, overall_timeout) - - try: - self.client.watch( - self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout - ) - except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): - pass - - if time.time() > self._rendezvous_deadline: - raise RendezvousTimeoutError - - # Unfortunately, we have to do another fetch in order to get last etcd_index. - return self.get_rdzv_state() - - def get_path(self, path): - if not path.startswith("/"): - path = "/" + path - - return f"{self._prefix}run_{self._run_id}{path}" - - def create_path_if_not_exists(self, full_path, ttl=None): - try: - self.client.write( - key=full_path, value=None, dir=True, prevExist=False, ttl=ttl - ) - except etcd.EtcdAlreadyExist: - pass - - def setup_lease_renewal(self, full_path, ttl): - # NOTE: For ephemeral key TTL renewal (~lease) to work correctly, - # make sure you don't call any long-blocking methods that do not - # release the Python's GIL! An example of this is calling a pybind11 - # extension function that is blocking / long-running, but is not - # doing a scoped release of the GIL. - def lease_worker(client, path, ttl, stop_event): - while True: - try: - client.refresh(path, ttl=ttl) - except etcd.EtcdKeyNotFound: - break - except ConnectionRefusedError: - # This error usually occurs during test when the server already got terminated but the - # python garbage collector have not yet invoked the __del__ method. - break - - if stop_event.wait(timeout=ttl / 2): - break - - lease_stop_event = threading.Event() - lease_thread = threading.Thread( - target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event) - ) - - lease_thread.daemon = True - lease_thread.start() - - return lease_stop_event - - def store_extra_data(self, rdzv_version, key, value): - node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") - try: - # If first time we are storing anything: - extra_data = self.client.write( - key=node, value=json.dumps({key: value}), prevExist=False - ) - return - except etcd.EtcdAlreadyExist: - pass - - # CAS loop, to make sure we don't lose concurrent stores. - while True: - # We never delete extra_data. Failure here should be fatal, no special handling. - extra_data = self.client.get(node) - - new_extra_data_value = json.loads(extra_data.value) - new_extra_data_value[key] = value - - try: - extra_data = self.client.test_and_set( - key=node, - value=json.dumps(new_extra_data_value), - prev_value=extra_data.value, - ) - return - except etcd.EtcdCompareFailed: - logger.info("Store extra_data CAS unsuccessful, retrying") - time.sleep(0.1) - - def load_extra_data(self, rdzv_version, key, timeout=None): - # 'extra_data' node itself, and the directory it is located in: - node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") - node_dir = self.get_path(f"/rdzv/v_{rdzv_version}") - - # TODO: implement timeout - # https://github.com/pytorch/elastic/issues/12 - while True: - # Combined wait for the node itself, and the key inside it. - root = self.client.get(node_dir) - - # Find the extra_data node, if it exists - extra_data = [n for n in root.children if n.key == node] - assert len(extra_data) <= 1 - - # Node for extra_data exists, check the desired key inside it. - if len(extra_data) == 1: - extra_data_dict = json.loads(extra_data[0].value) - if key in extra_data_dict: - return extra_data_dict[key] - - # The 'extra_data' node doesn't exist, or they key isn't published yet. - # Wait for interesting events on the extra_data node and retry. - try: - self.client.watch(node, index=root.etcd_index + 1) - except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): - pass - - def setup_kv_store(self, rdzv_version): - store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv") - self.create_path_if_not_exists(store_path) - return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path) - - -def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: - """Create a new ``etcd.Client`` from the specified ``RendezvousParameters``.""" - hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379) - - # The communication protocol - protocol = params.config.get("protocol") - if protocol is None: - protocol = "http" - else: - if protocol != "http" and protocol != "https": - raise ValueError("The etcd protocol must be HTTP or HTTPS.") - - # The SSL client certificate - ssl_cert = params.config.get("cert") - if ssl_cert is not None: - cert_key = params.config.get("key") - if cert_key is not None: - # The etcd client expects the certificate key as the second element - # of the `cert` tuple. - ssl_cert = (ssl_cert, cert_key) - - # The root certificate - ca_cert = params.config.get("cacert") - - return etcd.Client( - hostname, - port, - protocol=protocol, - cert=ssl_cert, - ca_cert=ca_cert, - allow_reconnect=True, - ) - - -# Handler for mindtorch.distributed "static" registration -def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: - """ - Usage: - - :: - - rdzv_params = RendezvousParameters( - backend="etcd", - endpoint="192.168.0.42:2379", - run_id="123", - min_nodes=4, - max_nodes=8, - timeout=300, - last_call_timeout=30, - etcd_prefix="custom_prefix", - protocol="https", - cacert="/etc/kubernetes/certs/ca.crt", - cert="/etc/kubernetes/certs/client.crt", - key="/etc/kubernetes/certs/client.key") - # -- or -- - rdzv_params = RendezvousParameters( - backend="etcd", - endpoint="192.168.0.42:2379", - run_id="123", - min_nodes=4, - max_nodes=8) - - etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params) - - - Where: - run_id - unique id for this training job instance, - min_nodes - min number of workers expected to join the rendezvous, - max_nodes - max number of workers allowed to join the rendezvous, - defaults to min_workers is not specified. - timeout - total timeout within which next_rendezvous is expected to - succeed; a RendezvousTimeoutError is raised otherwise; - Defaults is 600 (10 minutes). - last_call_timeout - additional wait amount ("last call") after - min number of workers has been reached. - Defaults to 30 seconds. - etcd_prefix - path prefix (from etcd root), inside which all - etcd nodes will be created. - Default is "/torchelastic/p2p". - protocol - http (default) or https to access etcd. - cacert - CA cert to access etcd, only makes sense with https. - cert - client cert to access etcd, only makes sense with https. - key - client key to access etcd, only makes sense with https. - """ - client = _create_etcd_client(params) - - etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p") - - rdzv = EtcdRendezvous( - client=client, - prefix=etcd_prefix, - run_id=params.run_id, - num_min_workers=params.min_nodes, - num_max_workers=params.max_nodes, - timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), - last_call_timeout=params.get_as_int( - "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT - ), - ) - return EtcdRendezvousHandler( - rdzv_impl=rdzv, - local_addr=params.local_addr, - ) diff --git a/mindtorch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/mindtorch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py deleted file mode 100644 index 8366ecbfc..000000000 --- a/mindtorch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py +++ /dev/null @@ -1,217 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import binascii -from base64 import b64decode, b64encode -from typing import cast, Optional, Tuple - -import urllib3.exceptions # type: ignore[import] -from etcd import ( # type: ignore[import] - Client as EtcdClient, - EtcdAlreadyExist, - EtcdCompareFailed, - EtcdException, - EtcdKeyNotFound, - EtcdResult, -) - -from mindtorch.distributed import Store - -from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError -from .dynamic_rendezvous import RendezvousBackend, Token -from .etcd_store import EtcdStore -from .utils import parse_rendezvous_endpoint - - -class EtcdRendezvousBackend(RendezvousBackend): - """Represents an etcd-based rendezvous backend. - - Args: - client: - The ``etcd.Client`` instance to use to communicate with etcd. - run_id: - The run id of the rendezvous. - key_prefix: - The path under which to store the rendezvous state in etcd. - ttl: - The TTL of the rendezvous state. If not specified, defaults to two hours. - """ - - _DEFAULT_TTL = 7200 # 2 hours - - _client: EtcdClient - _key: str - _ttl: int - - def __init__( - self, - client: EtcdClient, - run_id: str, - key_prefix: Optional[str] = None, - ttl: Optional[int] = None, - ) -> None: - if not run_id: - raise ValueError("The run id must be a non-empty string.") - - self._client = client - - if key_prefix: - self._key = key_prefix + "/" + run_id - else: - self._key = run_id - - if ttl and ttl > 0: - self._ttl = ttl - else: - self._ttl = self._DEFAULT_TTL - - @property - def name(self) -> str: - """See base class.""" - return "etcd-v2" - - def get_state(self) -> Optional[Tuple[bytes, Token]]: - """See base class.""" - try: - result = self._client.read(self._key) - except EtcdKeyNotFound: - return None - except (EtcdException, urllib3.exceptions.TimeoutError) as exc: - raise RendezvousConnectionError( - "The connection to etcd has failed. See inner exception for details." - ) from exc - - return self._decode_state(result) - - def set_state( - self, state: bytes, token: Optional[Token] = None - ) -> Optional[Tuple[bytes, Token, bool]]: - """See base class.""" - base64_state = b64encode(state).decode() - - kwargs = {} - - def get_state(): - result = self.get_state() - if result is not None: - tmp = *result, False - # Python 3.6 does not support tuple unpacking in return - # statements. - return tmp - return None - - if token: - try: - token = int(token) - except ValueError: - return get_state() - - if token: - kwargs["prevIndex"] = token - else: - kwargs["prevExist"] = False - - try: - result = self._client.write(self._key, base64_state, self._ttl, **kwargs) - except (EtcdAlreadyExist, EtcdCompareFailed): - result = None - except (EtcdException, urllib3.exceptions.TimeoutError) as exc: - raise RendezvousConnectionError( - "The connection to etcd has failed. See inner exception for details." - ) from exc - - if result is None: - return get_state() - - tmp = *self._decode_state(result), True - return tmp - - def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]: - base64_state = result.value.encode() - - try: - state = b64decode(base64_state) - except binascii.Error as exc: - raise RendezvousStateError( - "The state object is corrupt. See inner exception for details." - ) from exc - - return state, result.modifiedIndex - - -def _create_etcd_client(params: RendezvousParameters) -> EtcdClient: - host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379) - - # The timeout - read_timeout = cast(int, params.get_as_int("read_timeout", 60)) - if read_timeout <= 0: - raise ValueError("The read timeout must be a positive integer.") - - # The communication protocol - protocol = params.get("protocol", "http").strip().lower() - if protocol != "http" and protocol != "https": - raise ValueError("The protocol must be HTTP or HTTPS.") - - # The SSL client certificate - ssl_cert = params.get("ssl_cert") - if ssl_cert: - ssl_cert_key = params.get("ssl_cert_key") - if ssl_cert_key: - # The etcd client expects the certificate key as the second element - # of the `cert` tuple. - ssl_cert = (ssl_cert, ssl_cert_key) - - # The root certificate - ca_cert = params.get("ca_cert") - - try: - return EtcdClient( - host, - port, - read_timeout=read_timeout, - protocol=protocol, - cert=ssl_cert, - ca_cert=ca_cert, - allow_reconnect=True, - ) - except (EtcdException, urllib3.exceptions.TimeoutError) as exc: - raise RendezvousConnectionError( - "The connection to etcd has failed. See inner exception for details." - ) from exc - - -def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]: - """Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters. - - +--------------+-----------------------------------------------------------+ - | Parameter | Description | - +==============+===========================================================+ - | read_timeout | The read timeout, in seconds, for etcd operations. | - | | Defaults to 60 seconds. | - +--------------+-----------------------------------------------------------+ - | protocol | The protocol to use to communicate with etcd. Valid | - | | values are "http" and "https". Defaults to "http". | - +--------------+-----------------------------------------------------------+ - | ssl_cert | The path to the SSL client certificate to use along with | - | | HTTPS. Defaults to ``None``. | - +--------------+-----------------------------------------------------------+ - | ssl_cert_key | The path to the private key of the SSL client certificate | - | | to use along with HTTPS. Defaults to ``None``. | - +--------------+-----------------------------------------------------------+ - | ca_cert | The path to the rool SSL authority certificate. Defaults | - | | to ``None``. | - +--------------+-----------------------------------------------------------+ - """ - client = _create_etcd_client(params) - - backend = EtcdRendezvousBackend( - client, params.run_id, key_prefix="/torch/elastic/rendezvous" - ) - - store = EtcdStore(client, "/torch/elastic/store") - - return backend, store diff --git a/mindtorch/distributed/elastic/rendezvous/etcd_server.py b/mindtorch/distributed/elastic/rendezvous/etcd_server.py deleted file mode 100644 index 99623e0bb..000000000 --- a/mindtorch/distributed/elastic/rendezvous/etcd_server.py +++ /dev/null @@ -1,248 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import atexit -import logging -import os -import shlex -import shutil -import socket -import subprocess -import tempfile -import time -from typing import Optional, TextIO, Union - - -try: - import etcd # type: ignore[import] -except ModuleNotFoundError: - pass - - -logger = logging.getLogger(__name__) - - -def find_free_port(): - """ - Find a free port and binds a temporary socket to it so that the port can be "reserved" until used. - - .. note:: the returned socket must be closed before using the port, - otherwise a ``address already in use`` error will happen. - The socket should be held and closed as close to the - consumer of the port as possible since otherwise, there - is a greater chance of race-condition where a different - process may see the port as being free and take it. - - Returns: a socket binded to the reserved free port - - Usage:: - - sock = find_free_port() - port = sock.getsockname()[1] - sock.close() - use_port(port) - """ - addrs = socket.getaddrinfo( - host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM - ) - - for addr in addrs: - family, type, proto, _, _ = addr - try: - s = socket.socket(family, type, proto) - s.bind(("localhost", 0)) - s.listen(0) - return s - except OSError as e: - s.close() # type: ignore[possibly-undefined] - print(f"Socket creation attempt failed: {e}") - raise RuntimeError("Failed to create a socket") - - -def stop_etcd(subprocess, data_dir: Optional[str] = None): - if subprocess and subprocess.poll() is None: - logger.info("stopping etcd server") - subprocess.terminate() - subprocess.wait() - - if data_dir: - logger.info("deleting etcd data dir: %s", data_dir) - shutil.rmtree(data_dir, ignore_errors=True) - - -class EtcdServer: - """ - .. note:: tested on etcd server v3.4.3. - - Starts and stops a local standalone etcd server on a random free - port. Useful for single node, multi-worker launches or testing, - where a sidecar etcd server is more convenient than having to - separately setup an etcd server. - - This class registers a termination handler to shutdown the etcd - subprocess on exit. This termination handler is NOT a substitute for - calling the ``stop()`` method. - - The following fallback mechanism is used to find the etcd binary: - - 1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH - 2. Uses ``/bin/etcd`` if one exists - 3. Uses ``etcd`` from ``PATH`` - - Usage - :: - - server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd") - server.start() - client = server.get_client() - # use client - server.stop() - - Args: - etcd_binary_path: path of etcd server binary (see above for fallback path) - """ - - def __init__(self, data_dir: Optional[str] = None): - self._port = -1 - self._host = "localhost" - - root = os.path.dirname(__file__) - default_etcd_bin = os.path.join(root, "bin/etcd") - self._etcd_binary_path = os.environ.get( - "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin - ) - if not os.path.isfile(self._etcd_binary_path): - self._etcd_binary_path = "etcd" - - self._base_data_dir = ( - data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data") - ) - self._etcd_cmd = None - self._etcd_proc: Optional[subprocess.Popen] = None - - def _get_etcd_server_process(self) -> subprocess.Popen: - if not self._etcd_proc: - raise RuntimeError( - "No etcd server process started. Call etcd_server.start() first" - ) - else: - return self._etcd_proc - - def get_port(self) -> int: - """Return the port the server is running on.""" - return self._port - - def get_host(self) -> str: - """Return the host the server is running on.""" - return self._host - - def get_endpoint(self) -> str: - """Return the etcd server endpoint (host:port).""" - return f"{self._host}:{self._port}" - - def start( - self, - timeout: int = 60, - num_retries: int = 3, - stderr: Union[int, TextIO, None] = None, - ) -> None: - """ - Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests. - - Args: - timeout: time (in seconds) to wait for the server to be ready - before giving up. - num_retries: number of retries to start the server. Each retry - will wait for max ``timeout`` before considering it as failed. - stderr: the standard error file handle. Valid values are - `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file - descriptor (a positive integer), an existing file object, and - `None`. - - Raises: - TimeoutError: if the server is not ready within the specified timeout - """ - curr_retries = 0 - while True: - try: - data_dir = os.path.join(self._base_data_dir, str(curr_retries)) - os.makedirs(data_dir, exist_ok=True) - return self._start(data_dir, timeout, stderr) - except Exception as e: - curr_retries += 1 - stop_etcd(self._etcd_proc) - logger.warning( - "Failed to start etcd server, got error: %s, retrying", str(e) - ) - if curr_retries >= num_retries: - shutil.rmtree(self._base_data_dir, ignore_errors=True) - raise - atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir) - - def _start( - self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None - ) -> None: - sock = find_free_port() - sock_peer = find_free_port() - self._port = sock.getsockname()[1] - peer_port = sock_peer.getsockname()[1] - - etcd_cmd = shlex.split( - " ".join( - [ - self._etcd_binary_path, - "--enable-v2", - "--data-dir", - data_dir, - "--listen-client-urls", - f"http://{self._host}:{self._port}", - "--advertise-client-urls", - f"http://{self._host}:{self._port}", - "--listen-peer-urls", - f"http://{self._host}:{peer_port}", - ] - ) - ) - - logger.info("Starting etcd server: [%s]", etcd_cmd) - - sock.close() - sock_peer.close() - self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr) - self._wait_for_ready(timeout) - - def get_client(self): - """Return an etcd client object that can be used to make requests to this server.""" - return etcd.Client( - host=self._host, port=self._port, version_prefix="/v2", read_timeout=10 - ) - - def _wait_for_ready(self, timeout: int = 60) -> None: - client = etcd.Client( - host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5 - ) - max_time = time.time() + timeout - - while time.time() < max_time: - if self._get_etcd_server_process().poll() is not None: - # etcd server process finished - exitcode = self._get_etcd_server_process().returncode - raise RuntimeError( - f"Etcd server process exited with the code: {exitcode}" - ) - try: - logger.info("etcd server ready. version: %s", client.version) - return - except Exception: - time.sleep(1) - raise TimeoutError("Timed out waiting for etcd server to be ready!") - - def stop(self) -> None: - """Stop the server and cleans up auto generated resources (e.g. data dir).""" - logger.info("EtcdServer stop method called") - stop_etcd(self._etcd_proc, self._base_data_dir) diff --git a/mindtorch/distributed/elastic/rendezvous/etcd_store.py b/mindtorch/distributed/elastic/rendezvous/etcd_store.py deleted file mode 100644 index ef3911758..000000000 --- a/mindtorch/distributed/elastic/rendezvous/etcd_store.py +++ /dev/null @@ -1,212 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import datetime -import random -import time -from base64 import b64decode, b64encode -from typing import Optional - -import etcd # type: ignore[import] - -# pyre-ignore[21]: Could not find name `Store` in `mindtorch.distributed`. -from mindtorch.distributed import Store - - -# Delay (sleep) for a small random amount to reduce CAS failures. -# This does not affect correctness, but will reduce requests to etcd server. -def cas_delay(): - time.sleep(random.uniform(0, 0.1)) - - -# pyre-fixme[11]: Annotation `Store` is not defined as a type. -class EtcdStore(Store): - """ - Implement a c10 Store interface by piggybacking on the rendezvous etcd instance. - - This is the store object returned by ``EtcdRendezvous``. - """ - - def __init__( - self, - etcd_client, - etcd_store_prefix, - # Default timeout same as in c10d/Store.hpp - timeout: Optional[datetime.timedelta] = None, - ): - super().__init__() # required for pybind trampoline. - - self.client = etcd_client - self.prefix = etcd_store_prefix - - if timeout is not None: - self.set_timeout(timeout) - - if not self.prefix.endswith("/"): - self.prefix += "/" - - def set(self, key, value): - """ - Write a key/value pair into ``EtcdStore``. - - Both key and value may be either Python ``str`` or ``bytes``. - """ - self.client.set(key=self.prefix + self._encode(key), value=self._encode(value)) - - def get(self, key) -> bytes: - """ - Get a value by key, possibly doing a blocking wait. - - If key is not immediately present, will do a blocking wait - for at most ``timeout`` duration or until the key is published. - - - Returns: - value ``(bytes)`` - - Raises: - LookupError - If key still not published after timeout - """ - b64_key = self.prefix + self._encode(key) - kvs = self._try_wait_get([b64_key]) - - if kvs is None: - raise LookupError(f"Key {key} not found in EtcdStore") - - return self._decode(kvs[b64_key]) - - def add(self, key, num: int) -> int: - """ - Atomically increment a value by an integer amount. - - The integer is represented as a string using base 10. If key is not present, - a default value of ``0`` will be assumed. - - Returns: - the new (incremented) value - - - """ - b64_key = self._encode(key) - # c10d Store assumes value is an integer represented as a decimal string - try: - # Assume default value "0", if this key didn't yet: - node = self.client.write( - key=self.prefix + b64_key, - value=self._encode(str(num)), # i.e. 0 + num - prevExist=False, - ) - return int(self._decode(node.value)) - except etcd.EtcdAlreadyExist: - pass - - while True: - # Note: c10d Store does not have a method to delete keys, so we - # can be sure it's still there. - node = self.client.get(key=self.prefix + b64_key) - new_value = self._encode(str(int(self._decode(node.value)) + num)) - try: - node = self.client.test_and_set( - key=node.key, value=new_value, prev_value=node.value - ) - return int(self._decode(node.value)) - except etcd.EtcdCompareFailed: - cas_delay() - - def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None): - """ - Wait until all of the keys are published, or until timeout. - - Raises: - LookupError - if timeout occurs - """ - b64_keys = [self.prefix + self._encode(key) for key in keys] - kvs = self._try_wait_get(b64_keys, override_timeout) - if kvs is None: - raise LookupError("Timeout while waiting for keys in EtcdStore") - # No return value on success - - def check(self, keys) -> bool: - """Check if all of the keys are immediately present (without waiting).""" - b64_keys = [self.prefix + self._encode(key) for key in keys] - kvs = self._try_wait_get( - b64_keys, - override_timeout=datetime.timedelta(microseconds=1), # as if no wait - ) - return kvs is not None - - # - # Encode key/value data in base64, so we can store arbitrary binary data - # in EtcdStore. Input can be `str` or `bytes`. - # In case of `str`, utf-8 encoding is assumed. - # - def _encode(self, value) -> str: - if type(value) == bytes: - return b64encode(value).decode() - elif type(value) == str: - return b64encode(value.encode()).decode() - raise ValueError("Value must be of type str or bytes") - - # - # Decode a base64 string (of type `str` or `bytes`). - # Return type is `bytes`, which is more convenient with the Store interface. - # - def _decode(self, value) -> bytes: - if type(value) == bytes: - return b64decode(value) - elif type(value) == str: - return b64decode(value.encode()) - raise ValueError("Value must be of type str or bytes") - - # - # Get all of the (base64-encoded) etcd keys at once, or wait until all the keys - # are published or timeout occurs. - # This is a helper method for the public interface methods. - # - # On success, a dictionary of {etcd key -> etcd value} is returned. - # On timeout, None is returned. - # - def _try_wait_get(self, b64_keys, override_timeout=None): - timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined] - deadline = time.time() + timeout.total_seconds() - - while True: - # Read whole directory (of keys), filter only the ones waited for - all_nodes = None - try: - all_nodes = self.client.get(key=self.prefix) - req_nodes = { - node.key: node.value - for node in all_nodes.children - if node.key in b64_keys - } - - if len(req_nodes) == len(b64_keys): - # All keys are available - return req_nodes - except etcd.EtcdKeyNotFound: - pass - - watch_timeout = deadline - time.time() - if watch_timeout <= 0: - return None - - try: - index = all_nodes.etcd_index + 1 if all_nodes else 0 - self.client.watch( - key=self.prefix, - recursive=True, - timeout=watch_timeout, - index=index, - ) - except etcd.EtcdWatchTimedOut: - if time.time() >= deadline: - return None - else: - continue - except etcd.EtcdEventIndexCleared: - continue diff --git a/mindtorch/distributed/elastic/rendezvous/registry.py b/mindtorch/distributed/elastic/rendezvous/registry.py deleted file mode 100644 index fa7d3e811..000000000 --- a/mindtorch/distributed/elastic/rendezvous/registry.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import sys - -from .api import ( - rendezvous_handler_registry as handler_registry, - RendezvousHandler, - RendezvousParameters, -) -from .dynamic_rendezvous import create_handler - - -if sys.version_info < (3, 10): - from importlib_metadata import entry_points -else: - from importlib.metadata import entry_points - -log = logging.getLogger(__name__) - -__all__ = ["get_rendezvous_handler"] - - -def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: - from . import static_tcp_rendezvous - - return static_tcp_rendezvous.create_rdzv_handler(params) - - -def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler: - from . import etcd_rendezvous - - return etcd_rendezvous.create_rdzv_handler(params) - - -def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler: - from .etcd_rendezvous_backend import create_backend - - backend, store = create_backend(params) - - return create_handler(store, backend, params) - - -def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler: - from .c10d_rendezvous_backend import create_backend - - backend, store = create_backend(params) - - return create_handler(store, backend, params) - - -def _register_default_handlers() -> None: - handler_registry.register("etcd", _create_etcd_handler) - handler_registry.register("etcd-v2", _create_etcd_v2_handler) - handler_registry.register("c10d", _create_c10d_handler) - handler_registry.register("static", _create_static_handler) - - -def _register_out_of_tree_handlers() -> None: - discovered_handler_generators = entry_points(group="torchrun.handlers") - - for handler_generator in discovered_handler_generators: - try: - get_handler = discovered_handler_generators[handler_generator.name].load() - handler_registry.register(handler_generator.name, get_handler()) - except Exception: - log.warning( - "Exception while registering out of tree plugin %s: ", - handler_generator.name, - exc_info=True, - ) - - -def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: - """ - Obtain a reference to a :py:class`RendezvousHandler`. - - Custom rendezvous handlers can be registered by - - :: - - from mindtorch.distributed.elastic.rendezvous import rendezvous_handler_registry - from mindtorch.distributed.elastic.rendezvous.registry import get_rendezvous_handler - - def create_my_rdzv(params: RendezvousParameters): - return MyCustomRdzv(params) - - rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) - - my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters) - """ - return handler_registry.create_handler(params) diff --git a/mindtorch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/mindtorch/distributed/elastic/rendezvous/static_tcp_rendezvous.py deleted file mode 100644 index a24449093..000000000 --- a/mindtorch/distributed/elastic/rendezvous/static_tcp_rendezvous.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import datetime -import logging -from typing import cast, Optional - -from mindtorch.distributed import PrefixStore, Store, TCPStore -from mindtorch.distributed.elastic.rendezvous import ( - RendezvousHandler, - RendezvousInfo, - RendezvousParameters, - RendezvousStoreInfo, -) -from mindtorch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint - - -__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"] - -logger = logging.getLogger(__name__) - -_default_timeout_seconds = 600 - - -class StaticTCPRendezvous(RendezvousHandler): - """ - Static rendezvous that is a wrapper around the TCPStore. - - Creates TCPStore based on the input parameters with the - listener on the agent with group_rank=0 - """ - - def __init__( - self, - master_addr: str, - master_port: int, - rank: int, - world_size: int, - run_id: str, - timeout: int, - ): - self.master_addr = master_addr - self.master_port = master_port - self.rank = rank - self.world_size = world_size - self.run_id = run_id - self.timeout = datetime.timedelta(seconds=timeout) - self._store: Optional[Store] = None - - def get_backend(self) -> str: - return "static" - - @property - def use_agent_store(self) -> bool: - return True - - def next_rendezvous(self) -> RendezvousInfo: - logger.info("Creating TCPStore as the c10d::Store implementation") - is_master = self.rank == 0 - if not self._store: - self._store = TCPStore( # type: ignore[call-arg] - self.master_addr, - self.master_port, - self.world_size, - is_master, - self.timeout, - multi_tenant=True, - ) - store = PrefixStore(self.run_id, self._store) - # TCPStore server instance is used by trainer code - bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port) - return RendezvousInfo( - store, - self.rank, - self.world_size, - bootstrap_store_info, - ) - - def is_closed(self): - return False - - def set_closed(self): - pass - - def num_nodes_waiting(self): - return 0 - - def get_run_id(self) -> str: - return self.run_id - - def shutdown(self) -> bool: - return True - - -def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: - if "rank" not in params.config: - raise ValueError( - "rank is absent in RendezvousParameters." - "Try add --node-rank to the cmd request" - ) - endpoint = params.endpoint.strip() - if not endpoint: - raise ValueError( - "endpoint is absent in RendezvousParameters" - "Try add --master-port and --master-addr to the cmd request" - ) - master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1) - if master_port == -1: - raise ValueError( - f"Port is absent in endpoint: {endpoint}. Try launching with --master-port" - ) - world_size = params.max_nodes - rank = cast(int, params.config.get("rank")) - run_id = params.run_id - if "timeout" in params.config: - timeout = int(params.config["timeout"]) - else: - timeout = _default_timeout_seconds - - return StaticTCPRendezvous( - master_addr, master_port, rank, world_size, run_id, timeout - ) diff --git a/mindtorch/distributed/elastic/rendezvous/utils.py b/mindtorch/distributed/elastic/rendezvous/utils.py deleted file mode 100644 index f946209a9..000000000 --- a/mindtorch/distributed/elastic/rendezvous/utils.py +++ /dev/null @@ -1,284 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import ipaddress -import random -import re -import socket -import time -import weakref -from datetime import timedelta -from threading import Event, Thread -from typing import Any, Callable, Dict, Optional, Tuple, Union - - -__all__ = ["parse_rendezvous_endpoint"] - - -def _parse_rendezvous_config(config_str: str) -> Dict[str, str]: - """Extract key-value pairs from a rendezvous configuration string. - - Args: - config_str: - A string in format =,...,=. - """ - config: Dict[str, str] = {} - - config_str = config_str.strip() - if not config_str: - return config - - key_values = config_str.split(",") - for kv in key_values: - key, *values = kv.split("=", 1) - - key = key.strip() - if not key: - raise ValueError( - "The rendezvous configuration string must be in format " - "=,...,=." - ) - - value: Optional[str] - if values: - value = values[0].strip() - else: - value = None - if not value: - raise ValueError( - f"The rendezvous configuration option '{key}' must have a value specified." - ) - - config[key] = value - return config - - -def _try_parse_port(port_str: str) -> Optional[int]: - """Try to extract the port number from ``port_str``.""" - if port_str and re.match(r"^[0-9]{1,5}$", port_str): - return int(port_str) - return None - - -def parse_rendezvous_endpoint( - endpoint: Optional[str], default_port: int -) -> Tuple[str, int]: - """Extract the hostname and the port number from a rendezvous endpoint. - - Args: - endpoint: - A string in format [:]. - default_port: - The port number to use if the endpoint does not include one. - - Returns: - A tuple of hostname and port number. - """ - if endpoint is not None: - endpoint = endpoint.strip() - - if not endpoint: - return ("localhost", default_port) - - # An endpoint that starts and ends with brackets represents an IPv6 address. - if endpoint[0] == "[" and endpoint[-1] == "]": - host, *rest = endpoint, *[] - else: - host, *rest = endpoint.rsplit(":", 1) - - # Sanitize the IPv6 address. - if len(host) > 1 and host[0] == "[" and host[-1] == "]": - host = host[1:-1] - - if len(rest) == 1: - port = _try_parse_port(rest[0]) - if port is None or port >= 2**16: - raise ValueError( - f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " - "between 0 and 65536." - ) - else: - port = default_port - - if not re.match(r"^[\w\.:-]+$", host): - raise ValueError( - f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of " - "labels, an IPv4 address, or an IPv6 address." - ) - - return host, port - - -def _matches_machine_hostname(host: str) -> bool: - """Indicate whether ``host`` matches the hostname of this machine. - - This function compares ``host`` to the hostname as well as to the IP - addresses of this machine. Note that it may return a false negative if this - machine has CNAME records beyond its FQDN or IP addresses assigned to - secondary NICs. - """ - if host == "localhost": - return True - - try: - addr = ipaddress.ip_address(host) - except ValueError: - addr = None - - if addr and addr.is_loopback: - return True - - try: - host_addr_list = socket.getaddrinfo( - host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME - ) - except (ValueError, socket.gaierror) as _: - host_addr_list = [] - - host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list] - - this_host = socket.gethostname() - if host == this_host: - return True - - addr_list = socket.getaddrinfo( - this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME - ) - for addr_info in addr_list: - # If we have an FQDN in the addr_info, compare it to `host`. - if addr_info[3] and addr_info[3] == host: - return True - - # Otherwise if `host` represents an IP address, compare it to our IP - # address. - if addr and addr_info[4][0] == str(addr): - return True - - # If the IP address matches one of the provided host's IP addresses - if addr_info[4][0] in host_ip_list: - return True - - return False - - -def _delay(seconds: Union[float, Tuple[float, float]]) -> None: - """Suspend the current thread for ``seconds``. - - Args: - seconds: - Either the delay, in seconds, or a tuple of a lower and an upper - bound within which a random delay will be picked. - """ - if isinstance(seconds, tuple): - seconds = random.uniform(*seconds) - # Ignore delay requests that are less than 10 milliseconds. - if seconds >= 0.01: - time.sleep(seconds) - - -class _PeriodicTimer: - """Represent a timer that periodically runs a specified function. - - Args: - interval: - The interval, in seconds, between each run. - function: - The function to run. - """ - - # The state of the timer is hold in a separate context object to avoid a - # reference cycle between the timer and the background thread. - class _Context: - interval: float - function: Callable[..., None] - args: Tuple[Any, ...] - kwargs: Dict[str, Any] - stop_event: Event - - _name: Optional[str] - _thread: Optional[Thread] - _finalizer: Optional[weakref.finalize] - - # The context that is shared between the timer and the background thread. - _ctx: _Context - - def __init__( - self, - interval: timedelta, - function: Callable[..., None], - *args: Any, - **kwargs: Any, - ) -> None: - self._name = None - - self._ctx = self._Context() - self._ctx.interval = interval.total_seconds() - self._ctx.function = function # type: ignore[assignment] - self._ctx.args = args or () - self._ctx.kwargs = kwargs or {} - self._ctx.stop_event = Event() - - self._thread = None - self._finalizer = None - - @property - def name(self) -> Optional[str]: - """Get the name of the timer.""" - return self._name - - def set_name(self, name: str) -> None: - """Set the name of the timer. - - The specified name will be assigned to the background thread and serves - for debugging and troubleshooting purposes. - """ - if self._thread: - raise RuntimeError("The timer has already started.") - - self._name = name - - def start(self) -> None: - """Start the timer.""" - if self._thread: - raise RuntimeError("The timer has already started.") - - self._thread = Thread( - target=self._run, - name=self._name or "PeriodicTimer", - args=(self._ctx,), - daemon=True, - ) - - # We avoid using a regular finalizer (a.k.a. __del__) for stopping the - # timer as joining a daemon thread during the interpreter shutdown can - # cause deadlocks. The weakref.finalize is a superior alternative that - # provides a consistent behavior regardless of the GC implementation. - self._finalizer = weakref.finalize( - self, self._stop_thread, self._thread, self._ctx.stop_event - ) - - # We do not attempt to stop our background thread during the interpreter - # shutdown. At that point we do not even know whether it still exists. - self._finalizer.atexit = False - - self._thread.start() - - def cancel(self) -> None: - """Stop the timer at the next opportunity.""" - if self._finalizer: - self._finalizer() - - @staticmethod - def _run(ctx) -> None: - while not ctx.stop_event.wait(ctx.interval): - ctx.function(*ctx.args, **ctx.kwargs) - - @staticmethod - def _stop_thread(thread, stop_event): - stop_event.set() - - thread.join() diff --git a/mindtorch/distributed/elastic/timer/__init__.py b/mindtorch/distributed/elastic/timer/__init__.py deleted file mode 100644 index 7876a99be..000000000 --- a/mindtorch/distributed/elastic/timer/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Expiration timers are set up on the same process as the agent and -used from your script to deal with stuck workers. When you go into -a code-block that has the potential to get stuck you can acquire -an expiration timer, which instructs the timer server to kill the -process if it does not release the timer by the self-imposed expiration -deadline. - -Usage:: - - import mindtorchelastic.timer as timer - import mindtorchelastic.agent.server as agent - - def main(): - start_method = "spawn" - message_queue = mp.get_context(start_method).Queue() - server = timer.LocalTimerServer(message, max_interval=0.01) - server.start() # non-blocking - - spec = WorkerSpec( - fn=trainer_func, - args=(message_queue,), - ...) - agent = agent.LocalElasticAgent(spec, start_method) - agent.run() - - def trainer_func(message_queue): - timer.configure(timer.LocalTimerClient(message_queue)) - with timer.expires(after=60): # 60 second expiry - # do some work - -In the example above if ``trainer_func`` takes more than 60 seconds to -complete, then the worker process is killed and the agent retries the worker group. -""" - -from .api import ( # noqa: F401 - configure, - expires, - TimerClient, - TimerRequest, - TimerServer, -) -from .file_based_local_timer import ( # noqa: F401 - FileTimerClient, - FileTimerRequest, - FileTimerServer, -) -from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 diff --git a/mindtorch/distributed/elastic/timer/api.py b/mindtorch/distributed/elastic/timer/api.py deleted file mode 100644 index 34698ab47..000000000 --- a/mindtorch/distributed/elastic/timer/api.py +++ /dev/null @@ -1,283 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import abc -import logging -import threading -import time -from contextlib import contextmanager -from inspect import getframeinfo, stack -from typing import Any, Dict, List, Optional, Set - - -__all__ = [ - "TimerRequest", - "TimerClient", - "RequestQueue", - "TimerServer", - "configure", - "expires", -] - -logger = logging.getLogger(__name__) - - -class TimerRequest: - """ - Data object representing a countdown timer acquisition and release - that is used between the ``TimerClient`` and ``TimerServer``. - A negative ``expiration_time`` should be interpreted as a "release" - request. - - .. note:: the type of ``worker_id`` is implementation specific. - It is whatever the TimerServer and TimerClient implementations - have on to uniquely identify a worker. - """ - - __slots__ = ["worker_id", "scope_id", "expiration_time"] - - def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): - self.worker_id = worker_id - self.scope_id = scope_id - self.expiration_time = expiration_time - - def __eq__(self, other): - if isinstance(other, TimerRequest): - return ( - self.worker_id == other.worker_id - and self.scope_id == other.scope_id - and self.expiration_time == other.expiration_time - ) - return False - - -class TimerClient(abc.ABC): - """ - Client library to acquire and release countdown timers by communicating - with the TimerServer. - """ - - @abc.abstractmethod - def acquire(self, scope_id: str, expiration_time: float) -> None: - """ - Acquires a timer for the worker that holds this client object - given the scope_id and expiration_time. Typically registers - the timer with the TimerServer. - """ - - @abc.abstractmethod - def release(self, scope_id: str): - """ - Releases the timer for the ``scope_id`` on the worker this - client represents. After this method is - called, the countdown timer on the scope is no longer in effect. - """ - - -class RequestQueue(abc.ABC): - """ - Consumer queue holding timer acquisition/release requests - """ - - @abc.abstractmethod - def size(self) -> int: - """ - Returns the size of the queue at the time this method is called. - Note that by the time ``get`` is called the size of the queue - may have increased. The size of the queue should not decrease - until the ``get`` method is called. That is, the following assertion - should hold: - - size = q.size() - res = q.get(size, timeout=0) - assert size == len(res) - - -- or -- - - size = q.size() - res = q.get(size * 2, timeout=1) - assert size <= len(res) <= size * 2 - """ - - @abc.abstractmethod - def get(self, size: int, timeout: float) -> List[TimerRequest]: - """ - Gets up to ``size`` number of timer requests in a blocking fashion - (no more than ``timeout`` seconds). - """ - - -class TimerServer(abc.ABC): - """ - Entity that monitors active timers and expires them - in a timely fashion. This server is responsible for - reaping workers that have expired timers. - """ - - def __init__( - self, request_queue: RequestQueue, max_interval: float, daemon: bool = True - ): - """ - :param request_queue: Consumer ``RequestQueue`` - :param max_interval: max time (in seconds) to wait - for an item in the request_queue - :param daemon: whether to run the watchdog thread as a daemon - """ - super().__init__() - self._request_queue = request_queue - self._max_interval = max_interval - self._daemon = daemon - self._watchdog_thread: Optional[threading.Thread] = None - self._stop_signaled = False - - @abc.abstractmethod - def register_timers(self, timer_requests: List[TimerRequest]) -> None: - """ - Processes the incoming timer requests and registers them with the server. - The timer request can either be a acquire-timer or release-timer request. - Timer requests with a negative expiration_time should be interpreted - as a release-timer request. - """ - - @abc.abstractmethod - def clear_timers(self, worker_ids: Set[Any]) -> None: - """ - Clears all timers for the given ``worker_ids``. - """ - - @abc.abstractmethod - def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]: - """ - Returns all expired timers for each worker_id. An expired timer - is a timer for which the expiration_time is less than or equal to - the provided deadline. - """ - - @abc.abstractmethod - def _reap_worker(self, worker_id: Any) -> bool: - """ - Reaps the given worker. Returns True if the worker has been - successfully reaped, False otherwise. If any uncaught exception - is thrown from this method, the worker is considered reaped - and all associated timers will be removed. - """ - - def _reap_worker_no_throw(self, worker_id: Any) -> bool: - """ - Wraps ``_reap_worker(worker_id)``, if an uncaught exception is - thrown, then it considers the worker as reaped. - """ - try: - return self._reap_worker(worker_id) - except Exception: - logger.exception( - "Uncaught exception thrown from _reap_worker(), " - "check that the implementation correctly catches exceptions", - ) - return True - - def _watchdog_loop(self): - while not self._stop_signaled: - try: - self._run_watchdog() - except Exception: - logger.exception("Error running watchdog") - - def _run_watchdog(self): - batch_size = max(1, self._request_queue.size()) - timer_requests = self._request_queue.get(batch_size, self._max_interval) - self.register_timers(timer_requests) - now = time.time() - reaped_worker_ids = set() - for worker_id, expired_timers in self.get_expired_timers(now).items(): - logger.info( - "Reaping worker_id=[%s]." " Expired timers: %s", - worker_id, - self._get_scopes(expired_timers), - ) - if self._reap_worker_no_throw(worker_id): - logger.info("Successfully reaped worker=[%s]", worker_id) - reaped_worker_ids.add(worker_id) - else: - logger.error( - "Error reaping worker=[%s]. Will retry on next watchdog.", worker_id - ) - self.clear_timers(reaped_worker_ids) - - def _get_scopes(self, timer_requests): - return [r.scope_id for r in timer_requests] - - def start(self) -> None: - logger.info( - "Starting %s..." " max_interval=%s," " daemon=%s", - type(self).__name__, - self._max_interval, - self._daemon, - ) - self._watchdog_thread = threading.Thread( - target=self._watchdog_loop, daemon=self._daemon - ) - logger.info("Starting watchdog thread...") - self._watchdog_thread.start() - - def stop(self) -> None: - logger.info("Stopping %s", type(self).__name__) - self._stop_signaled = True - if self._watchdog_thread: - logger.info("Stopping watchdog thread...") - self._watchdog_thread.join(self._max_interval) - self._watchdog_thread = None - else: - logger.info("No watchdog thread running, doing nothing") - - -_timer_client: Optional[TimerClient] = None - - -def configure(timer_client: TimerClient): - """ - Configures a timer client. Must be called before using ``expires``. - """ - global _timer_client - _timer_client = timer_client - logger.info("Timer client configured to: %s", type(_timer_client).__name__) - - -@contextmanager -def expires( - after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None -): - """ - Acquires a countdown timer that expires in ``after`` seconds from now, - unless the code-block that it wraps is finished within the timeframe. - When the timer expires, this worker is eligible to be reaped. The - exact meaning of "reaped" depends on the client implementation. In - most cases, reaping means to terminate the worker process. - Note that the worker is NOT guaranteed to be reaped at exactly - ``time.now() + after``, but rather the worker is "eligible" for being - reaped and the ``TimerServer`` that the client talks to will ultimately - make the decision when and how to reap the workers with expired timers. - - Usage:: - - mindtorch.distributed.elastic.timer.configure(LocalTimerClient()) - with expires(after=10): - mindtorch.distributed.all_reduce(...) - """ - if client is None: - if _timer_client is None: - raise RuntimeError("Configure timer client before using countdown timers.") - client = _timer_client - if scope is None: - # grab the caller file + lineno - caller = getframeinfo(stack()[1][0]) - scope = f"{caller.filename}#{caller.lineno}" - expiration = time.time() + after - client.acquire(scope, expiration) - try: - yield - finally: - client.release(scope) diff --git a/mindtorch/distributed/elastic/timer/debug_info_logging.py b/mindtorch/distributed/elastic/timer/debug_info_logging.py deleted file mode 100644 index 9b89768b3..000000000 --- a/mindtorch/distributed/elastic/timer/debug_info_logging.py +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Dict, List - -from mindtorch.distributed.elastic.utils.logging import get_logger - - -logger = get_logger(__name__) - -__all__ = ["log_debug_info_for_expired_timers"] - - -def log_debug_info_for_expired_timers( - run_id: str, - expired_timers: Dict[int, List[str]], -): - if expired_timers: - logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers) diff --git a/mindtorch/distributed/elastic/timer/file_based_local_timer.py b/mindtorch/distributed/elastic/timer/file_based_local_timer.py deleted file mode 100644 index e5f81a30b..000000000 --- a/mindtorch/distributed/elastic/timer/file_based_local_timer.py +++ /dev/null @@ -1,396 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import io -import json -import os -import select -import signal -import sys -import threading -import time -from typing import Callable, Dict, List, Optional, Set, Tuple - -from mindtorch.distributed.elastic.timer.api import TimerClient, TimerRequest -from mindtorch.distributed.elastic.timer.debug_info_logging import ( - log_debug_info_for_expired_timers, -) -from mindtorch.distributed.elastic.utils.logging import get_logger - - -__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"] - -logger = get_logger(__name__) - - -class FileTimerRequest(TimerRequest): - """ - Data object representing a countdown timer acquisition and release - that is used between the ``FileTimerClient`` and ``FileTimerServer``. - A negative ``expiration_time`` should be interpreted as a "release" - request. - ``signal`` is the signal to reap the worker process from the server - process. - """ - - __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"] - - def __init__( - self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 - ) -> None: - self.version = 1 - self.worker_pid = worker_pid - self.scope_id = scope_id - self.expiration_time = expiration_time - self.signal = signal - - def __eq__(self, other) -> bool: - if isinstance(other, FileTimerRequest): - return ( - self.version == other.version - and self.worker_pid == other.worker_pid - and self.scope_id == other.scope_id - and self.expiration_time == other.expiration_time - and self.signal == other.signal - ) - return False - - def to_json(self) -> str: - return json.dumps( - { - "version": self.version, - "pid": self.worker_pid, - "scope_id": self.scope_id, - "expiration_time": self.expiration_time, - "signal": self.signal, - }, - ) - - -class FileTimerClient(TimerClient): - """ - Client side of ``FileTimerServer``. This client is meant to be used - on the same host that the ``FileTimerServer`` is running on and uses - pid to uniquely identify a worker. - This client uses a named_pipe to send timer requests to the - ``FileTimerServer``. This client is a producer while the - ``FileTimerServer`` is a consumer. Multiple clients can work with - the same ``FileTimerServer``. - - Args: - - file_path: str, the path of a FIFO special file. ``FileTimerServer`` - must have created it by calling os.mkfifo(). - - signal: signal, the signal to use to kill the process. Using a - negative or zero signal will not kill the process. - """ - - def __init__( - self, - file_path: str, - signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined] - ) -> None: - super().__init__() - self._file_path = file_path - self.signal = signal - - def _open_non_blocking(self) -> Optional[io.TextIOWrapper]: - try: - fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK) - return os.fdopen(fd, "wt") - except Exception: - return None - - def _send_request(self, request: FileTimerRequest) -> None: - # The server may have crashed or may haven't started yet. - # In such case, calling open() in blocking model blocks the client. - # To avoid such issue, open it in non-blocking mode, and an OSError will - # be raised if the server is not there. - file = self._open_non_blocking() - if file is None: - raise BrokenPipeError( - "Could not send the FileTimerRequest because FileTimerServer is not available." - ) - with file: - json_request = request.to_json() - # Write request with no greater than select.PIPE_BUF is guarantee to be atomic. - if len(json_request) > select.PIPE_BUF: - raise RuntimeError( - f"FileTimerRequest larger than {select.PIPE_BUF} bytes " - f"is not supported: {json_request}" - ) - file.write(json_request + "\n") - - def acquire(self, scope_id: str, expiration_time: float) -> None: - self._send_request( - request=FileTimerRequest( - worker_pid=os.getpid(), - scope_id=scope_id, - expiration_time=expiration_time, - signal=self.signal, - ), - ) - - def release(self, scope_id: str) -> None: - self._send_request( - request=FileTimerRequest( - worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0 - ), - ) - - -class FileTimerServer: - """ - Server that works with ``FileTimerClient``. Clients are expected to be - running on the same host as the process that is running this server. - Each host in the job is expected to start its own timer server locally - and each server instance manages timers for local workers (running on - processes on the same host). - - Args: - - file_path: str, the path of a FIFO special file to be created. - - max_interval: float, max interval in seconds for each watchdog loop. - - daemon: bool, running the watchdog thread in daemon mode or not. - A daemon thread will not block a process to stop. - log_event: Callable[[Dict[str, str]], None], an optional callback for - logging the events in JSON format. - """ - - def __init__( - self, - file_path: str, - run_id: str, - max_interval: float = 10, - daemon: bool = True, - log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None, - ) -> None: - self._file_path = file_path - self._run_id = run_id - self._max_interval = max_interval - self._daemon = daemon - self._timers: Dict[Tuple[int, str], FileTimerRequest] = {} - self._stop_signaled = False - self._watchdog_thread: Optional[threading.Thread] = None - - self._is_client_started = False - if os.path.exists(self._file_path): - os.remove(self._file_path) - os.mkfifo(self._file_path) - # For test only. Count the number of requests received. - self._request_count = 0 - # For test only. Process all requests and stop the server. - self._run_once = False - self._log_event = ( - log_event if log_event is not None else lambda name, request: None - ) - self._last_progress_time = int(time.time()) - - def start(self) -> None: - logger.info( - "Starting %s... max_interval=%s, daemon=%s, file_path=%s", - type(self).__name__, - self._max_interval, - self._daemon, - self._file_path, - ) - self._watchdog_thread = threading.Thread( - target=self._watchdog_loop, daemon=self._daemon - ) - logger.info("Starting watchdog thread...") - self._watchdog_thread.start() - self._log_event("watchdog started", None) - - def stop(self) -> None: - logger.info("Stopping %s", type(self).__name__) - self._stop_signaled = True - if self._watchdog_thread: - logger.info("Stopping watchdog thread...") - self._watchdog_thread.join(self._max_interval) - self._watchdog_thread = None - else: - logger.info("No watchdog thread running, doing nothing") - if os.path.exists(self._file_path): - os.remove(self._file_path) - self._log_event("watchdog stopped", None) - - def run_once(self) -> None: - self._run_once = True - if self._watchdog_thread: - logger.info("Stopping watchdog thread...") - self._watchdog_thread.join() - self._watchdog_thread = None - else: - logger.info("No watchdog thread running, doing nothing") - if os.path.exists(self._file_path): - os.remove(self._file_path) - - @staticmethod - def is_process_running(pid: int): - """ - function to check process is running or not - """ - try: - # Check if the process exists and we can send signals to it - os.kill(pid, 0) - return True - except OSError: - return False - - def _watchdog_loop(self) -> None: - # Open the pipe in blocking mode blocks the server thread. - # This is fine for the following reasons: - # 1. No client case usually does not happen. - # 2. We are running the watchdog loop in a separate daemon - # thread, which will not block the process to stop. - with open(self._file_path) as fd: - self._is_client_started = True - while not self._stop_signaled: - try: - run_once = self._run_once - self._run_watchdog(fd) - if run_once: - break - self._last_progress_time = int(time.time()) - except Exception: - logger.exception("Error running watchdog") - - def _run_watchdog(self, fd: io.TextIOWrapper) -> None: - timer_requests = self._get_requests(fd, self._max_interval) - self.register_timers(timer_requests) - now = time.time() - reaped_worker_pids = set() - - all_expired_timers = self.get_expired_timers(now) - log_debug_info_for_expired_timers( - self._run_id, - { - pid: [expired_timer.to_json() for expired_timer in expired_timers] - for pid, expired_timers in all_expired_timers.items() - }, - ) - - for worker_pid, expired_timers in all_expired_timers.items(): - logger.info( - "Reaping worker_pid=[%s]. Expired timers: %s", - worker_pid, - self._get_scopes(expired_timers), - ) - reaped_worker_pids.add(worker_pid) - # In case we have multiple expired timers, we find the first timer - # with a valid signal (>0) in the expiration time order. - expired_timers.sort(key=lambda timer: timer.expiration_time) - signal = 0 - expired_timer = None - for timer in expired_timers: - self._log_event("timer expired", timer) - if timer.signal > 0: - signal = timer.signal - expired_timer = timer - break - if signal <= 0: - logger.info( - "No signal specified with worker=[%s]. Do not reap it.", worker_pid - ) - continue - if self._reap_worker(worker_pid, signal): - logger.info( - "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal - ) - self._log_event("kill worker process", expired_timer) - else: - logger.error( - "Error reaping worker=[%s]. Will retry on next watchdog.", - worker_pid, - ) - self.clear_timers(reaped_worker_pids) - - def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]: - return [r.scope_id for r in timer_requests] - - def _get_requests( - self, fd: io.TextIOWrapper, max_interval: float - ) -> List[FileTimerRequest]: - start = time.time() - requests = [] - while not self._stop_signaled or self._run_once: - # For named pipe, readline() is blocking when at least one writer opens. - # It returns only when flush() is called at the writer side. - # Note that flush() is automatically called inside close(). - # After the last writer closes, readline() is not blocking. - # It will return an empty string when it's at end-of-file. - # Since the client side always opens the pipe, writes a message and closes - # the pipe immediately, the readline() call below is not blocking for long. - json_request = fd.readline() - if len(json_request) == 0: - if self._run_once: - break - time.sleep(min(max_interval, 1)) - else: - request = json.loads(json_request) - pid = request["pid"] - scope_id = request["scope_id"] - expiration_time = request["expiration_time"] - signal = request["signal"] - requests.append( - FileTimerRequest( - worker_pid=pid, - scope_id=scope_id, - expiration_time=expiration_time, - signal=signal, - ) - ) - now = time.time() - if now - start > max_interval: - break - return requests - - def register_timers(self, timer_requests: List[FileTimerRequest]) -> None: - for request in timer_requests: - pid = request.worker_pid - scope_id = request.scope_id - expiration_time = request.expiration_time - self._request_count += 1 - - key = (pid, scope_id) - # negative expiration is a proxy for a release call - if expiration_time < 0: - if key in self._timers: - del self._timers[key] - else: - self._timers[key] = request - - def clear_timers(self, worker_pids: Set[int]) -> None: - for pid, scope_id in list(self._timers.keys()): - if pid in worker_pids or not FileTimerServer.is_process_running(pid): - del self._timers[(pid, scope_id)] - - def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]: - # pid -> [timer_requests...] - expired_timers: Dict[int, List[FileTimerRequest]] = {} - for request in self._timers.values(): - if request.expiration_time <= deadline: - expired_scopes = expired_timers.setdefault(request.worker_pid, []) - expired_scopes.append(request) - return expired_timers - - def _reap_worker(self, worker_pid: int, signal: int) -> bool: - try: - os.kill(worker_pid, signal) - return True - except ProcessLookupError: - logger.info("Process with pid=%s does not exist. Skipping", worker_pid) - return True - except Exception: - logger.exception("Error terminating pid=%s", worker_pid) - return False - - def get_last_progress_time(self) -> int: - return self._last_progress_time if self._is_client_started else int(time.time()) diff --git a/mindtorch/distributed/elastic/timer/local_timer.py b/mindtorch/distributed/elastic/timer/local_timer.py deleted file mode 100644 index d3562877a..000000000 --- a/mindtorch/distributed/elastic/timer/local_timer.py +++ /dev/null @@ -1,128 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import logging -import multiprocessing as mp -import os -import signal -import time -from queue import Empty -from typing import Any, Dict, List, Set, Tuple - -from .api import RequestQueue, TimerClient, TimerRequest, TimerServer - - -__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"] - -logger = logging.getLogger(__name__) - - -class LocalTimerClient(TimerClient): - """ - Client side of ``LocalTimerServer``. This client is meant to be used - on the same host that the ``LocalTimerServer`` is running on and uses - pid to uniquely identify a worker. This is particularly useful in situations - where one spawns a subprocess (trainer) per GPU on a host with multiple - GPU devices. - """ - - def __init__(self, mp_queue): - super().__init__() - self._mp_queue = mp_queue - - def acquire(self, scope_id, expiration_time): - pid = os.getpid() - acquire_request = TimerRequest(pid, scope_id, expiration_time) - self._mp_queue.put(acquire_request) - - def release(self, scope_id): - pid = os.getpid() - release_request = TimerRequest(pid, scope_id, -1) - self._mp_queue.put(release_request) - - -class MultiprocessingRequestQueue(RequestQueue): - """ - A ``RequestQueue`` backed by python ``multiprocessing.Queue`` - """ - - def __init__(self, mp_queue: mp.Queue): - super().__init__() - self._mp_queue = mp_queue - - def size(self) -> int: - return self._mp_queue.qsize() - - def get(self, size, timeout: float) -> List[TimerRequest]: - requests = [] - wait = timeout - for _ in range(0, size): - start = time.time() - - try: - r = self._mp_queue.get(block=True, timeout=wait) - except Empty: - break - - requests.append(r) - wait = wait - (time.time() - start) - if wait <= 0: - break - - return requests - - -class LocalTimerServer(TimerServer): - """ - Server that works with ``LocalTimerClient``. Clients are expected to be - subprocesses to the parent process that is running this server. Each host - in the job is expected to start its own timer server locally and each - server instance manages timers for local workers (running on processes - on the same host). - """ - - def __init__( - self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True - ): - super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) - self._timers: Dict[Tuple[Any, str], TimerRequest] = {} - - def register_timers(self, timer_requests: List[TimerRequest]) -> None: - for request in timer_requests: - pid = request.worker_id - scope_id = request.scope_id - expiration_time = request.expiration_time - - # negative expiration is a proxy for a release call - if expiration_time < 0: - self._timers.pop((pid, scope_id), None) - else: - self._timers[(pid, scope_id)] = request - - def clear_timers(self, worker_ids: Set[int]) -> None: - for pid, scope_id in list(self._timers.keys()): - if pid in worker_ids: - self._timers.pop((pid, scope_id)) - - def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]: - # pid -> [timer_requests...] - expired_timers: Dict[Any, List[TimerRequest]] = {} - for request in self._timers.values(): - if request.expiration_time <= deadline: - expired_scopes = expired_timers.setdefault(request.worker_id, []) - expired_scopes.append(request) - return expired_timers - - def _reap_worker(self, worker_id: int) -> bool: - try: - os.kill(worker_id, signal.SIGKILL) - return True - except ProcessLookupError: - logger.info("Process with pid=%s does not exist. Skipping", worker_id) - return True - except Exception: - logger.exception("Error terminating pid=%s", worker_id) - return False diff --git a/mindtorch/distributed/elastic/utils/__init__.py b/mindtorch/distributed/elastic/utils/__init__.py deleted file mode 100644 index 5fbc76bf7..000000000 --- a/mindtorch/distributed/elastic/utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401 diff --git a/mindtorch/distributed/elastic/utils/api.py b/mindtorch/distributed/elastic/utils/api.py deleted file mode 100644 index bff91438b..000000000 --- a/mindtorch/distributed/elastic/utils/api.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import socket -from string import Template -from typing import Any, List - - -def get_env_variable_or_raise(env_name: str) -> str: - r""" - Tries to retrieve environment variable. Raises ``ValueError`` - if no environment variable found. - - Args: - env_name (str): Name of the env variable - """ - value = os.environ.get(env_name, None) - if value is None: - msg = f"Environment variable {env_name} expected, but not set" - raise ValueError(msg) - return value - - -def get_socket_with_port() -> socket.socket: - addrs = socket.getaddrinfo( - host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM - ) - for addr in addrs: - family, type, proto, _, _ = addr - s = socket.socket(family, type, proto) - try: - s.bind(("localhost", 0)) - s.listen(0) - return s - except OSError: - s.close() - raise RuntimeError("Failed to create a socket") - - -class macros: - """ - Defines simple macros for caffe2.distributed.launch cmd args substitution - """ - - local_rank = "${local_rank}" - - @staticmethod - def substitute(args: List[Any], local_rank: str) -> List[str]: - args_sub = [] - for arg in args: - if isinstance(arg, str): - sub = Template(arg).safe_substitute(local_rank=local_rank) - args_sub.append(sub) - else: - args_sub.append(arg) - return args_sub diff --git a/mindtorch/distributed/elastic/utils/distributed.py b/mindtorch/distributed/elastic/utils/distributed.py deleted file mode 100644 index 00437ff83..000000000 --- a/mindtorch/distributed/elastic/utils/distributed.py +++ /dev/null @@ -1,184 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import datetime -import os -import socket -from contextlib import closing -from typing import Optional - -import mindtorch.distributed as dist -from mindtorch.distributed.elastic.utils.logging import get_logger -from mindtorch.distributed.elastic.utils.store import barrier - - -__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"] - -logger = get_logger(__name__) - -_ADDRESS_IN_USE = "Address already in use" -_SOCKET_TIMEOUT = "Socket Timeout" - -_TCP_STORE_INIT = "_tcp_store/num_members" - - -def create_c10d_store( - is_server: bool, - server_addr: str, - server_port: int = -1, - world_size: int = 1, - timeout: float = (60 * 10), # 10 min - wait_for_workers: bool = True, - retries=3, - use_libuv: Optional[bool] = None, -): - if use_libuv is not None: - logger.warning( - "argument use_libuv is deprecated and ignored. Set USE_LIBUV environment " - 'variable to "0" to disable libuv, or "1" to enable it. If the env var ' - "is not set, libuv will be used by default." - ) - - # check os.environ for use_libuv - use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option - - if server_port == -1 and world_size > 1: - raise ValueError( - f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" - ) - - if server_port != -1: - logger.info("sever_port: %s, specified, ignoring retries", server_port) - - # only retry when server_port is NOT static - attempt = retries if server_port == -1 else 1 - while True: - if server_port != -1: - port = server_port - else: - port = get_free_port() - - logger.info( - "Creating c10d store on %s:%s\n" - " world_size : %s\n" - " is_server : %s\n" - " timeout(sec): %s\n" - " use_libuv : %s\n", - server_addr, - port, - world_size, - is_server, - timeout, - use_libuv, - ) - - try: - store = dist.TCPStore( - host_name=server_addr, - port=port, - world_size=world_size, - is_master=is_server, - timeout=datetime.timedelta(seconds=timeout), - wait_for_workers=wait_for_workers, - use_libuv=use_libuv, - ) - # skips full rank check when we don't have to wait for all workers - if wait_for_workers: - _check_full_rank(store, world_size, timeout=timeout) - logger.info("Successfully created c10d store") - return store - except RuntimeError as e: - # this is brittle, but the underlying exception type is not properly pybinded - # so we parse the error msg for now, interestingly this is how torch itself - # detects timeouts and port conflicts in their own unittests - # see - caffe2/torch/testing/_internal/common_utils.py - # TODO properly map the exceptions in pybind (c10d/init.cpp) - if str(e) == _ADDRESS_IN_USE: # this will only happen on the server - if attempt < retries: - logger.warning( - "port: %s already in use, attempt: [%s/%s]", - port, - attempt, - retries, - ) - attempt += 1 - else: - raise RuntimeError( - f"on {server_addr}, port: {port} already in use" - ) from e - else: - raise - - -def _check_full_rank(store, world_size, timeout): - try: - barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout) - except RuntimeError as e: - if str(e) == _SOCKET_TIMEOUT: - raise TimeoutError( - f"timed out waiting for all {world_size} members to join" - ) from e - else: - raise - - -def get_free_port(): - """ - Returns an unused port on localhost. - - This function finds an unused port on localhost by opening to socket to bind - to a port and then closing it. - - Returns: - int: an unused port on localhost - - Example: - >>> # xdoctest: +SKIP("Nondeterministic") - >>> get_free_port() - 63976 - - ..note: - The port returned by :func:`get_free_port` is not reserved and may be - taken by another process after this function returns. - """ - sock = get_socket_with_port() - with closing(sock): - return sock.getsockname()[1] - - -def get_socket_with_port() -> socket.socket: - """ - Returns a free port on localhost that is "reserved" by binding a temporary - socket on it. Close the socket before passing the port to the entity - that requires it. Usage example - - :: - - sock = _get_socket_with_port() - with closing(sock): - port = sock.getsockname()[1] - sock.close() - # there is still a race-condition that some other process - # may grab this port before func() runs - func(port) - """ - - addrs = socket.getaddrinfo( - host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM - ) - for addr in addrs: - family, type, proto, _, _ = addr - s = socket.socket(family, type, proto) - try: - s.bind(("localhost", 0)) - s.listen(0) - return s - except OSError as e: - s.close() - logger.warning("Socket creation attempt failed.", exc_info=e) - raise RuntimeError("Failed to create a socket") diff --git a/mindtorch/distributed/elastic/utils/log_level.py b/mindtorch/distributed/elastic/utils/log_level.py deleted file mode 100644 index ace6e1a33..000000000 --- a/mindtorch/distributed/elastic/utils/log_level.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -def get_log_level() -> str: - """ - Return default log level for pymindtorch. - """ - return "WARNING" diff --git a/mindtorch/distributed/elastic/utils/logging.py b/mindtorch/distributed/elastic/utils/logging.py deleted file mode 100644 index 90b450f2e..000000000 --- a/mindtorch/distributed/elastic/utils/logging.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import inspect -import logging -import os -import warnings -from typing import Optional - -from mindtorch.distributed.elastic.utils.log_level import get_log_level - - -def get_logger(name: Optional[str] = None): - """ - Util function to set up a simple logger that writes - into stderr. The loglevel is fetched from the LOGLEVEL - env. variable or WARNING as default. The function will use the - module name of the caller if no name is provided. - - Args: - name: Name of the logger. If no name provided, the name will - be derived from the call stack. - """ - - # Derive the name of the caller, if none provided - # Use depth=2 since this function takes up one level in the call stack - return _setup_logger(name or _derive_module_name(depth=2)) - - -def _setup_logger(name: Optional[str] = None): - logger = logging.getLogger(name) - logger.setLevel(os.environ.get("LOGLEVEL", get_log_level())) - return logger - - -def _derive_module_name(depth: int = 1) -> Optional[str]: - """ - Derives the name of the caller module from the stack frames. - - Args: - depth: The position of the frame in the stack. - """ - try: - stack = inspect.stack() - assert depth < len(stack) - # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index) - frame_info = stack[depth] - - module = inspect.getmodule(frame_info[0]) - if module: - module_name = module.__name__ - else: - # inspect.getmodule(frame_info[0]) does NOT work (returns None) in - # binaries built with @mode/opt - # return the filename (minus the .py extension) as modulename - filename = frame_info[1] - module_name = os.path.splitext(os.path.basename(filename))[0] - return module_name - except Exception as e: - warnings.warn( - f"Error deriving logger module name, using . Exception: {e}", - RuntimeWarning, - ) - return None diff --git a/mindtorch/distributed/elastic/utils/store.py b/mindtorch/distributed/elastic/utils/store.py deleted file mode 100644 index a3a9524c5..000000000 --- a/mindtorch/distributed/elastic/utils/store.py +++ /dev/null @@ -1,225 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from contextlib import contextmanager -from datetime import timedelta -from typing import Callable, Iterable, List, Optional - -import mindtorch - - -DistStoreError = mindtorch._C._DistStoreError - -_NUM_MEMBERS = "/num_members" -_LAST_MEMBER_CHECKIN = "/last_member" -_TRACE = "/TRACE" -_TRACING_GATE = "/TRACING_GATE" -_MAX_TRACE_MISSING_RANKS = 16 - - -__all__ = ["store_timeout", "get_all", "synchronize", "barrier"] - - -@contextmanager -def store_timeout(store, timeout: float): - """ - This sets the timeout and then restores the old timeout when the context - manager exits. - - Args: - store: the store to set the timeout on - timeout: the timeout to set - """ - - old_timeout = store.timeout - store.set_timeout(timedelta(seconds=timeout)) - yield - store.set_timeout(old_timeout) - - -def get_all(store, rank: int, prefix: str, world_size: int): - r""" - Given a store and a prefix, the method goes through the array of keys - of the following format: ``{prefix}{idx}``, where idx is in a range - from 0 to size, and tries to retrieve the data. - - The Rank0 process waits at the end to make sure all other processes - finished the procedure before exiting. - - Usage - - :: - - values = get_all(store, 'torchelastic/data', 3) - value1 = values[0] # retrieves the data for key torchelastic/data0 - value2 = values[1] # retrieves the data for key torchelastic/data1 - value3 = values[2] # retrieves the data for key torchelastic/data2 - - """ - data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) - - barrier_key = _barrier_nonblocking( - store=store, - world_size=world_size, - key_prefix=f"{prefix}/finished", - ) - if rank == 0: - # Rank0 runs the TCPStore daemon, as a result it needs to exit last. - # Otherwise, the barrier may timeout if rank0 process finished the work - # before other processes finished `get_all` method - store.wait([barrier_key]) - - return data_arr - - -def synchronize( - store, - data: bytes, - rank: int, - world_size: int, - key_prefix: str, - timeout: float = 300, -) -> List[bytes]: - """ - Synchronizes ``world_size`` agents between each other using the underlying c10d store. - The ``data`` will be available on each of the agents. - - Note: The data on the path is not deleted, as a result there can be stale data if - you use the same key_prefix twice. - - Time complexity: O(N) per worker, O(N^2) globally. - """ - with store_timeout(store, timeout): - store.set(f"{key_prefix}{rank}", data) - agent_data = get_all(store, rank, key_prefix, world_size) - return agent_data - - -def _try_detecting_missing_ranks( - store, - world_size: int, - key_prefix: str, - rank: int, - rank_decoder: Callable[[int], str], - trace_timeout: float, -) -> Optional[Iterable[str]]: - store.set(f"{key_prefix}{rank}{_TRACE}", "") - - def _find_missing_ranks(): - missing_rank_info = set() - ranks_missing = 0 - for i in range(1, world_size): - # reduce noise, assuming in general 8 ranks per node - # It is valuable to know that 1 or >1 nodes have timed-out. - if ranks_missing >= _MAX_TRACE_MISSING_RANKS: - break - try: - if ranks_missing == 0: - store.wait( - [f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout) - ) - else: - # use a shortest timeout, some ranks have failed to check-in - store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1)) - except DistStoreError: - ranks_missing += 1 - missing_rank_info.add(rank_decoder(i)) - return missing_rank_info - - def _checkin(): - try: - store.wait([f"{key_prefix}{_TRACING_GATE}"]) - return [f"[]"] - except DistStoreError: - # in case rank0 is the source of the timeout, original exception will be raised - return None - - if rank == 0: - missing_rank_info = _find_missing_ranks() - store.set(f"{key_prefix}{_TRACING_GATE}", "") - return missing_rank_info - else: - return _checkin() - - -def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: - """ - Does all the non-blocking operations for a barrier and returns the final key - that can be waited on. - """ - num_members_key = key_prefix + _NUM_MEMBERS - last_member_key = key_prefix + _LAST_MEMBER_CHECKIN - - idx = store.add(num_members_key, 1) - if idx == world_size: - store.set(last_member_key, "") - - return last_member_key - - -def barrier( - store, - world_size: int, - key_prefix: str, - barrier_timeout: float = 300, - rank: Optional[int] = None, - rank_tracing_decoder: Optional[Callable[[int], str]] = None, - trace_timeout: float = 10, -) -> None: - """ - A global lock between agents. This will pause all workers until at least - ``world_size`` workers respond. - - This uses a fast incrementing index to assign waiting ranks and a success - flag set by the last worker. - - Time complexity: O(1) per worker, O(N) globally. - - Optionally, passing rank will enable tracing of missing ranks on timeouts. - `rank_tracing_decoder` lambda arg can be used to convert rank data - into a more meaninful information at an app level (e.g. hostname). - - Note: Since the data is not removed from the store, the barrier can be used - once per unique ``key_prefix``. - """ - - if rank is None: - assert rank_tracing_decoder is None, "Tracing requires rank information" - - with store_timeout(store, barrier_timeout): - last_member_key = _barrier_nonblocking( - store=store, world_size=world_size, key_prefix=key_prefix - ) - try: - store.wait([last_member_key]) - except DistStoreError as e: - if rank is None: - raise e - else: - missing_ranks = _try_detecting_missing_ranks( - store, - world_size, - key_prefix, - rank, - rank_tracing_decoder or (lambda x: str(x)), - trace_timeout, - ) - if missing_ranks is not None: - raise DistStoreError( - "Timed out waiting on barrier on " - "rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format( - rank, - key_prefix, - world_size, - f"[{', '.join(missing_ranks)}]", - barrier_timeout, - ) - ) from None - else: - raise e diff --git a/mindtorch/distributed/launcher/__init__.py b/mindtorch/distributed/launcher/__init__.py deleted file mode 100644 index fd86a42be..000000000 --- a/mindtorch/distributed/launcher/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env/python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from mindtorch.distributed.launcher.api import ( # noqa: F401 - elastic_launch, - launch_agent, - LaunchConfig, -) diff --git a/mindtorch/distributed/launcher/api.py b/mindtorch/distributed/launcher/api.py deleted file mode 100644 index 849cbb4d4..000000000 --- a/mindtorch/distributed/launcher/api.py +++ /dev/null @@ -1,289 +0,0 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import sys -import uuid -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import mindtorch.distributed.elastic.rendezvous.registry as rdzv_registry -from mindtorch.distributed.elastic import events, metrics -from mindtorch.distributed.elastic.agent.server.api import WorkerSpec -from mindtorch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent -from mindtorch.distributed.elastic.multiprocessing import ( - DefaultLogsSpecs, - LogsSpecs, - SignalException, -) -from mindtorch.distributed.elastic.multiprocessing.errors import ChildFailedError -from mindtorch.distributed.elastic.rendezvous import RendezvousParameters -from mindtorch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint -from mindtorch.distributed.elastic.utils.logging import get_logger - - -__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] - -logger = get_logger(__name__) - - -@dataclass -class LaunchConfig: - """ - Creates a rendezvous config. - - Args: - min_nodes: Minimum amount of nodes that the user function will - be launched on. Elastic agent ensures that the user - function start only when the min_nodes amount enters - the rendezvous. - max_nodes: Maximum amount of nodes that the user function - will be launched on. - nproc_per_node: On each node the elastic agent will launch - this amount of workers that will execute user - defined function. - rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). - rdzv_endpoint: The endpoint of the rdzv sync. storage. - rdzv_configs: Key, value pair that specifies rendezvous specific configuration. - rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going - to be removed in future versions, see the note below. The default timeout is 900 seconds. - run_id: The unique run id of the job (if not passed a unique one will be - deduced from run environment - flow workflow id in flow - or auto generated). - role: User defined role of the worker (defaults to "trainer"). - max_restarts: The maximum amount of restarts that elastic agent will conduct - on workers before failure. - monitor_interval: The interval in seconds that is used by the elastic_agent - as a period of monitoring workers. - start_method: The method is used by the elastic agent to start the - workers (spawn, fork, forkserver). - metrics_cfg: configuration to initialize metrics. - local_addr: address of the local node if any. If not set, a lookup on the local - machine's FQDN will be performed. - local_ranks_filter: ranks for which to show logs in console. If not set, show from all. - ..note: - `rdzv_timeout` is a legacy argument that will be removed in future. - Set the timeout via `rdzv_configs['timeout']` - - """ - - min_nodes: int - max_nodes: int - nproc_per_node: int - logs_specs: Optional[LogsSpecs] = None - run_id: str = "" - role: str = "default_role" - rdzv_endpoint: str = "" - rdzv_backend: str = "etcd" - rdzv_configs: Dict[str, Any] = field(default_factory=dict) - rdzv_timeout: int = -1 - max_restarts: int = 3 - monitor_interval: float = 0.1 - start_method: str = "spawn" - log_line_prefix_template: Optional[str] = None - metrics_cfg: Dict[str, str] = field(default_factory=dict) - local_addr: Optional[str] = None - - def __post_init__(self): - default_timeout = 900 - if self.rdzv_timeout != -1: - self.rdzv_configs["timeout"] = self.rdzv_timeout - elif "timeout" not in self.rdzv_configs: - self.rdzv_configs["timeout"] = default_timeout - - # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage - if self.logs_specs is None: - self.logs_specs = DefaultLogsSpecs() - - -class elastic_launch: - """ - Launches an torchelastic agent on the container that invoked the entrypoint. - - 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ - ``entrypoint`` can be a function or a command. - 2. The return value is a map of each worker's output mapped - by their respective global rank. - - Usage - - :: - - def worker_fn(foo): - # ... - - def main(): - # entrypoint is a function. - outputs = elastic_launch(LaunchConfig, worker_fn)(foo) - # return rank 0's output - return outputs[0] - - # entrypoint is a command and ``script.py`` is the python module. - outputs = elastic_launch(LaunchConfig, "script.py")(args) - outputs = elastic_launch(LaunchConfig, "python")("script.py") - """ - - def __init__( - self, - config: LaunchConfig, - entrypoint: Union[Callable, str, None], - ): - self._config = config - self._entrypoint = entrypoint - - def __call__(self, *args): - return launch_agent(self._config, self._entrypoint, list(args)) - - -def _get_entrypoint_name( - entrypoint: Union[Callable, str, None], args: List[Any] -) -> str: - """Retrieve entrypoint name with the rule: - 1. If entrypoint is a function, use ``entrypoint.__qualname__``. - 2. If entrypoint is a string, check its value: - 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` - which does not start with hifen letter (for example, "-u" will be skipped). - 2.2 otherwise, use ``entrypoint`` value. - 3. Otherwise, return empty string. - """ - if isinstance(entrypoint, Callable): # type: ignore[arg-type] - return entrypoint.__name__ # type: ignore[union-attr] - elif isinstance(entrypoint, str): - if entrypoint == sys.executable: - return next((arg for arg in args if arg[0] != "-"), "") - else: - return entrypoint - else: - return "" - - -def _get_addr_and_port( - rdzv_parameters: RendezvousParameters, -) -> Tuple[Optional[str], Optional[int]]: - if rdzv_parameters.backend != "static": - return (None, None) - endpoint = rdzv_parameters.endpoint - endpoint = endpoint.strip() - if not endpoint: - raise ValueError( - "Endpoint is missing in endpoint. Try to add --master-addr and --master-port" - ) - master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) - if master_port == -1: - raise ValueError( - f"port is missing in endpoint: {endpoint}. Try to specify --master-port" - ) - return (master_addr, master_port) - - -def launch_agent( - config: LaunchConfig, - entrypoint: Union[Callable, str, None], - args: List[Any], -) -> Dict[int, Any]: - if not config.run_id: - run_id = str(uuid.uuid4().int) - logger.warning("config has no run_id, generated a random run_id: %s", run_id) - config.run_id = run_id - - entrypoint_name = _get_entrypoint_name(entrypoint, args) - - logger.info( - "Starting elastic_operator with launch configs:\n" - " entrypoint : %(entrypoint)s\n" - " min_nodes : %(min_nodes)s\n" - " max_nodes : %(max_nodes)s\n" - " nproc_per_node : %(nproc_per_node)s\n" - " run_id : %(run_id)s\n" - " rdzv_backend : %(rdzv_backend)s\n" - " rdzv_endpoint : %(rdzv_endpoint)s\n" - " rdzv_configs : %(rdzv_configs)s\n" - " max_restarts : %(max_restarts)s\n" - " monitor_interval : %(monitor_interval)s\n" - " log_dir : %(log_dir)s\n" - " metrics_cfg : %(metrics_cfg)s\n", - { - "entrypoint": entrypoint_name, - "min_nodes": config.min_nodes, - "max_nodes": config.max_nodes, - "nproc_per_node": config.nproc_per_node, - "run_id": config.run_id, - "rdzv_backend": config.rdzv_backend, - "rdzv_endpoint": config.rdzv_endpoint, - "rdzv_configs": config.rdzv_configs, - "max_restarts": config.max_restarts, - "monitor_interval": config.monitor_interval, - "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] - "metrics_cfg": config.metrics_cfg, - }, - ) - - rdzv_parameters = RendezvousParameters( - backend=config.rdzv_backend, - endpoint=config.rdzv_endpoint, - run_id=config.run_id, - min_nodes=config.min_nodes, - max_nodes=config.max_nodes, - local_addr=config.local_addr, - **config.rdzv_configs, - ) - - master_addr, master_port = _get_addr_and_port(rdzv_parameters) - - spec = WorkerSpec( - role=config.role, - local_world_size=config.nproc_per_node, - entrypoint=entrypoint, - args=tuple(args), - rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), - max_restarts=config.max_restarts, - monitor_interval=config.monitor_interval, - master_addr=master_addr, - master_port=master_port, - local_addr=config.local_addr, - ) - - agent = LocalElasticAgent( - spec=spec, - logs_specs=config.logs_specs, # type: ignore[arg-type] - start_method=config.start_method, - log_line_prefix_template=config.log_line_prefix_template, - ) - - shutdown_rdzv = True - try: - metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) - - result = agent.run() - # records that agent.run() has succeeded NOT that workers have succeeded - events.record(agent.get_event_succeeded()) - - if result.is_failed(): - # ChildFailedError is treated specially by @record - # if the error files for the failed children exist - # @record will copy the first error (root cause) - # to the error file of the launcher process. - raise ChildFailedError( - name=entrypoint_name, - failures=result.failures, - ) - - return result.return_values - except ChildFailedError: - raise - except SignalException: - # when the agent dies with a signal do NOT shutdown the rdzv_handler - # since this closes the rendezvous on this rdzv_id permanently and - # prevents any additional scaling events - shutdown_rdzv = False - events.record(agent.get_event_failed()) - raise - except Exception: - events.record(agent.get_event_failed()) - raise - finally: - if shutdown_rdzv: - spec.rdzv_handler.shutdown() diff --git a/mindtorch/distributed/process_entity/__init__.py b/mindtorch/distributed/process_entity/__init__.py new file mode 100644 index 000000000..ffab7e487 --- /dev/null +++ b/mindtorch/distributed/process_entity/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Interfaces for ms_run""" +from ._api import _Node, _MetaServerNode, _ComputeGraphNode, _ProcessManager + +from ._utils import _generate_cmd, _generate_url, _is_local_ip, _send_scale_num diff --git a/mindtorch/distributed/process_entity/_api.py b/mindtorch/distributed/process_entity/_api.py new file mode 100644 index 000000000..a168271b6 --- /dev/null +++ b/mindtorch/distributed/process_entity/_api.py @@ -0,0 +1,581 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""API for ms_run""" +import os +import re +import sys +import signal +import subprocess +import socket +import psutil +import mindspore.log as logger +from ._utils import _generate_cmd_args_list, _generate_cmd_args_list_with_core, _generate_url, \ + _is_local_ip, _convert_addr_to_ip, _send_scale_num, _get_local_ip + + +class _Node: + """ + Base class for dynamic networking nodes. + + """ + + def __init__(self, worker_num, sched_host, sched_port, timeout, args_list, output_file, tail_worker_log, + join, is_simulation): + self.worker_num = worker_num + self.sched_host = sched_host + self.sched_port = sched_port + self.args_list = args_list + self.output_file = output_file + self.timeout = timeout + self.tail_worker_log = tail_worker_log + self.join = join + self.is_simulation = is_simulation + + def run(self): + """ + Runs the node by setting environment variables and executing the entrypoint command or script. + + """ + os.environ["MS_WORKER_NUM"] = str(self.worker_num) + # If simulation level is set, environment variables for dynamic networking will not be set, + # and scheduler will not be started. + if not self.is_simulation: + os.environ["MS_SCHED_HOST"] = self.sched_host + os.environ["MS_SCHED_PORT"] = str(self.sched_port) + os.environ["MS_TOPO_TIMEOUT"] = str(self.timeout) + + +class _MetaServerNode(_Node): + """ + Scheduler node for dynamic networking. Inherits from the Node class. + + """ + + def run(self): + """ + Runs the MetaServerNode by setting environment variables, setting the MS_ROLE variable to + "MS_SCHED", and executing the entrypoint command or script. + """ + super().run() + os.environ["MS_ROLE"] = "MS_SCHED" + with open(self.output_file, "w") as file_handle: + return subprocess.Popen(self.args_list, stdout=file_handle, stderr=subprocess.STDOUT) + + +class _ComputeGraphNode(_Node): + """ + Worker node for dynamic networking. Inherits from the Node class. + """ + + def __init__(self, worker_num, sched_host, sched_port, timeout, node_id, args_list, output_file, + tail_worker_log, join, is_simulation): + super().__init__(worker_num, sched_host, sched_port, timeout, args_list, output_file, + tail_worker_log, join, is_simulation) + self.node_id = node_id + + def run(self): + """ + Runs the ComputeGraphNode by setting environment variables, setting the MS_NODE_ID variable + to the node ID, setting the MS_ROLE variable to "MS_WORKER", and executing the entrypoint + command or script. + + """ + super().run() + if self.node_id is not None: + os.environ["MS_NODE_ID"] = str(self.node_id) + # If simulation level is set, environment variable 'MS_ROLE' will not be set. + if not self.is_simulation: + os.environ["MS_ROLE"] = "MS_WORKER" + tail_worker_process = None + is_tail_worker_log = self.enable_tail_worker_log() + if self.join and not is_tail_worker_log: + logger.warning(f"The '--tail_worker_log' is:{self.tail_worker_log}, " + f"which doesn't contain this worker {self.node_id}." + f" So this worker {self.node_id}'s log will not be output to console. Reset " + "'--tail_worker_log', if you want to output this worker's log to console.") + with open(self.output_file, "w") as file_handle: + worker_process = subprocess.Popen(self.args_list, preexec_fn=os.setsid, stdout=file_handle, + stderr=subprocess.STDOUT) + if self.join and is_tail_worker_log: + tail_worker_process = self.output_to_console() + return worker_process, tail_worker_process + + def output_to_console(self): + """ + Output worker log file to console. + """ + return subprocess.Popen(['/usr/bin/tail', '-f', self.output_file]) + + def enable_tail_worker_log(self): + tail_worker_log_list = [] + if self.tail_worker_log != "-1": + tail_worker_log_list.extend([int(num) for num in self.tail_worker_log.split(',')]) + if self.tail_worker_log != "-1" and self.node_id not in tail_worker_log_list: + return False + return True + + +class _ProcessManager: + """ + Manages the local dynamic networking process. Responsible for dynamic networking and elastic + training + + """ + + def __init__(self, args): + """ + Initializes a ProcessManager object. + + Args: + args: An object containing the command-line arguments. + + """ + self.msn_process = None + self.cgn_processes = [] + self.tail_cgn_processes = [] + + self.master_addr = _convert_addr_to_ip(args.master_addr) + self.master_port = args.master_port + + """`is_master` flags whether the current node is the master node.""" + self.is_master = _is_local_ip(self.master_addr) + + self.worker_num = args.nproc_per_node * args.nnodes + if self.worker_num <= 0: + raise ValueError(f"worker_num must be greater than 0, but got {self.worker_num}.") + self.exported_rank_size = self.worker_num + self.local_worker_num = args.nproc_per_node + self.node_rank = args.node_rank + + self.log_dir = args.log_dir + self.join = args.join + self.worker_log_name = args.worker_log_name + self.tail_worker_log = args.tail_worker_log + self.cluster_time_out = args.cluster_time_out + self.bind_core = args.bind_core + self.rank_table_file = args.rank_table_file + + self.sim_level = args.sim_level + self.sim_rank_id = args.sim_rank_id + self.is_simulation = (self.sim_level != -1) + if self.is_simulation: + os.environ["MS_SIMULATION_LEVEL"] = str(self.sim_level) + elif os.getenv("MS_SIMULATION_LEVEL"): + self.is_simulation = True + self.sim_rank_id = int(os.getenv("RANK_ID", "-1")) + if os.getenv("RANK_SIZE"): + self.exported_rank_size = os.getenv("RANK_SIZE") + # If sim_rank_id is set, single worker can be started. + if self.is_simulation and (self.sim_rank_id != -1): + logger.info(f"Simulation rank id is set to {self.sim_rank_id}, will dryrun a single process.") + self.local_worker_num = 1 + if self.is_simulation and self.local_worker_num > 128: + self.local_worker_num = 1 + self.sim_rank_id = 0 + logger.warning(f"In dryrun case, local worker num is set to larger than 128. " + "To avoid a system clash, local worker num is set to 1.") + + self.cmd = args.task_script + self.cmd_args = args.task_script_args + + """`is_scale` flags whether the current task is a scaling task and there is already a + manager on the current node.""" + self.is_scale = False + self.scheduler_url = _generate_url(self.master_addr, self.master_port) + + # Create log directory and set the permission if not exists. + if self.log_dir and not os.path.exists(self.log_dir): + permissions = os.R_OK | os.W_OK | os.X_OK + origin_mask = os.umask(permissions << 3 | permissions) + try: + mode = permissions << 6 + os.makedirs(self.log_dir, mode=mode, exist_ok=True) + finally: + os.umask(origin_mask) + + self.proc_rank_map = {} + self.enable_mindx = False + tft_env = os.getenv("MS_ENABLE_TFT", "") + if ("TTP:1" in tft_env) or ("UCE:1" in tft_env) or ("ARF:1" in tft_env): + try: + from taskd.python.framework.agent.ms_mgr.msrun_plugin import MSRunPlugin + self.msmgr = MSRunPlugin() + self.msmgr.register_callbacks("KILL_WORKER", self.kill_workers) + self.msmgr.register_callbacks("START_ALL_WORKER", self.start_all_workers) + self.msmgr.register_callbacks("MONITOR", self.monitor_rank_status) + self.enable_mindx = True + os.environ["MS_ENABLE_RECOVERY"] = str(1) + except Exception as e: # pylint: disable=broad-except + logger.warning(f"mindx is not installed, using original mindspore recovery strategy.: {str(e)}") + + def run(self): + """ + Runs the process manager. + + """ + os.environ["RANK_SIZE"] = str(self.exported_rank_size) + if self.rank_table_file != "": + os.environ["RANK_TABLE_FILE"] = self.rank_table_file + logger.warning(f"msrun launching distributed job with user configured rank table file path:" + f"{self.rank_table_file}") + if self.is_scale: + response_message = _send_scale_num(self.scheduler_url, self.scale_num) + is_first_manager = response_message + if is_first_manager: + self.local_worker_num = 0 + else: + sys.exit() + else: + if self.is_master and not self.is_simulation: + self.start_scheduler() + if self.enable_mindx: + self.msmgr.start() + else: + self.start_workers() + if self.join: + logger.warning("Distributed job is spawned. Waiting all processes to exit...") + self.join_processes() + + def start_scheduler(self): + """ + Starts the scheduler node. + + """ + # For Scheduler, 'RANK_ID' is always 0. + os.environ['RANK_ID'] = str(0) + os.environ['RANK'] = str(0) + msn = _MetaServerNode(self.worker_num, self.master_addr, self.master_port, self.cluster_time_out, + _generate_cmd_args_list(self.cmd, self.cmd_args), + os.path.join(self.log_dir, "scheduler.log"), self.tail_worker_log, self.join, + self.is_simulation) + self.msn_process = msn.run() + + def start_workers(self): + """ + Starts the worker nodes. + + """ + if self.local_worker_num == self.worker_num and self.node_rank not in [0, -1]: + # If only one node is involved, ignore invalid 'node_rank'. + logger.warning("All workers will be spawned on this node, " + f"so 'node_rank': [{self.node_rank}] will be ignored.") + if self.local_worker_num < self.worker_num and self.node_rank == -1: + logger.warning("You are running distributed job with multiple nodes but not setting '--node_rank'. So " + "'rank_id' of each process will be assigned after cluster is successfully built.\n" + "You can access 'RANK_ID' environment variable after calling " + "'mindspore.communication.init()'") + + for i in range(self.local_worker_num): + os.environ["DEVICE_ID"] = str(i) + os.environ["LOCAL_RANK"] = str(i) + node_id, log_name = self._get_node_id_and_log_path(i) + if node_id is None: + logger.warning(f"Rank ids will be assigned automatically, " + "please use 'grep -rn 'rank id:' command to check each worker log's rank id.") + else: + # If node_id is generated in '_get_node_id_and_log_path' method, export 'RANK_ID' environment variable. + # This is for rank_table method's compatibility consideration. + os.environ["RANK_ID"] = str(node_id) + os.environ["RANK"] = str(i) + print(f"Start worker process with rank id:{node_id}, log file:{log_name}. " + f"Environment variable [RANK_ID={node_id}] is exported.", flush=True) + if self.is_simulation and (self.sim_rank_id != -1): + # Reset RANK_ID env to sim_rank_id if sim_rank_id is set. + os.environ["RANK_ID"] = str(self.sim_rank_id) + logger.warning(f"In dryrun case, RANK_ID is assigned to {self.sim_rank_id}.") + + if self.bind_core: + cpu_num = subprocess.getoutput("cat /proc/cpuinfo|grep processor|wc -l") + if not cpu_num.isdigit(): + raise RuntimeError(f"Got cpu number from '/proc/cpuinfo' is {cpu_num}, failed to bind core.") + avg = int(cpu_num) // self.local_worker_num + cpu_start = avg * i + cpu_end = cpu_start + avg - 1 + cmd = _generate_cmd_args_list_with_core(self.cmd, self.cmd_args, cpu_start, cpu_end) + else: + cmd = _generate_cmd_args_list(self.cmd, self.cmd_args) + cgn = _ComputeGraphNode(self.worker_num, self.master_addr, self.master_port, self.cluster_time_out, + node_id, cmd, log_name, self.tail_worker_log, self.join, self.is_simulation) + process, tail_process = cgn.run() + self.cgn_processes.append(process) + self.tail_cgn_processes.append(tail_process) + self.proc_rank_map[i] = process + + def join_processes(self): + """ + Join all processes to stop. + If there's any process does not exit normally, logs will be analyzed + so that understandable root cause of exception could be returned. + """ + + def signal_handler(sig, frame): + logger.warning("msrun process received SIGNIN (Ctrl+C), terminating all workers.") + self.kill_all_processes() + sys.exit(0) + + has_exception = False + success_cgn_processes = set() + signal.signal(signal.SIGINT, signal_handler) + while True: + # Traversal all workers and kill immediately if any exception happens. + for p in self.cgn_processes: + ret_code = p.poll() + if ret_code is None: + # This means the process is still running, poll next process. + continue + elif ret_code != 0: + has_exception = True + logger.error(f"Worker process {p.pid} exit with exception.") + break + else: + success_cgn_processes.add(p) + + if has_exception: + logger.warning("There's worker exits with exception, kill all other workers.") + self.kill_worker_processes() + self.kill_tail_log_processes() + break + elif len(success_cgn_processes) == len(self.cgn_processes): + logger.info("All workers successfully exit!") + self.kill_tail_log_processes() + break + + if self.msn_process: + self.msn_process.wait() + if self.msn_process.returncode != 0: + has_exception = True + logger.error(f"Scheduler process {self.msn_process.pid} exit with exception.") + + if has_exception: + logger.info("Analyzing exception log...") + self._analyze_log() + raise RuntimeError("Distributed job exited with exception. Please check logs in " + f"directory: {self.log_dir}.") + + def kill_tail_log_processes(self): + """ + Kills all tail worker log processes. + + """ + for p_tail in self.tail_cgn_processes: + if p_tail is not None: + logger.debug("Tail worker log process:{p_tail.pid} has been killed!") + p_tail.kill() + + def kill_worker_processes(self): + """ + Kills all worker processes. + + """ + for p in self.cgn_processes: + if p.poll() is None: + os.killpg(os.getpgid(p.pid), signal.SIGKILL) + + def kill_all_processes(self): + """ + Kills all running processes, including scheduler, worker and tail log. + + """ + self.kill_worker_processes() + self.kill_tail_log_processes() + if self.msn_process.poll() is None: + self.msn_process.kill() + + def stop_processes(self): + """ + Stops all running processes. + + """ + for p in self.cgn_processes: + p.terminate() + p.join() + + if self.msn_process: + self.msn_process.terminate() + self.msn_process.join() + + def stop_and_restart(self): + """ + Stops all running processes and restarts the scheduler and workers. + + """ + self.stop_processes() + if self.is_master: + self.start_scheduler() + self.start_workers() + + def kill_all_workers(self): + """ + Kill all running worker processes. + + Args: + NA. + """ + for p in self.cgn_processes: + if p.poll() is None: + p.kill() + self.cgn_processes.clear() + + for p in self.tail_cgn_processes: + if p is not None: + p.kill() + self.tail_cgn_processes.clear() + + def kill_single_worker(self, pid): + """ + Kill one worker process with specified pid. + + Args: + pid: Worker process' pid. + """ + kill_status = False + for i in range(len(self.cgn_processes)): + p = self.cgn_processes[i] + if p.pid == pid and p.poll() is None: + p.kill() + del self.cgn_processes[i] + tail_p = self.tail_cgn_processes[i] + if tail_p is not None: + tail_p.kill() + del self.tail_cgn_processes[i] + kill_status = True + break + if not kill_status: + logger.warning(f"There's no active worker with pid: {pid}") + + def kill_workers(self, pids): + """ + Kill worker process according to pids. Worker process with pid within pids list will be killed. + + Args: + pids(list): a list of worker process pid. When local_ranks pids -1, kill all worker process. + """ + if -1 in pids: + self.kill_all_workers() + else: + for pid in pids: + self.kill_single_worker(pid) + return 0 + + def monitor_rank_status(self, local_ranks): + """ + Monitor the status of workers whose rank is within local_ranks list. + + Args: + local_ranks(list): a list of local worker ranks. When local_ranks contains -1, + monitor all workers' status. + """ + rank_status = {} + if -1 in local_ranks: + local_ranks = list(range(self.local_worker_num)) + for i in local_ranks: + single_status = self.monitor_single_rank(i) + if single_status: + rank_status[i] = single_status + return rank_status + + def monitor_single_rank(self, rank_id): + """ + Monitor the status of a single worker with rank_id + + Args: + rank_id: worker process's local rank, which is also device_id. + """ + if 0 <= rank_id < self.local_worker_num: + global_rank_id = rank_id + if self.node_rank >= 0: + global_rank_id = self.node_rank * self.local_worker_num + rank_id + try: + p = self.proc_rank_map[rank_id] + p_status = p.poll() + if (not psutil.pid_exists(p.pid)) and (p_status != 0): + p_status = 300 + return {"pid": p.pid, "status": p_status, "global_rank": global_rank_id} + except KeyError: + logger.info(f"Process rank {rank_id} has not been initialized.") + return {"pid": None, "status": 200, "global_rank": global_rank_id} + else: + logger.warning(f"Invalid rank id!") + return {} + + def start_all_workers(self): + """ + Start all worker processes after killing all workers. + + Args: + NA. + """ + if self.cgn_processes: + self.kill_all_workers() + self.start_workers() + worker_status = self.monitor_rank_status([-1]) + for i in range(self.local_worker_num): + if worker_status[i]["status"] != None: # pylint: disable=singleton-comparison + return 1 + return 0 + + def _get_node_id_and_log_path(self, index): + """ + Generate node id and log path for corresponding process. + """ + formatted_log_name = self.format_worker_log_name() + if self.local_worker_num > self.worker_num: + raise ValueError(f"Total worker number is {self.worker_num}, " + f"but got exceeded local worker number: {self.local_worker_num}.") + if self.local_worker_num == self.worker_num: + return index, os.path.join(self.log_dir, formatted_log_name + "_" + str(index) + ".log") + + if self.node_rank >= 0: + # We assume that each node has same process number. + node_id = self.node_rank * self.local_worker_num + index + log_name = os.path.join(self.log_dir, formatted_log_name + "_" + str(node_id) + ".log") + else: + # If node_rank is default value -1, let MindSpore assign rank id. + node_id = None + log_name = os.path.join(self.log_dir, formatted_log_name + "_" + str(index) + ".log") + return node_id, log_name + + def _analyze_log(self): + """ + Analyze exception logs. + """ + scheduler_log_path = os.path.join(self.log_dir, "scheduler.log") + time_out_node_ids = [] + if os.path.exists(scheduler_log_path): + with open(scheduler_log_path, "r") as log: + scheduler_log = log.read() + # Filter out abnormal logs. + time_out_node_log = re.findall(r"node: .* is timed out", scheduler_log) + + # Filter out node ids of the processes which exit abnormally. + def node_id_splitter(node_id): + return re.split(" is timed out", re.split("node: ", node_id)[1])[0] + for node_id in time_out_node_log: + time_out_node_ids.append(node_id_splitter(node_id)) + logger.error(f"Time out nodes are {time_out_node_ids}") + + os.system(f"grep -rn -E 'ERROR|CRITICAL|Traceback|Error' -C 5 {self.log_dir}") + + def format_worker_log_name(self): + """ + Format worker log files' name. + """ + if not self.worker_log_name: + formatted_worker_log_name = "worker" + else: + current_ip = _get_local_ip(self.master_addr) + formatted_worker_log_name = re.sub(r'\{ip\}', current_ip, self.worker_log_name) + formatted_worker_log_name = re.sub(r'\{hostname\}', socket.gethostname(), formatted_worker_log_name) + return formatted_worker_log_name diff --git a/mindtorch/distributed/process_entity/_utils.py b/mindtorch/distributed/process_entity/_utils.py new file mode 100644 index 000000000..1f9e7bd75 --- /dev/null +++ b/mindtorch/distributed/process_entity/_utils.py @@ -0,0 +1,136 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils for ms_run""" +import os +import json +import socket +import ipaddress +import mindspore.log as logger + +CURRENT_IP = None + +def _generate_cmd(cmd, cmd_args, output_name): + """ + Generates a command string to execute a Python script in the background, r + edirecting the output to a log file. + + """ + if cmd not in ['python', 'pytest', 'python3']: + # If user don't set binary file name, defaulty use 'python' to launch the job. + command = f"python {cmd} {' '.join(cmd_args)} > {output_name} 2>&1 &" + else: + command = f"{cmd} {' '.join(cmd_args)} > {output_name} 2>&1 &" + return command + + +def _generate_cmd_args_list(cmd, cmd_args): + """ + Generates arguments list for 'Popen'. It consists of a binary file name and subsequential arguments. + """ + if cmd not in ['python', 'pytest', 'python3']: + # If user don't set binary file name, defaulty use 'python' to launch the job. + return ['python'] + [cmd] + cmd_args + return [cmd] + cmd_args + + +def _generate_cmd_args_list_with_core(cmd, cmd_args, cpu_start, cpu_end): + """ + Generates arguments list for 'Popen'. It consists of a binary file name and subsequential arguments. + """ + # Bind cpu cores to this process. + taskset_args = ['taskset'] + ['-c'] + [str(cpu_start) + '-' + str(cpu_end)] + final_cmd = [] + if cmd not in ['python', 'pytest', 'python3']: + # If user don't set binary file name, defaulty use 'python' to launch the job. + final_cmd = taskset_args + ['python'] + [cmd] + cmd_args + else: + final_cmd = taskset_args + [cmd] + cmd_args + logger.info(f"Launch process with command: {' '.join(final_cmd)}") + return final_cmd + + +def _generate_url(addr, port): + """ + Generates a url string by addr and port + + """ + url = f"http://{addr}:{port}/" + return url + + +def _get_local_ip(ip_address): + """ + Get current IP address. + + """ + global CURRENT_IP + if CURRENT_IP is None: + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect((ip_address, 0)) + CURRENT_IP = s.getsockname()[0] + s.close() + except Exception as e: + raise RuntimeError(f"Get local ip failed: {e}. Please check whether an accessible address " + "is input by '--master_address'.") + return CURRENT_IP + + +def _is_local_ip(ip_address): + """ + Check if the current input IP address is a local IP address. + + """ + p = os.popen("ip -j addr") + addr_info_str = p.read() + p.close() + current_ip = _get_local_ip(ip_address) + if not addr_info_str: + return current_ip == ip_address + + addr_infos = json.loads(addr_info_str) + for info in addr_infos: + for addr in info["addr_info"]: + if addr["local"] == ip_address: + logger.info(f"IP address found on this node. Address info:{addr}. Found address:{ip_address}") + return True + return False + + +def _convert_addr_to_ip(master_addr): + """ + Check whether the input parameter 'master_addr' is IPv4. If a hostname is inserted, it will be converted + to IP and then set as master host's IP. + + """ + try: + ipaddress.IPv4Address(master_addr) + return master_addr + except ipaddress.AddressValueError: + try: + ip_address = socket.gethostbyname(master_addr) + logger.info(f"Convert input host name:{master_addr} to ip address:{ip_address}.") + return ip_address + except socket.gaierror as e: + raise RuntimeError(f"DNS resolution failed: {e}. Please check whether a correct host name " + "is input by '--master_address'.") + + +def _send_scale_num(url, scale_num): + """ + Send an HTTP request to a specified URL, informing scale_num. + + """ + return "" diff --git a/mindtorch/distributed/run.py b/mindtorch/distributed/run.py index 3c4a17b8f..f0edb9ccb 100644 --- a/mindtorch/distributed/run.py +++ b/mindtorch/distributed/run.py @@ -1,922 +1,210 @@ -#!/usr/bin/env python3 -# mypy: allow-untyped-defs - -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. +# Copyright 2023 Huawei Technologies Co., Ltd # -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Superset of ``mindtorch.distributed.launch``. - -``torchrun`` provides a superset of the functionality as ``mindtorch.distributed.launch`` -with the following additional functionalities: - -1. Worker failures are handled gracefully by restarting all workers. - -2. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically. - -3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity). - -.. note:: ``torchrun`` is a python - `console script `_ - to the main module - `mindtorch.distributed.run `_ - declared in the ``entry_points`` configuration in - `setup.py `_. - It is equivalent to invoking ``python -m mindtorch.distributed.run``. - - -Transitioning from mindtorch.distributed.launch to torchrun -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - -``torchrun`` supports the same arguments as ``mindtorch.distributed.launch`` **except** -for ``--use-env`` which is now deprecated. To migrate from ``mindtorch.distributed.launch`` -to ``torchrun`` follow these steps: - -1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable. - Then you need simply omit the ``--use-env`` flag, e.g.: - - +--------------------------------------------------------------------+--------------------------------------------+ - | ``mindtorch.distributed.launch`` | ``torchrun`` | - +====================================================================+============================================+ - | | | - | .. code-block:: shell-session | .. code-block:: shell-session | - | | | - | $ python -m mindtorch.distributed.launch --use-env train_script.py | $ torchrun train_script.py | - | | | - +--------------------------------------------------------------------+--------------------------------------------+ - -2. If your training script reads local rank from a ``--local-rank`` cmd argument. - Change your training script to read from the ``LOCAL_RANK`` environment variable as - demonstrated by the following code snippet: - - +-------------------------------------------------------+----------------------------------------------------+ - | ``mindtorch.distributed.launch`` | ``torchrun`` | - +=======================================================+====================================================+ - | | | - | .. code-block:: python | .. code-block:: python | - | | | - | | | - | import argparse | import os | - | parser = argparse.ArgumentParser() | local_rank = int(os.environ["LOCAL_RANK"]) | - | parser.add_argument("--local-rank", type=int) | | - | args = parser.parse_args() | | - | | | - | local_rank = args.local_rank | | - | | | - +-------------------------------------------------------+----------------------------------------------------+ - -.. versionchanged:: 2.0.0 - - The launcher will pass the ``--local-rank=`` argument to your script. - From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the - previously used underscored ``--local_rank``. - - For backward compatibility, it may be necessary for users to handle both - cases in their argument parsing code. This means including both ``"--local-rank"`` - and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is - provided, the launcher will trigger an error: "error: unrecognized arguments: - --local-rank=". For training code that only supports PyTorch 2.0.0+, - including ``"--local-rank"`` should be sufficient. - - :: - - >>> # xdoctest: +SKIP - >>> import argparse - >>> parser = argparse.ArgumentParser() - >>> parser.add_argument("--local-rank", "--local_rank", type=int) - >>> args = parser.parse_args() - -The aformentioned changes suffice to migrate from ``mindtorch.distributed.launch`` to ``torchrun``. -To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun`` -please refer to: - -* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant. -* the rest of this page for more information on the features of ``torchrun``. - - -Usage --------- - -Single-node multi-worker -++++++++++++++++++++++++++++++ - -:: - - torchrun - --standalone - --nnodes=1 - --nproc-per-node=$NUM_TRAINERS - YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) - -Stacked single-node multi-worker -+++++++++++++++++++++++++++++++++++ - -To run multiple instances (separate jobs) of single-node, multi-worker on the -same host, we need to make sure that each instance (job) is -setup on different ports to avoid port conflicts (or worse, two jobs being merged -as a single job). To do this you have to run with ``--rdzv-backend=c10d`` -and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``. -For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random -port automatically instead of manually assigning different ports for each run. - -:: - - torchrun - --rdzv-backend=c10d - --rdzv-endpoint=localhost:0 - --nnodes=1 - --nproc-per-node=$NUM_TRAINERS - YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) - - -Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures) -++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -:: - - torchrun - --nnodes=$NUM_NODES - --nproc-per-node=$NUM_TRAINERS - --max-restarts=3 - --rdzv-id=$JOB_ID - --rdzv-backend=c10d - --rdzv-endpoint=$HOST_NODE_ADDR - YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) - -``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and -the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any -node in your training cluster, but ideally you should pick a node that has a high bandwidth. - -.. note:: - If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. - -Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures) -+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - -:: - - torchrun - --nnodes=1:4 - --nproc-per-node=$NUM_TRAINERS - --max-restarts=3 - --rdzv-id=$JOB_ID - --rdzv-backend=c10d - --rdzv-endpoint=$HOST_NODE_ADDR - YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) - -``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and -the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any -node in your training cluster, but ideally you should pick a node that has a high bandwidth. - -.. note:: - If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. - -Note on rendezvous backend ------------------------------- - -For multi-node training you need to specify: - -1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job) -2. ``--rdzv-backend``: An implementation of - :py:class:`mindtorch.distributed.elastic.rendezvous.RendezvousHandler` -3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form - ``host:port``. - -Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are -supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api -enabled (e.g. ``--enable-v2``). - -.. warning:: - ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd - server. Our tests use etcd v3.4.3. - -.. warning:: - For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally - equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be - removed in a future version. - -Definitions --------------- - -1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with. - -2. ``Worker`` - A worker in the context of distributed training. - -3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers). - -4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node. - -5. ``RANK`` - The rank of the worker within a worker group. - -6. ``WORLD_SIZE`` - The total number of workers in a worker group. - -7. ``LOCAL_RANK`` - The rank of the worker within a local worker group. - -8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group. - -9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is - used by each node to join as a member of a particular worker group. - -9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly - consistent key-value store. - -10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``:``. - -A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of -all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``. - -Environment Variables ----------------------- - -The following environment variables are made available to you in your script: - -1. ``LOCAL_RANK`` - The local rank. - -2. ``RANK`` - The global rank. - -3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When - running a single worker group per node, this is the rank of the node. - -4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role - of the worker is specified in the ``WorkerSpec``. - -5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to - ``--nproc-per-node`` specified on ``torchrun``. - -6. ``WORLD_SIZE`` - The world size (total number of workers in the job). - -7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified - in ``WorkerSpec``. - -8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize - the Torch Distributed backend. - -9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store. - -10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far. - -11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts. - -12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id). - -13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will - use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default. - -Deployment ------------- - -1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be - passed as ``--rdzv-endpoint`` to the launcher script) - -2. Single-node multi-worker: Start the launcher on the host to start the agent process which - creates and monitors a local worker group. - -3. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes - participating in training. - -When using a job/cluster manager the entry point command to the multi-node job should be this -launcher. - -Failure Modes ---------------- - -1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers - are stopped and restarted up to ``max_restarts``. - -2. Agent failure: An agent failure results in a local worker group failure. It is up to the job - manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors - are supported by the agent. - -3. Node failure: Same as agent failure. - -Membership Changes --------------------- - -1. Node departure (scale-down): The agent is notified of the departure, all existing workers are - stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and - ``WORLD_SIZE``. - -2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, - a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and - ``WORLD_SIZE``. - -Important Notices --------------------- - -1. This utility and multi-process distributed (single-node or - multi-node) GPU training currently only achieves the best performance using - the NCCL distributed backend. Thus NCCL backend is the recommended backend to - use for GPU training. - -2. The environment variables necessary to initialize a Torch process group are provided to you by - this module, no need for you to pass ``RANK`` manually. To initialize a process group in your - training script, simply run: - -:: - - >>> # xdoctest: +SKIP("stub") - >>> import mindtorch.distributed as dist - >>> dist.init_process_group(backend="gloo|nccl") - -3. In your training program, you can either use regular distributed functions - or use :func:`mindtorch.nn.parallel.DistributedDataParallel` module. If your - training program uses GPUs for training and you would like to use - :func:`mindtorch.nn.parallel.DistributedDataParallel` module, - here is how to configure it. - -:: - - local_rank = int(os.environ["LOCAL_RANK"]) - model = mindtorch.nn.parallel.DistributedDataParallel(model, - device_ids=[local_rank], - output_device=local_rank) - -Please ensure that ``device_ids`` argument is set to be the only GPU device id -that your code will be operating on. This is generally the local rank of the -process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``, -and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this -utility - - -4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to - checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance - for lost work. - -5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all - nodes run the same number of local workers (per role). - -6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a - different range of ranks than before. NEVER hard code any assumptions about the stable-ness of - ranks or some correlation between ``RANK`` and ``LOCAL_RANK``. - -7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about - ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join. - -8. It is recommended for your script to have the following structure: - -:: - - def main(): - load_checkpoint(checkpoint_path) - initialize() - train() - - def train(): - for batch in iter(dataset): - train_step(batch) - - if should_checkpoint: - save_checkpoint(checkpoint_path) - -9. (Recommended) On worker errors, this tool will summarize the details of the error - (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) - is heuristically reported as the "Root Cause" error. To get tracebacks as part of this - error summary print out, you must decorate your main entrypoint function in your - training script as shown in the example below. If not decorated, then the summary - will not include the traceback of the exception and will only contain the exitcode. - For details on torchelastic error handling see: https://pymindtorch.org/docs/stable/elastic/errors.html - -:: - - from mindtorch.distributed.elastic.multiprocessing.errors import record - - @record - def main(): - # do train - pass - - if __name__ == "__main__": - main() - -""" -import os -import sys -import uuid -from argparse import ArgumentParser, REMAINDER -from importlib import metadata -from typing import Callable, List, Optional, Set, Tuple, Type, Union - -import mindtorch -from mindtorch.distributed.argparse_util import check_env, env -from mindtorch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std -from mindtorch.distributed.elastic.multiprocessing.errors import record -from mindtorch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config -from mindtorch.distributed.elastic.utils import macros -from mindtorch.distributed.elastic.utils.logging import get_logger -from mindtorch.distributed.launcher.api import elastic_launch, LaunchConfig -from mindtorch.utils.backend_registration import _get_custom_mod_func - - -logger = get_logger(__name__) - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Entrypoint of ms_run""" +import ast +import re +import json +from argparse import REMAINDER, ArgumentParser, ArgumentTypeError +from .process_entity import _ProcessManager +from .argparse_util import check_env, env + + +def parse_and_validate_bind_core(value): + """ + Parse input argument of --bind_core. -def get_args_parser() -> ArgumentParser: - """Parse the command line options.""" - parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher") + """ + if value.lower() == "true": + return True + if value.lower() == "false": + return False - # - # Worker/node size related arguments. - # + try: + value_dict = json.loads(value) + except json.JSONDecodeError as e: + raise ArgumentTypeError("Failed to parse JSON into a dictionary") from e + + if isinstance(value_dict, dict): + range_pattern = re.compile(r'^\d+-\d+$') + for device_id, affinity_cpu_list in value_dict.items(): + if not re.fullmatch(r"device\d+", device_id): + raise ArgumentTypeError(f"Key '{device_id}' must be in format 'deviceX' (X ≥ 0).") + if not isinstance(affinity_cpu_list, list): + raise ArgumentTypeError(f"Value for '{device_id}':{affinity_cpu_list} should be a list, " + f"but got {type(affinity_cpu_list)}.") + + for cpu_range in affinity_cpu_list: + if not isinstance(cpu_range, str): + raise ArgumentTypeError(f"CPU range '{cpu_range}' in '{affinity_cpu_list}' should be a string.") + if not range_pattern.match(cpu_range): + raise ArgumentTypeError(f"CPU range '{cpu_range}' in '{affinity_cpu_list}' should be " + "in format 'cpuidX-cpuidY'.") + return value_dict + + raise ArgumentTypeError(f"Type of {value} should be bool or dict, but got {type(value)}.") + + +def get_args(): + """ + Parses and retrieves command-line arguments. + """ + parser = ArgumentParser() + # parser.add_argument( + # "--worker_num", type=int, default=8, + # help="the total number of nodes participating in the training, an integer variable, " + # "with a default value of 8." + # ) parser.add_argument( "--nnodes", action=env, - type=str, - default="1:1", + type=int, + default=1, help="Number of nodes, or the range of nodes in form :.", ) parser.add_argument( "--nproc-per-node", "--nproc_per_node", action=env, - type=str, - default="1", + type=int, + default=1, help="Number of workers per node; supported values: [auto, cpu, gpu, int].", ) - - # - # Rendezvous related arguments - # - + # parser.add_argument( + # "--local_worker_num", + # type=int, default=8, + # help="the number of nodes participating in local training, an integer variable, " + # "with a default value of 8." + # ) parser.add_argument( - "--rdzv-backend", - "--rdzv_backend", - action=env, - type=str, - default="static", - help="Rendezvous backend.", + "--master_addr", + default="127.0.0.1", type=str, + help="specifies the IP address or the host name of the scheduler and its data type is string." + " Allowed values: valid IP addresses or valid host name." ) parser.add_argument( - "--rdzv-endpoint", - "--rdzv_endpoint", - action=env, - type=str, - default="", - help="Rendezvous backend endpoint; usually in form :.", + "--master_port", default=8118, type=int, + help="specifies the port number of the scheduler, and its data type is integer." + " Allowed values: port numbers within the range of 1024 to 65535 that are not " + "already in use." ) parser.add_argument( - "--rdzv-id", - "--rdzv_id", - action=env, - type=str, - default="none", - help="User-defined group id.", + "--node_rank", default=-1, type=int, + help="specifies the rank of current physical node, and its data type is integer." + " This parameter is used for rank id assignment for each process on the node." + " If not set, MindSpore will assign rank ids automatically and" + " rank id of each process on the same node will be continuous." ) parser.add_argument( - "--rdzv-conf", - "--rdzv_conf", - action=env, - type=str, - default="", - help="Additional rendezvous configuration (=,=,...).", + "--log_dir", default="", type=str, + help="specifies the log output file path." ) parser.add_argument( - "--standalone", - action=check_env, - help="Start a local standalone rendezvous backend that is represented by a C10d TCP store " - "on a free port. Useful when launching single-node, multi-worker job. If specified " - "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values " - "are ignored.", + "--join", + default=False, + type=ast.literal_eval, + choices=[True, False], + help="specifies whether msrun should join spawned processes and return distributed job results." + "If set to True, msrun will check process status and parse the log files." ) - - # - # User-code launch related arguments. - # - parser.add_argument( - "--max-restarts", - "--max_restarts", - action=env, + "--cluster_time_out", + default=600, type=int, - default=0, - help="Maximum number of worker group restarts before failing.", + help="specifies time out window of cluster building procedure in second. " + "If only scheduler is launched, or spawned worker number is not enough, " + "other processes will wait for 'cluster_time_out' seconds and then exit. " + "If this value is negative, other processes will wait infinitely." ) parser.add_argument( - "--monitor-interval", - "--monitor_interval", - action=env, - type=float, - default=0.1, - help="Interval, in seconds, to monitor the state of workers.", + "--bind_core", + default=False, + type=parse_and_validate_bind_core, + help="specifies whether msrun should bind CPU cores to spawned processes. " + "If set to True, msrun will bind core based on the environment automatically, " + "and if passed a dict, msrun will bind core based on this dict information." ) parser.add_argument( - "--start-method", - "--start_method", - action=env, - type=str, - default="spawn", - choices=["spawn", "fork", "forkserver"], - help="Multiprocessing start method to use when creating workers.", - ) - parser.add_argument( - "--role", - action=env, - type=str, - default="default", - help="User-defined role for the workers.", - ) - parser.add_argument( - "-m", - "--module", - action=check_env, - help="Change each process to interpret the launch script as a Python module, executing " - "with the same behavior as 'python -m'.", - ) - parser.add_argument( - "--no-python", - "--no_python", - action=check_env, - help="Skip prepending the training script with 'python' - just execute it directly. Useful " - "when the script is not a Python script.", - ) - - parser.add_argument( - "--run-path", - "--run_path", - action=check_env, - help="Run the training script with runpy.run_path in the same interpreter." - " Script must be provided as an abs path (e.g. /abs/path/script.py)." - " Takes precedence over --no-python.", - ) - parser.add_argument( - "--log-dir", - "--log_dir", - action=env, - type=str, - default=None, - help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same " - "directory is re-used for multiple runs (a unique job-level sub-directory is created with " - "rdzv_id as the prefix).", + "--sim_level", + default=-1, + type=int, + choices=[0, 1, 2, 3], + help="specifies simulation level. This argument activates dryrun mode, functioning " + "equivalently to environment variable 'MS_SIMULATION_LEVEL' while having higher priority." ) parser.add_argument( - "-r", - "--redirects", - action=env, - type=str, - default="0", - help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects " - "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and " - "stderr for local rank 1).", + "--sim_rank_id", + default=-1, + type=int, + help="specifies simulation process's rank id. When this argument is set, only one process " + "is spawned on dryrun mode, functioning equivalently to environment variable 'RANK_ID' " + "while having higher priority." ) parser.add_argument( - "-t", - "--tee", - action=env, + "--rank_table_file", + default="", type=str, - default="0", - help="Tee std streams into a log file and also to console (see --redirects for format).", + help="specifies rank table file path. This path is not used to initialize distributed job in " + "'rank table file manner' but to help support other features." ) - parser.add_argument( - "--local-ranks-filter", - "--local_ranks_filter", - action=env, - type=str, + "--worker_log_name", default="", - help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will " - "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to" - "log files saved via --redirect or --tee", - ) - - # - # Backwards compatible parameters with caffe2.distributed.launch. - # - - parser.add_argument( - "--node-rank", - "--node_rank", - type=int, - action=env, - default=0, - help="Rank of the node for multi-node distributed training.", - ) - parser.add_argument( - "--master-addr", - "--master_addr", - default="127.0.0.1", type=str, - action=env, - help="Address of the master node (rank 0) that only used for static rendezvous. It should " - "be either the IP address or the hostname of rank 0. For single node multi-proc training " - "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern " - "`[0:0:0:0:0:0:0:1]`.", + help="Specifies the worker log file name as a string for current node; the default is worker_[rankid]. " + "Support configuring the current IP address and host name by using {ip} and {hostname} respectively. " + "e.g. --worker_log_name=worker_{ip}_{hostname}_test, worker [rankid] log name for current node " + "will be worker_[real IP address]_[real host name]_test_[rankid]." ) parser.add_argument( - "--master-port", - "--master_port", - default=29500, - type=int, - action=env, - help="Port on the master node (rank 0) to be used for communication during distributed " - "training. It is only used for static rendezvous.", - ) - parser.add_argument( - "--local-addr", - "--local_addr", - default=None, + "--tail_worker_log", + default="-1", type=str, - action=env, - help="Address of the local node. If specified, will use the given address for connection. " - "Else, will look up the local node address instead. Else, it will be default to local " - "machine's FQDN.", + help="Only tail worker log to console when '--join=True' and the configured value should be within " + "[0, local_worker_num], otherwise worker log will not be tail. All worker logs will be tail by " + "default. Support tail the specified worker log (e.g. --tail_log=0 tail the worker 0 log to console)." ) - parser.add_argument( - "--logs-specs", - "--logs_specs", - default=None, + "task_script", type=str, - help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. " - "Can be used to override custom logging behavior.", + help="The full path to the script that will be launched in distributed manner, followed " + "by any additional arguments required by the script." ) - - # - # Positional arguments. - # - parser.add_argument( - "training_script", - type=str, - help="Full path to the (single GPU) training program/script to be launched in parallel, " - "followed by all the arguments for the training script.", + "task_script_args", nargs=REMAINDER, + help="Arguments for user-defined script." ) + return parser.parse_args() - # Rest from the training program. - parser.add_argument("training_script_args", nargs=REMAINDER) - - return parser - - -def parse_args(args): - parser = get_args_parser() - return parser.parse_args(args) - - -def parse_min_max_nnodes(nnodes: str): - arr = nnodes.split(":") - - if len(arr) == 1: - min_nodes = max_nodes = int(arr[0]) - elif len(arr) == 2: - min_nodes = int(arr[0]) - max_nodes = int(arr[1]) - else: - raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231 - - return min_nodes, max_nodes - - -def determine_local_world_size(nproc_per_node: str): - try: - logger.info("Using nproc_per_node=%s.", nproc_per_node) - return int(nproc_per_node) - except ValueError as e: - if nproc_per_node == "cpu": - num_proc = os.cpu_count() - device_type = "cpu" - elif nproc_per_node == "gpu": - if not mindtorch.cuda.is_available(): - raise ValueError("Cuda is not available.") from e - device_type = "gpu" - num_proc = mindtorch.cuda.device_count() - elif nproc_per_node == mindtorch._C._get_privateuse1_backend_name(): - if not _get_custom_mod_func("is_available")(): - raise ValueError(f"{nproc_per_node} is not available.") from e - device_type = nproc_per_node - num_proc = _get_custom_mod_func("device_count")() - elif nproc_per_node == "auto": - if mindtorch.cuda.is_available(): - num_proc = mindtorch.cuda.device_count() - device_type = "gpu" - elif ( - hasattr(torch, mindtorch._C._get_privateuse1_backend_name()) - and _get_custom_mod_func("is_available")() - ): - num_proc = _get_custom_mod_func("device_count")() - device_type = mindtorch._C._get_privateuse1_backend_name() - else: - num_proc = os.cpu_count() - device_type = "cpu" - else: - raise ValueError( - f"Unsupported nproc_per_node value: {nproc_per_node}" - ) from e - - logger.info( - "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s", - nproc_per_node, - num_proc, - num_proc, - device_type, - ) - return num_proc - - -def get_rdzv_endpoint(args): - if args.rdzv_backend == "static" and not args.rdzv_endpoint: - return f"{args.master_addr}:{args.master_port}" # noqa: E231 - return args.rdzv_endpoint - - -def get_use_env(args) -> bool: - """ - Retrieve ``use_env`` from the args. - - ``use_env`` is a legacy argument, if ``use_env`` is False, the - ``--node-rank`` argument will be transferred to all worker processes. - ``use_env`` is only used by the ``mindtorch.distributed.launch`` and will - be deprecated in future releases. - """ - if not hasattr(args, "use_env"): - return True - return args.use_env - - -def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: - """ - Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. - Provides plugin mechanism to provide custom implementation of LogsSpecs. - Returns `DefaultLogsSpecs` when logs_spec_name is None. - Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints. +def run(args): """ - logs_specs_cls = None - if logs_specs_name is not None: - eps = metadata.entry_points() - if hasattr(eps, "select"): # >= 3.10 - group = eps.select(group="torchrun.logs_specs") - if group.select(name=logs_specs_name): - logs_specs_cls = group[logs_specs_name].load() - - elif specs := eps.get("torchrun.logs_specs"): # < 3.10 - if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]: - logs_specs_cls = entrypoint_list[0].load() - - if logs_specs_cls is None: - raise ValueError( - f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" - ) - - logger.info( - "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) - ) - else: - logs_specs_cls = DefaultLogsSpecs - - return logs_specs_cls - - -def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]: - # If ``args`` not passed, defaults to ``sys.argv[:1]`` - min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) - assert 0 < min_nodes <= max_nodes - assert args.max_restarts >= 0 - - if ( - hasattr(args, "master_addr") - and args.rdzv_backend != "static" - and not args.rdzv_endpoint - ): - logger.warning( - "master_addr is only used for static rdzv_backend and when rdzv_endpoint " - "is not specified." - ) - - nproc_per_node = determine_local_world_size(args.nproc_per_node) - if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: - omp_num_threads = 1 - logger.warning( - "\n*****************************************\n" - "Setting OMP_NUM_THREADS environment variable for each process to be " - "%s in default, to avoid your system being overloaded, " - "please further tune the variable for optimal performance in " - "your application as needed. \n" - "*****************************************", - omp_num_threads, - ) - # This env variable will be passed down to the subprocesses - os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) - - log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE") - - rdzv_configs = _parse_rendezvous_config(args.rdzv_conf) - - if args.rdzv_backend == "static": - rdzv_configs["rank"] = args.node_rank - - rdzv_endpoint = get_rdzv_endpoint(args) - - ranks: Optional[Set[int]] = None - if args.local_ranks_filter: - try: - ranks = set(map(int, args.local_ranks_filter.split(","))) - assert ranks - except Exception as e: - raise ValueError( - "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" - ) from e + Runs the dynamic networking process manager. - logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) - logs_specs = logs_specs_cls( - log_dir=args.log_dir, - redirects=Std.from_str(args.redirects), - tee=Std.from_str(args.tee), - local_ranks_filter=ranks, - ) - - config = LaunchConfig( - min_nodes=min_nodes, - max_nodes=max_nodes, - nproc_per_node=nproc_per_node, - run_id=args.rdzv_id, - role=args.role, - rdzv_endpoint=rdzv_endpoint, - rdzv_backend=args.rdzv_backend, - rdzv_configs=rdzv_configs, - max_restarts=args.max_restarts, - monitor_interval=args.monitor_interval, - start_method=args.start_method, - log_line_prefix_template=log_line_prefix_template, - local_addr=args.local_addr, - logs_specs=logs_specs, - ) - - with_python = not args.no_python - cmd: Union[Callable, str] - cmd_args = [] - use_env = get_use_env(args) - if args.run_path: - cmd = run_script_path - cmd_args.append(args.training_script) - else: - if with_python: - cmd = os.getenv("PYTHON_EXEC", sys.executable) - cmd_args.append("-u") - if args.module: - cmd_args.append("-m") - cmd_args.append(args.training_script) - else: - if args.module: - raise ValueError( - "Don't use both the '--no-python' flag" - " and the '--module' flag at the same time." - ) - cmd = args.training_script - if not use_env: - cmd_args.append(f"--local-rank={macros.local_rank}") - cmd_args.extend(args.training_script_args) - - return config, cmd, cmd_args - - -def run_script_path(training_script: str, *training_script_args: str): - """ - Run the provided `training_script` from within this interpreter. + Args: + args: An object containing the command-line arguments. - Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")` """ - import runpy - import sys + process_manager = _ProcessManager(args) + process_manager.run() - sys.argv = [training_script] + [*training_script_args] - runpy.run_path(sys.argv[0], run_name="__main__") - -def run(args): - mindtorch.multiprocessing._set_thread_name("pt_elastic") - - if args.standalone: - args.rdzv_backend = "c10d" - args.rdzv_endpoint = "localhost:0" - args.rdzv_id = str(uuid.uuid4()) - logger.info( - "\n**************************************\n" - "Rendezvous info:\n" - "--rdzv-backend=%s " - "--rdzv-endpoint=%s " - "--rdzv-id=%s\n" - "**************************************\n", - args.rdzv_backend, - args.rdzv_endpoint, - args.rdzv_id, - ) - - config, cmd, cmd_args = config_from_args(args) - elastic_launch( - config=config, - entrypoint=cmd, - )(*cmd_args) - - -@record -def main(args=None): - args = parse_args(args) +def main(): + """the main function""" + args = get_args() run(args) - if __name__ == "__main__": main() diff --git a/mindtorch/npu/__init__.py b/mindtorch/npu/__init__.py index 5582412ff..8255f4845 100644 --- a/mindtorch/npu/__init__.py +++ b/mindtorch/npu/__init__.py @@ -46,7 +46,7 @@ def device_count(): return 1 def current_device(): - return mindtorch.device('npu', 0) + return 0 def is_available(): return mindspore.get_context('device_target') == 'Ascend' @@ -102,9 +102,6 @@ def mem_get_info(device=None): return (res.free_memory, res.total_memory) -def current_device(): - return mindtorch.device('npu', 0) - def get_device_capability(device=None): return 10, 0 diff --git a/mindtorch/utils/data/__init__.py b/mindtorch/utils/data/__init__.py new file mode 100644 index 000000000..023eaf9de --- /dev/null +++ b/mindtorch/utils/data/__init__.py @@ -0,0 +1,10 @@ +from . import dataset +from .dataset import * +from .sampler import * +from .dataloader import * +from .dataloader import _DatasetKind + +__all__ = [] +__all__.extend(dataset.__all__) +__all__.extend(sampler.__all__) +__all__.extend(dataloader.__all__) diff --git a/mindtorch/utils/data/_utils/__init__.py b/mindtorch/utils/data/_utils/__init__.py new file mode 100644 index 000000000..5590b197a --- /dev/null +++ b/mindtorch/utils/data/_utils/__init__.py @@ -0,0 +1,54 @@ +# mypy: allow-untyped-defs +r"""Utility classes & functions for data loading. Code in this folder is mostly used by ../dataloder.py. + +A lot of multiprocessing is used in data loading, which only supports running +functions defined in global environment (py2 can't serialize static methods). +Therefore, for code tidiness we put these functions into different files in this +folder. +""" + +import atexit +import sys + +# old private location of the ExceptionWrapper that some users rely on: +from mindtorch._utils import ExceptionWrapper + + +IS_WINDOWS = sys.platform == "win32" + + +MP_STATUS_CHECK_INTERVAL = 5.0 +r"""Interval (in seconds) to check status of processes to avoid hanging in + multiprocessing data loading. This is mainly used in getting data from + another process, in which case we need to periodically check whether the + sender is alive to prevent hanging.""" + + +python_exit_status = False +r"""Whether Python is shutting down. This flag is guaranteed to be set before +the Python core library resources are freed, but Python may already be exiting +for some time when this is set. + +Hook to set this flag is `_set_python_exit_flag`, and is inspired by a similar +hook in Python 3.7 multiprocessing library: +https://github.com/python/cpython/blob/d4d60134b29290049e28df54f23493de4f1824b6/Lib/multiprocessing/util.py#L277-L327 +""" + + +try: + import numpy + + HAS_NUMPY = True +except ModuleNotFoundError: + HAS_NUMPY = False + + +def _set_python_exit_flag(): + global python_exit_status + python_exit_status = True + + +atexit.register(_set_python_exit_flag) + + +from . import collate, fetch, pin_memory, signal_handling, worker \ No newline at end of file diff --git a/mindtorch/utils/data/_utils/collate.py b/mindtorch/utils/data/_utils/collate.py new file mode 100644 index 000000000..231e2c362 --- /dev/null +++ b/mindtorch/utils/data/_utils/collate.py @@ -0,0 +1,398 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These methods are used to collate samples fetched from dataset into Tensor(s). +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. + +`default_collate` and `default_convert` are exposed to users via 'dataloader.py'. +""" + +import collections +import contextlib +import copy +import re +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import mindtorch + + +np_str_obj_array_pattern = re.compile(r"[SaUO]") + + +def default_convert(data): + r""" + Convert each NumPy array element into a :class:`mindtorch.Tensor`. + + If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`mindtorch.Tensor`. + If the input is not an NumPy array, it is left unchanged. + This is used as the default function for collation when both `batch_sampler` and `batch_size` + are NOT defined in :class:`~mindtorch.utils.data.DataLoader`. + + The general input type to output type mapping is similar to that + of :func:`~mindtorch.utils.data.default_collate`. See the description there for more details. + + Args: + data: a single data point to be converted + + Examples: + >>> # xdoctest: +SKIP + >>> # Example with `int` + >>> default_convert(0) + 0 + >>> # Example with NumPy array + >>> default_convert(np.array([0, 1])) + tensor([0, 1]) + >>> # Example with NamedTuple + >>> Point = namedtuple('Point', ['x', 'y']) + >>> default_convert(Point(0, 0)) + Point(x=0, y=0) + >>> default_convert(Point(np.array(0), np.array(0))) + Point(x=tensor(0), y=tensor(0)) + >>> # Example with List + >>> default_convert([np.array([0, 1]), np.array([2, 3])]) + [tensor([0, 1]), tensor([2, 3])] + """ + elem_type = type(data) + if isinstance(data, mindtorch.Tensor): + return data + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + # array of string classes and object + if ( + elem_type.__name__ == "ndarray" + and np_str_obj_array_pattern.search(data.dtype.str) is not None + ): + return data + return mindtorch.as_tensor(data) + elif isinstance(data, collections.abc.Mapping): + try: + if isinstance(data, collections.abc.MutableMapping): + # The mapping type may have extra properties, so we can't just + # use `type(data)(...)` to create the new mapping. + # Create a clone and update it if the mapping type is mutable. + clone = copy.copy(data) + clone.update({key: default_convert(data[key]) for key in data}) + return clone + else: + return elem_type({key: default_convert(data[key]) for key in data}) + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return {key: default_convert(data[key]) for key in data} + elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple + return elem_type(*(default_convert(d) for d in data)) + elif isinstance(data, tuple): + return [default_convert(d) for d in data] # Backwards compatibility. + elif isinstance(data, collections.abc.Sequence) and not isinstance( + data, (str, bytes) + ): + try: + if isinstance(data, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) # type: ignore[arg-type] + for i, d in enumerate(data): + clone[i] = default_convert(d) + return clone + else: + return elem_type([default_convert(d) for d in data]) + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [default_convert(d) for d in data] + else: + return data + + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}" +) + + +def collate( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, +): + r""" + General collate function that handles collection type of element within each batch. + + The function also opens function registry to deal with specific element types. `default_collate_fn_map` + provides default collate functions for tensors, numpy arrays, numbers and strings. + + Args: + batch: a single batch to be collated + collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function. + If the element type isn't present in this dictionary, + this function will go through each key of the dictionary in the insertion order to + invoke the corresponding collate function if the element type is a subclass of the key. + + Examples: + >>> def collate_tensor_fn(batch, *, collate_fn_map): + ... # Extend this function to handle batch of tensors + ... return mindtorch.stack(batch, 0) + >>> def custom_collate(batch): + ... collate_map = {mindtorch.Tensor: collate_tensor_fn} + ... return collate(batch, collate_fn_map=collate_map) + >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` + >>> default_collate_fn_map.update({mindtorch.Tensor: collate_tensor_fn}) + + Note: + Each collate function requires a positional argument for batch and a keyword argument + for the dictionary of collate functions as `collate_fn_map`. + """ + elem = batch[0] + elem_type = type(elem) + + if collate_fn_map is not None: + if elem_type in collate_fn_map: + return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) + + for collate_type in collate_fn_map: + if isinstance(elem, collate_type): + return collate_fn_map[collate_type]( + batch, collate_fn_map=collate_fn_map + ) + + if isinstance(elem, collections.abc.Mapping): + try: + if isinstance(elem, collections.abc.MutableMapping): + # The mapping type may have extra properties, so we can't just + # use `type(data)(...)` to create the new mapping. + # Create a clone and update it if the mapping type is mutable. + clone = copy.copy(elem) + clone.update( + { + key: collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map + ) + for key in elem + } + ) + return clone + else: + return elem_type( + { + key: collate( + [d[key] for d in batch], collate_fn_map=collate_fn_map + ) + for key in elem + } + ) + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return { + key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) + for key in elem + } + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple + return elem_type( + *( + collate(samples, collate_fn_map=collate_fn_map) + for samples in zip(*batch) + ) + ) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError("each element in list of batch should be of equal size") + transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. + + if isinstance(elem, tuple): + return [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] # Backwards compatibility. + else: + try: + if isinstance(elem, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(elem) # type: ignore[arg-type] + for i, samples in enumerate(transposed): + clone[i] = collate(samples, collate_fn_map=collate_fn_map) + return clone + else: + return elem_type( + [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] + ) + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [ + collate(samples, collate_fn_map=collate_fn_map) + for samples in transposed + ] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) + + +def collate_tensor_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, +): + elem = batch[0] + out = None + # if elem.is_nested: + # raise RuntimeError( + # "Batches of nested tensors are not currently supported by the default collate_fn; " + # "please provide a custom collate_fn to handle them appropriately." + # ) + # if elem.layout in { + # mindtorch.sparse_coo, + # mindtorch.sparse_csr, + # mindtorch.sparse_bsr, + # mindtorch.sparse_csc, + # mindtorch.sparse_bsc, + # }: + # raise RuntimeError( + # "Batches of sparse tensors are not currently supported by the default collate_fn; " + # "please provide a custom collate_fn to handle them appropriately." + # ) + # if mindtorch.utils.data.get_worker_info() is not None: + # # If we're in a background process, concatenate directly into a + # # shared memory tensor to avoid an extra copy + # numel = sum(x.numel() for x in batch) + # storage = elem._typed_storage()._new_shared(numel, device=elem.device) + # out = elem.new(storage).resize_(len(batch), *list(elem.size())) + return mindtorch.stack(batch, 0) + + +def collate_numpy_array_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, +): + elem = batch[0] + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return collate([mindtorch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map) + + +def collate_numpy_scalar_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, +): + return mindtorch.as_tensor(batch) + + +def collate_float_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, +): + return mindtorch.tensor(batch, dtype=mindtorch.float64) + + +def collate_int_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, +): + return mindtorch.tensor(batch) + + +def collate_str_fn( + batch, + *, + collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, +): + return batch + + +default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = { + mindtorch.Tensor: collate_tensor_fn +} +with contextlib.suppress(ImportError): + import numpy as np + + # For both ndarray and memmap (subclass of ndarray) + default_collate_fn_map[np.ndarray] = collate_numpy_array_fn + # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html + # Skip string scalars + default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn +default_collate_fn_map[float] = collate_float_fn +default_collate_fn_map[int] = collate_int_fn +default_collate_fn_map[str] = collate_str_fn +default_collate_fn_map[bytes] = collate_str_fn + + +def default_collate(batch): + r""" + Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size. + + The exact output type can be a :class:`mindtorch.Tensor`, a `Sequence` of :class:`mindtorch.Tensor`, a + Collection of :class:`mindtorch.Tensor`, or left unchanged, depending on the input type. + This is used as the default function for collation when + `batch_size` or `batch_sampler` is defined in :class:`~mindtorch.utils.data.DataLoader`. + + Here is the general input type (based on the type of the element within the batch) to output type mapping: + + * :class:`mindtorch.Tensor` -> :class:`mindtorch.Tensor` (with an added outer dimension batch size) + * NumPy Arrays -> :class:`mindtorch.Tensor` + * `float` -> :class:`mindtorch.Tensor` + * `int` -> :class:`mindtorch.Tensor` + * `str` -> `str` (unchanged) + * `bytes` -> `bytes` (unchanged) + * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]` + * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), + default_collate([V2_1, V2_2, ...]), ...]` + * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), + default_collate([V2_1, V2_2, ...]), ...]` + + Args: + batch: a single batch to be collated + + Examples: + >>> # xdoctest: +SKIP + >>> # Example with a batch of `int`s: + >>> default_collate([0, 1, 2, 3]) + tensor([0, 1, 2, 3]) + >>> # Example with a batch of `str`s: + >>> default_collate(['a', 'b', 'c']) + ['a', 'b', 'c'] + >>> # Example with `Map` inside the batch: + >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) + {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} + >>> # Example with `NamedTuple` inside the batch: + >>> Point = namedtuple('Point', ['x', 'y']) + >>> default_collate([Point(0, 0), Point(1, 1)]) + Point(x=tensor([0, 1]), y=tensor([0, 1])) + >>> # Example with `Tuple` inside the batch: + >>> default_collate([(0, 1), (2, 3)]) + [tensor([0, 2]), tensor([1, 3])] + >>> # Example with `List` inside the batch: + >>> default_collate([[0, 1], [2, 3]]) + [tensor([0, 2]), tensor([1, 3])] + >>> # Two options to extend `default_collate` to handle specific type + >>> # Option 1: Write custom collate function and invoke `default_collate` + >>> def custom_collate(batch): + ... elem = batch[0] + ... if isinstance(elem, CustomType): # Some custom condition + ... return ... + ... else: # Fall back to `default_collate` + ... return default_collate(batch) + >>> # Option 2: In-place modify `default_collate_fn_map` + >>> def collate_customtype_fn(batch, *, collate_fn_map=None): + ... return ... + >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) + >>> default_collate(batch) # Handle `CustomType` automatically + """ + return collate(batch, collate_fn_map=default_collate_fn_map) \ No newline at end of file diff --git a/mindtorch/utils/data/_utils/fetch.py b/mindtorch/utils/data/_utils/fetch.py new file mode 100644 index 000000000..e09e151d6 --- /dev/null +++ b/mindtorch/utils/data/_utils/fetch.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch data from an iterable-style or map-style dataset. + +This logic is shared in both single- and multi-processing data loading. +""" + + +class _BaseDatasetFetcher: + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + self.dataset = dataset + self.auto_collation = auto_collation + self.collate_fn = collate_fn + self.drop_last = drop_last + + def fetch(self, possibly_batched_index): + raise NotImplementedError + + +class _IterableDatasetFetcher(_BaseDatasetFetcher): + def __init__(self, dataset, auto_collation, collate_fn, drop_last): + super().__init__(dataset, auto_collation, collate_fn, drop_last) + self.dataset_iter = iter(dataset) + self.ended = False + + def fetch(self, possibly_batched_index): + if self.ended: + raise StopIteration + + if self.auto_collation: + data = [] + for _ in possibly_batched_index: + try: + data.append(next(self.dataset_iter)) + except StopIteration: + self.ended = True + break + if len(data) == 0 or ( + self.drop_last and len(data) < len(possibly_batched_index) + ): + raise StopIteration + else: + data = next(self.dataset_iter) + return self.collate_fn(data) + + +class _MapDatasetFetcher(_BaseDatasetFetcher): + def fetch(self, possibly_batched_index): + if self.auto_collation: + if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: + data = self.dataset.__getitems__(possibly_batched_index) + else: + data = [self.dataset[idx] for idx in possibly_batched_index] + else: + data = self.dataset[possibly_batched_index] + return self.collate_fn(data) \ No newline at end of file diff --git a/mindtorch/utils/data/_utils/pin_memory.py b/mindtorch/utils/data/_utils/pin_memory.py new file mode 100644 index 000000000..7d26bc0df --- /dev/null +++ b/mindtorch/utils/data/_utils/pin_memory.py @@ -0,0 +1,108 @@ +# mypy: allow-untyped-defs +r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import collections +import copy +import queue + +import mindtorch +from mindtorch._utils import ExceptionWrapper + +from . import MP_STATUS_CHECK_INTERVAL + + +def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): + # This setting is thread local, and prevents the copy in pin_memory from + # consuming all CPU cores. + mindtorch.set_num_threads(1) + + mindtorch.multiprocessing._set_thread_name("pt_data_pin") + + if device == "cuda": + mindtorch.cuda.set_device(device_id) + elif device == "xpu": + mindtorch.xpu.set_device(device_id) # type: ignore[attr-defined] + elif device == mindtorch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, mindtorch._C._get_privateuse1_backend_name()) + custom_device_mod.set_device(device_id) + + def do_one_step(): + try: + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + return + idx, data = r + if not done_event.is_set() and not isinstance(data, ExceptionWrapper): + try: + data = pin_memory(data, device) + except Exception: + data = ExceptionWrapper( + where=f"in pin memory thread for device {device_id}" + ) + r = (idx, data) + while not done_event.is_set(): + try: + out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) + break + except queue.Full: + continue + + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + while not done_event.is_set(): + # Make sure that we don't preserve any object from one iteration + # to the next + do_one_step() + + +def pin_memory(data, device=None): + if isinstance(data, mindtorch.Tensor): + return data.pin_memory(device) + elif isinstance(data, (str, bytes)): + return data + elif isinstance(data, collections.abc.Mapping): + try: + if isinstance(data, collections.abc.MutableMapping): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) + clone.update( + {k: pin_memory(sample, device) for k, sample in data.items()} + ) + return clone + else: + return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] + except TypeError: + # The mapping type may not support `copy()` / `update(mapping)` + # or `__init__(iterable)`. + return {k: pin_memory(sample, device) for k, sample in data.items()} + elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple + return type(data)(*(pin_memory(sample, device) for sample in data)) + elif isinstance(data, tuple): + return [ + pin_memory(sample, device) for sample in data + ] # Backwards compatibility. + elif isinstance(data, collections.abc.Sequence): + try: + if isinstance(data, collections.abc.MutableSequence): + # The sequence type may have extra properties, so we can't just + # use `type(data)(...)` to create the new sequence. + # Create a clone and update it if the sequence type is mutable. + clone = copy.copy(data) # type: ignore[arg-type] + for i, item in enumerate(data): + clone[i] = pin_memory(item, device) + return clone + return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg] + except TypeError: + # The sequence type may not support `copy()` / `__setitem__(index, item)` + # or `__init__(iterable)` (e.g., `range`). + return [pin_memory(sample, device) for sample in data] + elif hasattr(data, "pin_memory"): + return data.pin_memory() + else: + return data \ No newline at end of file diff --git a/mindtorch/distributed/elastic/agent/__init__.py b/mindtorch/utils/data/_utils/signal_handling.py similarity index 100% rename from mindtorch/distributed/elastic/agent/__init__.py rename to mindtorch/utils/data/_utils/signal_handling.py diff --git a/mindtorch/utils/data/_utils/worker.py b/mindtorch/utils/data/_utils/worker.py new file mode 100644 index 000000000..1d127162a --- /dev/null +++ b/mindtorch/utils/data/_utils/worker.py @@ -0,0 +1,381 @@ +# mypy: allow-untyped-defs +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These **needs** to be in global scope since Py2 doesn't support serializing +static methods. +""" + +import os +import queue +import random +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING, Union + +try: + from mindspore._c_expression import disable_multi_thread +except: + disable_multi_thread = None + +import mindtorch +from mindtorch._utils import ExceptionWrapper + +from . import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling + + +if TYPE_CHECKING: + from mindtorch.utils.data import Dataset + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import BOOL, DWORD, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog: + def __init__(self) -> None: + self.manager_pid = os.getppid() + + # mypy cannot detect this code is windows only + self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess( + SYNCHRONIZE, 0, self.manager_pid + ) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = ( + self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + ) + return not self.manager_dead + +else: + + class ManagerWatchdog: # type: ignore[no-redef] + def __init__(self) -> None: + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +_worker_info: Optional["WorkerInfo"] = None + + +class WorkerInfo: + id: int + num_workers: int + seed: int + dataset: "Dataset" + __initialized = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__keys = tuple(kwargs.keys()) + self.__initialized = True + + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError( + f"Cannot assign attributes to {self.__class__.__name__} objects" + ) + return super().__setattr__(key, val) + + def __repr__(self): + items = [f"{k}={getattr(self, k)}" for k in self.__keys] + return f"{self.__class__.__name__}({', '.join(items)})" + + +def get_worker_info() -> Optional[WorkerInfo]: + r"""Returns the information about the current + :class:`~mindtorch.utils.data.DataLoader` iterator worker process. + + When called in a worker, this returns an object guaranteed to have the + following attributes: + + * :attr:`id`: the current worker id. + * :attr:`num_workers`: the total number of workers. + * :attr:`seed`: the random seed set for the current worker. This value is + determined by main process RNG and the worker id. See + :class:`~mindtorch.utils.data.DataLoader`'s documentation for more details. + * :attr:`dataset`: the copy of the dataset object in **this** process. Note + that this will be a different object in a different process than the one + in the main process. + + When called in the main process, this returns ``None``. + + .. note:: + When used in a :attr:`worker_init_fn` passed over to + :class:`~mindtorch.utils.data.DataLoader`, this method can be useful to + set up each worker process differently, for instance, using ``worker_id`` + to configure the ``dataset`` object to only read a specific fraction of a + sharded dataset, or use ``seed`` to seed other libraries used in dataset + code. + """ + return _worker_info + + +r"""Dummy class used to signal the end of an IterableDataset""" + + +@dataclass(frozen=True) +class _IterableDatasetStopIteration: + worker_id: int + + +r"""Dummy class used to resume the fetching when worker reuse is enabled""" + + +@dataclass(frozen=True) +class _ResumeIteration: + seed: Optional[int] = None + + +# The function `_generate_state` is adapted from `numpy.random.SeedSequence` +# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx +# It's MIT licensed, here is the copyright: + +# Copyright (c) 2015 Melissa E. O'Neill +# Copyright (c) 2019 NumPy Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +# This function generates an array of int32 as the seed for +# `numpy.random`, in order to prevent state collision due to same +# seed and algorithm for `numpy.random` and `random` modules. +# TODO: Implement `SeedSequence` like object for `mindtorch.random` +def _generate_state(base_seed, worker_id): + INIT_A = 0x43B0D7E5 + MULT_A = 0x931E8875 + INIT_B = 0x8B51F9DD + MULT_B = 0x58F38DED + MIX_MULT_L = 0xCA01F9DD + MIX_MULT_R = 0x4973F715 + XSHIFT = 4 * 8 // 2 + MASK32 = 0xFFFFFFFF + + entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] + pool = [0] * 4 + + hash_const_A = INIT_A + + def hash(value): + nonlocal hash_const_A + value = (value ^ hash_const_A) & MASK32 + hash_const_A = (hash_const_A * MULT_A) & MASK32 + value = (value * hash_const_A) & MASK32 + value = (value ^ (value >> XSHIFT)) & MASK32 + return value + + def mix(x, y): + result_x = (MIX_MULT_L * x) & MASK32 + result_y = (MIX_MULT_R * y) & MASK32 + result = (result_x - result_y) & MASK32 + result = (result ^ (result >> XSHIFT)) & MASK32 + return result + + # Add in the entropy to the pool. + for i in range(len(pool)): + pool[i] = hash(entropy[i]) + + # Mix all bits together so late bits can affect earlier bits. + for i_src in range(len(pool)): + for i_dst in range(len(pool)): + if i_src != i_dst: + pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) + + hash_const_B = INIT_B + state = [] + for i_dst in range(4): + data_val = pool[i_dst] + data_val = (data_val ^ hash_const_B) & MASK32 + hash_const_B = (hash_const_B * MULT_B) & MASK32 + data_val = (data_val * hash_const_B) & MASK32 + data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 + state.append(data_val) + return state + + +def _worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + base_seed, + init_fn, + worker_id, + num_workers, + persistent_workers, + shared_seed, + use_pyboost=True, +): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + mindtorch.configs.set_pyboost(use_pyboost) + try: + # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + # signal_handling._set_worker_signal_handlers() + + # mindtorch.multiprocessing._set_thread_name("pt_data_worker") + if disable_multi_thread is not None: + disable_multi_thread() + # mindtorch.set_num_threads(1) + seed = base_seed + worker_id + random.seed(seed) + mindtorch.manual_seed(seed) + if HAS_NUMPY: + np_seed = _generate_state(base_seed, worker_id) + import numpy as np + + np.random.seed(np_seed) + + # from mindtorch.utils.data import IterDataPipe + # from mindtorch.utils.data.graph_settings import apply_random_seed + + shared_rng = mindtorch.Generator() + # if isinstance(dataset, IterDataPipe): + # assert shared_seed is not None + # shared_rng.manual_seed(shared_seed) + # dataset = apply_random_seed(dataset, shared_rng) + + global _worker_info + _worker_info = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset + ) + + from mindtorch.utils.data import _DatasetKind + + init_exception = None + + try: + if init_fn is not None: + init_fn(worker_id) + + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + except Exception: + init_exception = ExceptionWrapper( + where=f"in DataLoader worker process {worker_id}" + ) + + # When using Iterable mode, some worker can exit earlier than others due + # to the IterableDataset behaving differently for different workers. + # When such things happen, an `_IterableDatasetStopIteration` object is + # sent over to the main process with the ID of this worker, so that the + # main process won't send more tasks to this worker, and will send + # `None` to this worker to properly exit it. + # + # Note that we cannot set `done_event` from a worker as it is shared + # among all processes. Instead, we set the `iteration_end` flag to + # signify that the iterator is exhausted. When either `done_event` or + # `iteration_end` is set, we skip all processing step and just wait for + # `None`. + iteration_end = False + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if isinstance(r, _ResumeIteration): + # Acknowledge the main process + data_queue.put((r, None)) + iteration_end = False + + # if isinstance(dataset, IterDataPipe): + # assert r.seed is not None + # shared_rng.manual_seed(r.seed) + # dataset = apply_random_seed(dataset, shared_rng) + + # Recreate the fetcher for worker-reuse policy + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collation, collate_fn, drop_last + ) + continue + elif r is None: + # Received the final signal + assert done_event.is_set() or iteration_end + break + elif done_event.is_set() or iteration_end: + # `done_event` is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, index = r + data: Union[_IterableDatasetStopIteration, ExceptionWrapper] + if init_exception is not None: + data = init_exception + init_exception = None + else: + try: + data = fetcher.fetch(index) # type: ignore[possibly-undefined] + except Exception as e: + if ( + isinstance(e, StopIteration) + and dataset_kind == _DatasetKind.Iterable + ): + data = _IterableDatasetStopIteration(worker_id) + # Set `iteration_end` + # (1) to save future `next(...)` calls, and + # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. + iteration_end = True + else: + # It is important that we don't store exc_info in a variable. + # `ExceptionWrapper` does the correct thing. + # See NOTE [ Python Traceback Reference Cycle Problem ] + data = ExceptionWrapper( + where=f"in DataLoader worker process {worker_id}" + ) + data_queue.put((idx, data)) + del data, idx, index, r # save memory + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + if done_event.is_set(): + data_queue.cancel_join_thread() + data_queue.close() \ No newline at end of file diff --git a/mindtorch/utils/data/dataloader.py b/mindtorch/utils/data/dataloader.py new file mode 100644 index 000000000..6d51d250c --- /dev/null +++ b/mindtorch/utils/data/dataloader.py @@ -0,0 +1,1620 @@ +# mypy: allow-untyped-defs +r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter. + +To support these two classes, in `./_utils` we define many utility methods and +functions to be run in multiprocessing. E.g., the data loading worker loop is +in `./_utils/worker.py`. +""" + +import functools +import itertools +import logging +import multiprocessing as python_multiprocessing +import os +import queue +import threading +import warnings +from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar, Union + +import mindtorch +from mindtorch import distributed as dist +# import mindtorch.utils.data.graph_settings +from mindtorch._utils import ExceptionWrapper +from mindtorch.utils.data import _utils +# from mindtorch.utils.data.datapipes.datapipe import ( +# _IterDataPipeSerializationWrapper, +# _MapDataPipeSerializationWrapper, +# IterDataPipe, +# MapDataPipe, +# ) +from . import ( + IterableDataset, + Sampler, + SequentialSampler, + RandomSampler, + BatchSampler, + Dataset,) + + +__all__ = [ + "DataLoader", + "get_worker_info", + "default_collate", + "default_convert", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_worker_init_fn_t = Callable[[int], None] + +# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that +# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. +# See https://github.com/python/mypy/issues/3737. +_collate_fn_t = Callable[[List[_T]], Any] + + +# These functions used to be defined in this file. However, it was moved to +# _utils/collate.py. Although it is rather hard to access this from user land +# (one has to explicitly directly `import mindtorch.utils.data.dataloader`), there +# probably is user code out there using it. This aliasing maintains BC in this +# aspect. +default_collate: _collate_fn_t = _utils.collate.default_collate +default_convert = _utils.collate.default_convert + +get_worker_info = _utils.worker.get_worker_info + +logger = logging.getLogger(__name__) + + +class _DatasetKind: + Map = 0 + Iterable = 1 + + @staticmethod + def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): + if kind == _DatasetKind.Map: + return _utils.fetch._MapDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + else: + return _utils.fetch._IterableDatasetFetcher( + dataset, auto_collation, collate_fn, drop_last + ) + + +class _InfiniteConstantSampler(Sampler): + r"""Analogous to ``itertools.repeat(None, None)``. + + Used as sampler for :class:`~mindtorch.utils.data.IterableDataset`. + """ + + def __iter__(self): + while True: + yield None + + +def _get_distributed_settings(): + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + else: + return 1, 0 + + +def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): + global_worker_id = worker_id + info = mindtorch.utils.data.get_worker_info() + assert info is not None + total_workers = info.num_workers + datapipe = info.dataset + assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) + # To distribute elements across distributed process evenly, we should shard data on distributed + # processes first then shard on worker processes + total_workers *= world_size + global_worker_id = global_worker_id * world_size + rank_id + # For BC, use default SHARDING_PRIORITIES + mindtorch.utils.data.graph_settings.apply_sharding( + datapipe, total_workers, global_worker_id + ) + if worker_init_fn is not None: + worker_init_fn(worker_id) + + +def _share_dist_seed(generator, pg): + _shared_seed = mindtorch.empty((), dtype=mindtorch.int64).random_(generator=generator) + if isinstance(pg, dist.ProcessGroup): + dist.broadcast(_shared_seed, src=0, group=pg) + return _shared_seed.item() + + +class DataLoader(Generic[_T_co]): + r""" + Data loader combines a dataset and a sampler, and provides an iterable over the given dataset. + + The :class:`~mindtorch.utils.data.DataLoader` supports both map-style and + iterable-style datasets with single- or multi-process loading, customizing + loading order and optional automatic batching (collation) and memory pinning. + + See :py:mod:`mindtorch.utils.data` documentation page for more details. + + Args: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: ``1``). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: ``False``). + sampler (Sampler or Iterable, optional): defines the strategy to draw + samples from the dataset. Can be any ``Iterable`` with ``__len__`` + implemented. If specified, :attr:`shuffle` must not be specified. + batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but + returns a batch of indices at a time. Mutually exclusive with + :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, + and :attr:`drop_last`. + num_workers (int, optional): how many subprocesses to use for data + loading. ``0`` means that the data will be loaded in the main process. + (default: ``0``) + collate_fn (Callable, optional): merges a list of samples to form a + mini-batch of Tensor(s). Used when using batched loading from a + map-style dataset. + pin_memory (bool, optional): If ``True``, the data loader will copy Tensors + into device/CUDA pinned memory before returning them. If your data elements + are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, + see the example below. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``False``) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + worker_init_fn (Callable, optional): If not ``None``, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: ``None``) + multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If + ``None``, the default `multiprocessing context`_ of your operating system will + be used. (default: ``None``) + generator (mindtorch.Generator, optional): If not ``None``, this RNG will be used + by RandomSampler to generate random indexes and multiprocessing to generate + ``base_seed`` for workers. (default: ``None``) + prefetch_factor (int, optional, keyword-only arg): Number of batches loaded + in advance by each worker. ``2`` means there will be a total of + 2 * num_workers batches prefetched across all workers. (default value depends + on the set value for num_workers. If value of num_workers=0 default is ``None``. + Otherwise, if value of ``num_workers > 0`` default is ``2``). + persistent_workers (bool, optional): If ``True``, the data loader will not shut down + the worker processes after a dataset has been consumed once. This allows to + maintain the workers `Dataset` instances alive. (default: ``False``) + pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is + ``True``. + in_order (bool, optional): If ``False``, the data loader will not enforce that batches + are returned in a first-in, first-out order. Only applies when ``num_workers > 0``. (default: ``True``) + + + .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` + cannot be an unpicklable object, e.g., a lambda function. See + :ref:`multiprocessing-best-practices` on more details related + to multiprocessing in PyTorch. + + .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. + When :attr:`dataset` is an :class:`~mindtorch.utils.data.IterableDataset`, + it instead returns an estimate based on ``len(dataset) / batch_size``, with proper + rounding depending on :attr:`drop_last`, regardless of multi-process loading + configurations. This represents the best guess PyTorch can make because PyTorch + trusts user :attr:`dataset` code in correctly handling multi-process + loading to avoid duplicate data. + + However, if sharding results in multiple workers having incomplete last batches, + this estimate can still be inaccurate, because (1) an otherwise complete batch can + be broken into multiple ones and (2) more than one batch worth of samples can be + dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such + cases in general. + + See `Dataset Types`_ for more details on these two types of datasets and how + :class:`~mindtorch.utils.data.IterableDataset` interacts with + `Multi-process data loading`_. + + .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and + :ref:`data-loading-randomness` notes for random seed related questions. + + .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data + distribution being fed to the trainer in cases with imbalanced data. + + .. _multiprocessing context: + https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + """ + + dataset: Dataset[_T_co] + batch_size: Optional[int] + num_workers: int + pin_memory: bool + drop_last: bool + timeout: float + sampler: Union[Sampler, Iterable] + pin_memory_device: str + prefetch_factor: Optional[int] + _iterator: Optional["_BaseDataLoaderIter"] + __initialized = False + + def __init__( + self, + dataset: Dataset[_T_co], + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = None, + sampler: Union[Sampler, Iterable, None] = None, + batch_sampler: Union[Sampler[List], Iterable[List], None] = None, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn: Optional[_worker_init_fn_t] = None, + multiprocessing_context=None, + generator=None, + *, + prefetch_factor: Optional[int] = None, + persistent_workers: bool = False, + pin_memory_device: str = "", + in_order: bool = True, + ): + # mindtorch._C._log_api_usage_once("python.data_loader") + + if num_workers < 0: + raise ValueError( + "num_workers option should be non-negative; " + "use num_workers=0 to disable multiprocessing." + ) + + if timeout < 0: + raise ValueError("timeout option should be non-negative") + + if num_workers == 0 and prefetch_factor is not None: + raise ValueError( + "prefetch_factor option could only be specified in multiprocessing." + "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None." + ) + elif num_workers > 0 and prefetch_factor is None: + prefetch_factor = 2 + elif prefetch_factor is not None and prefetch_factor < 0: + raise ValueError("prefetch_factor option should be non-negative") + + if persistent_workers and num_workers == 0: + raise ValueError("persistent_workers option needs num_workers > 0") + + self.dataset = dataset + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.pin_memory = pin_memory + self.pin_memory_device = pin_memory_device + self.timeout = timeout + self.worker_init_fn = worker_init_fn + self.multiprocessing_context = multiprocessing_context + self.in_order = in_order + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler + # if isinstance(self.dataset, IterDataPipe): + # self.dataset = _IterDataPipeSerializationWrapper(self.dataset) + # elif isinstance(self.dataset, MapDataPipe): + # self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + + # Arg-check dataset related before checking samplers because we want to + # tell users that iterable-style datasets are incompatible with custom + # samplers first, so that they don't learn that this combo doesn't work + # after spending time fixing the custom sampler errors. + + if isinstance(dataset, IterableDataset): + self._dataset_kind = _DatasetKind.Iterable + # NOTE [ Custom Samplers and IterableDataset ] + # + # `IterableDataset` does not support custom `batch_sampler` or + # `sampler` since the key is irrelevant (unless we support + # generator-style dataset one day...). + # + # For `sampler`, we always create a dummy sampler. This is an + # infinite sampler even when the dataset may have an implemented + # finite `__len__` because in multi-process data loading, naive + # settings will return duplicated data (which may be desired), and + # thus using a sampler with length matching that of dataset will + # cause data lost (you may have duplicates of the first couple + # batches, but never see anything afterwards). Therefore, + # `Iterabledataset` always uses an infinite sampler, an instance of + # `_InfiniteConstantSampler` defined above. + # + # A custom `batch_sampler` essentially only controls the batch size. + # However, it is unclear how useful it would be since an iterable-style + # dataset can handle that within itself. Moreover, it is pointless + # in multi-process data loading as the assignment order of batches + # to workers is an implementation detail so users can not control + # how to batchify each worker's iterable. Thus, we disable this + # option. If this turns out to be useful in future, we can re-enable + # this, and support custom samplers that specify the assignments to + # specific workers. + # if isinstance(dataset, IterDataPipe): + # if shuffle is not None: + # dataset = mindtorch.utils.data.graph_settings.apply_shuffle_settings( + # dataset, shuffle=shuffle + # ) + # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. + if shuffle not in {False, None}: + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}" + ) + + if sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}" + ) + elif batch_sampler is not None: + # See NOTE [ Custom Samplers and IterableDataset ] + raise ValueError( + "DataLoader with IterableDataset: expected unspecified " + f"batch_sampler option, but got batch_sampler={batch_sampler}" + ) + else: + shuffle = bool(shuffle) + self._dataset_kind = _DatasetKind.Map + + if sampler is not None and shuffle: + raise ValueError("sampler option is mutually exclusive with " "shuffle") + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError( + "batch_sampler option is mutually exclusive " + "with batch_size, shuffle, sampler, and " + "drop_last" + ) + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if drop_last: + raise ValueError( + "batch_size=None option disables auto-batching " + "and is mutually exclusive with drop_last" + ) + + if sampler is None: # give default samplers + if self._dataset_kind == _DatasetKind.Iterable: + # See NOTE [ Custom Samplers and IterableDataset ] + sampler = _InfiniteConstantSampler() + else: # map-style + if shuffle: + sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] + else: + sampler = SequentialSampler(dataset) # type: ignore[arg-type] + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + self.generator = generator + + if collate_fn is None: + if self._auto_collation: + collate_fn = _utils.collate.default_collate + else: + collate_fn = _utils.collate.default_convert + + self.collate_fn = collate_fn + self.persistent_workers = persistent_workers + + self.__initialized = True + self._IterableDataset_len_called = ( + None # See NOTE [ IterableDataset and __len__ ] + ) + + self._iterator = None + + self.check_worker_number_rationality() + + # mindtorch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] + + def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIter(self) + + @property + def multiprocessing_context(self): + return self.__multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + if multiprocessing_context is not None: + if self.num_workers > 0: + if isinstance(multiprocessing_context, str): + valid_start_methods = mindtorch.multiprocessing.get_all_start_methods() + if multiprocessing_context not in valid_start_methods: + raise ValueError( + "multiprocessing_context option " + f"should specify a valid start method in {valid_start_methods!r}, but got " + f"multiprocessing_context={multiprocessing_context!r}" + ) + multiprocessing_context = mindtorch.multiprocessing.get_context( + multiprocessing_context + ) + + if not isinstance( + multiprocessing_context, python_multiprocessing.context.BaseContext + ): + raise TypeError( + "multiprocessing_context option should be a valid context " + "object or a string specifying the start method, but got " + f"multiprocessing_context={multiprocessing_context}" + ) + else: + raise ValueError( + "multiprocessing_context can only be used with " + "multi-process loading (num_workers > 0), but got " + f"num_workers={self.num_workers}" + ) + + self.__multiprocessing_context = multiprocessing_context + + def __setattr__(self, attr, val): + if self.__initialized and attr in ( + "batch_size", + "batch_sampler", + "sampler", + "drop_last", + "dataset", + "persistent_workers", + ): + raise ValueError( + f"{attr} attribute should not be set after {self.__class__.__name__} is initialized" + ) + + super().__setattr__(attr, val) + + # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up + # since '_BaseDataLoaderIter' references 'DataLoader'. + def __iter__(self) -> "_BaseDataLoaderIter": + # When using a single worker the returned iterator should be + # created everytime to avoid resetting its state + # However, in the case of a multiple workers iterator + # the iterator is only created once in the lifetime of the + # DataLoader object so that workers can be reused + if self.persistent_workers and self.num_workers > 0: + if self._iterator is None: + self._iterator = self._get_iterator() + else: + self._iterator._reset(self) + return self._iterator + else: + return self._get_iterator() + + @property + def _auto_collation(self): + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self) -> int: + if self._dataset_kind == _DatasetKind.Iterable: + # NOTE [ IterableDataset and __len__ ] + # + # For `IterableDataset`, `__len__` could be inaccurate when one naively + # does multi-processing data loading, since the samples will be duplicated. + # However, no real use case should be actually using that behavior, so + # it should count as a user error. We should generally trust user + # code to do the proper thing (e.g., configure each replica differently + # in `__iter__`), and give us the correct `__len__` if they choose to + # implement it (this will still throw if the dataset does not implement + # a `__len__`). + # + # To provide a further warning, we track if `__len__` was called on the + # `DataLoader`, save the returned value in `self._len_called`, and warn + # if the iterator ends up yielding more than this number of samples. + + # Cannot statically verify that dataset is Sized + length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type] + if ( + self.batch_size is not None + ): # IterableDataset doesn't allow custom sampler or batch_sampler + from math import ceil + + if self.drop_last: + length = length // self.batch_size + else: + length = ceil(length / self.batch_size) + return length + else: + return len(self._index_sampler) + + def check_worker_number_rationality(self): + # This function check whether the dataloader's worker number is rational based on + # current system's resource. Current rule is that if the number of workers this + # Dataloader will create is bigger than the number of logical cpus that is allowed to + # use, than we will pop up a warning to let user pay attention. + # + # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 + # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current + # DataLoader process can use half of them which is 32, then the rational max number of + # worker that initiated from this process is 32. + # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. + # So the warning message is triggered to notify the user to lower the worker number if + # necessary. + # + # + # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is + # available (available in most of Linux system, but not OSX and Windows). + # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but + # it doesn't repect cpuset. + # We don't take threading into account since each worker process is single threaded + # at this time. + # + # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) + # other than `mindtorch.set_num_threads` to 1 in the worker process, if the passing + # in functions use 3rd party modules that rely on those threading flags to determine + # how many thread to create (eg. numpy, etc), then it is caller's responsibility to + # set those flags correctly. + def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): + suggested_max_worker_msg = ( + ( + ( + "Our suggested max number of worker in current system is {}{}, which is smaller " + "than what this DataLoader is going to create." + ).format( + num_worker_suggest, + ( + "" + if cpuset_checked + else " (`cpuset` is not taken into account)" + ), + ) + ) + if num_worker_suggest is not None + else ( + "DataLoader is not able to compute a suggested max number of worker in current system." + ) + ) + + warn_msg = ( + f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} " + "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " + "lower the worker number to avoid potential slowness/freeze if necessary." + ) + return warn_msg + + if not self.num_workers or self.num_workers == 0: + return + + # try to compute a suggested max number of worker based on system's resource + max_num_worker_suggest = None + cpuset_checked = False + if hasattr(os, "sched_getaffinity"): + try: + max_num_worker_suggest = len(os.sched_getaffinity(0)) + cpuset_checked = True + except Exception: + pass + if max_num_worker_suggest is None: + # os.cpu_count() could return Optional[int] + # get cpu count first and check None in order to satisfy mypy check + cpu_count = os.cpu_count() + if cpu_count is not None: + max_num_worker_suggest = cpu_count + + if max_num_worker_suggest is None: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + return + + if self.num_workers > max_num_worker_suggest: + warnings.warn( + _create_warning_msg( + max_num_worker_suggest, self.num_workers, cpuset_checked + ) + ) + + +class _BaseDataLoaderIter: + def __init__(self, loader: DataLoader) -> None: + self._dataset = loader.dataset + self._shared_seed = None + self._pg = None + # if isinstance(self._dataset, IterDataPipe): + # if dist.is_available() and dist.is_initialized(): + # self._pg = dist.new_group(backend="gloo") + # self._shared_seed = _share_dist_seed(loader.generator, self._pg) + # shared_rng = mindtorch.Generator() + # shared_rng.manual_seed(self._shared_seed) + # self._dataset = mindtorch.utils.data.graph_settings.apply_random_seed( + # self._dataset, shared_rng + # ) + self._dataset_kind = loader._dataset_kind + self._IterableDataset_len_called = loader._IterableDataset_len_called + self._auto_collation = loader._auto_collation + self._drop_last = loader.drop_last + self._index_sampler = loader._index_sampler + self._num_workers = loader.num_workers + ws, rank = _get_distributed_settings() + self._world_size = ws + self._rank = rank + # for other backends, pin_memory_device need to set. if not set + # default behaviour is CUDA device. if pin_memory_device is selected + # and pin_memory is not set, the default behaviour false. + if len(loader.pin_memory_device) == 0: + self._pin_memory = loader.pin_memory and mindtorch.cuda.is_available() + self._pin_memory_device = None + else: + if not loader.pin_memory: + warn_msg = ( + "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used" + "please set pin_memory to true, if you need to use the device pin memory" + ) + warnings.warn(warn_msg) + + self._pin_memory = loader.pin_memory + self._pin_memory_device = loader.pin_memory_device + self._timeout = loader.timeout + self._collate_fn = loader.collate_fn + self._sampler_iter = iter(self._index_sampler) + self._base_seed = ( + mindtorch.empty((), dtype=mindtorch.int64) + .random_(generator=loader.generator) + .item() + ) + self._persistent_workers = loader.persistent_workers + self._num_yielded = 0 + self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" + + def __iter__(self) -> "_BaseDataLoaderIter": + return self + + def _reset(self, loader, first_iter=False): + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + # if isinstance(self._dataset, IterDataPipe): + # self._shared_seed = _share_dist_seed(loader.generator, self._pg) + # shared_rng = mindtorch.Generator() + # shared_rng.manual_seed(self._shared_seed) + # self._dataset = mindtorch.utils.data.graph_settings.apply_random_seed( + # self._dataset, shared_rng + # ) + + def _next_index(self): + return next(self._sampler_iter) # may raise StopIteration + + def _next_data(self): + raise NotImplementedError + + def __next__(self) -> Any: + # with mindtorch.autograd.profiler.record_function(self._profile_name): + if self._sampler_iter is None: + # TODO(https://github.com/pytorch/pytorch/issues/76750) + self._reset() # type: ignore[call-arg] + data = self._next_data() + self._num_yielded += 1 + if ( + self._dataset_kind == _DatasetKind.Iterable + and self._IterableDataset_len_called is not None + and self._num_yielded > self._IterableDataset_len_called + ): + warn_msg = ( + f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}" + f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. " + ) + if self._num_workers > 0: + warn_msg += ( + "For multiprocessing data-loading, this could be caused by not properly configuring the " + "IterableDataset replica at each worker. Please see " + "https://pymindtorch.org/docs/stable/data.html#mindtorch.utils.data.IterableDataset for examples." + ) + warnings.warn(warn_msg) + return data + + def __len__(self) -> int: + return len(self._index_sampler) + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) + + +class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): + def __init__(self, loader): + super().__init__(loader) + assert self._timeout == 0 + assert self._num_workers == 0 + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Taking care of distributed sharding + # if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + # # For BC, use default SHARDING_PRIORITIES + # mindtorch.utils.data.graph_settings.apply_sharding( + # self._dataset, self._world_size, self._rank + # ) + + self._dataset_fetcher = _DatasetKind.create_fetcher( + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, + ) + + def _next_data(self): + index = self._next_index() # may raise StopIteration + data = self._dataset_fetcher.fetch(index) # may raise StopIteration + if self._pin_memory: + data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) + return data + + +class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): + r"""Iterates once over the DataLoader's dataset, as specified by the sampler.""" + + # NOTE [ Data Loader Multiprocessing Shutdown Logic ] + # + # Preliminary: + # + # Our data model looks like this (queues are indicated with curly brackets): + # + # main process || + # | || + # {index_queue} || + # | || + # worker processes || DATA + # | || + # {worker_result_queue} || FLOW + # | || + # pin_memory_thread of main process || DIRECTION + # | || + # {data_queue} || + # | || + # data output \/ + # + # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if + # `pin_memory=False`. + # + # + # Terminating multiprocessing logic requires very careful design. In + # particular, we need to make sure that + # + # 1. The iterator gracefully exits the workers when its last reference is + # gone or it is depleted. + # + # In this case, the workers should be gracefully exited because the + # main process may still need to continue to run, and we want cleaning + # up code in the workers to be executed (e.g., releasing GPU memory). + # Naturally, we implement the shutdown logic in `__del__` of + # DataLoaderIterator. + # + # We delay the discussion on the logic in this case until later. + # + # 2. The iterator exits the workers when the loader process and/or worker + # processes exits normally or with error. + # + # We set all workers and `pin_memory_thread` to have `daemon=True`. + # + # You may ask, why can't we make the workers non-daemonic, and + # gracefully exit using the same logic as we have in `__del__` when the + # iterator gets deleted (see 1 above)? + # + # First of all, `__del__` is **not** guaranteed to be called when + # interpreter exits. Even if it is called, by the time it executes, + # many Python core library resources may already be freed, and even + # simple things like acquiring an internal lock of a queue may hang. + # Therefore, in this case, we actually need to prevent `__del__` from + # being executed, and rely on the automatic termination of daemonic + # children. + # + # Thus, we register an `atexit` hook that sets a global flag + # `_utils.python_exit_status`. Since `atexit` hooks are executed in the + # reverse order of registration, we are guaranteed that this flag is + # set before library resources we use are freed (which, at least in + # CPython, is done via an `atexit` handler defined in + # `multiprocessing/util.py` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 + # registered when an object requiring this mechanism is first + # created, e.g., `mp.Queue` + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 + # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 + # ) + # + # So in `__del__`, we check if `_utils.python_exit_status` is set or + # `None` (freed), and perform no-op if so. + # + # However, simply letting library clean-up codes run can also be bad, + # because such codes (i.e., `multiprocessing.util._exit_function()`) + # include join putting threads for `mp.Queue`, which can be blocking. + # Hence, the main process putting threads are called with + # `cancel_join_thread` at creation. See later section + # [ 3b. A process won't hang when putting into a queue; ] + # for more details. + # + # Here are two example cases where library clean-up codes can run + # before `__del__` is called: + # + # 1. If we hold onto a reference to the iterator, it more often + # than not tries to do `multiprocessing` library cleaning before + # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) + # and thus prevents our cleaning-up code to run first. + # + # 2. A similar issue araises when a `DataLoader` is used in a subprocess. + # When a process ends, it shuts the all its daemonic children + # down with a SIGTERM (instead of joining them without a timeout). + # Simiarly for threads, but by a different mechanism. This fact, + # together with a few implementation details of multiprocessing, forces + # us to make workers daemonic. All of our problems arise when a + # DataLoader is used in a subprocess, and are caused by multiprocessing + # code which looks more or less like this: + # + # try: + # your_function_using_a_dataloader() + # finally: + # multiprocessing.util._exit_function() + # + # The joining/termination mentioned above happens inside + # `_exit_function()`. Now, if `your_function_using_a_dataloader()` + # throws, the stack trace stored in the exception will prevent the + # frame which uses `DataLoaderIter` to be freed. If the frame has any + # reference to the `DataLoaderIter` (e.g., in a method of the iter), + # its `__del__`, which starts the shutdown procedure, will not be + # called. That, in turn, means that workers aren't notified. Attempting + # to join in `_exit_function` will then result in a hang. + # + # For context, `_exit_function` is also registered as an `atexit` call. + # So it is unclear to me (@ssnl) why this is needed in a finally block. + # The code dates back to 2008 and there is no comment on the original + # PEP 371 or patch https://bugs.python.org/issue3050 (containing both + # the finally block and the `atexit` registration) that explains this. + # + # + # Finally, another choice is to just shutdown workers with logic in 1 + # above whenever we see an error in `next`. This isn't ideal because + # a. It prevents users from using try-catch to resume data loading. + # b. It doesn't prevent hanging if users have references to the + # iterator. + # + # 3. All processes exit if any of them die unexpectedly by fatal signals. + # + # As shown above, the workers are set as daemonic children of the main + # process. However, automatic cleaning-up of such child processes only + # happens if the parent process exits gracefully (e.g., not via fatal + # signals like SIGKILL). So we must ensure that each process will exit + # even the process that should send/receive data to/from it were + # killed, i.e., + # + # a. A process won't hang when getting from a queue. + # + # Even with carefully designed data dependencies (i.e., a `put()` + # always corresponding to a `get()`), hanging on `get()` can still + # happen when data in queue is corrupted (e.g., due to + # `cancel_join_thread` or unexpected exit). + # + # For child exit, we set a timeout whenever we try to get data + # from `data_queue`, and check the workers' status on each timeout + # and error. + # See `_DataLoaderiter._get_batch()` and + # `_DataLoaderiter._try_get_data()` for details. + # + # Additionally, for child exit on non-Windows platforms, we also + # register a SIGCHLD handler (which is supported on Windows) on + # the main process, which checks if any of the workers fail in the + # (Python) handler. This is more efficient and faster in detecting + # worker failures, compared to only using the above mechanism. + # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. + # + # For `.get()` calls where the sender(s) is not the workers, we + # guard them with timeouts, and check the status of the sender + # when timeout happens: + # + in the workers, the `_utils.worker.ManagerWatchdog` class + # checks the status of the main process. + # + if `pin_memory=True`, when getting from `pin_memory_thread`, + # check `pin_memory_thread` status periodically until `.get()` + # returns or see that `pin_memory_thread` died. + # + # b. A process won't hang when putting into a queue; + # + # We use `mp.Queue` which has a separate background thread to put + # objects from an unbounded buffer array. The background thread is + # daemonic and usually automatically joined when the process + # *exits*. + # + # In case that the receiver has ended abruptly while + # reading from the pipe, the join will hang forever. The usual + # solution for this in Python is calling `q.cancel_join_thread`, + # which prevents automatically joining it when finalizing + # (exiting). + # + # Nonetheless, `cancel_join_thread` must only be called when the + # queue is **not** going to be read from or write into by another + # process, because it may hold onto a lock or leave corrupted data + # in the queue, leading other readers/writers to hang. + # + # Hence, + # + For worker processes, we only do so (for their output + # queues, i.e., `worker_result_queue`) before exiting. + # + For `pin_memory_thread`, its output queue `data_queue` is a + # `queue.Queue` that does blocking `put` if the queue is full. + # So there is no above problem, but as a result, in + # `_pin_memory_loop`, we do need to wrap the `put` in a loop + # that breaks not only upon success, but also when the main + # process stops reading, i.e., is shutting down. + # + For loader process, we `cancel_join_thread()` for all + # `_index_queues` because the whole purpose of workers and + # `pin_memory_thread` is to serve the loader process. If + # loader process is already exiting, we don't really care if + # the queues are corrupted. + # + # + # Now let's get back to 1: + # how we gracefully exit the workers when the last reference to the + # iterator is gone. + # + # To achieve this, we implement the following logic along with the design + # choices mentioned above: + # + # `workers_done_event`: + # A `multiprocessing.Event` shared among the main process and all worker + # processes. This is used to signal the workers that the iterator is + # shutting down. After it is set, they will not send processed data to + # queues anymore, and only wait for the final `None` before exiting. + # `done_event` isn't strictly needed. I.e., we can just check for `None` + # from the input queue, but it allows us to skip wasting resources + # processing data if we are already shutting down. + # + # `pin_memory_thread_done_event`: + # A `threading.Event` for a similar purpose to that of + # `workers_done_event`, but is for the `pin_memory_thread`. The reason + # that separate events are needed is that `pin_memory_thread` reads from + # the output queue of the workers. But the workers, upon seeing that + # `workers_done_event` is set, only wants to see the final `None`, and is + # not required to flush all data in the output queue (e.g., it may call + # `cancel_join_thread` on that queue if its `IterableDataset` iterator + # happens to exhaust coincidentally, which is out of the control of the + # main process). Thus, since we will exit `pin_memory_thread` before the + # workers (see below), two separete events are used. + # + # NOTE: In short, the protocol is that the main process will set these + # `done_event`s and then the corresponding processes/threads a `None`, + # and that they may exit at any time after receiving the `None`. + # + # NOTE: Using `None` as the final signal is valid, since normal data will + # always be a 2-tuple with the 1st element being the index of the data + # transferred (different from dataset index/key), and the 2nd being + # either the dataset key or the data sample (depending on which part + # of the data model the queue is at). + # + # [ worker processes ] + # While loader process is alive: + # Get from `index_queue`. + # If get anything else, + # Check `workers_done_event`. + # If set, continue to next iteration + # i.e., keep getting until see the `None`, then exit. + # Otherwise, process data: + # If is fetching from an `IterableDataset` and the iterator + # is exhausted, send an `_IterableDatasetStopIteration` + # object to signal iteration end. The main process, upon + # receiving such an object, will send `None` to this + # worker and not use the corresponding `index_queue` + # anymore. + # If timed out, + # No matter `workers_done_event` is set (still need to see `None`) + # or not, must continue to next iteration. + # (outside loop) + # If `workers_done_event` is set, (this can be False with `IterableDataset`) + # `data_queue.cancel_join_thread()`. (Everything is ending here: + # main process won't read from it; + # other workers will also call + # `cancel_join_thread`.) + # + # [ pin_memory_thread ] + # # No need to check main thread. If this thread is alive, the main loader + # # thread must be alive, because this thread is set as daemonic. + # While `pin_memory_thread_done_event` is not set: + # Get from `worker_result_queue`. + # If timed out, continue to get in the next iteration. + # Otherwise, process data. + # While `pin_memory_thread_done_event` is not set: + # Put processed data to `data_queue` (a `queue.Queue` with blocking put) + # If timed out, continue to put in the next iteration. + # Otherwise, break, i.e., continuing to the out loop. + # + # NOTE: we don't check the status of the main thread because + # 1. if the process is killed by fatal signal, `pin_memory_thread` + # ends. + # 2. in other cases, either the cleaning-up in __del__ or the + # automatic exit of daemonic thread will take care of it. + # This won't busy-wait either because `.get(timeout)` does not + # busy-wait. + # + # [ main process ] + # In the DataLoader Iter's `__del__` + # b. Exit `pin_memory_thread` + # i. Set `pin_memory_thread_done_event`. + # ii Put `None` in `worker_result_queue`. + # iii. Join the `pin_memory_thread`. + # iv. `worker_result_queue.cancel_join_thread()`. + # + # c. Exit the workers. + # i. Set `workers_done_event`. + # ii. Put `None` in each worker's `index_queue`. + # iii. Join the workers. + # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. + # + # NOTE: (c) is better placed after (b) because it may leave corrupted + # data in `worker_result_queue`, which `pin_memory_thread` + # reads from, in which case the `pin_memory_thread` can only + # happen at timing out, which is slow. Nonetheless, same thing + # happens if a worker is killed by signal at unfortunate times, + # but in other cases, we are better off having a non-corrupted + # `worker_result_queue` for `pin_memory_thread`. + # + # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) + # can be omitted + # + # NB: `done_event`s isn't strictly needed. E.g., we can just check for + # `None` from `index_queue`, but it allows us to skip wasting resources + # processing indices already in `index_queue` if we are already shutting + # down. + + def __init__(self, loader): + super().__init__(loader) + + self._prefetch_factor = loader.prefetch_factor + self._in_order = loader.in_order + + assert self._num_workers > 0 + assert self._prefetch_factor > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = mindtorch.multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + # if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + # self._worker_init_fn = functools.partial( + # _sharding_worker_init_fn, + # self._worker_init_fn, + # self._world_size, + # self._rank, + # ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_utils.worker._worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + False + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + if self._pin_memory_device == "xpu": + current_device = mindtorch.xpu.current_device() # type: ignore[attr-defined] + elif self._pin_memory_device == mindtorch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr( + torch, mindtorch._C._get_privateuse1_backend_name() + ) + current_device = custom_device_mod.current_device() + else: + current_device = mindtorch.cuda.current_device() # choose cuda for default + pin_memory_thread = threading.Thread( + target=_utils.pin_memory._pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device, + self._pin_memory_thread_done_event, + self._pin_memory_device, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue # type: ignore[assignment] + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + # _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + # _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def _reset(self, loader, first_iter=False): + super()._reset(loader, first_iter) + self._send_idx = 0 # idx of the next task to be sent to workers + self._rcvd_idx = 0 # idx of the next task to be returned in __next__ + # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). + # map: task idx => - (worker_id,) if data isn't fetched (outstanding) + # \ (worker_id, data) if data is already fetched (out-of-order) + self._task_info = {} + self._tasks_outstanding = ( + 0 # always equal to count(v for v in task_info.values() if len(v) == 1) + ) + # A list of booleans representing whether each worker still has work to + # do, i.e., not having exhausted its iterable dataset object. It always + # contains all `True`s if not using an iterable-style dataset + # (i.e., if kind != Iterable). + # Not that this indicates that a worker still has work to do *for this epoch*. + # It does not mean that a worker is dead. In case of `_persistent_workers`, + # the worker will be reset to available in the next epoch. + self._workers_status = [True for i in range(self._num_workers)] + # Reset the worker queue cycle so it resumes next epoch at worker 0 + self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) + # We resume the prefetching in case it was enabled + if not first_iter: + for idx in range(self._num_workers): + self._index_queues[idx].put( + _utils.worker._ResumeIteration(self._shared_seed) + ) + resume_iteration_cnt = self._num_workers + while resume_iteration_cnt > 0: + return_idx, return_data = self._get_data() + if isinstance(return_idx, _utils.worker._ResumeIteration): + assert return_data is None + resume_iteration_cnt -= 1 + # prime the prefetch loop + for _ in range(self._prefetch_factor * self._num_workers): + self._try_put_index() + + def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): + # Tries to fetch data from `self._data_queue` once for a given timeout. + # This can also be used as inner loop of fetching without timeout, with + # the sender status as the loop condition. + # + # This raises a `RuntimeError` if any worker died expectedly. This error + # can come from either the SIGCHLD handler in `_utils/signal_handling.py` + # (only for non-Windows platforms), or the manual check below on errors + # and timeouts. + # + # Returns a 2-tuple: + # (bool: whether successfully get data, any: data if successful else None) + try: + data = self._data_queue.get(timeout=timeout) + return (True, data) + except Exception as e: + # At timeout and error, we manually check whether any worker has + # failed. Note that this is the only mechanism for Windows to detect + # worker failures. + failed_workers = [] + for worker_id, w in enumerate(self._workers): + if self._workers_status[worker_id] and not w.is_alive(): + failed_workers.append(w) + self._mark_worker_as_unavailable(worker_id) + if len(failed_workers) > 0: + pids_str = ", ".join(str(w.pid) for w in failed_workers) + raise RuntimeError( + f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" + ) from e + if isinstance(e, queue.Empty): + return (False, None) + + import errno + import tempfile + + try: + # Raise an exception if we are this close to the FDs limit. + # Apparently, trying to open only one file is not a sufficient + # test. + # See NOTE [ DataLoader on Linux and open files limit ] + fds_limit_margin = 10 + [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] + except OSError as e: + if e.errno == errno.EMFILE: + raise RuntimeError( + "Too many open files. Communication with the" + " workers is no longer possible. Please increase the" + " limit using `ulimit -n` in the shell or change the" + " sharing strategy by calling" + " `mindtorch.multiprocessing.set_sharing_strategy('file_system')`" + " at the beginning of your code" + ) from None + raise + + # NOTE [ DataLoader on Linux and open files limit ] + # + # On Linux when DataLoader is used with multiprocessing we pass the data between + # the root process and the workers through SHM files. We remove those files from + # the filesystem as soon as they are created and keep them alive by + # passing around their file descriptors through AF_UNIX sockets. (See + # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in + # the wiki (https://github.com/pytorch/pytorch/wiki).) + # + # This sometimes leads us to exceeding the open files limit. When that happens, + # and the offending file descriptor is coming over a socket, the `socket` Python + # package silently strips the file descriptor from the message, setting only the + # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that + # it _indicates that some control data were discarded due to lack of space in + # the buffer for ancillary data_). This might reflect the C implementation of + # AF_UNIX sockets. + # + # This behaviour can be reproduced with the script and instructions at the + # bottom of this note. + # + # When that happens, the standard Python `multiprocessing` (and not + # `mindtorch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` + # + # Sometimes, instead of the FD being stripped, you may get an `OSError: + # Too many open files`, both in the script below and in DataLoader. However, + # this is rare and seems to be nondeterministic. + # + # + # #!/usr/bin/env python3 + # import sys + # import socket + # import os + # import array + # import shutil + # import socket + # + # + # if len(sys.argv) != 4: + # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") + # sys.exit(1) + # + # if __name__ == '__main__': + # dirname = sys.argv[1] + # sock_path = dirname + "/sock" + # iterations = int(sys.argv[2]) + # def dummy_path(i): + # return dirname + "/" + str(i) + ".dummy" + # + # + # if sys.argv[3] == 'send': + # while not os.path.exists(sock_path): + # pass + # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # client.connect(sock_path) + # for i in range(iterations): + # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) + # ancdata = array.array('i', [fd]) + # msg = bytes([i % 256]) + # print("Sending fd ", fd, " (iteration #", i, ")") + # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) + # + # + # else: + # assert sys.argv[3] == 'recv' + # + # if os.path.exists(dirname): + # raise Exception("Directory exists") + # + # os.mkdir(dirname) + # + # print("Opening socket...") + # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + # server.bind(sock_path) + # + # print("Listening...") + # for i in range(iterations): + # a = array.array('i') + # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) + # assert(len(ancdata) == 1) + # cmsg_level, cmsg_type, cmsg_data = ancdata[0] + # a.frombytes(cmsg_data) + # print("Received fd ", a[0], " (iteration #", i, ")") + # + # shutil.rmtree(dirname) + # + # Steps to reproduce: + # + # 1. Run two shells and set lower file descriptor limit in the receiving one: + # (shell1) ulimit -n 1020 + # (shell2) ulimit -n 1022 + # + # 2. Run the script above with the `recv` option in the first shell + # (shell1) ./test_socket.py sock_tmp 1017 recv + # + # 3. Run the script with the `send` option in the second shell: + # (shell2) ./test_socket.py sock_tmp 1017 send + + def _get_data(self): + # Fetches data from `self._data_queue`. + # + # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, + # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` + # in a loop. This is the only mechanism to detect worker failures for + # Windows. For other platforms, a SIGCHLD handler is also used for + # worker failure detection. + # + # If `pin_memory=True`, we also need check if `pin_memory_thread` had + # died at timeouts. + if self._timeout > 0: + success, data = self._try_get_data(self._timeout) + if success: + return data + else: + raise RuntimeError( + f"DataLoader timed out after {self._timeout} seconds" + ) + elif self._pin_memory: + while self._pin_memory_thread.is_alive(): + success, data = self._try_get_data() + if success: + return data + else: + # while condition is false, i.e., pin_memory_thread died. + raise RuntimeError("Pin memory thread exited unexpectedly") + # In this case, `self._data_queue` is a `queue.Queue`,. But we don't + # need to call `.task_done()` because we don't use `.join()`. + else: + while True: + success, data = self._try_get_data() + if success: + return data + + def _next_data(self): + while True: + # If the worker responsible for `self._rcvd_idx` has already ended + # and was unable to fulfill this task (due to exhausting an `IterableDataset`), + # we try to advance `self._rcvd_idx` to find the next valid index. + # + # This part needs to run in the loop because both the `self._get_data()` + # call and `_IterableDatasetStopIteration` check below can mark + # extra worker(s) as dead. + while self._rcvd_idx < self._send_idx: + info = self._task_info.get(self._rcvd_idx, None) + if info: + worker_id = info[0] + if ( + len(info) == 2 or self._workers_status[worker_id] + ): # has data or is still active + break + del self._task_info[self._rcvd_idx] + self._rcvd_idx += 1 + else: + # no valid `self._rcvd_idx` is found (i.e., didn't break) + if not self._persistent_workers: + self._shutdown_workers() + raise StopIteration + + # Now `self._rcvd_idx` is the batch index we want to fetch + + # Check if the next sample has already been generated + if len(self._task_info[self._rcvd_idx]) == 2: + data = self._task_info.pop(self._rcvd_idx)[1] + self._rcvd_idx += 1 + return self._process_data(data) + + assert not self._shutdown and self._tasks_outstanding > 0 + idx, data = self._get_data() + self._tasks_outstanding -= 1 + if self._dataset_kind == _DatasetKind.Iterable: + # Check for _IterableDatasetStopIteration + if isinstance(data, _utils.worker._IterableDatasetStopIteration): + if self._persistent_workers: + self._workers_status[data.worker_id] = False + else: + self._mark_worker_as_unavailable(data.worker_id) + self._try_put_index() + continue + + if idx != self._rcvd_idx: + if not self._in_order: + # don't store it for later, process now + del self._task_info[idx] + return self._process_data(data) + # store out-of-order samples + self._task_info[idx] += (data,) + else: + del self._task_info[idx] + self._rcvd_idx += 1 + return self._process_data(data) + + def _try_put_index(self): + assert self._tasks_outstanding < self._prefetch_factor * self._num_workers + + try: + index = self._next_index() + except StopIteration: + return + for _ in range(self._num_workers): # find the next active worker, if any + worker_queue_idx = next(self._worker_queue_idx_cycle) + if self._workers_status[worker_queue_idx]: + break + else: + # not found (i.e., didn't break) + return + + self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined] + self._task_info[self._send_idx] = (worker_queue_idx,) + self._tasks_outstanding += 1 + self._send_idx += 1 + + def _process_data(self, data): + self._try_put_index() + if isinstance(data, ExceptionWrapper): + data.reraise() + return data + + def _mark_worker_as_unavailable(self, worker_id, shutdown=False): + # Mark a worker as having finished its work e.g., due to + # exhausting an `IterableDataset`. This should be used only when this + # `_MultiProcessingDataLoaderIter` is going to continue running. + + assert self._workers_status[worker_id] or ( + self._persistent_workers and shutdown + ) + + # Signal termination to that specific worker. + q = self._index_queues[worker_id] + # Indicate that no more data will be put on this queue by the current + # process. + q.put(None) + + # Note that we don't actually join the worker here, nor do we remove the + # worker's pid from C side struct because (1) joining may be slow, and + # (2) since we don't join, the worker may still raise error, and we + # prefer capturing those, rather than ignoring them, even though they + # are raised after the worker has finished its job. + # Joinning is deferred to `_shutdown_workers`, which it is called when + # all workers finish their jobs (e.g., `IterableDataset` replicas) or + # when this iterator is garbage collected. + + self._workers_status[worker_id] = False + + assert self._workers_done_event.is_set() == shutdown + + def _shutdown_workers(self): + # Called when shutting down this `_MultiProcessingDataLoaderIter`. + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on + # the logic of this function. + if ( + _utils is None + or _utils.python_exit_status is True + or _utils.python_exit_status is None + ): + # See (2) of the note. If Python is shutting down, do no-op. + return + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + if not self._shutdown: + self._shutdown = True + try: + # Normal exit when last reference is gone / iterator is depleted. + # See (1) and the second half of the note. + + # Exit `pin_memory_thread` first because exiting workers may leave + # corrupted data in `worker_result_queue` which `pin_memory_thread` + # reads from. + if hasattr(self, "_pin_memory_thread"): + # Use hasattr in case error happens before we set the attribute. + self._pin_memory_thread_done_event.set() + # Send something to pin_memory_thread in case it is waiting + # so that it can wake up and check `pin_memory_thread_done_event` + self._worker_result_queue.put((None, None)) + self._pin_memory_thread.join() + self._worker_result_queue.cancel_join_thread() + self._worker_result_queue.close() + + # Exit workers now. + self._workers_done_event.set() + for worker_id in range(len(self._workers)): + # Get number of workers from `len(self._workers)` instead of + # `self._num_workers` in case we error before starting all + # workers. + # If we are using workers_status with persistent_workers + # we have to shut it down because the worker is paused + if self._persistent_workers or self._workers_status[worker_id]: + self._mark_worker_as_unavailable(worker_id, shutdown=True) + for w in self._workers: + # We should be able to join here, but in case anything went + # wrong, we set a timeout and if the workers fail to join, + # they are killed in the `finally` block. + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + for q in self._index_queues: + q.cancel_join_thread() + q.close() + finally: + # Even though all this function does is putting into queues that + # we have called `cancel_join_thread` on, weird things can + # happen when a worker is killed by a signal, e.g., hanging in + # `Event.set()`. So we need to guard this with SIGCHLD handler, + # and remove pids from the C side data structure only at the + # end. + # + # FIXME: Unfortunately, for Windows, we are missing a worker + # error detection mechanism here in this function, as it + # doesn't provide a SIGCHLD handler. + if self._worker_pids_set: + # _utils.signal_handling._remove_worker_pids(id(self)) + self._worker_pids_set = False + for w in self._workers: + if w.is_alive(): + # Existing mechanisms try to make the workers exit + # peacefully, but in case that we unfortunately reach + # here, which we shouldn't, (e.g., pytorch/pytorch#39570), + # we kill the worker. + w.terminate() + + # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` + @staticmethod + def _clean_up_worker(w): + try: + w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) + finally: + if w.is_alive(): + w.terminate() + + def __del__(self): + self._shutdown_workers() \ No newline at end of file diff --git a/mindtorch/utils/data/dataset.py b/mindtorch/utils/data/dataset.py new file mode 100644 index 000000000..f460b91f6 --- /dev/null +++ b/mindtorch/utils/data/dataset.py @@ -0,0 +1,489 @@ +# mypy: allow-untyped-defs +import bisect +import itertools +import math +import warnings +from typing import ( + cast, + Dict, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) +from typing_extensions import deprecated + +# No 'default_generator' in torch/__init__.pyi +from ... import default_generator, Generator, Tensor +from ...ops import randperm + +__all__ = [ + "Dataset", + "IterableDataset", + "TensorDataset", + "StackDataset", + "ConcatDataset", + "ChainDataset", + "Subset", + "random_split", +] + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_T_dict = Dict[str, _T_co] +_T_tuple = Tuple[_T_co, ...] +_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict) + + +class Dataset(Generic[_T_co]): + r"""An abstract class representing a :class:`Dataset`. + + All datasets that represent a map from keys to data samples should subclass + it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a + data sample for a given key. Subclasses could also optionally overwrite + :meth:`__len__`, which is expected to return the size of the dataset by many + :class:`~mindtorch.utils.data.Sampler` implementations and the default options + of :class:`~mindtorch.utils.data.DataLoader`. Subclasses could also + optionally implement :meth:`__getitems__`, for speedup batched samples + loading. This method accepts list of indices of samples of batch and returns + list of samples. + + .. note:: + :class:`~mindtorch.utils.data.DataLoader` by default constructs an index + sampler that yields integral indices. To make it work with a map-style + dataset with non-integral indices/keys, a custom sampler must be provided. + """ + + def __getitem__(self, index) -> _T_co: + raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") + + # def __getitems__(self, indices: List) -> List[_T_co]: + # Not implemented to prevent false-positives in fetcher check in + # mindtorch.utils.data._utils.fetch._MapDatasetFetcher + + def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]": + return ConcatDataset([self, other]) + + # No `def __len__(self)` default? + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # in pytorch/torch/utils/data/sampler.py + + +class IterableDataset(Dataset[_T_co], Iterable[_T_co]): + r"""An iterable Dataset. + + All datasets that represent an iterable of data samples should subclass it. + Such form of datasets is particularly useful when data come from a stream. + + All subclasses should overwrite :meth:`__iter__`, which would return an + iterator of samples in this dataset. + + When a subclass is used with :class:`~mindtorch.utils.data.DataLoader`, each + item in the dataset will be yielded from the :class:`~mindtorch.utils.data.DataLoader` + iterator. When :attr:`num_workers > 0`, each worker process will have a + different copy of the dataset object, so it is often desired to configure + each copy independently to avoid having duplicate data returned from the + workers. :func:`~mindtorch.utils.data.get_worker_info`, when called in a worker + process, returns information about the worker. It can be used in either the + dataset's :meth:`__iter__` method or the :class:`~mindtorch.utils.data.DataLoader` 's + :attr:`worker_init_fn` option to modify each copy's behavior. + + Example 1: splitting workload across all workers in :meth:`__iter__`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> # xdoctest: +SKIP("Fails on MacOS12") + >>> class MyIterableDataset(mindtorch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... worker_info = mindtorch.utils.data.get_worker_info() + ... if worker_info is None: # single-process data loading, return the full iterator + ... iter_start = self.start + ... iter_end = self.end + ... else: # in a worker process + ... # split workload + ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... iter_start = self.start + worker_id * per_worker + ... iter_end = min(iter_start + per_worker, self.end) + ... return iter(range(iter_start, iter_end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(mindtorch.utils.data.DataLoader(ds, num_workers=0))) + [tensor([3]), tensor([4]), tensor([5]), tensor([6])] + + >>> # xdoctest: +REQUIRES(POSIX) + >>> # Mult-process loading with two worker processes + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(mindtorch.utils.data.DataLoader(ds, num_workers=2))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + >>> # With even more workers + >>> # xdoctest: +IGNORE_WANT("non deterministic") + >>> print(list(mindtorch.utils.data.DataLoader(ds, num_workers=12))) + [tensor([3]), tensor([5]), tensor([4]), tensor([6])] + + Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) + >>> class MyIterableDataset(mindtorch.utils.data.IterableDataset): + ... def __init__(self, start, end): + ... super(MyIterableDataset).__init__() + ... assert end > start, "this example code only works with end >= start" + ... self.start = start + ... self.end = end + ... + ... def __iter__(self): + ... return iter(range(self.start, self.end)) + ... + >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. + >>> ds = MyIterableDataset(start=3, end=7) + + >>> # Single-process loading + >>> print(list(mindtorch.utils.data.DataLoader(ds, num_workers=0))) + [3, 4, 5, 6] + >>> + >>> # Directly doing multi-process loading yields duplicate data + >>> print(list(mindtorch.utils.data.DataLoader(ds, num_workers=2))) + [3, 3, 4, 4, 5, 5, 6, 6] + + >>> # Define a `worker_init_fn` that configures each dataset copy differently + >>> def worker_init_fn(worker_id): + ... worker_info = mindtorch.utils.data.get_worker_info() + ... dataset = worker_info.dataset # the dataset copy in this worker process + ... overall_start = dataset.start + ... overall_end = dataset.end + ... # configure the dataset to only process the split workload + ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + ... worker_id = worker_info.id + ... dataset.start = overall_start + worker_id * per_worker + ... dataset.end = min(dataset.start + per_worker, overall_end) + ... + + >>> # Mult-process loading with the custom `worker_init_fn` + >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. + >>> print(list(mindtorch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) + [3, 5, 4, 6] + + >>> # With even more workers + >>> print(list(mindtorch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) + [3, 4, 5, 6] + """ + + def __add__(self, other: Dataset[_T_co]): + return ChainDataset([self, other]) + + # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. + # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + + +class TensorDataset(Dataset[Tuple[Tensor, ...]]): + r"""Dataset wrapping tensors. + + Each sample will be retrieved by indexing tensors along the first dimension. + + Args: + *tensors (Tensor): tensors that have the same size of the first dimension. + """ + + tensors: Tuple[Tensor, ...] + + def __init__(self, *tensors: Tensor) -> None: + assert all( + tensors[0].size(0) == tensor.size(0) for tensor in tensors + ), "Size mismatch between tensors" + self.tensors = tensors + + def __getitem__(self, index): + return tuple(tensor[index] for tensor in self.tensors) + + def __len__(self): + return self.tensors[0].size(0) + + +class StackDataset(Dataset[_T_stack]): + r"""Dataset as a stacking of multiple datasets. + + This class is useful to assemble different parts of complex input data, given as datasets. + + Example: + >>> # xdoctest: +SKIP + >>> images = ImageDataset() + >>> texts = TextDataset() + >>> tuple_stack = StackDataset(images, texts) + >>> tuple_stack[0] == (images[0], texts[0]) + >>> dict_stack = StackDataset(image=images, text=texts) + >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} + + Args: + *args (Dataset): Datasets for stacking returned as tuple. + **kwargs (Dataset): Datasets for stacking returned as dict. + """ + + datasets: Union[tuple, dict] + + def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: + if args: + if kwargs: + raise ValueError( + "Supported either ``tuple``- (via ``args``) or" + "``dict``- (via ``kwargs``) like input/output, but both types are given." + ) + self._length = len(args[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = args + elif kwargs: + tmp = list(kwargs.values()) + self._length = len(tmp[0]) # type: ignore[arg-type] + if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] + raise ValueError("Size mismatch between datasets") + self.datasets = kwargs + else: + raise ValueError("At least one dataset should be passed") + + def __getitem__(self, index): + if isinstance(self.datasets, dict): + return {k: dataset[index] for k, dataset in self.datasets.items()} + return tuple(dataset[index] for dataset in self.datasets) + + def __getitems__(self, indices: list): + # add batched sampling support when parent datasets supports it. + if isinstance(self.datasets, dict): + dict_batch: List[_T_dict] = [{} for _ in indices] + for k, dataset in self.datasets.items(): + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, d_sample in zip(items, dict_batch): + d_sample[k] = data + else: + for idx, d_sample in zip(indices, dict_batch): + d_sample[k] = dataset[idx] + return dict_batch + + # tuple data + list_batch: List[list] = [[] for _ in indices] + for dataset in self.datasets: + if callable(getattr(dataset, "__getitems__", None)): + items = dataset.__getitems__(indices) # type: ignore[attr-defined] + if len(items) != len(indices): + raise ValueError( + "Nested dataset's output size mismatch." + f" Expected {len(indices)}, got {len(items)}" + ) + for data, t_sample in zip(items, list_batch): + t_sample.append(data) + else: + for idx, t_sample in zip(indices, list_batch): + t_sample.append(dataset[idx]) + tuple_batch: List[_T_tuple] = [tuple(sample) for sample in list_batch] + return tuple_batch + + def __len__(self): + return self._length + + +class ConcatDataset(Dataset[_T_co]): + r"""Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + """ + + datasets: List[Dataset[_T_co]] + cumulative_sizes: List[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = list(datasets) + assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] + for d in self.datasets: + assert not isinstance( + d, IterableDataset + ), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError( + "absolute value of index should not exceed dataset length" + ) + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + @deprecated( + "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", + category=FutureWarning, + ) + def cummulative_sizes(self): + return self.cumulative_sizes + + +class ChainDataset(IterableDataset): + r"""Dataset for chaining multiple :class:`IterableDataset` s. + + This class is useful to assemble different existing dataset streams. The + chaining operation is done on-the-fly, so concatenating large-scale + datasets with this class will be efficient. + + Args: + datasets (iterable of IterableDataset): datasets to be chained together + """ + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + self.datasets = datasets + + def __iter__(self): + for d in self.datasets: + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" + yield from d + + def __len__(self): + total = 0 + for d in self.datasets: + assert isinstance( + d, IterableDataset + ), "ChainDataset only supports IterableDataset" + total += len(d) # type: ignore[arg-type] + return total + + +class Subset(Dataset[_T_co]): + r""" + Subset of a dataset at specified indices. + + Args: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + + dataset: Dataset[_T_co] + indices: Sequence[int] + + def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None: + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.dataset[[self.indices[i] for i in idx]] + return self.dataset[self.indices[idx]] + + def __getitems__(self, indices: List[int]) -> List[_T_co]: + # add batched sampling support when parent dataset supports it. + # see mindtorch.utils.data._utils.fetch._MapDatasetFetcher + if callable(getattr(self.dataset, "__getitems__", None)): + return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] + else: + return [self.dataset[self.indices[idx]] for idx in indices] + + def __len__(self): + return len(self.indices) + + +def random_split( + dataset: Dataset[_T], + lengths: Sequence[Union[int, float]], + generator: Optional[Generator] = default_generator, +) -> List[Subset[_T]]: + r""" + Randomly split a dataset into non-overlapping new datasets of given lengths. + + If a list of fractions that sum up to 1 is given, + the lengths will be computed automatically as + floor(frac * len(dataset)) for each fraction provided. + + After computing the lengths, if there are any remainders, 1 count will be + distributed in round-robin fashion to the lengths + until there are no remainders left. + + Optionally fix the generator for reproducible results, e.g.: + + Example: + >>> # xdoctest: +SKIP + >>> generator1 = mindtorch.Generator().manual_seed(42) + >>> generator2 = mindtorch.Generator().manual_seed(42) + >>> random_split(range(10), [3, 7], generator=generator1) + >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) + + Args: + dataset (Dataset): Dataset to be split + lengths (sequence): lengths or fractions of splits to be produced + generator (Generator): Generator used for the random permutation. + """ + if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: + subset_lengths: List[int] = [] + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = int( + math.floor(len(dataset) * frac) # type: ignore[arg-type] + ) + subset_lengths.append(n_items_in_split) + remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] + # add 1 to all the lengths in round-robin fashion until the remainder is 0 + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn( + f"Length of split at index {i} is 0. " + f"This might result in an empty dataset." + ) + + # Cannot verify that dataset is Sized + if sum(lengths) != len(dataset): # type: ignore[arg-type] + raise ValueError( + "Sum of input lengths does not equal the length of the input dataset!" + ) + + indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] + lengths = cast(Sequence[int], lengths) + return [ + Subset(dataset, indices[offset - length : offset]) + for offset, length in zip(itertools.accumulate(lengths), lengths) + ] \ No newline at end of file diff --git a/mindtorch/utils/data/distributed.py b/mindtorch/utils/data/distributed.py new file mode 100644 index 000000000..e274a0a54 --- /dev/null +++ b/mindtorch/utils/data/distributed.py @@ -0,0 +1,150 @@ +import math +from collections.abc import Iterator +from typing import Optional, TypeVar + +import mindtorch +from mindtorch import distributed as dist +from mindtorch.utils.data.dataset import Dataset +from mindtorch.utils.data.sampler import Sampler + + +__all__ = ["DistributedSampler"] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class DistributedSampler(Sampler[_T_co]): + r"""Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`mindtorch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class:`~mindtorch.utils.data.DistributedSampler` instance as a + :class:`~mindtorch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size and that any instance of it always + returns the same elements in the same order. + + Args: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`world_size` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + + .. warning:: + In distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + Example:: + + >>> # xdoctest: +SKIP + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[_T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = mindtorch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = mindtorch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch \ No newline at end of file diff --git a/mindtorch/utils/data/sampler.py b/mindtorch/utils/data/sampler.py new file mode 100644 index 000000000..717485c1f --- /dev/null +++ b/mindtorch/utils/data/sampler.py @@ -0,0 +1,353 @@ +# mypy: allow-untyped-defs +import itertools +from typing import ( + Generic, + Iterable, + Iterator, + List, + Optional, + Sequence, + Sized, + TypeVar, + Union, +) + +import mindtorch + + +__all__ = [ + "BatchSampler", + "RandomSampler", + "Sampler", + "SequentialSampler", + "SubsetRandomSampler", + "WeightedRandomSampler", +] + + +_T_co = TypeVar("_T_co", covariant=True) + + +class Sampler(Generic[_T_co]): + r"""Base class for all Samplers. + + Every Sampler subclass has to provide an :meth:`__iter__` method, providing a + way to iterate over indices or lists of indices (batches) of dataset elements, + and may provide a :meth:`__len__` method that returns the length of the returned iterators. + + Args: + data_source (Dataset): This argument is not used and will be removed in 2.2.0. + You may still have custom implementation that utilizes it. + + Example: + >>> # xdoctest: +SKIP + >>> class AccedingSequenceLengthSampler(Sampler[int]): + >>> def __init__(self, data: List[str]) -> None: + >>> self.data = data + >>> + >>> def __len__(self) -> int: + >>> return len(self.data) + >>> + >>> def __iter__(self) -> Iterator[int]: + >>> sizes = mindtorch.tensor([len(x) for x in self.data]) + >>> yield from mindtorch.argsort(sizes).tolist() + >>> + >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): + >>> def __init__(self, data: List[str], batch_size: int) -> None: + >>> self.data = data + >>> self.batch_size = batch_size + >>> + >>> def __len__(self) -> int: + >>> return (len(self.data) + self.batch_size - 1) // self.batch_size + >>> + >>> def __iter__(self) -> Iterator[List[int]]: + >>> sizes = mindtorch.tensor([len(x) for x in self.data]) + >>> for batch in mindtorch.chunk(mindtorch.argsort(sizes), len(self)): + >>> yield batch.tolist() + + .. note:: The :meth:`__len__` method isn't strictly required by + :class:`~mindtorch.utils.data.DataLoader`, but is expected in any + calculation involving the length of a :class:`~mindtorch.utils.data.DataLoader`. + """ + + def __init__(self, data_source: Optional[Sized] = None) -> None: + if data_source is not None: + import warnings + + warnings.warn( + "`data_source` argument is not used and will be removed in 2.2.0." + "You may still have custom implementation that utilizes it." + ) + + def __iter__(self) -> Iterator[_T_co]: + raise NotImplementedError + + # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + # + # Many times we have an abstract class representing a collection/iterable of + # data, e.g., `mindtorch.utils.data.Sampler`, with its subclasses optionally + # implementing a `__len__` method. In such cases, we must make sure to not + # provide a default implementation, because both straightforward default + # implementations have their issues: + # + # + `return NotImplemented`: + # Calling `len(subclass_instance)` raises: + # TypeError: 'NotImplementedType' object cannot be interpreted as an integer + # + # + `raise NotImplementedError`: + # This prevents triggering some fallback behavior. E.g., the built-in + # `list(X)` tries to call `len(X)` first, and executes a different code + # path if the method is not found or `NotImplemented` is returned, while + # raising a `NotImplementedError` will propagate and make the call fail + # where it could have used `__iter__` to complete the call. + # + # Thus, the only two sensible things to do are + # + # + **not** provide a default `__len__`. + # + # + raise a `TypeError` instead, which is what Python uses when users call + # a method that is not defined on an object. + # (@ssnl verifies that this works on at least Python 3.7.) + + +class SequentialSampler(Sampler[int]): + r"""Samples elements sequentially, always in the same order. + + Args: + data_source (Dataset): dataset to sample from + """ + + data_source: Sized + + def __init__(self, data_source: Sized) -> None: + self.data_source = data_source + + def __iter__(self) -> Iterator[int]: + return iter(range(len(self.data_source))) + + def __len__(self) -> int: + return len(self.data_source) + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + + data_source: Sized + replacement: bool + + def __init__( + self, + data_source: Sized, + replacement: bool = False, + num_samples: Optional[int] = None, + generator=None, + ) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + if not isinstance(self.replacement, bool): + raise TypeError( + f"replacement should be a boolean value, but got replacement={self.replacement}" + ) + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" + ) + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(mindtorch.empty((), dtype=mindtorch.int64).random_().item()) + generator = mindtorch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from mindtorch.randint( + high=n, size=(32,), dtype=mindtorch.int64, generator=generator + ).tolist() + yield from mindtorch.randint( + high=n, + size=(self.num_samples % 32,), + dtype=mindtorch.int64, + generator=generator, + ).tolist() + else: + for _ in range(self.num_samples // n): + yield from mindtorch.randperm(n, generator=generator).tolist() + yield from mindtorch.randperm(n, generator=generator).tolist()[ + : self.num_samples % n + ] + + def __len__(self) -> int: + return self.num_samples + + +class SubsetRandomSampler(Sampler[int]): + r"""Samples elements randomly from a given list of indices, without replacement. + + Args: + indices (sequence): a sequence of indices + generator (Generator): Generator used in sampling. + """ + + indices: Sequence[int] + + def __init__(self, indices: Sequence[int], generator=None) -> None: + self.indices = indices + self.generator = generator + + def __iter__(self) -> Iterator[int]: + for i in mindtorch.randperm(len(self.indices), generator=self.generator): + yield self.indices[i] + + def __len__(self) -> int: + return len(self.indices) + + +class WeightedRandomSampler(Sampler[int]): + r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). + + Args: + weights (sequence) : a sequence of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + generator (Generator): Generator used in sampling. + + Example: + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + [4, 4, 1, 4, 5] + >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) + [0, 1, 4, 3, 2] + """ + + weights: mindtorch.Tensor + num_samples: int + replacement: bool + + def __init__( + self, + weights: Sequence[float], + num_samples: int, + replacement: bool = True, + generator=None, + ) -> None: + if ( + not isinstance(num_samples, int) + or isinstance(num_samples, bool) + or num_samples <= 0 + ): + raise ValueError( + f"num_samples should be a positive integer value, but got num_samples={num_samples}" + ) + if not isinstance(replacement, bool): + raise ValueError( + f"replacement should be a boolean value, but got replacement={replacement}" + ) + + weights_tensor = mindtorch.as_tensor(weights, dtype=mindtorch.double) + if len(weights_tensor.shape) != 1: + raise ValueError( + "weights should be a 1d sequence but given " + f"weights have shape {tuple(weights_tensor.shape)}" + ) + + self.weights = weights_tensor + self.num_samples = num_samples + self.replacement = replacement + self.generator = generator + + def __iter__(self) -> Iterator[int]: + rand_tensor = mindtorch.multinomial( + self.weights, self.num_samples, self.replacement, generator=self.generator + ) + yield from iter(rand_tensor.tolist()) + + def __len__(self) -> int: + return self.num_samples + + +class BatchSampler(Sampler[List[int]]): + r"""Wraps another sampler to yield a mini-batch of indices. + + Args: + sampler (Sampler or Iterable): Base sampler. Can be any iterable object + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + + Example: + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None: + # Since collections.abc.Iterable does not check for `__getitem__`, which + # is one way for an object to be an iterable, we don't do an `isinstance` + # check here. + if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ + batch_size <= 0: + raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") + if not isinstance(drop_last, bool): + raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}") + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self) -> Iterator[List[int]]: + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + if self.drop_last: + sampler_iter = iter(self.sampler) + while True: + try: + batch = [next(sampler_iter) for _ in range(self.batch_size)] + yield batch + except StopIteration: + break + else: + batch = [0] * self.batch_size + idx_in_batch = 0 + for idx in self.sampler: + batch[idx_in_batch] = idx + idx_in_batch += 1 + if idx_in_batch == self.batch_size: + yield batch + idx_in_batch = 0 + batch = [0] * self.batch_size + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self) -> int: + # Can only be called if self.sampler has __len__ implemented + # We cannot enforce this condition, so we turn off typechecking for the + # implementation below. + # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] + if self.drop_last: + return len(self.sampler) // self.batch_size # type: ignore[arg-type] + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]