Skip to content

Conversation

Jimmy2027
Copy link
Contributor

Transform op to request a tensor value to live in a specific memory space after bufferization

@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Hendrik_Klug (Jimmy2027)

Changes

Transform op to request a tensor value to live in a specific memory space after bufferization


Full diff: https://github.com/llvm/llvm-project/pull/158318.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+45-4)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+90-26)
  • (added) mlir/test/Dialect/Transform/test-promote-tensors.mlir (+104)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a19cce4b919a8..b4c62baad11bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/RegionKindInterface.td"
 
@@ -236,11 +237,51 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
                       Transform_AnyOpType:$new_ops);
   let assemblyFormat = "$target attr-dict `:` type($target)";
   let hasVerifier = 1;
+}
 
-  let builders = [
-    OpBuilder<(ins "Value":$target, "Attribute":$memorySpace)>,
-    OpBuilder<(ins "Value":$target, "int64_t":$memorySpace)>
-  ];
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+def PromoteTensorOp : Op<Transform_Dialect, "structured.promote_tensor",
+                         [DeclareOpInterfaceMethods<TransformOpInterface>,
+                          DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+                          SameOperandsAndResultType]> {
+  let summary = "Request a tensor value to live in a specific memory space "
+                "after bufferization";
+  let description = [{
+    Requests that a tensor value lives in a specific memory space for its
+    lifetime. This is achieved by allocating a new tensor in the desired
+    memory space with `bufferization.alloc_tensor` and optionally materializing
+    the source value into that allocation with
+    `bufferization.materialize_in_destination`. All uses of the original value
+    are then redirected to the promoted value.
+
+    The generated code for promoting tensor value %0 resembles the following:
+
+      %1 = bufferization.alloc_tensor(<dynamic dims of %0>)
+           { memory_space = memory_space }
+      // Note: the materialization is omitted if %0 is never read and is only
+      // written into (i.e., it behaves as a result tensor).
+      %2 = bufferization.materialize_in_destination %0 in %1
+      // ...
+      <all users of %0 now use %2 instead>
+
+    Deallocation is not handled by this transform.
+
+    Return modes:
+    - Produces a silenceable failure if the given handle does not point to
+      tensor-typed values.
+    - Succeeds otherwise and returns a handle to the promoted value(s), i.e.,
+      the result of materialization if present and the allocation otherwise.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$tensor,
+      OptionalAttr<AnyAttr>:$memory_space);
+  let results = (outs TransformValueHandleTypeInterface:$promoted);
+
+  let assemblyFormat =
+      "(`to` $memory_space^)? $tensor attr-dict `:` type($tensor)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..017886ef4fcd3 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -41,6 +41,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/LogicalResult.h"
@@ -272,32 +273,6 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
 
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
-                                               OperationState &result,
-                                               Value target,
-                                               Attribute memorySpace) {
-  SmallVector<Type> resultTypes;
-  resultTypes.push_back(b.getType<transform::AnyValueType>());
-  resultTypes.push_back(b.getType<transform::AnyOpType>());
-  return build(b, result,
-               /*resultTypes=*/resultTypes,
-               /*target=*/target,
-               /*memorySpace=*/memorySpace);
-}
-
-void transform::BufferizeToAllocationOp::build(OpBuilder &b,
-                                               OperationState &result,
-                                               Value target,
-                                               int64_t memorySpace) {
-  SmallVector<Type> resultTypes;
-  resultTypes.push_back(b.getType<transform::AnyValueType>());
-  resultTypes.push_back(b.getType<transform::AnyOpType>());
-  return build(b, result,
-               /*resultTypes=*/resultTypes,
-               /*target=*/target,
-               /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
-}
-
 namespace {
 class NewOpsListener : public RewriterBase::ForwardingListener {
 public:
@@ -407,6 +382,95 @@ LogicalResult transform::BufferizeToAllocationOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// PromoteTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Return true if the operand may be read from by its owner. This is currently
+/// very conservative and only looks inside linalg operations to prevent
+/// unintentional data loss.
+static bool mayBeRead(OpOperand &operand) {
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
+
+  // Be conservative about ops we cannot analyze deeper.
+  if (!linalgOp)
+    return true;
+
+  // Look inside linalg ops.
+  Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
+  return !blockArgument.use_empty();
+}
+
+/// Return true if the value may be read through any of its uses.
+static bool mayBeRead(Value value) {
+  // If the value has a reference semantics, it
+  // may be read through any alias...
+  if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
+    return true;
+  return llvm::any_of(value.getUses(),
+                      static_cast<bool (&)(OpOperand &)>(mayBeRead));
+}
+
+DiagnosedSilenceableFailure
+transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
+                                  transform::TransformResults &results,
+                                  transform::TransformState &state) {
+  SmallVector<Value> promoted;
+  for (Value tensor : state.getPayloadValues(getTensor())) {
+    auto type = dyn_cast<RankedTensorType>(tensor.getType());
+    if (!type) {
+      return emitSilenceableError() << "non-tensor type: " << tensor;
+    }
+
+    Operation *definingOp = tensor.getDefiningOp();
+    if (definingOp)
+      rewriter.setInsertionPointAfter(definingOp);
+    else
+      rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
+
+    // Check this before we emit operations using this value.
+    bool needsMaterialization = mayBeRead(tensor);
+
+    SmallVector<Value> dynamicDims;
+    llvm::SmallPtrSet<Operation *, 4> preservedOps;
+    for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
+      if (!ShapedType::isDynamic(dim))
+        continue;
+      Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
+      auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+      preservedOps.insert(dimOp);
+      dynamicDims.push_back(dimOp);
+    }
+    auto allocation = rewriter.create<bufferization::AllocTensorOp>(
+        tensor.getLoc(), type, dynamicDims);
+    // Set memory space if provided.
+    if (getMemorySpaceAttr())
+      allocation.setMemorySpaceAttr(getMemorySpaceAttr());
+    Value allocated = allocation;
+
+    // Only insert a materialization (typically bufferizes to a copy) when the
+    // value may be read from.
+    if (needsMaterialization) {
+      auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
+          tensor.getLoc(), tensor, allocated);
+      preservedOps.insert(copy);
+      promoted.push_back(copy.getResult());
+    } else {
+      promoted.push_back(allocated);
+    }
+    rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
+  }
+  results.setValues(cast<OpResult>(getPromoted()), promoted);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PromoteTensorOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTensorMutable(), effects);
+  transform::producesHandle(getOperation()->getOpResults(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // DecomposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-promote-tensors.mlir b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
new file mode 100644
index 0000000000000..bc9a05af64156
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-promote-tensors.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @promote_in0
+// CHECK-SAME:  (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}, %{{.*}})
+// CHECK:  %[[C0:.+]] = arith.constant 0
+// CHECK:  %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK:  %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM]]) {memory_space = 1 : i64}
+// CHECK:  %[[MAT:.+]] = bufferization.materialize_in_destination %[[ARG0]] in %[[ALLOC]]
+// CHECK:  linalg.matmul ins(%[[MAT]], %{{.*}}
+func.func @promote_in0(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+                       outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%root: !transform.any_op) {
+        %mm = transform.structured.match ops{["linalg.matmul"]} in %root
+            : (!transform.any_op) -> !transform.any_op
+        %op0 = transform.get_operand %mm[0]
+            : (!transform.any_op) -> !transform.any_value
+        transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+        transform.yield
+    }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_out
+// CHECK-SAME: (%{{.*}}: tensor<?x42xf32>, %{{.*}}: tensor<?x42xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_out(%arg0: tensor<?x42xf32>, %arg1: tensor<?x42xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    // CHECK:  %[[C0:.+]] = arith.constant 0
+    // CHECK:  %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[C0]]
+    // CHECK:  %[[C1:.+]] = arith.constant 1
+    // CHECK:  %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
+    // CHECK:  %[[ALLOC:.+]] = bufferization.alloc_tensor(%[[DIM0]], %[[DIM1]]) {memory_space = 1 : i64}
+    // CHECK-NOT: materialize_in_destination
+    // CHECK:  linalg.add {{.*}} outs(%[[ALLOC]]
+    %0 = linalg.add ins(%arg0, %arg1 : tensor<?x42xf32>, tensor<?x42xf32>)
+                    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%root: !transform.any_op) {
+        %la = transform.structured.match ops{["linalg.add"]} in %root
+            : (!transform.any_op) -> !transform.any_op
+        %init = transform.get_operand %la[2]
+                : (!transform.any_op) -> !transform.any_value
+        transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+        transform.yield
+    }
+}
+
+// -----
+
+// CHECK-LABEL: @promote_in0_out_bufferize
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x42xf32>, %{{.*}}: tensor<42x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>)
+func.func @promote_in0_out_bufferize(%arg0: tensor<?x42xf32>, %arg1: tensor<42x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+    // CHECK:  %[[IN1:.+]] = bufferization.to_buffer %arg1 : tensor<42x?xf32> to memref<42x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[IN0:.+]] = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %{{.+}} = bufferization.to_buffer %arg0 : tensor<?x42xf32> to memref<?x42xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %{{.+}} = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[C0:.+]] = arith.constant 0 : index
+    // CHECK:  %{{.+}} = memref.dim %{{.+}}, %[[C0]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[C1:.+]] = arith.constant 1 : index
+    // CHECK:  %{{.+}} = memref.dim %{{.+}}, %[[C1]] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[ALLOC_OUT:.+]] = memref.alloc(%{{.+}}, %{{.+}}) {alignment = 64 : i64} : memref<?x?xf32, 1>
+    // CHECK:  %{{.+}} = arith.constant 0 : index
+    // CHECK:  %{{.+}} = memref.dim %{{.+}}, %{{.+}} : memref<?x42xf32, strided<[?, ?], offset: ?>>
+    // CHECK:  %[[ALLOC_IN:.+]] = memref.alloc(%{{.+}}) {alignment = 64 : i64} : memref<?x42xf32, 1>
+    // CHECK:  memref.copy %[[IN0]], %[[ALLOC_IN]] : memref<?x42xf32, strided<[?, ?], offset: ?>> to memref<?x42xf32, 1>
+    // CHECK: linalg.add ins(%[[ALLOC_IN]], %[[IN1]] : memref<?x42xf32, 1>, memref<42x?xf32, strided<[?, ?], offset: ?>>) outs(%[[ALLOC_OUT]] : memref<?x?xf32, 1>)
+    %0 = linalg.add ins(%arg0, %arg1: tensor<?x42xf32>, tensor<42x?xf32>)
+                    outs(%arg2: tensor<?x?xf32>) -> tensor<?x?xf32>
+    return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%root: !transform.any_op) {
+        %la = transform.structured.match ops{["linalg.add"]} in %root
+            : (!transform.any_op) -> !transform.any_op
+        %op0 = transform.get_operand %la[0]
+            : (!transform.any_op) -> !transform.any_value
+        transform.structured.promote_tensor to 1 %op0 : !transform.any_value
+
+        %init = transform.get_operand %la[2]
+                : (!transform.any_op) -> !transform.any_value
+        transform.structured.promote_tensor to 1 %init : !transform.any_value
+
+        %func = transform.structured.match ops{["func.func"]} in %root
+                : (!transform.any_op) -> !transform.any_op
+
+        %bufferized = transform.bufferization.one_shot_bufferize %func
+            : (!transform.any_op) -> !transform.any_op
+
+        transform.yield
+    }
+}
+
+
+

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@llvmbot llvmbot added the mlir:python MLIR Python bindings label Sep 13, 2025
@nicolasvasilache
Copy link
Contributor

@Jimmy2027 could you rebase and land plz?

@Jimmy2027 Jimmy2027 force-pushed the hendrik/transform/promote_tensor branch from 5962919 to 7fe4cbc Compare September 30, 2025 19:35
Transform op to request a tensor value to live in a
specific memory space after bufferization

Remove hard-coded types from BufferizeToAllocationOp constructor

---------

Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com>
Co-authored-by: Alex Zinenko <ftynse@gmail.com>
@Jimmy2027 Jimmy2027 force-pushed the hendrik/transform/promote_tensor branch from 19f2cd0 to c7c1fae Compare September 30, 2025 19:47
@nicolasvasilache nicolasvasilache merged commit 3c0f7b1 into llvm:main Oct 1, 2025
9 checks passed
@nicolasvasilache
Copy link
Contributor

Merged on @Jimmy2027 's behalf as he does not yet have commit access.

mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
Transform op to request a tensor value to live in a specific memory
space after bufferization

Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com>
Co-authored-by: Alex Zinenko <ftynse@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants