Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add new OP take (#4715)
Browse files Browse the repository at this point in the history
  • Loading branch information
WellyZhang authored and piiswrong committed Jan 18, 2017
1 parent d49d566 commit 8abf132
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 6 deletions.
46 changes: 44 additions & 2 deletions src/operator/tensor/indexing_op.cc
@@ -1,8 +1,8 @@
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2017 by Contributors
* \file indexing_op.cc
* \brief
* \author Siyi Li
* \author Siyi Li, Chi Zhang
*/

#include "./indexing_op.h"
Expand Down Expand Up @@ -48,5 +48,47 @@ NNVM_REGISTER_OP(_backward_Embedding)
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>);

DMLC_REGISTER_PARAMETER(TakeParam);

NNVM_REGISTER_OP(take)
.MXNET_DESCRIBE("Take row vectors from an NDArray according to the indices"
" For an input of index with shape (d1, ..., dK), the output"
" shape is (d1, ..., dK, row_vector_length).All the input"
" values should be integers in the range"
" [0, column_vector_length).")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(TakeParamParser<TakeParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "indices"};
})
.set_attr<nnvm::FInferShape>("FInferShape", TakeOpShape)
.set_attr<nnvm::FInferType>("FInferType", TakeOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", TakeOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
heads.push_back(n->inputs[1]);
return MakeGradNode("_backward_take", n, heads, n->attrs.dict);
})
.add_argument("a", "Symbol", "The source array.")
.add_argument("indices", "Symbol", "The indices of the values to extract.")
.add_arguments(TakeParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_take)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", TakeOpBackward<cpu>);
} // namespace op
} // namespace mxnet
10 changes: 8 additions & 2 deletions src/operator/tensor/indexing_op.cu
@@ -1,8 +1,8 @@
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2017 by Contributors
* \file indexing_op.cu
* \brief
* \author Siyi Li
* \author Siyi Li, Chi Zhang
*/

#include "./indexing_op.h"
Expand All @@ -13,6 +13,12 @@ NNVM_REGISTER_OP(Embedding)

NNVM_REGISTER_OP(_backward_Embedding)
.set_attr<FCompute>("FCompute<gpu>", EmbeddingOpBackward<gpu>);

NNVM_REGISTER_OP(take)
.set_attr<FCompute>("FCompute<gpu>", TakeOpForward<gpu>);

NNVM_REGISTER_OP(_backward_take)
.set_attr<FCompute>("FCompute<gpu>", TakeOpBackward<gpu>);
} // namespace op
} // namespace mxnet

172 changes: 170 additions & 2 deletions src/operator/tensor/indexing_op.h
@@ -1,8 +1,8 @@
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2017 by Contributors
* \file indexing_op.h
* \brief
* \author Bing Xu, Siyi Li
* \author Bing Xu, Siyi Li, Chi Zhang
*/
#ifndef MXNET_OPERATOR_TENSOR_INDEXING_OP_H_
#define MXNET_OPERATOR_TENSOR_INDEXING_OP_H_
Expand Down Expand Up @@ -158,6 +158,174 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs,
});
}

namespace take_ { // to avoid name conflict
enum TakeOpInputs {kArr, kIdx};
enum TakeOpOutputs {kOut};
enum TakeOpResource {kTempSpace};
enum TakeOpMode {kRaise, kWrap, kClip};
} // namespace take_

// TODO(somebody): behaviors specified by params
struct TakeParam: public dmlc::Parameter<TakeParam> {
int axis;
int mode;
DMLC_DECLARE_PARAMETER(TakeParam) {
DMLC_DECLARE_FIELD(axis)
.set_lower_bound(0)
.set_default(0)
.describe("the axis of data tensor to be taken.");
DMLC_DECLARE_FIELD(mode)
.add_enum("raise", take_::kRaise)
.add_enum("wrap", take_::kWrap)
.add_enum("clip", take_::kClip)
.set_default(take_::kRaise)
.describe("specify how out-of-bound indices bahave.");
}
};

template<typename PType>
inline void TakeParamParser(nnvm::NodeAttrs *attrs) {
PType param;
param.Init(attrs->dict);
if (param.axis != 0) {
LOG(FATAL) << "Axis other than 0 currently not supported.";
}
if (param.mode != take_::kRaise) {
LOG(FATAL) << "Mode other than raise currently not supported.";
}
}

inline bool TakeOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
using namespace mshadow;
const TShape &arrshape = (*in_attrs)[take_::kArr];
const TShape &idxshape = (*in_attrs)[take_::kIdx];
if (idxshape.ndim() == 0) return false;

out_attrs->clear();

TShape oshape(idxshape.ndim() + arrshape.ndim() - 1);
for (size_t i = 0; i < idxshape.ndim(); ++i) {
oshape[i] = idxshape[i];
}
for (size_t i = 0; i < arrshape.ndim() - 1; i++) {
oshape[i + idxshape.ndim()] = arrshape[i + 1];
}
out_attrs->push_back(oshape);
return true;
}

