Skip to content

Commit

Permalink
[mlir][linalg] Vectorize tensor.extract using contiguous loads
Browse files Browse the repository at this point in the history
This patch implements vectorization of tensor.extract for n-D tensor (n
>= 2) using contiguous load operations, i.e. `vector.transfer_read`. This
is a follow-up of https://reviews.llvm.org/D137660 in which gather loads
were used, i.e. `vector.gather`.

It is always safe to use gather load operations when the underlying
memory pattern is contiguous, but not vice-verse. At the moment, the
following conditions have to be met for contiguous loads to be
generated:
  1. The _output tensor_ must be a 1-D vector with the trailing dim > 1,
     e.g. `tensor<1x1x4xi32`,
  2. The trailing dim in the _input tensor_ must be > 1, e.g.
     `tensor<1x1x4i32>` would be fine, but not `tensor<1x4x1xi32>`.
If these conditions are not satisfied, gather loads are generated
instead.

Condition 1 guarantees that the iteration space of the corresponding
`linalg.generic` Op is relatively simple. That makes analysing the
indices for `tensor.extract` rather straightforward.

Condition 2 is mostly there to avoid weird vectorisation patterns
resulting in vectors like: `vector<1x1x1xi32>`. In practice, tensors
like `tensor<1x4x1xi32>` should be collapsed to `tensor<1x4xi32>` before
vectorisation, but that's beyond the scope of this patch.

If needed, both conditions can be relaxed. I've not been able to find a
good motivating example for these, hence skipping. For reference,
`tosa.resize` (lowered to Linalg) was the driving example used here.

As a bonus, the test from "vectorization-unsupported.mlir" is moved to
"vectorization.mlir" with proper CHECK lines added.

Differential Revision: https://reviews.llvm.org/D141998

Co-authored-by: Diego Caballero <diegocaballero@google.com>
  • Loading branch information
banach-space and dcaballe committed Feb 22, 2023
1 parent f7b7c69 commit 89b144e
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 48 deletions.
206 changes: 196 additions & 10 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Expand Up @@ -611,11 +611,11 @@ static Value calculateGatherOffset(RewriterBase &rewriter,

const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);

auto dimSize = broadcastIfNeeded(
rewriter,
rewriter.create<arith::ConstantIndexOp>(
loc,
extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
indexVecType.getShape());

offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
Expand All @@ -630,6 +630,143 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
return offset;
}

enum VectorMemoryAccessKind {
// TODO: ScalarBroadcast,
Contiguous,
Gather
};

/// Check whether /p val can be used for calculating an index for a contiguous
/// load operation, i.e. whether /p val:
/// * is invariant with respect to /p linalgOp, i.e. whether it remains
/// constant for all iterations, and
/// * increments with the loop iterator (when /p strideZero is false) or is
/// not affected by the loop indices (/p strideZero is true).
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, size_t dim,
bool strideZero) {
auto *block = linalgOp.getBlock();

// Bail out if this is a block argument for this linalg.generic Op.
// TODO: We could try analysing the corresponding affine map here.
if (val.dyn_cast<BlockArgument>())
return llvm::all_of(block->getArguments(),
[&val](Value v) { return (v != val); });

Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");

// Given the assumption on the shape of the target tensor, index Op is
// either:
// * constant (for non-trailing dims), or
// * increments with stride one together with the trailing dimension
// Both cases are fine for contigious loads.
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
return strideZero ? (indexOp.getDim() != dim) : (indexOp.getDim() == dim);

auto *ancestor = block->findAncestorOpInBlock(*defOp);

// Values define outside `linalgOp`.
if (!ancestor)
return true;

// Values defined inside `linalgOp`, which are constant.
if (dyn_cast<arith::ConstantOp>(ancestor))
return true;

bool result = true;
for (auto op : ancestor->getOperands())
result &= isContiguousLoadIdx(linalgOp, op, dim, strideZero);

return result;
}

/// Check whether the calculation of \p val is based on linalg.index Op with
/// the dim attribute matching \p dim.
static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) {
auto *block = linalgOp.getBlock();
auto targetShape = linalgOp.getStaticLoopRanges();

if (val.isa<BlockArgument>())
return false;

Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");

if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
return (indexOp.getDim() == dim);

auto *ancestor = block->findAncestorOpInBlock(*defOp);

if (!ancestor)
return false;

bool result = false;
for (auto op : ancestor->getOperands())
result |= isBasedOnIndexOp(linalgOp, op, dim);

return result;
}

