Skip to content

Commit

Permalink
[mlir][Shape] Add a pattern to turn extract from shape_of into tensor…
Browse files Browse the repository at this point in the history
….dim

If I remember correctly this wasn't done previously because dim used to
be in the memref dialect.

Differential Revision: https://reviews.llvm.org/D111651
  • Loading branch information
d0k committed Oct 12, 2021
1 parent 9cf995b commit f67d57c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {

void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
ExtractFromShapeOfExtentTensor>(context);
}

LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def HasStaticShape : Constraint<CPred< [{
$0.getType().dyn_cast<ShapedType>().hasStaticShape()
}]>>;

// Helper that takes the first element of a range.
def TakeFront : NativeCodeCall<"$0.front()">;

// Canonicalization patterns.

def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
Expand Down Expand Up @@ -43,3 +46,9 @@ def SizeToIndexToSizeCanonicalization : Pat<
def TensorCastConstShape : Pat <
(Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg),
[(HasStaticShape $res)]>;

// tensor.extract from shape_of -> tensor.dim. We can take the first index
// because shape_of always returns a 1D tensor.
def ExtractFromShapeOfExtentTensor : Pat<
(Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices),
(Tensor_DimOp $arg, (TakeFront $indices))>;
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1380,3 +1380,17 @@ func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>,
-> tensor<?xindex>
return %0 : tensor<?xindex>
}

// -----

// CHECK-LABEL: func @extract_shapeof
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf64>
func @extract_shapeof(%arg0 : tensor<?x?xf64>) -> index {
%c1 = constant 1 : index
// CHECK: %[[C1:.*]] = constant 1
%shape = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
%result = tensor.extract %shape[%c1] : tensor<2xindex>
// CHECK: return %[[DIM]]
return %result : index
}

0 comments on commit f67d57c

Please sign in to comment.