Skip to content

Commit

Permalink
multi processing and fork fix (apache#8677)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored and eric-haibin-lin committed Dec 3, 2017
1 parent 9c6c258 commit b3ac0ab
Show file tree
Hide file tree
Showing 22 changed files with 651 additions and 98 deletions.
2 changes: 1 addition & 1 deletion dmlc-core
32 changes: 27 additions & 5 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ struct Context {
enum DeviceType {
kCPU = cpu::kDevMask,
kGPU = gpu::kDevMask,
kCPUPinned = 3
kCPUPinned = 3,
kCPUShared = 5,
};
/*! \brief the device type we run the op on */
DeviceType dev_type;
Expand All @@ -155,10 +156,17 @@ struct Context {
* \brief Get corresponding device mask
* \return cpu::kDevMask or gpu::kDevMask
*/
inline int dev_mask() const {
if (dev_type == kCPUPinned) return cpu::kDevMask;
inline DeviceType dev_mask() const {
if (dev_type == kCPUPinned || dev_type == kCPUShared) return kCPU;
return dev_type;
}
/*!
* \brief Returns dev_id for kGPU, 0 otherwise
*/
inline int real_dev_id() const {
if (dev_type == kGPU) return dev_id;
return 0;
}
/*!
* \brief Comparator, used to enable Context as std::map key.
* \param b another context to compare
Expand Down Expand Up @@ -200,7 +208,7 @@ struct Context {
return true;
}
/*! \brief the maximal device type */
static const int32_t kMaxDevType = 4;
static const int32_t kMaxDevType = 6;
/*! \brief the maximal device index */
static const int32_t kMaxDevID = 16;
/*!
Expand All @@ -223,6 +231,12 @@ struct Context {
* \return Pinned CPU context. -1 for current GPU.
*/
inline static Context CPUPinned(int32_t dev_id = -1);
/*!
* Create a CPU shared memory context.
* \param dev_id dummy device id.
* \return CPU shared memory context.
*/
inline static Context CPUShared(int32_t dev_id = 0);
/*!
* Create a context from string of the format [cpu|gpu|cpu_pinned](n)
* \param str the string pattern
Expand Down Expand Up @@ -273,7 +287,7 @@ inline Context Context::Create(DeviceType dev_type, int32_t dev_id) {
ctx.dev_type = dev_type;
if (dev_id < 0) {
ctx.dev_id = 0;
if (dev_type != kCPU) {
if (dev_type & kGPU) {
#if MXNET_USE_CUDA
CHECK_EQ(cudaGetDevice(&ctx.dev_id), cudaSuccess);
#else
Expand All @@ -293,6 +307,10 @@ inline Context Context::CPUPinned(int32_t dev_id) {
return Create(kCPUPinned, dev_id);
}

inline Context Context::CPUShared(int32_t dev_id) {
return Create(kCPUShared, dev_id);
}

inline Context Context::GPU(int32_t dev_id) {
return Create(kGPU, dev_id);
}
Expand All @@ -313,6 +331,8 @@ inline Context Context::FromString(std::string str) {
ret = GPU(id);
} else if (type == "cpu_pinned") {
ret = CPUPinned(id);
} else if (type == "cpu_shared") {
ret = CPUShared(id);
} else {
LOG(FATAL) << "Invalid context string " << str;
}
Expand All @@ -329,6 +349,8 @@ inline std::ostream& operator<<(std::ostream &out, const Context &ctx) {
out << "gpu(";
} else if (ctx.dev_type == Context::kCPUPinned) {
out << "cpu_pinned(";
} else if (ctx.dev_type == Context::kCPUShared) {
out << "cpu_shared(";
} else {
out << "unknown(";
}
Expand Down
20 changes: 20 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,26 @@ MXNET_DLL int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** ar
mx_uint grid_dim_z, mx_uint block_dim_x,
mx_uint block_dim_y, mx_uint block_dim_z,
mx_uint shared_mem);
/*!
* \brief Get shared memory handle from NDArray
* \param handle NDArray handle.
* \param shared_pid output PID
* \param shared_id output shared memory id.
*/
MXNET_DLL int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid,
int* shared_id);
/*!
* \brief Reconstruct NDArray from shared memory handle
* \param shared_pid shared PID
* \param shared_id shared memory id
* \param shape pointer to NDArray dimensions
* \param ndim number of NDArray dimensions
* \param dtype data type of NDArray
* \param out constructed NDArray
*/
MXNET_DLL int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const mx_uint *shape,
mx_uint ndim, int dtype, NDArrayHandle *out);


#ifdef __cplusplus
}
Expand Down
12 changes: 12 additions & 0 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ class MXNET_API Engine {
* \return 0 when success, -1 when failure happens.
*/
virtual void NotifyShutdown() = 0;
/*!
*\brief Stop all workers in the engine
*/
virtual void Stop() {
LOG(FATAL) << "Engine cannot be stopped";
}
/*!
* \brief Restart all workers in the engine
*/
virtual void Start() {
LOG(FATAL) << "Engine cannot be restarted";
}
/*!
* \brief Allocate a new variable, the variable can then
* be used to schedule the operation concurrently via dependency
Expand Down
27 changes: 27 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ class NDArray {
Mkl_mem_ = std::make_shared<MKLMemHolder>();
#endif
}
/*! \brief create ndarray from shared memory */
NDArray(int shared_pid, int shared_id, const TShape& shape, int dtype)
: ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)), shape_(shape),
dtype_(dtype), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
#if MKL_EXPERIMENTAL == 1
Mkl_mem_ = std::make_shared<MKLMemHolder>();
#endif
}

