Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent fae57f9 commit 001982c
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Array<PrimExpr>& isha
} else {
auto idim = ishape[axes[i]];
auto b_expr = make_const(dtype, begin[i]);
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
auto s = strides[i];
if (s < 0) {
b = tvm::min(b, idim - 1);
Expand All @@ -700,12 +700,13 @@ inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Array<PrimExpr>& isha
return begin_expr;
}

inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape, const std::vector<int64_t>& begin,
inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape,
const std::vector<int64_t>& begin,
const std::vector<int64_t>& end,
const std::vector<int64_t>& strides,
const Array<Integer>& axes, std::string slice_mode,
const Array<PrimExpr>& begin_canonicalized,
bool use_any=false) {
bool use_any = false) {
size_t src_tensor_dim = ishape.size();
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
Expand Down Expand Up @@ -734,21 +735,18 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape, co
return out_shape;
}

inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape, const Array<Integer>& begin,
const Array<Integer>& end,
const Array<Integer>& strides,
const Array<Integer>& axes,
const std::string& slice_mode) {
inline Array<PrimExpr> StridedSliceOutputShape(
const Array<PrimExpr>& ishape, const Array<Integer>& begin, const Array<Integer>& end,
const Array<Integer>& strides, const Array<Integer>& axes, const std::string& slice_mode) {
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
std::vector<int64_t> begin_vec, end_vec, strides_vec;
std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode);
auto begin_canonicalized =
StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes,
begin[0]->dtype, slice_mode);
return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode,
begin_canonicalized, true);
}


/*!
* \brief strided_slice of a tensor
*
Expand All @@ -775,10 +773,10 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
std::vector<int64_t> begin_vec, end_vec, strides_vec;
std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode);

auto begin_expr =
StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
auto out_shape =
StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_expr);
auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes,
begin[0]->dtype, slice_mode);
auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes,
slice_mode, begin_expr);

return te::compute(
out_shape,
Expand Down

0 comments on commit 001982c

Please sign in to comment.