Skip to content

Commit

Permalink
[mlir][Linalg][Vector] Add forwarding patterns between linalg.copy an…
Browse files Browse the repository at this point in the history
…d vector.transfer

This revision adds custom rewrites for patterns that arise during linalg structured
ops vectorization. These patterns allow the composition of linalg promotion,
vectorization and removal of redundant copies.

The patterns are voluntarily limited and restrictive atm.
More robust behavior will be implemented once more powerful side effect modeling and analyses are available on view/subview.

On the transfer_read side, the following pattern is rewritten:
```
   %alloc = ...
   [optional] %view = std.view %alloc ...
   %subView = subview %allocOrView ...
   [optional] linalg.fill(%allocOrView, %cst) ...
   ...
   linalg.copy(%in, %subView) ...
   vector.transfer_read %allocOrView[...], %cst ...
```
into
```
   [unchanged] %alloc = ...
   [unchanged] [optional] %view = std.view %alloc ...
   [unchanged] [unchanged] %subView = subview %allocOrView ...
   ...
   vector.transfer_read %in[...], %cst ...
```

On the transfer_write side, the following pattern is rewriten:
```
   %alloc = ...
   [optional] %view = std.view %alloc ...
   %subView = subview %allocOrView...
   ...
   vector.transfer_write %..., %allocOrView[...]
   linalg.copy(%subView, %out)
```

Differential Revision: https://reviews.llvm.org/D80728
  • Loading branch information
Nicolas Vasilache committed May 29, 2020
1 parent ea7db62 commit 1ee1143
Show file tree
Hide file tree
Showing 4 changed files with 443 additions and 46 deletions.
68 changes: 68 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
#include "llvm/ADT/SmallBitVector.h"

namespace mlir {
namespace vector {

class TransferReadOp;
class TransferWriteOp;

} // namespace vector

namespace linalg {

struct LinalgTilingOptions;
Expand Down Expand Up @@ -437,6 +444,67 @@ struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringType loweringType;
};

//===----------------------------------------------------------------------===//
// Op-specific patterns.
//===----------------------------------------------------------------------===//
/// Match and rewrite for the pattern:
/// ```
/// %alloc = ...
/// [optional] %view = std.view %alloc ...
/// %subView = subview %allocOrView ...
/// [optional] linalg.fill(%allocOrView, %cst) ...
/// ...
/// linalg.copy(%in, %subView) ...
/// vector.transfer_read %allocOrView[...], %cst ...
/// ```
/// into
/// ```
/// [unchanged] %alloc = ...
/// [unchanged] [optional] %view = std.view %alloc ...
/// [unchanged] [unchanged] %subView = subview %allocOrView ...
/// ...
/// vector.transfer_read %in[...], %cst ...
/// ```
/// Where there is no interleaved use between linalg.copy and transfer_read as
/// well as no interleaved use between linalg.fill and linalg.copy (if
/// linalg.fill is specified).
/// This is a custom rewrite to forward partial reads (with optional fills) to
/// vector.transfer_read.
struct LinalgCopyVTRForwardingPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
PatternRewriter &rewriter) const override;
};

/// Match and rewrite for the pattern:
/// ```
/// %alloc = ...
/// [optional] %view = std.view %alloc ...
/// %subView = subview %allocOrView...
/// ...
/// vector.transfer_write %..., %allocOrView[...]
/// linalg.copy(%subView, %out)
/// ```
/// into
/// ```
/// [unchanged] %alloc = ...
/// [unchanged] [optional] %view = std.view %alloc ...
/// [unchanged] %subView = subview %allocOrView...
/// ...
/// vector.transfer_write %..., %out[...]
/// ```
/// Where there is no interleaved use between transfer_write and linalg.copy.
/// This is a custom rewrite to forward partial writes to vector.transfer_write.
struct LinalgCopyVTWForwardingPattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
PatternRewriter &rewriter) const override;
};

//===----------------------------------------------------------------------===//
// Support for staged pattern application.
//===----------------------------------------------------------------------===//
Expand Down
177 changes: 171 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
llvm_unreachable("Unexpected conv with padding");
}

StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
(void)dbgPref;
edsc::ScopedContext scope(builder, op->getLoc());
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg.fill as vector.broadcast: "
<< *op << ":\n");
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
Value dst = std_load(memref);
Value res = vector_broadcast(dst.getType(), fillOp.value());
Expand All @@ -117,9 +118,8 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
}

// Vectorize other ops as vector contraction (currently only matmul).
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
auto linalgOp = cast<linalg::LinalgOp>(op);
Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
Expand All @@ -129,3 +129,168 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
linalgOp.iterator_types());
std_store(res, memref);
}

/// Check whether there is any interleaved use of any `values` between `firstOp`
/// and `secondOp`. Conservatively return `true` if any op or value is in a
/// different block.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
ValueRange values) {
StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
(void)dbgPref;
if (firstOp->getBlock() != secondOp->getBlock() ||
!firstOp->isBeforeInBlock(secondOp)) {
LLVM_DEBUG(llvm::dbgs()
<< dbgPref << "interleavedUses precondition failed, firstOp: "
<< *firstOp << ", second op: " << *secondOp);
return true;
}
for (auto v : values) {
for (auto &u : v.getUses()) {
Operation *owner = u.getOwner();
if (owner == firstOp || owner == secondOp)
continue;
// TODO: this is too conservative, use dominance info in the future.
if (owner->getBlock() == firstOp->getBlock() &&
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
continue;
LLVM_DEBUG(llvm::dbgs()
<< dbgPref << " found interleaved op " << *owner
<< ", firstOp: " << *firstOp << ", second op: " << *secondOp);
return true;
}
}
return false;
}

/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
static SubViewOp getSubViewUseIfUnique(Value v) {
SubViewOp subViewOp;
for (auto &u : v.getUses()) {
if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
if (subViewOp)
return SubViewOp();
subViewOp = newSubViewOp;
}
}
return subViewOp;
}

/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
/// when available.
LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {

// Transfer into `view`.
Value viewOrAlloc = xferOp.memref();
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
!viewOrAlloc.getDefiningOp<AllocOp>())
return failure();

StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: ";
(void)dbgPref;
LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc);

// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
if (!subViewOp)
return failure();
Value subView = subViewOp.getResult();
LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView);

// Find the copy into `subView` without interleaved uses.
CopyOp copyOp;
for (auto &u : subView.getUses()) {
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
if (newCopyOp.getOutputBuffer(0) != subView)
continue;
LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp);
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
continue;
copyOp = newCopyOp;
break;
}
}
if (!copyOp)
return failure();
LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp);

// Find the fill into `viewOrAlloc` without interleaved uses before the copy.
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
continue;
LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp);
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
continue;
maybeFillOp = newFillOp;
break;
}
}
// Ensure padding matches.
if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
return failure();
if (maybeFillOp)
LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp);

// `in` is the subview that linalg.copy reads. Replace it.
Value in = copyOp.getInput(0);

Value res = rewriter.create<vector::TransferReadOp>(
xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
xferOp.permutation_map(), xferOp.padding(),
xferOp.masked() ? *xferOp.masked() : ArrayAttr());

if (maybeFillOp)
rewriter.eraseOp(maybeFillOp);
rewriter.eraseOp(copyOp);
rewriter.replaceOp(xferOp, res);

return success();
}

/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
/// when available.
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
// Transfer into `viewOrAlloc`.
Value viewOrAlloc = xferOp.memref();
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
!viewOrAlloc.getDefiningOp<AllocOp>())
return failure();

// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
if (!subViewOp)
return failure();
Value subView = subViewOp.getResult();

// Find the copy from `subView` without interleaved uses.
CopyOp copyOp;
for (auto &u : subViewOp.getResult().getUses()) {
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
if (newCopyOp.getInput(0) != subView)
continue;
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
continue;
copyOp = newCopyOp;
break;
}
}
if (!copyOp)
return failure();

// `out` is the subview copied into that we replace.
Value out = copyOp.getOutputBuffer(0);

// Forward vector.transfer into copy.
rewriter.create<vector::TransferWriteOp>(
xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
xferOp.permutation_map(),
xferOp.masked() ? *xferOp.masked() : ArrayAttr());

rewriter.eraseOp(copyOp);
rewriter.eraseOp(xferOp);

return success();
}
Loading

0 comments on commit 1ee1143

Please sign in to comment.