Skip to content

Commit

Permalink
[mlir][linalg] move isElementwise() to Linalg/Utils (NFC)
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D128398
  • Loading branch information
okkwon committed Jun 23, 2022
1 parent f4a3df1 commit 1dd2c93
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 42 deletions.
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ class LinalgDependenceGraph;
// General utilities
//===----------------------------------------------------------------------===//

/// Check if all indexing maps are projected permutations.
bool allIndexingsAreProjectedPermutation(LinalgOp op);

/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
bool hasOnlyScalarElementwiseOp(Region &r);

/// Check if a LinalgOp is an element-wise operation.
bool isElementwise(LinalgOp op);

/// Check if `permutation` is a permutation of the range
/// `[0, permutation.size())`.
bool isPermutation(ArrayRef<int64_t> permutation);
Expand Down
42 changes: 0 additions & 42 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,48 +417,6 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
llvm::to_vector<4>(returnTypes), op->getAttrs())};
}

/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
static bool hasOnlyScalarElementwiseOp(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
for (Operation &op : r.front()) {
if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
linalg::IndexOp>(op) ||
OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
return false;
}
return true;
}

/// Returns `true` if all indexing maps of the linalg op are projected
/// permutations.
static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
});
}

// Return true if the op is an element-wise linalg op.
static bool isElementwise(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return false;
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;

if (!allIndexingsAreProjectedPermutation(linalgOp))
return false;

// TODO: relax the restrictions on indexing map.
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
return false;
}
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
}

/// Generic vectorization function that rewrites the body of a `linalgOp` into
/// vector form. Generic vectorization proceeds as follows:
/// 1. Verify the `linalgOp` has one non-empty region.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRLinalgUtils
MLIRAffineAnalysis
MLIRAffineUtils
MLIRArithmeticDialect
MLIRFuncDialect
MLIRIR
MLIRLinalgDialect
MLIRSCFDialect
Expand Down
36 changes: 36 additions & 0 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand Down Expand Up @@ -141,6 +142,41 @@ static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
namespace mlir {
namespace linalg {

bool allIndexingsAreProjectedPermutation(LinalgOp op) {
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
});
}

bool hasOnlyScalarElementwiseOp(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
for (Operation &op : r.front()) {
if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
linalg::IndexOp>(op) ||
OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
return false;
}
return true;
}

bool isElementwise(LinalgOp op) {
if (op.getNumLoops() != op.getNumParallelLoops())
return false;

if (!allIndexingsAreProjectedPermutation(op))
return false;

// TODO: relax the restrictions on indexing map.
for (OpOperand *opOperand : op.getOutputOperands()) {
if (!op.getTiedIndexingMap(opOperand).isPermutation())
return false;
}
return hasOnlyScalarElementwiseOp(op->getRegion(0));
}

bool isPermutation(ArrayRef<int64_t> permutation) {
// Count the number of appearances for all indices.
SmallVector<int64_t> indexCounts(permutation.size(), 0);
Expand Down
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7472,6 +7472,7 @@ cc_library(
":ArithmeticDialect",
":ArithmeticUtils",
":DialectUtils",
":FuncDialect",
":IR",
":LinalgAnalysis",
":LinalgDialect",
Expand Down

0 comments on commit 1dd2c93

Please sign in to comment.