Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Generalize expand_shape to take shape as explicit input #90040

Merged
merged 3 commits into from
Apr 30, 2024

Conversation

Shukla-Gaurav
Copy link
Contributor

This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values. This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.

The output_shape input to expand_shape follows the static/dynamic representation that's also used in tensor.extract_slice.

Differential Revision: https://reviews.llvm.org/D140821

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 25, 2024

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir-linalg

Author: Gaurav Shukla (Shukla-Gaurav)

Changes

This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values. This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.

The output_shape input to expand_shape follows the static/dynamic representation that's also used in tensor.extract_slice.

Differential Revision: https://reviews.llvm.org/D140821


Patch is 259.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90040.diff

53 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+58-22)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+56-23)
  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+46-12)
  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+2-3)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+12-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+8-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+49-18)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (+1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+5-3)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+75-7)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+9-2)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+71-14)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+3)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+16-8)
  • (modified) mlir/lib/Dialect/Utils/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+69-14)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+3-4)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+8-8)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+2-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+19-10)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+81-33)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir (+3-1)
  • (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+10-10)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+18-12)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+71-37)
  • (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+57-44)
  • (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+14-10)
  • (modified) mlir/test/Dialect/Linalg/reshape_control_fusion.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+192-100)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-3)
  • (modified) mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+3-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+17-18)
  • (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+9-7)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+12-10)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+14-24)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+40-32)
  • (modified) mlir/test/Dialect/MemRef/runtime-verification.mlir (+3-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+14-10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+53-59)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+2-3)
  • (modified) mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (+4-2)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+5-16)
  • (modified) mlir/test/Dialect/Tensor/ops.mlir (+16-2)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+7-7)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..14b8d95ea15b41 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     MemRef_Op<mnemonic, !listconcat(traits,
       [Pure, ViewLikeOpInterface])>,
-    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
   code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return getSrc(); }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     Example:
 
     ```mlir
-    %r = memref.expand_shape %0 [[0, 1], [2]]
-        : memref<?x?xf32> into memref<?x5x?xf32>
+    %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
+        : memref<?x32xf32> into memref<?x?x32xf32>
     ```
 
