Skip to content

Commit

Permalink
[fix] don't import ProcessGroup eagerly (#1074)
Browse files Browse the repository at this point in the history
* [fix] don't import ProcessGroup eagerly

- move the import into typing since it is only used for type checking
- fixes #1057

* more fixes

* one more

* tested at least

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Sep 23, 2022
1 parent d8fc94d commit 47ce21a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 17 deletions.
11 changes: 7 additions & 4 deletions fairscale/internal/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

from enum import Enum
import sys
from typing import List, Optional, Sequence
from typing import TYPE_CHECKING, List, Optional, Sequence

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn.functional as F

if TYPE_CHECKING:
# See comments in FSDP code for reason of this import.
from torch.distributed import ProcessGroup


def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
Expand All @@ -27,7 +30,7 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
return chunks


def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
def validate_process_group(device: torch.device, process_group: "ProcessGroup") -> None:
"""Do a quick test in case user called FSDP without calling torch.cuda.set_device()
correctly. This can easily happen in cpu_offload case where the model resides on
the CPU.
Expand Down Expand Up @@ -67,7 +70,7 @@ class ProcessGroupName(str, Enum):

def get_process_group_cached(
name: ProcessGroupName = ProcessGroupName.default, ranks: Optional[Sequence[int]] = None
) -> ProcessGroup:
) -> "ProcessGroup":
"""
Singleton PyTorch distributed group cache. Inspired by the code from fairseq.
Expand Down
14 changes: 8 additions & 6 deletions fairscale/internal/reduce_scatter_bucketer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

import functools
import os
from typing import Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup

if TYPE_CHECKING:
from torch.distributed import ProcessGroup

# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
Expand All @@ -20,7 +22,7 @@


class Bucket:
def __init__(self, data: Tensor, group: ProcessGroup):
def __init__(self, data: Tensor, group: "ProcessGroup"):
self.data = data
self.group = group
self.offset = 0
Expand Down Expand Up @@ -99,13 +101,13 @@ class ReduceScatterBucketer:

def __init__(self, bucket_cap_mb: int = 25):
self.bucket_cap_mb = bucket_cap_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
self.buckets: Dict[Tuple[torch.dtype, torch.device, "ProcessGroup"], Bucket] = {}

@torch.no_grad()
def reduce_scatter_async(
self,
input_list: List[Tensor],
group: ProcessGroup,
group: "ProcessGroup",
callback_fn: Optional[Callable] = None,
) -> None:
"""
Expand Down Expand Up @@ -186,7 +188,7 @@ def _get_shard_size(self, element_size: int, num_shards: int) -> int:
bucket_size = self.bucket_cap_mb * MB / element_size
return int(bucket_size // num_shards)

def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
def _get_bucket(self, tensor: Tensor, group: "ProcessGroup") -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
Expand Down
13 changes: 10 additions & 3 deletions fairscale/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@

from typing import List

import torch.distributed as dist

from .checkpoint import checkpoint_wrapper
from .data_parallel import FullyShardedDataParallel, ShardedDataParallel
from .data_parallel import FullyShardedDataParallel

if dist.is_available():
# Prevent import failure if dist is not available. #1057
from .data_parallel import ShardedDataParallel
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper

from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate
from .pipe import Pipe, PipeRPCWrapper
from .wrap import auto_wrap, config_auto_wrap_policy, default_auto_wrap_policy, enable_wrap, wrap

__all__: List[str] = []
7 changes: 6 additions & 1 deletion fairscale/nn/data_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@

from typing import List

import torch.distributed as dist

from .fully_sharded_data_parallel import (
FullyShardedDataParallel,
OffloadConfig,
TrainingState,
auto_wrap_bn,
no_pre_load_state_dict_hook,
)
from .sharded_ddp import ShardedDataParallel

if dist.is_available():
# Prevent import failure if dist is not available. #1057
from .sharded_ddp import ShardedDataParallel

__all__: List[str] = []
14 changes: 11 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
Expand All @@ -58,6 +57,12 @@

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401

# See #1057. On some platform, torch.distributed may not have ProcessGroup
# So we only import it during type checking, which is not done on default
# import and only done by developer (doing it on supported platforms I presume).
from torch.distributed import ProcessGroup

# TODO: Remove the toggle here when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
Expand Down Expand Up @@ -308,7 +313,7 @@ class FullyShardedDataParallel(nn.Module):
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
process_group: Optional["ProcessGroup"] = None,
# The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
reshard_after_forward: bool = True,
Expand Down Expand Up @@ -352,6 +357,9 @@ def __init__(
self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter)
else:
# If a specific process group is passed in, the reduce_scatter will use the passed in process group.
# Delay the import here since this type may not be available on certain platforms.
from torch.distributed import ProcessGroup

if isinstance(process_group_reduce_scatter, ProcessGroup):
self.process_group_reduce_scatter = process_group_reduce_scatter
else:
Expand Down Expand Up @@ -2648,7 +2656,7 @@ def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor:
def auto_wrap_bn(
module: nn.Module,
single_rank_pg: bool = False,
process_group: Optional[ProcessGroup] = None,
process_group: Optional["ProcessGroup"] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
wrap_it: bool = True,
assert_on_collision: bool = True,
Expand Down

0 comments on commit 47ce21a

Please sign in to comment.