Skip to content

Commit

Permalink
Expose kWriteInplace for imperative execution (fcompute_ex and fstate…
Browse files Browse the repository at this point in the history
…fulcompute_ex) (apache#133)

* expose kWriteInplace to FComputeEx and FStatefulComputeEx

* refactor ccode

* remove duplicated test
  • Loading branch information
eric-haibin-lin committed Jul 28, 2017
1 parent 88eaac6 commit 3b94a3c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 61 deletions.
17 changes: 17 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,21 @@ void SetDependency(std::vector<engine::VarHandle> *p_read_vars,
Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars);
}

inline void SetWriteInplaceReq(const std::vector<NDArray> &ndinputs,
const std::vector<NDArray> &ndoutputs,
std::vector<OpReqType> *req) {
std::unordered_set<engine::VarHandle> in_vars;
for (auto &nd : ndinputs) {
in_vars.insert(nd.var());
}
for (size_t i = 0; i < ndoutputs.size(); i++) {
// output NDArray shares the memory with the input NDArray
if (in_vars.find(ndoutputs[i].var()) != in_vars.end()) {
req->at(i) = kWriteInplace;
}
}
}

void PushFCompute(const FCompute& fn,
const nnvm::Op* op,
const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -332,6 +347,7 @@ void PushFComputeEx(const FComputeEx& fn,
engine::CallbackOnComplete(),
requested};
std::vector<OpReqType> req(ndoutputs.size(), kWriteTo);
SetWriteInplaceReq(ndinputs, ndoutputs, &req);
fn(attrs, opctx, ndinputs, req, ndoutputs);
if (ctx.dev_mask() == gpu::kDevMask) {
rctx.get_stream<gpu>()->Wait();
Expand Down Expand Up @@ -406,6 +422,7 @@ void PushOperator(const OpStatePtr& state,
engine::CallbackOnComplete on_complete) {
OpContext opctx{is_train, rctx, on_complete, requested};
std::vector<OpReqType> req(ndoutputs.size(), kWriteTo);
SetWriteInplaceReq(ndinputs, ndoutputs, &req);
fcompute_ex(state, opctx, ndinputs, req, ndoutputs);
if (exec_type == ExecType::kSync) {
if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
Expand Down
45 changes: 17 additions & 28 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ struct SGDDnsRspKernel {
// IType is row sparse idx type
// i is the ith row in row sparse gradient
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, size_t row_length, DType* out, const DType* weight,
MSHADOW_XINLINE static void Map(int i, const index_t row_length, DType* out, const DType* weight,
const IType* grad_idx, const DType *grad_val,
const DType clip_gradient, const DType lr,
const DType wd, const DType rescale_grad) {
for (size_t j = 0; j < row_length; j++) {
uint64_t data_i = grad_idx[i] * row_length + j;
uint64_t grad_i = i * row_length + j;
for (index_t j = 0; j < row_length; j++) {
index_t data_i = grad_idx[i] * row_length + j;
index_t grad_i = i * row_length + j;
if (clip_gradient >= 0.0f) {
KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
(lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient));
Expand All @@ -126,6 +126,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
CHECK_EQ(grad.storage_type(), kRowSparseStorage);
// if gradients are zeros, no weights are updated
if (!grad.storage_initialized() || req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK_GT(weight.shape_.Size(), 0);

MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
Expand All @@ -151,7 +152,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
template<int req>
struct SGDRspDnsKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, const DType* weight,
MSHADOW_XINLINE static void Map(int i, const index_t num_cols, DType* out, const DType* weight,
const DType *grad, const DType clip_gradient, const DType lr,
const DType wd, const DType rescale_grad) {
bool contains_non_zeros = false;
Expand Down Expand Up @@ -191,6 +192,7 @@ inline void SGDUpdateRspDnsImpl(const SGDParam& param,
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
CHECK_EQ(weight.storage_type(), kRowSparseStorage);
if (req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_update";
CHECK(weight.storage_initialized());
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
Expand All @@ -216,14 +218,9 @@ inline void SGDUpdateRspRspImpl(const SGDParam& param,
const OpReqType& req,
NDArray *out) {
CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights");
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
// reuse dns rsp implementation when storage_shape == shape
TBlob out_blob = out->data();
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, out_req, &out_blob);
SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob);
}

template<typename xpu>
Expand Down Expand Up @@ -425,14 +422,14 @@ inline void MP_SGDMomUpdate(const nnvm::NodeAttrs& attrs,
template<int req>
struct SGDMomDnsRspDnsKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, size_t row_length, DType* out_data,
MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
DType* mom_data, const DType* weight_data, const IType* grad_idx,
const DType* grad_data, const DType clip_gradient, const DType momentum,
const DType lr, const DType wd, const DType rescale_grad) {
const DType rate = lr * wd;
for (size_t j = 0; j < row_length; j++) {
uint64_t data_i = grad_idx[i] * row_length + j;
uint64_t grad_i = i * row_length + j;
for (index_t j = 0; j < row_length; j++) {
index_t data_i = grad_idx[i] * row_length + j;
index_t grad_i = i * row_length + j;
if (clip_gradient >= 0.0f) {
mom_data[data_i] = momentum * mom_data[data_i]
- rate * weight_data[data_i]
Expand Down Expand Up @@ -461,6 +458,7 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
using namespace rowsparse;
Stream<xpu>* s = ctx.get_stream<xpu>();
if (!grad.storage_initialized() || req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK_GT(weight.shape_.Size(), 0);
CHECK_GT(mom.shape_.Size(), 0);

Expand All @@ -487,7 +485,7 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param,
template<int req>
struct SGDMomRspDnsKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, DType* mom,
MSHADOW_XINLINE static void Map(int i, index_t num_cols, DType* out, DType* mom,
const DType* weight, const DType *grad,
const DType clip_gradient, const DType momentum,
const DType lr, const DType wd, const DType rescale_grad) {
Expand Down Expand Up @@ -531,19 +529,15 @@ inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param,
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(weight.storage_type(), kRowSparseStorage);
if (req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK(weight.storage_initialized());
// fill mom with zero values if not initialized yet
if (!mom.storage_initialized()) {
NDArray mom_zeros = mom;
FillDnsZerosRspImpl(s, &mom_zeros);
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, {
MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.data().dptr<DType>();
DType* grad_data = grad.dptr<DType>();
DType* mom_data = mom.data().dptr<DType>();
Expand Down Expand Up @@ -578,15 +572,10 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
NDArray mom_zeros = mom;
FillDnsZerosRspImpl(s, &mom_zeros);
}
// TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only
// feed in kWriteTo as req for all operators.
// For sgd we don't want to assign zeros to the output values when req == kWriteTo
auto out_req = req;
if (out_req == kWriteTo) out_req = kWriteInplace;
TBlob out_blob = out->data();
// reuse dns rsp implementation when storage_shape == shape
SGDMomUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
mom.data(), out_req, &out_blob);
mom.data(), req, &out_blob);
}

template<typename xpu>
Expand Down
35 changes: 2 additions & 33 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,45 +81,14 @@ def check_sparse_nd_copy(from_stype, to_stype, shape):


def test_sparse_nd_basic():
def check_rsp_creation(values, indices, shape):
rsp = mx.nd.row_sparse(values, indices, shape)
dns = mx.nd.zeros(shape)
dns[1] = mx.nd.array(values[0])
dns[3] = mx.nd.array(values[1])
indices_np = mx.nd.array(indices, dtype='int64').asnumpy()
assert_almost_equal(rsp.indices.asnumpy(), indices_np)

def check_csr_creation(shape):
csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr')
assert_almost_equal(csr.indptr.asnumpy(), indptr)
assert_almost_equal(csr.indices.asnumpy(), indices)
assert_almost_equal(csr.data.asnumpy(), values)

def check_sparse_nd_rsp_aux():
def check_sparse_nd_basic_rsp():
storage_type = 'row_sparse'
shape = rand_shape_2d()
nd, (v, idx) = rand_sparse_ndarray(shape, storage_type)
assert(nd._num_aux == 1)
assert(nd.indices.dtype == np.int64)
assert(nd.stype == 'row_sparse')
assert_almost_equal(nd.indices.asnumpy(), idx)
assert_almost_equal(nd.data.asnumpy(), v)

shape = (4,2)
values = np.random.rand(2,2)
indices = np.array([1,3], dtype='int64')
check_rsp_creation(values, indices, shape)

values = mx.nd.array(np.random.rand(2,2))
indices = mx.nd.array([1,3], dtype='int64')
check_rsp_creation(values, indices, shape)

values = [[0.1, 0.2], [0.3, 0.4]]
indices = [1,3]
check_rsp_creation(values, indices, shape)

check_csr_creation(shape)
check_sparse_nd_rsp_aux()
check_sparse_nd_basic_rsp()


def test_sparse_nd_setitem():
Expand Down

0 comments on commit 3b94a3c

Please sign in to comment.