Skip to content

Commit

Permalink
[mlir][vector] Add distribution for extract from 0d vector
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D135994
  • Loading branch information
ThomasRaoux committed Oct 14, 2022
1 parent cdfeeb8 commit 1757164
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
33 changes: 31 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,34 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
};

/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
return isa<vector::ExtractElementOp>(op);
});
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
if (extractOp.getVectorType().getRank() != 0)
return failure();
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value newExtract = rewriter.create<vector::ExtractElementOp>(
loc, newWarpOp->getResult(newRetIndices[0]));
newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
return success();
}
};

/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't change
/// the order of execution. This creates a new scf.for region after the
Expand Down Expand Up @@ -1093,8 +1121,9 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
WarpOpScfForOp, WarpOpConstant>(patterns.getContext(), benefit);
WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>(
patterns.getContext(), benefit);
}

void mlir::vector::populateDistributeReduction(
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,24 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {

// -----

// CHECK-PROP-LABEL: func.func @vector_extractelement_simple(
// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<f32>
// CHECK-PROP: vector.yield %[[V]] : vector<f32>
// CHECK-PROP: }
// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector<f32>
// CHECK-PROP: return %[[E]] : f32
func.func @vector_extractelement_simple(%laneid: index) -> (f32) {
%r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
%0 = "some_def"() : () -> (vector<f32>)
%1 = vector.extractelement %0[] : vector<f32>
vector.yield %1 : f32
}
return %r : f32
}

// -----

// CHECK-PROP: func @lane_dependent_warp_propagate_read
// CHECK-PROP-SAME: %[[ID:.*]]: index
func.func @lane_dependent_warp_propagate_read(
Expand Down

0 comments on commit 1757164

Please sign in to comment.