-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][bufferization] Support custom types at function boundaries #159766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Support custom types at function boundaries #159766
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-bufferization Author: Andrei Golubev (andrey-golubev) ChangesSupport custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support. In order to enable this, TensorLikeType is extended with a new interface method that is invoked solely within the function boundary bufferization. Full diff: https://github.com/llvm/llvm-project/pull/159766.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index a2bfcb7ed2b75..9b052b8bb7e14 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Func/IR/FuncOps.h" // to access mlir::func::FuncOp
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index fb6fc4f5ad964..c4235cd067999 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -43,6 +43,18 @@ def Bufferization_TensorLikeTypeInterface
/*args=*/(ins
"::mlir::bufferization::BufferLikeType":$bufferType,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+ >,
+ InterfaceMethod<[{
+ Returns a BufferLike type for this TensorLike type in the context of
+ this type being function argument or result.
+ }],
+ /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
+ /*methodName=*/"getBufferTypeAtFunctionBoundary",
+ /*args=*/(ins
+ "::mlir::func::FuncOp":$funcOp,
+ "const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
+ )
>
];
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..9b907922a24c4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -87,6 +87,19 @@ struct BuiltinTensorExternalModel
return mlir::success();
}
+
+ llvm::FailureOr<BufferLikeType> getBufferTypeAtFunctionBoundary(
+ mlir::Type tensor, mlir::func::FuncOp funcOp,
+ const BufferizationOptions &options,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ auto tensorType = cast<TensorType>(tensor);
+ auto memSpace = options.defaultMemorySpaceFn(tensorType);
+ if (!memSpace.has_value())
+ return emitError() << "could not infer memory space";
+
+ return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+ tensorType, *memSpace, funcOp, options));
+ }
};
template <typename MemRef>
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 68ef51992efee..701ab52a491a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -401,7 +401,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
// Compute the new signature.
SmallVector<Type> newTypes;
for (BlockArgument &bbArg : block->getArguments()) {
- auto tensorType = dyn_cast<TensorType>(bbArg.getType());
+ auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
if (!tensorType) {
newTypes.push_back(bbArg.getType());
continue;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index f69efd1b3fa8c..b7bac9f4623f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -52,26 +52,35 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
/// Return the index-th bufferized function argument type. This assumes that the
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
/// specified by the user (as per `options.functionArgTypeConverterFn`).
-static BaseMemRefType
+static BufferLikeType
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
auto tensorType =
- dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
- assert(tensorType && "expected TensorType");
-
- BaseMemRefType memrefType = options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
-
- auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
- index, BufferizationDialect::kBufferLayoutAttrName);
- if (!layoutAttr)
- return memrefType;
-
- auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
- assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
- return MemRefType::get(rankedMemrefType.getShape(),
- rankedMemrefType.getElementType(), layoutAttr,
- rankedMemrefType.getMemorySpace());
+ dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
+ assert(tensorType && "expected TensorLikeType");
+ auto maybeBufferType = tensorType.getBufferTypeAtFunctionBoundary(
+ funcOp, options, [&]() { return funcOp->emitError(); });
+ assert(mlir::succeeded(maybeBufferType) &&
+ "a valid buffer is always expected");
+
+ auto bufferType = *maybeBufferType;
+
+ // Note: For builtin tensors there is additional logic related to layout.
+ if (isa<TensorType>(tensorType)) {
+ auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
+ index, BufferizationDialect::kBufferLayoutAttrName);
+ if (!layoutAttr)
+ return bufferType;
+
+ auto rankedMemrefType = dyn_cast<MemRefType>(bufferType);
+ assert(rankedMemrefType &&
+ "buffer layout not supported on unranked tensors");
+ return cast<BufferLikeType>(MemRefType::get(
+ rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
+ layoutAttr, rankedMemrefType.getMemorySpace()));
+ }
+
+ return bufferType;
}
/// Return the FuncOp called by `callOp`.
@@ -227,14 +236,13 @@ struct CallOpInterface
FunctionType funcType = funcOp.getFunctionType();
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
- if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
- return cast<BufferLikeType>(bufferizedType);
+ if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
+ return bufferizedType;
// Otherwise, call the type converter to compute the bufferized type.
- auto tensorType = cast<TensorType>(resultType);
- return cast<BufferLikeType>(options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
- options));
+ auto tensorType = cast<TensorLikeType>(resultType);
+ return tensorType.getBufferTypeAtFunctionBoundary(
+ funcOp, options, [&]() { return funcOp->emitError(); });
}
/// All function arguments are writable. It is the responsibility of the
@@ -248,7 +256,7 @@ struct CallOpInterface
SmallVector<Type> resultTypes;
for (Value result : callOp.getResults()) {
Type returnType = result.getType();
- if (!isa<TensorType>(returnType)) {
+ if (!isa<TensorLikeType>(returnType)) {
// Non-tensor values are returned.
resultTypes.push_back(returnType);
continue;
@@ -272,7 +280,7 @@ struct CallOpInterface
for (OpOperand &opOperand : callOp->getOpOperands()) {
// Non-tensor operands are just copied.
- if (!isa<TensorType>(opOperand.get().getType())) {
+ if (!isa<TensorLikeType>(opOperand.get().getType())) {
newOperands.push_back(opOperand.get());
continue;
}
@@ -285,8 +293,8 @@ struct CallOpInterface
Value buffer = *maybeBuffer;
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
- auto memRefType = funcType.getInput(opOperand.getOperandNumber());
- if (!isa<BaseMemRefType>(memRefType)) {
+ auto bufferType = funcType.getInput(opOperand.getOperandNumber());
+ if (!isa<BufferLikeType>(bufferType)) {
// The called function was not bufferized yet. This can happen when
// there cycles in the function call graph. Compute the bufferized
// result type.
@@ -296,7 +304,7 @@ struct CallOpInterface
state);
if (failed(maybeBufferType))
return failure();
- memRefType = *maybeBufferType;
+ bufferType = *maybeBufferType;
}
// Since we don't yet have a clear layout story, to_buffer may
@@ -305,8 +313,8 @@ struct CallOpInterface
// that will either canonicalize away or fail compilation until we can do
// something better. Insert a reallocation + copy if it cannot be
// statically guaranteed that a direct cast would be valid.
- if (buffer.getType() != memRefType) {
- auto memrefDstType = dyn_cast<MemRefType>(memRefType);
+ if (buffer.getType() != bufferType) {
+ auto memrefDstType = dyn_cast<MemRefType>(bufferType);
assert(memrefDstType &&
"buffer layout not supported on unranked tensors");
FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
@@ -370,7 +378,7 @@ struct FuncOpInterface
static bool supportsUnstructuredControlFlow() { return true; }
bool hasTensorSemantics(Operation *op) const {
- auto isaTensor = llvm::IsaPred<TensorType>;
+ auto isaTensor = llvm::IsaPred<TensorLikeType>;
// A function has tensor semantics if it has tensor arguments/results.
auto funcOp = cast<FuncOp>(op);
@@ -406,8 +414,8 @@ struct FuncOpInterface
// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
- return cast<BufferLikeType>(
- getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
+ return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
+ options);
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
@@ -430,7 +438,7 @@ struct FuncOpInterface
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
- if (isa<TensorType>(argType)) {
+ if (isa<TensorLikeType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
@@ -441,11 +449,13 @@ struct FuncOpInterface
// Compute the result types.
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
- if (auto tensorType = dyn_cast<TensorType>(resultType)) {
- BaseMemRefType resultType = options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
- options);
- retTypes.push_back(resultType);
+ if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
+ FailureOr<BufferLikeType> resultType =
+ tensorType.getBufferTypeAtFunctionBoundary(
+ funcOp, options, [&]() { return funcOp->emitError(); });
+ assert(mlir::succeeded(resultType) &&
+ "a valid buffer is always expected");
+ retTypes.push_back(*resultType);
continue;
}
retTypes.push_back(resultType);
@@ -473,7 +483,7 @@ struct FuncOpInterface
SmallVector<Value> returnValues;
for (auto [returnVal, bufferizedType] :
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
- auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+ auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);
// If not a tensor type just forward it.
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 2efb5893c8511..eb0093106dc11 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -810,3 +810,59 @@ module @inner_module {
return %t : tensor<5xf32>
}
}
+
+// -----
+
+// CHECK: func.func @custom_types(
+// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>,
+// CHECK-SAME: !test.test_memref<[4, 8], f64>)
+func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>)
+ -> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) {
+ // CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) :
+ // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
+ %out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 8], f64>
+
+ // CHECK: %[[alloc:.*]] = "test.create_memref_op"
+ // CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]])
+ // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
+ %alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64>
+ %out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 8], f64>
+
+ // CHECK: return %[[out1]], %[[out2]]
+ return %out1, %out2 :
+ !test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>
+}
+
+// -----
+
+// CHECK: func.func @custom_types_foo(
+// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64> {
+ // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]])
+ %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+ // CHECK: return %[[out]]
+ return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// CHECK: func.func @custom_types_bar(
+// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64>
+func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 8], f64> {
+ // CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]])
+ %call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+
+ // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]])
+ %out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 8], f64>
+
+ // CHECK: return %[[out]]
+ return %out : !test.test_tensor<[4, 8], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index ea20597231d58..562fc66acea2a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -444,6 +444,11 @@ def TestTensorType : Test_Type<"TestTensor",
::mlir::LogicalResult verifyCompatibleBufferType(
::mlir::bufferization::BufferLikeType bufferType,
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
+
+ ::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
+ getBufferTypeAtFunctionBoundary(mlir::func::FuncOp funcOp,
+ const ::mlir::bufferization::BufferizationOptions& options,
+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
}];
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..3c92fb94aebee 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -573,3 +573,11 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
getElementType() == testMemref.getElementType();
return mlir::success(valid);
}
+
+::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
+TestTensorType::getBufferTypeAtFunctionBoundary(
+ mlir::func::FuncOp,
+ const ::mlir::bufferization::BufferizationOptions &options,
+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
+ return getBufferType(options, emitError);
+}
|
@@ -43,6 +43,18 @@ def Bufferization_TensorLikeTypeInterface | |||
/*args=*/(ins | |||
"::mlir::bufferization::BufferLikeType":$bufferType, | |||
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError) | |||
>, | |||
InterfaceMethod<[{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this interface function necessary? We already have BufferizationOptions::functionArgTypeConverterFn
. We can change the signature to:
/// Tensor -> MemRef type converter.
/// Parameters: tensor type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
Do you need access to the function? In that case, we can pass an additional FunctionOpInterface
to the lambda.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't need to access the function (at least in my custom types), i'm actually intrigued why the FunctionArgTypeConverterFn
needs it (e.g. the option function doesn't even take an index, so one has no info about which argument they're dealing with).
thought of the interface here because this seems to be aligned with what we've discussed before:
- we (mostly) have options to customize builtin types behaviour
- interface methods on top of tensor-like / buffer-like is where users provide customization for user types
out of the obvious challenges, doing this via options practically means users have to create their own "pass" (i.e. bufferizeOp()'s options have to be seeded with different functions that aren't really configurable outside of C++).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
main issue with this new interface method is that I can't seem to provide a default implementation. i'd go with new method
== getBufferType
by default but it fails to compile since builtin types get this interface dynamically attached.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm actually intrigued why the FunctionArgTypeConverterFn needs it
I don't know either. But it could be useful to get more context to do the type conversion.
thought of the interface here because this seems to be aligned with what we've discussed before:
The problem that I'm seeing is that there are now two ways to configure the type conversion, which is a bit confusing. If we make this interface-based, we should delete the options callback. But I think that's not possible because you wouldn't be able to configure identity/fully dynamic layout maps at the function boundary.
Function boundaries are always a bit special, so I think having an options callback is desirable also for custom tensor/buffer types.
doing this via options practically means users have to create their own "pass" (i.e. bufferizeOp()'s options have to be seeded with different functions that aren't really configurable outside of C++)
That is the case in every production compiler that I'm aware of. This is not the only flag that cannot be customized as a pass flag. AllocationFn
, MemCpyFn
is another one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think that's not possible because you wouldn't be able to configure identity/fully dynamic layout maps at the function boundary.
I'm not sure I understand this part, can you explain more? Do you mean the void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
setter that makes the conversion LayoutMapOption
value-dependent? I think, yes, if this is a hard blocker, an option callback must stay - passing LayoutMapOption
via interfaces is probably a pain (if possible without hacks or a significant redesign).
Function boundaries are always a bit special, so I think having an options callback is desirable also for custom tensor/buffer types.
Fair point. Let's go with this then. I think I could do it a bit differently: for TensorType
(current API), we keep the logic "as is", otherwise - for TensorLikeType
- we always call it's getBufferType()
? This would preserve single conversion "interface" - the options callback and also give us some sane default behavior for any type (i.e. i wouldn't need to create --test-one-shot-bufferize pass yet).
That is the case in every production compiler that I'm aware of. This is not the only flag that cannot be customized as a pass flag. AllocationFn, MemCpyFn is another one.
Right, also in our case that's the way it is (fwiw, this is probably inevitable). I was a bit lazy to have a test pass in MLIR for custom types though (i'll probably have one anyway but later: to test how builtins with tensor encoding
-> memref layout
behave, this is one of the important things we use everywhere in our downstream).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this interface method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand this part, can you explain more?
Interface method implementation cannot be overridden by the user. In the interface implementation you have to make a choice whether to convert a tensor type to a memref with identity layout or fully dynamic layout map. We want users to be able to customize behavior, therefore, there must be some kind of hook in the BufferizationOptions
.
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
Outdated
Show resolved
Hide resolved
return *options.defaultMemorySpaceFn(tensorType); | ||
} | ||
return nullptr; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: i guess we could also make defaultMemorySpaceFn
work with tensor-like type. but i'd prefer to do it separately to simplify the updates for the users.
otoh, TensorLikeType::getBufferType() already has access to the tensor type and can infer memory space from it. I am not sure why we need it explicitly at all.
/// Type converter from tensors to buffers. This type converter is used to | ||
/// determine bufferized function argument and result types. | ||
/// | ||
/// By default, if tensor is a (builtin) tensor type, a type converter that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Let's not use the term "type converter" here because that has a different meaning in a dialect conversion. I'd rephrase as "the type is converted to...". Or just write "type conversion" instead of "type converter".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack. i think this is original wording (just slightly refurbished).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to "type conversion". also for UnknownTypeConverterFn
.
Support custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support. In order to enable this, TensorLikeType is extended with a new interface method that is invoked solely within the function boundary bufferization.
Transform FunctionArgTypeConverterFn into a tensor-like -> buffer-like converter so that it could be used as a generic function boundary conversion utility.
d7373f3
to
95e4724
Compare
…vm#159766) Support custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support. To achieve this, `BufferizationOptions::FunctionArgTypeConverterFn` callback is converted to work with tensor-like and buffer-like types, instead of the builtin counterparts. The default behavior for builtins remains unchanged, while custom types by default go through `TensorLikeType::getBufferType()` which is a general conversion interface.
Support custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support.
To achieve this,
BufferizationOptions::FunctionArgTypeConverterFn
callback is converted to work with tensor-like and buffer-like types, instead of the builtin counterparts. The default behavior for builtins remains unchanged, while custom types by default go throughTensorLikeType::getBufferType()
which is a general conversion interface.