Skip to content

Commit 4281946

Browse files
author
MaheshRavishankar
committed
[mlir][Tensor] Add ReifyRankedShapedTypeOpInterface to tensor.extract_slice.
Differential Revision: https://reviews.llvm.org/D111263
1 parent 9fad9de commit 4281946

File tree

3 files changed

+172
-5
lines changed

3 files changed

+172
-5
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,10 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
158158
//===----------------------------------------------------------------------===//
159159

160160
def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
161-
Tensor_Dialect, "extract_slice", [NoSideEffect, AttrSizedOperandSegments,
162-
OffsetSizeAndStrideOpInterface]> {
161+
Tensor_Dialect, "extract_slice",
162+
[NoSideEffect, AttrSizedOperandSegments,
163+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
164+
OffsetSizeAndStrideOpInterface]> {
163165
let summary = "extract slice operation";
164166
let description = [{
165167
The "extract_slice" operation extract a tensor from another tensor as
@@ -284,6 +286,11 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
284286
/// Return the number of leading operands before the `offsets`, `sizes` and
285287
/// and `strides` operands.
286288
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
289+
290+
/// Return the dimensions of the source that are dropped in the
291+
/// result when the result is rank-reduced.
292+
llvm::SmallDenseSet<unsigned> getDroppedDims();
293+
287294
}];
288295

289296
let hasCanonicalizer = 1;

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,12 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
277277
unsigned unsignedIndex = index.getValue().getZExtValue();
278278

279279
if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
280-
assert(sliceOp.isDynamicSize(unsignedIndex) &&
281-
"Expected dynamic slice size");
282-
return sliceOp.getDynamicSize(unsignedIndex);
280+
// Fold only for non-rank reduced ops. For the rank-reduced version, rely on
281+
// `resolve-shaped-type-result-dims` pass.
282+
if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
283+
sliceOp.isDynamicSize(unsignedIndex)) {
284+
return {sliceOp.getDynamicSize(unsignedIndex)};
285+
}
283286
}
284287

285288
// dim(cast) -> dim
@@ -895,6 +898,46 @@ getCanonicalSliceResultType(unsigned resultRank, RankedTensorType sourceType,
895898
return resultType;
896899
}
897900

