diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index 216948b2b3df55..0b035836ec61dc 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -86,6 +86,15 @@ class SymbolTable { /// nullptr if no valid parent symbol table could be found. static Operation *getNearestSymbolTable(Operation *from); + /// Walks all symbol table operations nested within, and including, `op`. For + /// each symbol table operation, the provided callback is invoked with the op + /// and a boolean signifying if the symbols within that symbol table can be + /// treated as if all uses within the IR are visible to the caller. + /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols + /// within `op` are visible. + static void walkSymbolTables(Operation *op, bool allSymUsesVisible, + function_ref callback); + /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index 0ff189de68000c..81ab52f197aafb 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -34,7 +34,8 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { InterfaceMethod<[{ Returns the callee of this call-like operation. A `callee` is either a reference to a symbol, via SymbolRefAttr, or a reference to a defined - SSA value. + SSA value. If the reference is an SSA value, the SSA value corresponds + to a region of a lambda-like operation. }], "CallInterfaceCallable", "getCallableForCallee" >, diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 487b51de8dc97c..dc4186eaf1296b 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -207,6 +207,35 @@ Operation *SymbolTable::getNearestSymbolTable(Operation *from) { return from; } +/// Walks all symbol table operations nested within, and including, `op`. For +/// each symbol table operation, the provided callback is invoked with the op +/// and a boolean signifying if the symbols within that symbol table can be +/// treated as if all uses are visible. `allSymUsesVisible` identifies whether +/// all of the symbol uses of symbols within `op` are visible. +void SymbolTable::walkSymbolTables( + Operation *op, bool allSymUsesVisible, + function_ref callback) { + bool isSymbolTable = op->hasTrait(); + if (isSymbolTable) { + SymbolOpInterface symbol = dyn_cast(op); + allSymUsesVisible |= !symbol || symbol.isPrivate(); + } else { + // Otherwise if 'op' is not a symbol table, any nested symbols are + // guaranteed to be hidden. + allSymUsesVisible = true; + } + + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &nestedOp : block) + walkSymbolTables(&nestedOp, allSymUsesVisible, callback); + + // If 'op' had the symbol table trait, visit it after any nested symbol + // tables. + if (isSymbolTable) + callback(op, allSymUsesVisible); +} + /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index 28c8216f833394..c0f89da300f1e1 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -31,29 +31,6 @@ using namespace mlir; // Symbol Use Tracking //===----------------------------------------------------------------------===// -/// Walk all of the symbol table operations nested with 'op' along with a -/// boolean signifying if the symbols within can be treated as if all uses are -/// visible. The provided callback is invoked with the symbol table operation, -/// and a boolean signaling if all of the uses within the symbol table are -/// visible. -static void walkSymbolTables(Operation *op, bool allSymUsesVisible, - function_ref callback) { - if (op->hasTrait()) { - SymbolOpInterface symbol = dyn_cast(op); - allSymUsesVisible = allSymUsesVisible || !symbol || symbol.isPrivate(); - callback(op, allSymUsesVisible); - } else { - // Otherwise if 'op' is not a symbol table, any nested symbols are - // guaranteed to be hidden. - allSymUsesVisible = true; - } - - for (Region ®ion : op->getRegions()) - for (Block &block : region) - for (Operation &nested : block) - walkSymbolTables(&nested, allSymUsesVisible, callback); -} - /// Walk all of the used symbol callgraph nodes referenced with the given op. static void walkReferencedSymbolNodes( Operation *op, CallGraph &cg, @@ -164,7 +141,8 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) { } } }; - walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), walkFn); + SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), + walkFn); // Drop the use information for any discardable nodes that are always live. for (auto &it : alwaysLiveNodes) diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp index 1d0a279cc592bc..c9fc4ba2f39530 100644 --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -116,12 +116,56 @@ class LatticeValue { Dialect *constantDialect; }; +/// This class contains various state used when computing the lattice of a +/// callable operation. +class CallableLatticeState { +public: + /// Build a lattice state with a given callable region, and a specified number + /// of results to be initialized to the default lattice value (Unknown). + CallableLatticeState(Region *callableRegion, unsigned numResults) + : callableArguments(callableRegion->front().getArguments()), + resultLatticeValues(numResults) {} + + /// Returns the arguments to the callable region. + Block::BlockArgListType getCallableArguments() const { + return callableArguments; + } + + /// Returns the lattice value for the results of the callable region. + MutableArrayRef getResultLatticeValues() { + return resultLatticeValues; + } + + /// Add a call to this callable. This is only used if the callable defines a + /// symbol. + void addSymbolCall(Operation *op) { symbolCalls.push_back(op); } + + /// Return the calls that reference this callable. This is only used + /// if the callable defines a symbol. + ArrayRef getSymbolCalls() const { return symbolCalls; } + +private: + /// The arguments of the callable region. + Block::BlockArgListType callableArguments; + + /// The lattice state for each of the results of this region. The return + /// values of the callable aren't SSA values, so we need to track them + /// separately. + SmallVector resultLatticeValues; + + /// The calls referencing this callable if this callable defines a symbol. + /// This removes the need to recompute symbol references during propagation. + /// Value based references are trivial to resolve, so they can be done + /// in-place. + SmallVector symbolCalls; +}; + /// This class represents the solver for the SCCP analysis. This class acts as /// the propagation engine for computing which values form constants. class SCCPSolver { public: - /// Initialize the solver with a given set of regions. - SCCPSolver(MutableArrayRef regions); + /// Initialize the solver with the given top-level operation. + SCCPSolver(Operation *op); /// Run the solver until it converges. void solve(); @@ -132,6 +176,11 @@ class SCCPSolver { void rewrite(MLIRContext *context, MutableArrayRef regions); private: + /// Initialize the set of symbol defining callables that can have their + /// arguments and results tracked. 'op' is the top-level operation that SCCP + /// is operating on. + void initializeSymbolCallables(Operation *op); + /// Replace the given value with a constant if the corresponding lattice /// represents a constant. Returns success if the value was replaced, failure /// otherwise. @@ -149,6 +198,13 @@ class SCCPSolver { /// Visit the given operation and compute any necessary lattice state. void visitOperation(Operation *op); + /// Visit the given call operation and compute any necessary lattice state. + void visitCallOperation(CallOpInterface op); + + /// Visit the given callable operation and compute any necessary lattice + /// state. + void visitCallableOperation(Operation *op); + /// Visit the given operation, which defines regions, and compute any /// necessary lattice state. This also resolves the lattice state of both the /// operation results and any nested regions. @@ -168,6 +224,11 @@ class SCCPSolver { void visitTerminatorOperation(Operation *op, ArrayRef constantOperands); + /// Visit the given terminator operation that exits a callable region. These + /// are terminators with no CFG successors. + void visitCallableTerminatorOperation(Operation *callable, + Operation *terminator); + /// Visit the given block and compute any necessary lattice state. void visitBlock(Block *block); @@ -235,11 +296,20 @@ class SCCPSolver { /// A worklist of operations that need to be processed. SmallVector opWorklist; + + /// The callable operations that have their argument/result state tracked. + DenseMap callableLatticeState; + + /// A map between a call operation and the resolved symbol callable. This + /// avoids re-resolving symbol references during propagation. Value based + /// callables are trivial to resolve, so they can be done in-place. + DenseMap callToSymbolCallable; }; } // end anonymous namespace -SCCPSolver::SCCPSolver(MutableArrayRef regions) { - for (Region ®ion : regions) { +SCCPSolver::SCCPSolver(Operation *op) { + /// Initialize the solver with the regions within this operation. + for (Region ®ion : op->getRegions()) { if (region.empty()) continue; Block *entryBlock = ®ion.front(); @@ -251,6 +321,7 @@ SCCPSolver::SCCPSolver(MutableArrayRef regions) { // as overdefined. markAllOverdefined(entryBlock->getArguments()); } + initializeSymbolCallables(op); } void SCCPSolver::solve() { @@ -310,6 +381,73 @@ void SCCPSolver::rewrite(MLIRContext *context, } } +void SCCPSolver::initializeSymbolCallables(Operation *op) { + // Initialize the set of symbol callables that can have their state tracked. + // This tracks which symbol callable operations we can propagate within and + // out of. + auto walkFn = [&](Operation *symTable, bool allUsesVisible) { + Region &symbolTableRegion = symTable->getRegion(0); + Block *symbolTableBlock = &symbolTableRegion.front(); + for (auto callable : symbolTableBlock->getOps()) { + // We won't be able to track external callables. + Region *callableRegion = callable.getCallableRegion(); + if (!callableRegion) + continue; + // We only care about symbol defining callables here. + auto symbol = dyn_cast(callable.getOperation()); + if (!symbol) + continue; + callableLatticeState.try_emplace(callable, callableRegion, + callable.getCallableResults().size()); + + // If not all of the uses of this symbol are visible, we can't track the + // state of the arguments. + if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) + markAllOverdefined(callableRegion->front().getArguments()); + } + if (callableLatticeState.empty()) + return; + + // After computing the valid callables, walk any symbol uses to check + // for non-call references. We won't be able to track the lattice state + // for arguments to these callables, as we can't guarantee that we can see + // all of its calls. + Optional uses = + SymbolTable::getSymbolUses(&symbolTableRegion); + if (!uses) { + // If we couldn't gather the symbol uses, conservatively assume that + // we can't track information for any nested symbols. + op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); }); + return; + } + + for (const SymbolTable::SymbolUse &use : *uses) { + // If the use is a call, track it to avoid the need to recompute the + // reference later. + if (auto callOp = dyn_cast(use.getUser())) { + Operation *symCallable = callOp.resolveCallable(); + auto callableLatticeIt = callableLatticeState.find(symCallable); + if (callableLatticeIt != callableLatticeState.end()) { + callToSymbolCallable.try_emplace(callOp, symCallable); + + // We only need to record the call in the lattice if it produces any + // values. + if (callOp.getOperation()->getNumResults()) + callableLatticeIt->second.addSymbolCall(callOp); + } + continue; + } + // This use isn't a call, so don't we know all of the callers. + auto *symbol = SymbolTable::lookupSymbolIn(op, use.getSymbolRef()); + auto it = callableLatticeState.find(symbol); + if (it != callableLatticeState.end()) + markAllOverdefined(it->second.getCallableArguments()); + } + }; + SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), + walkFn); +} + LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder, OperationFolder &folder, Value value) { @@ -347,6 +485,16 @@ void SCCPSolver::visitOperation(Operation *op) { if (op->isKnownTerminator()) visitTerminatorOperation(op, operandConstants); + // Process call operations. The call visitor processes result values, so we + // can exit afterwards. + if (CallOpInterface call = dyn_cast(op)) + return visitCallOperation(call); + + // Process callable operations. These are specially handled region operations + // that track dataflow via calls. + if (isa(op)) + return visitCallableOperation(op); + // Process region holding operations. The region visitor processes result // values, so we can exit afterwards. if (op->getNumRegions()) @@ -399,6 +547,62 @@ void SCCPSolver::visitOperation(Operation *op) { } } +void SCCPSolver::visitCallableOperation(Operation *op) { + // Mark the regions as executable. + bool isTrackingLatticeState = callableLatticeState.count(op); + for (Region ®ion : op->getRegions()) { + if (region.empty()) + continue; + Block *entryBlock = ®ion.front(); + markBlockExecutable(entryBlock); + + // If we aren't tracking lattice state for this callable, mark all of the + // region arguments as overdefined. + if (!isTrackingLatticeState) + markAllOverdefined(entryBlock->getArguments()); + } + + // TODO: Add support for non-symbol callables when necessary. If the callable + // has non-call uses we would mark overdefined, otherwise allow for + // propagating the return values out. + markAllOverdefined(op, op->getResults()); +} + +void SCCPSolver::visitCallOperation(CallOpInterface op) { + ResultRange callResults = op.getOperation()->getResults(); + + // Resolve the callable operation for this call. + Operation *callableOp = nullptr; + if (Value callableValue = op.getCallableForCallee().dyn_cast()) + callableOp = callableValue.getDefiningOp(); + else + callableOp = callToSymbolCallable.lookup(op); + + // The callable of this call can't be resolved, mark any results overdefined. + if (!callableOp) + return markAllOverdefined(op, callResults); + + // If this callable is tracking state, merge the argument operands with the + // arguments of the callable. + auto callableLatticeIt = callableLatticeState.find(callableOp); + if (callableLatticeIt == callableLatticeState.end()) + return markAllOverdefined(op, callResults); + + OperandRange callOperands = op.getArgOperands(); + auto callableArgs = callableLatticeIt->second.getCallableArguments(); + for (auto it : llvm::zip(callOperands, callableArgs)) { + BlockArgument callableArg = std::get<1>(it); + if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)])) + visitUsers(callableArg); + } + + // Merge in the lattice state for the callable results as well. + auto callableResults = callableLatticeIt->second.getResultLatticeValues(); + for (auto it : llvm::zip(callResults, callableResults)) + meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)], + /*from=*/std::get<1>(it)); +} + void SCCPSolver::visitRegionOperation(Operation *op, ArrayRef constantOperands) { // Check to see if we can reason about the internal control flow of this @@ -509,9 +713,14 @@ void SCCPSolver::visitTerminatorOperation( Operation *op, ArrayRef constantOperands) { // If this operation has no successors, we treat it as an exiting terminator. if (op->getNumSuccessors() == 0) { - // Check to see if the parent tracks region control flow. Region *parentRegion = op->getParentRegion(); Operation *parentOp = parentRegion->getParentOp(); + + // Check to see if this is a terminator for a callable region. + if (isa(parentOp)) + return visitCallableTerminatorOperation(parentOp, op); + + // Otherwise, check to see if the parent tracks region control flow. auto regionInterface = dyn_cast(parentOp); if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) return; @@ -552,6 +761,42 @@ void SCCPSolver::visitTerminatorOperation( markEdgeExecutable(block, succ); } +void SCCPSolver::visitCallableTerminatorOperation(Operation *callable, + Operation *terminator) { + // If there are no exiting values, we have nothing to track. + if (terminator->getNumOperands() == 0) + return; + + // If this callable isn't tracking any lattice state there is nothing to do. + auto latticeIt = callableLatticeState.find(callable); + if (latticeIt == callableLatticeState.end()) + return; + assert(callable->getNumResults() == 0 && "expected symbol callable"); + + // If this terminator is not "return-like", conservatively mark all of the + // call-site results as overdefined. + auto callableResultLattices = latticeIt->second.getResultLatticeValues(); + if (!terminator->hasTrait()) { + for (auto &it : callableResultLattices) + it.markOverdefined(); + for (Operation *call : latticeIt->second.getSymbolCalls()) + markAllOverdefined(call, call->getResults()); + return; + } + + // Merge the terminator operands into the results. + bool anyChanged = false; + for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices)) + anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]); + if (!anyChanged) + return; + + // If any of the result lattices changed, update the callers. + for (Operation *call : latticeIt->second.getSymbolCalls()) + for (auto it : llvm::zip(call->getResults(), callableResultLattices)) + meet(call, latticeValues[std::get<0>(it)], std::get<1>(it)); +} + void SCCPSolver::visitBlock(Block *block) { // If the block is not the entry block we need to compute the lattice state // for the block arguments. Entry block argument lattices are computed @@ -663,7 +908,7 @@ void SCCP::runOnOperation() { Operation *op = getOperation(); // Solve for SCCP constraints within nested regions. - SCCPSolver solver(op->getRegions()); + SCCPSolver solver(op); solver.solve(); // Cleanup any operations using the solver analysis. diff --git a/mlir/test/Transforms/sccp-callgraph.mlir b/mlir/test/Transforms/sccp-callgraph.mlir new file mode 100644 index 00000000000000..5d47a277df931c --- /dev/null +++ b/mlir/test/Transforms/sccp-callgraph.mlir @@ -0,0 +1,257 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -sccp -split-input-file | FileCheck %s -dump-input-on-failure +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="module(sccp)" -split-input-file | FileCheck %s --check-prefix=NESTED -dump-input-on-failure + +/// Check that a constant is properly propagated through the arguments and +/// results of a private function. + +// CHECK-LABEL: func @private( +func @private(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: return %[[CST]] : i32 + + return %arg0 : i32 +} + +// CHECK-LABEL: func @simple_private( +func @simple_private() -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: return %[[CST]] : i32 + + %1 = constant 1 : i32 + %result = call @private(%1) : (i32) -> i32 + return %result : i32 +} + +// ----- + +/// Check that a constant is properly propagated through the arguments and +/// results of a visible nested function. + +// CHECK: func @nested( +func @nested(%arg0 : i32) -> i32 attributes { sym_visibility = "nested" } { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: return %[[CST]] : i32 + + return %arg0 : i32 +} + +// CHECK-LABEL: func @simple_nested( +func @simple_nested() -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: return %[[CST]] : i32 + + %1 = constant 1 : i32 + %result = call @nested(%1) : (i32) -> i32 + return %result : i32 +} + +// ----- + +/// Check that non-visible nested functions do not track arguments. +module { + // NESTED-LABEL: module @nested_module + module @nested_module attributes { sym_visibility = "public" } { + + // NESTED: func @nested( + func @nested(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "nested" } { + // NESTED: %[[CST:.*]] = constant 1 : i32 + // NESTED: return %[[CST]], %arg0 : i32, i32 + + %1 = constant 1 : i32 + return %1, %arg0 : i32, i32 + } + + // NESTED: func @nested_not_all_uses_visible( + func @nested_not_all_uses_visible() -> (i32, i32) { + // NESTED: %[[CST:.*]] = constant 1 : i32 + // NESTED: %[[CALL:.*]]:2 = call @nested + // NESTED: return %[[CST]], %[[CALL]]#1 : i32, i32 + + %1 = constant 1 : i32 + %result:2 = call @nested(%1) : (i32) -> (i32, i32) + return %result#0, %result#1 : i32, i32 + } + } +} + +// ----- + +/// Check that public functions do not track arguments. + +// CHECK-LABEL: func @public( +func @public(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "public" } { + %1 = constant 1 : i32 + return %1, %arg0 : i32, i32 +} + +// CHECK-LABEL: func @simple_public( +func @simple_public() -> (i32, i32) { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: %[[CALL:.*]]:2 = call @public + // CHECK: return %[[CST]], %[[CALL]]#1 : i32, i32 + + %1 = constant 1 : i32 + %result:2 = call @public(%1) : (i32) -> (i32, i32) + return %result#0, %result#1 : i32, i32 +} + +// ----- + +/// Check that functions with non-call users don't have arguments tracked. + +func @callable(%arg0 : i32) -> (i32, i32) attributes { sym_visibility = "private" } { + %1 = constant 1 : i32 + return %1, %arg0 : i32, i32 +} + +// CHECK-LABEL: func @non_call_users( +func @non_call_users() -> (i32, i32) { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: %[[CALL:.*]]:2 = call @callable + // CHECK: return %[[CST]], %[[CALL]]#1 : i32, i32 + + %1 = constant 1 : i32 + %result:2 = call @callable(%1) : (i32) -> (i32, i32) + return %result#0, %result#1 : i32, i32 +} + +"live.user"() {uses = [@callable]} : () -> () + +// ----- + +/// Check that return values are overdefined in the presence of an unknown terminator. + +func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } { + "unknown.return"(%arg0) : (i32) -> () +} + +// CHECK-LABEL: func @unknown_terminator( +func @unknown_terminator() -> i32 { + // CHECK: %[[CALL:.*]] = call @callable + // CHECK: return %[[CALL]] : i32 + + %1 = constant 1 : i32 + %result = call @callable(%1) : (i32) -> i32 + return %result : i32 +} + +// ----- + +/// Check that return values are overdefined when the constant conflicts. + +func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } { + "unknown.return"(%arg0) : (i32) -> () +} + +// CHECK-LABEL: func @conflicting_constant( +func @conflicting_constant() -> (i32, i32) { + // CHECK: %[[CALL1:.*]] = call @callable + // CHECK: %[[CALL2:.*]] = call @callable + // CHECK: return %[[CALL1]], %[[CALL2]] : i32, i32 + + %1 = constant 1 : i32 + %2 = constant 2 : i32 + %result = call @callable(%1) : (i32) -> i32 + %result2 = call @callable(%2) : (i32) -> i32 + return %result, %result2 : i32, i32 +} + +// ----- + +/// Check that return values are overdefined when the constant conflicts with a +/// non-constant. + +func @callable(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } { + "unknown.return"(%arg0) : (i32) -> () +} + +// CHECK-LABEL: func @conflicting_constant( +func @conflicting_constant(%arg0 : i32) -> (i32, i32) { + // CHECK: %[[CALL1:.*]] = call @callable + // CHECK: %[[CALL2:.*]] = call @callable + // CHECK: return %[[CALL1]], %[[CALL2]] : i32, i32 + + %1 = constant 1 : i32 + %result = call @callable(%1) : (i32) -> i32 + %result2 = call @callable(%arg0) : (i32) -> i32 + return %result, %result2 : i32, i32 +} + +// ----- + +/// Check a more complex interaction with calls and control flow. + +// CHECK-LABEL: func @complex_inner_if( +func @complex_inner_if(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } { + // CHECK-DAG: %[[TRUE:.*]] = constant 1 : i1 + // CHECK-DAG: %[[CST:.*]] = constant 1 : i32 + // CHECK: cond_br %[[TRUE]], ^bb1 + + %cst_20 = constant 20 : i32 + %cond = cmpi "ult", %arg0, %cst_20 : i32 + cond_br %cond, ^bb1, ^bb2 + +^bb1: + // CHECK: ^bb1: + // CHECK: return %[[CST]] : i32 + + %cst_1 = constant 1 : i32 + return %cst_1 : i32 + +^bb2: + %cst_1_2 = constant 1 : i32 + %arg_inc = addi %arg0, %cst_1_2 : i32 + return %arg_inc : i32 +} + +func @complex_cond() -> i1 + +// CHECK-LABEL: func @complex_callee( +func @complex_callee(%arg0 : i32) -> i32 attributes { sym_visibility = "private" } { + // CHECK: %[[CST:.*]] = constant 1 : i32 + + %loop_cond = call @complex_cond() : () -> i1 + cond_br %loop_cond, ^bb1, ^bb2 + +^bb1: + // CHECK: ^bb1: + // CHECK-NEXT: return %[[CST]] : i32 + return %arg0 : i32 + +^bb2: + // CHECK: ^bb2: + // CHECK: call @complex_inner_if(%[[CST]]) : (i32) -> i32 + // CHECK: call @complex_callee(%[[CST]]) : (i32) -> i32 + // CHECK: return %[[CST]] : i32 + + %updated_arg = call @complex_inner_if(%arg0) : (i32) -> i32 + %res = call @complex_callee(%updated_arg) : (i32) -> i32 + return %res : i32 +} + +// CHECK-LABEL: func @complex_caller( +func @complex_caller(%arg0 : i32) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: return %[[CST]] : i32 + + %1 = constant 1 : i32 + %result = call @complex_callee(%1) : (i32) -> i32 + return %result : i32 +} + +// ----- + +/// Check that non-symbol defining callables currently go to overdefined. + +// CHECK-LABEL: func @non_symbol_defining_callable +func @non_symbol_defining_callable() -> i32 { + // CHECK: %[[RES:.*]] = call_indirect + // CHECK: return %[[RES]] : i32 + + %fn = "test.functional_region_op"() ({ + %1 = constant 1 : i32 + "test.return"(%1) : (i32) -> () + }) : () -> (() -> i32) + %res = call_indirect %fn() : () -> (i32) + return %res : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 000a5722a76ae2..ad8c6fb99e6794 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1090,7 +1090,7 @@ def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> { //===----------------------------------------------------------------------===// def TestRegionBuilderOp : TEST_Op<"region_builder">; -def TestReturnOp : TEST_Op<"return", [Terminator]>, +def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]>, Arguments<(ins Variadic)>; def TestCastOp : TEST_Op<"cast">, Arguments<(ins Variadic)>, Results<(outs AnyType)>;