Skip to content

Commit

Permalink
[mlir][CallOpInterface] Add setCalleeFromCallable method
Browse files Browse the repository at this point in the history
Currently `CallOpInterface` has a method `getCallableForCallee` to have a consistent way to get the callee from an operation with `CallOpInterface`, but missing a consistent way to set a callee for an operation with `CallOpInterface`.

A set callee method is useful for transformations that operate on `CallOpInterface`, and change the callee, e.g., a pass that specialize function, which clone the callee, and change the `CallOpInterface`'s callee to the cloned version. Without such method, transformation would need to understand the implementation for every operations with `CallOpInterface`, and have a type switch to handle them.

This review adds a method to set callee for operation with `CallOpInterface`.

Reviewed By: gysit, zero9178o

Differential Revision: https://reviews.llvm.org/D149763
  • Loading branch information
whitneywhtsang committed May 8, 2023
1 parent 9fca031 commit a2ab6a5
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 0 deletions.
8 changes: 8 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,14 @@ def fir_CallOp : fir_Op<"call",
return calling;
return getOperand(0);
}

/// Set the callee for this operation.
void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
if (auto calling =
(*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
(*this)->setAttr(getCalleeAttrName(), callee.get<mlir::SymbolRefAttr>());
setOperand(0, callee.get<mlir::Value>());
}
}];
}

Expand Down
1 change: 1 addition & 0 deletions mlir/docs/Interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,7 @@ interface section goes as follows:

* `CallOpInterface` - Used to represent operations like 'call'
- `CallInterfaceCallable getCallableForCallee()`
- `void setCalleeFromCallable(CallInterfaceCallable)`
* `CallableOpInterface` - Used to represent the target callee of call.
- `Region * getCallableRegion()`
- `ArrayRef<Type> getCallableResults()`
Expand Down
6 changes: 6 additions & 0 deletions mlir/docs/Tutorials/Toy/Ch-4.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("callee");
}

/// Set the callee for the generic call operation, this is required by the call
/// interface.
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}

/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
Expand Down
6 changes: 6 additions & 0 deletions mlir/examples/toy/Ch4/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}

/// Set the callee for the generic call operation, this is required by the call
/// interface.
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}

/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
Expand Down
6 changes: 6 additions & 0 deletions mlir/examples/toy/Ch5/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}

/// Set the callee for the generic call operation, this is required by the call
/// interface.
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}

/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
Expand Down
6 changes: 6 additions & 0 deletions mlir/examples/toy/Ch6/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}

/// Set the callee for the generic call operation, this is required by the call
/// interface.
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}

/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
Expand Down
6 changes: 6 additions & 0 deletions mlir/examples/toy/Ch7/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}

/// Set the callee for the generic call operation, this is required by the call
/// interface.
void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}

/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ def Async_CallOp : Async_Op<"call",
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}

/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}
}];

let assemblyFormat = [{
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Func/IR/FuncOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def CallOp : Func_Op<"call",
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}

/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -153,6 +158,11 @@ def CallIndirectOp : Func_Op<"call_indirect", [

/// Return the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return getCallee(); }

/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
setOperand(0, callee.get<Value>());
}
}];

let hasCanonicalizeMethod = 1;
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ def IncludeOp : TransformDialectOp<"include",
return getTarget();
}

void setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
setTargetAttr(callee.get<SymbolRefAttr>());
}

::mlir::Operation::operand_range getArgOperands() {
return getOperands();
}
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Interfaces/CallInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
}],
"::mlir::CallInterfaceCallable", "getCallableForCallee"
>,
InterfaceMethod<[{
Sets the callee of this call-like operation. A `callee` is either a
reference to a symbol, via SymbolRefAttr, or a reference to a defined
SSA value. The type of the `callee` is expected to be the same as the
return type of `getCallableForCallee`, e.g., `callee` should be
SymbolRefAttr for `func.call`.
}],
"void", "setCalleeFromCallable", (ins "::mlir::CallInterfaceCallable":$callee)
>,
InterfaceMethod<[{
Returns the operands within this call that are used as arguments to the
callee.
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,16 @@ CallInterfaceCallable CallOp::getCallableForCallee() {
return getOperand(0);
}

void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
// Direct call.
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
auto symRef = callee.get<SymbolRefAttr>();
return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
}
// Indirect call, callee Value is the first operand.
return setOperand(0, callee.get<Value>());
}

Operation::operand_range CallOp::getArgOperands() {
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
Expand Down Expand Up @@ -1157,6 +1167,16 @@ CallInterfaceCallable InvokeOp::getCallableForCallee() {
return getOperand(0);
}

void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
// Direct call.
if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
auto symRef = callee.get<SymbolRefAttr>();
return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
}
// Indirect call, callee Value is the first operand.
return setOperand(0, callee.get<Value>());
}

Operation::operand_range InvokeOp::getArgOperands() {
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,11 @@ CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
}

void spirv::FunctionCallOp::setCalleeFromCallable(
CallInterfaceCallable callee) {
(*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
}

Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
return getArguments();
}
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,18 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
let extraClassDeclaration = [{
/// Return the callee of this operation.
::mlir::CallInterfaceCallable getCallableForCallee();

/// Set the callee for this operation.
void setCalleeFromCallable(::mlir::CallInterfaceCallable);
}];
let extraClassDefinition = [{
::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee");
}

void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}
}];
}

Expand Down

0 comments on commit a2ab6a5

Please sign in to comment.