Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ struct BufferizationOptions {
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
/// Parameters: tensor type, memory space, func op, bufferization options
/// Tensor-like -> Buffer-like type conversion.
/// Parameters: tensor-like type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
/// Tensor -> MemRef type conversion.
/// Parameters: tensor type, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
Expand Down Expand Up @@ -335,10 +335,12 @@ struct BufferizationOptions {
/// predictable.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);

/// Type converter from tensors to memrefs. This type converter is used to
/// determine bufferized function argument and result types. By default, a
/// type converter that returns a memref type with a fully dynamic layout map
/// is used.
/// Type conversion from tensors to buffers. This type conversion is used to
/// determine bufferized function argument and result types.
///
/// By default, if tensor is a (builtin) tensor type, it is converted to a
/// memref type with a fully dynamic layout map; if tensor is a (generic)
/// tensor-like type, it is converted using TensorLikeType::getBufferType().
///
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
Expand All @@ -350,10 +352,9 @@ struct BufferizationOptions {
/// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
bool inferFunctionResultLayout = true;

/// Type converter from tensors to memrefs. This type converter is used if no
/// memref type could be inferred during bufferization. By default, a type
/// converter that returns a memref type with a fully dynamic layout map is
/// used.
/// Type conversion from tensors to memrefs. This type conversion is used if
/// no memref type could be inferred during bufferization. By default, returns
/// a memref type with a fully dynamic layout map.
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;

// Use during type conversion to determine the memory space for memref based
Expand Down
39 changes: 30 additions & 9 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
namespace {

/// Default function arg type converter: Use a fully dynamic layout map.
BaseMemRefType
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
BufferLikeType
defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace,
func::FuncOp funcOp,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
return cast<BufferLikeType>(
getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace));
}

// If not builtin, fallback to TensorLikeType::getBufferType()
auto bufferType =
type.getBufferType(options, [&]() { return funcOp->emitError(); });
assert(succeeded(bufferType) &&
"a valid buffer is always expected at function boundary");
return *bufferType;
}
/// Default unknown type converter: Use a fully dynamic layout map.
BaseMemRefType
Expand Down Expand Up @@ -385,14 +395,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {

void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace,
func::FuncOp funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
memorySpace);
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
return cast<BufferLikeType>(
bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
memorySpace));
return cast<BufferLikeType>(
bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace));
}

// If not builtin, fallback to TensorLikeType::getBufferType()
auto bufferType =
type.getBufferType(options, [&]() { return funcOp->emitError(); });
assert(succeeded(bufferType) &&
"a valid buffer is always expected at function boundary");
return *bufferType;
};
inferFunctionResultLayout =
layoutMapOption == LayoutMapOption::InferLayoutMap;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
// Compute the new signature.
SmallVector<Type> newTypes;
for (BlockArgument &bbArg : block->getArguments()) {
auto tensorType = dyn_cast<TensorType>(bbArg.getType());
auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
if (!tensorType) {
newTypes.push_back(bbArg.getType());
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
#endif // NDEBUG
}

// Note: this is a local adaptor to unify TensorType and TensorLikeType code
// paths that both work with BufferizationOptions.
static mlir::Attribute
getDefaultMemorySpace(const BufferizationOptions &options,
TensorLikeType type) {
if (auto tensorType = dyn_cast<TensorType>(type)) {
return *options.defaultMemorySpaceFn(tensorType);
}
return nullptr;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: i guess we could also make defaultMemorySpaceFn work with tensor-like type. but i'd prefer to do it separately to simplify the updates for the users.

otoh, TensorLikeType::getBufferType() already has access to the tensor type and can infer memory space from it. I am not sure why we need it explicitly at all.


/// Return the index-th bufferized function argument type. This assumes that the
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
/// specified by the user (as per `options.functionArgTypeConverterFn`).
static BaseMemRefType
static BufferLikeType
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
auto tensorType =
dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
assert(tensorType && "expected TensorType");

BaseMemRefType memrefType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);

auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
index, BufferizationDialect::kBufferLayoutAttrName);
if (!layoutAttr)
return memrefType;

auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
return MemRefType::get(rankedMemrefType.getShape(),
rankedMemrefType.getElementType(), layoutAttr,
rankedMemrefType.getMemorySpace());
auto type =
dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
assert(type && "expected TensorLikeType");

// Note: For builtin tensors there is additional logic related to layout.
if (auto tensorType = dyn_cast<TensorType>(type)) {
BufferLikeType memrefType = options.functionArgTypeConverterFn(
type, *options.defaultMemorySpaceFn(tensorType), funcOp, options);

auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
index, BufferizationDialect::kBufferLayoutAttrName);
if (!layoutAttr)
return memrefType;

auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
assert(rankedMemrefType &&
"buffer layout not supported on unranked tensors");
return cast<BufferLikeType>(MemRefType::get(
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
layoutAttr, rankedMemrefType.getMemorySpace()));
}

return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp,
options);
}