inline bool TakeOpType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
// using single dtype ("float32") for safety reason
CHECK_GE(in_type->size(), 2);
int dtype = (*in_type)[1];
CHECK_NE(dtype, -1) << "idx must have specified type";
for (index_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
} else {
CHECK_EQ((*in_type)[i], dtype) << "This layer requires uniform type. "
<< "Expected " << dtype << " v.s. given "
<< (*in_type)[i];
}
}
out_type->clear();
out_type->push_back(dtype);
return true;
}

template<typename xpu>
void TakeOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(req[take_::kOut], kWriteTo);
CHECK_EQ(inputs.size(), 2);
CHECK_EQ(outputs.size(), 1);
CHECK_GE(inputs[take_::kArr].ndim(), 2)
<< "take layer expects its array's size to be at least 2. "
<< inputs[take_::kArr].ndim()
<< " dimensional input is given instead";

const TShape& idxshape = inputs[take_::kIdx].shape_;
const TShape& arrshape = inputs[take_::kArr].shape_;
const TShape& oshape = outputs[take_::kOut].shape_;

int idxndim = idxshape.ndim();

Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> idx = inputs[take_::kIdx].get_with_shape<xpu, 1, DType>(
Shape1(idxshape.ProdShape(0, idxndim)), s);
Tensor<xpu, 2, DType> data = inputs[take_::kArr].get_with_shape<xpu, 2, DType>(
Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s);
Tensor<xpu, 2, DType> out = outputs[take_::kOut].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s);
out = take(idx, data);
});
}

template<typename xpu>
void TakeOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 2);
CHECK_EQ(outputs.size(), 2);
CHECK_EQ(req[take_::kIdx], kNullOp)
<< "take layer doesn't support gradient into index";

// inputs are specified in the .cc file, which are the gradients from
// the upper layer and the input index
// outputs are the gradients of inputs in the feed-forward pass
const TShape& idxshape = inputs[1].shape_;
const TShape& arrshape = outputs[0].shape_;
const TShape& oshape = inputs[0].shape_;

int idxndim = idxshape.ndim();

// grad_out is the gradient of the outputs in the feed-forward
// grad_in is the gradient of the inputs in the feed-forward
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> idx = inputs[1].get_with_shape<xpu, 1, DType>(
Shape1(idxshape.ProdShape(0, idxndim)), s);
Tensor<xpu, 2, DType> grad_out = inputs[0].get_with_shape<xpu, 2, DType>(
Shape2(oshape.ProdShape(0, idxndim), oshape.ProdShape(idxndim, oshape.ndim())), s);
Tensor<xpu, 2, DType> grad_in = outputs[0].get_with_shape<xpu, 2, DType>(
Shape2(arrshape[0], arrshape.ProdShape(1, arrshape.ndim())), s);

if (req[take_::kArr] == kWriteTo || req[take_::kArr] == kAddTo) {
if (req[take_::kArr] == kWriteTo) {
grad_in = scalar<DType>(0.0f);
}
if ((grad_out.shape_[0] < grad_out.shape_[1]) && (grad_out.shape_[0] < 512)) {
AddTakeGrad(grad_in, idx, grad_out);
} else {
Tensor<xpu, 2, int> workspace =
ctx.requested[take_::kTempSpace].get_space_typed<xpu, 2, int>(
mshadow::Shape2(2, idx.shape_.Size()), s);
Tensor<xpu, 1, int> sorted_idx = workspace[0];
Tensor<xpu, 1, int> original_idx = workspace[1];
sorted_idx = tcast<int>(idx);
original_idx = range<int>(0, idx.shape_.Size());
SortByKey(sorted_idx, original_idx, true);
AddTakeGradLargeBatch(grad_in, sorted_idx, original_idx, grad_out);
}
} else {
LOG(FATAL) << "wrong req";
}
});
}

} // namespace op
} // namespace mxnet
Expand Down
44 changes: 44 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Expand Up @@ -208,6 +208,49 @@ def test_embedding_with_type():
check_consistency(sym, ctx_list, grad_req={'embedding_data': 'null','embedding_weight': 'write'},
arg_params=arg_params)

def test_take_with_type():
sym = mx.sym.take(name='take')
for data_ndim in range(2, 5):
for idx_ndim in range(1, 4):
data_shape = ()
for _ in range(data_ndim):
data_shape += (np.random.randint(low=3, high=6), )
idx_shape = ()
for _ in range(idx_ndim):
idx_shape += (np.random.randint(low=3, high=5), )
ctx_list = [{'ctx': mx.gpu(0), 'take_indices': idx_shape,
'take_a': data_shape,
'type_dict': {'take_indices': np.float64,
'take_a': np.float64}},
{'ctx': mx.gpu(0), 'take_indices': idx_shape,
'take_a': data_shape,
'type_dict': {'take_indices': np.float32,
'take_a': np.float32}},
{'ctx': mx.gpu(0), 'take_indices': idx_shape,
'take_a': data_shape,
'type_dict': {'take_indices': np.float16,
'take_a': np.float16}},
{'ctx': mx.cpu(0), 'take_indices': idx_shape,
'take_a': data_shape,
'type_dict': {'take_indices': np.float64,
'take_a': np.float64}},
{'ctx': mx.cpu(0), 'take_indices': idx_shape,
'take_a': data_shape,
'type_dict': {'take_indices': np.float32,
'take_a': np.float32}},
{'ctx': mx.cpu(0), 'take_indices': idx_shape,
'take_a': data_shape,
'type_dict': {'take_indices': np.float16,
'take_a': np.float16}}]
arg_params = {'take_indices': np.random.randint(low=0,
high=data_shape[0],
size=idx_shape),
'take_a': np.random.normal(size=data_shape)}
check_consistency(sym, ctx_list,
grad_req={'take_indices': 'null',
'take_a': 'write'},
arg_params=arg_params)

if __name__ == '__main__':
test_convolution_options()
test_convolution_with_type()
Expand All @@ -224,4 +267,5 @@ def test_embedding_with_type():
test_fullyconnected_with_type()
test_activation_with_type()
test_embedding_with_type()
test_take_with_type()

17 changes: 17 additions & 0 deletions tests/python/unittest/test_ndarray.py
Expand Up @@ -563,6 +563,22 @@ def test_ndarray_lesser_equal():
z = 1 <= y
assert (z.asnumpy() == np.ones((2, 3))).all()

def test_take():
for data_ndim in range(2, 5):
for idx_ndim in range(1, 4):
data_shape = ()
for _ in range(data_ndim):
data_shape += (np.random.randint(low=3, high=6), )
data_real = np.random.normal(size=data_shape).astype('float32')
idx_shape = ()
for _ in range(idx_ndim):
idx_shape += (np.random.randint(low=3, high=5), )
idx_real = np.random.randint(low=0, high=data_shape[0], size=idx_shape)
data_real_mx = mx.nd.array(data_real)
idx_real_mx = mx.nd.array(idx_real)
result = mx.nd.take(data_real_mx, idx_real_mx)
assert reldiff(result.asnumpy(), data_real[idx_real]) < 1e-6

if __name__ == '__main__':
test_broadcast_binary()
test_ndarray_setitem()
Expand All @@ -586,3 +602,4 @@ def test_ndarray_lesser_equal():
test_arange()
test_order()
test_ndarray_equal()
test_take()
36 changes: 36 additions & 0 deletions tests/python/unittest/test_operator.py
Expand Up @@ -2268,6 +2268,41 @@ def test_blockgrad():
assert_almost_equal(exe.outputs[0].asnumpy(), a_npy)
exe.backward() # No error if BlockGrad works

def test_take():
def check_output_n_grad(data_shape, idx_shape):
exe = result.simple_bind(default_context(), a=data_shape,
indices=idx_shape)
data_real = np.random.normal(size=data_shape).astype('float32')
idx_real = np.random.randint(low=0, high=data_shape[0], size=idx_shape)
grad_out = np.ones(idx_shape + data_shape[1:], dtype='float32')
grad_in = np.zeros(data_shape, dtype='float32')

exe.arg_dict['a'][:] = mx.nd.array(data_real)
exe.arg_dict['indices'][:] = mx.nd.array(idx_real)
exe.forward()
assert reldiff(exe.outputs[0].asnumpy(), data_real[idx_real]) < 1e-6

for i in np.nditer(idx_real):
grad_in[i] += 1.0

exe.backward([mx.nd.array(grad_out)])
assert reldiff(exe.grad_dict['a'].asnumpy(), grad_in) < 1e-6

data = mx.sym.Variable('a')
idx = mx.sym.Variable('indices')
idx = mx.sym.BlockGrad(idx)
result = mx.sym.take(a=data, indices=idx)

for data_ndim in range(2, 5):
for idx_ndim in range(1, 4):
data_shape = ()
for _ in range(data_ndim):
data_shape += (np.random.randint(low=3, high=6), )
idx_shape = ()
for _ in range(idx_ndim):
idx_shape += (np.random.randint(low=3, high=5), )
check_output_n_grad(data_shape, idx_shape)


if __name__ == '__main__':
test_init()
Expand Down Expand Up @@ -2319,3 +2354,4 @@ def test_blockgrad():
test_special_functions_using_scipy()
test_order()
test_blockgrad()
test_take()

0 comments on commit 8abf132

Please sign in to comment.