-    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
-    dynamic in the result type. Otherwise, the op would be ambiguous, as it
-    would not be clear how the source dimension is extended.
-
     If an op can be statically proven to be invalid (e.g, an expansion from
     `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
     it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,41 +1613,80 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     there must be a dynamic result dimension in the corresponding reassociation
     group. Same for strides.
 
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
+
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape)>,
+
+    // It will infer output shape using inferOutputShape() method.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation)>,
+
+    // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-                          getReassociationIndicesAttribute($_builder, reassociation));
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
     }]>,
 
-    // Builder using ReassociationExprs.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps,
+            outputShape);
     }]>,
 
+    // Builder that infers the result layout map. The result shape must be
+    // specified. Otherwise, the op may be ambiguous. The output shape for 
+    // the op will be inferred using the inferOutputShape() method.
+    OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
+               "ArrayRef<ReassociationIndices>":$reassociation)>,
+
     // Builder that infers the result layout map. The result shape must be
     // specified. Otherwise, the op may be ambiguous.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation)>
+               "ArrayRef<ReassociationIndices>":$reassociation,
+               "ArrayRef<OpFoldResult>":$outputShape)>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static FailureOr<MemRefType> computeExpandedType(
         MemRefType srcType, ArrayRef<int64_t> resultShape,
         ArrayRef<ReassociationIndices> reassociation);
+
+    // Infer the output shape for a memref.expand_shape when it is possible
+    // to do so.
+    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
+        OpBuilder &b, Location loc, MemRefType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1707,6 +1737,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1718,7 +1754,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1736,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cf7f3e89079c1c..a403e89a39f98c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
-    Results<(outs AnyRankedTensor:$result)> {
+    Results<(outs AnyTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1102,43 +1097,75 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     rank than the operand `src` whose dimension sizes are a reassociation of
     `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions. It is
-    represented with an array of DenseI64ArrayAttr attribute. Entries in the
-    array are referred to as reassociation maps.
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
+    maps applied to the result tensor with the higher rank must result in the
+    operand tensor with the smaller rank.
 
-    The reassociation maps are applied to the result shape to obtain the operand
-    shape.
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
 
     Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = tensor.expand_shape %a [[0, 1], [2]]
-        : tensor<?x?xf32> into tensor<?x?x?xf32>
+    %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
+        : tensor<?x32xf32> into tensor<?x?x32xf32>
     ```
   }];
+
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape)>,
+
+    // It will infer output shape using inferOutputShape() method.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation)>,
+
+    // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-          getReassociationIndicesAttribute($_builder, reassociation));
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices,
+            outputShape);
     }]>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
+
+    // Infer the output shape for a tensor.expand_shape when it is possible
+    // to do so.
+    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
+        OpBuilder &b, Location loc, RankedTensorType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1146,6 +1173,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1191,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1174,7 +1207,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1192,7 +1225,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index ae9824f728da4d..8a41a0a18b0ab3 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,6 +30,27 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+std::optional<SmallVector<OpFoldResult>>
+inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
+                            ArrayRef<ReassociationIndices> reassociation,
+                            ArrayRef<OpFoldResult> inputShape);
+
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
@@ -62,7 +83,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
 
 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
-    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+    ArrayRef<ReassociationExprs> reassociationExprs);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return std::nullopt when this computation
@@ -140,14 +161,11 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
       op.getReassociationIndices(), isExpansion);
 }
 
-/// Verify that shapes of the reshaped types using following rules
-/// 1) if a dimension in the collapsed type is static, then the corresponding
-///    dimensions in the expanded shape should be
+/// Verify that shapes of the reshaped types using following rule:
+/// if a dimension in the collapsed type is static, then the corresponding
+/// dimensions in the expanded shape should be
 ///    a) static
 ///    b) the product should be same as the collaped shape.
-/// 2) if a dimension in the collaped type is dynamic, one and only one of the
-///    corresponding dimensions in the expanded type should be dynamic. This
-///    rule is only needed with reshape operations that are expanding.
 LogicalResult reshapeLikeShapesAreCompatible(
     function_ref<LogicalResult(const Twine &)> emitError,
     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
@@ -156,9 +174,11 @@ LogicalResult reshapeLikeShapesAreCompatible(
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
+enum class ReshapeOpKind { kExpand, kCollapse };
+
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, ReshapeOpKind opKind>
 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -181,8 +201,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
-    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-        reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+
+    if constexpr (opKind == ReshapeOpKind::kExpand) {
+      SmallVector<OpFoldResult> outputShape(
+          getMixedValues(reshapeOp.getStaticOutputShape(),
+                         reshapeOp.getOutputShape(), rewriter));
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
+          outputShape);
+    } else {
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+    }
     return success();
   }
 };
@@ -215,7 +245,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
+          typename DimOpTy, typename TensorTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -322,8 +353,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
     if (!composedReassociation)
       return failure();
 
+    SmallVector<OpFoldResult> outputShape(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
     rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
+        expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
+        outputShape);
     return success();
   }
 
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 20f019666a2e6a..594bcf5dbb399a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,9 +125,8 @@ SmallVect...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 25, 2024

@llvm/pr-subscribers-mlir-memref

Author: Gaurav Shukla (Shukla-Gaurav)

Changes

This patch generalizes tensor.expand_shape and memref.expand_shape to consume the output shape as a list of SSA values. This enables us to implement generic reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.

The output_shape input to expand_shape follows the static/dynamic representation that's also used in tensor.extract_slice.

Differential Revision: https://reviews.llvm.org/D140821


Patch is 259.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90040.diff

53 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+58-22)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+56-23)
  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+46-12)
  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+2-3)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+12-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+8-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+49-18)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (+1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+5-3)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+75-7)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+9-2)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+71-14)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+3)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+16-8)
  • (modified) mlir/lib/Dialect/Utils/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+69-14)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+3-4)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+8-8)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+2-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+19-10)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+81-33)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir (+3-1)
  • (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+10-10)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+18-12)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+71-37)
  • (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+57-44)
  • (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+14-10)
  • (modified) mlir/test/Dialect/Linalg/reshape_control_fusion.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+192-100)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-3)
  • (modified) mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+3-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+17-18)
  • (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+9-7)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+12-10)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+14-24)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+40-32)
  • (modified) mlir/test/Dialect/MemRef/runtime-verification.mlir (+3-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+14-10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+53-59)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+2-3)
  • (modified) mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (+4-2)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+5-16)
  • (modified) mlir/test/Dialect/Tensor/ops.mlir (+16-2)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+7-7)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..14b8d95ea15b41 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     MemRef_Op<mnemonic, !listconcat(traits,
       [Pure, ViewLikeOpInterface])>,
-    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
   code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return getSrc(); }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     Example:
 
     ```mlir
-    %r = memref.expand_shape %0 [[0, 1], [2]]
-        : memref<?x?xf32> into memref<?x5x?xf32>
+    %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
+        : memref<?x32xf32> into memref<?x?x32xf32>
     ```
 
-    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
-    dynamic in the result type. Otherwise, the op would be ambiguous, as it
-    would not be clear how the source dimension is extended.
-
     If an op can be statically proven to be invalid (e.g, an expansion from
     `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
     it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,41 +1613,80 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     there must be a dynamic result dimension in the corresponding reassociation
     group. Same for strides.
 
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
+
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape)>,
+
+    // It will infer output shape using inferOutputShape() method.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation)>,
+
+    // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-                          getReassociationIndicesAttribute($_builder, reassociation));
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
     }]>,
 
