diff --git a/chainermn/optimizers.py b/chainermn/optimizers.py index 901e7c9e..55eaf24e 100644 --- a/chainermn/optimizers.py +++ b/chainermn/optimizers.py @@ -1,6 +1,18 @@ import chainer import copy import multiprocessing.pool +import warnings + + +def _check_mp_start_method(comm): + """Show a warning if multiprocessing's start_method is not 'forkserver'.""" + method = multiprocessing.get_start_method() + + if comm.size > 1 and comm.rank == 0: + if method is not 'forkserver': + warnings.warn("multiprocessing's `start_method` must be " + "'forkserver' (now it's '{}')".format(method), + stacklevel=2) class _MultiNodeOptimizer(object): @@ -164,6 +176,8 @@ def create_multi_node_optimizer(actual_optimizer, communicator, Returns: The multi node optimizer based on ``actual_optimizer``. """ + _check_mp_start_method(communicator) + if double_buffering: from chainermn.communicators.pure_nccl_communicator \ import PureNcclCommunicator