Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[MLIR] Generalize expand_shape to take shape as explicit input" #89540

Merged
merged 1 commit into from
Apr 21, 2024

Conversation

joker-eph
Copy link
Collaborator

Reverts #69267

this broke some bots.

@joker-eph joker-eph added the skip-precommit-approval PR for CI feedback, not intended for review label Apr 21, 2024
@joker-eph joker-eph merged commit 8c0341d into main Apr 21, 2024
5 of 6 checks passed
@joker-eph joker-eph deleted the revert-69267-arcpatch-D140821 branch April 21, 2024 12:33
@llvmbot llvmbot added mlir:linalg mlir:sparse Sparse compiler in MLIR mlir mlir:tensor mlir:bufferization Bufferization infrastructure mlir:memref bazel "Peripheral" support tier build system: utils/bazel labels Apr 21, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 21, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir-tensor

Author: Mehdi Amini (joker-eph)

Changes

Reverts llvm/llvm-project#69267

this broke some bots.


Patch is 259.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89540.diff

52 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+22-58)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+23-56)
  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+12-46)
  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+3-2)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+4-12)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+4-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+5-8)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+18-49)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+3-5)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+7-75)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+2-9)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+14-71)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (-3)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+8-16)
  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+14-69)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+4-3)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+8-8)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+2-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-19)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+33-81)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir (+1-3)
  • (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+10-10)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+12-18)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+37-71)
  • (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+44-57)
  • (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+10-14)
  • (modified) mlir/test/Dialect/Linalg/reshape_control_fusion.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+100-192)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+3-2)
  • (modified) mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+1-3)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+18-17)
  • (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+7-9)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+10-12)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+24-14)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+32-40)
  • (modified) mlir/test/Dialect/MemRef/runtime-verification.mlir (+2-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+10-14)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+59-53)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+3-2)
  • (modified) mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (+2-4)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+16-5)
  • (modified) mlir/test/Dialect/Tensor/ops.mlir (+2-16)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+7-7)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (-1)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 14b8d95ea15b41..39e66cd9e6e5ab 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,6 +1548,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     MemRef_Op<mnemonic, !listconcat(traits,
       [Pure, ViewLikeOpInterface])>,
+    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
   code commonExtraClassDeclaration = [{
@@ -1572,6 +1573,10 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return getSrc(); }
   }];
 
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1593,10 +1598,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     Example:
 
     ```mlir
-    %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
-        : memref<?x32xf32> into memref<?x?x32xf32>
+    %r = memref.expand_shape %0 [[0, 1], [2]]
+        : memref<?x?xf32> into memref<?x5x?xf32>
     ```
 
