Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add oneflow dist backend #59

Merged
merged 21 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ jobs:
name: Install TensorFlow via pip
command: |
pip install tensorflow==<< parameters.tensorflow >>
- run:
name: Install OneFlow via pip
command: |
pip install -f https://release.oneflow.info oneflow==0.8.0+cpu
- run:
name: Install mmeval and dependencies
command: |
Expand Down Expand Up @@ -120,13 +124,14 @@ jobs:
sudo apt-get install libcudnn8=${cudnn_version}-1+${cuda_version}
sudo apt-get install libcudnn8-dev=${cudnn_version}-1+${cuda_version}
- run:
name: Install PyTorch and Paddle via pip
name: Install PyTorch, Paddle and OneFlow via pip
command: |
pyenv global 3.9.2
pip install --upgrade pip
python -V
pip install torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
pip install paddlepaddle-gpu==2.3.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
pip install --pre oneflow -f https://staging.oneflow.info/branch/master/cu112
- run:
name: Install mmeval and dependencies
command: |
Expand Down
6 changes: 4 additions & 2 deletions mmeval/core/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from typing import List, Optional, no_type_check

from .dist_backends import (BaseDistBackend, MPI4PyDist, NonDist, PaddleDist,
TFHorovodDist, TorchCPUDist, TorchCUDADist)
from .dist_backends import (BaseDistBackend, MPI4PyDist, NonDist, OneFlowDist,
PaddleDist, TFHorovodDist, TorchCPUDist,
TorchCUDADist)

_DIST_BACKENDS = {
'non_dist': NonDist,
'mpi4py': MPI4PyDist,
'oneflow': OneFlowDist,
'tf_horovod': TFHorovodDist,
'torch_cpu': TorchCPUDist,
'torch_cuda': TorchCUDADist,
Expand Down
4 changes: 3 additions & 1 deletion mmeval/core/dist_backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from .base_backend import BaseDistBackend, TensorBaseDistBackend
from .mpi4py import MPI4PyDist
from .non_dist import NonDist
from .oneflow_dist import OneFlowDist
from .paddle_dist import PaddleDist
from .tf_horovod import TFHorovodDist
from .torch_cpu import TorchCPUDist
from .torch_cuda import TorchCUDADist

__all__ = [
'BaseDistBackend', 'TensorBaseDistBackend', 'MPI4PyDist', 'NonDist',
'TFHorovodDist', 'TorchCPUDist', 'TorchCUDADist', 'PaddleDist'
'OneFlowDist', 'TFHorovodDist', 'TorchCPUDist', 'TorchCUDADist',
'PaddleDist'
]
134 changes: 134 additions & 0 deletions mmeval/core/dist_backends/oneflow_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) OpenMMLab. All rights reserved.

import numpy as np
import pickle
from typing import TYPE_CHECKING, Any, List, Tuple, TypeVar, Union

from mmeval.utils import try_import
from .base_backend import TensorBaseDistBackend

if TYPE_CHECKING:
import oneflow
import oneflow as flow
import oneflow.framework.check_point_v2 as check_point_v2
else:
flow = try_import('oneflow')
check_point_v2 = try_import('oneflow.framework.check_point_v2')

Tensor = TypeVar('Tensor', bound='oneflow.Tensor')


class OneFlowDist(TensorBaseDistBackend):
"""A distributed communication backend for oneflow."""

def __init__(self) -> None:
super().__init__()
if flow is None:
raise ImportError(f'For availability of {self.__class__.__name__},'
' please install oneflow first.')

@property
def is_initialized(self) -> bool:
"""Returns True if the distributed environment has been initialized.

Returns:
bool: Returns True if the distributed environment has been
initialized, otherwise returns False.
"""
try:
flow.env.get_world_size()
is_init = True
except ValueError:
is_init = False
return is_init

@property
def rank(self) -> int:
"""Returns the rank index of the current process group."""
return flow.env.get_rank()

@property
def world_size(self) -> int:
"""Returns the world size of the current process group."""
return flow.env.get_world_size()

def _object_to_tensor(self, obj: Any) -> Tuple[Tensor, Tensor]:
"""Convert the given object to a tensor via `pickle.dumps`.

Args:
obj (any): Any pickle-able python object.

Returns:
Tuple: A tuple of the tensor converted from given object and the
tensor size.
"""
buffer = pickle.dumps(obj)
storage = np.frombuffer(buffer, dtype=np.int8)
obj_tensor = flow.tensor(storage)
obj_size_tensor = flow.tensor([obj_tensor.numel()])
return obj_tensor, obj_size_tensor

def _tensor_to_object(self, tensor: Tensor,
tensor_size: Union[int, Tensor]) -> Any:
"""Convert the given Tensor to a object via `pickle.loads`.

Args:
tenosr (Tensor): A tensor-like data.
tensor_size (int or Tensor): The tensor size of the given Tensor to
be convert object.

Returns:
Any: The object converted from the given tensor.
"""
size = int(tensor_size)
buffer = tensor.cpu().numpy().tobytes()[:size]
obj = pickle.loads(buffer)
return obj

def _pad_tensor(self, tensor: Tensor,
max_size: Union[int, Tensor]) -> Tensor: # yapf: disable
"""Padding the given tensor to the given size.

Args:
tensor (Tensor): A tensor-like data to be padded.
max_size (int or Tensor): The max tensor size that for tensor
padding.

Returns:
Tensor: The padded tensor.
"""
max_size = int(max_size)
padding = flow.zeros((max_size - tensor.numel(), ),
dtype=flow.int8,
device=tensor.device)
tensor = flow.cat((tensor, padding), dim=0)
return tensor

