diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index f3b34f9fded7f..dd693a25fd54f 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -260,12 +260,12 @@ struct BufferizationOptions { std::function; /// Initializer function for analysis state. using AnalysisStateInitFn = std::function; - /// Tensor -> MemRef type converter. - /// Parameters: tensor type, memory space, func op, bufferization options + /// Tensor-like -> Buffer-like type conversion. + /// Parameters: tensor-like type, memory space, func op, bufferization options using FunctionArgTypeConverterFn = - std::function; - /// Tensor -> MemRef type converter. + /// Tensor -> MemRef type conversion. /// Parameters: tensor type, memory space, bufferization options using UnknownTypeConverterFn = std::function; @@ -335,10 +335,12 @@ struct BufferizationOptions { /// predictable. void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption); - /// Type converter from tensors to memrefs. This type converter is used to - /// determine bufferized function argument and result types. By default, a - /// type converter that returns a memref type with a fully dynamic layout map - /// is used. + /// Type conversion from tensors to buffers. This type conversion is used to + /// determine bufferized function argument and result types. + /// + /// By default, if tensor is a (builtin) tensor type, it is converted to a + /// memref type with a fully dynamic layout map; if tensor is a (generic) + /// tensor-like type, it is converted using TensorLikeType::getBufferType(). /// /// If `bufferizeFunctionBoundaries` is not set, this function isn't used. FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr; @@ -350,10 +352,9 @@ struct BufferizationOptions { /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect. bool inferFunctionResultLayout = true; - /// Type converter from tensors to memrefs. This type converter is used if no - /// memref type could be inferred during bufferization. By default, a type - /// converter that returns a memref type with a fully dynamic layout map is - /// used. + /// Type conversion from tensors to memrefs. This type conversion is used if + /// no memref type could be inferred during bufferization. By default, returns + /// a memref type with a fully dynamic layout map. UnknownTypeConverterFn unknownTypeConverterFn = nullptr; // Use during type conversion to determine the memory space for memref based diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index f7b0b87085f3d..e0cf353da207f 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -338,11 +338,21 @@ bool OpFilter::isOpAllowed(Operation *op) const { namespace { /// Default function arg type converter: Use a fully dynamic layout map. -BaseMemRefType -defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace, +BufferLikeType +defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - return getMemRefTypeWithFullyDynamicLayout(type, memorySpace); + if (auto tensorType = mlir::dyn_cast(type)) { + return cast( + getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; } /// Default unknown type converter: Use a fully dynamic layout map. BaseMemRefType @@ -385,14 +395,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const { void BufferizationOptions::setFunctionBoundaryTypeConversion( LayoutMapOption layoutMapOption) { - functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace, + functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace, func::FuncOp funcOp, const BufferizationOptions &options) { - if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) - return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, - memorySpace); - return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, - memorySpace); + if (auto tensorType = mlir::dyn_cast(type)) { + if (layoutMapOption == LayoutMapOption::IdentityLayoutMap) + return cast( + bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, + memorySpace)); + return cast( + bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType, + memorySpace)); + } + + // If not builtin, fallback to TensorLikeType::getBufferType() + auto bufferType = + type.getBufferType(options, [&]() { return funcOp->emitError(); }); + assert(succeeded(bufferType) && + "a valid buffer is always expected at function boundary"); + return *bufferType; }; inferFunctionResultLayout = layoutMapOption == LayoutMapOption::InferLayoutMap; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 68ef51992efee..701ab52a491a8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -401,7 +401,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, // Compute the new signature. SmallVector newTypes; for (BlockArgument &bbArg : block->getArguments()) { - auto tensorType = dyn_cast(bbArg.getType()); + auto tensorType = dyn_cast(bbArg.getType()); if (!tensorType) { newTypes.push_back(bbArg.getType()); continue; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index f69efd1b3fa8c..d9d69342e42a8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -49,29 +49,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { #endif // NDEBUG } +// Note: this is a local adaptor to unify TensorType and TensorLikeType code +// paths that both work with BufferizationOptions. +static mlir::Attribute +getDefaultMemorySpace(const BufferizationOptions &options, + TensorLikeType type) { + if (auto tensorType = dyn_cast(type)) { + return *options.defaultMemorySpaceFn(tensorType); + } + return nullptr; +} + /// Return the index-th bufferized function argument type. This assumes that the /// specified argument is a tensor. If the tensor is ranked, a layout map may be /// specified by the user (as per `options.functionArgTypeConverterFn`). -static BaseMemRefType +static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options) { - auto tensorType = - dyn_cast(funcOp.getFunctionType().getInput(index)); - assert(tensorType && "expected TensorType"); - - BaseMemRefType memrefType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); - - auto layoutAttr = funcOp.getArgAttrOfType( - index, BufferizationDialect::kBufferLayoutAttrName); - if (!layoutAttr) - return memrefType; - - auto rankedMemrefType = dyn_cast(memrefType); - assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); - return MemRefType::get(rankedMemrefType.getShape(), - rankedMemrefType.getElementType(), layoutAttr, - rankedMemrefType.getMemorySpace()); + auto type = + dyn_cast(funcOp.getFunctionType().getInput(index)); + assert(type && "expected TensorLikeType"); + + // Note: For builtin tensors there is additional logic related to layout. + if (auto tensorType = dyn_cast(type)) { + BufferLikeType memrefType = options.functionArgTypeConverterFn( + type, *options.defaultMemorySpaceFn(tensorType), funcOp, options); + + auto layoutAttr = funcOp.getArgAttrOfType( + index, BufferizationDialect::kBufferLayoutAttrName); + if (!layoutAttr) + return memrefType; + + auto rankedMemrefType = dyn_cast(memrefType); + assert(rankedMemrefType && + "buffer layout not supported on unranked tensors"); + return cast(MemRefType::get( + rankedMemrefType.getShape(), rankedMemrefType.getElementType(), + layoutAttr, rankedMemrefType.getMemorySpace())); + } + + return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp, + options); } /// Return the FuncOp called by `callOp`. @@ -227,13 +245,13 @@ struct CallOpInterface FunctionType funcType = funcOp.getFunctionType(); Type resultType = funcType.getResult(cast(value).getResultNumber()); - if (auto bufferizedType = dyn_cast(resultType)) - return cast(bufferizedType); + if (auto bufferizedType = dyn_cast(resultType)) + return bufferizedType; // Otherwise, call the type converter to compute the bufferized type. - auto tensorType = cast(resultType); + auto tensorType = cast(resultType); return cast(options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + tensorType, getDefaultMemorySpace(options, tensorType), funcOp, options)); } @@ -248,7 +266,7 @@ struct CallOpInterface SmallVector resultTypes; for (Value result : callOp.getResults()) { Type returnType = result.getType(); - if (!isa(returnType)) { + if (!isa(returnType)) { // Non-tensor values are returned. resultTypes.push_back(returnType); continue; @@ -272,7 +290,7 @@ struct CallOpInterface for (OpOperand &opOperand : callOp->getOpOperands()) { // Non-tensor operands are just copied. - if (!isa(opOperand.get().getType())) { + if (!isa(opOperand.get().getType())) { newOperands.push_back(opOperand.get()); continue; } @@ -285,8 +303,8 @@ struct CallOpInterface Value buffer = *maybeBuffer; // Caller / callee type mismatch is handled with castOrReallocMemRefValue. - auto memRefType = funcType.getInput(opOperand.getOperandNumber()); - if (!isa(memRefType)) { + auto bufferType = funcType.getInput(opOperand.getOperandNumber()); + if (!isa(bufferType)) { // The called function was not bufferized yet. This can happen when // there cycles in the function call graph. Compute the bufferized // result type. @@ -296,7 +314,7 @@ struct CallOpInterface state); if (failed(maybeBufferType)) return failure(); - memRefType = *maybeBufferType; + bufferType = *maybeBufferType; } // Since we don't yet have a clear layout story, to_buffer may @@ -305,8 +323,8 @@ struct CallOpInterface // that will either canonicalize away or fail compilation until we can do // something better. Insert a reallocation + copy if it cannot be // statically guaranteed that a direct cast would be valid. - if (buffer.getType() != memRefType) { - auto memrefDstType = dyn_cast(memRefType); + if (buffer.getType() != bufferType) { + auto memrefDstType = dyn_cast(bufferType); assert(memrefDstType && "buffer layout not supported on unranked tensors"); FailureOr replacement = bufferization::castOrReallocMemRefValue( @@ -370,7 +388,7 @@ struct FuncOpInterface static bool supportsUnstructuredControlFlow() { return true; } bool hasTensorSemantics(Operation *op) const { - auto isaTensor = llvm::IsaPred; + auto isaTensor = llvm::IsaPred; // A function has tensor semantics if it has tensor arguments/results. auto funcOp = cast(op); @@ -406,8 +424,8 @@ struct FuncOpInterface // Function arguments are special. if (bbArg.getOwner() == &funcOp.getBody().front()) - return cast( - getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options)); + return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), + options); return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: getBufferType(op, value, options, state, invocationStack); @@ -430,7 +448,7 @@ struct FuncOpInterface SmallVector argTypes; for (const auto &it : llvm::enumerate(funcType.getInputs())) { Type argType = it.value(); - if (isa(argType)) { + if (isa(argType)) { argTypes.push_back( getBufferizedFunctionArgType(funcOp, it.index(), options)); continue; @@ -441,9 +459,9 @@ struct FuncOpInterface // Compute the result types. SmallVector retTypes; for (Type resultType : funcType.getResults()) { - if (auto tensorType = dyn_cast(resultType)) { - BaseMemRefType resultType = options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + if (auto tensorType = dyn_cast(resultType)) { + BufferLikeType resultType = options.functionArgTypeConverterFn( + tensorType, getDefaultMemorySpace(options, tensorType), funcOp, options); retTypes.push_back(resultType); continue; @@ -473,7 +491,7 @@ struct FuncOpInterface SmallVector returnValues; for (auto [returnVal, bufferizedType] : llvm::zip_equal(returnOp->getOperands(), retTypes)) { - auto tensorType = dyn_cast(returnVal.getType()); + auto tensorType = dyn_cast(returnVal.getType()); rewriter.setInsertionPoint(returnOp); // If not a tensor type just forward it. diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 2efb5893c8511..eb0093106dc11 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -810,3 +810,59 @@ module @inner_module { return %t : tensor<5xf32> } } + +// ----- + +// CHECK: func.func @custom_types( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>, +// CHECK-SAME: !test.test_memref<[4, 8], f64>) +func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>) + -> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) { + // CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) : + // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64> + %out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: %[[alloc:.*]] = "test.create_memref_op" + // CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]]) + // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64> + %alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64> + %out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: return %[[out1]], %[[out2]] + return %out1, %out2 : + !test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64> +} + +// ----- + +// CHECK: func.func @custom_types_foo( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64> +func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> { + // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]]) + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + // CHECK: return %[[out]] + return %out : !test.test_tensor<[4, 4], f64> +} + +// CHECK: func.func @custom_types_bar( +// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64> +func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> { + // CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]]) + %call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + + // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]]) + %out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 8], f64> + + // CHECK: return %[[out]] + return %out : !test.test_tensor<[4, 8], f64> +}