From 16056beca1ff91853bec248d49eb424c26cc4d5b Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Wed, 9 Jul 2025 09:56:39 +0000 Subject: [PATCH 1/3] [mlir][bufferization] Support custom types at function boundaries Support custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support. In order to enable this, TensorLikeType is extended with a new interface method that is invoked solely within the function boundary bufferization. --- .../IR/BufferizationTypeInterfaces.h | 1 + .../IR/BufferizationTypeInterfaces.td | 12 +++ .../Bufferization/IR/BufferizationDialect.cpp | 13 +++ .../Bufferization/Transforms/Bufferize.cpp | 2 +- .../FuncBufferizableOpInterfaceImpl.cpp | 90 ++++++++++--------- .../Transforms/one-shot-module-bufferize.mlir | 56 ++++++++++++ mlir/test/lib/Dialect/Test/TestTypeDefs.td | 5 ++ mlir/test/lib/Dialect/Test/TestTypes.cpp | 8 ++ 8 files changed, 146 insertions(+), 41 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h index a2bfcb7ed2b75..9b052b8bb7e14 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h @@ -13,6 +13,7 @@ // Bufferization Type Interfaces //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Func/IR/FuncOps.h" // to access mlir::func::FuncOp #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td index fb6fc4f5ad964..c4235cd067999 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td @@ -43,6 +43,18 @@ def Bufferization_TensorLikeTypeInterface /*args=*/(ins "::mlir::bufferization::BufferLikeType":$bufferType, "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) + >, + InterfaceMethod<[{ + Returns a BufferLike type for this TensorLike type in the context of + this type being function argument or result. + }], + /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>", + /*methodName=*/"getBufferTypeAtFunctionBoundary", + /*args=*/(ins + "::mlir::func::FuncOp":$funcOp, + "const ::mlir::bufferization::BufferizationOptions &":$options, + "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError + ) > ]; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index 6c08cdfb669f3..9b907922a24c4 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -87,6 +87,19 @@ struct BuiltinTensorExternalModel return mlir::success(); } + + llvm::FailureOr getBufferTypeAtFunctionBoundary( + mlir::Type tensor, mlir::func::FuncOp funcOp, + const BufferizationOptions &options, + llvm::function_ref emitError) const { + auto tensorType = cast(tensor); + auto memSpace = options.defaultMemorySpaceFn(tensorType); + if (!memSpace.has_value()) + return emitError() << "could not infer memory space"; + + return cast(options.functionArgTypeConverterFn( + tensorType, *memSpace, funcOp, options)); + } }; template diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 68ef51992efee..701ab52a491a8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -401,7 +401,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, // Compute the new signature. SmallVector newTypes; for (BlockArgument &bbArg : block->getArguments()) { - auto tensorType = dyn_cast(bbArg.getType()); + auto tensorType = dyn_cast(bbArg.getType()); if (!tensorType) { newTypes.push_back(bbArg.getType()); continue; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index f69efd1b3fa8c..b7bac9f4623f1 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -52,26 +52,35 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { /// Return the index-th bufferized function argument type. This assumes that the /// specified argument is a tensor. If the tensor is ranked, a layout map may be /// specified by the user (as per `options.functionArgTypeConverterFn`). -static BaseMemRefType +static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { auto tensorType = - dyn_cast(funcOp.getFunctionType().getInput(index)); - assert(tensorType && "expected TensorType"); - - BaseMemRefType memrefType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); - - auto layoutAttr = funcOp.getArgAttrOfType( - index, BufferizationDialect::kBufferLayoutAttrName); - if (!layoutAttr) - return memrefType; - - auto rankedMemrefType = dyn_cast(memrefType); - assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); - return MemRefType::get(rankedMemrefType.getShape(), - rankedMemrefType.getElementType(), layoutAttr, - rankedMemrefType.getMemorySpace()); + dyn_cast(funcOp.getFunctionType().getInput(index)); + assert(tensorType && "expected TensorLikeType"); + auto maybeBufferType = tensorType.getBufferTypeAtFunctionBoundary( + funcOp, options, [&]() { return funcOp->emitError(); }); + assert(mlir::succeeded(maybeBufferType) && + "a valid buffer is always expected"); + + auto bufferType = *maybeBufferType; + + // Note: For builtin tensors there is additional logic related to layout. + if (isa(tensorType)) { + auto layoutAttr = funcOp.getArgAttrOfType( + index, BufferizationDialect::kBufferLayoutAttrName); + if (!layoutAttr) + return bufferType; + + auto rankedMemrefType = dyn_cast(bufferType); + assert(rankedMemrefType && + "buffer layout not supported on unranked tensors"); + return cast(MemRefType::get( + rankedMemrefType.getShape(), rankedMemrefType.getElementType(), + layoutAttr, rankedMemrefType.getMemorySpace())); + } + + return bufferType; } /// Return the FuncOp called by `callOp`. @@ -227,14 +236,13 @@ struct CallOpInterface FunctionType funcType = funcOp.getFunctionType(); Type resultType = funcType.getResult(cast(value).getResultNumber()); - if (auto bufferizedType = dyn_cast(resultType)) - return cast(bufferizedType); + if (auto bufferizedType = dyn_cast(resultType)) + return bufferizedType; // Otherwise, call the type converter to compute the bufferized type. - auto tensorType = cast(resultType); - return cast(options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, - options)); + auto tensorType = cast(resultType); + return tensorType.getBufferTypeAtFunctionBoundary( + funcOp, options, [&]() { return funcOp->emitError(); }); } /// All function arguments are writable. It is the responsibility of the @@ -248,7 +256,7 @@ struct CallOpInterface SmallVector resultTypes; for (Value result : callOp.getResults()) { Type returnType = result.getType(); - if (!isa(returnType)) { + if (!isa(returnType)) { // Non-tensor values are returned. resultTypes.push_back(returnType); continue; @@ -272,7 +280,7 @@ struct CallOpInterface for (OpOperand &opOperand : callOp->getOpOperands()) { // Non-tensor operands are just copied. - if (!isa(opOperand.get().getType())) { + if (!isa(opOperand.get().getType())) { newOperands.push_back(opOperand.get()); continue; } @@ -285,8 +293,8 @@ struct CallOpInterface Value buffer = *maybeBuffer; // Caller / callee type mismatch is handled with castOrReallocMemRefValue. - auto memRefType = funcType.getInput(opOperand.getOperandNumber()); - if (!isa(memRefType)) { + auto bufferType = funcType.getInput(opOperand.getOperandNumber()); + if (!isa(bufferType)) { // The called function was not bufferized yet. This can happen when // there cycles in the function call graph. Compute the bufferized // result type. @@ -296,7 +304,7 @@ struct CallOpInterface state); if (failed(maybeBufferType)) return failure(); - memRefType = *maybeBufferType; + bufferType = *maybeBufferType; } // Since we don't yet have a clear layout story, to_buffer may @@ -305,8 +313,8 @@ struct CallOpInterface // that will either canonicalize away or fail compilation until we can do // something better. Insert a reallocation + copy if it cannot be // statically guaranteed that a direct cast would be valid. - if (buffer.getType() != memRefType) { - auto memrefDstType = dyn_cast(memRefType); + if (buffer.getType() != bufferType) { + auto memrefDstType = dyn_cast(bufferType); assert(memrefDstType && "buffer layout not supported on unranked tensors"); FailureOr replacement = bufferization::castOrReallocMemRefValue( @@ -370,7 +378,7 @@ struct FuncOpInterface static bool supportsUnstructuredControlFlow() { return true; } bool hasTensorSemantics(Operation *op) const { - auto isaTensor = llvm::IsaPred; + auto isaTensor = llvm::IsaPred; // A function has tensor semantics if it has tensor arguments/results. auto funcOp = cast(op); @@ -406,8 +414,8 @@ struct FuncOpInterface // Function arguments are special. if (bbArg.getOwner() == &funcOp.getBody().front()) - return cast( - getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options)); + return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), + options); return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: getBufferType(op, value, options, state, invocationStack); @@ -430,7 +438,7 @@ struct FuncOpInterface SmallVector argTypes; for (const auto &it : llvm::enumerate(funcType.getInputs())) { Type argType = it.value(); - if (isa(argType)) { + if (isa(argType)) { argTypes.push_back( getBufferizedFunctionArgType(funcOp, it.index(), options)); continue; @@ -441,11 +449,13 @@ struct FuncOpInterface // Compute the result types. SmallVector retTypes; for (Type resultType : funcType.getResults()) { - if (auto tensorType = dyn_cast(resultType)) { - BaseMemRefType resultType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, - options); - retTypes.push_back(resultType); + if (auto tensorType = dyn_cast(resultType)) { + FailureOr resultType = + tensorType.getBufferTypeAtFunctionBoundary( + funcOp, options, [&]() { return funcOp->emitError(); }); + assert(mlir::succeeded(resultType) && + "a valid buffer is always expected"); + retTypes.push_back(*resultType); continue; } retTypes.push_back(resultType); @@ -473,7 +483,7 @@ struct FuncOpInterface SmallVector returnValues; for (auto [returnVal, bufferizedType] : llvm::zip_equal(returnOp->getOperands(), retTypes)) { - auto tensorType = dyn_cast(returnVal.getType()); + auto tensorType = dyn_cast(returnVal.getType()); rewriter.setInsertionPoint(returnOp); // If not a tensor type just forward it. diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 2efb5893c8511..eb0093106dc11 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -810,3 +810,59 @@ module @inner_module { return %t : tensor<5xf32> } } + +// ----- + +// CHECK: func.func @custom_types( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>, +// CHECK-SAME: !test.test_memref<[4, 8], f64>) +func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>) + -> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) { + // CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) : + // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64> + %out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: %[[alloc:.*]] = "test.create_memref_op" + // CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]]) + // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64> + %alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64> + %out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: return %[[out1]], %[[out2]] + return %out1, %out2 : + !test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64> +} + +// ----- + +// CHECK: func.func @custom_types_foo( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64> +func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> { + // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]]) + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + // CHECK: return %[[out]] + return %out : !test.test_tensor<[4, 4], f64> +} + +// CHECK: func.func @custom_types_bar( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64> +func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> { + // CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]]) + %call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + + // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]]) + %out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: return %[[out]] + return %out : !test.test_tensor<[4, 8], f64> +} diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index ea20597231d58..562fc66acea2a 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -444,6 +444,11 @@ def TestTensorType : Test_Type<"TestTensor", ::mlir::LogicalResult verifyCompatibleBufferType( ::mlir::bufferization::BufferLikeType bufferType, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); + + ::mlir::FailureOr<::mlir::bufferization::BufferLikeType> + getBufferTypeAtFunctionBoundary(mlir::func::FuncOp funcOp, + const ::mlir::bufferization::BufferizationOptions& options, + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); }]; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index bea043f56fe21..3c92fb94aebee 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -573,3 +573,11 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( getElementType() == testMemref.getElementType(); return mlir::success(valid); } + +::mlir::FailureOr<::mlir::bufferization::BufferLikeType> +TestTensorType::getBufferTypeAtFunctionBoundary( + mlir::func::FuncOp, + const ::mlir::bufferization::BufferizationOptions &options, + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { + return getBufferType(options, emitError); +} From e40d66a643d9b69e073e0c7cb9914a32445b3160 Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Mon, 22 Sep 2025 11:22:53 +0000 Subject: [PATCH 2/3] Rely on BufferizationOptions::FunctionArgTypeConverterFn Transform FunctionArgTypeConverterFn into a tensor-like -> buffer-like converter so that it could be used as a generic function boundary conversion utility. --- .../IR/BufferizableOpInterface.h | 16 ++++--- .../IR/BufferizationTypeInterfaces.h | 1 - .../IR/BufferizationTypeInterfaces.td | 12 ----- .../IR/BufferizableOpInterface.cpp | 39 +++++++++++---- .../Bufferization/IR/BufferizationDialect.cpp | 13 ----- .../FuncBufferizableOpInterfaceImpl.cpp | 48 +++++++++++-------- mlir/test/lib/Dialect/Test/TestTypeDefs.td | 5 -- mlir/test/lib/Dialect/Test/TestTypes.cpp | 8 ---- 8 files changed, 67 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index f3b34f9fded7f..5bf3916630158 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -260,10 +260,10 @@ struct BufferizationOptions { std::function; /// Initializer function for analysis state. using AnalysisStateInitFn = std::function; - /// Tensor -> MemRef type converter. - /// Parameters: tensor type, memory space, func op, bufferization options + /// Tensor-like -> Buffer-like type converter. + /// Parameters: tensor-like type, memory space, func op, bufferization options using FunctionArgTypeConverterFn = - std::function; /// Tensor -> MemRef type converter. /// Parameters: tensor type, memory space, bufferization options @@ -335,10 +335,12 @@ struct BufferizationOptions { /// predictable. void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); - /// Type converter from tensors to memrefs. This type converter is used to - /// determine bufferized function argument and result types. By default, a - /// type converter that returns a memref type with a fully dynamic layout map - /// is used. + /// Type converter from tensors to buffers. This type converter is used to + /// determine bufferized function argument and result types. + /// + /// By default, if tensor is a (builtin) tensor type, a type converter that + /// returns a memref type with a fully dynamic layout map is used; if tensor + /// is a (generic) tensor-like type, TensorLikeType::getBufferType() is used. /// /// If `bufferizeFunctionBoundaries` is not set, this function isn't used. FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h index 9b052b8bb7e14..a2bfcb7ed2b75 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h @@ -13,7 +13,6 @@ // Bufferization Type Interfaces //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Func/IR/FuncOps.h" // to access mlir::func::FuncOp #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td index c4235cd067999..fb6fc4f5ad964 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td @@ -43,18 +43,6 @@ def Bufferization_TensorLikeTypeInterface /*args=*/(ins "::mlir::bufferization::BufferLikeType":$bufferType, "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) - >, - InterfaceMethod<[{ - Returns a BufferLike type for this TensorLike type in the context of - this type being function argument or result. - }], - /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>", - /*methodName=*/"getBufferTypeAtFunctionBoundary", - /*args=*/(ins - "::mlir::func::FuncOp":$funcOp, - "const ::mlir::bufferization::BufferizationOptions &":$options, - "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError - ) > ]; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index f7b0b87085f3d..fae1df69ed3e3 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -338,11 +338,21 @@ bool OpFilter::isOpAllowed(Operation *op) const { namespace { /// Default function arg type converter: Use a fully dynamic layout map. -BaseMemRefType -defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, +BufferLikeType +defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); + if (auto tensorType = mlir::dyn_cast(type)) { + return cast( + getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(mlir::succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; } /// Default unknown type converter: Use a fully dynamic layout map. BaseMemRefType @@ -385,14 +395,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const { void BufferizationOptions::setFunctionBoundaryTypeConversion( LayoutMapOption layoutMapOption) { - functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, + functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) - return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, - memorySpace); - return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, - memorySpace); + if (auto tensorType = mlir::dyn_cast(type)) { + if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) + return cast( + bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, + memorySpace)); + return cast( + bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, + memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(mlir::succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; }; inferFunctionResultLayout = layoutMapOption == LayoutMapOption::InferLayoutMap; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp index 9b907922a24c4..6c08cdfb669f3 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -87,19 +87,6 @@ struct BuiltinTensorExternalModel return mlir::success(); } - - llvm::FailureOr getBufferTypeAtFunctionBoundary( - mlir::Type tensor, mlir::func::FuncOp funcOp, - const BufferizationOptions &options, - llvm::function_ref emitError) const { - auto tensorType = cast(tensor); - auto memSpace = options.defaultMemorySpaceFn(tensorType); - if (!memSpace.has_value()) - return emitError() << "could not infer memory space"; - - return cast(options.functionArgTypeConverterFn( - tensorType, *memSpace, funcOp, options)); - } }; template diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index b7bac9f4623f1..d9d69342e42a8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -49,30 +49,38 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { #endif // NDEBUG } +// Note: this is a local adaptor to unify TensorType and TensorLikeType code +// paths that both work with BufferizationOptions. +static mlir::Attribute +getDefaultMemorySpace(const BufferizationOptions &options, + TensorLikeType type) { + if (auto tensorType = dyn_cast(type)) { + return *options.defaultMemorySpaceFn(tensorType); + } + return nullptr; +} + /// Return the index-th bufferized function argument type. This assumes that the /// specified argument is a tensor. If the tensor is ranked, a layout map may be /// specified by the user (as per `options.functionArgTypeConverterFn`). static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { - auto tensorType = + auto type = dyn_cast(funcOp.getFunctionType().getInput(index)); - assert(tensorType && "expected TensorLikeType"); - auto maybeBufferType = tensorType.getBufferTypeAtFunctionBoundary( - funcOp, options, [&]() { return funcOp->emitError(); }); - assert(mlir::succeeded(maybeBufferType) && - "a valid buffer is always expected"); - - auto bufferType = *maybeBufferType; + assert(type && "expected TensorLikeType"); // Note: For builtin tensors there is additional logic related to layout. - if (isa(tensorType)) { + if (auto tensorType = dyn_cast(type)) { + BufferLikeType memrefType = options.functionArgTypeConverterFn( + type, *options.defaultMemorySpaceFn(tensorType), funcOp, options); + auto layoutAttr = funcOp.getArgAttrOfType( index, BufferizationDialect::kBufferLayoutAttrName); if (!layoutAttr) - return bufferType; + return memrefType; - auto rankedMemrefType = dyn_cast(bufferType); + auto rankedMemrefType = dyn_cast(memrefType); assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); return cast(MemRefType::get( @@ -80,7 +88,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, layoutAttr, rankedMemrefType.getMemorySpace())); } - return bufferType; + return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp, + options); } /// Return the FuncOp called by `callOp`. @@ -241,8 +250,9 @@ struct CallOpInterface // Otherwise, call the type converter to compute the bufferized type. auto tensorType = cast(resultType); - return tensorType.getBufferTypeAtFunctionBoundary( - funcOp, options, [&]() { return funcOp->emitError(); }); + return cast(options.functionArgTypeConverterFn( + tensorType, getDefaultMemorySpace(options, tensorType), funcOp, + options)); } /// All function arguments are writable. It is the responsibility of the @@ -450,12 +460,10 @@ struct FuncOpInterface SmallVector retTypes; for (Type resultType : funcType.getResults()) { if (auto tensorType = dyn_cast(resultType)) { - FailureOr resultType = - tensorType.getBufferTypeAtFunctionBoundary( - funcOp, options, [&]() { return funcOp->emitError(); }); - assert(mlir::succeeded(resultType) && - "a valid buffer is always expected"); - retTypes.push_back(*resultType); + BufferLikeType resultType = options.functionArgTypeConverterFn( + tensorType, getDefaultMemorySpace(options, tensorType), funcOp, + options); + retTypes.push_back(resultType); continue; } retTypes.push_back(resultType); diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 562fc66acea2a..ea20597231d58 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -444,11 +444,6 @@ def TestTensorType : Test_Type<"TestTensor", ::mlir::LogicalResult verifyCompatibleBufferType( ::mlir::bufferization::BufferLikeType bufferType, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); - - ::mlir::FailureOr<::mlir::bufferization::BufferLikeType> - getBufferTypeAtFunctionBoundary(mlir::func::FuncOp funcOp, - const ::mlir::bufferization::BufferizationOptions& options, - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError); }]; } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 3c92fb94aebee..bea043f56fe21 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -573,11 +573,3 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( getElementType() == testMemref.getElementType(); return mlir::success(valid); } - -::mlir::FailureOr<::mlir::bufferization::BufferLikeType> -TestTensorType::getBufferTypeAtFunctionBoundary( - mlir::func::FuncOp, - const ::mlir::bufferization::BufferizationOptions &options, - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) { - return getBufferType(options, emitError); -} From 95e4724090591d6b034e68c2d624fe22fdda876b Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Wed, 24 Sep 2025 09:50:04 +0000 Subject: [PATCH 3/3] Change documentation wording for type converters --- .../IR/BufferizableOpInterface.h | 19 +++++++++---------- .../IR/BufferizableOpInterface.cpp | 4 ++-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 5bf3916630158..dd693a25fd54f 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -260,12 +260,12 @@ struct BufferizationOptions { std::function; /// Initializer function for analysis state. using AnalysisStateInitFn = std::function; - /// Tensor-like -> Buffer-like type converter. + /// Tensor-like -> Buffer-like type conversion. /// Parameters: tensor-like type, memory space, func op, bufferization options using FunctionArgTypeConverterFn = std::function; - /// Tensor -> MemRef type converter. + /// Tensor -> MemRef type conversion. /// Parameters: tensor type, memory space, bufferization options using UnknownTypeConverterFn = std::function; @@ -335,12 +335,12 @@ struct BufferizationOptions { /// predictable. void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); - /// Type converter from tensors to buffers. This type converter is used to + /// Type conversion from tensors to buffers. This type conversion is used to /// determine bufferized function argument and result types. /// - /// By default, if tensor is a (builtin) tensor type, a type converter that - /// returns a memref type with a fully dynamic layout map is used; if tensor - /// is a (generic) tensor-like type, TensorLikeType::getBufferType() is used. + /// By default, if tensor is a (builtin) tensor type, it is converted to a + /// memref type with a fully dynamic layout map; if tensor is a (generic) + /// tensor-like type, it is converted using TensorLikeType::getBufferType(). /// /// If `bufferizeFunctionBoundaries` is not set, this function isn't used. FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr; @@ -352,10 +352,9 @@ struct BufferizationOptions { /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. bool inferFunctionResultLayout = true; - /// Type converter from tensors to memrefs. This type converter is used if no - /// memref type could be inferred during bufferization. By default, a type - /// converter that returns a memref type with a fully dynamic layout map is - /// used. + /// Type conversion from tensors to memrefs. This type conversion is used if + /// no memref type could be inferred during bufferization. By default, returns + /// a memref type with a fully dynamic layout map. UnknownTypeConverterFn unknownTypeConverterFn = nullptr; // Use during type conversion to determine the memory space for memref based diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index fae1df69ed3e3..e0cf353da207f 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -350,7 +350,7 @@ defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace, // If not builtin, fallback to TensorLikeType::getBufferType() auto bufferType = type.getBufferType(options, [&]() { return funcOp->emitError(); }); - assert(mlir::succeeded(bufferType) && + assert(succeeded(bufferType) && "a valid buffer is always expected at function boundary"); return *bufferType; } @@ -411,7 +411,7 @@ void BufferizationOptions::setFunctionBoundaryTypeConversion( // If not builtin, fallback to TensorLikeType::getBufferType() auto bufferType = type.getBufferType(options, [&]() { return funcOp->emitError(); }); - assert(mlir::succeeded(bufferType) && + assert(succeeded(bufferType) && "a valid buffer is always expected at function boundary"); return *bufferType; };