Skip to content

Commit

Permalink
remove wip layout transform support for slice with axes
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 9bcb2ad commit 37eaf57
Showing 1 changed file with 0 additions and 212 deletions.
212 changes: 0 additions & 212 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2590,218 +2590,6 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
return true;
}

// Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
// const Array<Layout>& new_in_layouts,
// const Array<Layout>& old_in_layouts,
// const Array<tvm::relay::Type>& old_in_types)
// {
// Array<Array<IndexExpr>> old_in_shapes;
// for (auto old_in_t : old_in_types) {
// ICHECK(old_in_t.as<TensorTypeNode>());
// old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
// }

// ICHECK(old_in_layouts.defined());
// ICHECK_GE(old_in_layouts.size(), 1);
// ICHECK(old_in_shapes.defined());
// ICHECK_GE(old_in_shapes.size(), 1);

// auto layout = old_in_layouts[0];
// if (layout.defined() && new_in_layouts.defined()) {
// ICHECK_GE(new_in_layouts.size(), 1);
// auto new_layout = new_in_layouts[0];
// auto shape = old_in_shapes[0];

// // NOTE: Discard "const" qualifier here.
// auto* params = const_cast<StridedSliceAttrs*>(attrs.as<StridedSliceAttrs>());
// ICHECK(params != nullptr);
// Array<Integer> begin, end, strides;
// if (params->begin && params->end && params->strides) {
// for (Integer i : params->strides.value()) {
// ICHECK(i.defined());
// strides.push_back(params->slice_mode == "size" ? 1 : i->value);
// }

// for (Integer i : params->begin.value()) {
// ICHECK(i.defined());
// begin.push_back(i->value);
// }
// for (Integer i : params->end.value()) {
// ICHECK(i.defined());
// end.push_back(i->value);
// }
// }
// auto axes = params->axes;

// Array<Integer> new_begin, new_end, new_strides;

// // Handles layout conversion like NHWC -> NCHW
// auto old_layout_name = layout.name();
// auto new_layout_name = new_layout.name();

// if (old_layout_name.rfind(new_layout_name, 0) != 0 &&
// new_layout_name.rfind(old_layout_name, 0) != 0) {
// if (old_layout_name.size() != new_layout_name.size()) {
// // Not support NHW4c -> NCHW
// return {{Layout::Undef()}, {Layout::Undef()}};
// } else {
// if (params->axes) {
// auto axes = params->axes.value();
// Array<Integer> new_axes(axes);
// std::vector<int> axes_map(old_layout_name.size(), -1);
// for (size_t i = 0; i < axes.size(); ++i) {
// axes_map[axes[i]] = i;
// }
// LOG(INFO) << "old layout: " << old_layout_name;
// LOG(INFO) << "new layout: " << new_layout_name;
// for (size_t i = 0; i < new_layout_name.size(); ++i) {
// auto index = layout.IndexOf(new_layout[i]);
// if (index == -1) {
// return {{Layout::Undef()}, {Layout::Undef()}};
// }

// size_t new_index = static_cast<size_t>(index);
// if (axes_map[new_index] != -1) {
// ICHECK(strides[axes_map[new_index]].defined());
// new_axes.Set(axes_map[new_index], new_index);
// new_begin.push_back(begin[axes_map[new_index]]->value);
// new_end.push_back(end[axes_map[new_index]]->value);
// new_strides.push_back(strides[axes_map[new_index]]->value);
// }
// }
// params->axes = new_axes;
// } else {
// for (size_t i = 0; i < new_layout_name.size(); ++i) {
// auto index = layout.IndexOf(new_layout[i]);
// if (index == -1) {
// return {{Layout::Undef()}, {Layout::Undef()}};
// }