def _all_gather(self, tensor: Tensor) -> List[Tensor]:
"""All gather the given tensor.

Args:
tensor (Tensor): The tensor for all gather.

Returns:
list: A list of the gathered tensor.
"""
tensor_list = [
flow.empty_like(tensor).to(tensor.device)
for _ in range(self.world_size)
]
flow.comm.all_gather(tensor_list, tensor)
return tensor_list

def _broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:
"""Broadcast the given object from the source rank.

Args:
tensor (Tensor): The tensor for broadcast.
src (int): The source rank index.

Returns:
Tensor: The broadcast tensor.
"""
flow.comm.broadcast(tensor, src=src)
return tensor
133 changes: 133 additions & 0 deletions tests/test_core/test_dist_backends/test_oneflow_dist_single_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) OpenMMLab. All rights reserved.

import multiprocessing as mp
import numpy as np
import os
import pytest

# check if current process is launch via mpirun
if os.environ.get('OMPI_COMM_WORLD_SIZE', '0') != '0':
pytest.skip(allow_module_level=True)

from mmeval.core.dist_backends.oneflow_dist import OneFlowDist

flow = pytest.importorskip('oneflow')


def equal(a, b):
if isinstance(a, dict):
return all(equal(a[k], b[k]) for k in a.keys())
elif isinstance(a, (list, tuple)):
return all(equal(ai, bi) for ai, bi in zip(a, b))
elif isinstance(a, (int, float, bool, str)):
return a == b
elif isinstance(a, flow.Tensor):
return np.all(a.numpy() == b.numpy())
else:
return False


def _create_obj_list(rank, world_size):
obj_list = []
for idx in range(world_size):
rank = idx + 1
obj = dict()
obj['rank'] = idx
obj['ranks'] = list(range(world_size))
obj['world_size'] = world_size
obj['data'] = [flow.tensor([rank * 1.0, rank * 2.0, rank * 3.0])]
obj_list.append(obj)
return obj_list


def _oneflow_dist_all_gather_fn(rank, world_size):
dist_comm = OneFlowDist()

assert dist_comm.is_initialized
assert dist_comm.world_size == world_size

rank = dist_comm.rank

obj_list = _create_obj_list(rank, world_size)

local_obj = obj_list[rank]
print(f'rank {rank}, local_obj {local_obj}')

gather_obj_list = dist_comm.all_gather_object(local_obj)
print(f'rank {rank}, gather_obj_list {gather_obj_list}')
assert equal(gather_obj_list, obj_list)


def _oneflow_dist_broadcast_fn(rank, world_size):
dist_comm = OneFlowDist()

assert dist_comm.is_initialized
assert dist_comm.world_size == world_size

rank = dist_comm.rank

obj_list = _create_obj_list(rank, world_size)

local_obj = obj_list[rank]

print(f'rank {rank}, obj {local_obj}')
broadcast_obj = dist_comm.broadcast_object(local_obj, src=0)
print(f'rank {rank}, broadcast_obj {broadcast_obj}')

assert equal(broadcast_obj, obj_list[0])


def _init_oneflow_dist(local_rank, world_size, port):
os.environ['RANK'] = str(local_rank)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_ADDR'] = '127.0.0.1'


def _reset_dist_env():
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'


def _launch_dist_fn(target_fn, process_num, comm_port):
ctx = mp.get_context('spawn')
process_list = []
for rank in range(process_num):
_init_oneflow_dist(rank, process_num, comm_port)
p = ctx.Process(target=target_fn, args=(rank, process_num))
p.start()
process_list.append(p)

for p in process_list:
p.join()

# reset the env variable to prevent getting stuck when importing oneflow
_reset_dist_env()


@pytest.mark.parametrize(
argnames=['process_num', 'comm_port'],
argvalues=[
(1, 2350),
(2, 2350),
(4, 2350),
])
def test_broadcast_object(process_num, comm_port):
_launch_dist_fn(_oneflow_dist_broadcast_fn, process_num, comm_port)


@pytest.mark.parametrize(
argnames=['process_num', 'comm_port'],
argvalues=[
(1, 2350),
(2, 2350),
(4, 2350),
])
def test_all_gather_object(process_num, comm_port):
_launch_dist_fn(_oneflow_dist_all_gather_fn, process_num, comm_port)


if __name__ == '__main__':
pytest.main([__file__, '-vvv', '--capture=no'])
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def test_mpi_all_gather_object(process_num, comm_port):

try:
nccl_version = torch.cuda.nccl.version()
if isinstance(nccl_version, tuple):
MAJOR, MINOR, PATCH = nccl_version
nccl_version = MAJOR * 1000 + MINOR * 100 + PATCH
except Exception:
nccl_version = 0

Expand All @@ -122,7 +125,7 @@ def test_mpi_all_gather_object(process_num, comm_port):
not torch_dist.is_nccl_available(),
reason='NCCL backend is not available.')
@pytest.mark.skipif(
torch.cuda.device_count() < 0,
torch.cuda.device_count() < 1,
reason='CUDA device count must greater than 0.')
@pytest.mark.parametrize(
argnames=['process_num', 'comm_port'],
Expand Down Expand Up @@ -175,7 +178,7 @@ def test_mpi_broadcast_object(process_num, comm_port):
1,
2350,
marks=pytest.mark.skipif(
torch.cuda.device_count() < 0,
torch.cuda.device_count() < 1,
reason='CUDA device count must greater than 0.')),
pytest.param(
2,
Expand Down