Permalink
Browse files

boardcast mask

  • Loading branch information...
1 parent 787ee96 commit 69f09705f2ae24270514d506040dd59963a919b1 @antinucleon antinucleon committed Aug 12, 2016
Showing with 114 additions and 0 deletions.
  1. +16 −0 guide/basic.cpp
  2. +1 −0 mshadow/extension.h
  3. +97 −0 mshadow/extension/mask.h
View
@@ -140,6 +140,22 @@ int main(void) {
printf("\n");
}
+ printf("mask\n");
+ TensorContainer<cpu, 2> mask_data(Shape2(6, 8));
+ TensorContainer<cpu, 2> mask_out(Shape2(6, 8));
+ TensorContainer<cpu, 1> mask_src(Shape1(6));
+
+ mask_data = 1.0f;
+ for (int i = 0; i < 6; ++i) {
+ mask_src[i] = static_cast<float>(i);
+ }
+ mask_out = mask(mask_src, mask_data);
+ for (index_t i = 0; i < mask_out.size(0); ++i) {
+ for (index_t j = 0; j < mask_out.size(1); ++j) {
+ printf("%.2f ", mask_out[i][j]);
+ }
+ printf("\n");
+ }
ShutdownTensorEngine<cpu>();
return 0;
}
View
@@ -37,4 +37,5 @@
#include "./extension/flip.h"
#include "./extension/complex.h"
#include "./extension/range.h"
+#include "./extension/mask.h"
#endif // MSHADOW_EXTENSION_H_
@@ -0,0 +1,97 @@
+/*!
+ * Copyright (c) 2016 by Contributors
+ * \file mask.h
+ * \brief
+ * \author Bing Xu
+*/
+#ifndef MSHADOW_EXTENSION_MASK_H_
+#define MSHADOW_EXTENSION_MASK_H_
+
+#include "../extension.h"
+
+namespace mshadow {
+namespace expr {
+
+/*! \brief Broadcast a mask and do element-wise multiplication
+ * \tparam IndexExp type of index expression
+ * \tparam SrcExp type of src expression
+ * \tparam DType data type
+ */
+template<typename IndexExp, typename SrcExp, typename DType>
+struct MaskExp: public Exp<MaskExp<IndexExp, SrcExp, DType>,
+ DType, type::kChainer> {
+ /*! \brief index oprand */
+ const IndexExp &index_;
+ /*! \brief matrix oprand */
+ const SrcExp &src_;
+ /*! constructor */
+ MaskExp(const IndexExp &index, const SrcExp &src)
+ : index_(index), src_(src) {}
+}; // struct MaskExp
+
+
+
+template<typename IndexExp,
+ typename SrcExp,
+ typename DType,
+ int e1, int e2>
+inline MaskExp<IndexExp, SrcExp, DType>
+mask(const Exp<IndexExp, DType, e1> &index,
+ const Exp<SrcExp, DType, e2> &src) {
+ return MaskExp<IndexExp, SrcExp, DType>(index.self(), src.self());
+}
+
+
+//----------------------
+// Execution plan
+//----------------------
+
+template<typename IndexExp, typename SrcExp, typename DType>
+struct Plan<MaskExp<IndexExp, SrcExp, DType>, DType> {
+ public:
+ explicit Plan(const MaskExp<IndexExp, SrcExp, DType> &e)
+ : index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) {
+ }
+
+ MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
+ return static_cast<DType>(src_.Eval(y, x) * index_.Eval(0, y));
+ }
+
+ private:
+ expr::Plan<IndexExp, DType> index_;
+ expr::Plan<SrcExp, DType> src_;
+}; // struct Plan
+
+template<typename IndexExp, typename SrcExp, typename DType>
+inline Plan<MaskExp<IndexExp, SrcExp, DType>, DType>
+MakePlan(const MaskExp<IndexExp, SrcExp, DType> &exp) {
+ return Plan<MaskExp<IndexExp, SrcExp, DType>, DType>(exp);
+}
+
+template<int dim, typename IndexExp, typename SrcExp, typename DType>
+struct ShapeCheck<dim, MaskExp<IndexExp, SrcExp, DType> > {
+ inline static Shape<dim>
+ Check(const MaskExp<IndexExp, SrcExp, DType> &t) {
+ CHECK(dim == 2)
+ << "MaskExp only support 2D output";
+ Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_);
+ Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_);
+ CHECK_EQ(dshape[0], wshape[0]) << "MaskExp require inputs in same first dimention";
+ Shape<dim> ret;
+ ret[0] = wshape[0];
+ ret[1] = wshape[1];
+ return ret;
+ }
+};
+
+
+template<typename IndexExp, typename SrcExp, typename DType>
+struct ExpInfo<MaskExp<IndexExp, SrcExp, DType> > {
+ static const int kDim = 2;
+ static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
+};
+
+} // namespace expr
+} // namespace mshadow
+
+#endif // MSHADOW_EXTENSION_MASK_H_

0 comments on commit 69f0970

Please sign in to comment.