Skip to content

Commit

Permalink
[mlir][vector] Add patterns for vector distribution
Browse files Browse the repository at this point in the history
Add pattern to hoist scalar code outside of warp distribute region as
those cannot be distributed and we would want to execute them on all
the lanes.
Add patterns to distribute transfer_write ops. Those operations can be
distributed in different ways and it is control by user.

Differential Revision: https://reviews.llvm.org/D127152
  • Loading branch information
ThomasRaoux committed Jun 10, 2022
1 parent b5019ff commit ed0288f
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 3 deletions.
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
Expand Up @@ -39,6 +39,32 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
RewritePatternSet &patterns,
const WarpExecuteOnLane0LoweringOptions &options);

using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;

/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// Example:
/// ```
/// %0 = vector.warp_execute_on_lane_0(%id){
/// ...
/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
/// vector.yield
/// }
/// ```
/// To
/// ```
/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
/// ...
/// vector.yield %v : vector<32xf32>
/// }
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
void populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, DistributionMapFn distributionMapFn);

/// Move scalar operations with no dependency on the warp op outside of the
/// region.
void moveScalarUniformCode(WarpExecuteOnLane0Op op);

} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
238 changes: 236 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Expand Up @@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Transforms/SideEffectUtils.h"

using namespace mlir;
using namespace mlir::vector;
Expand Down Expand Up @@ -93,8 +95,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
if (resultType == val.getType()) {
// Result type and yielded value type are the same. This is a broadcast.
// E.g.:
// %r = vector_ext.warp_execute_on_lane_0(...) -> (f32) {
// vector_ext.yield %cst : f32
// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
// vector.yield %cst : f32
// }
// Both types are f32. The constant %cst is broadcasted to all lanes.
// This is described in more detail in the documentation of the op.
Expand Down Expand Up @@ -131,6 +133,54 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
return success();
}

/// Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes) {
// Create a new op before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(warpOp);
auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());

Region &opBody = warpOp.getBodyRegion();
Region &newOpBody = newWarpOp.getBodyRegion();
rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
auto yield =
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());

rewriter.updateRootInPlace(
yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
return newWarpOp;
}

/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
ValueRange newYieldedValues, TypeRange newReturnTypes) {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
types.append(newReturnTypes.begin(), newReturnTypes.end());
auto yield = cast<vector::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
SmallVector<Value> yieldValues(yield.getOperands().begin(),
yield.getOperands().end());
yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, yieldValues, types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
}

/// Helper to know if an op can be hoisted out of the region.
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
return llvm::all_of(op->getOperands(), definedOutside) &&
isSideEffectFree(op) && op->getNumRegions() == 0;
}

namespace {

struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
Expand All @@ -149,10 +199,194 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
const WarpExecuteOnLane0LoweringOptions &options;
};

/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// Example:
/// ```
/// %0 = vector.warp_execute_on_lane_0(%id){
/// ...
/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
/// vector.yield
/// }
/// ```
/// To
/// ```
/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
/// ...
/// vector.yield %v : vector<32xf32>
/// }
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
PatternBenefit b = 1)
: OpRewritePattern<vector::TransferWriteOp>(ctx, b),
distributionMapFn(fn) {}

/// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
/// are multiples of the distribution ratio are supported at the moment.
LogicalResult tryDistributeOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
AffineMap map = distributionMapFn(writeOp);
SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
writeOp.getVectorType().getShape().end());
assert(map.getNumResults() == 1 &&
"multi-dim distribution not implemented yet");
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
if (targetShape[position] % warpOp.getWarpSize() != 0)
return failure();
targetShape[position] = targetShape[position] / warpOp.getWarpSize();
}
VectorType targetType =
VectorType::get(targetShape, writeOp.getVectorType().getElementType());

SmallVector<Value> yieldValues = {writeOp.getVector()};
SmallVector<Type> retTypes = {targetType};
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes);
rewriter.setInsertionPointAfter(newWarpOp);

// Move op outside of region: Insert clone at the insertion point and delete
// the old op.
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);

rewriter.setInsertionPoint(newWriteOp);
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
Location loc = newWriteOp.getLoc();
SmallVector<Value> indices(newWriteOp.getIndices().begin(),
newWriteOp.getIndices().end());
for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(newWarpOp.getContext(), d0, d1);
auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
auto scale =
getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
indices[indexPos] =
makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
{indices[indexPos], newWarpOp.getLaneid()});
}
newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
newWriteOp.getIndicesMutable().assign(indices);

return success();
}

/// Extract TransferWriteOps of vector<1x> into a separate warp op.
LogicalResult tryExtractOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
Location loc = writeOp.getLoc();
VectorType vecType = writeOp.getVectorType();

// Only vector<1x> is supported at the moment.
if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
return failure();

// Do not process warp ops that contain only TransferWriteOps.
if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
}))
return failure();

SmallVector<Value> yieldValues = {writeOp.getVector()};
SmallVector<Type> retTypes = {vecType};
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes);
rewriter.setInsertionPointAfter(newWarpOp);

// Create a second warp op that contains only writeOp.
auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
Block &body = secondWarpOp.getBodyRegion().front();
rewriter.setInsertionPointToStart(&body);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
newWriteOp.getVectorMutable().assign(
newWarpOp.getResult(newWarpOp.getNumResults() - 1));
rewriter.eraseOp(writeOp);
rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
return success();
}

LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
// Ops with mask not supported yet.
if (writeOp.getMask())
return failure();

auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
if (!warpOp)
return failure();

// There must be no op with a side effect after writeOp.
Operation *nextOp = writeOp.getOperation();
while ((nextOp = nextOp->getNextNode()))
if (!isSideEffectFree(nextOp))
return failure();

if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
return writeOp.getVector() == value ||
warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();

if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
return success();

if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
return success();

return failure();
}

private:
DistributionMapFn distributionMapFn;
};

} // namespace

void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
RewritePatternSet &patterns,
const WarpExecuteOnLane0LoweringOptions &options) {
patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
}

void mlir::vector::populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
}

void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
Block *body = warpOp.getBody();

// Keep track of the ops we want to hoist.
llvm::SmallSetVector<Operation *, 8> opsToMove;

// Helper to check if a value is or will be defined outside of the region.
auto isDefinedOutsideOfBody = [&](Value value) {
auto *definingOp = value.getDefiningOp();
return (definingOp && opsToMove.count(definingOp)) ||
warpOp.isDefinedOutsideOfRegion(value);
};

// Do not use walk here, as we do not want to go into nested regions and hoist
// operations from there.
for (auto &op : body->without_terminator()) {
bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
return result.getType().isa<VectorType>();
});
if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
opsToMove.insert(&op);
}

// Move all the ops marked as uniform outside of the region.
for (Operation *op : opsToMove)
op->moveBefore(warpOp);
}
75 changes: 75 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1,4 +1,6 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s

// CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3>
// CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3>
Expand Down Expand Up @@ -52,3 +54,76 @@ func.func @rewrite_warp_op_to_scf_if(%laneid: index,
"some_use"(%r#1) : (vector<2xf32>) -> ()
return
}

// -----

// CHECK-D-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 2 + 32)>

// CHECK-ALL-LABEL: func @warp(
// CHECK-HOIST: memref.subview
// CHECK-HOIST: memref.subview
// CHECK-HOIST: memref.subview
// CHECK-HOIST: vector.warp_execute_on_lane_0

// CHECK-D: %[[R:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>, vector<1xf32>) {
// CHECK-D: arith.addf {{.*}} : vector<32xf32>
// CHECK-D: arith.addf {{.*}} : vector<64xf32>
// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<64xf32>, vector<32xf32>
// CHECK-D-DAG: vector.transfer_write %[[R]]#1, %{{.*}}[%{{.*}}] {in_bounds = [true]} : vector<1xf32>, memref<128xf32
// CHECK-D-DAG: %[[ID1:.*]] = affine.apply #[[MAP1]]()[%{{.*}}]
// CHECK-D-DAG: vector.transfer_write %[[R]]#0, %2[%[[ID1]]] {in_bounds = [true]} : vector<2xf32>, memref<128xf32

// CHECK-ALL-NOT: vector.warp_execute_on_lane_0
// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32>
// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32>
// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32>
// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32>
// CHECK-ALL: arith.addf {{.*}} : vector<1xf32>
// CHECK-ALL: arith.addf {{.*}} : vector<2xf32>
// CHECK-ALL: vector.transfer_write {{.*}} : vector<1xf32>
// CHECK-ALL: vector.transfer_write {{.*}} : vector<2xf32>

#map0 = affine_map<(d0)[s0] -> (d0 + s0)>
func.func @warp(%laneid: index, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>,
%arg3: memref<1024xf32>, %gid : index) {
vector.warp_execute_on_lane_0(%laneid)[32] {
%sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
%sb = memref.subview %arg2[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
%sc = memref.subview %arg3[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0>
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%cst = arith.constant 0.000000e+00 : f32
%2 = vector.transfer_read %sa[%c0], %cst : memref<128xf32, #map0>, vector<32xf32>
%3 = vector.transfer_read %sa[%c32], %cst : memref<128xf32, #map0>, vector<32xf32>
%4 = vector.transfer_read %sb[%c0], %cst : memref<128xf32, #map0>, vector<64xf32>
%5 = vector.transfer_read %sb[%c32], %cst : memref<128xf32, #map0>, vector<64xf32>
%6 = arith.addf %2, %3 : vector<32xf32>
%7 = arith.addf %4, %5 : vector<64xf32>
vector.transfer_write %6, %sc[%c0] : vector<32xf32>, memref<128xf32, #map0>
vector.transfer_write %7, %sc[%c32] : vector<64xf32>, memref<128xf32, #map0>
}
return
}

// -----

// CHECK-D-LABEL: func @warp_extract(
// CHECK-D: %[[WARPOP:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>)
// CHECK-D: "test.dummy_op"
// CHECK-D: vector.yield %{{.*}} : vector<1xf32>
// CHECK-D: }
// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] {
// CHECK-D: vector.transfer_write %[[WARPOP]], %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32>
// CHECK-D: }

#map2 = affine_map<(d0)[s0] -> (d0 + s0)>

func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) {
vector.warp_execute_on_lane_0(%laneid)[32] {
%sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map2>
%c0 = arith.constant 0 : index
%v = "test.dummy_op"() : () -> (vector<1xf32>)
vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2>
}
return
}

0 comments on commit ed0288f

Please sign in to comment.