Skip to content

Conversation

andrey-golubev
Copy link
Contributor

@andrey-golubev andrey-golubev commented Sep 19, 2025

Support custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support.

To achieve this, BufferizationOptions::FunctionArgTypeConverterFn callback is converted to work with tensor-like and buffer-like types, instead of the builtin counterparts. The default behavior for builtins remains unchanged, while custom types by default go through TensorLikeType::getBufferType() which is a general conversion interface.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Sep 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Andrei Golubev (andrey-golubev)

Changes

Support custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support.

In order to enable this, TensorLikeType is extended with a new interface method that is invoked solely within the function boundary bufferization.


Full diff: https://github.com/llvm/llvm-project/pull/159766.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td (+12)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+13)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+50-40)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+56)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+5)
  • (modified) mlir/test/lib/Dialect/Test/TestTypes.cpp (+8)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index a2bfcb7ed2b75..9b052b8bb7e14 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Func/IR/FuncOps.h" // to access mlir::func::FuncOp
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index fb6fc4f5ad964..c4235cd067999 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -43,6 +43,18 @@ def Bufferization_TensorLikeTypeInterface
       /*args=*/(ins
         "::mlir::bufferization::BufferLikeType":$bufferType,
         "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+    >,
+    InterfaceMethod<[{
+        Returns a BufferLike type for this TensorLike type in the context of
+        this type being function argument or result.
+      }],
+      /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
+      /*methodName=*/"getBufferTypeAtFunctionBoundary",
+      /*args=*/(ins
+        "::mlir::func::FuncOp":$funcOp,
+        "const ::mlir::bufferization::BufferizationOptions &":$options,
+        "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
+      )
     >
   ];
 }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..9b907922a24c4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -87,6 +87,19 @@ struct BuiltinTensorExternalModel
 
     return mlir::success();
   }
+
+  llvm::FailureOr<BufferLikeType> getBufferTypeAtFunctionBoundary(
+      mlir::Type tensor, mlir::func::FuncOp funcOp,
+      const BufferizationOptions &options,
+      llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+    auto tensorType = cast<TensorType>(tensor);
+    auto memSpace = options.defaultMemorySpaceFn(tensorType);
+    if (!memSpace.has_value())
+      return emitError() << "could not infer memory space";
+
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *memSpace, funcOp, options));
+  }
 };
 
 template <typename MemRef>
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 68ef51992efee..701ab52a491a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -401,7 +401,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
   // Compute the new signature.
   SmallVector<Type> newTypes;
   for (BlockArgument &bbArg : block->getArguments()) {
-    auto tensorType = dyn_cast<TensorType>(bbArg.getType());
+    auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
     if (!tensorType) {
       newTypes.push_back(bbArg.getType());
       continue;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index f69efd1b3fa8c..b7bac9f4623f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -52,26 +52,35 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
 /// Return the index-th bufferized function argument type. This assumes that the
 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
 /// specified by the user (as per `options.functionArgTypeConverterFn`).
-static BaseMemRefType
+static BufferLikeType
 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
                              const BufferizationOptions &options) {
   auto tensorType =
-      dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
-  assert(tensorType && "expected TensorType");
-
-  BaseMemRefType memrefType = options.functionArgTypeConverterFn(
-      tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
-
-  auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
-      index, BufferizationDialect::kBufferLayoutAttrName);
-  if (!layoutAttr)
-    return memrefType;
-
-  auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
-  assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
-  return MemRefType::get(rankedMemrefType.getShape(),
-                         rankedMemrefType.getElementType(), layoutAttr,
-                         rankedMemrefType.getMemorySpace());
+      dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
+  assert(tensorType && "expected TensorLikeType");
+  auto maybeBufferType = tensorType.getBufferTypeAtFunctionBoundary(
+      funcOp, options, [&]() { return funcOp->emitError(); });
+  assert(mlir::succeeded(maybeBufferType) &&
+         "a valid buffer is always expected");
+
+  auto bufferType = *maybeBufferType;
+
+  // Note: For builtin tensors there is additional logic related to layout.
+  if (isa<TensorType>(tensorType)) {
+    auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
+        index, BufferizationDialect::kBufferLayoutAttrName);
+    if (!layoutAttr)
+      return bufferType;
+
+    auto rankedMemrefType = dyn_cast<MemRefType>(bufferType);
+    assert(rankedMemrefType &&
+           "buffer layout not supported on unranked tensors");
+    return cast<BufferLikeType>(MemRefType::get(
+        rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
+        layoutAttr, rankedMemrefType.getMemorySpace()));
+  }
+
+  return bufferType;
 }
 
 /// Return the FuncOp called by `callOp`.
