diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index bbb9dd71ac3c9..ebf20acab4171 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -432,6 +432,14 @@ static bool hasOnlyScalarElementwiseOp(Region &r) { return true; } +/// Returns `true` if all indexing maps of the linalg op are projected +/// permutations. +static bool allIndexingsAreProjectedPermutation(LinalgOp op) { + return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) { + return m.isProjectedPermutation(/*allowZeroInResults=*/true); + }); +} + // Return true if the op is an element-wise linalg op. static bool isElementwise(Operation *op) { auto linalgOp = dyn_cast(op); @@ -439,6 +447,10 @@ static bool isElementwise(Operation *op) { return false; if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) return false; + + if (!allIndexingsAreProjectedPermutation(linalgOp)) + return false; + // TODO: relax the restrictions on indexing map. for (OpOperand *opOperand : linalgOp.getOutputOperands()) { if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation()) @@ -564,17 +576,6 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, return success(); } -/// Helper function to vectorize a `linalgOp` with contraction semantics in a -/// generic fashion. -/// This helper is needed atm because the truly generic implementation requires -/// good vector.multi_reduce folding patterns that are currently NYI. -// TODO: drop reliance on a specific pattern. -static bool allIndexingsAreProjectedPermutation(LinalgOp op) { - return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) { - return m.isProjectedPermutation(/*allowZeroInResults=*/true); - }); -} - // TODO: probably need some extra checks for reduction followed by consumer // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 301e8f0dcd8da..99617d50684ed 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1077,3 +1077,27 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { return %2 : tensor } + + +// ----- + +// This test checks that vectorization does not occur when an input indexing map +// is not a projected permutation. In the future, this can be converted to a +// positive test when support is added. + +// CHECK-LABEL: func @not_projected_permutation +func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf32> { + %c0 = arith.constant 0.0 : f32 + %init = linalg.init_tensor [6, 6, 3, 3] : tensor<6x6x3x3xf32> + %fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<6x6x3x3xf32>) -> tensor<6x6x3x3xf32> + // CHECK: linalg.generic + %result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<8x8xf32>) + outs(%fill : tensor<6x6x3x3xf32>) { + ^bb0(%arg7: f32, %arg9: f32): + linalg.yield %arg7 : f32 + } -> tensor<6x6x3x3xf32> + return %result : tensor<6x6x3x3xf32> +}