Skip to content

Commit

Permalink
[mlir][Vector] Enable masked vectorization of linalg.fill
Browse files Browse the repository at this point in the history
linalg.fill was already vectorizable with masks but not supported in the
dynamic pre-checks.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D146856
  • Loading branch information
dcaballe committed Mar 29, 2023
1 parent 7b70baa commit f18a861
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
8 changes: 1 addition & 7 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Expand Up @@ -1291,19 +1291,13 @@ static LogicalResult reductionPreconditions(LinalgOp op) {

static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
// TODO: Masking only supports dynamic generic ops for now.
if (!isa<linalg::GenericOp>(op))
if (!isa<linalg::GenericOp, linalg::FillOp>(op))
return failure();

// TODO: Index vectorization assumes static shape.
if (op.hasIndexSemantics())
return failure();

// TODO: 0-d vectors are not supported yet.
if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) {
return map.isEmpty() || map.getResults().empty();
}))
return failure();

LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
return success();
}
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Linalg/vectorization.mlir
Expand Up @@ -2535,3 +2535,24 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
transform.structured.masked_vectorize %0 vector_sizes [8, 32]
}

// -----

func.func @vectorize_dynamic_fill(%A : tensor<?x?xf32>, %arg0 : f32) -> tensor<?x?xf32> {
%0 = linalg.fill ins(%arg0 : f32) outs(%A : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// CHECK-LABEL: func.func @vectorize_dynamic_fill
// CHECK: %[[DIM0:.*]] = tensor.dim
// CHECK: %[[DIM1:.*]] = tensor.dim
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<8x16xi1>
// CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f32 to vector<8x16xf32>
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[BCAST]], {{.*}} {in_bounds = [true, true]} : vector<8x16xf32>, tensor<?x?xf32> } : vector<8x16xi1>

transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!pdl.operation) -> !pdl.operation
transform.structured.masked_vectorize %0 vector_sizes [8, 16]
}

0 comments on commit f18a861

Please sign in to comment.