Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
Add transpose_indices
Browse files Browse the repository at this point in the history
Fix SortByKey on CPU

Use float for start/stop/step

Add index_fill

Remove debugging output

remove debugging info

fix error

fix doc + lint

fix lint

update comment
  • Loading branch information
sxjscience committed Nov 23, 2016
1 parent 73f8776 commit 1877b90
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 35 deletions.
46 changes: 46 additions & 0 deletions mshadow/cuda/tensor_gpu-inl.cuh
Expand Up @@ -552,6 +552,9 @@ template<typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
CHECK_EQ(src.CheckContiguous(), true);
const int kUnitBits = kMemUnitBits + 1;
dim3 dimBlock(1 << kUnitBits);
dim3 dimGrid((dst.size(1) + (1 << kUnitBits) - 1) >> kUnitBits);
Expand All @@ -575,6 +578,10 @@ inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(sorted.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
CHECK_EQ(src.CheckContiguous(), true);
const int kWarpBits = kMemUnitBits;
const int SZ = 4;
const int block_dim_x = 1 << kWarpBits;
Expand All @@ -599,6 +606,45 @@ inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
static_cast<int>(src.size(1)));
}

template<int warp_bits, typename DType, typename DstPlan, typename IndexPlan, typename SrcPlan>
__global__ void IndexFillKernel(DstPlan dst,
IndexPlan index, SrcPlan src,
index_t ymax, int xmax) {
int src_idx = blockIdx.x * blockDim.y + threadIdx.y;
if (src_idx < ymax) {
int dst_idx = static_cast<int>(index.Eval(0, src_idx));
for (int i = threadIdx.x; i < xmax; i += blockDim.x) {
dst.REval(dst_idx, i) = src.Eval(src_idx, i);
}
}
}

template<typename IndexType, typename DType>
inline void IndexFill(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
CHECK_EQ(src.CheckContiguous(), true);
CHECK_EQ(dst.size(1), src.size(1)) << "IndexFill: shape mismatch";
CHECK_EQ(index.size(0), src.size(0)) << "IndexFill: shape mismatch";
const int block_dim_x = 1 << kMemUnitBits;
const int block_dim_y = 4;
const int grid_dim_x = (src.size(0) + block_dim_y - 1) / block_dim_y;
dim3 dimBlock(block_dim_x, block_dim_y);
dim3 dimGrid(grid_dim_x);
CheckLaunchParam(dimGrid, dimBlock, "IndexFill");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);

IndexFillKernel<kMemUnitBits, DType>
<<<dimGrid, dimBlock, 0, stream>>>
(expr::MakePlan(dst),
expr::MakePlan(index),
expr::MakePlan(src),
src.size(0),
src.size(1));
}

template<typename KDType, typename VDType>
inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values,
bool is_ascend) {
Expand Down
23 changes: 12 additions & 11 deletions mshadow/extension/range.h
Expand Up @@ -23,18 +23,18 @@ namespace expr {
template<typename DType>
struct RangeExp:
public Exp<RangeExp<DType>, DType, type::kMapper> {
const int start_;
const int stop_;
const int step_;
const float start_;
const float stop_;
const float step_;
const int repeat_;
/*! \brief constructor */
RangeExp(int start, int stop, int step, int repeat)
RangeExp(float start, float stop, float step, int repeat)
: start_(start), stop_(stop), step_(step), repeat_(repeat) {}
};

template<typename DType>
inline RangeExp<DType>
range(int start, int stop, int step = 1, int repeat = 1) {
range(float start, float stop, float step = 1, int repeat = 1) {
return RangeExp<DType>(start, stop, step, repeat);
}

Expand All @@ -51,13 +51,14 @@ struct Plan<RangeExp<DType>, DType> {
repeat_(e.repeat_) {
}
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
return static_cast<DType>(start_ + (static_cast<int>(x) / repeat_) * step_);
return static_cast<DType>(start_ +
static_cast<float>((static_cast<int>(x) / repeat_)) * step_);
}

private:
const int start_;
const int stop_;
const int step_;
const float start_;
const float stop_;
const float step_;
const int repeat_;
};

Expand All @@ -80,11 +81,11 @@ struct ShapeCheck<dim, RangeExp<DType> > {
if (t.step_ > 0) {
CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = "
<< "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
return Shape1(t.repeat_ * ((t.stop_ - 1 - t.start_) / t.step_ + 1));
return Shape1(t.repeat_ * ceil((t.stop_ - t.start_) / t.step_));
} else {
CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= "
<< "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
return Shape1(t.repeat_ * ((t.start_ - 1 - t.stop_) / (- t.step_) + 1));
return Shape1(t.repeat_ * ceil((t.stop_ - t.start_) / t.step_));
}
}
};
Expand Down
126 changes: 119 additions & 7 deletions mshadow/extension/transpose.h
Expand Up @@ -28,17 +28,17 @@ struct TransposeExExp:
/*! \brief source expression */
const SrcExp &src_;
const Shape<dimsrc> axes_;
Shape<dimsrc> dst_stride_;
Shape<dimsrc> dst_in_src_stride_; // Holds the corresponding stride of the dst axes in src
index_t src_stride_;
/*! \brief constructor */
explicit TransposeExExp(const SrcExp &src, Shape<dimsrc> axes) : src_(src), axes_(axes) {
Shape<dimsrc> src_shape = ShapeCheck<dimsrc, SrcExp>::Check(src);
src_stride_ = src_shape[dimsrc-1];
src_stride_ = src_shape[dimsrc - 1];
Shape<dimsrc> src_stride;
src_stride[dimsrc-1] = 1;
for (int i = dimsrc-2; i >= 0; --i) src_stride[i] = src_shape[i+1]*src_stride[i+1];
for (int i = 0; i < dimsrc; ++i) {
dst_stride_[i] = src_stride[axes[i]];
dst_in_src_stride_[i] = src_stride[axes[i]];
this->shape_[i] = src_shape[axes[i]];
}
}
Expand All @@ -65,13 +65,13 @@ struct Plan<TransposeExExp<SrcExp, DType, dimsrc>, DType> {
explicit Plan(const TransposeExExp<SrcExp, DType, dimsrc> &e)
: src_(MakePlan(e.src_)),
src_stride_(e.src_stride_),
dst_stride_(e.dst_stride_),
dst_in_src_stride_(e.dst_in_src_stride_),
dst_shape_(e.shape_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t idx = j*dst_stride_[dimsrc-1];
index_t idx = j * dst_in_src_stride_[dimsrc - 1];
#pragma unroll
for (int k = dimsrc-2; k >= 0; --k) {
idx += (i%dst_shape_[k])*dst_stride_[k];
idx += (i % dst_shape_[k]) * dst_in_src_stride_[k];
i /= dst_shape_[k];
}
return src_.Eval(idx/src_stride_, idx%src_stride_);
Expand All @@ -80,9 +80,121 @@ struct Plan<TransposeExExp<SrcExp, DType, dimsrc>, DType> {
private:
Plan<SrcExp, DType> src_;
const index_t src_stride_;
const Shape<dimsrc> dst_stride_, dst_shape_;
const Shape<dimsrc> dst_in_src_stride_, dst_shape_;
};

/*!
* \brief transform contiguous indices of the source tensor to indices of the transposed tensor.
* input: Tensor<Device, k>: ishape
* output: Tensor<Device, k>: oshape = ishape
*
* \tparam SrcExp type of source expression
* \tparam DType the type of elements
* \tparam dimsrc source dimension
* \tparam etype source type
*/
template<typename SrcExp, typename DType, int dimsrc, int etype>
struct TransposeIndicesExp:
public Exp<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType, etype> {
/*! \brief source expression */
const SrcExp &src_indices_; // Expression of the source indices
const Shape<dimsrc> axes_; // The transpose axes
Shape<dimsrc> src_in_dst_stride_; // Holds the corresponding stride of the source axes in dst
Shape<dimsrc> src_shape_; // Holds the corresponding stride of the source axes in dst
/*! \brief constructor */
explicit TransposeIndicesExp(const SrcExp &src_indices,
Shape<dimsrc> src_shape,
Shape<dimsrc> axes) : src_indices_(src_indices),
src_shape_(src_shape), axes_(axes) {
Shape<dimsrc> dst_shape_;
Shape<dimsrc> dst_stride_;
bool axes_checking_flag[dimsrc] = { 0 };
for (int i = 0; i < dimsrc; ++i) {
CHECK_LT(axes[i], dimsrc)
<< "Invalid axes input! All elements of axes must be between 0 and " << dimsrc
<< ", find axes=" << axes;
dst_shape_[i] = src_shape[axes[i]];
axes_checking_flag[axes[i]] = true;
}
// check if the input axes is valid
for (int i = 0; i < dimsrc; ++i) {
CHECK_EQ(axes_checking_flag[i], true)
<< "Invalid axes input! All elements of axes must be between 0 and " << dimsrc
<< ", find axes=" << axes;
}
dst_stride_[dimsrc - 1] = 1;
for (int i = dimsrc - 2; i >= 0; --i) dst_stride_[i] = dst_shape_[i+1] * dst_stride_[i+1];
for (int i = 0; i < dimsrc; ++i) {
src_in_dst_stride_[axes[i]] = dst_stride_[i];
}
}
};

/*!
* \brief a expression that reshapes a tensor to another shape
* \param src Tensor<Device,dimsrc>:
* \return a expresion with type Tensor<Device,dimdst>
* \tparam a1 higher dimension to be swapped, assert a1 > a2
* \tparam a2 lower dimension to be swapped
* \tparam SrcExp source expression
* \tparam DType the type of elements
* \tparam etype source expression type
*/
template<typename SrcExp, typename DType, int dimsrc, int etype>
inline TransposeIndicesExp<SrcExp, DType, dimsrc, etype>
transpose_indices(const Exp<SrcExp, DType, etype> &src_indices,
Shape<dimsrc> src_shape,
Shape<dimsrc> axes) {
return TransposeIndicesExp<SrcExp, DType, dimsrc, etype>(src_indices.self(), src_shape, axes);
}

template<typename SrcExp, typename DType, int dimsrc, int etype>
struct Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType> {
public:
explicit Plan(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &e)
: src_indices_(MakePlan(e.src_indices_)),
src_in_dst_stride_(e.src_in_dst_stride_),
src_shape_(e.src_shape_) {}
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
index_t src_idx = static_cast<index_t>(src_indices_.Eval(i, j));
index_t dst_idx = 0;
#pragma unroll
for (int k = dimsrc - 1; k >= 0; --k) {
dst_idx += (src_idx % src_shape_[k]) * src_in_dst_stride_[k];
src_idx /= src_shape_[k];
}
return static_cast<DType>(dst_idx);
}

private:
Plan<SrcExp, DType> src_indices_;
const Shape<dimsrc> src_in_dst_stride_, src_shape_;
};

//----------------------
// Execution plan
//----------------------
/*! \brief make expression */
template<typename SrcExp, typename DType, int dimsrc, int etype>
inline Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType>
MakePlan(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &e) {
return Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType>(e);
}

template<int dim, typename SrcExp, typename DType, int dimsrc, int etype>
struct ShapeCheck<dim, TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > {
inline static Shape<dim>
Check(const TransposeIndicesExp<SrcExp, DType, dimsrc, etype> &t) {
Shape<dim> s = ShapeCheck<dim, SrcExp>::Check(t.src_indices_);
return s;
}
};

template<typename SrcExp, typename DType, int dimsrc, int etype>
struct ExpInfo<TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > {
static const int kDim = ExpInfo<SrcExp>::kDim;
static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_EXTENSION_TRANSPOSE_H_
46 changes: 37 additions & 9 deletions mshadow/tensor.h
Expand Up @@ -766,7 +766,8 @@ inline void SoftmaxGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix. dst += take_grad(src, index)
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[index[i]] += src[i]
Called when the featuredim of src is much larger than the batchsize
* \param dst destination
* \param index index to take
Expand All @@ -777,7 +778,8 @@ inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix. dst += take_grad(src, index)
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[index[i]] += src[i]
Called when the featuredim of src is much larger than the batchsize
* \param dst destination
* \param index index to take
Expand All @@ -788,7 +790,8 @@ inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix. dst += take_grad(src, index)
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[sorted[i]] += src[index[i]]
Called when the batchsize of src is larger than the featuredim
* \param dst destination
* \param sorted the sorted indices
Expand All @@ -801,7 +804,8 @@ inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Gradient accumulate of embedding matrix. dst += take_grad(src, index)
* \brief CPU/GPU: Gradient accumulate of embedding matrix.
dst[sorted[i]] += src[index[i]]
Called when the batchsize of src is larger than the featuredim
* \param dst destination
* \param sorted the sorted indices
Expand All @@ -814,11 +818,35 @@ inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!)
* \param keys the keys to sort
* \param values the values that sorts w.r.t the key
* \param is_ascend whether to sort key in ascending order
*/
* \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix.
dst[index[i]] = src[i]
Will use atomicAdd in the inner implementation and the result may not be deterministic.
* \param dst destination
* \param index the index to accumulate value
* \param src source output
*/
template<typename IndexType, typename DType>
inline void IndexFill(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 1, IndexType>& index,
const Tensor<cpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix.
dst[index[i]] = src[i]
Will use atomicAdd in the inner implementation and the result may not be deterministic.
* \param dst destination
* \param index the index to accumulate value
* \param src source output
*/
template<typename IndexType, typename DType>
inline void IndexFill(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src);
/*!
* \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!)
* \param keys the keys to sort
* \param values the values that sorts w.r.t the key
* \param is_ascend whether to sort key in ascending order
*/
template<typename KDType, typename VDType>
inline void SortByKey(Tensor<cpu, 1, KDType> keys, Tensor<cpu, 1, VDType> values,
bool is_ascend = true);
Expand Down

0 comments on commit 1877b90

Please sign in to comment.