Skip to content

Commit

Permalink
[mlir][linalg] fix crash in vectorization of elementwise operations
Browse files Browse the repository at this point in the history
The current vectorization logic implicitly expects "elementwise"
linalg ops to have projected permutations for indexing maps, but
the precondition logic misses this check. This can result in a
crash when executing the generic vectorization transform on an op
with a non-projected permutation input indexing map. This change
fixes the logic and adds a test (which crashes without this fix).

Differential Revision: https://reviews.llvm.org/D127000
  • Loading branch information
christopherbate committed Jun 3, 2022
1 parent f608752 commit 9f819f4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
23 changes: 12 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Expand Up @@ -432,13 +432,25 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
return true;
}

/// Returns `true` if all indexing maps of the linalg op are projected
/// permutations.
static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
});
}

// Return true if the op is an element-wise linalg op.
static bool isElementwise(Operation *op) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return false;
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;

if (!allIndexingsAreProjectedPermutation(linalgOp))
return false;

// TODO: relax the restrictions on indexing map.
for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
if (!linalgOp.getTiedIndexingMap(opOperand).isPermutation())
Expand Down Expand Up @@ -564,17 +576,6 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
return success();
}

/// Helper function to vectorize a `linalgOp` with contraction semantics in a
/// generic fashion.
/// This helper is needed atm because the truly generic implementation requires
/// good vector.multi_reduce folding patterns that are currently NYI.
// TODO: drop reliance on a specific pattern.
static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
return llvm::all_of(op.getIndexingMaps(), [](AffineMap m) {
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
});
}

// TODO: probably need some extra checks for reduction followed by consumer
// ops that may not commute (e.g. linear reduction + non-linear instructions).
static LogicalResult reductionPreconditions(LinalgOp op) {
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/Linalg/vectorization.mlir
Expand Up @@ -1077,3 +1077,27 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {

return %2 : tensor<f32>
}


// -----

// This test checks that vectorization does not occur when an input indexing map
// is not a projected permutation. In the future, this can be converted to a
// positive test when support is added.

// CHECK-LABEL: func @not_projected_permutation
func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf32> {
%c0 = arith.constant 0.0 : f32
%init = linalg.init_tensor [6, 6, 3, 3] : tensor<6x6x3x3xf32>
%fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<6x6x3x3xf32>) -> tensor<6x6x3x3xf32>
// CHECK: linalg.generic
%result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<8x8xf32>)
outs(%fill : tensor<6x6x3x3xf32>) {
^bb0(%arg7: f32, %arg9: f32):
linalg.yield %arg7 : f32
} -> tensor<6x6x3x3xf32>
return %result : tensor<6x6x3x3xf32>
}

0 comments on commit 9f819f4

Please sign in to comment.