+    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
+    dynamic in the result type. Otherwise, the op would be ambiguous, as it
+    would not be clear how the source dimension is extended.
+
     If an op can be statically proven to be invalid (e.g, an expansion from
     `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
     it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1613,80 +1622,41 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     there must be a dynamic result dimension in the corresponding reassociation
     group. Same for strides.
 
-    The representation for the output shape supports a partially-static
-    specification via attributes specified through the `static_output_shape`
-    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
-    corresponding entry has a dynamic value.  There must be exactly as many SSA
-    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
-    `static_output_shape`.
-
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
 
-  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
-                       Variadic<Index>:$output_shape,
-                       DenseI64ArrayAttr:$static_output_shape);
-
-  let assemblyFormat = [{
-    $src $reassociation `output_shape`
-    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
-    type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape)>,
-
-    // It will infer output shape using inferOutputShape() method.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation)>,
-
-    // Builder using ReassociationExprs.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices);
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+                          getReassociationIndicesAttribute($_builder, reassociation));
     }]>,
 
+    // Builder using ReassociationExprs.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps,
-            outputShape);
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>,
 
-    // Builder that infers the result layout map. The result shape must be
-    // specified. Otherwise, the op may be ambiguous. The output shape for 
-    // the op will be inferred using the inferOutputShape() method.
-    OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation)>,
-
     // Builder that infers the result layout map. The result shape must be
     // specified. Otherwise, the op may be ambiguous.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation,
-               "ArrayRef<OpFoldResult>":$outputShape)>
+               "ArrayRef<ReassociationIndices>":$reassociation)>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static FailureOr<MemRefType> computeExpandedType(
         MemRefType srcType, ArrayRef<int64_t> resultShape,
         ArrayRef<ReassociationIndices> reassociation);
-
-    // Infer the output shape for a memref.expand_shape when it is possible
-    // to do so.
-    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
-        OpBuilder &b, Location loc, MemRefType expandedType,
-        ArrayRef<ReassociationIndices> reassociation,
-        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1737,12 +1707,6 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     source/result layout map are the faster-varying ones.
   }];
 
-  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
-
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1754,7 +1718,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1772,7 +1736,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98c..cf7f3e89079c1c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,7 +1062,8 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Results<(outs AnyTensor:$result)> {
+    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
+    Results<(outs AnyRankedTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1085,6 +1086,10 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1097,75 +1102,43 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     rank than the operand `src` whose dimension sizes are a reassociation of
     `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
-    maps applied to the result tensor with the higher rank must result in the
-    operand tensor with the smaller rank.
+    A reassociation is defined as a continuous grouping of dimensions. It is
+    represented with an array of DenseI64ArrayAttr attribute. Entries in the
+    array are referred to as reassociation maps.
 
-    The representation for the output shape supports a partially-static
-    specification via attributes specified through the `static_output_shape`
-    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
-    corresponding entry has a dynamic value.  There must be exactly as many SSA
-    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
-    `static_output_shape`.
+    The reassociation maps are applied to the result shape to obtain the operand
+    shape.
 
     Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
-        : tensor<?x32xf32> into tensor<?x?x32xf32>
+    %b = tensor.expand_shape %a [[0, 1], [2]]
+        : tensor<?x?xf32> into tensor<?x?x?xf32>
     ```
   }];
-
-  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
-                       Variadic<Index>:$output_shape,
-                       DenseI64ArrayAttr:$static_output_shape);
-
-  let assemblyFormat = [{
-    $src $reassociation `output_shape`
-    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
-    type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape)>,
-
-    // It will infer output shape using inferOutputShape() method.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation)>,
-
-    // Builder using ReassociationExprs.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices);
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+          getReassociationIndicesAttribute($_builder, reassociation));
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices,
-            outputShape);
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
-
-    // Infer the output shape for a tensor.expand_shape when it is possible
-    // to do so.
-    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
-        OpBuilder &b, Location loc, RankedTensorType expandedType,
-        ArrayRef<ReassociationIndices> reassociation,
-        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1173,7 +1146,6 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
-  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1191,11 +1163,6 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
-
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1207,7 +1174,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1225,7 +1192,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 8a41a0a18b0ab3..ae9824f728da4d 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,27 +30,6 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
-// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
-// to do so.
-//
-// Note: This should *only* be used to implement
-// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
-// If you need to infer the output shape you should use the static method of
-// `ExpandShapeOp` instead of calling this.
-//
-// `inputShape` is the shape of the tensor or memref being expanded as a
-// sequence of SSA values or constants. `expandedType` is the output shape of
-// the expand_shape operation. `reassociation` is the reassociation denoting
-// the output dims each input dim is mapped to.
-//
-// Returns the output shape in `outputShape` and `staticOutputShape`, following
-// the conventions for the output_shape and static_output_shape inputs to the
-// expand_shape ops.
-std::optional<SmallVector<OpFoldResult>>
-inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
-                            ArrayRef<ReassociationIndices> reassociation,
-                            ArrayRef<OpFoldResult> inputShape);
-
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
@@ -83,7 +62,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
 
 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
-    ArrayRef<ReassociationExprs> reassociationExprs);
+    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return std::nullopt when this computation
@@ -161,11 +140,14 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
       op.getReassociationIndices(), isExpansion);
 }
 
-/// Verify that shapes of the reshaped types using following rule:
-/// if a dimension in the collapsed type is static, then the corresponding
-/// dimensions in the expanded shape should be
+/// Verify that shapes of the reshaped types using following rules
+/// 1) if a dimension in the collapsed type is static, then the corresponding
+///    dimensions in the expanded shape should be
 ///    a) static
 ///    b) the product should be same as the collaped shape.
+/// 2) if a dimension in the collaped type is dynamic, one and only one of the
+///    corresponding dimensions in the expanded type should be dynamic. This
+///    rule is only needed with reshape operations that are expanding.
 LogicalResult reshapeLikeShapesAreCompatible(
     function_ref<LogicalResult(const Twine &)> emitError,
     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
@@ -174,11 +156,9 @@ LogicalResult reshapeLikeShapesAreCompatible(
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
-enum class ReshapeOpKind { kExpand, kCollapse };
-
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy, ReshapeOpKind opKind>
+template <typename ReshapeOpTy>
 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -201,18 +181,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
-
-    if constexpr (opKind == ReshapeOpKind::kExpand) {
-      SmallVector<OpFoldResult> outputShape(
-          getMixedValues(reshapeOp.getStaticOutputShape(),
-                         reshapeOp.getOutputShape(), rewriter));
-      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
-          outputShape);
-    } else {
-      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
-    }
+    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+        reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
     return success();
   }
 };
@@ -245,8 +215,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
-          typename DimOpTy, typename TensorTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -353,11 +322,8 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
     if (!composedReassociation)
       return failure();
 
-    SmallVector<OpFoldResult> outputShape(getMixedValues(
-        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
     rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
-        outputShape);
+        expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
     return success();
   }
 
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 594bcf5dbb399a..20f019666a2e6a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,8 +125,9 @@ SmallVect...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 21, 2024

@llvm/pr-subscribers-mlir-memref

Author: Mehdi Amini (joker-eph)

Changes

Reverts llvm/llvm-project#69267

this broke some bots.


Patch is 259.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89540.diff

52 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+22-58)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+23-56)
  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+12-46)
  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+3-2)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+4-12)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+4-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+5-8)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+18-49)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+3-5)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+7-75)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+2-9)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+14-71)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (-3)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+8-16)
  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+14-69)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+4-3)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+8-8)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+2-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-19)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+33-81)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir (+1-3)
  • (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+10-10)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+12-18)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+37-71)
  • (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+44-57)
  • (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+10-14)
  • (modified) mlir/test/Dialect/Linalg/reshape_control_fusion.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+100-192)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+3-2)
  • (modified) mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+1-3)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+18-17)
  • (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+7-9)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+10-12)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+24-14)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+32-40)
  • (modified) mlir/test/Dialect/MemRef/runtime-verification.mlir (+2-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+10-14)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+59-53)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+3-2)
  • (modified) mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (+2-4)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+16-5)
  • (modified) mlir/test/Dialect/Tensor/ops.mlir (+2-16)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+7-7)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (-1)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 14b8d95ea15b41..39e66cd9e6e5ab 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,6 +1548,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     MemRef_Op<mnemonic, !listconcat(traits,
       [Pure, ViewLikeOpInterface])>,
+    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
   code commonExtraClassDeclaration = [{
@@ -1572,6 +1573,10 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return getSrc(); }
   }];
 
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1593,10 +1598,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     Example:
 
     ```mlir
-    %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
-        : memref<?x32xf32> into memref<?x?x32xf32>
+    %r = memref.expand_shape %0 [[0, 1], [2]]
+        : memref<?x?xf32> into memref<?x5x?xf32>
     ```
 
+    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
+    dynamic in the result type. Otherwise, the op would be ambiguous, as it
+    would not be clear how the source dimension is extended.
+
     If an op can be statically proven to be invalid (e.g, an expansion from
     `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
     it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1613,80 +1622,41 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     there must be a dynamic result dimension in the corresponding reassociation
     group. Same for strides.
 
-    The representation for the output shape supports a partially-static
-    specification via attributes specified through the `static_output_shape`
-    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
-    corresponding entry has a dynamic value.  There must be exactly as many SSA
-    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
-    `static_output_shape`.
-
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
 
-  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
-                       Variadic<Index>:$output_shape,
-                       DenseI64ArrayAttr:$static_output_shape);
-
-  let assemblyFormat = [{
-    $src $reassociation `output_shape`
-    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
-    type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape)>,
-
-    // It will infer output shape using inferOutputShape() method.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation)>,
-
-    // Builder using ReassociationExprs.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices);
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+                          getReassociationIndicesAttribute($_builder, reassociation));
     }]>,
 
