Skip to content

Commit

Permalink
[mlir][vector] Add folding for ExtractOp with ShapeCastOp source
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D89853
  • Loading branch information
ThomasRaoux committed Oct 23, 2020
1 parent 4b90a25 commit 8c72eea
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
57 changes: 57 additions & 0 deletions mlir/lib/Dialect/Vector/VectorOps.cpp
Expand Up @@ -843,6 +843,61 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return Value();
}

// Fold extractOp with source coming from ShapeCast op.
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return Value();
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
return type.getDimSize(type.getRank() - n - 1);
};
int64_t destinationRank =
extractOp.getVectorType().getRank() - extractOp.position().size();
if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
return Value();
if (destinationRank > 0) {
auto destinationType = extractOp.getResult().getType().cast<VectorType>();
for (int64_t i = 0; i < destinationRank; i++) {
// The lowest dimension of of the destination must match the lowest
// dimension of the shapecast op source.
if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
getDimReverse(destinationType, i))
return Value();
}
}
// Extract the strides associated with the extract op vector source. Then use
// this to calculate a linearized position for the extract.
auto extractedPos = extractVector<int64_t>(extractOp.position());
std::reverse(extractedPos.begin(), extractedPos.end());
SmallVector<int64_t, 4> strides;
int64_t stride = 1;
for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
strides.push_back(stride);
stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
}

int64_t position = linearize(extractedPos, strides);
// Then extract the strides assoociated to the shapeCast op vector source and
// delinearize the position using those strides.
SmallVector<int64_t, 4> newStrides;
int64_t numDimension =
shapeCastOp.getSourceVectorType().getRank() - destinationRank;
stride = 1;
for (int64_t i = 0; i < numDimension; i++) {
newStrides.push_back(stride);
stride *=
getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
}
std::reverse(newStrides.begin(), newStrides.end());
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
OpBuilder b(extractOp.getContext());
extractOp.setAttr(ExtractOp::getPositionAttrName(),
b.getI64ArrayAttr(newPosition));
extractOp.setOperand(shapeCastOp.source());
return extractOp.getResult();
}

OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
Expand All @@ -852,6 +907,8 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
return val;
if (auto val = foldExtractFromBroadcast(*this))
return val;
if (auto val = foldExtractFromShapeCast(*this))
return val;
return OpFoldResult();
}

Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Expand Up @@ -394,6 +394,39 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
return %r : vector<4xf32>
}

// -----

// CHECK-LABEL: func @fold_extract_shapecast
// CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
// CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
// CHECK: %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
// CHECK: %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
// CHECK: return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x4x2xf32>)
-> (f32, vector<2xf32>, vector<4x2xf32>) {
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
%1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
%r1 = vector.extract %0[4, 1] : vector<15x2xf32>
%r2 = vector.extract %0[5] : vector<15x2xf32>
%r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
}

// -----

// CHECK-LABEL: fold_extract_shapecast_negative
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
// CHECK: return %[[R]] : vector<4x2xf32>
func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
%arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
%r = vector.extract %0[1] : vector<2x4x2xf32>
return %r : vector<4x2xf32>
}


// -----

// CHECK-LABEL: fold_vector_transfers
Expand Down

0 comments on commit 8c72eea

Please sign in to comment.