Skip to content

Commit

Permalink
[mlir][linalg] Refine how contiguous loads are identified
Browse files Browse the repository at this point in the history
Vectorization of `tensor.extract` using contiguous loads
(`vector.transfer_read`) was introduced in [1]. This patch updates and
refines the existing logic (so that more cases of contiguous can be
identified), as well as adds more tests.

Specifically, contiguous load operations are identified by making sure
that:
  1. non-trailing indices for `tensor.extract` are loop invariant (so,
     e.g., there are no "jumps" from one row to the other between
     iterations),
  2. the trailing index for `tensor.extract` increments by 1 with every
     loop iteration (so that it's always adjacent elements that are
     loaded).
This patch introduces:
  * `isLoopInvariantIdx` for step 1., and
  * `isContiguousLoadIdx` for step 2.
These new methods replace:
  * `isContiguousLoadIdx`, and `isBasedOnIndexOp`.

Both approaches lead to similar end-result (none of the existing tests
required updating). However, with the updated approach, it's much easier
to treat the trailing and non-trailing indices separately and to add
more cases for which contiguous loads can be used.

[1] https://reviews.llvm.org/D141998

Differential Revision: https://reviews.llvm.org/D145385
  • Loading branch information
banach-space committed Mar 8, 2023
1 parent c7fcae5 commit 7a078b6
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 65 deletions.
153 changes: 88 additions & 65 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Expand Up @@ -636,81 +636,112 @@ enum VectorMemoryAccessKind {
Gather
};

