diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 7ac71ceeb588a..8b8e6a1001574 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -145,6 +145,10 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> { def PrintOpStats : Pass<"print-op-stats"> { let summary = "Print statistics of operations"; let constructor = "mlir::createPrintOpStatsPass()"; + let options = [ + Option<"printAsJSON", "json", "bool", /*default=*/"false", + "print the stats as JSON"> + ]; } def SCCP : Pass<"sccp"> { diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp index 8adc4117f9542..e7740abb3d19f 100644 --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -27,6 +27,9 @@ struct PrintOpStatsPass : public PrintOpStatsBase { // Print summary of op stats. void printSummary(); + // Print symmary of op stats in JSON. + void printSummaryInJSON(); + private: llvm::StringMap opCount; raw_ostream &os; @@ -37,8 +40,12 @@ void PrintOpStatsPass::runOnOperation() { opCount.clear(); // Compute the operation statistics for the currently visited operation. - getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); - printSummary(); + getOperation()->walk( + [&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); + if (printAsJSON) { + printSummaryInJSON(); + } else + printSummary(); } void PrintOpStatsPass::printSummary() { @@ -80,6 +87,23 @@ void PrintOpStatsPass::printSummary() { } } +void PrintOpStatsPass::printSummaryInJSON() { + SmallVector sorted(opCount.keys()); + llvm::sort(sorted); + + os << "{\n"; + + for (unsigned i = 0, e = sorted.size(); i != e; ++i) { + const auto &key = sorted[i]; + os << " \"" << key << "\" : " << opCount[key]; + if (i != e - 1) + os << ",\n"; + else + os << "\n"; + } + os << "}\n"; +} + std::unique_ptr mlir::createPrintOpStatsPass(raw_ostream &os) { return std::make_unique(os); } diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c index 63aba29e56f61..f73398e817beb 100644 --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -138,14 +138,14 @@ void testPrintPassPipeline() { mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); // Print the top level pass manager - // CHECK: Top-level: builtin.module(func.func(print-op-stats)) + // CHECK: Top-level: builtin.module(func.func(print-op-stats{json=false})) fprintf(stderr, "Top-level: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); fprintf(stderr, "\n"); // Print the pipeline nested one level down - // CHECK: Nested Module: func.func(print-op-stats) + // CHECK: Nested Module: func.func(print-op-stats{json=false}) fprintf(stderr, "Nested Module: "); mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL); fprintf(stderr, "\n"); @@ -166,8 +166,9 @@ void testParsePassPipeline() { // Try parse a pipeline. MlirLogicalResult status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), - mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats)," - " func.func(print-op-stats))")); + mlirStringRefCreateFromCString( + "builtin.module(func.func(print-op-stats{json=false})," + " func.func(print-op-stats{json=false}))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsSuccess(status)) { fprintf( @@ -179,8 +180,9 @@ void testParsePassPipeline() { mlirRegisterTransformsPrintOpStats(); status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), - mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats)," - " func.func(print-op-stats))")); + mlirStringRefCreateFromCString( + "builtin.module(func.func(print-op-stats{json=false})," + " func.func(print-op-stats{json=false}))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsFailure(status)) { fprintf(stderr, @@ -188,8 +190,8 @@ void testParsePassPipeline() { exit(EXIT_FAILURE); } - // CHECK: Round-trip: builtin.module(func.func(print-op-stats), - // func.func(print-op-stats)) + // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}), + // func.func(print-op-stats{json=false})) fprintf(stderr, "Round-trip: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); diff --git a/mlir/test/IR/op-stats-json.mlir b/mlir/test/IR/op-stats-json.mlir new file mode 100644 index 0000000000000..40b0602a95897 --- /dev/null +++ b/mlir/test/IR/op-stats-json.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats=json %s -o=/dev/null 2>&1 | FileCheck %s + +func.func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> { +^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): + %0 = arith.addf %arg0, %arg1 : tensor<4xf32> + %1 = arith.addf %arg0, %arg1 : tensor<4xf32> + %2 = arith.addf %arg0, %arg1 : tensor<4xf32> + %3 = arith.addf %arg0, %arg1 : tensor<4xf32> + %4 = arith.addf %arg0, %arg1 : tensor<4xf32> + %5 = arith.addf %arg0, %arg1 : tensor<4xf32> + %10 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %11 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %12 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %13 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %14 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %15 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %16 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %17 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %18 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %19 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %20 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %21 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %22 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %23 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %24 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %25 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %26 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + %30 = "long_op_name"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK: { +// CHECK: "arith.addf" : 6, +// CHECK: "func.return" : 1, +// CHECK: "long_op_name" : 1, +// CHECK: "xla.add" : 17 +// CHECK: } diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index c046bb818632e..6cc627d542337 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -36,19 +36,19 @@ def testParseSuccess(): # A first import is expected to fail because the pass isn't registered # until we import mlir.transforms try: - pm = PassManager.parse("builtin.module(func.func(print-op-stats))") + pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") # TODO: this error should be propagate to Python but the C API does not help right now. # CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline except ValueError as e: - # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats))'. + # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats{json=false}))'. log("ValueError exception:", e) else: log("Exception not produced") # This will register the pass and round-trip should be possible now. import mlir.transforms - pm = PassManager.parse("builtin.module(func.func(print-op-stats))") - # CHECK: Roundtrip: builtin.module(func.func(print-op-stats)) + pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") + # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) log("Roundtrip: ", pm) run(testParseSuccess) @@ -86,7 +86,7 @@ def testInvalidNesting(): # CHECK-LABEL: TEST: testRun def testRunPipeline(): with Context(): - pm = PassManager.parse("print-op-stats") + pm = PassManager.parse("print-op-stats{json=false}") module = Module.parse(r"""func.func @successfulParse() { return }""") pm.run(module) # CHECK: Operations encountered: