Skip to content

Commit

Permalink
[mlir][vector] Distribute vector.insert op
Browse files Browse the repository at this point in the history
In case the distributed dim of the dest vector is also a dim of the src vector, each lane inserts a smaller part of the source vector. Otherwise, one lane inserts the entire src vector and the other lanes do nothing.

Differential Revision: https://reviews.llvm.org/D137953
  • Loading branch information
matthias-springer committed Jan 9, 2023
1 parent 6a6f62a commit 1523b72
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 2 deletions.
129 changes: 127 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,131 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};

struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;

LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(
warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); });
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
Location loc = insertOp.getLoc();

// "vector.insert %v, %v[] : ..." can be canonicalized to %v.
if (insertOp.getPosition().empty())
return failure();

// Rewrite vector.insert with 1d dest to vector.insertelement.
if (insertOp.getDestVectorType().getRank() == 1) {
assert(insertOp.getPosition().size() == 1 && "expected 1 index");
int64_t pos = insertOp.getPosition()[0].cast<IntegerAttr>().getInt();
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
rewriter.create<arith::ConstantIndexOp>(loc, pos));
return success();
}

if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
// There is no distribution, this is a broadcast. Simply move the insert
// out of the warp op.
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
{insertOp.getSourceType(), insertOp.getDestVectorType()},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
Value newResult = rewriter.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, insertOp.getPosition());
newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
return success();
}

// Find the distributed dimension. There should be exactly one.
auto distrDestType =
warpOp.getResult(operandNumber).getType().cast<VectorType>();
auto yieldedType = operand->get().getType().cast<VectorType>();
int64_t distrDestDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
// support distributing multiple dimensions in the future.
assert(distrDestDim == -1 && "found multiple distributed dims");
distrDestDim = i;
}
}
assert(distrDestDim != -1 && "could not find distributed dimension");

// Compute the distributed source vector type.
VectorType srcVecType = insertOp.getSourceType().cast<VectorType>();
SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
srcVecType.getShape().end());
// E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
// Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
// insert a smaller vector<3xf32>.
// Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
// case, one lane will insert the source vector<96xf32>. The other
// lanes will not do anything.
int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size();
if (distrSrcDim >= 0)
distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
auto distrSrcType =
VectorType::get(distrSrcShape, distrDestType.getElementType());

// Yield source and dest vectors from warp op.
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
{distrSrcType, distrDestType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);

// Insert into the distributed vector.
Value newResult;
if (distrSrcDim >= 0) {
// Every lane inserts a small piece.
newResult = rewriter.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, insertOp.getPosition());
} else {
// One lane inserts the entire source vector.
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
SmallVector<int64_t> newPos = llvm::to_vector(
llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
return attr.cast<IntegerAttr>().getInt();
}));
// tid of inserting lane: pos / elementsPerLane
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
loc, newPos[distrDestDim] / elementsPerLane);
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
// Insert position: pos % elementsPerLane
newPos[distrDestDim] %= elementsPerLane;
auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
Value newInsert = builder.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, newPos);
builder.create<scf::YieldOp>(loc, newInsert);
};
auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc, distributedDest);
};
newResult = rewriter
.create<scf::IfOp>(loc, distrDestType, isInsertingLane,
/*thenBuilder=*/insertingBuilder,
/*elseBuilder=*/nonInsertingBuilder)
.getResult(0);
}

newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
return success();
}
};

/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't change
/// the order of execution. This creates a new scf.for region after the
Expand Down Expand Up @@ -1387,8 +1512,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
WarpOpConstant, WarpOpInsertElement>(patterns.getContext(),
benefit);
WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractElement>(patterns.getContext(),
warpShuffleFromIdxFn, benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
Expand Down
95 changes: 95 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,98 @@ func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
}
return %r : vector<f32>
}

// -----

// CHECK-PROP-LABEL: func @vector_insert_1d(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index
// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP-DAG: %[[C26:.*]] = arith.constant 26 : index
// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32)
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]]
// CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C26]]
// CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) {
// CHECK-PROP: %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[C1]] : index]
// CHECK-PROP: scf.yield %[[INSERT]]
// CHECK-PROP: } else {
// CHECK-PROP: scf.yield %[[W]]#0
// CHECK-PROP: }
// CHECK-PROP: return %[[R]]
func.func @vector_insert_1d(%laneid: index) -> (vector<3xf32>) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) {
%0 = "some_def"() : () -> (vector<96xf32>)
%f = "another_def"() : () -> (f32)
%1 = vector.insert %f, %0[76] : f32 into vector<96xf32>
vector.yield %1 : vector<96xf32>
}
return %r : vector<3xf32>
}

// -----

// CHECK-PROP-LABEL: func @vector_insert_2d_distr_src(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index
// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, vector<4x3xf32>)
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: vector.yield %[[VAL]], %[[VEC]]
// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [2] : vector<3xf32> into vector<4x3xf32>
// CHECK-PROP: return %[[INSERT]]
func.func @vector_insert_2d_distr_src(%laneid: index) -> (vector<4x3xf32>) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x3xf32>) {
%0 = "some_def"() : () -> (vector<4x96xf32>)
%s = "another_def"() : () -> (vector<96xf32>)
%1 = vector.insert %s, %0[2] : vector<96xf32> into vector<4x96xf32>
vector.yield %1 : vector<4x96xf32>
}
return %r : vector<4x3xf32>
}

// -----

// CHECK-PROP-LABEL: func @vector_insert_2d_distr_pos(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index
// CHECK-PROP: %[[C19:.*]] = arith.constant 19 : index
// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, vector<4x96xf32>)
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: vector.yield %[[VAL]], %[[VEC]]
// CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C19]]
// CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<4x96xf32>) {
// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [3] : vector<96xf32> into vector<4x96xf32>
// CHECK-PROP: scf.yield %[[INSERT]]
// CHECK-PROP: } else {
// CHECK-PROP: scf.yield %[[W]]#1
// CHECK-PROP: }
// CHECK-PROP: return %[[R]]
func.func @vector_insert_2d_distr_pos(%laneid: index) -> (vector<4x96xf32>) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) {
%0 = "some_def"() : () -> (vector<128x96xf32>)
%s = "another_def"() : () -> (vector<96xf32>)
%1 = vector.insert %s, %0[79] : vector<96xf32> into vector<128x96xf32>
vector.yield %1 : vector<128x96xf32>
}
return %r : vector<4x96xf32>
}

// -----

// CHECK-PROP-LABEL: func @vector_insert_2d_broadcast(
// CHECK-PROP-SAME: %[[LANEID:.*]]: index
// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, vector<4x96xf32>)
// CHECK-PROP: %[[VEC:.*]] = "some_def"
// CHECK-PROP: %[[VAL:.*]] = "another_def"
// CHECK-PROP: vector.yield %[[VAL]], %[[VEC]]
// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [2] : vector<96xf32> into vector<4x96xf32>
// CHECK-PROP: return %[[INSERT]]
func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) {
%0 = "some_def"() : () -> (vector<4x96xf32>)
%s = "another_def"() : () -> (vector<96xf32>)
%1 = vector.insert %s, %0[2] : vector<96xf32> into vector<4x96xf32>
vector.yield %1 : vector<4x96xf32>
}
return %r : vector<4x96xf32>
}

0 comments on commit 1523b72

Please sign in to comment.