901+
llvm::SmallDenseSet<unsigned> ExtractSliceOp::getDroppedDims() {
902+
llvm::SmallDenseSet<unsigned> droppedDims;
903+
ArrayRef<int64_t> resultShape = getType().getShape();
904+
SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
905+
unsigned shapePos = 0;
906+
for (auto size : enumerate(mixedSizes)) {
907+
Optional<int64_t> sizeVal = getConstantIntValue(size.value());
908+
// If the size is not 1, or if the current matched dimension of the result
909+
// is the same static shape as the size value (which is 1), then the
910+
// dimension is preserved.
911+
if (!sizeVal || sizeVal.getValue() != 1 ||
912+
(shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
913+
shapePos++;
914+
continue;
915+
}
916+
droppedDims.insert(size.index());
917+
}
918+
return droppedDims;
919+
}
920+
921+
LogicalResult ExtractSliceOp::reifyResultShapes(
922+
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
923+
reifiedReturnShapes.resize(1);
924+
reifiedReturnShapes[0].reserve(getType().getRank());
925+
SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
926+
llvm::SmallDenseSet<unsigned> droppedDims = getDroppedDims();
927+
Location loc = getLoc();
928+
for (auto size : enumerate(mixedSizes)) {
929+
if (droppedDims.count(size.index()))
930+
continue;
931+
if (auto attr = size.value().dyn_cast<Attribute>()) {
932+
reifiedReturnShapes[0].push_back(builder.create<ConstantIndexOp>(
933+
loc, attr.cast<IntegerAttr>().getInt()));
934+
continue;
935+
}
936+
reifiedReturnShapes[0].push_back(size.value().get<Value>());
937+
}
938+
return success();
939+
}
940+
898941
namespace {
899942
/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
900943
/// This essentially pushes memref_cast past its consuming slice when

mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,120 @@ func @insert_slice(
2525
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
2626
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
2727
// CHECK: return %[[D0]], %[[D1]], %[[D2]]
28+
29+
// -----
30+
31+
func @extract_slice(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 : index,
32+
%arg3 : index) -> (index, index, index) {
33+
%c0 = constant 0 : index
34+
%c1 = constant 1 : index
35+
%c2 = constant 2 : index
36+
%0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, %arg2, %arg3] [1, 1, 1] :
37+
tensor<?x?x?xf32> to tensor<?x?x?xf32>
38+
%1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
39+
%2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
40+
%3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
41+
return %1, %2, %3 : index, index, index
42+
}
43+
// CHECK-LABEL: func @extract_slice(
44+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
45+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
46+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
47+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
48+
// CHECK: return %[[ARG1]], %[[ARG2]], %[[ARG3]]
49+
50+
// -----
51+
52+
func @extract_slice_rank_reduced_1(%arg0 : tensor<?x?x?xf32>,
53+
%arg1 : index) -> index {
54+
%c0 = constant 0 : index
55+
%0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
56+
tensor<?x?x?xf32> to tensor<?xf32>
57+
%1 = tensor.dim %0, %c0 : tensor<?xf32>
58+
return %1 : index
59+
}
60+
// CHECK-LABEL: func @extract_slice_rank_reduced_1(
61+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
62+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
63+
// CHECK: return %[[ARG1]]
64+
65+
// -----
66+
67+
func @extract_slice_rank_reduced_2(%arg0 : tensor<?x?x?xf32>,
68+
%arg1 : index) -> index {
69+
%c0 = constant 0 : index
70+
%0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
71+
tensor<?x?x?xf32> to tensor<?x1xf32>
72+
%1 = tensor.dim %0, %c0 : tensor<?x1xf32>
73+
return %1 : index
74+
}
75+
// CHECK-LABEL: func @extract_slice_rank_reduced_2(
76+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
77+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
78+
// CHECK: return %[[ARG1]]
79+
80+
// -----
81+
82+
func @extract_slice_rank_reduced_3(%arg0 : tensor<?x?x?xf32>,
83+
%arg1 : index) -> index {
84+
%c1 = constant 1 : index
85+
%0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
86+
tensor<?x?x?xf32> to tensor<1x?xf32>
87+
%1 = tensor.dim %0, %c1 : tensor<1x?xf32>
88+
return %1 : index
89+
}
90+
// CHECK-LABEL: func @extract_slice_rank_reduced_3(
91+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
92+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
93+
// CHECK: return %[[ARG1]]
94+
95+
// -----
96+
97+
func @extract_slice_rank_reduced_4(%arg0 : tensor<?x?x?xf32>,
98+
%arg1 : index) -> index {
99+
%c1 = constant 1 : index
100+
%0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
101+
tensor<?x?x?xf32> to tensor<1x?x1xf32>
102+
%1 = tensor.dim %0, %c1 : tensor<1x?x1xf32>
103+
return %1 : index
104+
}
105+
// CHECK-LABEL: func @extract_slice_rank_reduced_4(
106+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
107+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
108+
// CHECK: return %[[ARG1]]
109+
110+
// -----
111+
112+
func @extract_slice_rank_reduced_5(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
113+
%arg2 : index) -> (index, index) {
114+
%c0 = constant 0 : index
115+
%c1 = constant 1 : index
116+
%0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
117+
tensor<?x?x?xf32> to tensor<?x?xf32>
118+
%1 = tensor.dim %0, %c0 : tensor<?x?xf32>
119+
%2 = tensor.dim %0, %c1 : tensor<?x?xf32>
120+
return %1, %2 : index, index
121+
}
122+
// CHECK-LABEL: func @extract_slice_rank_reduced_5(
123+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
124+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
125+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
126+
// CHECK: return %[[ARG1]], %[[ARG2]]
127+
128+
// -----
129+
130+
func @extract_slice_rank_reduced_6(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
131+
%arg2 : index) -> (index, index) {
132+
%c0 = constant 0 : index
133+
%c2 = constant 2 : index
134+
%0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
135+
tensor<?x?x?xf32> to tensor<?x1x?xf32>
136+
%1 = tensor.dim %0, %c0 : tensor<?x1x?xf32>
137+
%2 = tensor.dim %0, %c2 : tensor<?x1x?xf32>
138+
return %1, %2 : index, index
139+
}
140+
// CHECK-LABEL: func @extract_slice_rank_reduced_6(
141+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
142+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
143+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
144+
// CHECK: return %[[ARG1]], %[[ARG2]]

0 commit comments

Comments
 (0)