diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir index a101b76ef186b5..db882f7a54d392 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-expand-tensor.mlir @@ -35,6 +35,12 @@ func.func @main() { func.func private @printMemrefF32(%ptr : tensor<*xf32>) func.func @expand_dynamic_shape(%arg0 : tensor<2x?x?xf32>) -> tensor<2x2x?x1x?xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2, 3], [4]]: tensor<2x?x?xf32> into tensor<2x2x?x1x?xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %d1 = tensor.dim %arg0, %c1 : tensor<2x?x?xf32> + %d2 = tensor.dim %arg0, %c2 : tensor<2x?x?xf32> + %sz1 = arith.divui %d1, %c2 : index + %0 = tensor.expand_shape %arg0 [[0], [1, 2, 3], [4]] output_shape [2, 2, %sz1, 1, %d2] : tensor<2x?x?xf32> into tensor<2x2x?x1x?xf32> return %0 : tensor<2x2x?x1x?xf32> }