-    // Builder using ReassociationExprs.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps,
+            outputShape);
     }]>,
 
+    // Builder that infers the result layout map. The result shape must be
+    // specified. Otherwise, the op may be ambiguous. The output shape for 
+    // the op will be inferred using the inferOutputShape() method.
+    OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
+               "ArrayRef<ReassociationIndices>":$reassociation)>,
+
     // Builder that infers the result layout map. The result shape must be
     // specified. Otherwise, the op may be ambiguous.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation)>
+               "ArrayRef<ReassociationIndices>":$reassociation,
+               "ArrayRef<OpFoldResult>":$outputShape)>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static FailureOr<MemRefType> computeExpandedType(
         MemRefType srcType, ArrayRef<int64_t> resultShape,
         ArrayRef<ReassociationIndices> reassociation);
+
+    // Infer the output shape for a memref.expand_shape when it is possible
+    // to do so.
+    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
+        OpBuilder &b, Location loc, MemRefType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1707,6 +1737,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     source/result layout map are the faster-varying ones.
   }];
 
+  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1718,7 +1754,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1736,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cf7f3e89079c1c..a403e89a39f98c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,8 +1062,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
-    Results<(outs AnyRankedTensor:$result)> {
+    Results<(outs AnyTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1086,10 +1085,6 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1102,43 +1097,75 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     rank than the operand `src` whose dimension sizes are a reassociation of
     `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions. It is
-    represented with an array of DenseI64ArrayAttr attribute. Entries in the
-    array are referred to as reassociation maps.
+    A reassociation is defined as a continuous grouping of dimensions and is
+    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
+    maps applied to the result tensor with the higher rank must result in the
+    operand tensor with the smaller rank.
 
-    The reassociation maps are applied to the result shape to obtain the operand
-    shape.
+    The representation for the output shape supports a partially-static
+    specification via attributes specified through the `static_output_shape`
+    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
+    corresponding entry has a dynamic value.  There must be exactly as many SSA
+    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
+    `static_output_shape`.
 
     Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = tensor.expand_shape %a [[0, 1], [2]]
-        : tensor<?x?xf32> into tensor<?x?x?xf32>
+    %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
+        : tensor<?x32xf32> into tensor<?x?x32xf32>
     ```
   }];
+
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
+                       Variadic<Index>:$output_shape,
+                       DenseI64ArrayAttr:$static_output_shape);
+
+  let assemblyFormat = [{
+    $src $reassociation `output_shape`
+    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
+    type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape)>,
+
+    // It will infer output shape using inferOutputShape() method.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationIndices>":$reassociation)>,
+
+    // Builder using ReassociationExprs.
+    OpBuilder<(ins "Type":$resultType, "Value":$src,
+      "ArrayRef<ReassociationExprs>":$reassociation),
     [{
-      build($_builder, $_state, resultType, src, attrs);
-      $_state.addAttribute("reassociation",
-          getReassociationIndicesAttribute($_builder, reassociation));
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
+      "ArrayRef<OpFoldResult>":$outputShape),
     [{
-      auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
+      auto reassociationIndices =
+          convertReassociationMapsToIndices(reassociation);
+      build($_builder, $_state, resultType, src, reassociationIndices,
+            outputShape);
     }]>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
+
+    // Infer the output shape for a tensor.expand_shape when it is possible
+    // to do so.
+    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
+        OpBuilder &b, Location loc, RankedTensorType expandedType,
+        ArrayRef<ReassociationIndices> reassociation,
+        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1146,6 +1173,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
+  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1163,6 +1191,11 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
+
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1174,7 +1207,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1192,7 +1225,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices($_builder, reassociation);
+          convertReassociationMapsToIndices(reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index ae9824f728da4d..8a41a0a18b0ab3 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,6 +30,27 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
+// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
+// to do so.
+//
+// Note: This should *only* be used to implement
+// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
+// If you need to infer the output shape you should use the static method of
+// `ExpandShapeOp` instead of calling this.
+//
+// `inputShape` is the shape of the tensor or memref being expanded as a
+// sequence of SSA values or constants. `expandedType` is the output shape of
+// the expand_shape operation. `reassociation` is the reassociation denoting
+// the output dims each input dim is mapped to.
+//
+// Returns the output shape in `outputShape` and `staticOutputShape`, following
+// the conventions for the output_shape and static_output_shape inputs to the
+// expand_shape ops.
+std::optional<SmallVector<OpFoldResult>>
+inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
+                            ArrayRef<ReassociationIndices> reassociation,
+                            ArrayRef<OpFoldResult> inputShape);
+
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
@@ -62,7 +83,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
 
 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
-    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
+    ArrayRef<ReassociationExprs> reassociationExprs);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return std::nullopt when this computation
@@ -140,14 +161,11 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
       op.getReassociationIndices(), isExpansion);
 }
 
-/// Verify that shapes of the reshaped types using following rules
-/// 1) if a dimension in the collapsed type is static, then the corresponding
-///    dimensions in the expanded shape should be
+/// Verify that shapes of the reshaped types using following rule:
+/// if a dimension in the collapsed type is static, then the corresponding
+/// dimensions in the expanded shape should be
 ///    a) static
 ///    b) the product should be same as the collaped shape.
-/// 2) if a dimension in the collaped type is dynamic, one and only one of the
-///    corresponding dimensions in the expanded type should be dynamic. This
-///    rule is only needed with reshape operations that are expanding.
 LogicalResult reshapeLikeShapesAreCompatible(
     function_ref<LogicalResult(const Twine &)> emitError,
     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
@@ -156,9 +174,11 @@ LogicalResult reshapeLikeShapesAreCompatible(
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
+enum class ReshapeOpKind { kExpand, kCollapse };
+
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy>
+template <typename ReshapeOpTy, ReshapeOpKind opKind>
 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -181,8 +201,18 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
-    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-        reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+
+    if constexpr (opKind == ReshapeOpKind::kExpand) {
+      SmallVector<OpFoldResult> outputShape(
+          getMixedValues(reshapeOp.getStaticOutputShape(),
+                         reshapeOp.getOutputShape(), rewriter));
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
+          outputShape);
+    } else {
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
+    }
     return success();
   }
 };
@@ -215,7 +245,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
+          typename DimOpTy, typename TensorTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -322,8 +353,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
     if (!composedReassociation)
       return failure();
 
+    SmallVector<OpFoldResult> outputShape(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
     rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
+        expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
+        outputShape);
     return success();
   }
 
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 20f019666a2e6a..594bcf5dbb399a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,9 +125,8 @@ SmallVect...
[truncated]

@Shukla-Gaurav
Copy link
Contributor Author

@MaheshRavishankar @joker-eph This PR is an exact copy of #69267 with one line change(mlir/lib/Dialect/Utils/CMakeLists.txt) to fix the bot failure. Could you please approve this?

@Shukla-Gaurav
Copy link
Contributor Author

@bjacob fyi

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns with layering here. This adds a dependency on the arithmetic dialect from dialect utils. It's not exactly clear where such a code should go since it seems to be shared between memref and tensor, maybe somewhere next to type inference.

Incidentally, this dependency is also what broke the build when this was first landed.

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp Show resolved Hide resolved
mlir/lib/Dialect/Utils/CMakeLists.txt Outdated Show resolved Hide resolved
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp Outdated Show resolved Hide resolved
@MaheshRavishankar
Copy link
Contributor

@ftynse unless this arith -> utils dependency breaks builds (I dont think it does, cause there is no cyclic dependency afaics), i think it might be better to land this and then remove the dependency as a quick follow up. This is fairly big change, and more things might get added that will have to be resolved. I am inclined to take the snap shot as is, but fix the dependency issue quickly after.

@ftynse
Copy link
Member

ftynse commented Apr 26, 2024

@ftynse unless this arith -> utils dependency breaks builds (I dont think it does, cause there is no cyclic dependency afaics), i think it might be better to land this and then remove the dependency as a quick follow up.

I don't see a cyclic dependency immediately. This is a bit of a bikeshedding discussion, one can simply move a file under a new ShapedTypeArithUtils and resolve the surface concern: DialectUtils suggests that it's something cheap and useful for all dialects, but it comes with a dependency on Arith (and UB, etc.). There's a deeper concern: we don't have a dialect-agnostic way to create constants, so anything that creates constants will have to depend on Arith or similar. I certainly won't block this on resolving that concern.

This is fairly big change, and more things might get added that will have to be resolved. I am inclined to take the snap shot as is, but fix the dependency issue quickly after.

If the issue can be fixed quickly, why not quickly add the solution to the PR? :)

My worry here is by allowing this dependency temporarily, more uses will creep in, making it harder to remove as the time goes. I have made a similar comment this week on another PR...

How about we ask a third opinion, @joker-eph ?

@Shukla-Gaurav
Copy link
Contributor Author

@ftynse unless this arith -> utils dependency breaks builds (I dont think it does, cause there is no cyclic dependency afaics), i think it might be better to land this and then remove the dependency as a quick follow up.

I don't see a cyclic dependency immediately. This is a bit of a bikeshedding discussion, one can simply move a file under a new ShapedTypeArithUtils and resolve the surface concern: DialectUtils suggests that it's something cheap and useful for all dialects, but it comes with a dependency on Arith (and UB, etc.). There's a deeper concern: we don't have a dialect-agnostic way to create constants, so anything that creates constants will have to depend on Arith or similar. I certainly won't block this on resolving that concern.

This is fairly big change, and more things might get added that will have to be resolved. I am inclined to take the snap shot as is, but fix the dependency issue quickly after.

If the issue can be fixed quickly, why not quickly add the solution to the PR? :)

My worry here is by allowing this dependency temporarily, more uses will creep in, making it harder to remove as the time goes. I have made a similar comment this week on another PR...

How about we ask a third opinion, @joker-eph ?

