From 3624e3afa1d2630d51e149fdd3f7e87afdbdc0e7 Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Sat, 22 Feb 2025 17:00:07 +0900 Subject: [PATCH 1/3] [mlir][math] expand-math pass assumes the static shaped type In the process of `expand-math` pass, the conversion of ceil op assumes the static shaped type as input as it needs create 0 and 1 constant values whose type is aligned with the op type. Fixes https://github.com/llvm/llvm-project/issues/128275 --- .../Math/Transforms/ExpandPatterns.cpp | 5 +++++ mlir/test/Dialect/Math/expand-math.mlir | 22 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 23356d752146d..67e8dbba989b7 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -222,6 +222,11 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { // if (x > y) then incr = 1 else incr = 0 // y = y + incr <= replace this op with the ceilf op. static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { + // Creating constants assumes the statis shaped type. + auto shapedType = dyn_cast(op.getType()); + if (shapedType && !shapedType.hasStaticShape()) + return failure(); + ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); Type opType = operand.getType(); diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 1fdfb854325b4..4e249ec510afa 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -761,3 +761,25 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) { %float_result = math.rsqrt %float : tensor<5x8xf32> return %float_result : tensor<5x8xf32> } + +// ----- + +// CHECK-LABEL func.func @non_static_shape_ceil_op +// CHECK: %[[IDX:.*]] = index.constant 0 +// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK: %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor +// CHECK: %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor +// CHECK: %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor +// CHECK: vector.print %[[DIM]] : index +// CHECK: return + +func.func @non_static_shape_ceil_op() { + %idx0 = index.constant 0 + %cst_90 = arith.constant 1.000000e+00 : f32 + %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32> + %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor + %112 = math.ceil %cast_93 : tensor + %dim_233 = tensor.dim %112, %idx0 : tensor + vector.print %dim_233 : index + return +} From a436c01811461b624dd049a4f89da45c02ef9a97 Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Sun, 23 Feb 2025 14:38:15 +0900 Subject: [PATCH 2/3] Post review follow-up --- .../Math/Transforms/ExpandPatterns.cpp | 2 +- mlir/test/Dialect/Math/expand-math.mlir | 27 +++++++------------ 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 67e8dbba989b7..bb592c667549c 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -222,7 +222,7 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { // if (x > y) then incr = 1 else incr = 0 // y = y + incr <= replace this op with the ceilf op. static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { - // Creating constants assumes the statis shaped type. + // Creating constants assumes the static shaped type. auto shapedType = dyn_cast(op.getType()); if (shapedType && !shapedType.hasStaticShape()) return failure(); diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 4e249ec510afa..56d562ad0b3fe 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -764,22 +764,13 @@ func.func @rsqrt_tns(%float: tensor<5x8xf32>) -> (tensor<5x8xf32>) { // ----- -// CHECK-LABEL func.func @non_static_shape_ceil_op -// CHECK: %[[IDX:.*]] = index.constant 0 -// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32> -// CHECK: %[[CAST:.*]] = tensor.cast %[[CST]] : tensor<2xf32> to tensor -// CHECK: %[[CEIL:.*]] = math.ceil %[[CAST]] : tensor -// CHECK: %[[DIM:.*]] = tensor.dim %[[CEIL]], %[[IDX]] : tensor -// CHECK: vector.print %[[DIM]] : index -// CHECK: return - -func.func @non_static_shape_ceil_op() { - %idx0 = index.constant 0 - %cst_90 = arith.constant 1.000000e+00 : f32 - %from_elements_92 = tensor.from_elements %cst_90, %cst_90 : tensor<2xf32> - %cast_93 = tensor.cast %from_elements_92 : tensor<2xf32> to tensor - %112 = math.ceil %cast_93 : tensor - %dim_233 = tensor.dim %112, %idx0 : tensor - vector.print %dim_233 : index - return +// CHECK-LABEL: func.func @non_static_shape_ceil_op +// CHECK-SAME: (%[[ARG:.*]]: tensor) +// CHECK-SAME: -> tensor +// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor +// CHECK: return %[[CEIL]] : tensor + +func.func @non_static_shape_ceil_op(%arg: tensor) -> tensor{ + %a = math.ceil %arg : tensor + return %a: tensor } From 9bc81d7705862a0c3031d836ad5298395be53f64 Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Tue, 25 Feb 2025 09:58:54 +0900 Subject: [PATCH 3/3] Add test case for unranked type --- mlir/test/Dialect/Math/expand-math.mlir | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 56d562ad0b3fe..946a411e4cc4b 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -774,3 +774,16 @@ func.func @non_static_shape_ceil_op(%arg: tensor) -> tensor{ %a = math.ceil %arg : tensor return %a: tensor } + +// ----- + +// CHECK-LABEL: func.func @unranked_ceil_op +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +// CHECK-SAME: -> tensor<*xf32> +// CHECK: %[[CEIL:.*]] = math.ceil %[[ARG]] : tensor<*xf32> +// CHECK: return %[[CEIL]] : tensor<*xf32> + +func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{ + %a = math.ceil %arg : tensor<*xf32> + return %a: tensor<*xf32> +}