// size_t new_index = static_cast<size_t>(index);
// int64_t bg, ed, st;
// if (strides.defined() && new_index < strides.size() && strides[new_index].defined())
// {
// st = strides[new_index]->value;
// } else {
// st = 1;
// }
// if (new_index < begin.size() && begin[new_index].defined()) {
// bg = begin[new_index]->value;
// } else {
// bg = 0;
// }
// if (new_index < end.size() && end[new_index].defined()) {
// ed = end[new_index]->value;
// } else {
// ed = shape[new_index].as<IntImmNode>()->value;
// }

// new_begin.push_back(bg);
// new_end.push_back(ed);
// new_strides.push_back(st);
// }
// }
// params->begin = new_begin;
// params->end = new_end;
// params->strides = new_strides;
// layout = new_layout;
// }
// } else {
// if (params->axes) {
// auto axes = params->axes.value();
// LOG(INFO) << "old layout: " << old_layout_name;
// LOG(INFO) << "new layout: " << new_layout_name;
// for (size_t i = 0; i < axes.size(); i++) {
// const LayoutAxis& axis = layout[axes[i]];
// if (!axis.IsPrimal()) {
// // original layout that contains splitted axes is not supported
// return {{Layout::Undef()}, {Layout::Undef()}};
// }
// auto factor = new_layout.FactorOf(axis);
// if (factor == -1) {
// new_begin.push_back(begin[i]);
// new_end.push_back(end[i]);
// } else {
// if (strides.defined() && i < strides.size()) {
// auto stride = strides[i];
// // arbitrary stride is not supported
// if (stride.defined() && stride->value != 1) {
// return {{Layout::Undef()}, {Layout::Undef()}};
// }
// }
// int64_t bg = begin[i].defined() ? begin[i]->value : 0;
// int64_t ed;
// if (!end[i].defined()) {
// ed = shape[axes[i]].as<IntImmNode>()->value;
// } else if (params->slice_mode == "size") {
// if (end[i]->value < 0) {
// ed = shape[axes[i]].as<IntImmNode>()->value;
// } else {
// ed = bg + end[i]->value;
// }
// } else {
// ed = end[i]->value;
// }

// if (bg % factor || ed % factor) {
// // transform to original layout
// return {{Layout::Undef()}, {Layout::Undef()}};
// }
// new_begin.push_back(tvm::Integer(bg / factor));
// new_end.push_back(tvm::Integer(ed / factor));
// }
// }
// } else {
// for (size_t i = 0; i < begin.size(); i++) {
// const LayoutAxis& axis = layout[i];
// if (!axis.IsPrimal()) {
// // original layout that contains splitted axes is not supported
// return {{Layout::Undef()}, {Layout::Undef()}};
// }
// auto factor = new_layout.FactorOf(axis);
// if (factor == -1) {
// new_begin.push_back(begin[i]);
// new_end.push_back(end[i]);
// } else {
// if (strides.defined() && i < strides.size()) {
// auto stride = strides[i];
// // arbitrary stride is not supported
// if (stride.defined() && stride->value != 1) {
// return {{Layout::Undef()}, {Layout::Undef()}};
// }
// }
// int64_t bg = begin[i].defined() ? begin[i]->value : 0;
// int64_t ed;
// if (!end[i].defined()) {
// ed = shape[i].as<IntImmNode>()->value;
// } else if (params->slice_mode == "size") {
// if (end[i]->value < 0) {
// ed = shape[i].as<IntImmNode>()->value;
// } else {
// ed = bg + end[i]->value;
// }
// } else {
// ed = end[i]->value;
// }

// if (bg % factor || ed % factor) {
// // transform to original layout
// return {{Layout::Undef()}, {Layout::Undef()}};
// }
// new_begin.push_back(tvm::Integer(bg / factor));
// new_end.push_back(tvm::Integer(ed / factor));
// }
// }
// }

// layout = new_layout;
// params->begin = new_begin;
// params->end = new_end;
// }
// }
// return {{layout}, {layout}};
// }

Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
Expand Down

0 comments on commit 37eaf57

Please sign in to comment.