Skip to content

Commit

Permalink
Merge branch 'master' into branch-unpin-torchhead
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed May 15, 2021
2 parents c5cf6fb + cbb343d commit 38d8ae0
Show file tree
Hide file tree
Showing 22 changed files with 232 additions and 127 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Estimator: add petastorm reader_pool_type into constructor ([#2903](https://github.com/horovod/horovod/pull/2903))
- Added NVTX tracing hooks for profiling with Nsight Systems. ([#2723](https://github.com/horovod/horovod/pull/2723))
- Added a generic `num_workers` API for ``RayExecutor`` ([#2870](https://github.com/horovod/horovod/pull/2870))

Expand All @@ -28,6 +29,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
### Fixed

- Changed RayExecutor to use Ray node ID to enable multi-container:single-host setups. ([#2883](https://github.com/horovod/horovod/pull/2882))
- Support sparse gradients aggregation in TF1 Keras. ([#2879](https://github.com/horovod/horovod/pull/2879))
- Respect `global_step` parameter for LegacyOptimizers when aggregating gradients. ([#2879](https://github.com/horovod/horovod/pull/2879))

## [v0.21.3] - 2021-02-15

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.test.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ RUN if [[ ${PYTORCH_PACKAGE} == "torch-nightly" ]]; then \

# Install MXNet (nightly).
RUN if [[ ${MXNET_PACKAGE} == "mxnet-nightly" ]]; then \
pip install --no-cache-dir --pre mxnet==2.0.0b20210319 -f https://dist.mxnet.io/python/all; \
pip install --no-cache-dir --pre mxnet -f https://dist.mxnet.io/python/all; \
fi

# Install Horovod.
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.test.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ RUN if [[ ${PYTORCH_PACKAGE} == "torch-nightly-cu"* ]]; then \

# Install MXNet (nightly).
RUN if [[ ${MXNET_PACKAGE} == "mxnet-nightly-cu"* ]]; then \
pip install --no-cache-dir --pre ${MXNET_PACKAGE/-nightly/}==2.0.0b20210319 -f https://dist.mxnet.io/python/${MXNET_PACKAGE/#mxnet-nightly-/}; \
pip install --no-cache-dir --pre ${MXNET_PACKAGE/-nightly/} -f https://dist.mxnet.io/python/${MXNET_PACKAGE/#mxnet-nightly-/}; \
fi

# Install Horovod.
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ See `Run Horovod <docs/running.rst>`_ for more details, including RoCE/InfiniBan

9. To run in a LSF HPC cluster (e.g. Summit), see `LSF <docs/lsf.rst>`_.

10. To run on Hadoop Yarn, see `TonY <https://github.com/linkedin/TonY/>`_
10. To run on Hadoop Yarn, see `TonY <https://github.com/linkedin/TonY/>`_.

Gloo
----
Expand Down
36 changes: 20 additions & 16 deletions cmake/Modules/FindMxnet.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Mxnet_LIBRARIES
# Mxnet_COMPILE_FLAGS
# Mxnet_USE_MKLDNN
# Mxnet_USE_ONEDNN
# Mxnet_USE_CUDA
# Mxnet_VERSION
# Mxnet_CXX11
Expand All @@ -15,37 +16,42 @@ list(APPEND CMAKE_PREFIX_PATH ${Mxnet_ROOT})
set(Mxnet_COMPILE_FLAGS "")

set(ENV{PYTHONPATH} "${PROJECT_SOURCE_DIR}/cmake:$ENV{PYTHONPATH}")
execute_process(COMMAND ${PY_EXE} -c "import os; import mxnet as mx; import build_utils; print(mx.__version__); print(mx.libinfo.find_include_path()); print(' '.join(mx.libinfo.find_lib_path())); print(build_utils.is_mx_mkldnn()); print(build_utils.is_mx_cuda())"
execute_process(COMMAND ${PY_EXE} -c "import os; import mxnet as mx; import build_utils; print(mx.__version__); print(mx.libinfo.find_include_path()); print(' '.join(mx.libinfo.find_lib_path())); print(build_utils.is_mx_mkldnn()); print(build_utils.is_mx_onednn()); print(build_utils.is_mx_cuda())"
OUTPUT_VARIABLE Mxnet_OUTPUT OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET)
string(REGEX REPLACE "\n" ";" Mxnet_OUTPUT "${Mxnet_OUTPUT}")
list(LENGTH Mxnet_OUTPUT LEN)
if (LEN EQUAL "5")
if (LEN EQUAL "6")
list(GET Mxnet_OUTPUT 0 Mxnet_VERSION)
list(GET Mxnet_OUTPUT 1 Mxnet_INCLUDE_DIRS)
list(GET Mxnet_OUTPUT 2 Mxnet_LIBRARIES)
string(REPLACE " " ";" Mxnet_LIBRARIES "${Mxnet_LIBRARIES}")
list(GET Mxnet_OUTPUT 3 Mxnet_USE_MKLDNN)
list(GET Mxnet_OUTPUT 4 Mxnet_USE_CUDA)
list(GET Mxnet_OUTPUT 4 Mxnet_USE_ONEDNN)
list(GET Mxnet_OUTPUT 5 Mxnet_USE_CUDA)
string(TOUPPER ${Mxnet_USE_MKLDNN} Mxnet_USE_MKLDNN)
string(TOUPPER ${Mxnet_USE_ONEDNN} Mxnet_USE_ONEDNN)
string(TOUPPER ${Mxnet_USE_CUDA} Mxnet_USE_CUDA)
if (Mxnet_USE_MKLDNN)
if (NOT EXISTS ${Mxnet_INCLUDE_DIRS}/mkldnn AND NOT EXISTS ${Mxnet_INCLUDE_DIRS}/onednn)
if (Mxnet_USE_MKLDNN OR Mxnet_USE_ONEDNN)
if (Mxnet_USE_MKLDNN AND EXISTS ${Mxnet_INCLUDE_DIRS}/mkldnn)
set(Mxnet_COMPILE_FLAGS "${Mxnet_COMPILE_FLAGS} -DMXNET_USE_MKLDNN=1 -DMXNET_USE_ONEDNN=0")
list(APPEND Mxnet_INCLUDE_DIRS "${Mxnet_INCLUDE_DIRS}/mkldnn")
elseif (Mxnet_USE_ONEDNN AND EXISTS ${Mxnet_INCLUDE_DIRS}/onednn)
set(Mxnet_COMPILE_FLAGS "${Mxnet_COMPILE_FLAGS} -DMXNET_USE_MKLDNN=0 -DMXNET_USE_ONEDNN=1")
list(APPEND Mxnet_INCLUDE_DIRS "${Mxnet_INCLUDE_DIRS}/onednn")
else()
if (Mxnet_FIND_REQUIRED)
set(MSG_LEVEL "FATAL_ERROR")
else()
set(MSG_LEVEL "WARNING")
endif()
set(MXNET_FOUND FALSE)
message(${MSG_LEVEL} "MXNet was found with mkl-dnn / onednn support but mkldnn / onednn header files are missing. Please, install MXNet with mkldnn / onednn header files.")
if (Mxnet_USE_MKLDNN)
message(${MSG_LEVEL} "MXNet was found with mkl-dnn support but mkldnn header files are missing. Please, install MXNet with mkldnn header files.")
elseif (Mxnet_USE_ONEDNN)
message(${MSG_LEVEL} "MXNet was found with onednn support but onednn header files are missing. Please, install MXNet with onednn header files.")
endif()
return()
endif()
if (EXISTS ${Mxnet_INCLUDE_DIRS}/mkldnn)
set(Mxnet_COMPILE_FLAGS "${Mxnet_COMPILE_FLAGS} -DMXNET_USE_MKLDNN=1 -DMXNET_USE_ONEDNN=0")
list(APPEND Mxnet_INCLUDE_DIRS "${Mxnet_INCLUDE_DIRS}/mkldnn")
elseif (EXISTS ${Mxnet_INCLUDE_DIRS}/onednn)
set(Mxnet_COMPILE_FLAGS "${Mxnet_COMPILE_FLAGS} -DMXNET_USE_MKLDNN=0 -DMXNET_USE_ONEDNN=1")
list(APPEND Mxnet_INCLUDE_DIRS "${Mxnet_INCLUDE_DIRS}/onednn")
endif()
else()
set(Mxnet_COMPILE_FLAGS "${Mxnet_COMPILE_FLAGS} -DMXNET_USE_MKLDNN=0 -DMXNET_USE_ONEDNN=0")
endif()
Expand All @@ -71,6 +77,4 @@ else()
set(Mxnet_COMPILE_FLAGS "${Mxnet_COMPILE_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()

mark_as_advanced(Mxnet_INCLUDE_DIRS Mxnet_LIBRARIES Mxnet_COMPILE_FLAGS Mxnet_USE_MKLDNN Mxnet_USE_CUDA Mxnet_VERSION)

mark_as_advanced(Mxnet_INCLUDE_DIRS Mxnet_LIBRARIES Mxnet_COMPILE_FLAGS Mxnet_USE_MKLDNN Mxnet_USE_CUDA Mxnet_VERSION)
mark_as_advanced(Mxnet_INCLUDE_DIRS Mxnet_LIBRARIES Mxnet_COMPILE_FLAGS Mxnet_USE_MKLDNN Mxnet_USE_ONEDNN Mxnet_USE_CUDA Mxnet_VERSION)
30 changes: 18 additions & 12 deletions cmake/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,44 @@ def is_mx_cuda():
return False

def is_mx_mkldnn():
return is_mx_dnn('MKLDNN')

def is_mx_onednn():
return is_mx_dnn('ONEDNN')

def is_mx_dnn(dnn_flavour: str):
"""
Detects if MXNet is build with MKLDNN or oneDNN support.
MXNET ≥ 2.0.0 uses oneDNN (renamed from MKLDNN) but still calls the feature 'MKLDNN'.
Detects if MXNet is build with given DNN flavour (MKLDNN or oneDNN) support.
MXNET ≥ 2.0.0 uses oneDNN (renamed from MKLDNN), < 2.0.0 MKLDNN.
"""
dnn_flavour_lower = dnn_flavour.lower()
dnn_flavour = dnn_flavour.upper()
try:
from mxnet import runtime
features = runtime.Features()
return features.is_enabled('MKLDNN')
return features.is_enabled(dnn_flavour)
except Exception:
msg = f'INFO: Cannot detect if MKLDNN / ONEDNN is enabled in MXNet. Please ' \
f'set MXNET_USE_MKLDNN=1 if MKLDNN / MXNET_USE_ONEDNN=1 if ONEDNN is ' \
msg = f'INFO: Cannot detect if {dnn_flavour} is enabled in MXNet. Please ' \
f'set MXNET_USE_{dnn_flavour}=1 if {dnn_flavour} is ' \
f'enabled in your MXNet build.'
if 'linux' not in sys.platform:
# MKLDNN / oneDNN is only enabled by default in MXNet Linux build. Return
# False by default for non-linux build but still allow users to
# enable it by using MXNET_USE_MKLDNN / MXNET_USE_ONEDNN env variable.
print(msg)
return os.environ.get(f'MXNET_USE_MKLDNN', '0') == '1' or \
os.environ.get(f'MXNET_USE_ONEDNN', '0') == '1'
print(msg, file=sys.stderr)
return os.environ.get(f'MXNET_USE_{dnn_flavour}', '0') == '1'
else:
try:
import mxnet as mx
mx_libs = mx.libinfo.find_lib_path()
for mx_lib in mx_libs:
output = subprocess.check_output(['readelf', '-d', mx_lib])
if 'mkldnn' in str(output):
if dnn_flavour_lower in str(output):
return True
return False
except Exception:
print(msg)
return os.environ.get(f'MXNET_USE_MKLDNN', '0') == '1' or \
os.environ.get(f'MXNET_USE_ONEDNN', '0') == '1'
print(msg, file=sys.stderr)
return os.environ.get(f'MXNET_USE_{dnn_flavour}', '0') == '1'

def get_nvcc_bin():
cuda_home = os.environ.get('HOROVOD_CUDA_HOME', '/usr/local/cuda')
Expand Down
2 changes: 2 additions & 0 deletions docs/summary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ See `Run Horovod <running.rst>`_ for more details, including RoCE/InfiniBand twe

9. To run in a LSF HPC cluster (e.g. Summit), see `LSF <lsf.rst>`_.

10. To run on Hadoop Yarn, see `TonY <https://github.com/linkedin/TonY/>`_.

Gloo
----
`Gloo <https://github.com/facebookincubator/gloo>`_ is an open source collective communications library developed by Facebook.
Expand Down
1 change: 1 addition & 0 deletions horovod/common/controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
// order consistently across workers.
for (auto& response : response_list.responses()) {
if ((response.response_type() == Response::ResponseType::ALLREDUCE ||
response.response_type() == Response::ResponseType::ALLGATHER ||
response.response_type() == Response::ResponseType::ADASUM ||
response.response_type() == Response::ResponseType::ALLTOALL) &&
(int)response.devices().size() == size_) {
Expand Down
7 changes: 6 additions & 1 deletion horovod/common/response_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ void ResponseCache::put(const Response& response, TensorQueue& tensor_queue, boo
joined);
}

int64_t global_size = response.tensor_sizes().size() / response.tensor_names().size();
// If response is fused, split back into individual responses
if (response.tensor_names().size() > 1) {
int64_t i = 0;
Expand All @@ -200,7 +201,11 @@ void ResponseCache::put(const Response& response, TensorQueue& tensor_queue, boo
new_response.add_tensor_name(name);
new_response.set_response_type(response.response_type());
new_response.set_devices(response.devices());
new_response.add_tensor_size(response.tensor_sizes()[i]);
// For allreduce, adasum and alltoall, tensor_sizes are the num_elements of response
// For allgather, tensor_sizes are the first dim of all ranks
for (int64_t j = 0; j < global_size; ++j) {
new_response.add_tensor_size(response.tensor_sizes()[i * global_size + j]);
}
new_response.set_tensor_type(response.tensor_type());
new_response.set_prescale_factor(response.prescale_factor());
new_response.set_postscale_factor(response.postscale_factor());
Expand Down
8 changes: 8 additions & 0 deletions horovod/spark/common/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class EstimatorParams(Params):
'number of parallel worker processes to read train data')
val_reader_num_workers = Param(Params._dummy(), 'val_reader_num_workers',
'number of parallel worker processes to read validation data')
reader_pool_type = Param(Params._dummy(), 'reader_pool_type', 'type of worker pool to read data')
optimizer = Param(Params._dummy(), 'optimizer', 'optimizer')
model = Param(Params._dummy(), 'model', 'model')
backend = Param(Params._dummy(), 'backend', 'backend')
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(self):
transformation_fn=None,
train_reader_num_workers=2,
val_reader_num_workers=2,
reader_pool_type='process',
label_shapes=None)

def _check_params(self, metadata):
Expand Down Expand Up @@ -309,6 +311,12 @@ def setValReaderNumWorker(self, value):
def getValReaderNumWorker(self):
return self.getOrDefault(self.val_reader_num_workers)

def setReaderPoolType(self, value):
return self._set(reader_pool_type=value)

def getReaderPoolType(self):
return self.getOrDefault(self.reader_pool_type)

def setLabelShapes(self, value):
return self._set(label_shapes=value)

Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/keras/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class KerasEstimator(HorovodEstimator, KerasEstimatorParamsReadable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

custom_objects = Param(Params._dummy(), 'custom_objects', 'custom objects')
Expand Down Expand Up @@ -194,6 +196,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
checkpoint_callback=None):

Expand Down
5 changes: 3 additions & 2 deletions horovod/spark/keras/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx):
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()

# Model parameters
input_shapes, output_shapes = estimator.get_model_shapes()
Expand Down Expand Up @@ -214,7 +215,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
with reader_factory(remote_store.train_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=train_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand All @@ -224,7 +225,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
with reader_factory(remote_store.val_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=val_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/lightning/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand Down Expand Up @@ -193,6 +195,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
inmemory_cache_all=False):

Expand Down
9 changes: 5 additions & 4 deletions horovod/spark/lightning/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()

# Utility functions
deserialize = deserialize_fn()
Expand Down Expand Up @@ -126,9 +127,9 @@ def train(serialized_model):
# print(row_group)

with make_petastorm_reader(model, remote_store.train_data_path, 'train_dataloader',
train_reader_worker_count), \
train_reader_worker_count, reader_pool_type), \
make_petastorm_reader(model, remote_store.val_data_path, 'val_dataloader',
val_reader_worker_count, should_validate):
val_reader_worker_count, reader_pool_type, should_validate):

trainer.fit(model)

Expand Down Expand Up @@ -168,7 +169,7 @@ def on_sanity_check_end(self, trainer, model):
def _make_petastorm_reader_fn(transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, dataloader_cls):

@contextlib.contextmanager
def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count, should_read=True):
def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, should_read=True):
from petastorm import TransformSpec, make_reader, make_batch_reader
import horovod.torch as hvd

Expand Down Expand Up @@ -201,7 +202,7 @@ def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count
with reader_factory(data_path,
num_epochs=1,
cur_shard=hvd.rank(),
reader_pool_type='process',
reader_pool_type=reader_pool_type,
workers_count=reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
Expand Down
3 changes: 3 additions & 0 deletions horovod/spark/torch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
high enough, or users need to apply transformation such as
decompression or data augmentation on raw data.
val_reader_num_workers: Similar to the train_reader_num_workers.
reader_pool_type: Type of worker pool used to parallelize reading data from the dataset.
Should be one of ['thread', 'process']. Defaults to 'process'.
"""

input_shapes = Param(Params._dummy(), 'input_shapes', 'input layer shapes')
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(self,
transformation_fn=None,
train_reader_num_workers=None,
val_reader_num_workers=None,
reader_pool_type=None,
label_shapes=None,
inmemory_cache_all=False):

Expand Down
Loading

0 comments on commit 38d8ae0

Please sign in to comment.