Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
boardcast mask
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 12, 2016
1 parent 787ee96 commit 69f0970
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
16 changes: 16 additions & 0 deletions guide/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ int main(void) {
printf("\n");
}

printf("mask\n");
TensorContainer<cpu, 2> mask_data(Shape2(6, 8));
TensorContainer<cpu, 2> mask_out(Shape2(6, 8));
TensorContainer<cpu, 1> mask_src(Shape1(6));

mask_data = 1.0f;
for (int i = 0; i < 6; ++i) {
mask_src[i] = static_cast<float>(i);
}
mask_out = mask(mask_src, mask_data);
for (index_t i = 0; i < mask_out.size(0); ++i) {
for (index_t j = 0; j < mask_out.size(1); ++j) {
printf("%.2f ", mask_out[i][j]);
}
printf("\n");
}
ShutdownTensorEngine<cpu>();
return 0;
}
1 change: 1 addition & 0 deletions mshadow/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@
#include "./extension/flip.h"
#include "./extension/complex.h"
#include "./extension/range.h"
#include "./extension/mask.h"
#endif // MSHADOW_EXTENSION_H_
97 changes: 97 additions & 0 deletions mshadow/extension/mask.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*!
* Copyright (c) 2016 by Contributors
* \file mask.h
* \brief
* \author Bing Xu
*/
#ifndef MSHADOW_EXTENSION_MASK_H_
#define MSHADOW_EXTENSION_MASK_H_

#include "../extension.h"

namespace mshadow {
namespace expr {

/*! \brief Broadcast a mask and do element-wise multiplication
* \tparam IndexExp type of index expression
* \tparam SrcExp type of src expression
* \tparam DType data type
*/
template<typename IndexExp, typename SrcExp, typename DType>
struct MaskExp: public Exp<MaskExp<IndexExp, SrcExp, DType>,
DType, type::kChainer> {
/*! \brief index oprand */
const IndexExp &index_;
/*! \brief matrix oprand */
const SrcExp &src_;
/*! constructor */
MaskExp(const IndexExp &index, const SrcExp &src)
: index_(index), src_(src) {}
}; // struct MaskExp



template<typename IndexExp,
typename SrcExp,
typename DType,
int e1, int e2>
inline MaskExp<IndexExp, SrcExp, DType>
mask(const Exp<IndexExp, DType, e1> &index,
const Exp<SrcExp, DType, e2> &src) {
return MaskExp<IndexExp, SrcExp, DType>(index.self(), src.self());
}


//----------------------
// Execution plan
//----------------------

template<typename IndexExp, typename SrcExp, typename DType>
struct Plan<MaskExp<IndexExp, SrcExp, DType>, DType> {
public:
explicit Plan(const MaskExp<IndexExp, SrcExp, DType> &e)
: index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) {
}

MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return static_cast<DType>(src_.Eval(y, x) * index_.Eval(0, y));
}

private:
expr::Plan<IndexExp, DType> index_;
expr::Plan<SrcExp, DType> src_;
}; // struct Plan

template<typename IndexExp, typename SrcExp, typename DType>
inline Plan<MaskExp<IndexExp, SrcExp, DType>, DType>
MakePlan(const MaskExp<IndexExp, SrcExp, DType> &exp) {
return Plan<MaskExp<IndexExp, SrcExp, DType>, DType>(exp);
}

template<int dim, typename IndexExp, typename SrcExp, typename DType>
struct ShapeCheck<dim, MaskExp<IndexExp, SrcExp, DType> > {
inline static Shape<dim>
Check(const MaskExp<IndexExp, SrcExp, DType> &t) {
CHECK(dim == 2)
<< "MaskExp only support 2D output";
Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_);
Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_);
CHECK_EQ(dshape[0], wshape[0]) << "MaskExp require inputs in same first dimention";
Shape<dim> ret;
ret[0] = wshape[0];
ret[1] = wshape[1];
return ret;
}
};


template<typename IndexExp, typename SrcExp, typename DType>
struct ExpInfo<MaskExp<IndexExp, SrcExp, DType> > {
static const int kDim = 2;
static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
};

} // namespace expr
} // namespace mshadow

#endif // MSHADOW_EXTENSION_MASK_H_

0 comments on commit 69f0970

Please sign in to comment.