Skip to content

Commit

Permalink
[mlir][Linalg] Refactor HoistPadding and add support for hoisting in …
Browse files Browse the repository at this point in the history
…the absence of packing loops.

This revision cleans up the implementation of hoist padding and extends it to also work in the
absence of packing loops.
This allows better composition when hoisting the padded result of a DPS operation.

A systematic usage of RewriterBase is applied to the implementation.

Depends on: D144856

Differential Revision: https://reviews.llvm.org/D144855
  • Loading branch information
nicolasvasilache committed Feb 28, 2023
1 parent 4bc254c commit 2f07d62
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 214 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -411,6 +411,12 @@ rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// }
/// ```
FailureOr<Value>
hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist,
int64_t numLoops, ArrayRef<int64_t> transposeVector,
tensor::PadOp &hoistedOp,
SmallVectorImpl<GenericOp> &transposeOps);
/// Calls into `hoistPaddingOnTensors` with a local IRRewriter.
FailureOr<Value>
hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops,
ArrayRef<int64_t> transposeVector,
tensor::PadOp &hoistedOp,
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
Expand Up @@ -31,6 +31,12 @@ SmallVector<Value> createDynamicDimValues(OpBuilder &b, Location loc,
SmallVector<OpFoldResult> createDimValues(OpBuilder &b, Location loc,
Value rankedTensor);

/// Returns the transposed `rankedTensorType` if `transposeVector` is non-empty.
/// Fail if `transposeVector` is not a permutation matching the tensor rank.
FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);

} // namespace tensor
} // namespace mlir

Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -1779,12 +1779,12 @@ DiagnosedSilenceableFailure
transform::HoistPadOp::applyToOne(tensor::PadOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
IRRewriter rewriter(target->getContext());
tensor::PadOp hoistedPadOp;
SmallVector<GenericOp> transposeOps;
// TODO: Pass rewriter down to hoistPaddingOnTensors, in a followup commit.
FailureOr<Value> result = hoistPaddingOnTensors(
target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps);
IRRewriter rewriter(target->getContext());
FailureOr<Value> result =
hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
hoistedPadOp, transposeOps);
if (succeeded(result)) {
// We need to perform our own replacement here because this API is still
// used in patterns that "pad and hoist", for which the replacement values
Expand Down
530 changes: 324 additions & 206 deletions mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Large diffs are not rendered by default.

19 changes: 15 additions & 4 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Expand Up @@ -2449,6 +2449,15 @@ LogicalResult PadOp::verify() {
auto resultType = getResult().getType().cast<RankedTensorType>();
auto expectedType =
PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
if (!expectedType) {
return emitError("failed to infer expectedType from sourceType ")
<< sourceType << ", specified resultType is " << resultType;
}
if (resultType.getRank() != expectedType.getRank()) {
return emitError("specified type ")
<< resultType << " does not match the inferred type "
<< expectedType;
}
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
continue;
Expand Down Expand Up @@ -2490,10 +2499,12 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
ArrayRef<int64_t> staticHigh,
ArrayRef<int64_t> resultShape) {
unsigned rank = sourceType.getRank();
assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
assert((resultShape.empty() || resultShape.size() == rank) &&
"unexpected resultShape size mismatch");
if (staticLow.size() != rank)
return RankedTensorType();
if (staticHigh.size() != rank)
return RankedTensorType();
if (!(resultShape.empty() || resultShape.size() == rank))
return RankedTensorType();

SmallVector<int64_t, 4> inferredShape;
for (auto i : llvm::seq<unsigned>(0, rank)) {
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Dialect/Tensor/Utils/Utils.cpp
Expand Up @@ -14,6 +14,7 @@

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"

using namespace mlir;
using namespace mlir::tensor;
Expand Down Expand Up @@ -65,3 +66,22 @@ mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) {
}
return dims;
}

FailureOr<RankedTensorType>
mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector) {
if (transposeVector.empty())
return rankedTensorType;
if (!isPermutationVector(transposeVector) ||
transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
return failure();

SmallVector<int64_t> transposedShape(rankedTensorType.getShape().begin(),
rankedTensorType.getShape().end());
applyPermutationToVector(transposedShape, transposeVector);

using RTTBuilder = RankedTensorType::Builder;
RankedTensorType transposedTensorType =
RTTBuilder(rankedTensorType).setShape(transposedShape);
return transposedTensorType;
}
43 changes: 43 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-hoist-pad.mlir
Expand Up @@ -149,3 +149,46 @@ transform.sequence failures(propagate) {
transform.structured.hoist_pad %pad by 1 loops, transpose by [1, 0]
: (!pdl.operation) -> !pdl.operation
}

// -----

// CHECK-LABEL: pad_and_hoist_init
func.func @pad_and_hoist_init(
%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
-> tensor<24x25xf32>
{

// CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>) {
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}}
// CHECK: : tensor<?x25xf32> to tensor<5x25xf32>
// CHECK: scf.for %{{.*}} -> (tensor<?x25xf32>) {
// CHECK: %[[RES:.*]] = linalg.matmul {{.*}} outs(%[[PADDED]] : tensor<5x25xf32>
//
// TODO: atm we are missing the plumbing of packedTensor through the loop bbarg
// when required (i.e. when hoisting init tensors).
// CHECK: %[[RES_EXTRACTED:.*]] = tensor.extract_slice %[[RES]][0, 0] [%{{.*}}, 25] [1, 1]
// CHECK-SAME: : tensor<5x25xf32> to tensor<?x25xf32>
// CHECK: scf.yield %[[RES_EXTRACTED]] : tensor<?x25xf32>
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}

transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!pdl.operation) -> !pdl.operation


%matmul_l1, %loops_l1:2 = transform.structured.tile_to_scf_for %matmul [5, 0, 7]

%matmul_padded = transform.structured.pad %matmul_l1 {
padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
padding_dimensions=[0, 1, 2]
}

%pad = transform.get_producer_of_operand %matmul_padded[2]
: (!pdl.operation) -> !transform.op<"tensor.pad">

transform.structured.hoist_pad %pad by 1 loops
: (!transform.op<"tensor.pad">) -> !pdl.operation
}
2 changes: 2 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Expand Up @@ -5491,6 +5491,7 @@ cc_library(
deps = [
":AffineDialect",
":ArithDialect",
":DialectUtils",
":TensorDialect",
"//llvm:Support",
],
Expand Down Expand Up @@ -8467,6 +8468,7 @@ cc_library(
":LinalgPassIncGen",
":LinalgStructuredOpsIncGen",
":LinalgUtils",
":LoopLikeInterface",
":MaskableOpInterface",
":MathDialect",
":MemRefDialect",
Expand Down

0 comments on commit 2f07d62

Please sign in to comment.