From 52fc4aa7a83332eeab6ff8ddc9f68668144e3f38 Mon Sep 17 00:00:00 2001 From: Johannes Reifferscheid Date: Thu, 27 Jun 2024 13:04:50 +0200 Subject: [PATCH] Do not trigger UB during AffineExpr parsing. Currently, parsing expressions that are undefined will trigger UB during compilation (e.g. INT_MIN / -1). This change instead leaves the expressions as they were. This change is an NFC for compilations that did not previously involve UB. --- llvm/include/llvm/Support/MathExtras.h | 6 ++-- mlir/lib/IR/AffineExpr.cpp | 44 ++++++++++++++++++------ mlir/unittests/IR/AffineExprTest.cpp | 46 ++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 12 deletions(-) diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h index 3bba999fb00e9..6de754f472635 100644 --- a/llvm/include/llvm/Support/MathExtras.h +++ b/llvm/include/llvm/Support/MathExtras.h @@ -435,7 +435,8 @@ inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) { } /// Returns the integer ceil(Numerator / Denominator). Signed version. -/// Guaranteed to never overflow. +/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator +/// is -1. inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) { assert(Denominator && "Division by zero"); if (!Numerator) @@ -448,7 +449,8 @@ inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) { } /// Returns the integer floor(Numerator / Denominator). Signed version. -/// Guaranteed to never overflow. +/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator +/// is -1. inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) { assert(Denominator && "Division by zero"); if (!Numerator) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 1fab33327ba76..cf8157cf7bb8c 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include +#include #include #include "AffineExprDetail.h" @@ -645,10 +647,14 @@ mlir::getAffineConstantExprs(ArrayRef constants, static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { auto lhsConst = dyn_cast(lhs); auto rhsConst = dyn_cast(rhs); - // Fold if both LHS, RHS are a constant. - if (lhsConst && rhsConst) - return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), - lhs.getContext()); + // Fold if both LHS, RHS are a constant and the sum does not overflow. + if (lhsConst && rhsConst) { + int64_t sum; + if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) { + return nullptr; + } + return getAffineConstantExpr(sum, lhs.getContext()); + } // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). // If only one of them is a symbolic expressions, make it the RHS. @@ -774,9 +780,13 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { auto lhsConst = dyn_cast(lhs); auto rhsConst = dyn_cast(rhs); - if (lhsConst && rhsConst) - return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), - lhs.getContext()); + if (lhsConst && rhsConst) { + int64_t product; + if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) { + return nullptr; + } + return getAffineConstantExpr(product, lhs.getContext()); + } if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) return nullptr; @@ -849,10 +859,16 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { if (!rhsConst || rhsConst.getValue() < 1) return nullptr; - if (lhsConst) + if (lhsConst) { + // divideFloorSigned can only overflow in this case: + if (lhsConst.getValue() == std::numeric_limits::min() && + rhsConst.getValue() == -1) { + return nullptr; + } return getAffineConstantExpr( divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); + } // Fold floordiv of a multiply with a constant that is a multiple of the // divisor. Eg: (i * 128) floordiv 64 = i * 2. @@ -905,10 +921,16 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { if (!rhsConst || rhsConst.getValue() < 1) return nullptr; - if (lhsConst) + if (lhsConst) { + // divideCeilSigned can only overflow in this case: + if (lhsConst.getValue() == std::numeric_limits::min() && + rhsConst.getValue() == -1) { + return nullptr; + } return getAffineConstantExpr( divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); + } // Fold ceildiv of a multiply with a constant that is a multiple of the // divisor. Eg: (i * 128) ceildiv 64 = i * 2. @@ -950,9 +972,11 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { if (!rhsConst || rhsConst.getValue() < 1) return nullptr; - if (lhsConst) + if (lhsConst) { + // mod never overflows. return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext()); + } // Fold modulo of an expression that is known to be a multiple of a constant // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp index ff154eb29807c..9740165c6b324 100644 --- a/mlir/unittests/IR/AffineExprTest.cpp +++ b/mlir/unittests/IR/AffineExprTest.cpp @@ -6,6 +6,9 @@ // //===----------------------------------------------------------------------===// +#include +#include + #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "gtest/gtest.h" @@ -30,3 +33,46 @@ TEST(AffineExprTest, constructFromBinaryOperators) { ASSERT_EQ(product.getKind(), AffineExprKind::Mul); ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod); } + +TEST(AffineExprTest, constantFolding) { + MLIRContext ctx; + OpBuilder b(&ctx); + auto cn1 = b.getAffineConstantExpr(-1); + auto c0 = b.getAffineConstantExpr(0); + auto c1 = b.getAffineConstantExpr(1); + auto c2 = b.getAffineConstantExpr(2); + auto c3 = b.getAffineConstantExpr(3); + auto c6 = b.getAffineConstantExpr(6); + auto cmax = b.getAffineConstantExpr(std::numeric_limits::max()); + auto cmin = b.getAffineConstantExpr(std::numeric_limits::min()); + + ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3); + ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6); + ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1); + ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2); + + // Test division by zero: + auto c3ceildivc0 = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c0); + ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv); + + auto c3floordivc0 = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c0); + ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv); + + auto c3modc0 = getAffineBinaryOpExpr(AffineExprKind::Mod, c3, c0); + ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod); + + // Test overflow: + auto cmaxplusc1 = getAffineBinaryOpExpr(AffineExprKind::Add, cmax, c1); + ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add); + + auto cmaxtimesc2 = getAffineBinaryOpExpr(AffineExprKind::Mul, cmax, c2); + ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul); + + auto cminceildivcn1 = + getAffineBinaryOpExpr(AffineExprKind::CeilDiv, cmin, cn1); + ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv); + + auto cminfloordivcn1 = + getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1); + ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv); +}