The utility has been moved to Arith/Utils, so no specific dialect dependency on the utils dialect. Let me know if Arith/Utils is the right place to put the utility inferExpandShapeOutputShape?

@ftynse
Copy link
Member

ftynse commented Apr 29, 2024

The utility has been moved to Arith/Utils, so no specific dialect dependency on the utils dialect. Let me know if Arith/Utils is the right place to put the utility inferExpandShapeOutputShape?

Maybe not, but it's a separate discussion. Thanks for addressing my concerns!

@joker-eph
Copy link
Collaborator

joker-eph commented Apr 29, 2024

Thanks @ftynse, and @Shukla-Gaurav for addressing quickly!

ramiro050 and others added 3 commits April 30, 2024 11:31
This patch generalizes tensor.expand_shape and memref.expand_shape to consume
the output shape as a list of SSA values.  This enables us to implement generic
reshape operations with dynamic shapes using collapse_shape/expand_shape pairs.

The output_shape input to expand_shape follows the static/dynamic representation
that's also used in `tensor.extract_slice`.

Differential Revision: https://reviews.llvm.org/D140821
Signed-Off-by: Gaurav Shukla<gaurav.shukla@amd.com>
This commit moves the inferExpandShapeOutputShape utility from the
Dialect/Utils/ReshapeOpsUtils.cpp to Arith/Utils/Utils.cpp in order to
remove specific dialect dependencies from the DialectUtils.

Signed-Off-by: Gaurav Shukla <gaurav.shukla@amd.com>
@rsuderman rsuderman merged commit 97069a8 into llvm:main Apr 30, 2024
4 checks passed
@joker-eph
Copy link
Collaborator

This broke the bot here: https://lab.llvm.org/buildbot/#/builders/264/builds/10013

Unless you have a fix available, please revert.

@PeimingLiu
Copy link
Member

PeimingLiu commented Apr 30, 2024

FYI, tests related to sparse tensor dialect should be fixed by #90637 (though not yet submitted).

@joker-eph
Copy link
Collaborator

joker-eph commented Apr 30, 2024

Priority is always getting the bot back green: so please either merge immediately or revert the original PR in the meantime.

PeimingLiu added a commit that referenced this pull request Apr 30, 2024
@PeimingLiu
Copy link
Member

Merged, but there is one more remaining outside sparse tensor tests.

@hanhanW
Copy link
Contributor

hanhanW commented Apr 30, 2024

Merged, but there is one more remaining outside sparse tensor tests.

Testing now, #90649 should fix the issue.

@hanhanW
Copy link
Contributor

hanhanW commented Apr 30, 2024

side question: are we able to enable integration tests in presubmit? All the integration tests are marked UNSUPPORTED in https://buildkite.com/llvm-project/github-pull-requests/builds/59998#018f2d98-8ac6-47ac-8d79-65d3e7cd6c28

@hanhanW
Copy link
Contributor

hanhanW commented Apr 30, 2024

Merged, but there is one more remaining outside sparse tensor tests.

Testing now, #90649 should fix the issue.

I verified that it is passing with the PR.

hanhanW added a commit that referenced this pull request Apr 30, 2024
bjacob added a commit that referenced this pull request May 3, 2024
#90975)

This is a new take on #89111. Now that #90040 is merged, this has become
trivial to implement. The added test shows the kind of benefit that we
get from this: now dim-of-expand-shape naturally folds without us
needing to implement an ad-hoc folding rewrite.
@Shukla-Gaurav Shukla-Gaurav deleted the gaurav/generalize_expand_shape branch May 9, 2024 15:01
newling added a commit to nod-ai/iree-amd-aie that referenced this pull request May 21, 2024
All bumped to what is currently the most recent version available 

lit test updates for: 
https://github.com/llvm/llvm-project/pull/90897/files and
llvm/llvm-project#90040
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:arith mlir:bufferization Bufferization infrastructure mlir:linalg mlir:memref mlir:sparse Sparse compiler in MLIR mlir:tensor mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants