diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h index 65b2928bfd546d..a43b2f10085619 100644 --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -20,6 +20,7 @@ #include #include +#include "mlir/IR/OwningOpRef.h" #include "mlir/Reducer/Tester.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" @@ -57,7 +58,7 @@ class ReductionNode { /// will have been applied certain reduction strategies. Note that it's not /// necessary to be an interesting case or a reduced module (has smaller size /// than parent's). - ModuleOp getModule() const { return module; } + ModuleOp getModule() const { return module.get(); } /// Return the region we're reducing. Region &getRegion() const { return *region; } @@ -141,7 +142,7 @@ class ReductionNode { /// This is a copy of module from parent node. All the reducer patterns will /// be applied to this instance. - ModuleOp module; + OwningOpRef module; /// The region of certain operation we're reducing in the module Region *region; diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp index e2a96681e9aa1a..a9d17431defed5 100644 --- a/mlir/lib/Reducer/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -112,6 +112,9 @@ void ReductionNode::update(std::pair result) { // This module may has been updated. Reset the range. ranges.clear(); ranges.push_back({0, std::distance(region->op_begin(), region->op_end())}); + } else { + // Release the uninteresting module to save some memory. + module.release()->erase(); } } diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp index 9fdec3e27c8fcb..d895432fa156a7 100644 --- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp +++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp @@ -28,7 +28,8 @@ using namespace mlir; // Parse and verify the input MLIR file. -static LogicalResult loadModule(MLIRContext &context, OwningModuleRef &module, +static LogicalResult loadModule(MLIRContext &context, + OwningOpRef &module, StringRef inputFilename) { module = parseSourceFile(inputFilename, &context); if (!module) @@ -75,7 +76,7 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv, if (!output) return failure(); - mlir::OwningModuleRef moduleRef; + OwningOpRef moduleRef; if (failed(loadModule(context, moduleRef, inputFilename))) return failure(); @@ -88,12 +89,12 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv, if (failed(parser.addToPipeline(pm, errorHandler))) return failure(); - ModuleOp m = moduleRef.get().clone(); + OwningOpRef m = moduleRef.get().clone(); - if (failed(pm.run(m))) + if (failed(pm.run(m.get()))) return failure(); - m.print(output->os()); + m->print(output->os()); output->keep(); return success();