+    // Builder using ReassociationExprs.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps,
-            outputShape);
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>,
 
-    // Builder that infers the result layout map. The result shape must be
-    // specified. Otherwise, the op may be ambiguous. The output shape for 
-    // the op will be inferred using the inferOutputShape() method.
-    OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation)>,
-
     // Builder that infers the result layout map. The result shape must be
     // specified. Otherwise, the op may be ambiguous.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation,
-               "ArrayRef<OpFoldResult>":$outputShape)>
+               "ArrayRef<ReassociationIndices>":$reassociation)>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static FailureOr<MemRefType> computeExpandedType(
         MemRefType srcType, ArrayRef<int64_t> resultShape,
         ArrayRef<ReassociationIndices> reassociation);
-
-    // Infer the output shape for a memref.expand_shape when it is possible
-    // to do so.
-    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
-        OpBuilder &b, Location loc, MemRefType expandedType,
-        ArrayRef<ReassociationIndices> reassociation,
-        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1737,12 +1707,6 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     source/result layout map are the faster-varying ones.
   }];
 
-  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
-
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1754,7 +1718,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1772,7 +1736,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98c..cf7f3e89079c1c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,7 +1062,8 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Results<(outs AnyTensor:$result)> {
+    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
+    Results<(outs AnyRankedTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1085,6 +1086,10 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1097,75 +1102,43 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     rank than the operand `src` whose dimension sizes are a reassociation of
     `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
-    maps applied to the result tensor with the higher rank must result in the
-    operand tensor with the smaller rank.
+    A reassociation is defined as a continuous grouping of dimensions. It is
+    represented with an array of DenseI64ArrayAttr attribute. Entries in the
+    array are referred to as reassociation maps.
 
-    The representation for the output shape supports a partially-static
-    specification via attributes specified through the `static_output_shape`
-    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
-    corresponding entry has a dynamic value.  There must be exactly as many SSA
-    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
-    `static_output_shape`.
+    The reassociation maps are applied to the result shape to obtain the operand
+    shape.
 
     Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
-        : tensor<?x32xf32> into tensor<?x?x32xf32>
+    %b = tensor.expand_shape %a [[0, 1], [2]]
+        : tensor<?x?xf32> into tensor<?x?x?xf32>
     ```
   }];