@@ -227,14 +236,13 @@ struct CallOpInterface
     FunctionType funcType = funcOp.getFunctionType();
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
-    if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return cast<BufferLikeType>(bufferizedType);
+    if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
+      return bufferizedType;
 
     // Otherwise, call the type converter to compute the bufferized type.
-    auto tensorType = cast<TensorType>(resultType);
-    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
-        options));
+    auto tensorType = cast<TensorLikeType>(resultType);
+    return tensorType.getBufferTypeAtFunctionBoundary(
+        funcOp, options, [&]() { return funcOp->emitError(); });
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -248,7 +256,7 @@ struct CallOpInterface
     SmallVector<Type> resultTypes;
     for (Value result : callOp.getResults()) {
       Type returnType = result.getType();
-      if (!isa<TensorType>(returnType)) {
+      if (!isa<TensorLikeType>(returnType)) {
         // Non-tensor values are returned.
         resultTypes.push_back(returnType);
         continue;
@@ -272,7 +280,7 @@ struct CallOpInterface
 
     for (OpOperand &opOperand : callOp->getOpOperands()) {
       // Non-tensor operands are just copied.
-      if (!isa<TensorType>(opOperand.get().getType())) {
+      if (!isa<TensorLikeType>(opOperand.get().getType())) {
         newOperands.push_back(opOperand.get());
         continue;
       }
@@ -285,8 +293,8 @@ struct CallOpInterface
       Value buffer = *maybeBuffer;
 
       // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
-      auto memRefType = funcType.getInput(opOperand.getOperandNumber());
-      if (!isa<BaseMemRefType>(memRefType)) {
+      auto bufferType = funcType.getInput(opOperand.getOperandNumber());
+      if (!isa<BufferLikeType>(bufferType)) {
         // The called function was not bufferized yet. This can happen when
         // there cycles in the function call graph. Compute the bufferized
         // result type.
@@ -296,7 +304,7 @@ struct CallOpInterface
                 state);
         if (failed(maybeBufferType))
           return failure();
-        memRefType = *maybeBufferType;
+        bufferType = *maybeBufferType;
       }
 
       // Since we don't yet have a clear layout story, to_buffer may
@@ -305,8 +313,8 @@ struct CallOpInterface
       // that will either canonicalize away or fail compilation until we can do
       // something better. Insert a reallocation + copy if it cannot be
       // statically guaranteed that a direct cast would be valid.
-      if (buffer.getType() != memRefType) {
-        auto memrefDstType = dyn_cast<MemRefType>(memRefType);
+      if (buffer.getType() != bufferType) {
+        auto memrefDstType = dyn_cast<MemRefType>(bufferType);
         assert(memrefDstType &&
                "buffer layout not supported on unranked tensors");
         FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
@@ -370,7 +378,7 @@ struct FuncOpInterface
   static bool supportsUnstructuredControlFlow() { return true; }
 
   bool hasTensorSemantics(Operation *op) const {
-    auto isaTensor = llvm::IsaPred<TensorType>;
+    auto isaTensor = llvm::IsaPred<TensorLikeType>;
 
     // A function has tensor semantics if it has tensor arguments/results.
     auto funcOp = cast<FuncOp>(op);
@@ -406,8 +414,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return cast<BufferLikeType>(
-          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
+      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
+                                          options);
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
@@ -430,7 +438,7 @@ struct FuncOpInterface
     SmallVector<Type> argTypes;
     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
       Type argType = it.value();
-      if (isa<TensorType>(argType)) {
+      if (isa<TensorLikeType>(argType)) {
         argTypes.push_back(
             getBufferizedFunctionArgType(funcOp, it.index(), options));
         continue;
@@ -441,11 +449,13 @@ struct FuncOpInterface
     // Compute the result types.
     SmallVector<Type> retTypes;
     for (Type resultType : funcType.getResults()) {
-      if (auto tensorType = dyn_cast<TensorType>(resultType)) {
-        BaseMemRefType resultType = options.functionArgTypeConverterFn(
-            tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
-            options);
-        retTypes.push_back(resultType);
+      if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
+        FailureOr<BufferLikeType> resultType =
+            tensorType.getBufferTypeAtFunctionBoundary(
+                funcOp, options, [&]() { return funcOp->emitError(); });
+        assert(mlir::succeeded(resultType) &&
+               "a valid buffer is always expected");
+        retTypes.push_back(*resultType);
         continue;
       }
       retTypes.push_back(resultType);
@@ -473,7 +483,7 @@ struct FuncOpInterface
       SmallVector<Value> returnValues;
       for (auto [returnVal, bufferizedType] :
            llvm::zip_equal(returnOp->getOperands(), retTypes)) {
-        auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+        auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
         rewriter.setInsertionPoint(returnOp);
 
         // If not a tensor type just forward it.
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 2efb5893c8511..eb0093106dc11 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -810,3 +810,59 @@ module @inner_module {
     return %t : tensor<5xf32>
   }
 }
+
+// -----
+
+// CHECK:   func.func @custom_types(
+// CHECK-SAME:    %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME:  ) -> (!test.test_memref<[4, 8], f64>,
+// CHECK-SAME:        !test.test_memref<[4, 8], f64>)
+func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>)
+    -> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) {
+  // CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) :
+  // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
+  %out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 8], f64>
+
+  // CHECK: %[[alloc:.*]] = "test.create_memref_op"
+  // CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]])
+  // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
+  %alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64>
+  %out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 8], f64>
+
+  // CHECK: return %[[out1]], %[[out2]]
+  return %out1, %out2 :
+    !test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>
+}
+
+// -----
+
+// CHECK:   func.func @custom_types_foo(
+// CHECK-SAME:    %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME:  ) -> !test.test_memref<[4, 4], f64>
+func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64> {
+  // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]])
+  %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64>
+  // CHECK: return %[[out]]
+  return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// CHECK:   func.func @custom_types_bar(
+// CHECK-SAME:    %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME:  ) -> !test.test_memref<[4, 8], f64>
+func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 8], f64> {
+  // CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]])
+  %call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64>
+
+  // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]])
+  %out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 8], f64>
+
+  // CHECK: return %[[out]]
+  return %out : !test.test_tensor<[4, 8], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index ea20597231d58..562fc66acea2a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -444,6 +444,11 @@ def TestTensorType : Test_Type<"TestTensor",
     ::mlir::LogicalResult verifyCompatibleBufferType(
         ::mlir::bufferization::BufferLikeType bufferType,
         ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
+
+    ::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
+    getBufferTypeAtFunctionBoundary(mlir::func::FuncOp funcOp,
+        const ::mlir::bufferization::BufferizationOptions& options,
+        ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
   }];
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..3c92fb94aebee 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -573,3 +573,11 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
                      getElementType() == testMemref.getElementType();
   return mlir::success(valid);
 }
+
+::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
+TestTensorType::getBufferTypeAtFunctionBoundary(
+    mlir::func::FuncOp,
+    const ::mlir::bufferization::BufferizationOptions &options,
+    ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
+  return getBufferType(options, emitError);
+}

@@ -43,6 +43,18 @@ def Bufferization_TensorLikeTypeInterface
/*args=*/(ins
"::mlir::bufferization::BufferLikeType":$bufferType,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>,
InterfaceMethod<[{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this interface function necessary? We already have BufferizationOptions::functionArgTypeConverterFn. We can change the signature to:

  /// Tensor -> MemRef type converter.
  /// Parameters: tensor type, memory space, func op, bufferization options
  using FunctionArgTypeConverterFn =
      std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
                                   func::FuncOp, const BufferizationOptions &)>;

Do you need access to the function? In that case, we can pass an additional FunctionOpInterface to the lambda.

Copy link
Contributor Author

@andrey-golubev andrey-golubev Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't need to access the function (at least in my custom types), i'm actually intrigued why the FunctionArgTypeConverterFn needs it (e.g. the option function doesn't even take an index, so one has no info about which argument they're dealing with).

thought of the interface here because this seems to be aligned with what we've discussed before:

  • we (mostly) have options to customize builtin types behaviour
  • interface methods on top of tensor-like / buffer-like is where users provide customization for user types

out of the obvious challenges, doing this via options practically means users have to create their own "pass" (i.e. bufferizeOp()'s options have to be seeded with different functions that aren't really configurable outside of C++).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main issue with this new interface method is that I can't seem to provide a default implementation. i'd go with new method == getBufferType by default but it fails to compile since builtin types get this interface dynamically attached.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm actually intrigued why the FunctionArgTypeConverterFn needs it

I don't know either. But it could be useful to get more context to do the type conversion.

thought of the interface here because this seems to be aligned with what we've discussed before:

The problem that I'm seeing is that there are now two ways to configure the type conversion, which is a bit confusing. If we make this interface-based, we should delete the options callback. But I think that's not possible because you wouldn't be able to configure identity/fully dynamic layout maps at the function boundary.

Function boundaries are always a bit special, so I think having an options callback is desirable also for custom tensor/buffer types.

doing this via options practically means users have to create their own "pass" (i.e. bufferizeOp()'s options have to be seeded with different functions that aren't really configurable outside of C++)

That is the case in every production compiler that I'm aware of. This is not the only flag that cannot be customized as a pass flag. AllocationFn, MemCpyFn is another one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think that's not possible because you wouldn't be able to configure identity/fully dynamic layout maps at the function boundary.

I'm not sure I understand this part, can you explain more? Do you mean the void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); setter that makes the conversion LayoutMapOption value-dependent? I think, yes, if this is a hard blocker, an option callback must stay - passing LayoutMapOption via interfaces is probably a pain (if possible without hacks or a significant redesign).

Function boundaries are always a bit special, so I think having an options callback is desirable also for custom tensor/buffer types.

Fair point. Let's go with this then. I think I could do it a bit differently: for TensorType (current API), we keep the logic "as is", otherwise - for TensorLikeType - we always call it's getBufferType()? This would preserve single conversion "interface" - the options callback and also give us some sane default behavior for any type (i.e. i wouldn't need to create --test-one-shot-bufferize pass yet).

That is the case in every production compiler that I'm aware of. This is not the only flag that cannot be customized as a pass flag. AllocationFn, MemCpyFn is another one.

Right, also in our case that's the way it is (fwiw, this is probably inevitable). I was a bit lazy to have a test pass in MLIR for custom types though (i'll probably have one anyway but later: to test how builtins with tensor encoding -> memref layout behave, this is one of the important things we use everywhere in our downstream).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this interface method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this part, can you explain more?

Interface method implementation cannot be overridden by the user. In the interface implementation you have to make a choice whether to convert a tensor type to a memref with identity layout or fully dynamic layout map. We want users to be able to customize behavior, therefore, there must be some kind of hook in the BufferizationOptions.

return *options.defaultMemorySpaceFn(tensorType);
}
return nullptr;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: i guess we could also make defaultMemorySpaceFn work with tensor-like type. but i'd prefer to do it separately to simplify the updates for the users.

otoh, TensorLikeType::getBufferType() already has access to the tensor type and can infer memory space from it. I am not sure why we need it explicitly at all.

/// Type converter from tensors to buffers. This type converter is used to
/// determine bufferized function argument and result types.
///
/// By default, if tensor is a (builtin) tensor type, a type converter that
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Let's not use the term "type converter" here because that has a different meaning in a dialect conversion. I'd rephrase as "the type is converted to...". Or just write "type conversion" instead of "type converter".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack. i think this is original wording (just slightly refurbished).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to "type conversion". also for UnknownTypeConverterFn.

Support custom types (3/N): allow custom tensor and buffer types in
function signatures and at call-sites. This is one of the major building
blocks to move in the direction of module-level one-shot-bufferization
support.

In order to enable this, TensorLikeType is extended with a new interface
method that is invoked solely within the function boundary
bufferization.
Transform FunctionArgTypeConverterFn into a tensor-like -> buffer-like
converter so that it could be used as a generic function boundary
conversion utility.
@andrey-golubev andrey-golubev force-pushed the function_boundary_bufferization branch from d7373f3 to 95e4724 Compare September 24, 2025 09:47
@andrey-golubev andrey-golubev merged commit ff4c499 into llvm:main Sep 24, 2025
9 checks passed
@andrey-golubev andrey-golubev deleted the function_boundary_bufferization branch September 24, 2025 11:09
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
…vm#159766)

Support custom types (3/N): allow custom tensor and buffer types in
function signatures and at call-sites. This is one of the major building
blocks to move in the direction of module-level one-shot-bufferization
support.

To achieve this, `BufferizationOptions::FunctionArgTypeConverterFn`
callback is converted to work with tensor-like and buffer-like types,
instead of the builtin counterparts. The default behavior for builtins
remains unchanged, while custom types by default go through
`TensorLikeType::getBufferType()` which is a general conversion
interface.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants