Skip to content

Commit

Permalink
[mlir][vector] Add unrolling pattern for TransposeOp
Browse files Browse the repository at this point in the history
Support unrolling for vector.transpose following the same interface as
other vector unrolling ops.

Differential Revision: https://reviews.llvm.org/D123688
  • Loading branch information
ThomasRaoux committed Apr 13, 2022
1 parent 26eec9e commit 5b1b710
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 6 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Expand Up @@ -2217,6 +2217,7 @@ def Vector_CreateMaskOp :

def Vector_TransposeOp :
Vector_Op<"transpose", [NoSideEffect,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Expand Up @@ -4320,6 +4320,10 @@ LogicalResult vector::TransposeOp::verify() {
return success();
}

Optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultType().getShape());
}

namespace {

// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
Expand Down
52 changes: 50 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
Expand Up @@ -681,14 +681,62 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
const vector::UnrollVectorOptions options;
};

struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTranposePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
options(options) {}
LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
PatternRewriter &rewriter) const override {
if (tranposeOp.getResultType().getRank() == 0)
return failure();
auto targetShape = getTargetShape(options, tranposeOp);
if (!targetShape)
return failure();
auto originalVectorType = tranposeOp.getResultType();
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
Location loc = tranposeOp.getLoc();
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
int64_t sliceCount = computeMaxLinearIndex(ratio);
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
SmallVector<int64_t> permutation;
tranposeOp.getTransp(permutation);
for (int64_t i = 0; i < sliceCount; i++) {
SmallVector<int64_t, 4> elementOffsets =
getVectorOffset(originalSize, *targetShape, i);
SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
// Compute the source offsets and shape.
for (auto &indices : llvm::enumerate(permutation)) {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
Value tranposedSlice =
rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, tranposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(tranposeOp, result);
return success();
}

private:
vector::UnrollVectorOptions options;
};

} // namespace

void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern>(
patterns.getContext(), options);
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTranposePattern>(patterns.getContext(), options);
}

void mlir::vector::populatePropagateVectorDistributionPatterns(
Expand Down
40 changes: 36 additions & 4 deletions mlir/test/Dialect/Vector/vector-unroll-options.mlir
Expand Up @@ -107,6 +107,11 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK: return %[[V2]] : vector<4xf32>


func @vector_reduction(%v : vector<8xf32>) -> f32 {
%0 = vector.reduction <add>, %v : vector<8xf32> into f32
return %0 : f32
}
// CHECK-LABEL: func @vector_reduction(
// CHECK-SAME: %[[v:.*]]: vector<8xf32>
// CHECK: %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
Expand All @@ -121,8 +126,35 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
// CHECK: %[[r3:.*]] = vector.reduction <add>, %[[s3]]
// CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
// CHECK: return %[[add3]]
func @vector_reduction(%v : vector<8xf32>) -> f32 {
%0 = vector.reduction <add>, %v : vector<8xf32> into f32
return %0 : f32
}

func @vector_tranpose(%v : vector<2x4x3x8xf32>) -> vector<2x3x8x4xf32> {
%t = vector.transpose %v, [0, 2, 3, 1] : vector<2x4x3x8xf32> to vector<2x3x8x4xf32>
return %t : vector<2x3x8x4xf32>
}
// CHECK-LABEL: func @vector_tranpose
// CHECK: %[[VI:.*]] = arith.constant dense<0.000000e+00> : vector<2x3x8x4xf32>
// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
// CHECK: %[[T0:.*]] = vector.transpose %[[E0]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V0:.*]] = vector.insert_strided_slice %[[T0]], %[[VI]] {offsets = [0, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
// CHECK: %[[T1:.*]] = vector.transpose %[[E1]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[T1]], %[[V0]] {offsets = [0, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
// CHECK: %[[T2:.*]] = vector.transpose %[[E2]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[T2]], %[[V1]] {offsets = [0, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[E3]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[T3]], %[[V2]] {offsets = [0, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
// CHECK: %[[T4:.*]] = vector.transpose %[[E4]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V4:.*]] = vector.insert_strided_slice %[[T4]], %[[V3]] {offsets = [1, 0, 0, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 0], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
// CHECK: %[[T5:.*]] = vector.transpose %[[E5]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[T5]], %[[V4]] {offsets = [1, 0, 0, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
// CHECK: %[[T6:.*]] = vector.transpose %[[E6]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V6:.*]] = vector.insert_strided_slice %[[T6]], %[[V5]] {offsets = [1, 0, 4, 0], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 2, 0, 4], sizes = [1, 2, 3, 4], strides = [1, 1, 1, 1]} : vector<2x4x3x8xf32> to vector<1x2x3x4xf32>
// CHECK: %[[T7:.*]] = vector.transpose %[[E7]], [0, 2, 3, 1] : vector<1x2x3x4xf32> to vector<1x3x4x2xf32>
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[T7]], %[[V6]] {offsets = [1, 0, 4, 2], strides = [1, 1, 1, 1]} : vector<1x3x4x2xf32> into vector<2x3x8x4xf32>
// CHECK: return %[[V7]] : vector<2x3x8x4xf32>
6 changes: 6 additions & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Expand Up @@ -282,6 +282,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::TransposeOp>(op));
}));

if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
Expand Down

0 comments on commit 5b1b710

Please sign in to comment.