Skip to content

Commit

Permalink
[mlir][Linalg] Add canonicalization to remove no-op linalg operations.
Browse files Browse the repository at this point in the history
linalg.generic/indexed_generic operations on tensors whose body is
just yielding the (non-induction variable) arguments of the operation
can be canonicalized by replacing uses of the result with the
corresponding arguments.

Differential Revision: https://reviews.llvm.org/D94581
  • Loading branch information
MaheshRavishankar committed Jan 14, 2021
1 parent 4183999 commit 722ae10
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
51 changes: 50 additions & 1 deletion mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2119,6 +2119,54 @@ struct DeduplicateInputs : public RewritePattern {
}
};

/// Remove generic/indexed_generic operations (on tensors) that are just copying
/// the values from inputs to the results. Requirements are
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
/// the arguments corresponding to the operands.
struct RemoveIdentityLinalgOps : public RewritePattern {
RemoveIdentityLinalgOps(PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!isa<GenericOp, IndexedGenericOp>(op))
return failure();
LinalgOp genericOp = cast<LinalgOp>(op);
if (!genericOp.hasTensorSemantics())
return failure();
// Check all indexing maps are identity.
if (llvm::any_of(genericOp.getIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
return failure();

// Check that the body of the linalg operation is just a linalg.yield
// operation.
Block &body = op->getRegion(0).front();
if (!llvm::hasSingleElement(body))
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
if (!yieldOp)
return failure();

// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
unsigned numIndexArgs = genericOp.getNumPayloadInductionVariables();
SmallVector<Value, 4> returnedArgs;
for (Value yieldVal : yieldOp.values()) {
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
if (!yieldArg)
return failure();
unsigned argumentNumber = yieldArg.getArgNumber();
if (argumentNumber < numIndexArgs)
return failure();
returnedArgs.push_back(op->getOperand(argumentNumber - numIndexArgs));
}
rewriter.replaceOp(genericOp, returnedArgs);
return success();
}
};

/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
/// with the corresponding output tensor argument of the linalg op.
struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
Expand All @@ -2143,7 +2191,8 @@ struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
#define CANONICALIZERS_AND_FOLDERS(XXX) \
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
MLIRContext *context) { \
results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp>(); \
results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
RemoveIdentityLinalgOps>(); \
results.insert<ReplaceDimOfLinalgResult>(context); \
} \
\
Expand Down
31 changes: 30 additions & 1 deletion mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,10 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
return %1: tensor<0xf32>
}
// CHECK-LABEL: @dce_zero_memref
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32>
// CHECK-NOT: linalg.copy
// CHECK-NEXT: linalg.generic
// CHECK-NEXT: return %[[ARG1]]

// -----

Expand Down Expand Up @@ -449,3 +451,30 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]]
// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
// CHECK: return %[[T1]]

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%0 = dim %arg0, %c0 : tensor<?x?x?xf32>
%1 = dim %arg0, %c1 : tensor<?x?x?xf32>
%2 = dim %arg0, %c2 : tensor<?x?x?xf32>
%3 = linalg.init_tensor [%0, %1, %2] : tensor<?x?x?xf32>
%4, %5 = linalg.generic {
indexing_maps = [#map, #map, #map, #map],
iterator_types = ["parallel", "parallel", "parallel"]
} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
outs(%3, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32):
linalg.yield %arg3, %arg2 : f32, f32
} -> tensor<?x?x?xf32>, tensor<?x?x?xf32>
return %4, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}
// CHECK-LABEL: func @remove_no_op
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK: return %[[ARG1]], %[[ARG0]]

0 comments on commit 722ae10

Please sign in to comment.