From c7c1fae34ae14f64c214e169304c6e81a2e7c669 Mon Sep 17 00:00:00 2001 From: Hendrik Klug Date: Thu, 11 Sep 2025 14:45:26 +0000 Subject: [PATCH] [mlir][transform] Add PromoteTensorOp Transform op to request a tensor value to live in a specific memory space after bufferization Remove hard-coded types from BufferizeToAllocationOp constructor --------- Co-authored-by: Nicolas Vasilache Co-authored-by: Alex Zinenko --- .../Linalg/TransformOps/LinalgTransformOps.td | 49 +++++++- .../TransformOps/LinalgTransformOps.cpp | 116 ++++++++++++++---- .../mlir/dialects/transform/structured.py | 6 - .../Transform/test-promote-tensors.mlir | 104 ++++++++++++++++ 4 files changed, 239 insertions(+), 36 deletions(-) create mode 100644 mlir/test/Dialect/Transform/test-promote-tensors.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 8f3232f01544f..0d6ebc087e2f3 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/OpBase.td" include "mlir/IR/RegionKindInterface.td" @@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op, - OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)> - ]; +//===----------------------------------------------------------------------===// +// PromoteTensorOp +//===----------------------------------------------------------------------===// + +def PromoteTensorOp : Op, + DeclareOpInterfaceMethods, + SameOperandsAndResultType]> { + let summary = "Request a tensor value to live in a specific memory space " + "after bufferization"; + let description = [{ + Requests that a tensor value lives in a specific memory space for its + lifetime. This is achieved by allocating a new tensor in the desired + memory space with `bufferization.alloc_tensor` and optionally materializing + the source value into that allocation with + `bufferization.materialize_in_destination`. All uses of the original value + are then redirected to the promoted value. + + The generated code for promoting tensor value %0 resembles the following: + + %1 = bufferization.alloc_tensor() + { memory_space = memory_space } + // Note: the materialization is omitted if %0 is never read and is only + // written into (i.e., it behaves as a result tensor). + %2 = bufferization.materialize_in_destination %0 in %1 + // ... + + + Deallocation is not handled by this transform. + + Return modes: + - Produces a silenceable failure if the given handle does not point to + tensor-typed values. + - Succeeds otherwise and returns a handle to the promoted value(s), i.e., + the result of materialization if present and the allocation otherwise. + }]; + + let arguments = (ins TransformValueHandleTypeInterface:$tensor, + OptionalAttr:$memory_space); + let results = (outs TransformValueHandleTypeInterface:$promoted); + + let assemblyFormat = + "(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 3f0b0bacd9756..dd9b4c2490ef4 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -42,6 +42,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" @@ -273,32 +274,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns( // BufferizeToAllocationOp //===----------------------------------------------------------------------===// -void transform::BufferizeToAllocationOp::build(OpBuilder &b, - OperationState &result, - Value target, - Attribute memorySpace) { - SmallVector resultTypes; - resultTypes.push_back(b.getType()); - resultTypes.push_back(b.getType()); - return build(b, result, - /*resultTypes=*/resultTypes, - /*target=*/target, - /*memory_space=*/memorySpace); -} - -void transform::BufferizeToAllocationOp::build(OpBuilder &b, - OperationState &result, - Value target, - int64_t memorySpace) { - SmallVector resultTypes; - resultTypes.push_back(b.getType()); - resultTypes.push_back(b.getType()); - return build(b, result, - /*resultTypes=*/resultTypes, - /*target=*/target, - /*memory_space=*/b.getI64IntegerAttr(memorySpace)); -} - namespace { class NewOpsListener : public RewriterBase::ForwardingListener { public: @@ -408,6 +383,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PromoteTensorOp +//===----------------------------------------------------------------------===// + +/// Return true if the operand may be read from by its owner. This is currently +/// very conservative and only looks inside linalg operations to prevent +/// unintentional data loss. +static bool mayBeRead(OpOperand &operand) { + auto linalgOp = dyn_cast(operand.getOwner()); + + // Be conservative about ops we cannot analyze deeper. + if (!linalgOp) + return true; + + // Look inside linalg ops. + Value blockArgument = linalgOp.getMatchingBlockArgument(&operand); + return !blockArgument.use_empty(); +} + +/// Return true if the value may be read through any of its uses. +static bool mayBeRead(Value value) { + // If the value has a reference semantics, it + // may be read through any alias... + if (!isa(value.getType())) + return true; + return llvm::any_of(value.getUses(), + static_cast(mayBeRead)); +} + +DiagnosedSilenceableFailure +transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + SmallVector promoted; + for (Value tensor : state.getPayloadValues(getTensor())) { + auto type = dyn_cast(tensor.getType()); + if (!type) { + return emitSilenceableError() << "non-tensor type: " << tensor; + } + + Operation *definingOp = tensor.getDefiningOp(); + if (definingOp) + rewriter.setInsertionPointAfter(definingOp); + else + rewriter.setInsertionPointToStart(cast(tensor).getOwner()); + + // Check this before we emit operations using this value. + bool needsMaterialization = mayBeRead(tensor); + + SmallVector dynamicDims; + llvm::SmallPtrSet preservedOps; + for (auto [pos, dim] : llvm::enumerate(type.getShape())) { + if (!ShapedType::isDynamic(dim)) + continue; + Value cst = rewriter.create(tensor.getLoc(), pos); + auto dimOp = rewriter.create(tensor.getLoc(), tensor, cst); + preservedOps.insert(dimOp); + dynamicDims.push_back(dimOp); + } + auto allocation = rewriter.create( + tensor.getLoc(), type, dynamicDims); + // Set memory space if provided. + if (getMemorySpaceAttr()) + allocation.setMemorySpaceAttr(getMemorySpaceAttr()); + Value allocated = allocation; + + // Only insert a materialization (typically bufferizes to a copy) when the + // value may be read from. + if (needsMaterialization) { + auto copy = rewriter.create( + tensor.getLoc(), tensor, allocated); + preservedOps.insert(copy); + promoted.push_back(copy.getResult()); + } else { + promoted.push_back(allocated); + } + rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps); + } + results.setValues(cast(getPromoted()), promoted); + return DiagnosedSilenceableFailure::success(); +} + +void transform::PromoteTensorOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTensorMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // DecomposeOp //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index bf40cc532065d..e3bacb5777d9f 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -44,18 +44,12 @@ def __init__( loc=None, ip=None, ): - # No other types are allowed, so hard-code those here. - allocated_buffer_type = transform.AnyValueType.get() - new_ops_type = transform.AnyOpType.get() - if isinstance(memory_space, int): memory_space = str(memory_space) if isinstance(memory_space, str): memory_space = Attribute.parse(memory_space) super().__init__( - allocated_buffer_type, - new_ops_type, target, memory_space=memory_space, memcpy_op=memcpy_op, diff --git a/mlir/test/Dialect/Transform/test-promote-tensors.mlir b/mlir/test/Dialect/Transform/test-promote-tensors.mlir new file mode 100644 index 0000000000000..bc9a05af64156 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-promote-tensors.mlir @@ -0,0 +1,104 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: @promote_in0 +// CHECK-SAME: (%[[ARG0:.+]]: tensor, %{{.*}}, %{{.*}}) +// CHECK: %[[C0:.+]] = arith.constant 0 +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64} +// CHECK: %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]] +// CHECK: linalg.matmul ins(%[[MAT]], %{{.*}} +func.func @promote_in0(%arg0: tensor, %arg1: tensor<42x?xf32>, %arg2: tensor) -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor<42x?xf32>) + outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %mm = transform.structured.match ops{["linalg.matmul"]} in %root + : (!transform.any_op) -> !transform.any_op + %op0 = transform.get_operand %mm[0] + : (!transform.any_op) -> !transform.any_value + transform.structured.promote_tensor to 1 %op0 : !transform.any_value + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @promote_out +// CHECK-SAME: (%{{.*}}: tensor, %{{.*}}: tensor, %[[ARG2:.+]]: tensor) +func.func @promote_out(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: %[[C0:.+]] = arith.constant 0 + // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]] + // CHECK: %[[C1:.+]] = arith.constant 1 + // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]] + // CHECK: %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64} + // CHECK-NOT: materialize_in_destination + // CHECK: linalg.add {{.*}} outs(%[[ALLOC]] + %0 = linalg.add ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %la = transform.structured.match ops{["linalg.add"]} in %root + : (!transform.any_op) -> !transform.any_op + %init = transform.get_operand %la[2] + : (!transform.any_op) -> !transform.any_value + transform.structured.promote_tensor to 1 %init : !transform.any_value + + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @promote_in0_out_bufferize +// CHECK-SAME: (%[[ARG0:.+]]: tensor, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor) +func.func @promote_in0_out_bufferize(%arg0: tensor, %arg1: tensor<42x?xf32>, %arg2: tensor) -> tensor { + // CHECK: %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>> + // CHECK: %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor to memref> + // CHECK: %{{.+}} = bufferization.to_buffer %arg0 : tensor to memref> + // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor to memref> + // CHECK: %{{.+}} = bufferization.to_buffer %arg2 : tensor to memref> + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref> + // CHECK: %[[C1:.+]] = arith.constant 1 : index + // CHECK: %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref> + // CHECK: %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref + // CHECK: %{{.+}} = arith.constant 0 : index + // CHECK: %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref> + // CHECK: %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref + // CHECK: memref.copy %[[IN0]], %[[ALLOC_IN]] : memref> to memref + // CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref) + %0 = linalg.add ins(%arg0, %arg1: tensor, tensor<42x?xf32>) + outs(%arg2: tensor) -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root: !transform.any_op) { + %la = transform.structured.match ops{["linalg.add"]} in %root + : (!transform.any_op) -> !transform.any_op + %op0 = transform.get_operand %la[0] + : (!transform.any_op) -> !transform.any_value + transform.structured.promote_tensor to 1 %op0 : !transform.any_value + + %init = transform.get_operand %la[2] + : (!transform.any_op) -> !transform.any_value + transform.structured.promote_tensor to 1 %init : !transform.any_value + + %func = transform.structured.match ops{["func.func"]} in %root + : (!transform.any_op) -> !transform.any_op + + %bufferized = transform.bufferization.one_shot_bufferize %func + : (!transform.any_op) -> !transform.any_op + + transform.yield + } +} + + +