/// Check whether /p val can be used for calculating an index for a contiguous
/// load operation. This means that /p val should either:
/// * be invariant with respect to /p linalgOp, or
/// * increment by 1 with every loop iterator (when /p shouldBeConstant is
/// false).
/// Parameters /p trailingLoopDim and /p shouldBeConstant are used to analyze
/// `linalg.index` ops.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
size_t trailingLoopDim, bool shouldBeConstant) {
auto *block = linalgOp.getBlock();
/// Checks whether /p val can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {

// Bail out if this is a block argument for this linalg.generic Op.
auto targetShape = linalgOp.getStaticLoopRanges();
assert(((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
assert(targetShape.back() != 1 &&
"1-D vectors with the trailing dim eqaual 1 are not yet supported");

// Blocks outside _this_ linalg.generic are effectively loop invariant.
// However, analysing block arguments for _this_ linalg.generic Op is a bit
// tricky. Just bail out in the latter case.
// TODO: We could try analysing the corresponding affine map here.
if (val.dyn_cast<BlockArgument>())
auto *block = linalgOp.getBlock();
if (isa<BlockArgument>(val))
return llvm::all_of(block->getArguments(),
[&val](Value v) { return (v != val); });

Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");

// We know that we are reading into a 1-D tensor like this:
// `tensor<1x1x4xi32`. Given this assumption, the following Op:
// * `%idx = `linalg.index dim : index`,
// will either:
// 1. produce a constant when `dim` _is not_ the trailing loop dim, or
// 2. increment with stride one when `dim` _is_ the trailing loop dim.
// IndexOp is loop invariant as long as its result remains constant across
// iterations. Given the assumptions on the loop ranges above, only the
// trailing loop dim ever changes.
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
return shouldBeConstant ? (indexOp.getDim() != trailingLoopDim)
: (indexOp.getDim() == trailingLoopDim);
return (indexOp.getDim() != trailingLoopDim);

auto *ancestor = block->findAncestorOpInBlock(*defOp);

// Values define outside `linalgOp`.
// Values define outside `linalgOp` are loop invariant.
if (!ancestor)
return true;

// Values defined inside `linalgOp`, which are constant.
if (dyn_cast<arith::ConstantOp>(ancestor))
// Values defined inside `linalgOp`, which are constant, are loop invariant.
if (isa<arith::ConstantOp>(ancestor))
return true;

// Conservatively reject Ops that could lead to non-contiguous accesses.
if (!isa<arith::AddIOp, arith::SubIOp, linalg::IndexOp>(ancestor))
return false;

bool result = true;
for (auto op : ancestor->getOperands())
result &=
isContiguousLoadIdx(linalgOp, op, trailingLoopDim, shouldBeConstant);
result &= isLoopInvariantIdx(linalgOp, op);

return result;
}

/// Check whether the calculation of \p val is based on linalg.index Op with
/// the dim attribute matching \p dim.
static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) {
auto *block = linalgOp.getBlock();
auto targetShape = linalgOp.getStaticLoopRanges();
/// Check whether \p val could be used for calculating the trailing index for a
/// contiguous load operation.
///
/// There are currently 3 types of values that are allowed here:
/// 1. loop-invariant values,
/// 2. values that increment by 1 with every loop iteration,
/// 3. results of basic arithmetic operations (linear and continuous)
/// involving 1., 2. and 3.
/// This method returns True if indeed only such values are used in calculating
/// \p val.
///
/// Additionally, the trailing index for a contiguous load operation should
/// increment by 1 with every loop iteration, i.e. be based on:
/// * `linalg.index <dim>` ,
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
/// updated to `true` when such an op is found.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
bool &foundIndexOp) {

if (val.isa<BlockArgument>())
return false;
auto targetShape = linalgOp.getStaticLoopRanges();
assert(((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
assert(targetShape.back() != 1 &&
"1-D vectors with the trailing dim 1 are not yet supported");

// Blocks outside _this_ linalg.generic are effectively loop invariant.
// However, analysing block arguments for _this_ linalg.generic Op is a bit
// tricky. Just bail out in the latter case.
// TODO: We could try analysing the corresponding affine map here.
auto *block = linalgOp.getBlock();
if (isa<BlockArgument>(val))
return llvm::all_of(block->getArguments(),
[&val](Value v) { return (v != val); });

Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");

if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
return (indexOp.getDim() == dim);
// Given the assumption on the loop ranges above, only the trailing loop
// index is not constant.
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
return true;
}

auto *ancestor = block->findAncestorOpInBlock(*defOp);

if (!ancestor)
return false;

// Conservatively reject Ops that could lead to indices with stride other
// than 1.
if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
ancestor))
return false;

bool result = false;
for (auto op : ancestor->getOperands())
result |= isBasedOnIndexOp(linalgOp, op, dim);
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);

return result;
}
Expand All @@ -725,7 +756,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,

auto targetShape = linalgOp.getStaticLoopRanges();

// Assume that it's a gather load when reading _into_:
// 1. Assume that it's a gather load when reading _into_:
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
// * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
// TODO: Relax these conditions.
Expand All @@ -736,44 +767,36 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,

auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();

// Assume that it's a gather load when reading _from_ a tensor for which the
// trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
// 2. Assume that it's a gather load when reading _from_ a tensor for which
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
// TODO: Relax this condition.
if (inputShape.getShape().back() == 1)
return VectorMemoryAccessKind::Gather;

// The trailing loop dim is needed when analyzing ops like:
// * %idx = `linalg.index <dim> : index`.
auto trailingLoopDim = targetShape.size() - 1;

bool isContiguous = true;

// Iterate over all indices. Analyze the way each index is calculated and
// decide whether it is suitable for a contiguous load (e.g. loop invariant).
// 3a. Analyze the leading indices of `extractOp`.
// Look at the way each index is calculated and decide whether it is suitable
// for a contiguous load, i.e. whether it's loop invariant.
auto indices = extractOp.getIndices();
for (auto [i, indexVal] : llvm::enumerate(indices)) {
if (inputShape.getShape()[i] == 1) {
// This index will always be equal 0, so it is a loop-invariant constant.
continue;
}
auto leadIndices = ValueRange(indices.drop_back(1));

// Should this index be loop invariant?
// * _no_ if this is the trailing index,
// * _yes_ otherwise.
auto extractOpBottomIdx = indices.size() - 1;
bool loopInvariantIndex = (i != extractOpBottomIdx);
for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
if (inputShape.getShape()[i] == 1)
continue;

isContiguous &= isContiguousLoadIdx(linalgOp, indexVal, trailingLoopDim,
loopInvariantIndex);
isContiguous &= isLoopInvariantIdx(linalgOp, indexVal);
}

// The trailing index in the extract Op must increment with every iteration,
// which means that it must be based on a loop index. Given the assumption
// on the output tensor, only the trailing loop index is not constant, so
// that's what we need to check against.
// 3b. Analyze the trailing index for `extractOp`.
auto extractOpTrailingIdx = indices.back();
// For contiguous loads, the trailing `extractOp` index should increment with
// every loop iteration. This effectively means that it must be based on the
// trailing loop index. This is what the following bool captures.
bool foundIndexOp = false;
isContiguous &=
isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, trailingLoopDim);
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
isContiguous &= foundIndexOp;

if (isContiguous) {
LDBG("Found contigous load: " << extractOp);
Expand Down

0 comments on commit 7a078b6

Please sign in to comment.