diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index e0c65b0e09774..00701f6675bdd 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -235,6 +235,30 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { op->erase(); } +// Remove the dead functions from moduleOp. +static void deleteDeadFunction(Operation *module) { + auto functions = module->getRegion(0).getOps(); + llvm::DenseSet tasks(functions.begin(), functions.end()); + while (!tasks.empty()) { + llvm::DenseSet nextTasks; + for (FunctionOpInterface funcOp : tasks) { + if (funcOp.isPublic() || funcOp.isExternal()) + return; + SymbolTable::UseRange uses = *funcOp.getSymbolUses(module); + auto callSites = funcOp.getFunctionBody().getOps(); + if (uses.empty() && !callSites.empty()) { + for (CallOpInterface callOp : callSites) { + nextTasks.insert(cast(callOp.resolveCallable())); + } + } + + if (uses.empty() && !nextTasks.contains(funcOp)) + funcOp.erase(); + } + tasks = nextTasks; + } +} + /// Convert a list of `Operand`s to a list of `OpOperand`s. static SmallVector operandsToOpOperands(OperandRange operands) { OpOperand *values = operands.getBase(); @@ -881,6 +905,8 @@ void RemoveDeadValues::runOnOperation() { // end of this pass. RDVFinalCleanupList finalCleanupList; + // Remove the dead function in advance. + deleteDeadFunction(module); module->walk([&](Operation *op) { if (auto funcOp = dyn_cast(op)) { processFuncOp(funcOp, module, la, deadVals, finalCleanupList); diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 56449469dc29f..1efaf2ecc11ce 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -455,7 +455,7 @@ module @llvm_unreachable { func.func private @fn_with_llvm_unreachable(%arg0: tensor<4x4xf32>) -> tensor<4x4xi1> { llvm.unreachable } - func.func private @main(%arg0: tensor<4x4xf32>) { + func.func @main(%arg0: tensor<4x4xf32>) { %0 = call @fn_with_llvm_unreachable(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xi1> llvm.return } @@ -649,3 +649,49 @@ func.func @callee(%arg0: index, %arg1: index, %arg2: index) -> index { %res = call @mutl_parameter(%arg0, %arg1, %arg2) : (index, index, index) -> (index) return %res : index } + +// ----- + +// Test the elimination of dead functions. + +// CHECK-NOT: func private @single_private_func +func.func private @single_private_func(%arg0: i64) -> (i64) { + %c0_i64 = arith.constant 0 : i64 + %2 = arith.cmpi eq, %arg0, %c0_i64 : i64 + cf.cond_br %2, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + %c1_i64 = arith.constant 1 : i64 + return %c1_i64 : i64 + ^bb2: // pred: ^bb0 + %c3_i64 = arith.constant 3 : i64 + return %c3_i64 : i64 +} + +// ----- + +// Test the elimination of dead functions. + +// CHECK-NOT: @single_parameter +func.func private @single_parameter(%arg0: index) { + return +} + +// CHECK-NOT: @mutl_parameter +func.func private @mutl_parameter(%arg0: index, %arg1: index, %arg2: index) -> index { + return %arg1 : index +} + +// CHECK-NOT: @eliminate_parameter +func.func private @eliminate_parameter(%arg0: index, %arg1: index) { + call @single_parameter(%arg0) : (index) -> () + return +} + +// CHECK-NOT: @callee +func.func private @callee(%arg0: index, %arg1: index, %arg2: index) -> index { + // CHECK-NOT: call @eliminate_parameter + call @eliminate_parameter(%arg0, %arg1) : (index, index) -> () + // CHECK-NOT: call @mutl_parameter + %res = call @mutl_parameter(%arg0, %arg1, %arg2) : (index, index, index) -> (index) + return %res : index +}