diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index d2e2e1314dc4f1..a198d1a208b726 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1524,59 +1524,62 @@ static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, // divisor * id <= expr <-- Upper bound for 'id' // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). // - // For example, if -32*k + 16*i + j >= 0 - // 32*k - 16*i - j + 31 >= 0 <=> - // k = ( 16*i + j ) floordiv 32 - unsigned seenDividends = 0; + // For example: + // 32*k >= 16*i + j - 31 <-- Lower bound for 'k' + // 32*k <= 16*i + j <-- Upper bound for 'k' + // expr = 16*i + j, divisor = 32 + // k = ( 16*i + j ) floordiv 32 + // + // 4q >= i + j - 2 <-- Lower bound for 'q' + // 4q <= i + j + 1 <-- Upper bound for 'q' + // expr = i + j + 1, divisor = 4 + // q = (i + j + 1) floordiv 4 for (auto ubPos : ubIndices) { for (auto lbPos : lbIndices) { - // Check if the lower bound's constant term is divisor - 1. The - // 'divisor' here is cst.atIneq(lbPos, pos) and we already know that it's - // positive (since cst.Ineq(lbPos, ...) is a lower bound expr for 'pos'. + // Due to the form of the inequalities, the sum of constants of upper + // bound and lower bound is divisor - 1. The 'divisor' here is + // cst.atIneq(lbPos, pos) and we already know that it's positive (since + // cst.Ineq(lbPos, ...) is a lower bound expr for 'pos'. + // Check if this sum of constants is divisor - 1. int64_t divisor = cst.atIneq(lbPos, pos); - int64_t lbConstTerm = cst.atIneq(lbPos, cst.getNumCols() - 1); - if (lbConstTerm != divisor - 1) - continue; - // Check if upper bound's constant term is 0. - if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) + int64_t constantSum = cst.atIneq(lbPos, cst.getNumCols() - 1) + + cst.atIneq(ubPos, cst.getNumCols() - 1); + if (constantSum != divisor - 1) continue; // For the remaining part, check if the lower bound expr's coeff's are // negations of corresponding upper bound ones'. unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + for (c = 0, f = cst.getNumCols() - 1; c < f; ++c) if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) break; - if (c != pos && cst.atIneq(lbPos, c) != 0) - seenDividends++; - } // Lb coeff's aren't negative of ub coeff's (for the non constant term // part). if (c < f) continue; - if (seenDividends >= 1) { - // Construct the dividend expression. - auto dividendExpr = getAffineConstantExpr(0, context); - unsigned c, f; - for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { - if (c == pos) - continue; - int64_t ubVal = cst.atIneq(ubPos, c); - if (ubVal == 0) - continue; - if (!exprs[c]) - break; - dividendExpr = dividendExpr + ubVal * exprs[c]; - } - // Expression can't be constructed as it depends on a yet unknown - // identifier. - // TODO: Visit/compute the identifiers in an order so that this doesn't - // happen. More complex but much more efficient. - if (c < f) + // Due to the form of the upper bound inequality, the constant term of + // `expr` is the constant term of upper bound inequality. + int64_t divConstantTerm = cst.atIneq(ubPos, cst.getNumCols() - 1); + // Construct the dividend expression. + auto dividendExpr = getAffineConstantExpr(divConstantTerm, context); + for (c = 0, f = cst.getNumCols() - 1; c < f; ++c) { + if (c == pos) continue; - // Successfully detected the floordiv. - exprs[pos] = dividendExpr.floorDiv(divisor); - return true; + int64_t ubVal = cst.atIneq(ubPos, c); + if (ubVal == 0) + continue; + if (!exprs[c]) + break; + dividendExpr = dividendExpr + ubVal * exprs[c]; } + // Expression can't be constructed as it depends on a yet unknown + // identifier. + // TODO: Visit/compute the identifiers in an order so that this doesn't + // happen. More complex but much more efficient. + if (c < f) + continue; + // Successfully detected the floordiv. + exprs[pos] = dividendExpr.floorDiv(divisor); + return true; } } return false; diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp index 3ee6f049c9a20e..a0344a43e670ff 100644 --- a/mlir/unittests/Analysis/AffineStructuresTest.cpp +++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineStructures.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/MLIRContext.h" #include #include @@ -587,4 +589,25 @@ TEST(FlatAffineConstraintsTest, clearConstraints) { EXPECT_EQ(fac.atIneq(0, 1), 0); } +TEST(FlatAffineConstraintsTest, constantDivs) { + // This test checks if floordivs with numerator containing non zero constant + // term can be computed from a FlatAffineConstraints instance. + FlatAffineConstraints fac = makeFACFromConstraints(4, {}, {}); + + // Build a FlatAffineConstraints instance with floordivs containing numerator + // with non zero constant term. + fac.addLocalFloorDiv({0, 1, 0, 0, 10}, 30); + fac.addLocalFloorDiv({1, 0, 0, 0, 0, 99}, 101); + + // Add inequalities using the local variables created above. + fac.addInequality({1, 0, 0, 0, 1, 0, 2}); + fac.addInequality({1, 0, 0, 0, 0, 1, 5}); + + // FlatAffineConstraints::getAsIntegerSet returns a null integer set if an + // explicit representation for each local variable could not be found. + MLIRContext ctx; + IntegerSet iSet = fac.getAsIntegerSet(&ctx); + EXPECT_TRUE((bool)iSet); +} + } // namespace mlir