diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 8f5f87ba620ee..120d4e4a91372 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -12,6 +12,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -84,7 +85,8 @@ static void getBackwardSliceImpl(Operation *op, if (!op) return; - assert((op->getNumRegions() == 0 || isa(op)) && + assert((op->getNumRegions() == 0 || + isa(op)) && "unexpected generic op with regions"); // Evaluate whether we should keep this def. diff --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir new file mode 100644 index 0000000000000..731f3872f67dd --- /dev/null +++ b/mlir/test/IR/slice.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt -slice-analysis-test %s | FileCheck %s + +func @slicing_linalg_op(%arg0 : index, %arg1 : index, %arg2 : index) { + %a = alloc(%arg0, %arg2) : memref + %b = alloc(%arg2, %arg1) : memref + %c = alloc(%arg0, %arg1) : memref + %d = alloc(%arg0, %arg1) : memref + linalg.matmul %a, %b, %c : (memref, memref, memref) + linalg.matmul %a, %b, %d : (memref, memref, memref) + dealloc %c : memref + dealloc %b : memref + dealloc %a : memref + dealloc %d : memref + return +} + +// CHECK-LABEL: func @slicing_linalg_op__backward_slice__0 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref +// CHECK-DAG: %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref +// CHECK-DAG: %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref +// CHECK: return + +// CHECK-LABEL: func @slicing_linalg_op__backward_slice__1 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[A:.+]] = alloc(%[[ARG0]], %[[ARG2]]) : memref +// CHECK-DAG: %[[B:.+]] = alloc(%[[ARG2]], %[[ARG1]]) : memref +// CHECK-DAG: %[[C:.+]] = alloc(%[[ARG0]], %[[ARG1]]) : memref +// CHECK: return diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt index cf4ecada0f3cb..a42f90bb92689 100644 --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_library(MLIRTestIR TestPrintDefUse.cpp TestPrintNesting.cpp TestSideEffects.cpp + TestSlicing.cpp TestSymbolUses.cpp TestTypes.cpp diff --git a/mlir/test/lib/IR/TestSlicing.cpp b/mlir/test/lib/IR/TestSlicing.cpp new file mode 100644 index 0000000000000..a95b2f84cfcf5 --- /dev/null +++ b/mlir/test/lib/IR/TestSlicing.cpp @@ -0,0 +1,81 @@ +//===- TestSlicing.cpp - Testing slice functionality ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple testing pass for slicing. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; + +/// Create a function with the same signature as the parent function of `op` +/// with name being the function name and a `suffix`. +static LogicalResult createBackwardSliceFunction(Operation *op, + StringRef suffix) { + FuncOp parentFuncOp = op->getParentOfType(); + OpBuilder builder(parentFuncOp); + Location loc = op->getLoc(); + std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str(); + FuncOp clonedFuncOp = + builder.create(loc, clonedFuncOpName, parentFuncOp.getType()); + BlockAndValueMapping mapper; + builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock()); + for (auto arg : enumerate(parentFuncOp.getArguments())) + mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index())); + llvm::SetVector slice; + getBackwardSlice(op, &slice); + for (Operation *slicedOp : slice) + builder.clone(*slicedOp, mapper); + builder.create(loc); + return success(); +} + +namespace { +/// Pass to test slice generated from slice analysis. +struct SliceAnalysisTestPass + : public PassWrapper> { + void runOnOperation() override; + SliceAnalysisTestPass() = default; + SliceAnalysisTestPass(const SliceAnalysisTestPass &) {} +}; +} // namespace + +void SliceAnalysisTestPass::runOnOperation() { + ModuleOp module = getOperation(); + auto funcOps = module.getOps(); + unsigned opNum = 0; + for (auto funcOp : funcOps) { + // TODO: For now this is just looking for Linalg ops. It can be generalized + // to look for other ops using flags. + funcOp.walk([&](Operation *op) { + if (!isa(op)) + return WalkResult::advance(); + std::string append = + std::string("__backward_slice__") + std::to_string(opNum); + createBackwardSliceFunction(op, append); + opNum++; + return WalkResult::advance(); + }); + } +} + +namespace mlir { +void registerSliceAnalysisTestPass() { + PassRegistration pass( + "slice-analysis-test", "Test Slice analysis functionality."); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 437b5f4b6f1a6..e46327aa63992 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -38,6 +38,7 @@ void registerPatternsTestPass(); void registerPrintOpAvailabilityPass(); void registerSideEffectTestPasses(); void registerSimpleParametricTilingPass(); +void registerSliceAnalysisTestPass(); void registerSymbolTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAffineLoopUnswitchingPass(); @@ -88,6 +89,7 @@ void registerTestPasses() { registerPrintOpAvailabilityPass(); registerSideEffectTestPasses(); registerSimpleParametricTilingPass(); + registerSliceAnalysisTestPass(); registerSymbolTestPasses(); registerTestAffineDataCopyPass(); registerTestAllReduceLoweringPass();