diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index ae8e3528b02e0..acedf51d0e240 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -160,6 +160,12 @@ struct BubbleUpExpandThroughParallelCollapse auto expandReInds = expandOp.getReassociationIndices(); auto collapseReInds = collapseOp.getReassociationIndices(); + // Special case where the collapsed tensor to expand is a 0-D tensor, + // then the reassociation maps will be empty and not produce valid results. + if (expandReInds.size() == 0) { + return failure(); + } + // Reshapes are parallel to each other if none of the reassociation indices // have greater than 1 index for both reshapes. for (auto [expandReassociation, collapseReassociation] : diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir index cf6b12852bcd3..eeed794884942 100644 --- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir @@ -45,3 +45,17 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor, % // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]] // CHECK: return %[[EXPAND]] + +// ----- + +func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [] : tensor into tensor + %expand = tensor.expand_shape %collapse [] + output_shape [%s0, %s1, %s2, %s3] : tensor into tensor + return %expand : tensor +} +// CHECK: func @no_bubble_0d_tensor_reshapes +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}] +// CHECK: return %[[EXPAND]]