Skip to content

Commit

Permalink
remove dynamic input specific op
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 510bce6 commit 7db4cea
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 80 deletions.
118 changes: 46 additions & 72 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,26 +550,6 @@ inline Array<Tensor> split(const Tensor& x, Array<PrimExpr> split_indices, int a
return result;
}

inline te::Tensor strided_slice_compute_common(const te::Tensor& x,
const Array<PrimExpr>& out_shape,
const Array<PrimExpr>& begin,
const Array<PrimExpr>& strides,
const Array<Integer>& axes, const std::string& name,
const std::string& tag) {
return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides[i] + begin[i];
real_indices.Set(axes[i], ind);
}
return x(real_indices);
},
name, tag);
}

inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
std::string name = "T_dynamic_strided_slice",
Expand Down Expand Up @@ -645,34 +625,6 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
}

inline Tensor strided_slice_dynamic_input(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string tag = kInjective) {
size_t src_tensor_dim = x->shape.size();
ICHECK(begin.size() == src_tensor_dim)
<< "for dynamic inputs, len(begin) must equal the input dimension";
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
Array<PrimExpr> begin_expr, end_expr, strides_expr;
Array<Integer> axes;
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_i = begin[i]->value;
if (begin_i < 0) {
begin_i += topi::detail::GetConstInt(x->shape[i]);
}
begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i));
strides_expr.push_back(
tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()),
(i < strides.size() ? strides[i]->value : 1)));
axes.push_back(i);
}
return strided_slice_compute_common(x, out_shape, begin_expr, strides_expr, axes, name, tag);
}

inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
const Array<Integer>& axes, std::string slice_mode = "end",
Expand Down Expand Up @@ -729,34 +681,56 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& beg
Array<PrimExpr> begin_expr, strides_expr;
for (size_t i = 0; i < axes.size(); ++i) {
int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
ICHECK(x->shape[axes[i]]->IsInstance<tvm::IntImmNode>())
<< "Input shape at axis " << axes[i] << " is not static";
int64_t dim_i = GetConstInt(x->shape[axes[i]]);
int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i;
// transform negative indices to positive value, clips on the correct range
auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) {
if (index < 0) {
index += dim_i;
}
return std::min(std::max(index, begin_range), end_range);
};
if (x->shape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
int64_t dim_i = GetConstInt(x->shape[axes[i]]);
int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i;
// transform negative indices to positive value, clips on the correct range
auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) {
if (index < 0) {
index += dim_i;
}
return std::min(std::max(index, begin_range), end_range);
};

int64_t begin_i = index_canonicalization(begin_vec[i]);
int64_t end_i = index_canonicalization(end_vec[i]);
int64_t begin_i = index_canonicalization(begin_vec[i]);
int64_t end_i = index_canonicalization(end_vec[i]);

int interval = std::abs(end_i - begin_i);
int slice_size =
static_cast<int>((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
<< "] is invalid for axis=" << i;
int interval = std::abs(end_i - begin_i);
int slice_size =
static_cast<int>((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
<< "] is invalid for axis=" << i;

begin_expr.push_back(make_const(begin[0].dtype(), begin_i));
strides_expr.push_back(
make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i]));
out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size)));
begin_expr.push_back(make_const(begin[0].dtype(), begin_i));
out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size)));
} else {
auto idim = x->shape[axes[i]];
auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]);
auto s = strides[i]->value;
if (s < 0) {
b = tvm::min(b, idim - 1);
} else {
b = tvm::if_then_else(b < 0, 0, b);
}
out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype));
begin_expr.push_back(b);
}
strides_expr.push_back(make_const(strides[i].dtype(), stride_vec[i]));
}
return strided_slice_compute_common(x, out_shape, begin_expr, strides_expr, axes, name, tag);

return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides_expr[i] + begin_expr[i];
real_indices.Set(axes[i], ind);
}
return x(real_indices);
},
name, tag);
}

/*!
Expand Down
3 changes: 0 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2739,9 +2739,6 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
auto axes = param->axes.value();
return Array<te::Tensor>{
topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)};
} else if (IsDynamic(out_type)) {
return Array<te::Tensor>{
topi::strided_slice_dynamic_input(inputs[0], begin, end, strides, param->slice_mode)};
}
return Array<te::Tensor>{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)};
}
Expand Down
6 changes: 1 addition & 5 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,7 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
Array<Integer> begin_static = args[1];
Array<Integer> end_static = args[2];
Array<Integer> strides_static = args[3];
if (IsConstIntArray(x->shape)) {
*rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode);
} else {
*rv = strided_slice_dynamic_input(x, begin_static, end_static, strides_static, slice_mode);
}
*rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode);
} else {
*rv = dynamic_strided_slice(x, begin, end, strides);
}
Expand Down

0 comments on commit 7db4cea

Please sign in to comment.