Permalink
Browse files

Keep the original function name and add broadcast_keepdim & reduce_ke…

…epdim
  • Loading branch information...
1 parent a90696e commit 997d042e6da8ebf1c21f640a956d71eda9702fe3 @sxjscience sxjscience committed May 28, 2016
Showing with 47 additions and 15 deletions.
  1. +21 −6 mshadow/extension/broadcast_with_axis.h
  2. +22 −5 mshadow/extension/reduce_with_axis.h
  3. +4 −4 test/test_tblob.cc
@@ -34,8 +34,9 @@ struct BroadcastWithAxisExp:
/*! \brief size of the last dimension of src*/
index_t last_;
/*! constructor */
- BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size, int keepdim)
+ BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
: src_(src), size_(size) {
+ bool keepdim = (dimsrc == dimdst);
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
this->trailing_ = 1;
@@ -71,19 +72,33 @@ struct BroadcastWithAxisExp:
}; // struct BroadcastWithAxisExp
/*!
- * \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis.
- * \param keepdim whether to keepdim
+ * \brief Broadcasting the tensor after given axis.
* \param SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
-template<int keepdim, typename SrcExp, typename DType, int etype>
+template<typename SrcExp, typename DType, int etype>
inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
- ExpInfo<SrcExp>::kDim + 1 - keepdim>
+ ExpInfo<SrcExp>::kDim + 1>
broadcast_with_axis(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
- ExpInfo<SrcExp>::kDim + 1 - keepdim>(src.self(), axis, size, keepdim);
+ ExpInfo<SrcExp>::kDim + 1>(src.self(), axis, size);
}
+
+/*!
+* \brief Broadcasting the tensor in the given axis (keepdim turned on)
+* \param SrcExp source expression
+* \tparam DType data type
+* \tparam etype type of the expression
+*/
+template<typename SrcExp, typename DType, int etype>
+inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
+ ExpInfo<SrcExp>::kDim>
+ broadcast_keepdim(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
+ return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
+ ExpInfo<SrcExp>::kDim>(src.self(), axis, size);
+}
+
//----------------------
// Execution plan
//----------------------
@@ -32,8 +32,9 @@ struct ReduceWithAxisExp:
/*! \brief size of last src dimension */
index_t last_;
/*! constructor */
- explicit ReduceWithAxisExp(const SrcExp &src, int axis, int keepdim)
+ explicit ReduceWithAxisExp(const SrcExp &src, int axis)
: src_(src) {
+ bool keepdim = (dimsrc == dimdst);
CHECK(dimsrc > axis) << "reduce axis out of bound";
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src_);
for (index_t i = 0; i < axis; ++i) {
@@ -63,18 +64,34 @@ struct ReduceWithAxisExp:
* \brief reduce out the dimension of src labeled by axis.
* \param Reducer type of the reducing operation
* \param mask whether to output the unmask indices
- * \param keepdim the keepdim flag
* \tparam SrcExp source expression
* \tparam DType data type
* \tparam etype type of the expression
*/
-template<typename Reducer, bool mask, int keepdim, typename SrcExp, typename DType, int etype>
+template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
- ExpInfo<SrcExp>::kDim + keepdim - 1>
+ ExpInfo<SrcExp>::kDim - 1>
reduce_with_axis(const Exp<SrcExp, DType, etype> &src, int axis) {
return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
- ExpInfo<SrcExp>::kDim + keepdim - 1>(src.self(), axis, keepdim);
+ ExpInfo<SrcExp>::kDim- 1>(src.self(), axis);
}
+
+/*!
+* \brief reduce out the dimension of src labeled by axis, keepdim turned on.
+* \param Reducer type of the reducing operation
+* \param mask whether to output the unmask indices
+* \tparam SrcExp source expression
+* \tparam DType data type
+* \tparam etype type of the expression
+*/
+template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
+inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
+ ExpInfo<SrcExp>::kDim>
+ reduce_keepdim(const Exp<SrcExp, DType, etype> &src, int axis) {
+ return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
+ ExpInfo<SrcExp>::kDim>(src.self(), axis);
+}
+
//----------------------
// Execution plan
//----------------------
View
@@ -80,7 +80,7 @@ void test_broadcast_with_axis() {
input_tensor = 11;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, test_shapes[dim + 1]);
mshadow::AllocSpace(&n_tensor);
- n_tensor = broadcast_with_axis<0>(input_tensor, dim, 5);
+ n_tensor = broadcast_with_axis(input_tensor, dim, 5);
printf("Test for keepdim = 0, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
@@ -100,7 +100,7 @@ void test_broadcast_with_axis() {
input_tensor = 11;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, test_shapes[dim]);
mshadow::AllocSpace(&n_tensor);
- n_tensor = broadcast_with_axis<1>(input_tensor, dim, 5);
+ n_tensor = broadcast_keepdim(input_tensor, dim, 5);
printf("Test for keepdim = 1, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
@@ -136,7 +136,7 @@ void test_reduce_with_axis() {
input_tensor = 1;
mshadow::Tensor<mshadow::cpu, 3> n_tensor(NULL, mshadow::Shape3(2, 3, 4));
mshadow::AllocSpace(&n_tensor);
- n_tensor = reduce_with_axis<mshadow::red::sum, false, 0>(input_tensor, dim);
+ n_tensor = reduce_with_axis<mshadow::red::sum, false>(input_tensor, dim);
printf("Test for keepdim = 0, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {
@@ -154,7 +154,7 @@ void test_reduce_with_axis() {
input_tensor = 1;
mshadow::Tensor<mshadow::cpu, 4> n_tensor(NULL, keepdim_output_shapes[dim]);
mshadow::AllocSpace(&n_tensor);
- n_tensor = reduce_with_axis<mshadow::red::sum, false, 1>(input_tensor, dim);
+ n_tensor = reduce_keepdim<mshadow::red::sum, false>(input_tensor, dim);
printf("Test for keepdim = 1, dim = %d", dim);
for (index_t i = 0; i < n_tensor.shape_[0]; i++) {
for (index_t j = 0; j < n_tensor.shape_[1]; j++) {

0 comments on commit 997d042

Please sign in to comment.