Skip to content

Commit

Permalink
[mlir][Vector] Add transformation + pattern to split vector.transfer_…
Browse files Browse the repository at this point in the history
…read into full and partial copies.

This revision adds a transformation and a pattern that rewrites a "maybe masked" `vector.transfer_read %view[...], %pad `into a pattern resembling:

```
   %1:3 = scf.if (%inBounds) {
      scf.yield %view : memref<A...>, index, index
    } else {
      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
      %3 = vector.type_cast %extra_alloc : memref<...> to
      memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
      memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
      memref<A...>, index, index
   }
   %res= vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
```
where `extra_alloc` is a top of the function alloca'ed buffer of one vector.

This rewrite makes it possible to realize the "always full tile" abstraction where vector.transfer_read operations are guaranteed to read from a padded full buffer.
The extra work only occurs on the boundary tiles.

Differential Revision: https://reviews.llvm.org/D84631
  • Loading branch information
Nicolas Vasilache committed Aug 3, 2020
1 parent 4b1b109 commit d313e9c
Show file tree
Hide file tree
Showing 6 changed files with 430 additions and 0 deletions.
64 changes: 64 additions & 0 deletions mlir/include/mlir/Dialect/Vector/VectorTransforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
class VectorTransferOpInterface;

namespace scf {
class IfOp;
} // namespace scf

/// Collect a set of patterns to convert from the Vector dialect to itself.
/// Should be merged with populateVectorToSCFLoweringPattern.
Expand Down Expand Up @@ -104,6 +109,65 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
FilterConstraintType filter;
};

/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
/// is `success, the `ifOp` points to the newly created conditional upon
/// function return. To accomodate for the fact that the original
/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
/// in the temporary buffer, the scf.if op returns a view and values of type
/// index. At this time, only vector.transfer_read is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// ```
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// scf.yield %0 : memref<A...>, index, index
/// } else {
/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// %3 = vector.type_cast %extra_alloc : memref<...> to
/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
/// memref<A...>, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
/// ```
/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp);
LogicalResult splitFullAndPartialTransfer(OpBuilder &b,
VectorTransferOpInterface xferOp,
scf::IfOp *ifOp = nullptr);

/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
struct VectorTransferFullPartialRewriter : public RewritePattern {
using FilterConstraintType =
std::function<LogicalResult(VectorTransferOpInterface op)>;

explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}

/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;

private:
FilterConstraintType filter;
};

} // namespace vector

//===----------------------------------------------------------------------===//
Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Interfaces/VectorInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*defaultImplementation=*/
"return $_op.getMemRefType().getRank() - $_op.getTransferRank();"
>,
InterfaceMethod<
/*desc=*/[{ Returns true if at least one of the dimensions is masked.}],
/*retTy=*/"bool",
/*methodName=*/"hasMaskedDim",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
for (unsigned idx = 0, e = $_op.getTransferRank(); idx < e; ++idx)
if ($_op.isMaskedDim(idx))
return true;
return false;
}]
>,
InterfaceMethod<
/*desc=*/[{
Helper function to account for the fact that `permutationMap` results and
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVector
MLIRVectorOpsIncGen

LINK_LIBS PUBLIC
MLIRAffineEDSC
MLIREDSC
MLIRIR
MLIRStandardOps
Expand Down
234 changes: 234 additions & 0 deletions mlir/lib/Dialect/Vector/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@

#include <type_traits>

#include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
Expand Down Expand Up @@ -1985,6 +1989,236 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,

} // namespace mlir

static Optional<int64_t> extractConstantIndex(Value v) {
if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
return cstOp.getValue();
if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
if (affineApplyOp.getAffineMap().isSingleConstant())
return affineApplyOp.getAffineMap().getSingleConstantResult();
return None;
}

// Missing foldings of scf.if make it necessary to perform poor man's folding
// eagerly, especially in the case of unrolling. In the future, this should go
// away once scf.if folds properly.
static Value createScopedFoldedSLE(Value v, Value ub) {
using namespace edsc::op;
auto maybeCstV = extractConstantIndex(v);
auto maybeCstUb = extractConstantIndex(ub);
if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
return Value();
return sle(v, ub);
}

// Operates under a scoped context to build the condition to ensure that a
// particular VectorTransferOpInterface is unmasked.
static Value createScopedInBoundsCond(VectorTransferOpInterface xferOp) {
assert(xferOp.permutation_map().isMinorIdentity() &&
"Expected minor identity map");
Value inBoundsCond;
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
// Zip over the resulting vector shape and memref indices.
// If the dimension is known to be unmasked, it does not participate in the
// construction of `inBoundsCond`.
if (!xferOp.isMaskedDim(resultIdx))
return;
int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
using namespace edsc::op;
using namespace edsc::intrinsics;
// Fold or create the check that `index + vector_size` <= `memref_size`.
Value sum = xferOp.indices()[indicesIdx] + std_constant_index(vectorSize);
Value cond =
createScopedFoldedSLE(sum, std_dim(xferOp.memref(), indicesIdx));
if (!cond)
return;
// Conjunction over all dims for which we are in-bounds.
inBoundsCond = inBoundsCond ? inBoundsCond && cond : cond;
});
return inBoundsCond;
}

LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition(
VectorTransferOpInterface xferOp) {
// TODO: expand support to these 2 cases.
if (!xferOp.permutation_map().isMinorIdentity())
return failure();
// TODO: relax this precondition. This will require rank-reducing subviews.
if (xferOp.getMemRefType().getRank() != xferOp.getTransferRank())
return failure();
// Must have some masked dimension to be a candidate for splitting.
if (!xferOp.hasMaskedDim())
return failure();
// Don't split transfer operations under IfOp, this avoids applying the
// pattern recursively.
// TODO: improve the condition to make it more applicable.
if (xferOp.getParentOfType<scf::IfOp>())
return failure();
return success();
}

MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
if (MemRefCastOp::areCastCompatible(aT, bT))
return aT;
if (aT.getRank() != bT.getRank())
return MemRefType();
int64_t aOffset, bOffset;
SmallVector<int64_t, 4> aStrides, bStrides;
if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
aStrides.size() != bStrides.size())
return MemRefType();

ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
int64_t resOffset;
SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
resStrides(bT.getRank(), 0);
for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
resShape[idx] =
(aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize;
resStrides[idx] = (aStrides[idx] == bStrides[idx])
? aStrides[idx]
: MemRefType::kDynamicStrideOrOffset;
}
resOffset =
(aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset;
return MemRefType::get(
resShape, aT.getElementType(),
makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext()));
}

/// Split a vector.transfer operation into an unmasked fastpath vector.transfer
/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result
/// is `success, the `ifOp` points to the newly created conditional upon
/// function return. To accomodate for the fact that the original
/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0]
/// in the temporary buffer, the scf.if op returns a view and values of type
/// index. At this time, only vector.transfer_read is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// ```
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// scf.yield %0 : memref<A...>, index, index
/// } else {
/// %2 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// %3 = vector.type_cast %extra_alloc : memref<...> to
/// memref<vector<...>> store %2, %3[] : memref<vector<...>> %4 =
/// memref_cast %extra_alloc: memref<B...> to memref<A...> scf.yield %4 :
/// memref<A...>, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]}
/// ```
/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.permutation_map()` must be a minor identity map
/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()`
/// must be equal. This will be relaxed in the future but requires
/// rank-reducing subviews.
LogicalResult mlir::vector::splitFullAndPartialTransfer(
OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) {
using namespace edsc;
using namespace edsc::intrinsics;

assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
"Expected splitFullAndPartialTransferPrecondition to hold");
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());

// TODO: add support for write case.
if (!xferReadOp)
return failure();

OpBuilder::InsertionGuard guard(b);
if (xferOp.memref().getDefiningOp())
b.setInsertionPointAfter(xferOp.memref().getDefiningOp());
else
b.setInsertionPoint(xferOp);
ScopedContext scope(b, xferOp.getLoc());
Value inBoundsCond = createScopedInBoundsCond(
cast<VectorTransferOpInterface>(xferOp.getOperation()));
if (!inBoundsCond)
return failure();

// Top of the function `alloc` for transient storage.
Value alloc;
{
FuncOp funcOp = xferOp.getParentOfType<FuncOp>();
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(&funcOp.getRegion().front());
auto shape = xferOp.getVectorType().getShape();
Type elementType = xferOp.getVectorType().getElementType();
alloc = std_alloca(MemRefType::get(shape, elementType), ValueRange{},
b.getI64IntegerAttr(32));
}

Value memref = xferOp.memref();
SmallVector<bool, 4> bools(xferOp.getTransferRank(), false);
auto unmaskedAttr = b.getBoolArrayAttr(bools);

MemRefType compatibleMemRefType = getCastCompatibleMemRefType(
xferOp.getMemRefType(), alloc.getType().cast<MemRefType>());

// Read case: full fill + partial copy -> unmasked vector.xfer_read.
Value zero = std_constant_index(0);
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
b.getIndexType());
returnTypes[0] = compatibleMemRefType;
scf::IfOp fullPartialIfOp;
conditionBuilder(
returnTypes, inBoundsCond,
[&]() -> scf::ValueVector {
Value res = memref;
if (compatibleMemRefType != xferOp.getMemRefType())
res = std_memref_cast(memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
return viewAndIndices;
},
[&]() -> scf::ValueVector {
Operation *newXfer =
ScopedContext::getBuilderRef().clone(*xferOp.getOperation());
Value vector = cast<VectorTransferOpInterface>(newXfer).vector();
std_store(vector, vector_type_cast(
MemRefType::get({}, vector.getType()), alloc));

Value casted = std_memref_cast(alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);

return viewAndIndices;
},
&fullPartialIfOp);
if (ifOp)
*ifOp = fullPartialIfOp;

// Unmask the existing read op, it always reads from a full buffer.
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr);

return success();
}

LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
failed(filter(xferOp)))
return failure();
rewriter.startRootUpdate(xferOp);
if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) {
rewriter.finalizeRootUpdate(xferOp);
return success();
}
rewriter.cancelRootUpdate(xferOp);
return failure();
}

// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
Expand Down
Loading

0 comments on commit d313e9c

Please sign in to comment.