diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index ff5f7f685903f..bf3f8343f35db 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -229,6 +229,15 @@ PowIStrengthReduction::matchAndRewrite( // Inverse the base for negative exponent, i.e. for // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. if (exponentIsNegative) { + // For integer-base power ops (`math.ipowi`), a negative exponent produces + // an integer division `1/x^|n|`. This is: + // - Division by zero when x == 0 (undefined behaviour). + // - Integer truncation to 0 for |x| > 1 (almost certainly not intended). + // The signed interpretation of narrow integer types (e.g. i1 `true` == -1) + // makes this especially surprising. Don't perform this transformation for + // integer power; leave it to the runtime or other lowerings. + if constexpr (std::is_same_v) + return failure(); if constexpr (std::is_same_v) result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one), result, op.getFastmathAttr()); diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir index 7342600748967..2c55a0b3acc04 100644 --- a/mlir/test/Dialect/Math/algebraic-simplification.mlir +++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -170,6 +170,19 @@ func.func @pow_0_75_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf3 return %0, %1 : f32, vector<4xf32> } +// CHECK-LABEL: @ipowi_i1_true_exp( +func.func @ipowi_i1_true_exp(%arg0: i1) -> i1 { + // `true : i1` has signed value -1. Negative-exponent ipowi is not + // transformed to avoid potential division-by-zero (e.g. when base == 0). + // CHECK-NOT: arith.divsi + // CHECK: %[[TRUE:.*]] = arith.constant true + // CHECK: %[[RES:.*]] = math.ipowi %arg0, %[[TRUE]] : i1 + // CHECK: return %[[RES]] + %true = arith.constant true + %res = math.ipowi %arg0, %true : i1 + return %res : i1 +} + // CHECK-LABEL: @ipowi_zero_exp( // CHECK-SAME: %[[ARG0:.+]]: i32 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> @@ -190,10 +203,14 @@ func.func @ipowi_zero_exp(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi3 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> // CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { func.func @ipowi_exp_one(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { - // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1 : i32 - // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> - // CHECK: %[[SCALAR:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] - // CHECK: %[[VECTOR:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] + // Positive exponent 1: x^1 = x (identity). + // Negative exponent -1: not transformed to avoid potential division by zero + // (x^-1 = 1/x is UB when x == 0). + // CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : i32 + // CHECK-DAG: %[[VM1:.*]] = arith.constant dense<-1> : vector<4xi32> + // CHECK-NOT: arith.divsi + // CHECK: %[[SCALAR:.*]] = math.ipowi %[[ARG0]], %[[CM1]] : i32 + // CHECK: %[[VECTOR:.*]] = math.ipowi %[[ARG1]], %[[VM1]] : vector<4xi32> // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]] %c1 = arith.constant 1 : i32 %v1 = arith.constant dense <1> : vector<4xi32> @@ -211,14 +228,15 @@ func.func @ipowi_exp_one(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> // CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { - // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1 : i32 - // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // Positive exponent 2: x^2 = x*x. + // Negative exponent -2: not transformed (avoid potential division by zero). + // CHECK-DAG: %[[CM2:.*]] = arith.constant -2 : i32 + // CHECK-DAG: %[[VM2:.*]] = arith.constant dense<-2> : vector<4xi32> + // CHECK-NOT: arith.divsi // CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]] // CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]] - // CHECK: %[[SMUL:.*]] = arith.muli %[[ARG0]], %[[ARG0]] - // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL]] - // CHECK: %[[VMUL:.*]] = arith.muli %[[ARG1]], %[[ARG1]] - // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL]] + // CHECK: %[[SCALAR1:.*]] = math.ipowi %[[ARG0]], %[[CM2]] : i32 + // CHECK: %[[VECTOR1:.*]] = math.ipowi %[[ARG1]], %[[VM2]] : vector<4xi32> // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]] %c1 = arith.constant 2 : i32 %v1 = arith.constant dense <2> : vector<4xi32> @@ -236,18 +254,17 @@ func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32 // CHECK-SAME: %[[ARG1:.+]]: vector<4xi32> // CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) { func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) { - // CHECK-DAG: %[[CST_S:.*]] = arith.constant 1 : i32 - // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> + // Positive exponent 3: x^3 = x*x*x. + // Negative exponent -3: not transformed (avoid potential division by zero). + // CHECK-DAG: %[[CM3:.*]] = arith.constant -3 : i32 + // CHECK-DAG: %[[VM3:.*]] = arith.constant dense<-3> : vector<4xi32> + // CHECK-NOT: arith.divsi // CHECK: %[[SMUL0:.*]] = arith.muli %[[ARG0]], %[[ARG0]] // CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]] // CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]] // CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]] - // CHECK: %[[SMUL1:.*]] = arith.muli %[[ARG0]], %[[ARG0]] - // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[ARG0]] - // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL2]] - // CHECK: %[[VMUL1:.*]] = arith.muli %[[ARG1]], %[[ARG1]] - // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[ARG1]] - // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL2]] + // CHECK: %[[SCALAR1:.*]] = math.ipowi %[[ARG0]], %[[CM3]] : i32 + // CHECK: %[[VECTOR1:.*]] = math.ipowi %[[ARG1]], %[[VM3]] : vector<4xi32> // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]] %c1 = arith.constant 3 : i32 %v1 = arith.constant dense <3> : vector<4xi32>