Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NVTX tracing hooks for profiling with Nsight Systems #2723

Merged
merged 15 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Expand Up @@ -8,9 +8,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Added NVTX tracing hooks for profiling with Nsight Systems. ([#2723](https://github.com/horovod/horovod/pull/2723))

### Changed

- Changed `alltoall` to return the received splits as a second return value if non-uniform splits are sent ([#2631](https://github.com/horovod/horovod/pull/2631))
- Changed `alltoall` to return the received splits as a second return value if non-uniform splits are sent. ([#2631](https://github.com/horovod/horovod/pull/2631))

### Deprecated

Expand Down
16 changes: 16 additions & 0 deletions CMakeLists.txt
Expand Up @@ -234,6 +234,22 @@ message(FATAL_ERROR "You should not mix NCCL and MPI GPU due to a possible deadl
"HOROVOD_ALLOW_MIXED_GPU_IMPL environment variable to '1'.")
endif()

# NVTX
if (NOT "$ENV{HOROVOD_WITHOUT_NVTX}" STREQUAL "1")
set(NVTX_REQUIRED "")
if ("$ENV{HOROVOD_WITH_NVTX}" STREQUAL "1")
set(NVTX_REQUIRED "REQUIRED")
endif ()
find_package(NVTX ${NVTX_REQUIRED})
if(NVTX_FOUND)
include_directories(SYSTEM ${NVTX_INCLUDE_DIRS})
list(APPEND LINKER_LIBS ${NVTX_LIBRARIES})
add_definitions(-DHAVE_NVTX=1)
list(APPEND SOURCES "${PROJECT_SOURCE_DIR}/horovod/common/nvtx_op_range.cc")
set(HAVE_NVTX TRUE)
endif()
endif()

# Gloo
if (NOT "$ENV{HOROVOD_WITHOUT_GLOO}" STREQUAL "1")
if(HAVE_MPI)
Expand Down
30 changes: 30 additions & 0 deletions cmake/Modules/FindNVTX.cmake
@@ -0,0 +1,30 @@
# Try to find NVTX
#
# NVTX comes with the CUDA toolkit so we use that root dir to search for the header-only variation of NVTX.
# Alternatively an explicit path can be given via the variable HOROVOD_NVTX_INCLUDE
#
# The following are set after configuration is done:
# NVTX_FOUND
# NVTX_INCLUDE_DIRS
# NVTX_LIBRARIES

set(HOROVOD_NVTX_INCLUDE $ENV{HOROVOD_NVTX_INCLUDE} CACHE PATH "Folder containing NVIDIA NVTX3 headers")

list(APPEND NVTX_ROOT ${CUDA_TOOLKIT_ROOT_DIR})
# Compatible layer for CMake <3.12:
list(APPEND CMAKE_PREFIX_PATH ${NVTX_ROOT})

find_path(NVTX_INCLUDE_DIR
NAMES nvtx3/nvToolsExt.h
HINTS ${HOROVOD_NVTX_INCLUDE})

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NVTX DEFAULT_MSG NVTX_INCLUDE_DIR)

if (NVTX_FOUND)
set(NVTX_INCLUDE_DIRS ${NVTX_INCLUDE_DIR})
# -ldl for dlopen, dlclose:
set(NVTX_LIBRARIES ${CMAKE_DL_LIBS})
message(STATUS "Found NVTX (include: ${NVTX_INCLUDE_DIRS}, library: ${NVTX_LIBRARIES})")
mark_as_advanced(NVTX_INCLUDE_DIRS NVTX_LIBRARIES)
endif()
24 changes: 16 additions & 8 deletions horovod/common/common.cc
Expand Up @@ -20,35 +20,35 @@
#include <sstream>
#include <cassert>
#include <cstring>
#include <utility>
#include <limits.h>

namespace horovod {
namespace common {

Status::Status() = default;

Status::Status(StatusType type, std::string reason) {
type_ = type;
reason_ = reason;
Status::Status(StatusType type, std::string reason)
: type_(type), reason_(std::move(reason)) {
}

Status Status::OK() {
return Status();
}

Status Status::UnknownError(std::string message) {
Status Status::UnknownError(const std::string& message) {
return Status(StatusType::UNKNOWN_ERROR, message);
}

Status Status::PreconditionError(std::string message) {
Status Status::PreconditionError(const std::string& message) {
return Status(StatusType::PRECONDITION_ERROR, message);
}

Status Status::Aborted(std::string message) {
Status Status::Aborted(const std::string& message) {
return Status(StatusType::ABORTED, message);
}

Status Status::InvalidArgument(std::string message) {
Status Status::InvalidArgument(const std::string& message) {
return Status(StatusType::INVALID_ARGUMENT, message);
}

Expand Down Expand Up @@ -82,7 +82,7 @@ void TensorShape::AppendShape(TensorShape& other) {
}
}

const std::string TensorShape::DebugString() const {
std::string TensorShape::DebugString() const {
std::stringstream args;
args << "[";
for (auto it = shape_.begin(); it != shape_.end(); ++it) {
Expand Down Expand Up @@ -199,5 +199,13 @@ void parse_and_set_affinity(const char* affinity, int local_size, int local_rank
free(affinity_copy);
}

void TensorTableEntry::FinishWithCallback(const Status& status) {
// Callback can be null if the rank sent Join request.
if (callback != nullptr) {
callback(status);
}
nvtx_op_range.End();
}

} // namespace common
} // namespace horovod
25 changes: 17 additions & 8 deletions horovod/common/common.h
Expand Up @@ -24,6 +24,7 @@
#include <unordered_map>

#include "message.h"
#include "nvtx_op_range.h"

namespace horovod {
namespace common {
Expand Down Expand Up @@ -92,6 +93,7 @@ namespace common {
#define HOROVOD_ADASUM_MPI_CHUNK_SIZE "HOROVOD_ADASUM_MPI_CHUNK_SIZE"
#define HOROVOD_THREAD_AFFINITY "HOROVOD_THREAD_AFFINITY"
#define HOROVOD_DISABLE_GROUP_FUSION "HOROVOD_DISABLE_GROUP_FUSION"
#define HOROVOD_DISABLE_NVTX_RANGES "HOROVOD_DISABLE_NVTX_RANGES"

// String constant for gloo interface.
#define GLOO_DEFAULT_IFACE ""
Expand Down Expand Up @@ -137,10 +139,10 @@ class Status {
public:
Status();
static Status OK();
static Status UnknownError(std::string message);
static Status PreconditionError(std::string message);
static Status Aborted(std::string message);
static Status InvalidArgument(std::string message);
static Status UnknownError(const std::string& message);
static Status PreconditionError(const std::string& message);
static Status Aborted(const std::string& message);
static Status InvalidArgument(const std::string& message);
static Status InProgress();
bool ok() const;
bool in_progress() const;
Expand All @@ -149,7 +151,7 @@ class Status {

private:
StatusType type_ = StatusType::OK;
std::string reason_ = "";
std::string reason_;
Status(StatusType type, std::string reason);
};

Expand All @@ -174,7 +176,7 @@ class TensorShape {
void AddDim(int64_t dim);
void AppendShape(TensorShape& other);

const std::string DebugString() const;
std::string DebugString() const;
int dims() const;
int64_t dim_size(int idx) const;
int64_t num_elements() const;
Expand Down Expand Up @@ -226,7 +228,7 @@ class OpContext {
virtual Status AllocateOutput(int output_index, TensorShape shape,
std::shared_ptr<Tensor>* tensor) {
if (output_index == 0) {
return AllocateOutput(shape, tensor);
return AllocateOutput(std::move(shape), tensor);
} else {
throw std::logic_error("output_index != 0 not supported");
}
Expand All @@ -243,7 +245,7 @@ class OpContext {
using StatusCallback = std::function<void(const Status&)>;

// Table storing Tensors to be reduced, keyed by unique name.
// This table contains everything necessary to do the reduction.
// This table contains everything necessary to do the distributed operation.
struct TensorTableEntry {
// Name of the tensor.
std::string tensor_name;
Expand All @@ -261,13 +263,20 @@ struct TensorTableEntry {
int device = CPU_DEVICE_ID;
// A callback to call with the status.
StatusCallback callback;
// If we build with NVTX support: A range marking the start
// and end of the distributed op for this tensor (may be
// shared by multiple tensors).
SharedNvtxOpRange nvtx_op_range;

// Alltoall splits (if tensor is for an Alltoall operation)
// Note: splits are stored in TensorTableEntry to avoid N^2
// storage complexity of collecting all worker split arrays
// on coordinator rank.
std::vector<int32_t> splits;
std::shared_ptr<Tensor> received_splits;

// Execute callback and end NVTX range
void FinishWithCallback(const Status& status);
};
using TensorTable = std::unordered_map<std::string, TensorTableEntry>;

Expand Down
36 changes: 36 additions & 0 deletions horovod/common/nvtx_op_range.cc
@@ -0,0 +1,36 @@
#include "nvtx_op_range.h"

#if HAVE_NVTX

namespace horovod {
namespace common {

NvtxOpsHandle::NvtxOpsHandle() noexcept
: domain_(nvtxDomainCreateA("HorovodOps")), op_names_{}
{
#define REGISTER_STRING(op) op_names_[static_cast<int>(RegisteredNvtxOp::op)] = nvtxDomainRegisterStringA(domain_, #op)
REGISTER_STRING(HorovodAllreduce);
REGISTER_STRING(HorovodGroupedAllreduce);
REGISTER_STRING(HorovodAllgather);
REGISTER_STRING(HorovodBroadcast);
REGISTER_STRING(HorovodAlltoall);
#undef REGISTER_STRING
}

NvtxOpsHandle::~NvtxOpsHandle() {
Disable();
}

void NvtxOpsHandle::Disable() {
if (domain_ != nullptr) {
nvtxDomainDestroy(domain_);
domain_ = nullptr;
}
}

NvtxOpsHandle NvtxOpRange::nvtx_ops_handle;

} // namespace common
} // namespace horovod

#endif // HAVE_NVTX
100 changes: 100 additions & 0 deletions horovod/common/nvtx_op_range.h
@@ -0,0 +1,100 @@
#ifndef HOROVOD_NVTX_OP_RANGE_H
#define HOROVOD_NVTX_OP_RANGE_H

#if HAVE_NVTX
#include <memory>
#include <nvtx3/nvToolsExt.h>
#endif // HAVE_NVTX

namespace horovod {
namespace common {

enum class RegisteredNvtxOp {
HorovodAllreduce = 0,
HorovodGroupedAllreduce,
HorovodAllgather,
HorovodBroadcast,
HorovodAlltoall,
// Insert new enum values above this line
END,
};

#if HAVE_NVTX
class NvtxOpsHandle {
public:
NvtxOpsHandle() noexcept;
~NvtxOpsHandle();

inline nvtxRangeId_t StartRange(RegisteredNvtxOp msg, int64_t payload) {
if (domain_ == nullptr) {
return invalid_range_id;
}

nvtxEventAttributes_t eventAttrib = {0};
eventAttrib.version = NVTX_VERSION;
eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
eventAttrib.messageType = NVTX_MESSAGE_TYPE_REGISTERED;
eventAttrib.message.registered = op_names_[static_cast<int>(msg)];
eventAttrib.payloadType = NVTX_PAYLOAD_TYPE_INT64;
eventAttrib.payload.llValue = payload;

nvtxRangeId_t range_id = nvtxDomainRangeStartEx(domain_, &eventAttrib);
return range_id;
}

inline void EndRange(nvtxRangeId_t range_id) {
if (domain_ == nullptr || range_id == invalid_range_id) {
return;
}
nvtxDomainRangeEnd(domain_, range_id);
}

void Disable();

static constexpr nvtxRangeId_t invalid_range_id = 0xfffffffffffffffful;

private:
nvtxDomainHandle_t domain_; // nullptr if disabled
nvtxStringHandle_t op_names_[static_cast<int>(RegisteredNvtxOp::END)];
};

class NvtxOpRange {
public:
NvtxOpRange(RegisteredNvtxOp msg, int64_t payload)
: range_id_(nvtx_ops_handle.StartRange(msg, payload)) {
}

~NvtxOpRange() { nvtx_ops_handle.EndRange(range_id_); }

static NvtxOpsHandle nvtx_ops_handle;

private:
nvtxRangeId_t range_id_;
};

class SharedNvtxOpRange {
public:
void Start(RegisteredNvtxOp msg, int64_t payload) {
p_ = std::make_shared<NvtxOpRange>(msg, payload);
}

void End() {
p_.reset();
}

private:
std::shared_ptr<NvtxOpRange> p_;
};

#else // HAVE_NVTX
class SharedNvtxOpRange {
public:
void Start(RegisteredNvtxOp msg, int64_t payload) { }
void End() { }
};
#endif // HAVE_NVTX

} // namespace common
} // namespace horovod

#endif // HOROVOD_NVTX_OP_RANGE_H