Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5226 from chainer/integrate-chainermn
Integrate ChainerMN to Chainer repository
- Loading branch information
Showing
105 changed files
with
11,965 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import pkg_resources | ||
|
||
from chainermn import communicators # NOQA | ||
from chainermn import datasets # NOQA | ||
from chainermn import extensions # NOQA | ||
from chainermn import functions # NOQA | ||
from chainermn import global_except_hook # NOQA | ||
from chainermn import iterators # NOQA | ||
from chainermn import links # NOQA | ||
from chainermn import optimizers # NOQA | ||
|
||
from chainermn.communicators import CommunicatorBase # NOQA | ||
from chainermn.communicators import create_communicator # NOQA | ||
from chainermn.datasets import DataSizeError # NOQA | ||
from chainermn.datasets import scatter_dataset # NOQA | ||
from chainermn.extensions import create_multi_node_checkpointer # NOQA | ||
from chainermn.extensions import create_multi_node_evaluator # NOQA | ||
from chainermn.links import MultiNodeChainList # NOQA | ||
from chainermn.optimizers import create_multi_node_optimizer # NOQA | ||
|
||
global_except_hook._add_hook_if_enabled() | ||
|
||
__version__ = pkg_resources.get_distribution('chainer').version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from chainermn.communicators.communicator_base import CommunicatorBase # NOQA | ||
|
||
|
||
def create_communicator( | ||
communicator_name='hierarchical', mpi_comm=None, | ||
allreduce_grad_dtype=None): | ||
"""Create a ChainerMN communicator. | ||
Different communicators provide different approaches of communication, so | ||
they have different performance charasteristics. The default communicator | ||
``hierarchical`` is expected to generally perform well on a variety of | ||
environments, so one need not to change communicators in most cases. | ||
However, choosing proper communicator may give better performance. | ||
The following communicators are available. | ||
+---------------+---+---+--------+--------------------------------------+ | ||
|Name |CPU|GPU|NCCL |Recommended Use Cases | | ||
+===============+===+===+========+======================================+ | ||
|pure_nccl | |OK |Required|``pure_nccl`` is recommended when | | ||
| | | |(>= v2) |NCCL2 is available in the environment.| | ||
+---------------+---+---+--------+--------------------------------------+ | ||
|hierarchical | |OK |Required|Each node has a single NIC or HCA | | ||
+---------------+---+---+--------+--------------------------------------+ | ||
|two_dimensional| |OK |Required|Each node has multiple NICs or HCAs | | ||
+---------------+---+---+--------+--------------------------------------+ | ||
|single_node | |OK |Required|Single node with multiple GPUs | | ||
+---------------+---+---+--------+--------------------------------------+ | ||
|flat | |OK | |N/A | | ||
+---------------+---+---+--------+--------------------------------------+ | ||
|naive |OK |OK | |Testing on CPU mode | | ||
+---------------+---+---+--------+--------------------------------------+ | ||
Args: | ||
communicator_name: The name of communicator (``naive``, ``flat``, | ||
``hierarchical``, ``two_dimensional``, ``pure_nccl``, or | ||
``single_node``) | ||
mpi_comm: MPI4py communicator | ||
allreduce_grad_dtype: Data type of gradient used in All-Reduce. | ||
If ``None``, the dtype of a model is used. | ||
Returns: | ||
ChainerMN communicator that implements methods defined in | ||
:class:`chainermn.CommunicatorBase` | ||
""" | ||
|
||
if mpi_comm is None: | ||
import mpi4py.MPI | ||
mpi_comm = mpi4py.MPI.COMM_WORLD | ||
|
||
if communicator_name != 'pure_nccl' and allreduce_grad_dtype is not None: | ||
raise ValueError( | ||
'allreduce_grad_dtype is only available' | ||
'at \'pure_nccl\' communicator.') | ||
|
||
if communicator_name == 'naive': | ||
from chainermn.communicators.naive_communicator \ | ||
import NaiveCommunicator | ||
return NaiveCommunicator(mpi_comm=mpi_comm) | ||
|
||
elif communicator_name == 'flat': | ||
from chainermn.communicators.flat_communicator \ | ||
import FlatCommunicator | ||
return FlatCommunicator(mpi_comm=mpi_comm) | ||
|
||
elif communicator_name == 'hierarchical': | ||
from chainermn.communicators.hierarchical_communicator \ | ||
import HierarchicalCommunicator | ||
return HierarchicalCommunicator(mpi_comm=mpi_comm) | ||
|
||
elif communicator_name == 'two_dimensional': | ||
from chainermn.communicators.two_dimensional_communicator \ | ||
import TwoDimensionalCommunicator | ||
return TwoDimensionalCommunicator(mpi_comm=mpi_comm) | ||
|
||
elif communicator_name == 'single_node': | ||
from chainermn.communicators.single_node_communicator \ | ||
import SingleNodeCommunicator | ||
return SingleNodeCommunicator(mpi_comm=mpi_comm) | ||
|
||
elif communicator_name == 'non_cuda_aware': | ||
from chainermn.communicators.non_cuda_aware_communicator \ | ||
import NonCudaAwareCommunicator | ||
return NonCudaAwareCommunicator(mpi_comm=mpi_comm) | ||
|
||
elif communicator_name == 'pure_nccl': | ||
from chainermn.communicators.pure_nccl_communicator \ | ||
import PureNcclCommunicator | ||
return PureNcclCommunicator(mpi_comm=mpi_comm, | ||
allreduce_grad_dtype=allreduce_grad_dtype) | ||
|
||
elif communicator_name == 'dummy': | ||
from chainermn.communicators.dummy_communicator \ | ||
import DummyCommunicator | ||
return DummyCommunicator(mpi_comm=mpi_comm) | ||
|
||
else: | ||
raise ValueError( | ||
'Unrecognized communicator: "{}"'.format(communicator_name)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import collections | ||
import pickle | ||
|
||
import mpi4py.MPI | ||
|
||
|
||
def init_ranks(mpi_comm): | ||
"""Returns rank information of the local process in `mpi_comm`. | ||
Args: | ||
mpi_comm (type:TODO) | ||
MPI Communicator from mpi4py | ||
Returns: | ||
rank_info (list): | ||
Elements are: | ||
* rank (`mpi_comm.rank`) | ||
* intra_rank (rank within the local computing node) | ||
* intra_size (number of processes on the node) | ||
* inter_rank (rank of the node) | ||
* inter_size (number of computing nodes) | ||
""" | ||
|
||
global_names = mpi_comm.gather(mpi4py.MPI.Get_processor_name()) | ||
|
||
if mpi_comm.rank == 0: | ||
name_to_global_ranks = collections.defaultdict(list) | ||
for global_rank, name in enumerate(global_names): | ||
name_to_global_ranks[name].append(global_rank) | ||
|
||
for global_ranks in name_to_global_ranks.values(): | ||
global_ranks.sort() | ||
|
||
inter_names = sorted( | ||
set(global_names), key=lambda name: name_to_global_ranks[name]) | ||
name_to_inter_rank = { | ||
name: inter_rank | ||
for inter_rank, name in enumerate(inter_names) | ||
} | ||
inter_size = len(inter_names) | ||
|
||
all_ranks = [] | ||
for global_rank, name in enumerate(global_names): | ||
ranks = name_to_global_ranks[name] | ||
intra_rank = ranks.index(global_rank) | ||
intra_size = len(ranks) | ||
inter_rank = name_to_inter_rank[name] | ||
all_ranks.append(( | ||
global_rank, intra_rank, intra_size, | ||
inter_rank, inter_size)) | ||
my_ranks = mpi_comm.scatter(all_ranks) | ||
else: | ||
my_ranks = mpi_comm.scatter(None) | ||
|
||
assert my_ranks[0] == mpi_comm.rank | ||
return my_ranks | ||
|
||
|
||
def init_intra_mpi_comm(mpi_comm, intra_rank, inter_rank): | ||
return mpi_comm.Split(inter_rank, intra_rank) | ||
|
||
|
||
def init_inter_mpi_comm(mpi_comm, intra_rank, inter_rank): | ||
return mpi_comm.Split(intra_rank, inter_rank) | ||
|
||
|
||
def init_nccl_comm(mpi_comm): | ||
from chainermn import nccl | ||
if mpi_comm.rank == 0: | ||
nccl_comm_id = nccl.get_unique_id() | ||
else: | ||
nccl_comm_id = None | ||
nccl_comm_id = mpi_comm.bcast(nccl_comm_id) | ||
return nccl.NcclCommunicator(mpi_comm.size, nccl_comm_id, mpi_comm.rank) | ||
|
||
|
||
def inter_allreduce_gpu( | ||
inter_mpi_comm, size, gpu_buffer_a, gpu_buffer_b, | ||
n_bytes_buffer, n_elems_per_node, n_bytes_per_node, cuda_stream): | ||
inter_size = inter_mpi_comm.size | ||
|
||
# Exchange all data to get own region data (bufferB -> bufferA) | ||
cuda_stream.synchronize() | ||
inter_mpi_comm.Alltoall( | ||
[gpu_buffer_b.buffer(n_bytes_buffer), mpi4py.MPI.FLOAT], | ||
[gpu_buffer_a.buffer(n_bytes_buffer), mpi4py.MPI.FLOAT]) | ||
|
||
# Reduce own region data (inplace bufferA) and averaging | ||
ret = gpu_buffer_a.array(inter_size * n_elems_per_node) \ | ||
.reshape(inter_size, n_elems_per_node) \ | ||
.sum(axis=0) * (1.0 / size) | ||
|
||
# Gather others' region data (bufferA -> bufferB) | ||
for i in range(0, inter_size): | ||
gpu_buffer_a.from_device( | ||
ret, n_bytes_per_node, i * n_bytes_per_node) | ||
|
||
cuda_stream.synchronize() | ||
inter_mpi_comm.Alltoall( | ||
[gpu_buffer_a.buffer(n_bytes_buffer), mpi4py.MPI.FLOAT], | ||
[gpu_buffer_b.buffer(n_bytes_buffer), mpi4py.MPI.FLOAT]) | ||
|
||
|
||
INT_MAX = 2147483647 | ||
|
||
|
||
def chunked_bcast_obj(obj, mpi_comm, max_buf_len=256 * 1024 * 1024, | ||
root=0): | ||
'''Split object to max_buf_len size chunks and send them out | ||
As mpi4py does not accept an object whose pickled size is larger | ||
than signed integer max (2147483647) the object is pickled and | ||
split into chunks. | ||
Another hack could be try with mpi_comm.bcast(obj) then rank 0 | ||
node will receive OverflowError from mpi4py. But in that case rank | ||
> 0 nodes shall block busy waiting forever at mpi_comm.bcast(obj). | ||
Args: | ||
obj: A Python object that is to be broadcasted. | ||
comm: ChainerMN communicator or MPI4py communicator. | ||
root (int): The root process of the scatter operation. | ||
max_buf_len (int): Max buffer size to be used at broadcasting | ||
binaries. Must not be larger than 2147483647 (INT_MAX). | ||
Default value is 256MB. | ||
Returns: | ||
Broadcasted object. | ||
''' | ||
assert max_buf_len < INT_MAX | ||
assert max_buf_len > 0 | ||
|
||
# check XOR condition of obj is None and rank==0 | ||
# rank \ obj | None | not None | | ||
# == 0 | NG | OK | | ||
# > 0 | OK | NG | | ||
assert not (obj is None and mpi_comm.rank == root) | ||
assert not (obj is not None and mpi_comm.rank != root) | ||
|
||
if obj is not None and mpi_comm.rank == root: | ||
pickled_bytes = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) | ||
else: | ||
pickled_bytes = bytearray() | ||
|
||
total_bytes = len(pickled_bytes) | ||
total_chunk_num = total_bytes // max_buf_len | ||
if (total_bytes % max_buf_len) > 0: | ||
total_chunk_num += 1 | ||
|
||
data = mpi_comm.bcast((total_chunk_num, max_buf_len, total_bytes)) | ||
assert data is not None | ||
(total_chunk_num, max_buf_len, total_bytes) = data | ||
|
||
for i in range(total_chunk_num): | ||
b = i * max_buf_len | ||
e = min(b + max_buf_len, total_bytes) | ||
|
||
if mpi_comm.rank == root: | ||
buf = pickled_bytes[b:e] | ||
else: | ||
buf = bytearray(e - b) | ||
|
||
mpi_comm.Bcast(buf, root=root) | ||
|
||
if mpi_comm.rank != root: | ||
pickled_bytes[b:e] = buf | ||
|
||
if mpi_comm.rank > root: | ||
obj = pickle.loads(pickled_bytes) | ||
|
||
return obj |
Oops, something went wrong.