-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][transform] Add PromoteTensorOp #158318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][transform] Add PromoteTensorOp #158318
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Hendrik_Klug (Jimmy2027) ChangesTransform op to request a tensor value to live in a specific memory space after bufferization Full diff: https://github.com/llvm/llvm-project/pull/158318.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4b919a8..b4c62baad11bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/RegionKindInterface.td"
@@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
Transform_AnyOpType:$new_ops);
let assemblyFormat = "$target attr-dict `:` type($target)";
let hasVerifier = 1;
+}
- let builders = [
- OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
- OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
- ];
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+def PromoteTensorOp : Op<Transform_Dialect, "structured.promote_tensor",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ SameOperandsAndResultType]> {
+ let summary = "Request a tensor value to live in a specific memory space "
+ "after bufferization";
+ let description = [{
+ Requests that a tensor value lives in a specific memory space for its
+ lifetime. This is achieved by allocating a new tensor in the desired
+ memory space with `bufferization.alloc_tensor` and optionally materializing
+ the source value into that allocation with
+ `bufferization.materialize_in_destination`. All uses of the original value
+ are then redirected to the promoted value.
+
+ The generated code for promoting tensor value %0 resembles the following:
+
+ %1 = bufferization.alloc_tensor(<dynamic dims of %0>)
+ { memory_space = memory_space }
+ // Note: the materialization is omitted if %0 is never read and is only
+ // written into (i.e., it behaves as a result tensor).
+ %2 = bufferization.materialize_in_destination %0 in %1
+ // ...
+ <all users of %0 now use %2 instead>
+
+ Deallocation is not handled by this transform.
+
+ Return modes:
+ - Produces a silenceable failure if the given handle does not point to
+ tensor-typed values.
+ - Succeeds otherwise and returns a handle to the promoted value(s), i.e.,
+ the result of materialization if present and the allocation otherwise.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface:$tensor,
+ OptionalAttr<AnyAttr>:$memory_space);
+ let results = (outs TransformValueHandleTypeInterface:$promoted);
+
+ let assemblyFormat =
+ "(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..017886ef4fcd3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -41,6 +41,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
@@ -272,32 +273,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- Attribute memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memorySpace=*/memorySpace);
-}
-
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
- OperationState &result,
- Value target,
- int64_t memorySpace) {
- SmallVector<Type> resultTypes;
- resultTypes.push_back(b.getType<transform::AnyValueType>());
- resultTypes.push_back(b.getType<transform::AnyOpType>());
- return build(b, result,
- /*resultTypes=*/resultTypes,
- /*target=*/target,
- /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
-}
-
namespace {
class NewOpsListener : public RewriterBase::ForwardingListener {
public:
@@ -407,6 +382,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the operand may be read from by its owner. This is currently
+/// very conservative and only looks inside linalg operations to prevent
+/// unintentional data loss.
+static bool mayBeRead(OpOperand &operand) {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+
+ // Be conservative about ops we cannot analyze deeper.
+ if (!linalgOp)
+ return true;
+
+ // Look inside linalg ops.
+ Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
+ return !blockArgument.use_empty();
+}
+
+/// Return true if the value may be read through any of its uses.
+static bool mayBeRead(Value value) {
+ // If the value has a reference semantics, it
+ // may be read through any alias...
+ if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
+ return true;
+ return llvm::any_of(value.getUses(),
+ static_cast<bool (&)(OpOperand &)>(mayBeRead));
+}
+
+DiagnosedSilenceableFailure
+transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Value> promoted;
+ for (Value tensor : state.getPayloadValues(getTensor())) {
+ auto type = dyn_cast<RankedTensorType>(tensor.getType());
+ if (!type) {
+ return emitSilenceableError() << "non-tensor type: " << tensor;
+ }
+
+ Operation *definingOp = tensor.getDefiningOp();
+ if (definingOp)
+ rewriter.setInsertionPointAfter(definingOp);
+ else
+ rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
+
+ // Check this before we emit operations using this value.
+ bool needsMaterialization = mayBeRead(tensor);
+
+ SmallVector<Value> dynamicDims;
+ llvm::SmallPtrSet<Operation *, 4> preservedOps;
+ for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
+ if (!ShapedType::isDynamic(dim))
+ continue;
+ Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
+ auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ preservedOps.insert(dimOp);
+ dynamicDims.push_back(dimOp);
+ }
+ auto allocation = rewriter.create<bufferization::AllocTensorOp>(
+ tensor.getLoc(), type, dynamicDims);
+ // Set memory space if provided.
+ if (getMemorySpaceAttr())
+ allocation.setMemorySpaceAttr(getMemorySpaceAttr());
+ Value allocated = allocation;
+
+ // Only insert a materialization (typically bufferizes to a copy) when the
+ // value may be read from.
+ if (needsMaterialization) {
+ auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
+ tensor.getLoc(), tensor, allocated);
+ preservedOps.insert(copy);
+ promoted.push_back(copy.getResult());
+ } else {
+ promoted.push_back(allocated);
+ }
+ rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
+ }
+ results.setValues(cast<OpResult>(getPromoted()), promoted);
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PromoteTensorOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTensorMutable(), effects);
+ transform::producesHandle(getOperation()->getOpResults(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-promote-tensors.mlir b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
new file mode 100644
index 0000000000000..bc9a05af64156
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @promote_in0
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}, %{{.*}})
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64}
+// CHECK: %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]]
+// CHECK: linalg.matmul ins(%[[MAT]], %{{.*}}
+func.func @promote_in0(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+ outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %mm = transform.structured.match ops{["linalg.matmul"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %op0 = transform.get_operand %mm[0]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_out
+// CHECK-SAME: (%{{.*}}: tensor<?x42xf32>, %{{.*}}: tensor<?x42xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_out(%arg0: tensor<?x42xf32>, %arg1: tensor<?x42xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[C0:.+]] = arith.constant 0
+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+ // CHECK: %[[C1:.+]] = arith.constant 1
+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
+ // CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64}
+ // CHECK-NOT: materialize_in_destination
+ // CHECK: linalg.add {{.*}} outs(%[[ALLOC]]
+ %0 = linalg.add ins(%arg0, %arg1 : tensor<?x42xf32>, tensor<?x42xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %la = transform.structured.match ops{["linalg.add"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %init = transform.get_operand %la[2]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_in0_out_bufferize
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_in0_out_bufferize(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref<?x?xf32, 1>
+ // CHECK: %{{.+}} = arith.constant 0 : index
+ // CHECK: %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref<?x42xf32, strided<[?, ?], offset: ?>>
+ // CHECK: %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref<?x42xf32, 1>
+ // CHECK: memref.copy %[[IN0]], %[[ALLOC_IN]] : memref<?x42xf32, strided<[?, ?], offset: ?>> to memref<?x42xf32, 1>
+ // CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref<?x42xf32, 1>, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref<?x?xf32, 1>)
+ %0 = linalg.add ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+ outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ %la = transform.structured.match ops{["linalg.add"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+ %op0 = transform.get_operand %la[0]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+
+ %init = transform.get_operand %la[2]
+ : (!transform.any_op) -> !transform.any_value
+ transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+ %func = transform.structured.match ops{["func.func"]} in %root
+ : (!transform.any_op) -> !transform.any_op
+
+ %bufferized = transform.bufferization.one_shot_bufferize %func
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.yield
+ }
+}
+
+
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
@Jimmy2027 could you rebase and land plz? |
5962919
to
7fe4cbc
Compare
Transform op to request a tensor value to live in a specific memory space after bufferization Remove hard-coded types from BufferizeToAllocationOp constructor --------- Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com> Co-authored-by: Alex Zinenko <ftynse@gmail.com>
19f2cd0
to
c7c1fae
Compare
Merged on @Jimmy2027 's behalf as he does not yet have commit access. |
Transform op to request a tensor value to live in a specific memory space after bufferization Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com> Co-authored-by: Alex Zinenko <ftynse@gmail.com>
Transform op to request a tensor value to live in a specific memory space after bufferization