-
-  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
-                       Variadic<Index>:$output_shape,
-                       DenseI64ArrayAttr:$static_output_shape);
-
-  let assemblyFormat = [{
-    $src $reassociation `output_shape`
-    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
-    type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape)>,
-
-    // It will infer output shape using inferOutputShape() method.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation)>,
-
-    // Builder using ReassociationExprs.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices);
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+          getReassociationIndicesAttribute($_builder, reassociation));
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices,
-            outputShape);
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
-
-    // Infer the output shape for a tensor.expand_shape when it is possible
-    // to do so.
-    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
-        OpBuilder &b, Location loc, RankedTensorType expandedType,
-        ArrayRef<ReassociationIndices> reassociation,
-        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1173,7 +1146,6 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
-  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1191,11 +1163,6 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
-
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1207,7 +1174,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1225,7 +1192,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 8a41a0a18b0ab3..ae9824f728da4d 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,27 +30,6 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
-// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
-// to do so.
-//
-// Note: This should *only* be used to implement
-// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
-// If you need to infer the output shape you should use the static method of
-// `ExpandShapeOp` instead of calling this.
-//
-// `inputShape` is the shape of the tensor or memref being expanded as a
-// sequence of SSA values or constants. `expandedType` is the output shape of
-// the expand_shape operation. `reassociation` is the reassociation denoting
-// the output dims each input dim is mapped to.
-//
-// Returns the output shape in `outputShape` and `staticOutputShape`, following
-// the conventions for the output_shape and static_output_shape inputs to the
-// expand_shape ops.
-std::optional<SmallVector<OpFoldResult>>
-inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
-                            ArrayRef<ReassociationIndices> reassociation,
-                            ArrayRef<OpFoldResult> inputShape);
-
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
@@ -83,7 +62,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
 
 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