/// Check whether \p extractOp would be a gather or a contiguous load Op after
/// vectorising \p linalgOp. Note that it is always safe to use gather load
/// operations for contiguous loads (albeit slow), but not vice-versa. When in
/// doubt, bail out and assume that \p extractOp is a gather load.
static VectorMemoryAccessKind
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
LinalgOp &linalgOp) {

auto targetShape = linalgOp.getStaticLoopRanges();

// Assume that it's a gather load when reading _into_:
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
// * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
// TODO: Relax these conditions.
if ((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
targetShape.back() == 1)
return VectorMemoryAccessKind::Gather;

auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();

// Assume that it's a gather load when reading _from_ a tensor for which the
// trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
// TODO: Relax this condition.
if (inputShape.getShape().back() == 1)
return VectorMemoryAccessKind::Gather;

bool isContiguous = true;

// Iterate over all indices. Analyze whether the way each index is calculate
// is suitable for contiguous load operations (e.g. loop invariant).
auto indices = extractOp.getIndices();
for (auto [i, indexVal] : llvm::enumerate(indices)) {
if (inputShape.getShape()[i] == 1) {
// This extractOp index must be a loop-invariant constant
continue;
}

auto extractOpBottomIdx = indices.size() - 1;
auto strideOneDim = targetShape.size() - 1;
bool strideZero = (i != extractOpBottomIdx);
isContiguous &=
isContiguousLoadIdx(linalgOp, indexVal, strideOneDim, strideZero);
}

// The calculation of the trailing index must include the loop index. Given
// the assumption on the output tensor (which is defined by the iteration
// space), only the trailing dim matters.
auto extractOpTrailingIdx = indices.back();
isContiguous &=
isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, targetShape.size() - 1);

if (isContiguous) {
LDBG("Found contigous load: " << extractOp);
return VectorMemoryAccessKind::Contiguous;
}

return VectorMemoryAccessKind::Gather;
}

/// Helper function to vectorize the tensor.extract operations. Returns
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
Expand Down Expand Up @@ -660,15 +797,64 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
extractOp.getIndices().size(),
rewriter.create<arith::ConstantIndexOp>(loc, 0));

Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
VectorMemoryAccessKind memAccessKind =
getTensorExtractMemoryAccessPattern(extractOp, linalgOp);

// 1. Handle gather access
if (memAccessKind == VectorMemoryAccessKind::Gather) {
Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);

// Generate the gather load
Operation *gatherOp = rewriter.create<vector::GatherOp>(
loc, resultType, extractOp.getTensor(), baseIndices, offset,
maskConstantOp, passThruConstantOp);
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);

LDBG("Vectorised as gather load: " << extractOp);
return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
}

// 2. Handle contiguous access.
SmallVector<Value> transferReadIdxs;
auto resTrailingDim = resultType.getShape().back();
auto zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));

// Collect indices for `vector.transfer_read`. At this point, the indices will
// either be scalars or would have been broadcast to vectors matching the
// result type. For indices that are vectors, there are two options:
// * for non-trailing indices, all elements are identical (contiguous
// loads are identified by looking for non-trailing indices that are
// invariant with respect to the corresponding linalg.generic), or
// * for trailing indices, the index vector will contain values with stride
// one, but for `vector.transfer_read` only the first (i.e. 0th) index is
// needed.
// This means that
// * for scalar indices - just re-use it,
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
// (0th) element and use that.
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
auto idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
transferReadIdxs.push_back(idx);
continue;
}

auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
bvm.lookup(extractOp.getIndices()[i]));
transferReadIdxs.push_back(
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
}

// `tensor.extract_element` is always in-bounds, hence the following holds.
SmallVector<bool> inBounds(resultType.getRank(), true);

// Generate the gather load
Operation *gatherOp = rewriter.create<vector::GatherOp>(
loc, resultType, extractOp.getTensor(), baseIndices, offset,
maskConstantOp, passThruConstantOp);
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);

return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
LDBG("Vectorised as contiguous load: " << extractOp);
return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
}

/// Emit reduction operations if the shapes of the value to reduce is different
Expand Down
29 changes: 0 additions & 29 deletions mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

This file was deleted.

0 comments on commit 89b144e

Please sign in to comment.