Skip to content

Commit

Permalink
[mlir][linalg][bufferize] Fix CallOps with non-tensor operands
Browse files Browse the repository at this point in the history
Such CallOps were not handled properly. When computing the new result types (and replacement values) of a CallOp, non-tensor return values were not accounted for.

Differential Revision: https://reviews.llvm.org/D116445
  • Loading branch information
matthias-springer committed Jan 5, 2022
1 parent d716cfc commit a98c5a0
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 32 deletions.
Expand Up @@ -490,6 +490,19 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {

/// Return the index of the parent function's bbArg that is equivalent to the
/// given ReturnOp operand (if any).
static Optional<int64_t>
getEquivalentFuncArgIdx(ModuleBufferizationState &state,
OpOperand &returnOperand) {
FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
// Return value has no equivalent bbArg.
return None;

return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
}

struct CallOpInterface
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
Expand All @@ -515,57 +528,67 @@ struct CallOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op);
unsigned numResults = callOp.getNumResults();
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected Callop to a FuncOp");
"expected CallOp to a FuncOp");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);

// 1. Filter return types:
// - if the callee is bodiless / external, we cannot inspect it and we
// cannot assume anything. We can just assert that it does not return a
// tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
// - if the callee has a body, we perform inter-procedural equivalence
// analysis. When successful, a result folds onto an operand. When
// unsuccessful, additional work is needed (TODO) to either:
// * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer.
// Result types of the bufferized CallOp.
SmallVector<Type> resultTypes;
// Replacement values for the existing CallOp. These are usually the results
// of the bufferized CallOp, unless a tensor result folds onto an operand.
SmallVector<Value> replacementValues(numResults, Value());
// For non-tensor results: A mapping from return val indices of the old
// CallOp to return val indices of the bufferized CallOp.
SmallVector<Optional<unsigned>> retValMapping(numResults, None);

if (funcOp.body().empty()) {
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
return callOp->emitError()
<< "cannot bufferize bodiless function that returns a tensor";
// The callee is bodiless / external, so we cannot inspect it and we
// cannot assume anything. We can just assert that it does not return a
// tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
for (int i = 0; i < numResults; ++i) {
Type returnType = callOp.getResult(i).getType();
if (isaTensor(returnType))
return callOp->emitError()
<< "cannot bufferize bodiless function that returns a tensor";
resultTypes.push_back(returnType);
retValMapping[i] = i;
}
} else {
// The callee has a body. Based on previously gathered equivalence
// information, we know if a tensor result folds onto an operand. These
// are the only tensor value returns that are supported at the moment.
//
// For tensors return values that do not fold onto an operand, additional
// work is needed (TODO) to either:
// * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");

// For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
unsigned returnIdx = returnOperand.getOperandNumber();
Type returnType = returnOperand.get().getType();
if (!isaTensor(returnType)) {
// Non-tensor values are returned.
retValMapping[returnIdx] = resultTypes.size();
resultTypes.push_back(returnType);
continue;
}

// If return operand is equivalent to some bbArg, no need to return it.
if (moduleState.equivalentFuncArgs[funcOp].count(
returnOperand.getOperandNumber())) {
int64_t idx =
moduleState
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx));
// Add a ToTensorOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This ToTensorOp must fold/DCE away or bufferization should be
// considered failed.
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(toTensorOp);
if (Optional<int64_t> bbArgIdx =
getEquivalentFuncArgIdx(moduleState, returnOperand)) {
// Return operands that are equivalent to some bbArg, are not
// returned.
replacementValues[returnIdx] =
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
continue;
}

resultTypes.push_back(returnType);
llvm_unreachable("returning non-equivalent tensors not supported");
}
}

Expand Down Expand Up @@ -612,8 +635,13 @@ struct CallOpInterface
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());

// 5. Delete the op at the end of bufferization.
callOp->erase();
// 5. Replace the old op with the new op.
for (int i = 0; i < replacementValues.size(); ++i) {
if (replacementValues[i])
continue;
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
}
state.replaceOp(rewriter, callOp, replacementValues);

return success();
}
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Expand Up @@ -1000,6 +1000,34 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},

// -----

// CHECK-LABEL: func @inner_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
// CHECK-NOT: copy
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK: memref.store %{{.*}}, %[[arg0]]
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
// CHECK: %[[load:.*]] = memref.load %[[arg0]]
%1 = tensor.extract %0[%c1] : tensor<?xf32>
// CHECK: return %[[load]] : f32
return %0, %1 : tensor<?xf32>, f32
}

// CHECK-LABEL: func @call_func_with_non_tensor_return(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func @call_func_with_non_tensor_return(
%t0: tensor<?xf32> {linalg.inplaceable = true}) -> (f32, tensor<?xf32>) {
// CHECK-NOT: copy
// CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
// CHECK: return %[[call]] : f32
return %1, %0 : f32, tensor<?xf32>
}

// -----

// CHECK-LABEL: func @func_without_tensor_args
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
// CHECK: %[[alloc:.*]] = memref.alloc()
Expand Down

0 comments on commit a98c5a0

Please sign in to comment.