/*!
* \brief constructing a static NDArray of non-default storage that shares data with TBlob
Expand Down Expand Up @@ -317,6 +325,13 @@ class NDArray {
}
return true;
}
/*! \brief get storage handle */
inline Storage::Handle storage_handle() const {
CHECK(!is_none());
CHECK_EQ(storage_type(), kDefaultStorage);
CheckAndAlloc();
return ptr_->shandle;
}
/*!
* \brief Block until all the pending write operations with respect
* to current NDArray are finished, and read can be performed.
Expand Down Expand Up @@ -682,6 +697,18 @@ class NDArray {
shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_);
storage_shape = data.shape_;
}

Chunk(int shared_pid, int shared_id, const TShape& shape, int dtype)
: static_data(false), delay_alloc(false) {
var = Engine::Get()->NewVariable();
ctx = Context::CPUShared(0);
shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);;
shandle.ctx = ctx;
shandle.shared_pid = shared_pid;
shandle.shared_id = shared_id;
Storage::Get()->Alloc(&shandle);
storage_shape = shape;
}
// Constructor for a non-default storage chunk
Chunk(NDArrayStorageType storage_type_, const TShape &storage_shape_, Context ctx_,
bool delay_alloc_, int dtype, const std::vector<int> &aux_types_,
Expand Down
23 changes: 22 additions & 1 deletion include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,35 @@ class Storage {
* \brief Context information about device and ID.
*/
Context ctx;
/*!
* \brief Id for IPC shared memory
*/
int shared_pid{-1};
int shared_id{-1};
};
/*!
* \brief Allocate a new contiguous memory for a given size.
* \param size Total size of memory in bytes.
* \param ctx Context information about the device and ID.
* \return Handle struct.
*/
virtual Handle Alloc(size_t size, Context ctx) = 0;
Handle Alloc(size_t size, Context ctx) {
Handle hd;
hd.size = size;
hd.ctx = ctx;
this->Alloc(&hd);
return hd;
}
/*!
* \brief Allocate a new contiguous memory for a given size.
* \param handle handle initialized with size and ctx
*/
virtual void Alloc(Handle* handle) = 0;
/*!
* \brief Increase ref counter on shared memory.
* \param handle handle to shared memory.
*/
virtual void SharedIncrementRefCount(Handle handle) = 0;
/*!
* \brief Free storage.
* \param handle Handle struect.
Expand Down
51 changes: 41 additions & 10 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class Context(object):
"""
# static class variable
default_ctx = None
devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned'}
devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3}
devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'}
devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5}
def __init__(self, device_type, device_id=0):
if isinstance(device_type, Context):
self.device_typeid = device_type.device_typeid
Expand Down Expand Up @@ -128,14 +128,13 @@ def cpu(device_id=0):
Examples
----------
>>> with mx.Context('cpu', 1):
>>> with mx.cpu():
... cpu_array = mx.nd.ones((2, 3))
>>> cpu_array.context
cpu(1)
>>> with mx.cpu(1):
... cpu_array = mx.nd.ones((2, 3))
cpu(0)
>>> cpu_array = mx.nd.ones((2, 3), ctx=mx.cpu())
>>> cpu_array.context
cpu(1)
cpu(0)
Parameters
----------
Expand All @@ -151,6 +150,36 @@ def cpu(device_id=0):
return Context('cpu', device_id)


def cpu_pinned(device_id=0):
"""Returns a CPU pinned memory context. Copying from CPU pinned memory to GPU
is faster than from normal CPU memory.
This function is a short cut for ``Context('cpu_pinned', device_id)``.
Examples
----------
>>> with mx.cpu_pinned():
... cpu_array = mx.nd.ones((2, 3))
>>> cpu_array.context
cpu_pinned(0)
>>> cpu_array = mx.nd.ones((2, 3), ctx=mx.cpu_pinned())
>>> cpu_array.context
cpu_pinned(0)
Parameters
----------
device_id : int, optional
The device id of the device. `device_id` is not needed for CPU.
This is included to make interface compatible with GPU.
Returns
-------
context : Context
The corresponding CPU pinned memory context.
"""
return Context('cpu_pinned', device_id)


def gpu(device_id=0):
"""Returns a GPU context.
Expand All @@ -159,12 +188,14 @@ def gpu(device_id=0):
Examples
----------
>>> with mx.Context('gpu', 1):
>>> cpu_array = mx.nd.ones((2, 3))
>>> cpu_array.context
cpu(0)
>>> with mx.gpu(1):
... gpu_array = mx.nd.ones((2, 3))
>>> gpu_array.context
gpu(1)
>>> with mx.gpu(1):
... gpu_array = mx.nd.ones((2, 3))
>>> gpu_array = mx.nd.ones((2, 3), ctx=mx.gpu(1))
>>> gpu_array.context
gpu(1)
Expand Down

0 comments on commit b3ac0ab

Please sign in to comment.