Skip to content

Commit

Permalink
[mlir][Linalg] Add an InitTensor -> DimOp canonicalization pattern.
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D105537
  • Loading branch information
nicolasvasilache committed Jul 7, 2021
1 parent ce098cc commit 0c4e538
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
21 changes: 20 additions & 1 deletion mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,11 +799,30 @@ struct FoldInitTensorWithTensorReshapeOp
return success();
}
};

struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::DimOp dimOp,
PatternRewriter &rewriter) const override {
Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>();
if (!initTensorOp || !maybeConstantIndex)
return failure();
if (initTensorOp.isDynamicSize(*maybeConstantIndex)) {
rewriter.replaceOp(dimOp,
initTensorOp.getDynamicSize(*maybeConstantIndex));
return success();
}
rewriter.replaceOpWithNewOp<ConstantIndexOp>(dimOp, *maybeConstantIndex);
return success();
}
};
} // namespace

void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldInitTensorWithExtractSliceOp,
results.add<FoldInitTensorWithDimOp, FoldInitTensorWithExtractSliceOp,
FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
ReplaceStaticShapeDims>(context);
Expand Down
45 changes: 31 additions & 14 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,10 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @init_tensor_reshape_expansion
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK: %[[C2:.+]] = constant 2
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]]
// CHECK: %[[D0:.+]] = tensor.dim %[[INIT1]], %[[C2]]
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
// CHECK: return %[[INIT2]]
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[D]], 7]
// CHECK-NEXT: return %[[INIT]]

// -----

Expand All @@ -558,13 +555,10 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
// CHECK: func @init_tensor_reshape_collapse
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK: %[[C4:.+]] = constant 4
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7]
// CHECK: %[[D0:.+]] = tensor.dim %[[INIT1]], %[[C4]]
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
// CHECK: return %[[INIT2]]
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
// CHECK-NEXT: %[[INIT:.+]] = linalg.init_tensor [6, 5, %[[D]]]
// CHECK-NEXT: return %[[INIT]]

// -----

Expand Down Expand Up @@ -873,3 +867,26 @@ func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<
return %0 : tensor<2x3x4xf32>
}

// -----

func private @some_use(%i : index, %j : index)

// CHECK-LABEL: func @init_canonicalize
// CHECK-SAME: %[[I:.*]]: index
func @init_canonicalize(%i : index) {
%c0 = constant 0 : index
%c1 = constant 1 : index

// CHECK-NOT: init_tensor
%0 = linalg.init_tensor [%i, 42] : tensor<?x42xf32>

// CHECK-NOT: tensor.dim
%1 = tensor.dim %0, %c0: tensor<?x42xf32>
%2 = tensor.dim %0, %c1: tensor<?x42xf32>

// CHECK: %[[c42:.*]] = constant 42 : index
// CHECK: call @some_use(%[[I]], %[[c42]])
call @some_use(%1, %2) : (index, index) -> ()

return
}

0 comments on commit 0c4e538

Please sign in to comment.