Skip to content

Commit

Permalink
[MLIR] Vectorize tensor.extract on 1-d tensor
Browse files Browse the repository at this point in the history
This patch implements the vectorization of tensor.extract for the
basic 1-d lookup case. It only vectorizes the tensor.extract to a
vector.gather when the op extracts value from an 1-d tensor.

Related discussion: iree-org/iree#9198

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D133786
  • Loading branch information
Che-Yu Wu authored and dcaballe committed Oct 18, 2022
1 parent d3fcbee commit d09bef8
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 13 deletions.
108 changes: 97 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Expand Up @@ -232,6 +232,12 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
return Value();
}

// Custom vectorization precondition function type. This is intented to be used
// with CustomVectorizationHook. Returns success if the correpsonding custom
// hook can vectorize the op.
using CustomVectorizationPrecondition =
std::function<LogicalResult(Operation *)>;

// Custom vectorization function type. Produce a vector form of Operation*
// assuming all its vectorized operands are already in the BlockAndValueMapping.
// Return nullptr if the Operation cannot be vectorized.
Expand Down Expand Up @@ -300,6 +306,69 @@ static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
}

/// Helper function to check if the tensor.extract can be vectorized by the
/// custom hook vectorizeTensorExtract.
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
return failure();

// Currently only supports extraction with an 1-D index.
if (extractOp.getIndices().size() != 1)
return failure();

if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
return failure();

if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
})) {
return failure();
}

return success();
}

/// Helper function to vectorize the tensor.extract operations. Returns
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
/// CustomVectorizationHook.
static VectorizationResult
vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp,
const BlockAndValueMapping &bvm) {
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
if (!extractOp)
return VectorizationResult{VectorizationStatus::Failure, nullptr};
auto loc = extractOp.getLoc();

// Currently only supports extraction with an 1-D index. Checked in the
// tensorExtractVectorizationPrecondition.
assert(extractOp.getIndices().size() == 1);

auto indexVec = bvm.lookup(extractOp.getIndices()[0]);
// Compute the static loop sizes of the extract op.
auto targetShape = linalgOp.computeStaticLoopSizes();

SmallVector<Value> gatherIndices;
gatherIndices.push_back(b.create<arith::ConstantIndexOp>(loc, 0));

auto maskConstantOp = b.create<arith::ConstantOp>(
loc,
DenseIntElementsAttr::get(VectorType::get(targetShape, b.getI1Type()),
/*value=*/true));

auto resultType =
VectorType::get(targetShape, extractOp.getResult().getType());
auto passThruConstantOp =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(resultType));

auto gatherOp = b.create<vector::GatherOp>(
loc, resultType, extractOp.getTensor(), gatherIndices, indexVec,
maskConstantOp, passThruConstantOp);

return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
}

/// Emit reduction operations if the shapes of the value to reduce is different
/// that the result shape.
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
Expand Down Expand Up @@ -515,6 +584,14 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
};
hooks.push_back(vectorizeIndex);

// 4c. Register CustomVectorizationHook for extractOp.
CustomVectorizationHook vectorizeExtract =
[&](Operation *op,
const BlockAndValueMapping &bvm) -> VectorizationResult {
return vectorizeTensorExtract(b, op, linalgOp, bvm);
};
hooks.push_back(vectorizeExtract);

// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
for (Operation &op : block->getOperations()) {
VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
Expand Down Expand Up @@ -552,9 +629,20 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
return success();
}

static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
static LogicalResult vectorizeStaticLinalgOpPrecondition(
linalg::LinalgOp op,
ArrayRef<CustomVectorizationPrecondition> customPreconditions) {

// All types in the body should be a supported element type for VectorType.
for (Operation &innerOp : op->getRegion(0).front()) {
// Check if any custom hook can vectorize the inner op.
if (llvm::any_of(
customPreconditions,
[&](const CustomVectorizationPrecondition &customPrecondition) {
return succeeded(customPrecondition(&innerOp));
})) {
continue;
}
if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
return !VectorType::isValidElementType(type);
})) {
Expand All @@ -566,16 +654,8 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
return failure();
}
}
if (isElementwise(op)) {
// Some operations in the body cannot be vectorized.
for (Operation &payloadOp : *op.getBlock()) {
if (isa<tensor::ExtractOp>(payloadOp)) {
LDBG("precondition failed: `tensor.extract` not vectorizable");
return failure();
}
}
if (isElementwise(op))
return success();
}
// TODO: isaConvolutionOpInterface that can also infer from generic features.
// But we will still need stride/dilation attributes that will be annoying to
// reverse-engineer...
Expand All @@ -601,7 +681,13 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
LDBG("precondition failed: dynamic shape");
return failure();
}
return vectorizeStaticLinalgOpPrecondition(linalgOp);

SmallVector<CustomVectorizationPrecondition> customPreconditions;

// Register CustomVectorizationPrecondition for extractOp.
customPreconditions.push_back(tensorExtractVectorizationPrecondition);

return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions);
}

LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Expand Up @@ -161,8 +161,8 @@ bool hasOnlyScalarElementwiseOp(Region &r) {
if (!llvm::hasSingleElement(r))
return false;
for (Operation &op : r.front()) {
if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp,
linalg::IndexOp>(op) ||
if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
linalg::YieldOp, linalg::IndexOp>(op) ||
OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
Expand Down
65 changes: 65 additions & 0 deletions mlir/test/Dialect/Linalg/vectorization.mlir
Expand Up @@ -1457,3 +1457,68 @@ transform.sequence failures(propagate) {
%1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
%2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns }
}

// -----

#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @vectorize_1d_tensor_extract(%arg0: tensor<3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x7x2xf32>, %arg3: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
%2 = linalg.generic {
indexing_maps = [#map0, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%arg1, %arg2 : tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%arg3 : tensor<4x7x3x2xf32>) {
^bb0(%arg4: i32, %arg5: f32, %arg6: f32):
%3 = arith.index_cast %arg4 : i32 to index
%7 = tensor.extract %arg0[%3] : tensor<3xf32>
linalg.yield %7 : f32
} -> tensor<4x7x3x2xf32>
return %2 : tensor<4x7x3x2xf32>
}
// CHECK-LABEL: func.func @vectorize_1d_tensor_extract
// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<4x3xi32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]]
// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]]
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST]]
// CHECK: %[[INDICES:.*]] = vector.transpose %[[BROADCAST]]
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]]] [%[[INDICES]]], %[[MASK]], %[[PASSTHRU]]
// CHECK: vector.transfer_write %[[GATHER]]

transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
%2 = transform.structured.vectorize %1
}

// -----

#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
%2 = linalg.generic {
indexing_maps = [#map0, #map0, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%arg1, %arg2, %arg3 : tensor<4x3xi32>, tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%arg4 : tensor<4x7x3x2xf32>) {
^bb0(%arg5: i32, %arg6: i32, %arg7: f32, %arg8: f32):
%3 = arith.index_cast %arg5 : i32 to index
%4 = arith.index_cast %arg6 : i32 to index
%7 = tensor.extract %arg0[%3, %4] : tensor<3x3xf32>
linalg.yield %7 : f32
} -> tensor<4x7x3x2xf32>
return %2 : tensor<4x7x3x2xf32>
}
// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract
// CHECK: tensor.extract

transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
%2 = transform.structured.vectorize %1
}

0 comments on commit d09bef8

Please sign in to comment.