-    ArrayRef<ReassociationExprs> reassociationExprs);
+    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return std::nullopt when this computation
@@ -161,11 +140,14 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
       op.getReassociationIndices(), isExpansion);
 }
 
-/// Verify that shapes of the reshaped types using following rule:
-/// if a dimension in the collapsed type is static, then the corresponding
-/// dimensions in the expanded shape should be
+/// Verify that shapes of the reshaped types using following rules
+/// 1) if a dimension in the collapsed type is static, then the corresponding
+///    dimensions in the expanded shape should be
 ///    a) static
 ///    b) the product should be same as the collaped shape.
+/// 2) if a dimension in the collaped type is dynamic, one and only one of the
+///    corresponding dimensions in the expanded type should be dynamic. This
+///    rule is only needed with reshape operations that are expanding.
 LogicalResult reshapeLikeShapesAreCompatible(
     function_ref<LogicalResult(const Twine &)> emitError,
     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
@@ -174,11 +156,9 @@ LogicalResult reshapeLikeShapesAreCompatible(
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
-enum class ReshapeOpKind { kExpand, kCollapse };
-
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy, ReshapeOpKind opKind>
+template <typename ReshapeOpTy>
 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -201,18 +181,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
-
-    if constexpr (opKind == ReshapeOpKind::kExpand) {
-      SmallVector<OpFoldResult> outputShape(
-          getMixedValues(reshapeOp.getStaticOutputShape(),
-                         reshapeOp.getOutputShape(), rewriter));
-      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
-          outputShape);
-    } else {
-      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
-    }
+    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+        reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
     return success();
   }
 };
@@ -245,8 +215,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
-          typename DimOpTy, typename TensorTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -353,11 +322,8 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
     if (!composedReassociation)
       return failure();
 
-    SmallVector<OpFoldResult> outputShape(getMixedValues(
-        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
     rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
-        outputShape);
+        expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
     return success();
   }
 
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 594bcf5dbb399a..20f019666a2e6a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,8 +125,9 @@ SmallVect...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 21, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: Mehdi Amini (joker-eph)

Changes

Reverts llvm/llvm-project#69267

this broke some bots.


Patch is 259.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89540.diff

52 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+22-58)
  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+23-56)
  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+12-46)
  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+3-2)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+4-12)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+4-6)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+5-8)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+18-49)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp (-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+3-5)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+7-75)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+2-9)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+14-71)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (-3)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp (+8-16)
  • (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+14-69)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+4-3)
  • (modified) mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir (+8-8)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+2-2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-19)
  • (modified) mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir (+33-81)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir (+1-3)
  • (modified) mlir/test/Dialect/Linalg/collapse-dim.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir (+10-10)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+12-18)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+37-71)
  • (modified) mlir/test/Dialect/Linalg/flatten-elementwise.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+44-57)
  • (modified) mlir/test/Dialect/Linalg/fusion-push-reshape.mlir (+10-14)
  • (modified) mlir/test/Dialect/Linalg/reshape_control_fusion.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+100-192)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+3-2)
  • (modified) mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+1-3)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+18-17)
  • (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+7-9)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+10-12)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+24-14)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+32-40)
  • (modified) mlir/test/Dialect/MemRef/runtime-verification.mlir (+2-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (+6-6)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+10-14)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+59-53)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+3-2)
  • (modified) mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (+2-4)
  • (modified) mlir/test/Dialect/Tensor/invalid.mlir (+16-5)
  • (modified) mlir/test/Dialect/Tensor/ops.mlir (+2-16)
  • (modified) mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir (+7-7)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (-1)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 14b8d95ea15b41..39e66cd9e6e5ab 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1548,6 +1548,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
 class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     MemRef_Op<mnemonic, !listconcat(traits,
       [Pure, ViewLikeOpInterface])>,
+    Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyStridedMemRef:$result)>{
 
   code commonExtraClassDeclaration = [{
@@ -1572,6 +1573,10 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Value getViewSource() { return getSrc(); }
   }];
 
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1593,10 +1598,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     Example:
 
     ```mlir
-    %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
-        : memref<?x32xf32> into memref<?x?x32xf32>
+    %r = memref.expand_shape %0 [[0, 1], [2]]
+        : memref<?x?xf32> into memref<?x5x?xf32>
     ```
 
+    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
+    dynamic in the result type. Otherwise, the op would be ambiguous, as it
+    would not be clear how the source dimension is extended.
+
     If an op can be statically proven to be invalid (e.g, an expansion from
     `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
     it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1613,80 +1622,41 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
     there must be a dynamic result dimension in the corresponding reassociation
     group. Same for strides.
 
-    The representation for the output shape supports a partially-static
-    specification via attributes specified through the `static_output_shape`
-    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
-    corresponding entry has a dynamic value.  There must be exactly as many SSA
-    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
-    `static_output_shape`.
-
     Note: This op currently assumes that the inner strides are of the
     source/result layout map are the faster-varying ones.
   }];
 
-  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
-                       Variadic<Index>:$output_shape,
-                       DenseI64ArrayAttr:$static_output_shape);
-
-  let assemblyFormat = [{
-    $src $reassociation `output_shape`
-    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
-    type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape)>,
-
-    // It will infer output shape using inferOutputShape() method.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation)>,
-
-    // Builder using ReassociationExprs.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices);
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+                          getReassociationIndicesAttribute($_builder, reassociation));
     }]>,
 
+    // Builder using ReassociationExprs.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationMaps,
-            outputShape);
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>,
 
-    // Builder that infers the result layout map. The result shape must be
-    // specified. Otherwise, the op may be ambiguous. The output shape for 
-    // the op will be inferred using the inferOutputShape() method.
-    OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation)>,
-
     // Builder that infers the result layout map. The result shape must be
     // specified. Otherwise, the op may be ambiguous.
     OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
-               "ArrayRef<ReassociationIndices>":$reassociation,
-               "ArrayRef<OpFoldResult>":$outputShape)>
+               "ArrayRef<ReassociationIndices>":$reassociation)>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     static FailureOr<MemRefType> computeExpandedType(
         MemRefType srcType, ArrayRef<int64_t> resultShape,
         ArrayRef<ReassociationIndices> reassociation);
-
-    // Infer the output shape for a memref.expand_shape when it is possible
-    // to do so.
-    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
-        OpBuilder &b, Location loc, MemRefType expandedType,
-        ArrayRef<ReassociationIndices> reassociation,
-        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1737,12 +1707,6 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
     source/result layout map are the faster-varying ones.
   }];
 
-  let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
-
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1754,7 +1718,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1772,7 +1736,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98c..cf7f3e89079c1c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1062,7 +1062,8 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Results<(outs AnyTensor:$result)> {
+    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
+    Results<(outs AnyRankedTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1085,6 +1086,10 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     }
   }];
 
+  let assemblyFormat = [{
+    $src $reassociation attr-dict `:` type($src) `into` type($result)
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -1097,75 +1102,43 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     rank than the operand `src` whose dimension sizes are a reassociation of
     `src`.
 
-    A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of DenseI64ArrayAttr attribute.  The reassociation
-    maps applied to the result tensor with the higher rank must result in the
-    operand tensor with the smaller rank.
+    A reassociation is defined as a continuous grouping of dimensions. It is
+    represented with an array of DenseI64ArrayAttr attribute. Entries in the
+    array are referred to as reassociation maps.
 
-    The representation for the output shape supports a partially-static
-    specification via attributes specified through the `static_output_shape`
-    argument.  A special sentinel value `ShapedType::kDynamic` encodes that the
-    corresponding entry has a dynamic value.  There must be exactly as many SSA
-    inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
-    `static_output_shape`.
+    The reassociation maps are applied to the result shape to obtain the operand
+    shape.
 
     Example:
 
     ```mlir
     // Dimension expansion i -> (i', j') and (k) -> (k')
-    %b = tensor.expand_shape %a [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
-        : tensor<?x32xf32> into tensor<?x?x32xf32>
+    %b = tensor.expand_shape %a [[0, 1], [2]]
+        : tensor<?x?xf32> into tensor<?x?x?xf32>
     ```
   }];
-
-  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation,
-                       Variadic<Index>:$output_shape,
-                       DenseI64ArrayAttr:$static_output_shape);
-
-  let assemblyFormat = [{
-    $src $reassociation `output_shape`
-    custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
-    type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationIndices>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape)>,
-
-    // It will infer output shape using inferOutputShape() method.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationIndices>":$reassociation)>,
-
-    // Builder using ReassociationExprs.
-    OpBuilder<(ins "Type":$resultType, "Value":$src,
-      "ArrayRef<ReassociationExprs>":$reassociation),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices);
+      build($_builder, $_state, resultType, src, attrs);
+      $_state.addAttribute("reassociation",
+          getReassociationIndicesAttribute($_builder, reassociation));
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
-      "ArrayRef<OpFoldResult>":$outputShape),
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      auto reassociationIndices =
-          convertReassociationMapsToIndices(reassociation);
-      build($_builder, $_state, resultType, src, reassociationIndices,
-            outputShape);
+      auto reassociationMaps =
+          convertReassociationMapsToIndices($_builder, reassociation);
+      build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
 
   let extraClassDeclaration = commonExtraClassDeclaration # [{
     int64_t getCorrespondingSourceDim(int64_t resultDim);
-
-    // Infer the output shape for a tensor.expand_shape when it is possible
-    // to do so.
-    static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
-        OpBuilder &b, Location loc, RankedTensorType expandedType,
-        ArrayRef<ReassociationIndices> reassociation,
-        ArrayRef<OpFoldResult> inputShape);
   }];
 
   let hasVerifier = 1;
@@ -1173,7 +1146,6 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
 
 def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
   let summary = "operation to produce a tensor with a smaller rank";
-  let arguments = (ins AnyTensor:$src, IndexListArrayAttr:$reassociation);
   let description = [{
     The `tensor.collapse_shape` op produces a new tensor of lower (or equal)
     rank whose dimension sizes are a reassociation of the original `src` dimensions.
@@ -1191,11 +1163,6 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
         : tensor<?x?x?xf32> into tensor<?x?xf32>
     ```
   }];
-
-  let assemblyFormat = [{
-    $src $reassociation attr-dict `:` type($src) `into` type($result)
-  }];
-
   let builders = [
     // Builders for a contracting reshape whose result type is computed from
     // `src` and `reassociation`.
@@ -1207,7 +1174,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, src, reassociationMaps, attrs);
     }]>,
 
@@ -1225,7 +1192,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       auto reassociationMaps =
-          convertReassociationMapsToIndices(reassociation);
+          convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
     }]>
   ];
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 8a41a0a18b0ab3..ae9824f728da4d 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -30,27 +30,6 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
 /// Attribute name for the ArrayAttr which encodes reassociation indices.
 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
 
-// Infer the output shape for a {memref|tensor}.expand_shape when it is possible
-// to do so.
-//
-// Note: This should *only* be used to implement
-// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
-// If you need to infer the output shape you should use the static method of
-// `ExpandShapeOp` instead of calling this.
-//
-// `inputShape` is the shape of the tensor or memref being expanded as a
-// sequence of SSA values or constants. `expandedType` is the output shape of
-// the expand_shape operation. `reassociation` is the reassociation denoting
-// the output dims each input dim is mapped to.
-//
-// Returns the output shape in `outputShape` and `staticOutputShape`, following
-// the conventions for the output_shape and static_output_shape inputs to the
-// expand_shape ops.
-std::optional<SmallVector<OpFoldResult>>
-inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
-                            ArrayRef<ReassociationIndices> reassociation,
-                            ArrayRef<OpFoldResult> inputShape);
-
 /// Compose reassociation maps that are used in pair of reshape ops where one
 /// is a producer and other is the consumer. Only valid to use this method when
 /// both the producer and consumer are collapsing dimensions or both are