/// Return the FuncOp called by `callOp`.
Expand Down Expand Up @@ -227,13 +245,13 @@ struct CallOpInterface
FunctionType funcType = funcOp.getFunctionType();
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
return cast<BufferLikeType>(bufferizedType);
if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
return bufferizedType;

// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorType>(resultType);
auto tensorType = cast<TensorLikeType>(resultType);
return cast<BufferLikeType>(options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
options));
}

Expand All @@ -248,7 +266,7 @@ struct CallOpInterface
SmallVector<Type> resultTypes;
for (Value result : callOp.getResults()) {
Type returnType = result.getType();
if (!isa<TensorType>(returnType)) {
if (!isa<TensorLikeType>(returnType)) {
// Non-tensor values are returned.
resultTypes.push_back(returnType);
continue;
Expand All @@ -272,7 +290,7 @@ struct CallOpInterface

for (OpOperand &opOperand : callOp->getOpOperands()) {
// Non-tensor operands are just copied.
if (!isa<TensorType>(opOperand.get().getType())) {
if (!isa<TensorLikeType>(opOperand.get().getType())) {
newOperands.push_back(opOperand.get());
continue;
}
Expand All @@ -285,8 +303,8 @@ struct CallOpInterface
Value buffer = *maybeBuffer;

// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
if (!isa<BaseMemRefType>(memRefType)) {
auto bufferType = funcType.getInput(opOperand.getOperandNumber());
if (!isa<BufferLikeType>(bufferType)) {
// The called function was not bufferized yet. This can happen when
// there cycles in the function call graph. Compute the bufferized
// result type.
Expand All @@ -296,7 +314,7 @@ struct CallOpInterface
state);
if (failed(maybeBufferType))
return failure();
memRefType = *maybeBufferType;
bufferType = *maybeBufferType;
}

// Since we don't yet have a clear layout story, to_buffer may
Expand All @@ -305,8 +323,8 @@ struct CallOpInterface
// that will either canonicalize away or fail compilation until we can do
// something better. Insert a reallocation + copy if it cannot be
// statically guaranteed that a direct cast would be valid.
if (buffer.getType() != memRefType) {
auto memrefDstType = dyn_cast<MemRefType>(memRefType);
if (buffer.getType() != bufferType) {
auto memrefDstType = dyn_cast<MemRefType>(bufferType);
assert(memrefDstType &&
"buffer layout not supported on unranked tensors");
FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
Expand Down Expand Up @@ -370,7 +388,7 @@ struct FuncOpInterface
static bool supportsUnstructuredControlFlow() { return true; }

bool hasTensorSemantics(Operation *op) const {
auto isaTensor = llvm::IsaPred<TensorType>;
auto isaTensor = llvm::IsaPred<TensorLikeType>;

// A function has tensor semantics if it has tensor arguments/results.
auto funcOp = cast<FuncOp>(op);
Expand Down Expand Up @@ -406,8 +424,8 @@ struct FuncOpInterface

// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
return cast<BufferLikeType>(
getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
options);

return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
Expand All @@ -430,7 +448,7 @@ struct FuncOpInterface
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
if (isa<TensorType>(argType)) {
if (isa<TensorLikeType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
Expand All @@ -441,9 +459,9 @@ struct FuncOpInterface
// Compute the result types.
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
if (auto tensorType = dyn_cast<TensorType>(resultType)) {
BaseMemRefType resultType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
BufferLikeType resultType = options.functionArgTypeConverterFn(
tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
options);
retTypes.push_back(resultType);
continue;
Expand Down Expand Up @@ -473,7 +491,7 @@ struct FuncOpInterface
SmallVector<Value> returnValues;
for (auto [returnVal, bufferizedType] :
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);

// If not a tensor type just forward it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,59 @@ module @inner_module {
return %t : tensor<5xf32>
}
}

// -----

// CHECK: func.func @custom_types(
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>,
// CHECK-SAME: !test.test_memref<[4, 8], f64>)
func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>)
-> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) {
// CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) :
// CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
%out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
-> !test.test_tensor<[4, 8], f64>

// CHECK: %[[alloc:.*]] = "test.create_memref_op"
// CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]])
// CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
%alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64>
%out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>)
-> !test.test_tensor<[4, 8], f64>

// CHECK: return %[[out1]], %[[out2]]
return %out1, %out2 :
!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>
}

// -----

// CHECK: func.func @custom_types_foo(
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>)
-> !test.test_tensor<[4, 4], f64> {
// CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]])
%out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
-> !test.test_tensor<[4, 4], f64>
// CHECK: return %[[out]]
return %out : !test.test_tensor<[4, 4], f64>
}

// CHECK: func.func @custom_types_bar(
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64>
func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
-> !test.test_tensor<[4, 8], f64> {
// CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]])
%call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>)
-> !test.test_tensor<[4, 4], f64>

// CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]])
%out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>)
-> !test.test_tensor<[4, 8], f64>

// CHECK: return %[[out]]
return %out : !test.test_tensor<[4, 8], f64>
}
Loading