Skip to content
Permalink
Browse files

Refactor operations into separate components by framework (#826)

  • Loading branch information...
tgaddair committed Mar 9, 2019
1 parent 91d5ae8 commit aa605d6de1e07d1184e8793b6d4f65cee9d4efb3
@@ -48,10 +48,18 @@ Status Status::InvalidArgument(std::string message) {
return Status(StatusType::INVALID_ARGUMENT, message);
}

Status Status::InProgress() {
return Status(StatusType::IN_PROGRESS, "");
}

bool Status::ok() const {
return type_ == StatusType::OK;
}

bool Status::in_progress() const {
return type_ == StatusType::IN_PROGRESS;
}

StatusType Status::type() const {
return type_;
}
@@ -16,24 +16,69 @@
#ifndef HOROVOD_COMMON_H
#define HOROVOD_COMMON_H

#include <functional>
#include <memory>
#include <string>
#include <unordered_map>

#include "message.h"

namespace horovod {
namespace common {

// Activity names, see Horovod Timeline for more details.
#define INIT_FUSION_BUFFER "INIT_FUSION_BUFFER"
#define WAIT_FOR_DATA "WAIT_FOR_DATA"
#define WAIT_FOR_OTHER_TENSOR_DATA "WAIT_FOR_OTHER_TENSOR_DATA"
#define ALLOCATE_OUTPUT "ALLOCATE_OUTPUT"
#define MPI_CROSS_ALLGATHER "MPI_CROSS_ALLGATHER"
#define MPI_ALLGATHER "MPI_ALLGATHER"
#define INIT_NCCL "INIT_NCCL"
#define QUEUE "QUEUE"
#define MEMCPY_IN_FUSION_BUFFER "MEMCPY_IN_FUSION_BUFFER"
#define MEMCPY_IN_HOST_BUFFER "MEMCPY_IN_HOST_BUFFER"
#define MEMCPY_IN_SHARED_BUFFER "MEMCPY_IN_SHARED_BUFFER"
#define MPI_ALLREDUCE "MPI_ALLREDUCE"
#define MEMCPY_OUT_HOST_BUFFER "MEMCPY_OUT_HOST_BUFFER"
#define NCCL_ALLREDUCE "NCCL_ALLREDUCE"
#define MEMCPY_OUT_FUSION_BUFFER "MEMCPY_OUT_FUSION_BUFFER"
#define MPI_BCAST "MPI_BCAST"
#define NCCL_REDUCESCATTER "NCCL_REDUCESCATTER"
#define NCCL_ALLGATHER "NCCL_ALLGATHER"
#define NCCL_REDUCE "NCCL_REDUCE"
#define NCCL_BCAST "NCCL_BCAST"
#define COPY_ALLGATHER_OUTPUT "COPY_ALLGATHER_OUTPUT"
#define ALLOCATE_SHARED_BUFFER "ALLOCATE_SHARED_BUFFER"

// Device ID used for CPU.
#define CPU_DEVICE_ID (-1)

// List of supported frameworks.
enum Framework { TENSORFLOW, PYTORCH, MXNET };

enum StatusType { OK, UNKNOWN_ERROR, PRECONDITION_ERROR, ABORTED, INVALID_ARGUMENT };
enum StatusType { OK, UNKNOWN_ERROR, PRECONDITION_ERROR, ABORTED, INVALID_ARGUMENT, IN_PROGRESS };

enum DeviceType { CPU, GPU };

enum Communicator {
GLOBAL = 0,
LOCAL = 1,
CROSS = 2
};

inline std::string CommunicatorName(Communicator comm) {
switch (comm) {
case GLOBAL:
return "global";
case LOCAL:
return "local";
case CROSS:
return "cross";
default:
return "<unknown>";
}
}

class Status {
public:
Status();
@@ -42,7 +87,9 @@ class Status {
static Status PreconditionError(std::string message);
static Status Aborted(std::string message);
static Status InvalidArgument(std::string message);
static Status InProgress();
bool ok() const;
bool in_progress() const;
StatusType type() const;
const std::string& reason() const;

@@ -109,6 +156,33 @@ class OpContext {
virtual ~OpContext() = default;
};

// A callback to call after the MPI communication completes. Since the
// allreduce and allgather ops are asynchronous, this callback is what resumes
// computation after the reduction is completed.
using StatusCallback = std::function<void(const Status&)>;

// Table storing Tensors to be reduced, keyed by unique name.
// This table contains everything necessary to do the reduction.
struct TensorTableEntry {
// Name of the tensor.
std::string tensor_name;
// Operation context.
std::shared_ptr<OpContext> context;
// Input tensor.
std::shared_ptr<Tensor> tensor;
// Pre-allocated output tensor.
std::shared_ptr<Tensor> output;
// Root rank for broadcast operation.
int root_rank = 0;
// Event indicating that data is ready.
std::shared_ptr<ReadyEvent> ready_event;
// GPU to do reduction on, or CPU_DEVICE_ID in case of CPU.
int device = CPU_DEVICE_ID;
// A callback to call with the status.
StatusCallback callback;
};
using TensorTable = std::unordered_map<std::string, TensorTableEntry>;

} // namespace common
} // namespace horovod

@@ -0,0 +1,133 @@
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2019 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef HOROVOD_GLOBAL_STATE_H
#define HOROVOD_GLOBAL_STATE_H

#include <queue>
#include <thread>

#include "fusion_buffer_manager.h"
#include "parameter_manager.h"
#include "timeline.h"

namespace horovod {
namespace common {

// Table for storing Tensor metadata on rank zero. This is used for error
// checking, stall checking and size calculations, as well as determining
// when a reduction is ready to be done (when all nodes are ready to do it).
using MessageTable = std::unordered_map<
std::string,
std::tuple<std::vector<Request>, std::chrono::steady_clock::time_point>>;

// The global state required for the MPI ops.
//
// MPI is a library that stores a lot of global per-program state and often
// requires running on a single thread. As a result, we have to have a single
// background thread responsible for all MPI operations, and communicate with
// that background thread through global state.
struct HorovodGlobalState {
// An atomic boolean which is set to true when background thread is started.
// This ensures that only one background thread is spawned.
std::atomic_flag initialize_flag = ATOMIC_FLAG_INIT;

// A mutex that needs to be used whenever MPI operations are done.
std::mutex mutex;

// Tensors waiting to be allreduced or allgathered.
TensorTable tensor_table;

// Background thread running MPI communication.
std::thread background_thread;

// Whether the background thread should shutdown.
std::atomic_bool shut_down {false};

// Whether Horovod should finalize MPI (only if it has initialized it).
bool should_finalize = false;

// Time point when coordinator last checked for stalled tensors.
std::chrono::steady_clock::time_point last_stall_check;

// Flag indicating whether to perform stall tensor check.
bool perform_stall_check = true;

// Timeline writer.
Timeline timeline;

// Flag indicating whether to mark cycles in the timeline.
bool mark_cycles_in_timeline = false;

ParameterManager param_manager;

// Encapsulates the fusion buffers, handles resizing and auto-tuning of buffer size.
FusionBufferManager fusion_buffer;

// Time point when last cycle started.
std::chrono::steady_clock::time_point last_cycle_start;

// Whether MPI_Init has been completed on the background thread.
std::atomic_bool initialization_done {false};

// The MPI rank, local rank, size, local size, flag indicating whether MPI
// multi-threading is supported, ranks from which the MPI communicator will
// be made and the communicator itself.
int rank = 0;
int local_rank = 0;
int cross_rank = 0;
int size = 1;
int local_size = 1;
int cross_size = 1;
bool mpi_threads_supported = false;
bool is_homogeneous = false;
std::vector<int> ranks;

// COMM_WORLD ranks of processes running on this node.
std::vector<int> local_comm_ranks;

// Numbers of ranks running per node
std::vector<int> local_sizes;

// Pointer to shared buffer for allgather
void* shared_buffer = nullptr;

// Current shared buffer size
int64_t shared_buffer_size = 0;

// Queue of MPI requests waiting to be sent to the coordinator node.
std::queue<Request> message_queue;

// Only exists on the coordinator node (rank zero). Maintains a count of
// how many nodes are ready to allreduce every tensor (keyed by tensor
// name) and time point when tensor started allreduce op.
std::unique_ptr<MessageTable> message_table;

~HorovodGlobalState() {
// Make sure that the destructor of the background thread is safe to
// call. If a thread is still joinable (not detached or complete) its
// destructor cannot be called.
if (background_thread.joinable()) {
shut_down = true;
background_thread.join();
}
}
};

} // namespace common
} // namespace horovod

#endif //HOROVOD_GLOBAL_STATE_H
@@ -33,7 +33,8 @@ enum DataType {
HOROVOD_FLOAT16 = 6,
HOROVOD_FLOAT32 = 7,
HOROVOD_FLOAT64 = 8,
HOROVOD_BOOL = 9
HOROVOD_BOOL = 9,
HOROVOD_BYTE = 10,
};

const std::string& DataType_Name(DataType value);
@@ -0,0 +1,81 @@
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2019 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#include "mpi_context.h"

namespace horovod {
namespace common {

MPI_Datatype MPIContext::GetMPIDataType(const std::shared_ptr<Tensor> tensor) {
return GetMPIDataType(tensor->dtype());
}

MPI_Datatype MPIContext::GetMPIDataType(const DataType dtype) {
switch (dtype) {
case HOROVOD_UINT8:
return MPI_UINT8_T;
case HOROVOD_INT8:
return MPI_INT8_T;
case HOROVOD_UINT16:
return MPI_UINT16_T;
case HOROVOD_INT16:
return MPI_INT16_T;
case HOROVOD_INT32:
return MPI_INT32_T;
case HOROVOD_INT64:
return MPI_INT64_T;
case HOROVOD_FLOAT16:
return mpi_float16_t;
case HOROVOD_FLOAT32:
return MPI_FLOAT;
case HOROVOD_FLOAT64:
return MPI_DOUBLE;
case HOROVOD_BOOL:
return MPI_C_BOOL;
case HOROVOD_BYTE:
return MPI_BYTE;
default:
throw std::logic_error("Type " + DataType_Name(dtype) +
" is not supported in MPI mode.");
}
}

MPI_Op MPIContext::GetMPISumOp(DataType dtype) {
return dtype == HOROVOD_FLOAT16 ? mpi_float16_sum : MPI_SUM;
}

MPI_Comm MPIContext::GetMPICommunicator(Communicator comm) {
switch (comm) {
case GLOBAL:
return mpi_comm;
case LOCAL:
return local_comm;
case CROSS:
return cross_comm;
default:
throw std::logic_error("Communicator " + CommunicatorName(comm) +
" is not supported in MPI mode.");
}
}

int MPIContext::GetMPITypeSize(DataType dtype) {
int out;
MPI_Type_size(GetMPIDataType(dtype), &out);
return out;
}

} // namespace common
} // namespace horovod
Oops, something went wrong.

0 comments on commit aa605d6

Please sign in to comment.
You can’t perform that action at this time.