@@ -83,7 +62,7 @@ getReassociationIndicesAttribute(OpBuilder &b,
 
 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
-    ArrayRef<ReassociationExprs> reassociationExprs);
+    OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
 
 /// Return the reassociations maps to use to reshape given the source type and
 /// the target type when possible. Return std::nullopt when this computation
@@ -161,11 +140,14 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
       op.getReassociationIndices(), isExpansion);
 }
 
-/// Verify that shapes of the reshaped types using following rule:
-/// if a dimension in the collapsed type is static, then the corresponding
-/// dimensions in the expanded shape should be
+/// Verify that shapes of the reshaped types using following rules
+/// 1) if a dimension in the collapsed type is static, then the corresponding
+///    dimensions in the expanded shape should be
 ///    a) static
 ///    b) the product should be same as the collaped shape.
+/// 2) if a dimension in the collaped type is dynamic, one and only one of the
+///    corresponding dimensions in the expanded type should be dynamic. This
+///    rule is only needed with reshape operations that are expanding.
 LogicalResult reshapeLikeShapesAreCompatible(
     function_ref<LogicalResult(const Twine &)> emitError,
     ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
@@ -174,11 +156,9 @@ LogicalResult reshapeLikeShapesAreCompatible(
 /// Returns true iff the type is a MemRefType and has a non-identity layout.
 bool hasNonIdentityLayout(Type type);
 
-enum class ReshapeOpKind { kExpand, kCollapse };
-
 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
 /// dimensions or are both expanding dimensions.
-template <typename ReshapeOpTy, ReshapeOpKind opKind>
+template <typename ReshapeOpTy>
 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
   using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
@@ -201,18 +181,8 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
                                     rewriter.getContext());
     if (!reassociationIndices)
       return failure();
-
-    if constexpr (opKind == ReshapeOpKind::kExpand) {
-      SmallVector<OpFoldResult> outputShape(
-          getMixedValues(reshapeOp.getStaticOutputShape(),
-                         reshapeOp.getOutputShape(), rewriter));
-      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
-          outputShape);
-    } else {
-      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
-          reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
-    }
+    rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+        reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
     return success();
   }
 };
@@ -245,8 +215,7 @@ struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
 //
 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
 /// `reassociation_2` and produce `expand_shape`.
-template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
-          typename DimOpTy, typename TensorTy>
+template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
   using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
@@ -353,11 +322,8 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
     if (!composedReassociation)
       return failure();
 
-    SmallVector<OpFoldResult> outputShape(getMixedValues(
-        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
     rewriter.replaceOpWithNewOp<ExpandOpTy>(
-        expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
-        outputShape);
+        expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
     return success();
   }
 
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 594bcf5dbb399a..20f019666a2e6a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -125,8 +125,9 @@ SmallVect...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:bufferization Bufferization infrastructure mlir:linalg mlir:memref mlir:sparse Sparse compiler in MLIR mlir:tensor mlir skip-precommit-approval PR for CI feedback, not intended for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants