Skip to content

Commit

Permalink
support for fp16 in naive comm (#6333)
Browse files Browse the repository at this point in the history
* support fp16 in naive comm
* support for fp16 in bcast
  • Loading branch information
shu65 authored and kuenishi committed Mar 4, 2019
1 parent d9bdfc6 commit b23f877
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
8 changes: 7 additions & 1 deletion chainermn/communicators/mpi_communicator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,14 @@ def allreduce_obj(self, obj):
def bcast_data(self, model):
for _, param in sorted(model.namedparams()):
if param.data is not None:
buf = _memory_utility.array_to_buffer_object(param.data)
data = param.data
is_float16 = param.data.dtype == numpy.float16
if is_float16:
data = data.astype(numpy.float32)
buf = _memory_utility.array_to_buffer_object(data)
self.mpi_comm.Bcast(buf)
if is_float16:
param.data = data.astype(numpy.float16)

# Private methods
def _init_ranks(self):
Expand Down
9 changes: 8 additions & 1 deletion chainermn/communicators/naive_communicator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import mpi4py.MPI
import numpy as np

from chainermn.communicators import _memory_utility
from chainermn.communicators import mpi_communicator_base
Expand All @@ -11,6 +12,12 @@ def __init__(self, mpi_comm):

def allreduce_grad(self, model):
for param in _memory_utility.extract_params_set_grad(model):
buf = _memory_utility.array_to_buffer_object(param.grad)
grad = param.grad
is_float16 = param.grad.dtype == np.float16
if is_float16:
grad = grad.astype(np.float32)
buf = _memory_utility.array_to_buffer_object(grad)
self.mpi_comm.Allreduce(mpi4py.MPI.IN_PLACE, buf)
if is_float16:
param.grad = grad.astype(np.float16)
param.grad /= self.size
4 changes: 4 additions & 0 deletions tests/chainermn_tests/communicator_tests/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __init__(self, param):
{
'communicator_class': NaiveCommunicator,
'multi_node': True,
}, {
'communicator_class': NaiveCommunicator,
'model_dtype': np.float16,
'multi_node': True,
}, {
'communicator_class': FlatCommunicator,
'multi_node': True,
Expand Down

0 comments on commit b23f877

Please sign in to comment.