From 3545795998e71f190abadb829aae21224d2cf58d Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Thu, 6 Jun 2024 18:40:28 +0100 Subject: [PATCH 1/3] mlir/Presburger: contribute a free-standing parser The Presburger library is already quite independent of MLIR, with the exception of the MLIR Support library. There is, however, one major exception: the test suite for the library depends on the core MLIR parser. To free it of this dependency, extract the parts of the core MLIR parser that are applicable to the Presburger test suite, author custom parsing data structures, and adapt the new parser to parse into these structures. This patch is part of a project to move the Presburger library into LLVM. --- .../include/mlir/Analysis/Presburger/Matrix.h | 3 + .../mlir}/Analysis/Presburger/Parser.h | 55 +- mlir/lib/Analysis/Presburger/CMakeLists.txt | 2 + mlir/lib/Analysis/Presburger/Matrix.cpp | 8 + .../Analysis/Presburger/Parser/CMakeLists.txt | 6 + .../Analysis/Presburger/Parser/Flattener.cpp | 88 +++ .../Analysis/Presburger/Parser/Flattener.h | 67 ++ mlir/lib/Analysis/Presburger/Parser/Lexer.cpp | 136 ++++ mlir/lib/Analysis/Presburger/Parser/Lexer.h | 56 ++ .../Presburger/Parser/ParseStructs.cpp | 268 ++++++++ .../Analysis/Presburger/Parser/ParseStructs.h | 353 ++++++++++ .../Analysis/Presburger/Parser/ParserImpl.cpp | 640 ++++++++++++++++++ .../Analysis/Presburger/Parser/ParserImpl.h | 160 +++++ mlir/lib/Analysis/Presburger/Parser/Token.cpp | 64 ++ mlir/lib/Analysis/Presburger/Parser/Token.h | 91 +++ .../Analysis/Presburger/Parser/TokenKinds.def | 72 ++ .../Analysis/Presburger/BarvinokTest.cpp | 4 +- .../Analysis/Presburger/CMakeLists.txt | 4 +- .../Analysis/Presburger/FractionTest.cpp | 1 - .../Presburger/GeneratingFunctionTest.cpp | 2 +- .../Presburger/IntegerPolyhedronTest.cpp | 2 +- .../Presburger/IntegerRelationTest.cpp | 2 +- .../Analysis/Presburger/MatrixTest.cpp | 3 +- .../Analysis/Presburger/PWMAFunctionTest.cpp | 5 +- .../Analysis/Presburger/ParserTest.cpp | 2 +- .../Presburger/PresburgerRelationTest.cpp | 2 +- .../Analysis/Presburger/PresburgerSetTest.cpp | 3 +- .../Presburger/QuasiPolynomialTest.cpp | 4 +- .../Analysis/Presburger/SimplexTest.cpp | 5 +- mlir/unittests/Analysis/Presburger/Utils.h | 1 - 30 files changed, 2045 insertions(+), 64 deletions(-) rename mlir/{unittests => include/mlir}/Analysis/Presburger/Parser.h (65%) create mode 100644 mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt create mode 100644 mlir/lib/Analysis/Presburger/Parser/Flattener.cpp create mode 100644 mlir/lib/Analysis/Presburger/Parser/Flattener.h create mode 100644 mlir/lib/Analysis/Presburger/Parser/Lexer.cpp create mode 100644 mlir/lib/Analysis/Presburger/Parser/Lexer.h create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParseStructs.h create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp create mode 100644 mlir/lib/Analysis/Presburger/Parser/ParserImpl.h create mode 100644 mlir/lib/Analysis/Presburger/Parser/Token.cpp create mode 100644 mlir/lib/Analysis/Presburger/Parser/Token.h create mode 100644 mlir/lib/Analysis/Presburger/Parser/TokenKinds.def diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h index 054eb7b26d06e..9b9cfd178ce2d 100644 --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -103,6 +103,9 @@ class Matrix { /// Set the specified row to `elems`. void setRow(unsigned row, ArrayRef elems); + /// Add the specified row to `elems`. + void addToRow(unsigned row, ArrayRef elems); + /// Insert columns having positions pos, pos + 1, ... pos + count - 1. /// Columns that were at positions 0 to pos - 1 will stay where they are; /// columns that were at positions pos to nColumns - 1 will be pushed to the diff --git a/mlir/unittests/Analysis/Presburger/Parser.h b/mlir/include/mlir/Analysis/Presburger/Parser.h similarity index 65% rename from mlir/unittests/Analysis/Presburger/Parser.h rename to mlir/include/mlir/Analysis/Presburger/Parser.h index 75842fb054e2b..0e117e51e203e 100644 --- a/mlir/unittests/Analysis/Presburger/Parser.h +++ b/mlir/include/mlir/Analysis/Presburger/Parser.h @@ -11,32 +11,26 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H -#define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_H #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" -#include "mlir/AsmParser/AsmParser.h" -#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/IntegerSet.h" - -namespace mlir { -namespace presburger { - -/// Parses an IntegerPolyhedron from a StringRef. It is expected that the string -/// represents a valid IntegerSet. -inline IntegerPolyhedron parseIntegerPolyhedron(StringRef str) { - MLIRContext context(MLIRContext::Threading::DISABLED); - return affine::FlatAffineValueConstraints(parseIntegerSet(str, &context)); -} + +namespace mlir::presburger { +using llvm::StringRef; + +/// Parses an IntegerPolyhedron from a StringRef. +IntegerPolyhedron parseIntegerPolyhedron(StringRef str); + +/// Parses a MultiAffineFunction from a StringRef. +MultiAffineFunction parseMultiAffineFunction(StringRef str); /// Parse a list of StringRefs to IntegerRelation and combine them into a -/// PresburgerSet by using the union operation. It is expected that the strings -/// are all valid IntegerSet representation and that all of them have compatible -/// spaces. +/// PresburgerSet by using the union operation. It is expected that the +/// strings are all valid IntegerSet representation and that all of them have +/// compatible spaces. inline PresburgerSet parsePresburgerSet(ArrayRef strs) { assert(!strs.empty() && "strs should not be empty"); @@ -47,25 +41,10 @@ inline PresburgerSet parsePresburgerSet(ArrayRef strs) { return result; } -inline MultiAffineFunction parseMultiAffineFunction(StringRef str) { - MLIRContext context(MLIRContext::Threading::DISABLED); - - // TODO: Add default constructor for MultiAffineFunction. - MultiAffineFunction multiAff(PresburgerSpace::getRelationSpace(), - IntMatrix(0, 1)); - if (getMultiAffineFunctionFromMap(parseAffineMap(str, &context), multiAff) - .failed()) - llvm_unreachable( - "Failed to parse MultiAffineFunction because of semi-affinity"); - return multiAff; -} - inline PWMAFunction parsePWMAF(ArrayRef> pieces) { assert(!pieces.empty() && "At least one piece should be present."); - MLIRContext context(MLIRContext::Threading::DISABLED); - IntegerPolyhedron initDomain = parseIntegerPolyhedron(pieces[0].first); MultiAffineFunction initMultiAff = parseMultiAffineFunction(pieces[0].second); @@ -100,8 +79,6 @@ parsePresburgerRelationFromPresburgerSet(ArrayRef strs, result.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain, 0); return result; } +} // namespace mlir::presburger -} // namespace presburger -} // namespace mlir - -#endif // MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt index 6b9842a6160ab..03fe3eeb1658a 100644 --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Parser) + add_mlir_library(MLIRPresburger Barvinok.cpp IntegerRelation.cpp diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp index 110c5df1af37c..d7aee64044e8f 100644 --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -144,6 +144,14 @@ void Matrix::setRow(unsigned row, ArrayRef elems) { at(row, i) = elems[i]; } +template +void Matrix::addToRow(unsigned row, ArrayRef elems) { + assert(elems.size() == getNumColumns() && + "elems size must match row length!"); + for (unsigned i = 0, e = getNumColumns(); i < e; ++i) + at(row, i) += elems[i]; +} + template void Matrix::insertColumn(unsigned pos) { insertColumns(pos, 1); diff --git a/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt b/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt new file mode 100644 index 0000000000000..f708a5c8db949 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_library(MLIRPresburgerParser + Flattener.cpp + Lexer.cpp + ParserImpl.cpp + ParseStructs.cpp + Token.cpp) diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp new file mode 100644 index 0000000000000..90f645a9f20bd --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp @@ -0,0 +1,88 @@ +//===- Flattener.cpp - Presburger ParseStruct Flattener ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Flattener class for flattening the parse tree +// produced by the parser for the Presburger library. +// +//===----------------------------------------------------------------------===// + +#include "Flattener.h" +#include "ParseStructs.h" + +using namespace mlir; +using namespace presburger; + +void Flattener::visitDiv(const PureAffineExprImpl &div) { + int64_t divisor = div.getDivisor(); + + // First construct the linear part of the divisor. + auto dividend = div.collectLinearTerms().getPadded( + info.getLocalVarStartIdx() + localExprs.size() + 1); + + // Next, insert the non-linear coefficients. + for (const auto &[hash, adjustedMulFactor, adjustedLinearTerm] : + div.getNonLinearCoeffs()) { + + // adjustedMulFactor will be mulFactor * -divisor in case of mod, and + // mulFactor in case of floordiv. + dividend[lookupLocal(hash)] = adjustedMulFactor; + + // The expansion is either a modExpansion which we previously stored, or the + // adjustedLinearTerm, which is correct in the case when we're encountering + // the innermost mod for the first time. + CoefficientVector expansion = + lookupModExpansion(hash) + .value_or(adjustedLinearTerm) + .getPadded(info.getLocalVarStartIdx() + localExprs.size() + 1); + dividend += expansion; + + // If this is a mod, insert the new computed expansion, which is the + // dividend * mulFactor. + if (div.isMod()) + localModExpansion.insert({div.hash(), dividend * div.getMulFactor()}); + } + + cst.addLocalFloorDiv(dividend, divisor); + localExprs.insert(div.hash()); +} + +void Flattener::flatten(unsigned row, PureAffineExprImpl &div) { + // Visit divs inner to outer. + for (auto &nestedDiv : div.getNestedDivTerms()) + flatten(row, *nestedDiv); + + if (div.hasDivisor()) { + visitDiv(div); + return; + } + + // Hit multiple times every time we have a linear sub-expression, but the + // row is overwritten to consider only the outermost div, which is hit + // last. + + // Set the linear part of the row. + setRow(row, div.getLinearDividend()); + + // Set the non-linear coefficients. + for (const auto &[hash, adjustedMulFactor, adjustedLinearTerm] : + div.getNonLinearCoeffs()) { + flatMatrix(row, lookupLocal(hash)) = adjustedMulFactor; + CoefficientVector expansion = + lookupModExpansion(hash).value_or(adjustedLinearTerm); + addToRow(row, expansion); + } +} + +std::pair Flattener::flatten() { + // Use the same flattener to simplify each expression successively. This way + // local variables / expressions are shared. + for (const auto &[row, expr] : enumerate(exprs)) + flatten(row, *expr); + + return {flatMatrix, cst}; +} diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.h b/mlir/lib/Analysis/Presburger/Parser/Flattener.h new file mode 100644 index 0000000000000..d3043ed67f56d --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.h @@ -0,0 +1,67 @@ +//===- Flattener.h - Presburger ParseStruct Flattener -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H + +#include "ParseStructs.h" +#include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::presburger { +using llvm::SmallDenseMap; +using llvm::SmallSetVector; + +class Flattener : public FinalParseResult { +public: + // The final flattened result is stored here. + IntMatrix flatMatrix; + + // We maintain a set of divs that we have seen while flattening. The size of + // this set is at most info.numDivs, hitting info.numDivs at the end of the + // flattening, if that expression contains all the possible divs. + SmallSetVector localExprs; + + // We maintain a mapping between local mods and their expansions. The vector + // is the dividend. + SmallDenseMap localModExpansion; + + Flattener(FinalParseResult &&parseResult) + : FinalParseResult(std::move(parseResult)), + flatMatrix(info.numExprs, info.getNumCols()) {} + + std::pair flatten(); + +private: + void flatten(unsigned row, PureAffineExprImpl &div); + void visitDiv(const PureAffineExprImpl &div); + + void addToRow(unsigned row, const CoefficientVector &l) { + flatMatrix.addToRow(row, + getDynamicAPIntVec(l.getPadded(info.getNumCols()))); + } + void setRow(unsigned row, const CoefficientVector &l) { + flatMatrix.setRow(row, getDynamicAPIntVec(l.getPadded(info.getNumCols()))); + } + + unsigned lookupLocal(size_t hash) { + const auto *it = find(localExprs, hash); + assert(it != localExprs.end() && + "Local expression not found; walking from inner to outer?"); + return info.getLocalVarStartIdx() + it - localExprs.begin(); + } + std::optional lookupModExpansion(size_t hash) { + return localModExpansion.contains(hash) + ? std::make_optional(localModExpansion.at(hash)) + : std::nullopt; + } +}; +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H diff --git a/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp b/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp new file mode 100644 index 0000000000000..8c037539094a6 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp @@ -0,0 +1,136 @@ +//===- Lexer.cpp - Presburger Lexer Implementation ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lexer for the Presburger textual form. +// +//===----------------------------------------------------------------------===// + +#include "Lexer.h" +#include "Token.h" +#include "llvm/ADT/StringSwitch.h" + +using namespace mlir::presburger; + +Lexer::Lexer(const llvm::SourceMgr &sourceMgr) : sourceMgr(sourceMgr) { + auto bufferID = sourceMgr.getMainFileID(); + curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer(); + curPtr = curBuffer.begin(); +} + +Token Lexer::emitError(const char *loc, const llvm::Twine &message) { + sourceMgr.PrintMessage(SMLoc::getFromPointer(loc), llvm::SourceMgr::DK_Error, + message); + return formToken(Token::error, loc); +} + +Token Lexer::lexToken() { + while (true) { + const char *tokStart = curPtr; + + switch (*curPtr++) { + default: + // Handle bare identifiers. + if (isalpha(curPtr[-1])) + return lexBareIdentifierOrKeyword(tokStart); + return emitError(tokStart, "unexpected character"); + + case ' ': + case '\t': + case '\n': + case '\r': + // Handle whitespace. + continue; + + case 0: + // This may be the EOF marker that llvm::MemoryBuffer guarantees will be + // there. + if (curPtr - 1 == curBuffer.end()) + return formToken(Token::eof, tokStart); + return emitError(tokStart, "unexpected character"); + + case ':': + return formToken(Token::colon, tokStart); + case ',': + return formToken(Token::comma, tokStart); + case '(': + return formToken(Token::l_paren, tokStart); + case ')': + return formToken(Token::r_paren, tokStart); + case '{': + return formToken(Token::l_brace, tokStart); + case '}': + return formToken(Token::r_brace, tokStart); + case '[': + return formToken(Token::l_square, tokStart); + case ']': + return formToken(Token::r_square, tokStart); + case '<': + return formToken(Token::less, tokStart); + case '>': + return formToken(Token::greater, tokStart); + case '=': + return formToken(Token::equal, tokStart); + case '+': + return formToken(Token::plus, tokStart); + case '*': + return formToken(Token::star, tokStart); + case '-': + if (*curPtr == '>') { + ++curPtr; + return formToken(Token::arrow, tokStart); + } + return formToken(Token::minus, tokStart); + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + return lexNumber(tokStart); + } + } +} + +/// Lex a bare identifier or keyword that starts with a letter. +/// +/// bare-id ::= (letter|[_]) (letter|digit|[_$.])* +/// +Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_.$]* + while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' || + *curPtr == '$' || *curPtr == '.') + ++curPtr; + + // Check to see if this identifier is a keyword. + StringRef spelling(tokStart, curPtr - tokStart); + + Token::Kind kind = llvm::StringSwitch(spelling) +#define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING) +#include "TokenKinds.def" + .Default(Token::bare_identifier); + + return Token(kind, spelling); +} + +/// Lex a number literal. +/// +/// integer-literal ::= digit+ +/// +Token Lexer::lexNumber(const char *tokStart) { + assert(isdigit(curPtr[-1])); + + // Handle the normal decimal case. + while (isdigit(*curPtr)) + ++curPtr; + + return formToken(Token::integer, tokStart); +} diff --git a/mlir/lib/Analysis/Presburger/Parser/Lexer.h b/mlir/lib/Analysis/Presburger/Parser/Lexer.h new file mode 100644 index 0000000000000..1dc458a358a94 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Lexer.h @@ -0,0 +1,56 @@ +//===- Lexer.h - Presburger Lexer Interface ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Presburger Lexer class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H + +#include "Token.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir::presburger { +/// This class breaks up the current file into a token stream. +class Lexer { +public: + explicit Lexer(const llvm::SourceMgr &sourceMgr); + + Token lexToken(); + + /// Change the position of the lexer cursor. The next token we lex will start + /// at the designated point in the input. + void resetPointer(const char *newPointer) { curPtr = newPointer; } + + /// Returns the start of the buffer. + const char *getBufferBegin() { return curBuffer.data(); } + +private: + Token formToken(Token::Kind kind, const char *tokStart) { + return Token(kind, StringRef(tokStart, curPtr - tokStart)); + } + + Token emitError(const char *loc, const llvm::Twine &message); + + // Lexer implementation methods. + Token lexBareIdentifierOrKeyword(const char *tokStart); + Token lexNumber(const char *tokStart); + + const llvm::SourceMgr &sourceMgr; + + StringRef curBuffer; + const char *curPtr; + + Lexer(const Lexer &) = delete; + void operator=(const Lexer &) = delete; +}; +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp new file mode 100644 index 0000000000000..b937ae27a8fea --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp @@ -0,0 +1,268 @@ +//===- ParseStructs.cpp - Presburger Parse Structrures ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ParseStructs class that the parser for the +// Presburger library parses into. +// +//===----------------------------------------------------------------------===// + +#include "ParseStructs.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir::presburger; +using llvm::dbgs; +using llvm::divideFloorSigned; +using llvm::mod; + +CoefficientVector PureAffineExprImpl::collectLinearTerms() const { + CoefficientVector nestedLinear = std::accumulate( + nestedDivTerms.begin(), nestedDivTerms.end(), CoefficientVector(info), + [](const CoefficientVector &acc, const PureAffineExpr &div) { + return acc + div->getLinearDividend(); + }); + return nestedLinear += linearDividend; +} + +SmallVector, 8> +PureAffineExprImpl::getNonLinearCoeffs() const { + SmallVector, 8> ret; + // dividend `floordiv` divisor <=> q; adjustedMulFactor = 1, + // adjustedLinearTerm is empty. + // + // dividend `mod` divisor <=> dividend - divisor*q; adjustedMulFactor = + // -divisor, adjustedLinearTerm is the linear part of the dividend. + // + // where q is a floordiv id added by the flattener. + auto adjustedMulFactor = [](const PureAffineExprImpl &div) { + return div.mulFactor * (div.kind == DivKind::Mod ? -div.divisor : 1); + }; + auto adjustedLinearTerm = [](PureAffineExprImpl &div) { + return div.kind == DivKind::Mod ? div.linearDividend *= div.mulFactor + : CoefficientVector(div.info); + }; + if (hasDivisor()) + for (const auto &toplevel : nestedDivTerms) + for (const auto &div : toplevel->getNestedDivTerms()) + ret.emplace_back(div->hash(), adjustedMulFactor(*div), + adjustedLinearTerm(*div)); + else + for (const auto &div : nestedDivTerms) + ret.emplace_back(div->hash(), adjustedMulFactor(*div), + adjustedLinearTerm(*div)); + return ret; +} + +unsigned PureAffineExprImpl::countNestedDivs() const { + return hasDivisor() + + std::accumulate(getNestedDivTerms().begin(), getNestedDivTerms().end(), + 0, [](unsigned acc, const PureAffineExpr &div) { + return acc + div->countNestedDivs(); + }); +} + +PureAffineExpr mlir::presburger::operator+(PureAffineExpr &&lhs, + PureAffineExpr &&rhs) { + if (lhs->isLinear() && rhs->isLinear()) + return std::make_unique(lhs->getLinearDividend() + + rhs->getLinearDividend()); + if (lhs->isLinear()) + return std::make_unique( + std::move(rhs->addLinearTerm(lhs->getLinearDividend()))); + if (rhs->isLinear()) + return std::make_unique( + std::move(lhs->addLinearTerm(rhs->getLinearDividend()))); + + if (!(lhs->hasDivisor() ^ rhs->hasDivisor())) { + auto ret = PureAffineExprImpl(lhs->info); + ret.addDivTerm(std::move(*lhs)); + ret.addDivTerm(std::move(*rhs)); + return std::make_unique(std::move(ret)); + } + if (lhs->hasDivisor()) { + rhs->addDivTerm(std::move(*lhs)); + return rhs; + } + if (rhs->hasDivisor()) { + lhs->addDivTerm(std::move(*rhs)); + return lhs; + } + llvm_unreachable("Malformed AffineExpr"); +} + +PureAffineExpr mlir::presburger::operator*(PureAffineExpr &&expr, int64_t c) { + if (expr->isLinear()) + return std::make_unique(expr->getLinearDividend() * c); + return std::make_unique(std::move(expr->mulConstant(c))); +} + +PureAffineExpr mlir::presburger::operator+(PureAffineExpr &&expr, int64_t c) { + return std::move(expr) + std::make_unique(expr->info, c); +} + +PureAffineExpr mlir::presburger::div(PureAffineExpr &÷nd, int64_t divisor, + DivKind kind) { + assert(divisor > 0 && "floorDiv or mod with a negative divisor"); + + // Constant fold. + if (dividend->isConstant()) { + int64_t c = kind == DivKind::FloorDiv + ? divideFloorSigned(dividend->getConstant(), divisor) + : mod(dividend->getConstant(), divisor); + return std::make_unique(dividend->info, c); + } + + // Factor out mul, using gcd internally. + uint64_t exprMultiple; + if (dividend->isLinear()) { + exprMultiple = dividend->getLinearDividend().factorMulFromLinearTerm(); + } else { + // Canonicalize the div. + uint64_t constMultiple = + dividend->getLinearDividend().factorMulFromLinearTerm(); + dividend->divLinearDividend(static_cast(constMultiple)); + dividend->mulFactor *= constMultiple; + exprMultiple = dividend->mulFactor; + } + + // Perform gcd with divisor. + uint64_t gcd = std::gcd(std::abs(divisor), exprMultiple); + + // Divide according to the type. + if (gcd > 1) { + if (dividend->isLinear()) + dividend->linearDividend /= static_cast(gcd); + else + dividend->mulFactor /= (static_cast(gcd)); + + divisor /= static_cast(gcd); + } + + if (dividend->isLinear()) + return std::make_unique(dividend->getLinearDividend(), + divisor, kind); + + // x floordiv 1 <=> x, x % 1 <=> 0 + return divisor == 1 + ? (dividend->isMod() + ? std::make_unique(dividend->info) + : std::move(dividend)) + : std::make_unique(std::move(*dividend), + divisor, kind); + ; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +enum class BindingStrength { + Weak, // + and - + Strong, // All other binary operators. +}; + +static void printCoefficient(int64_t c, bool &isExprBegin, + const ParseInfo &info, int idx = -1) { + bool isConstant = idx != -1 && info.isConstantIdx(idx); + bool isDimOrSymbol = idx != -1 && !info.isConstantIdx(idx); + if (!c) + return; + if (!isExprBegin) + dbgs() << " "; + if (c < 0) { + dbgs() << "- "; + if (c != -1 || isConstant) + dbgs() << std::abs(c); + } else { + if (!isExprBegin) + dbgs() << "+ "; + if (c != 1 || isConstant) { + dbgs() << c; + isExprBegin = false; + } + } + if (isDimOrSymbol) { + if (std::abs(c) != 1) + dbgs() << " * "; + dbgs() << (info.isDimIdx(idx) ? 'd' : 's') << idx; + isExprBegin = false; + } +} + +static bool printCoefficientVec( + const CoefficientVector &linear, bool isExprBegin = true, + BindingStrength enclosingTightness = BindingStrength::Weak) { + if (enclosingTightness == BindingStrength::Strong && + linear.hasMultipleCoefficients()) { + dbgs() << '('; + isExprBegin = true; + } + for (auto [idx, c] : enumerate(linear.getCoefficients())) + printCoefficient(c, isExprBegin, linear.info, idx); + + if (enclosingTightness == BindingStrength::Strong && + linear.hasMultipleCoefficients()) + dbgs() << ')'; + return isExprBegin; +} + +static bool +printAffineExpr(const PureAffineExprImpl &expr, bool isExprBegin = true, + BindingStrength enclosingTightness = BindingStrength::Weak) { + if (expr.isLinear()) + return printCoefficientVec(expr.getLinearDividend(), isExprBegin, + enclosingTightness); + + const auto &div = expr; + const auto &linearDividend = div.getLinearDividend(); + const auto &divisor = div.getDivisor(); + const auto &mulFactor = div.getMulFactor(); + const auto &nestedDivs = div.getNestedDivTerms(); + + printCoefficient(mulFactor, isExprBegin, expr.info); + if (std::abs(mulFactor) != 1) + dbgs() << " * "; + + if (div.hasDivisor() || enclosingTightness == BindingStrength::Strong) { + dbgs() << '('; + isExprBegin = true; + } + + isExprBegin = printCoefficientVec(linearDividend, isExprBegin); + + for (const auto &div : nestedDivs) + isExprBegin = printAffineExpr(*div, isExprBegin, BindingStrength::Strong); + + if (div.hasDivisor()) + dbgs() << (expr.isMod() ? " % " : " floordiv ") << divisor; + + if (div.hasDivisor() || enclosingTightness == BindingStrength::Strong) + dbgs() << ')'; + + return isExprBegin; +} + +LLVM_DUMP_METHOD void CoefficientVector::dump() const { + printCoefficientVec(*this); + dbgs() << '\n'; +} + +LLVM_DUMP_METHOD void PureAffineExprImpl::dump() const { + printAffineExpr(*this); + dbgs() << '\n'; +} + +LLVM_DUMP_METHOD void FinalParseResult::dump() const { + dbgs() << "Exprs:\n"; + for (const auto &expr : exprs) + expr->dump(); + dbgs() << "EqFlags: "; + for (bool eqF : eqFlags) + dbgs() << eqF << ' '; + dbgs() << '\n'; +} +#endif diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h new file mode 100644 index 0000000000000..ecf8428acd656 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h @@ -0,0 +1,353 @@ +//===- ParseStructs.h - Presburger Parse Structrures ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H + +#include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/PresburgerSpace.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir::presburger { +using llvm::ArrayRef; +using llvm::SmallVector; +using llvm::SmallVectorImpl; + +/// This structure is central to the parser and flattener, and holds the number +/// of dimensions, symbols, locals, and the constant term. +struct ParseInfo { + unsigned numDims = 0; + unsigned numSymbols = 0; + unsigned numExprs = 0; + unsigned numDivs = 0; + + constexpr unsigned getDimStartIdx() const { return 0; } + constexpr unsigned getSymbolStartIdx() const { return numDims; } + constexpr unsigned getLocalVarStartIdx() const { + return numDims + numSymbols; + } + constexpr unsigned getNumCols() const { + return numDims + numSymbols + numDivs + 1; + } + constexpr unsigned getConstantIdx() const { return getNumCols() - 1; } + + constexpr bool isDimIdx(unsigned i) const { return i < getSymbolStartIdx(); } + constexpr bool isSymbolIdx(unsigned i) const { + return i >= getSymbolStartIdx() && i < getLocalVarStartIdx(); + } + constexpr bool isLocalVarIdx(unsigned i) const { + return i >= getLocalVarStartIdx() && i < getConstantIdx(); + } + constexpr bool isConstantIdx(unsigned i) const { + return i == getConstantIdx(); + } +}; + +/// Helper for storing coefficients in canonical form: dims followed by symbols, +/// followed by locals, and finally the constant term. +/// +/// (x, y)[a, b]: y * 91 + x + 3 * a + 7 +/// coefficients: [1, 91, 3, 0, 7] +struct CoefficientVector { + ParseInfo info; + SmallVector coefficients; + + CoefficientVector(const ParseInfo &info, int64_t c = 0) : info(info) { + coefficients.resize(info.getNumCols()); + coefficients[info.getConstantIdx()] = c; + } + + // Copyable and movable + CoefficientVector(const CoefficientVector &o) = default; + CoefficientVector &operator=(const CoefficientVector &o) = default; + CoefficientVector(CoefficientVector &&o) + : info(o.info), coefficients(std::move(o.coefficients)) { + o.coefficients.clear(); + } + + ArrayRef getCoefficients() const { return coefficients; } + int64_t getConstant() const { + return coefficients[info.getConstantIdx()]; + } + size_t size() const { return coefficients.size(); } + operator ArrayRef() const { return coefficients; } + void resize(size_t size) { coefficients.resize(size); } + operator bool() const { + return any_of(coefficients, [](int64_t c) { return c; }); + } + int64_t &operator[](unsigned i) { + assert(i < coefficients.size()); + return coefficients[i]; + } + int64_t &back() { return coefficients.back(); } + int64_t back() const { return coefficients.back(); } + void clear() { + for_each(coefficients, [](auto &coeff) { coeff = 0; }); + } + + CoefficientVector &operator+=(const CoefficientVector &l) { + coefficients.resize(l.size()); + for (auto [idx, c] : enumerate(l.getCoefficients())) + coefficients[idx] += c; + return *this; + } + CoefficientVector &operator*=(int64_t c) { + for_each(coefficients, [c](auto &coeff) { coeff *= c; }); + return *this; + } + CoefficientVector &operator/=(int64_t c) { + assert(c && "Division by zero"); + for_each(coefficients, [c](auto &coeff) { coeff /= c; }); + return *this; + } + + CoefficientVector operator+(const CoefficientVector &l) const { + CoefficientVector ret(*this); + return ret += l; + } + CoefficientVector operator*(int64_t c) const { + CoefficientVector ret(*this); + return ret *= c; + } + CoefficientVector operator/(int64_t c) const { + CoefficientVector ret(*this); + return ret /= c; + } + + bool isConstant() const { + return all_of(drop_end(coefficients), [](int64_t c) { return !c; }); + } + CoefficientVector getPadded(size_t newSize) const { + assert(newSize >= size() && + "Padding size should be greater than expr size"); + CoefficientVector ret(info); + ret.resize(newSize); + + // Start constructing the result by taking the dims and symbols of the + // coefficients. + for (const auto &[col, coeff] : enumerate(drop_end(coefficients))) + ret[col] = coeff; + + // Put the constant at the end. + ret.back() = back(); + return ret; + } + + uint64_t factorMulFromLinearTerm() const { + uint64_t gcd = 1; + for (int64_t val : coefficients) + gcd = std::gcd(gcd, std::abs(val)); + return gcd; + } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + bool hasMultipleCoefficients() const { + return count_if(coefficients, [](auto &coeff) { return coeff; }) > 1; + } + LLVM_DUMP_METHOD void dump() const; +#endif +}; + +enum class DimOrSymbolKind { + DimId, + Symbol, +}; + +using DimOrSymbolExpr = std::pair; +enum class DivKind { FloorDiv, Mod }; + +/// Represents a pure Affine expression. Linear expressions are represented with +/// divisor = 1, and no nestedDivTerms. +/// +/// 3 - a * (3 + (3 - x div 3) div 4 + y div 7) div 4 +/// ^ linearDivident = 3, mulFactor = 1, divisor = 1 +/// ^ nest: 1, mulFactor: -a +/// ^ nest: 1, linearDividend +/// ^ nest: 2, linearDividend +/// ^ nest: 3 +/// ^ nest: 2 +/// ^ nest: 2 +/// nest: 1, divisor ^ +/// +/// Where div = floordiv|mod; ceildiv is pre-reduced +struct PureAffineExprImpl { + ParseInfo info; + DivKind kind = DivKind::FloorDiv; + using PureAffineExpr = std::unique_ptr; + + int64_t mulFactor = 1; + CoefficientVector linearDividend; + int64_t divisor = 1; + SmallVector nestedDivTerms; + + PureAffineExprImpl(const ParseInfo &info, int64_t c = 0) + : info(info), linearDividend(info, c) {} + PureAffineExprImpl(const ParseInfo &info, DimOrSymbolExpr idExpr) + : PureAffineExprImpl(info) { + auto [kind, pos] = idExpr; + unsigned startIdx = kind == DimOrSymbolKind::Symbol + ? info.getSymbolStartIdx() + : info.getDimStartIdx(); + linearDividend.coefficients[startIdx + pos] = 1; + } + PureAffineExprImpl(const CoefficientVector &linearDividend, + int64_t divisor = 1, DivKind kind = DivKind::FloorDiv) + : info(linearDividend.info), kind(kind), linearDividend(linearDividend), + divisor(divisor) {} + PureAffineExprImpl(PureAffineExprImpl &&div, int64_t divisor, DivKind kind) + : info(div.info), kind(kind), linearDividend(div.info), divisor(divisor) { + addDivTerm(std::move(div)); + } + + // Non-copyable, only movable + PureAffineExprImpl(const PureAffineExprImpl &) = delete; + PureAffineExprImpl(PureAffineExprImpl &&o) + : info(o.info), kind(o.kind), mulFactor(o.mulFactor), + linearDividend(std::move(o.linearDividend)), divisor(o.divisor), + nestedDivTerms(std::move(o.nestedDivTerms)) { + o.nestedDivTerms.clear(); + o.divisor = o.mulFactor = 1; + } + + const CoefficientVector &getLinearDividend() const { return linearDividend; } + CoefficientVector collectLinearTerms() const; + SmallVector, 8> + getNonLinearCoeffs() const; + + constexpr bool isMod() const { return kind == DivKind::Mod; } + constexpr bool hasDivisor() const { return divisor != 1; } + bool isLinear() const { + return nestedDivTerms.empty() && !hasDivisor(); + } + constexpr int64_t getDivisor() const { return divisor; } + constexpr int64_t getMulFactor() const { return mulFactor; } + bool isConstant() const { + return nestedDivTerms.empty() && getLinearDividend().isConstant(); + } + int64_t getConstant() const { + return getLinearDividend().getConstant(); + } + size_t hash() const { + return std::hash{}(this); + } + ArrayRef getNestedDivTerms() const { return nestedDivTerms; } + + PureAffineExprImpl &mulConstant(int64_t c) { + // Canonicalize mulFactors in div terms without divisors. + mulFactor *= c; + distributeMulFactor(); + return *this; + } + PureAffineExprImpl &addLinearTerm(const CoefficientVector &l) { + if (hasDivisor()) + nestedDivTerms.emplace_back( + std::make_unique(std::move(*this))); + linearDividend += l; + return *this; + } + PureAffineExprImpl &addDivTerm(PureAffineExprImpl &&d) { + nestedDivTerms.emplace_back( + std::make_unique(std::move(d))); + return *this; + } + PureAffineExprImpl &divLinearDividend(int64_t c) { + linearDividend /= c; + return *this; + } + + void distributeMulFactor() { + // Canonicalize the -1 mulFactor for divs without divisor. + if (mulFactor != -1 || hasDivisor()) + return; + linearDividend *= -1; + for_each(nestedDivTerms, + [](const PureAffineExpr &div) { div->mulFactor *= -1; }); + mulFactor *= -1; + } + unsigned countNestedDivs() const; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const; +#endif +}; + +using PureAffineExpr = std::unique_ptr; + +/// This structure holds the final parse result, and is constructed +/// non-trivially to compute the total number of divs, which is then used to +/// compute the total number of columns, and construct the IntMatrix in the +/// Flattener. +struct FinalParseResult { + ParseInfo info; + SmallVector exprs; + SmallVector eqFlags; + IntegerPolyhedron cst; + + FinalParseResult(const ParseInfo &parseInfo, + SmallVectorImpl &&exprStack, + ArrayRef eqFlagStack) + : info(parseInfo), exprs(std::move(exprStack)), eqFlags(eqFlagStack), + cst(0, 0, info.numDims + info.numSymbols + 1, + PresburgerSpace::getSetSpace(info.numDims, info.numSymbols, 0)) { + auto &i = this->info; + i.numExprs = exprs.size(); + i.numDivs = std::accumulate(exprs.begin(), exprs.end(), 0, + [](unsigned acc, const PureAffineExpr &expr) { + return acc + expr->countNestedDivs(); + }); + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const; +#endif +}; + +// The main interface to the parser is a bunch of operators, which is used to +// successively build the final AffineExpr. +PureAffineExpr operator+(PureAffineExpr &&lhs, PureAffineExpr &&rhs); +PureAffineExpr operator*(PureAffineExpr &&expr, int64_t c); +PureAffineExpr operator+(PureAffineExpr &&expr, int64_t c); +PureAffineExpr div(PureAffineExpr &÷nd, int64_t divisor, DivKind kind); +inline PureAffineExpr floordiv(PureAffineExpr &&expr, int64_t c) { + return div(std::move(expr), c, DivKind::FloorDiv); +} +inline PureAffineExpr operator%(PureAffineExpr &&expr, int64_t c) { + return div(std::move(expr), c, DivKind::Mod); +} +inline PureAffineExpr operator*(int64_t c, PureAffineExpr &&expr) { + return std::move(expr) * c; +} +inline PureAffineExpr operator-(PureAffineExpr &&lhs, PureAffineExpr &&rhs) { + return std::move(lhs) + std::move(rhs) * -1; +} +inline PureAffineExpr operator+(int64_t c, PureAffineExpr &&expr) { + return std::move(expr) + c; +} +inline PureAffineExpr operator-(PureAffineExpr &&expr, int64_t c) { + return std::move(expr) + (-c); +} +inline PureAffineExpr operator-(int64_t c, PureAffineExpr &&expr) { + return -1 * std::move(expr) + c; +} +inline PureAffineExpr ceildiv(PureAffineExpr &&expr, int64_t c) { + // expr ceildiv c <=> (expr + c - 1) floordiv c + return floordiv(std::move(expr) + (c - 1), c); +} + +// Our final canonical expression, the outermost div, should have a divisor +// of 1. +inline PureAffineExpr canonicalize(PureAffineExpr &&expr) { + if (expr->hasDivisor()) + expr->addLinearTerm(CoefficientVector(expr->info)); + return expr; +} +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp new file mode 100644 index 0000000000000..20b2b6d08638e --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp @@ -0,0 +1,640 @@ +//===- ParserImpl.cpp - Presburger Parser Implementation --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ParserImpl class for the Presburger textual form. +// +//===----------------------------------------------------------------------===// + +#include "ParserImpl.h" +#include "Flattener.h" +#include "ParseStructs.h" +#include "mlir/Analysis/Presburger/Parser.h" + +using namespace mlir; +using namespace presburger; +using llvm::MemoryBuffer; +using llvm::SourceMgr; + +static bool isIdentifier(const Token &token) { + return token.is(Token::bare_identifier) || token.isKeyword(); +} + +bool ParserImpl::parseToken(Token::Kind expectedToken, const Twine &message) { + if (consumeIf(expectedToken)) + return true; + return emitError(message); +} + +bool ParserImpl::parseCommaSepeatedList(Delimiter delimiter, + function_ref parseElementFn, + StringRef contextMessage) { + switch (delimiter) { + case Delimiter::None: + break; + case Delimiter::OptionalParen: + if (getToken().isNot(Token::l_paren)) + return true; + [[fallthrough]]; + case Delimiter::Paren: + if (!parseToken(Token::l_paren, "expected '('" + contextMessage)) + return false; + if (consumeIf(Token::r_paren)) + return true; + break; + case Delimiter::OptionalLessGreater: + if (getToken().isNot(Token::less)) + return true; + [[fallthrough]]; + case Delimiter::LessGreater: + if (!parseToken(Token::less, "expected '<'" + contextMessage)) + return true; + if (consumeIf(Token::greater)) + return true; + break; + case Delimiter::OptionalSquare: + if (getToken().isNot(Token::l_square)) + return true; + [[fallthrough]]; + case Delimiter::Square: + if (!parseToken(Token::l_square, "expected '['" + contextMessage)) + return false; + if (consumeIf(Token::r_square)) + return true; + break; + case Delimiter::OptionalBraces: + if (getToken().isNot(Token::l_brace)) + return true; + [[fallthrough]]; + case Delimiter::Braces: + if (!parseToken(Token::l_brace, "expected '{'" + contextMessage)) + return false; + if (consumeIf(Token::r_brace)) + return true; + break; + } + + if (!parseElementFn()) + return false; + + while (consumeIf(Token::comma)) + if (!parseElementFn()) + return false; + + switch (delimiter) { + case Delimiter::None: + return true; + case Delimiter::OptionalParen: + case Delimiter::Paren: + return parseToken(Token::r_paren, "expected ')'" + contextMessage); + case Delimiter::OptionalLessGreater: + case Delimiter::LessGreater: + return parseToken(Token::greater, "expected '>'" + contextMessage); + case Delimiter::OptionalSquare: + case Delimiter::Square: + return parseToken(Token::r_square, "expected ']'" + contextMessage); + case Delimiter::OptionalBraces: + case Delimiter::Braces: + return parseToken(Token::r_brace, "expected '}'" + contextMessage); + } + llvm_unreachable("Unknown delimiter"); +} + +bool ParserImpl::emitError(SMLoc loc, const Twine &message) { + // If we hit a parse error in response to a lexer error, then the lexer + // already reported the error. + if (getToken().isNot(Token::error)) + sourceMgr.PrintMessage(loc, SourceMgr::DK_Error, message); + return false; +} + +bool ParserImpl::emitError(const Twine &message) { + SMLoc loc = state.curToken.getLoc(); + if (state.curToken.isNot(Token::eof)) + return emitError(loc, message); + + // If the error is to be emitted at EOF, move it back one character. + return emitError(SMLoc::getFromPointer(loc.getPointer() - 1), message); +} + +/// Parse a bare id that may appear in an affine expression. +/// +/// affine-expr ::= bare-id +PureAffineExpr ParserImpl::parseBareIdExpr() { + if (!isIdentifier(getToken())) + return emitError("expected bare identifier"), nullptr; + + StringRef sRef = getTokenSpelling(); + for (const auto &entry : dimsAndSymbols) { + if (entry.first == sRef) { + consumeToken(); + return std::make_unique(info, entry.second); + } + } + + return emitError("use of undeclared identifier"), nullptr; +} + +/// Parse an affine expression inside parentheses. +/// +/// affine-expr ::= `(` affine-expr `)` +PureAffineExpr ParserImpl::parseParentheticalExpr() { + if (!parseToken(Token::l_paren, "expected '('")) + return nullptr; + if (getToken().is(Token::r_paren)) { + return emitError("no expression inside parentheses"), nullptr; + } + + PureAffineExpr expr = parseAffineExpr(); + if (!expr || !parseToken(Token::r_paren, "expected ')'")) + return nullptr; + + return expr; +} + +/// Parse the negation expression. +/// +/// affine-expr ::= `-` affine-expr +PureAffineExpr ParserImpl::parseNegateExpression(const PureAffineExpr &lhs) { + if (!parseToken(Token::minus, "expected '-'")) + return nullptr; + + PureAffineExpr operand = parseAffineOperandExpr(lhs); + // Since negation has the highest precedence of all ops (including high + // precedence ops) but lower than parentheses, we are only going to use + // parseAffineOperandExpr instead of parseAffineExpr here. + if (!operand) { + // Extra error message although parseAffineOperandExpr would have + // complained. Leads to a better diagnostic. + return emitError("missing operand of negation"), nullptr; + } + return -1 * std::move(operand); +} + +/// Parse a positive integral constant appearing in an affine expression. +/// +/// affine-expr ::= integer-literal +PureAffineExpr ParserImpl::parseIntegerExpr() { + std::optional val = getToken().getUInt64IntegerValue(); + if (!val) + return emitError("failed to parse constant"), nullptr; + int64_t ret = static_cast(*val); + if (ret < 0) + return emitError("constant too large"), nullptr; + + consumeToken(Token::integer); + return std::make_unique(info, ret); +} + +/// Parses an expression that can be a valid operand of an affine expression. +/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary +/// operator, the rhs of which is being parsed. This is used to determine +/// whether an error should be emitted for a missing right operand. +// Eg: for an expression without parentheses (like i + j + k + l), each +// of the four identifiers is an operand. For i + j*k + l, j*k is not an +// operand expression, it's an op expression and will be parsed via +// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and +// -l are valid operands that will be parsed by this function. +PureAffineExpr ParserImpl::parseAffineOperandExpr(const PureAffineExpr &lhs) { + switch (getToken().getKind()) { + case Token::integer: + return parseIntegerExpr(); + case Token::l_paren: + return parseParentheticalExpr(); + case Token::minus: + return parseNegateExpression(lhs); + case Token::kw_ceildiv: + case Token::kw_floordiv: + case Token::kw_mod: + return parseBareIdExpr(); + case Token::plus: + case Token::star: + if (lhs) + emitError("missing right operand of binary operator"); + else + emitError("missing left operand of binary operator"); + return nullptr; + default: + if (isIdentifier(getToken())) + return parseBareIdExpr(); + + if (lhs) + emitError("missing right operand of binary operator"); + else + emitError("expected affine expression"); + return nullptr; + } +} + +PureAffineExpr ParserImpl::getAffineBinaryOpExpr(AffineHighPrecOp op, + PureAffineExpr &&lhs, + PureAffineExpr &&rhs, + SMLoc opLoc) { + switch (op) { + case Mul: + if (!lhs->isConstant() && !rhs->isConstant()) { + return emitError(opLoc, + "non-affine expression: at least one of the multiply " + "operands has to be a constant"), + nullptr; + } + if (rhs->isConstant()) + return std::move(lhs) * rhs->getConstant(); + return std::move(rhs) * lhs->getConstant(); + case FloorDiv: + if (!rhs->isConstant()) { + return emitError(opLoc, + "non-affine expression: right operand of floordiv " + "has to be a constant"), + nullptr; + } + return floordiv(std::move(lhs), rhs->getConstant()); + case CeilDiv: + if (!rhs->isConstant()) { + return emitError(opLoc, "non-affine expression: right operand of ceildiv " + "has to be a constant"), + nullptr; + } + return ceildiv(std::move(lhs), rhs->getConstant()); + case Mod: + if (!rhs->isConstant()) { + return emitError(opLoc, "non-affine expression: right operand of mod " + "has to be a constant"), + nullptr; + } + return std::move(lhs) % rhs->getConstant(); + case HNoOp: + llvm_unreachable("can't create affine expression for null high prec op"); + return nullptr; + } + llvm_unreachable("Unknown AffineHighPrecOp"); +} + +PureAffineExpr ParserImpl::getAffineBinaryOpExpr(AffineLowPrecOp op, + PureAffineExpr &&lhs, + PureAffineExpr &&rhs) { + switch (op) { + case AffineLowPrecOp::Add: + return std::move(lhs) + std::move(rhs); + case AffineLowPrecOp::Sub: + return std::move(lhs) - std::move(rhs); + case AffineLowPrecOp::LNoOp: + llvm_unreachable("can't create affine expression for null low prec op"); + return nullptr; + } + llvm_unreachable("Unknown AffineLowPrecOp"); +} + +AffineLowPrecOp ParserImpl::consumeIfLowPrecOp() { + switch (getToken().getKind()) { + case Token::plus: + consumeToken(Token::plus); + return AffineLowPrecOp::Add; + case Token::minus: + consumeToken(Token::minus); + return AffineLowPrecOp::Sub; + default: + return AffineLowPrecOp::LNoOp; + } +} + +/// Consume this token if it is a higher precedence affine op (there are only +/// two precedence levels) +AffineHighPrecOp ParserImpl::consumeIfHighPrecOp() { + switch (getToken().getKind()) { + case Token::star: + consumeToken(Token::star); + return Mul; + case Token::kw_floordiv: + consumeToken(Token::kw_floordiv); + return FloorDiv; + case Token::kw_ceildiv: + consumeToken(Token::kw_ceildiv); + return CeilDiv; + case Token::kw_mod: + consumeToken(Token::kw_mod); + return Mod; + default: + return HNoOp; + } +} + +/// Parse a high precedence op expression list: mul, div, and mod are high +/// precedence binary ops, i.e., parse a +/// expr_1 op_1 expr_2 op_2 ... expr_n +/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). +/// All affine binary ops are left associative. +/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is +/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is +/// null. llhsOpLoc is the location of the llhsOp token that will be used to +/// report an error for non-conforming expressions. +PureAffineExpr ParserImpl::parseAffineHighPrecOpExpr(PureAffineExpr &&llhs, + AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc) { + PureAffineExpr lhs = parseAffineOperandExpr(llhs); + if (!lhs) + return nullptr; + + // Found an LHS. Parse the remaining expression. + SMLoc opLoc = getToken().getLoc(); + if (AffineHighPrecOp op = consumeIfHighPrecOp()) { + if (llhs) { + PureAffineExpr expr = + getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs), opLoc); + if (!expr) + return nullptr; + return parseAffineHighPrecOpExpr(std::move(expr), op, opLoc); + } + // No LLHS, get RHS + return parseAffineHighPrecOpExpr(std::move(lhs), op, opLoc); + } + + // This is the last operand in this expression. + if (llhs) + return getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs), + llhsOpLoc); + + return lhs; +} + +/// Parse affine expressions that are bare-id's, integer constants, +/// parenthetical affine expressions, and affine op expressions that are a +/// composition of those. +/// +/// All binary op's associate from left to right. +/// +/// {add, sub} have lower precedence than {mul, div, and mod}. +/// +/// Add, sub'are themselves at the same precedence level. Mul, floordiv, +/// ceildiv, and mod are at the same higher precedence level. Negation has +/// higher precedence than any binary op. +/// +/// llhs: the affine expression appearing on the left of the one being parsed. +/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, +/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned +/// if llhs is non-null; otherwise lhs is returned. This is to deal with left +/// associativity. +/// +/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function +/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where +/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). +PureAffineExpr ParserImpl::parseAffineLowPrecOpExpr(PureAffineExpr &&llhs, + AffineLowPrecOp llhsOp) { + PureAffineExpr lhs = parseAffineOperandExpr(llhs); + if (!lhs) + return nullptr; + + // Found an LHS. Deal with the ops. + if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { + if (llhs) { + PureAffineExpr sum = + getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs)); + return parseAffineLowPrecOpExpr(std::move(sum), lOp); + } + + return parseAffineLowPrecOpExpr(std::move(lhs), lOp); + } + SMLoc opLoc = getToken().getLoc(); + if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { + // We have a higher precedence op here. Get the rhs operand for the llhs + // through parseAffineHighPrecOpExpr. + PureAffineExpr highRes = + parseAffineHighPrecOpExpr(std::move(lhs), hOp, opLoc); + if (!highRes) + return nullptr; + + // If llhs is null, the product forms the first operand of the yet to be + // found expression. If non-null, the op to associate with llhs is llhsOp. + PureAffineExpr expr = llhs ? getAffineBinaryOpExpr(llhsOp, std::move(llhs), + std::move(highRes)) + : std::move(highRes); + + // Recurse for subsequent low prec op's after the affine high prec op + // expression. + if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) + return parseAffineLowPrecOpExpr(std::move(expr), nextOp); + return expr; + } + // Last operand in the expression list. + if (llhs) + return getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs)); + + return lhs; +} + +/// Parse an affine expression. +/// affine-expr ::= `(` affine-expr `)` +/// | `-` affine-expr +/// | affine-expr `+` affine-expr +/// | affine-expr `-` affine-expr +/// | affine-expr `*` affine-expr +/// | affine-expr `floordiv` affine-expr +/// | affine-expr `ceildiv` affine-expr +/// | affine-expr `mod` affine-expr +/// | bare-id +/// | integer-literal +/// +/// Additional conditions are checked depending on the production. For eg., +/// one of the operands for `*` has to be either constant/symbolic; the second +/// operand for floordiv, ceildiv, and mod has to be a positive integer. +PureAffineExpr ParserImpl::parseAffineExpr() { + return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); +} + +/// Parse a dim or symbol from the lists appearing before the actual +/// expressions of the affine map. Update our state to store the +/// dimensional/symbolic identifier. +bool ParserImpl::parseIdentifierDefinition(DimOrSymbolExpr idExpr) { + if (!isIdentifier(getToken())) + return emitError("expected bare identifier"); + + StringRef name = getTokenSpelling(); + for (const auto &entry : dimsAndSymbols) { + if (entry.first == name) + return emitError("redefinition of identifier '" + name + "'"); + } + consumeToken(); + + dimsAndSymbols.emplace_back(name, idExpr); + return true; +} + +/// Parse the list of dimensional identifiers to an affine map. +bool ParserImpl::parseDimIdList() { + auto parseElt = [&]() -> bool { + return parseIdentifierDefinition({DimOrSymbolKind::DimId, info.numDims++}); + }; + return parseCommaSepeatedList(Delimiter::Paren, parseElt, + " in dimensional identifier list"); +} + +/// Parse the list of symbolic identifiers to an affine map. +bool ParserImpl::parseSymbolIdList() { + auto parseElt = [&]() -> bool { + return parseIdentifierDefinition( + {DimOrSymbolKind::Symbol, info.numSymbols++}); + }; + return parseCommaSepeatedList(Delimiter::Square, parseElt, " in symbol list"); +} + +/// Parse the list of symbolic identifiers to an affine map. +bool ParserImpl::parseDimAndOptionalSymbolIdList() { + if (!parseDimIdList()) + return false; + if (getToken().isNot(Token::l_square)) { + info.numSymbols = 0; + return true; + } + return parseSymbolIdList(); +} + +/// Parse an affine constraint. +/// affine-constraint ::= affine-expr `>=` `affine-expr` +/// | affine-expr `<=` `affine-expr` +/// | affine-expr `==` `affine-expr` +/// +/// The constraint is normalized to +/// affine-constraint ::= affine-expr `>=` `0` +/// | affine-expr `==` `0` +/// before returning. +/// +/// isEq is set to true if the parsed constraint is an equality, false if it +/// is an inequality (greater than or equal). +PureAffineExpr ParserImpl::parseAffineConstraint(bool *isEq) { + PureAffineExpr lhsExpr = parseAffineExpr(); + if (!lhsExpr) + return nullptr; + + if (consumeIf(Token::greater) && consumeIf(Token::equal)) { + PureAffineExpr rhsExpr = parseAffineExpr(); + if (!rhsExpr) + return nullptr; + *isEq = false; + return std::move(lhsExpr) - std::move(rhsExpr); + } + + if (consumeIf(Token::less) && consumeIf(Token::equal)) { + PureAffineExpr rhsExpr = parseAffineExpr(); + if (!rhsExpr) + return nullptr; + *isEq = false; + return std::move(rhsExpr) - std::move(lhsExpr); + } + + if (consumeIf(Token::equal) && consumeIf(Token::equal)) { + PureAffineExpr rhsExpr = parseAffineExpr(); + if (!rhsExpr) + return nullptr; + *isEq = true; + return std::move(lhsExpr) - std::move(rhsExpr); + } + + return emitError("expected '==', '<=' or '>='"), nullptr; +} + +FinalParseResult ParserImpl::parseAffineMapOrIntegerSet() { + SmallVector exprs; + SmallVector eqFlags; + + if (!parseDimAndOptionalSymbolIdList()) + llvm_unreachable("expected dim and symbol list"); + + auto parseExpr = [&]() -> bool { + PureAffineExpr elt = canonicalize(parseAffineExpr()); + if (elt) { + exprs.emplace_back(std::move(elt)); + return true; + } + return false; + }; + + auto parseConstraint = [&]() -> bool { + bool isEq; + PureAffineExpr elt = canonicalize(parseAffineConstraint(&isEq)); + if (elt) { + exprs.emplace_back(std::move(elt)); + eqFlags.push_back(isEq); + return true; + } + return false; + }; + + if (consumeIf(Token::arrow)) { + /// Parse the range and sizes affine map definition inline. + /// + /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr + /// + /// multi-dim-affine-expr ::= `(` `)` + /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` + if (!parseCommaSepeatedList(Delimiter::Paren, parseExpr, + " in affine map range")) + llvm_unreachable("expected affine map range"); + + return {info, std::move(exprs), eqFlags}; + } + + if (!parseToken(Token::colon, "expected '->' or ':'")) + llvm_unreachable("Unexpected token"); + + /// Parse the constraints that are part of an integer set definition. + /// integer-set-inline + /// ::= dim-and-symbol-id-lists `:` + /// '(' affine-constraint-conjunction? ')' + /// affine-constraint-conjunction ::= affine-constraint (`,` + /// affine-constraint)* + /// + if (!parseCommaSepeatedList(Delimiter::Paren, parseConstraint, + " in integer set constraint list")) + llvm_unreachable("expected integer set"); + + return {info, std::move(exprs), eqFlags}; +} + +static MultiAffineFunction getMAF(FinalParseResult &&parseResult) { + auto info = parseResult.info; + auto [flatMatrix, cst] = Flattener(std::move(parseResult)).flatten(); + + DivisionRepr divs = cst.getLocalReprs(); + assert(divs.hasAllReprs() && + "AffineMap cannot produce divs without local representation"); + + return MultiAffineFunction( + PresburgerSpace::getRelationSpace(info.numDims, info.numExprs, + info.numSymbols, divs.getNumDivs()), + flatMatrix, divs); +} + +static IntegerPolyhedron getPoly(FinalParseResult &&parseResult) { + auto eqFlags = parseResult.eqFlags; + auto [flatMatrix, cst] = Flattener(std::move(parseResult)).flatten(); + + for (unsigned i = 0; i < flatMatrix.getNumRows(); ++i) { + if (eqFlags[i]) + cst.addEquality(flatMatrix.getRow(i)); + else + cst.addInequality(flatMatrix.getRow(i)); + } + + return cst; +} + +static FinalParseResult parseAffineMapOrIntegerSet(StringRef str) { + SourceMgr sourceMgr; + auto memBuffer = + MemoryBuffer::getMemBuffer(str, "", false); + sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); + ParserImpl parser(sourceMgr); + return parser.parseAffineMapOrIntegerSet(); +} + +IntegerPolyhedron mlir::presburger::parseIntegerPolyhedron(StringRef str) { + return getPoly(parseAffineMapOrIntegerSet(str)); +} + +MultiAffineFunction mlir::presburger::parseMultiAffineFunction(StringRef str) { + return getMAF(parseAffineMapOrIntegerSet(str)); +} diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h new file mode 100644 index 0000000000000..7fef4ded21673 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h @@ -0,0 +1,160 @@ +//===- ParserImpl.h - Presburger Parser Implementation ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H + +#include "Lexer.h" +#include "ParseStructs.h" + +namespace mlir::presburger { +template +using function_ref = llvm::function_ref; +using llvm::SourceMgr; +using llvm::Twine; + +/// This class refers to the lexing-related state for the parser. +struct ParserState { + ParserState(const llvm::SourceMgr &sourceMgr) + : lex(sourceMgr), curToken(lex.lexToken()), lastToken(Token::error, "") {} + ParserState(const ParserState &) = delete; + + Lexer lex; + Token curToken; + Token lastToken; +}; + +/// These are the supported delimiters around operand lists and region +/// argument lists, used by parseOperandList. +enum class Delimiter { + /// Zero or more operands with no delimiters. + None, + /// Parens surrounding zero or more operands. + Paren, + /// Square brackets surrounding zero or more operands. + Square, + /// <> brackets surrounding zero or more operands. + LessGreater, + /// {} brackets surrounding zero or more operands. + Braces, + /// Parens supporting zero or more operands, or nothing. + OptionalParen, + /// Square brackets supporting zero or more ops, or nothing. + OptionalSquare, + /// <> brackets supporting zero or more ops, or nothing. + OptionalLessGreater, + /// {} brackets surrounding zero or more operands, or nothing. + OptionalBraces, +}; + +/// Lower precedence ops (all at the same precedence level). LNoOp is false in +/// the boolean sense. +enum AffineLowPrecOp { + LNoOp, // Null value. + Add, + Sub +}; + +/// Higher precedence ops - all at the same precedence level. HNoOp is false +/// in the boolean sense. +enum AffineHighPrecOp { + HNoOp, // Null value. + Mul, + FloorDiv, + CeilDiv, + Mod +}; + +class ParserImpl { +public: + ParserImpl(const SourceMgr &sourceMgr) + : sourceMgr(sourceMgr), state(sourceMgr) {} + + FinalParseResult parseAffineMapOrIntegerSet(); + +private: + /// Parse a list of comma-separated items with an optional delimiter. If a + /// delimiter is provided, then an empty list is allowed. If not, then at + /// least one element will be parsed. + bool parseCommaSepeatedList(Delimiter delimiter, + function_ref parseElementFn, + StringRef contextMessage); + + //===--------------------------------------------------------------------===// + // Error Handling + //===--------------------------------------------------------------------===// + bool emitError(const Twine &message); + bool emitError(SMLoc loc, const Twine &message); + + //===--------------------------------------------------------------------===// + // Token Parsing + //===--------------------------------------------------------------------===// + const Token &getToken() const { return state.curToken; } + StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } + const Token &getLastToken() const { return state.lastToken; } + bool consumeIf(Token::Kind kind) { + if (state.curToken.isNot(kind)) + return false; + consumeToken(kind); + return true; + } + void consumeToken() { + assert(state.curToken.isNot(Token::eof, Token::error) && + "shouldn't advance past EOF or errors"); + state.lastToken = state.curToken; + state.curToken = state.lex.lexToken(); + } + void consumeToken(Token::Kind kind) { + assert(state.curToken.is(kind) && "consumed an unexpected token"); + consumeToken(); + } + void resetToken(const char *tokPos) { + state.lex.resetPointer(tokPos); + state.lastToken = state.curToken; + state.curToken = state.lex.lexToken(); + } + bool parseToken(Token::Kind expectedToken, const Twine &message); + + //===--------------------------------------------------------------------===// + // Affine Parsing + //===--------------------------------------------------------------------===// + AffineLowPrecOp consumeIfLowPrecOp(); + AffineHighPrecOp consumeIfHighPrecOp(); + + bool parseDimIdList(); + bool parseSymbolIdList(); + bool parseDimAndOptionalSymbolIdList(); + bool parseIdentifierDefinition(DimOrSymbolExpr idExpr); + + PureAffineExpr parseAffineExpr(); + PureAffineExpr parseParentheticalExpr(); + PureAffineExpr parseNegateExpression(const PureAffineExpr &lhs); + PureAffineExpr parseIntegerExpr(); + PureAffineExpr parseBareIdExpr(); + + PureAffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, + PureAffineExpr &&lhs, + PureAffineExpr &&rhs, SMLoc opLoc); + PureAffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, PureAffineExpr &&lhs, + PureAffineExpr &&rhs); + PureAffineExpr parseAffineOperandExpr(const PureAffineExpr &lhs); + PureAffineExpr parseAffineLowPrecOpExpr(PureAffineExpr &&llhs, + AffineLowPrecOp llhsOp); + PureAffineExpr parseAffineHighPrecOpExpr(PureAffineExpr &&llhs, + AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc); + PureAffineExpr parseAffineConstraint(bool *isEq); + + const SourceMgr &sourceMgr; + ParserState state; + ParseInfo info; + SmallVector, 4> dimsAndSymbols; +}; +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H diff --git a/mlir/lib/Analysis/Presburger/Parser/Token.cpp b/mlir/lib/Analysis/Presburger/Parser/Token.cpp new file mode 100644 index 0000000000000..d675e0662c88a --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Token.cpp @@ -0,0 +1,64 @@ +//===- Token.cpp - Presburger Token Implementation --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Token class for the Presburger textual form. +// +//===----------------------------------------------------------------------===// + +#include "Token.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/SMLoc.h" +#include + +using namespace mlir::presburger; + +SMLoc Token::getLoc() const { return SMLoc::getFromPointer(spelling.data()); } + +SMLoc Token::getEndLoc() const { + return SMLoc::getFromPointer(spelling.data() + spelling.size()); +} + +SMRange Token::getLocRange() const { return SMRange(getLoc(), getEndLoc()); } + +/// For an integer token, return its value as a uint64_t. If it doesn't fit, +/// return std::nullopt. +std::optional Token::getUInt64IntegerValue(StringRef spelling) { + uint64_t result = 0; + if (spelling.getAsInteger(10, result)) + return std::nullopt; + return result; +} + +/// Given a punctuation or keyword token kind, return the spelling of the +/// token as a string. Warning: This will abort on markers, identifiers and +/// literal tokens since they have no fixed spelling. +StringRef Token::getTokenSpelling(Kind kind) { + switch (kind) { + default: + llvm_unreachable("This token kind has no fixed spelling"); +#define TOK_PUNCTUATION(NAME, SPELLING) \ + case NAME: \ + return SPELLING; +#define TOK_KEYWORD(SPELLING) \ + case kw_##SPELLING: \ + return #SPELLING; +#include "TokenKinds.def" + } +} + +/// Return true if this is one of the keyword token kinds (e.g. kw_if). +bool Token::isKeyword() const { + switch (kind) { + default: + return false; +#define TOK_KEYWORD(SPELLING) \ + case kw_##SPELLING: \ + return true; +#include "TokenKinds.def" + } +} diff --git a/mlir/lib/Analysis/Presburger/Parser/Token.h b/mlir/lib/Analysis/Presburger/Parser/Token.h new file mode 100644 index 0000000000000..5c84c1ba820ee --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Token.h @@ -0,0 +1,91 @@ +//===- Token.h - Presburger Token Interface ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SMLoc.h" +#include + +namespace mlir::presburger { +using llvm::SMLoc; +using llvm::SMRange; +using llvm::StringRef; + +class Token { +public: + enum Kind { +#define TOK_MARKER(NAME) NAME, +#define TOK_IDENTIFIER(NAME) NAME, +#define TOK_LITERAL(NAME) NAME, +#define TOK_PUNCTUATION(NAME, SPELLING) NAME, +#define TOK_KEYWORD(SPELLING) kw_##SPELLING, +#include "TokenKinds.def" + }; + + Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} + + // Return the bytes that make up this token. + StringRef getSpelling() const { return spelling; } + + // Token classification. + Kind getKind() const { return kind; } + bool is(Kind k) const { return kind == k; } + + bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); } + + /// Return true if this token is one of the specified kinds. + template + bool isAny(Kind k1, Kind k2, Kind k3, T... others) const { + if (is(k1)) + return true; + return isAny(k2, k3, others...); + } + + bool isNot(Kind k) const { return kind != k; } + + /// Return true if this token isn't one of the specified kinds. + template + bool isNot(Kind k1, Kind k2, T... others) const { + return !isAny(k1, k2, others...); + } + + /// Return true if this is one of the keyword token kinds (e.g. kw_if). + bool isKeyword() const; + + // Helpers to decode specific sorts of tokens. + + /// For an integer token, return its value as an uint64_t. If it doesn't fit, + /// return std::nullopt. + static std::optional getUInt64IntegerValue(StringRef spelling); + std::optional getUInt64IntegerValue() const { + return getUInt64IntegerValue(getSpelling()); + } + + // Location processing. + SMLoc getLoc() const; + SMLoc getEndLoc() const; + SMRange getLocRange() const; + + /// Given a punctuation or keyword token kind, return the spelling of the + /// token as a string. Warning: This will abort on markers, identifiers and + /// literal tokens since they have no fixed spelling. + static StringRef getTokenSpelling(Kind kind); + +private: + /// Discriminator that indicates the sort of token this is. + Kind kind; + + /// A reference to the entire token contents; this is always a pointer into + /// a memory buffer owned by the source manager. + StringRef spelling; +}; +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H diff --git a/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def b/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def new file mode 100644 index 0000000000000..3285e77a40975 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def @@ -0,0 +1,72 @@ +//===- TokenKinds.def - Presburger Parser Token Description -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is intended to be #include'd multiple times to extract information +// about tokens for various clients in the lexer. +// +//===----------------------------------------------------------------------===// + +#if !defined(TOK_MARKER) && !defined(TOK_IDENTIFIER) && \ + !defined(TOK_LITERAL) && !defined(TOK_PUNCTUATION) && \ + !defined(TOK_KEYWORD) +#error Must define one of the TOK_ macros. +#endif + +#ifndef TOK_MARKER +#define TOK_MARKER(X) +#endif +#ifndef TOK_IDENTIFIER +#define TOK_IDENTIFIER(NAME) +#endif +#ifndef TOK_LITERAL +#define TOK_LITERAL(NAME) +#endif +#ifndef TOK_PUNCTUATION +#define TOK_PUNCTUATION(NAME, SPELLING) +#endif +#ifndef TOK_KEYWORD +#define TOK_KEYWORD(SPELLING) +#endif + +// Markers +TOK_MARKER(eof) +TOK_MARKER(error) + +// Identifiers. +TOK_IDENTIFIER(bare_identifier) // foo + +// Literals +TOK_LITERAL(integer) // 42 + +// Punctuation. +TOK_PUNCTUATION(arrow, "->") +TOK_PUNCTUATION(colon, ":") +TOK_PUNCTUATION(comma, ",") +TOK_PUNCTUATION(equal, "=") +TOK_PUNCTUATION(greater, ">") +TOK_PUNCTUATION(l_brace, "{") +TOK_PUNCTUATION(l_paren, "(") +TOK_PUNCTUATION(l_square, "[") +TOK_PUNCTUATION(less, "<") +TOK_PUNCTUATION(minus, "-") +TOK_PUNCTUATION(plus, "+") +TOK_PUNCTUATION(r_brace, "}") +TOK_PUNCTUATION(r_paren, ")") +TOK_PUNCTUATION(r_square, "]") +TOK_PUNCTUATION(star, "*") + +// Keywords. These turn "foo" into Token::kw_foo enums. +TOK_KEYWORD(ceildiv) +TOK_KEYWORD(floordiv) +TOK_KEYWORD(mod) + +#undef TOK_MARKER +#undef TOK_IDENTIFIER +#undef TOK_LITERAL +#undef TOK_PUNCTUATION +#undef TOK_KEYWORD diff --git a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp index 5e279b542fdf9..135697ec2d6a5 100644 --- a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp +++ b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp @@ -1,6 +1,6 @@ #include "mlir/Analysis/Presburger/Barvinok.h" -#include "./Utils.h" -#include "Parser.h" +#include "Utils.h" +#include "mlir/Analysis/Presburger/Parser.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt index b69f514711337..d65e32917a35d 100644 --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -6,7 +6,6 @@ add_mlir_unittest(MLIRPresburgerTests IntegerRelationTest.cpp LinearTransformTest.cpp MatrixTest.cpp - Parser.h ParserTest.cpp PresburgerSetTest.cpp PresburgerRelationTest.cpp @@ -19,6 +18,5 @@ add_mlir_unittest(MLIRPresburgerTests target_link_libraries(MLIRPresburgerTests PRIVATE MLIRPresburger - MLIRAffineAnalysis - MLIRParser + MLIRPresburgerParser ) diff --git a/mlir/unittests/Analysis/Presburger/FractionTest.cpp b/mlir/unittests/Analysis/Presburger/FractionTest.cpp index c9fad953dacd5..92bdd57d6cf82 100644 --- a/mlir/unittests/Analysis/Presburger/FractionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/FractionTest.cpp @@ -1,5 +1,4 @@ #include "mlir/Analysis/Presburger/Fraction.h" -#include "./Utils.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp index 3fc68cddaad00..759f4d6cc5194 100644 --- a/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/GeneratingFunction.h" -#include "./Utils.h" +#include "Utils.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp index f64bb240b4ee4..bc0f7a0426d08 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" #include "Utils.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/PWMAFunction.h" +#include "mlir/Analysis/Presburger/Parser.h" #include "mlir/Analysis/Presburger/Simplex.h" #include diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp index 7df500bc9568a..a6e02f3982936 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/IntegerRelation.h" -#include "Parser.h" +#include "mlir/Analysis/Presburger/Parser.h" #include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "mlir/Analysis/Presburger/Simplex.h" diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp index cb8df8b346011..2da49d7793b72 100644 --- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp +++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp @@ -7,8 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/Matrix.h" -#include "./Utils.h" +#include "Utils.h" #include "mlir/Analysis/Presburger/Fraction.h" +#include "mlir/Analysis/Presburger/PresburgerRelation.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp index ee2931e78185c..22557414b8cb0 100644 --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -10,11 +10,8 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" - #include "mlir/Analysis/Presburger/PWMAFunction.h" -#include "mlir/Analysis/Presburger/PresburgerRelation.h" -#include "mlir/IR/MLIRContext.h" +#include "mlir/Analysis/Presburger/Parser.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/ParserTest.cpp b/mlir/unittests/Analysis/Presburger/ParserTest.cpp index 06b728cd1a8fa..81026cf4ee0a5 100644 --- a/mlir/unittests/Analysis/Presburger/ParserTest.cpp +++ b/mlir/unittests/Analysis/Presburger/ParserTest.cpp @@ -13,7 +13,7 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" +#include "mlir/Analysis/Presburger/Parser.h" #include diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp index ad71bb32a0688..b8e21dcf76200 100644 --- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/PresburgerRelation.h" -#include "Parser.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/Parser.h" #include "mlir/Analysis/Presburger/Simplex.h" #include diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp index 8e31a8bb2030b..ec265a71b38cd 100644 --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -14,10 +14,9 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" #include "Utils.h" +#include "mlir/Analysis/Presburger/Parser.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" -#include "mlir/IR/MLIRContext.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp b/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp index a84f0234067ab..bc3d788aeb4c3 100644 --- a/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp +++ b/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/QuasiPolynomial.h" -#include "./Utils.h" +#include "Utils.h" #include "mlir/Analysis/Presburger/Fraction.h" #include #include @@ -137,4 +137,4 @@ TEST(QuasiPolynomialTest, simplify) { {{{Fraction(1, 1), Fraction(3, 4), Fraction(5, 3)}, {Fraction(2, 1), Fraction(0, 1), Fraction(0, 1)}}, {{Fraction(1, 1), Fraction(4, 5), Fraction(6, 5)}}})); -} \ No newline at end of file +} diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp index 63d0243808555..59fe453eb6a05 100644 --- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp +++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp @@ -6,11 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" -#include "Utils.h" - #include "mlir/Analysis/Presburger/Simplex.h" -#include "mlir/IR/MLIRContext.h" +#include "mlir/Analysis/Presburger/Parser.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h index 58b9267168540..89d1b0e9127f6 100644 --- a/mlir/unittests/Analysis/Presburger/Utils.h +++ b/mlir/unittests/Analysis/Presburger/Utils.h @@ -14,7 +14,6 @@ #define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_UTILS_H #include "mlir/Analysis/Presburger/GeneratingFunction.h" -#include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/Matrix.h" #include "mlir/Analysis/Presburger/QuasiPolynomial.h" From c0032043d6493c3147f97f20344265384065cb87 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Tue, 2 Jul 2024 12:11:38 +0100 Subject: [PATCH 2/3] Presburger/Parser: git-clang-format --- .../Analysis/Presburger/Parser/ParseStructs.h | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h index ecf8428acd656..fcafcf940e36f 100644 --- a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h +++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h @@ -74,9 +74,7 @@ struct CoefficientVector { } ArrayRef getCoefficients() const { return coefficients; } - int64_t getConstant() const { - return coefficients[info.getConstantIdx()]; - } + int64_t getConstant() const { return coefficients[info.getConstantIdx()]; } size_t size() const { return coefficients.size(); } operator ArrayRef() const { return coefficients; } void resize(size_t size) { coefficients.resize(size); } @@ -223,20 +221,14 @@ struct PureAffineExprImpl { constexpr bool isMod() const { return kind == DivKind::Mod; } constexpr bool hasDivisor() const { return divisor != 1; } - bool isLinear() const { - return nestedDivTerms.empty() && !hasDivisor(); - } + bool isLinear() const { return nestedDivTerms.empty() && !hasDivisor(); } constexpr int64_t getDivisor() const { return divisor; } constexpr int64_t getMulFactor() const { return mulFactor; } bool isConstant() const { return nestedDivTerms.empty() && getLinearDividend().isConstant(); } - int64_t getConstant() const { - return getLinearDividend().getConstant(); - } - size_t hash() const { - return std::hash{}(this); - } + int64_t getConstant() const { return getLinearDividend().getConstant(); } + size_t hash() const { return std::hash{}(this); } ArrayRef getNestedDivTerms() const { return nestedDivTerms; } PureAffineExprImpl &mulConstant(int64_t c) { From f22875c140328bac9c78e724bc70a3bb4ce90f2b Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Sun, 7 Jul 2024 15:16:53 +0100 Subject: [PATCH 3/3] Presburger/Parser: address review --- .../include/mlir/Analysis/Presburger/Matrix.h | 2 +- .../include/mlir/Analysis/Presburger/Parser.h | 15 +++- .../Analysis/Presburger/PresburgerSpace.h | 54 ++++++++--- .../Analysis/Presburger/Parser/Flattener.cpp | 4 +- .../Analysis/Presburger/Parser/Flattener.h | 8 +- .../Presburger/Parser/ParseStructs.cpp | 25 +++--- .../Analysis/Presburger/Parser/ParseStructs.h | 90 +++++++------------ .../Analysis/Presburger/Parser/ParserImpl.cpp | 31 +++---- .../Analysis/Presburger/Parser/ParserImpl.h | 2 +- 9 files changed, 121 insertions(+), 110 deletions(-) diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h index 9b9cfd178ce2d..c68c7237fe7df 100644 --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -103,7 +103,7 @@ class Matrix { /// Set the specified row to `elems`. void setRow(unsigned row, ArrayRef elems); - /// Add the specified row to `elems`. + /// Add `elems` to the specified row. void addToRow(unsigned row, ArrayRef elems); /// Insert columns having positions pos, pos + 1, ... pos + count - 1. diff --git a/mlir/include/mlir/Analysis/Presburger/Parser.h b/mlir/include/mlir/Analysis/Presburger/Parser.h index 0e117e51e203e..49c8e32b4a0f8 100644 --- a/mlir/include/mlir/Analysis/Presburger/Parser.h +++ b/mlir/include/mlir/Analysis/Presburger/Parser.h @@ -21,10 +21,21 @@ namespace mlir::presburger { using llvm::StringRef; -/// Parses an IntegerPolyhedron from a StringRef. +/// Parse an IntegerPolyhedron from a StringRef. +/// +/// integer-set ::= dim-and-symbol-id-lists `:` +/// '(' affine-constraint-conjunction? ')' +/// affine-constraint-conjunction ::= affine-constraint (`,` +/// affine-constraint)* +/// IntegerPolyhedron parseIntegerPolyhedron(StringRef str); -/// Parses a MultiAffineFunction from a StringRef. +/// Parse a MultiAffineFunction from a StringRef. +/// +/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr +/// +/// multi-dim-affine-expr ::= `(` `)` +/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` MultiAffineFunction parseMultiAffineFunction(StringRef str); /// Parse a list of StringRefs to IntegerRelation and combine them into a diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h index cff7957989871..3e24517e17c0d 100644 --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -173,6 +173,7 @@ class PresburgerSpace { return PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols, numLocals); } + PresburgerSpace() = default; /// Get the domain/range space of this space. The returned space is a set /// space. @@ -182,18 +183,43 @@ class PresburgerSpace { /// Get the space without local variables. PresburgerSpace getSpaceWithoutLocals() const; - unsigned getNumDomainVars() const { return numDomain; } - unsigned getNumRangeVars() const { return numRange; } - unsigned getNumSetDimVars() const { return numRange; } - unsigned getNumSymbolVars() const { return numSymbols; } - unsigned getNumLocalVars() const { return numLocals; } + constexpr unsigned getNumDomainVars() const { return numDomain; } + constexpr unsigned &numDomainVars() { return numDomain; } + constexpr unsigned getNumRangeVars() const { return numRange; } + constexpr unsigned &numRangeVars() { return numRange; } + constexpr unsigned getNumSetDimVars() const { return numRange; } + constexpr unsigned &numSetDimVars() { return numRange; } + constexpr unsigned getNumSymbolVars() const { return numSymbols; } + constexpr unsigned &numSymbolVars() { return numSymbols; } + constexpr unsigned getNumLocalVars() const { return numLocals; } + constexpr unsigned &numLocalVars() { return numLocals; } + constexpr unsigned getNumDimVars() const { return numDomain + numRange; } + constexpr unsigned getNumDimAndSymbolVars() const { + return getNumDimVars() + getNumSymbolVars(); + } + constexpr unsigned getNumVars() const { + return getNumDimAndSymbolVars() + getNumLocalVars(); + } + constexpr unsigned getNumCols() const { return getNumVars() + 1; } + + constexpr unsigned getSetDimStartIdx() const { return 0; } + constexpr unsigned getSymbolStartIdx() const { return getNumDimVars(); } + constexpr unsigned getLocalVarStartIdx() const { + return getNumDimAndSymbolVars(); + } + constexpr unsigned getConstantIdx() const { return getNumCols() - 1; } - unsigned getNumDimVars() const { return numDomain + numRange; } - unsigned getNumDimAndSymbolVars() const { - return numDomain + numRange + numSymbols; + constexpr bool isSetDimIdx(unsigned i) const { + return i < getSymbolStartIdx(); + } + constexpr bool isSymbolIdx(unsigned i) const { + return i >= getSymbolStartIdx() && i < getLocalVarStartIdx(); + } + constexpr bool isLocalVarIdx(unsigned i) const { + return i >= getLocalVarStartIdx() && i < getConstantIdx(); } - unsigned getNumVars() const { - return numDomain + numRange + numSymbols + numLocals; + constexpr bool isConstantIdx(unsigned i) const { + return i == getConstantIdx(); } /// Get the number of vars of the specified kind. @@ -321,18 +347,18 @@ class PresburgerSpace { private: // Number of variables corresponding to domain variables. - unsigned numDomain; + unsigned numDomain = 0; // Number of variables corresponding to range variables. - unsigned numRange; + unsigned numRange = 0; /// Number of variables corresponding to symbols (unknown but constant for /// analysis). - unsigned numSymbols; + unsigned numSymbols = 0; /// Number of variables corresponding to locals (variables corresponding /// to existentially quantified variables). - unsigned numLocals; + unsigned numLocals = 0; /// Stores whether or not identifiers are being used in this space. bool usingIds = false; diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp index 90f645a9f20bd..3413c1b2bbe31 100644 --- a/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp +++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp @@ -22,7 +22,7 @@ void Flattener::visitDiv(const PureAffineExprImpl &div) { // First construct the linear part of the divisor. auto dividend = div.collectLinearTerms().getPadded( - info.getLocalVarStartIdx() + localExprs.size() + 1); + space.getLocalVarStartIdx() + localExprs.size() + 1); // Next, insert the non-linear coefficients. for (const auto &[hash, adjustedMulFactor, adjustedLinearTerm] : @@ -38,7 +38,7 @@ void Flattener::visitDiv(const PureAffineExprImpl &div) { CoefficientVector expansion = lookupModExpansion(hash) .value_or(adjustedLinearTerm) - .getPadded(info.getLocalVarStartIdx() + localExprs.size() + 1); + .getPadded(space.getLocalVarStartIdx() + localExprs.size() + 1); dividend += expansion; // If this is a mod, insert the new computed expansion, which is the diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.h b/mlir/lib/Analysis/Presburger/Parser/Flattener.h index d3043ed67f56d..6bb5378cf8156 100644 --- a/mlir/lib/Analysis/Presburger/Parser/Flattener.h +++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.h @@ -34,7 +34,7 @@ class Flattener : public FinalParseResult { Flattener(FinalParseResult &&parseResult) : FinalParseResult(std::move(parseResult)), - flatMatrix(info.numExprs, info.getNumCols()) {} + flatMatrix(exprs.size(), space.getNumCols()) {} std::pair flatten(); @@ -44,17 +44,17 @@ class Flattener : public FinalParseResult { void addToRow(unsigned row, const CoefficientVector &l) { flatMatrix.addToRow(row, - getDynamicAPIntVec(l.getPadded(info.getNumCols()))); + getDynamicAPIntVec(l.getPadded(space.getNumCols()))); } void setRow(unsigned row, const CoefficientVector &l) { - flatMatrix.setRow(row, getDynamicAPIntVec(l.getPadded(info.getNumCols()))); + flatMatrix.setRow(row, getDynamicAPIntVec(l.getPadded(space.getNumCols()))); } unsigned lookupLocal(size_t hash) { const auto *it = find(localExprs, hash); assert(it != localExprs.end() && "Local expression not found; walking from inner to outer?"); - return info.getLocalVarStartIdx() + it - localExprs.begin(); + return space.getLocalVarStartIdx() + it - localExprs.begin(); } std::optional lookupModExpansion(size_t hash) { return localModExpansion.contains(hash) diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp index b937ae27a8fea..e1f55c3418fed 100644 --- a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp +++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "ParseStructs.h" +#include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" @@ -24,7 +25,7 @@ using llvm::mod; CoefficientVector PureAffineExprImpl::collectLinearTerms() const { CoefficientVector nestedLinear = std::accumulate( - nestedDivTerms.begin(), nestedDivTerms.end(), CoefficientVector(info), + nestedDivTerms.begin(), nestedDivTerms.end(), CoefficientVector(space), [](const CoefficientVector &acc, const PureAffineExpr &div) { return acc + div->getLinearDividend(); }); @@ -46,7 +47,7 @@ PureAffineExprImpl::getNonLinearCoeffs() const { }; auto adjustedLinearTerm = [](PureAffineExprImpl &div) { return div.kind == DivKind::Mod ? div.linearDividend *= div.mulFactor - : CoefficientVector(div.info); + : CoefficientVector(div.space); }; if (hasDivisor()) for (const auto &toplevel : nestedDivTerms) @@ -81,7 +82,7 @@ PureAffineExpr mlir::presburger::operator+(PureAffineExpr &&lhs, std::move(lhs->addLinearTerm(rhs->getLinearDividend()))); if (!(lhs->hasDivisor() ^ rhs->hasDivisor())) { - auto ret = PureAffineExprImpl(lhs->info); + auto ret = PureAffineExprImpl(lhs->space); ret.addDivTerm(std::move(*lhs)); ret.addDivTerm(std::move(*rhs)); return std::make_unique(std::move(ret)); @@ -104,7 +105,7 @@ PureAffineExpr mlir::presburger::operator*(PureAffineExpr &&expr, int64_t c) { } PureAffineExpr mlir::presburger::operator+(PureAffineExpr &&expr, int64_t c) { - return std::move(expr) + std::make_unique(expr->info, c); + return std::move(expr) + std::make_unique(expr->space, c); } PureAffineExpr mlir::presburger::div(PureAffineExpr &÷nd, int64_t divisor, @@ -116,7 +117,7 @@ PureAffineExpr mlir::presburger::div(PureAffineExpr &÷nd, int64_t divisor, int64_t c = kind == DivKind::FloorDiv ? divideFloorSigned(dividend->getConstant(), divisor) : mod(dividend->getConstant(), divisor); - return std::make_unique(dividend->info, c); + return std::make_unique(dividend->space, c); } // Factor out mul, using gcd internally. @@ -152,7 +153,7 @@ PureAffineExpr mlir::presburger::div(PureAffineExpr &÷nd, int64_t divisor, // x floordiv 1 <=> x, x % 1 <=> 0 return divisor == 1 ? (dividend->isMod() - ? std::make_unique(dividend->info) + ? std::make_unique(dividend->space) : std::move(dividend)) : std::make_unique(std::move(*dividend), divisor, kind); @@ -166,9 +167,9 @@ enum class BindingStrength { }; static void printCoefficient(int64_t c, bool &isExprBegin, - const ParseInfo &info, int idx = -1) { - bool isConstant = idx != -1 && info.isConstantIdx(idx); - bool isDimOrSymbol = idx != -1 && !info.isConstantIdx(idx); + const PresburgerSpace &space, int idx = -1) { + bool isConstant = idx != -1 && space.isConstantIdx(idx); + bool isDimOrSymbol = idx != -1 && !space.isConstantIdx(idx); if (!c) return; if (!isExprBegin) @@ -188,7 +189,7 @@ static void printCoefficient(int64_t c, bool &isExprBegin, if (isDimOrSymbol) { if (std::abs(c) != 1) dbgs() << " * "; - dbgs() << (info.isDimIdx(idx) ? 'd' : 's') << idx; + dbgs() << (space.isSetDimIdx(idx) ? 'd' : 's') << idx; isExprBegin = false; } } @@ -202,7 +203,7 @@ static bool printCoefficientVec( isExprBegin = true; } for (auto [idx, c] : enumerate(linear.getCoefficients())) - printCoefficient(c, isExprBegin, linear.info, idx); + printCoefficient(c, isExprBegin, linear.space, idx); if (enclosingTightness == BindingStrength::Strong && linear.hasMultipleCoefficients()) @@ -223,7 +224,7 @@ printAffineExpr(const PureAffineExprImpl &expr, bool isExprBegin = true, const auto &mulFactor = div.getMulFactor(); const auto &nestedDivs = div.getNestedDivTerms(); - printCoefficient(mulFactor, isExprBegin, expr.info); + printCoefficient(mulFactor, isExprBegin, expr.space); if (std::abs(mulFactor) != 1) dbgs() << " * "; diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h index fcafcf940e36f..92949c864743d 100644 --- a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h +++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h @@ -21,60 +21,31 @@ using llvm::ArrayRef; using llvm::SmallVector; using llvm::SmallVectorImpl; -/// This structure is central to the parser and flattener, and holds the number -/// of dimensions, symbols, locals, and the constant term. -struct ParseInfo { - unsigned numDims = 0; - unsigned numSymbols = 0; - unsigned numExprs = 0; - unsigned numDivs = 0; - - constexpr unsigned getDimStartIdx() const { return 0; } - constexpr unsigned getSymbolStartIdx() const { return numDims; } - constexpr unsigned getLocalVarStartIdx() const { - return numDims + numSymbols; - } - constexpr unsigned getNumCols() const { - return numDims + numSymbols + numDivs + 1; - } - constexpr unsigned getConstantIdx() const { return getNumCols() - 1; } - - constexpr bool isDimIdx(unsigned i) const { return i < getSymbolStartIdx(); } - constexpr bool isSymbolIdx(unsigned i) const { - return i >= getSymbolStartIdx() && i < getLocalVarStartIdx(); - } - constexpr bool isLocalVarIdx(unsigned i) const { - return i >= getLocalVarStartIdx() && i < getConstantIdx(); - } - constexpr bool isConstantIdx(unsigned i) const { - return i == getConstantIdx(); - } -}; - /// Helper for storing coefficients in canonical form: dims followed by symbols, /// followed by locals, and finally the constant term. /// /// (x, y)[a, b]: y * 91 + x + 3 * a + 7 /// coefficients: [1, 91, 3, 0, 7] struct CoefficientVector { - ParseInfo info; + PresburgerSpace space; SmallVector coefficients; - CoefficientVector(const ParseInfo &info, int64_t c = 0) : info(info) { - coefficients.resize(info.getNumCols()); - coefficients[info.getConstantIdx()] = c; + CoefficientVector(const PresburgerSpace &space, int64_t c = 0) + : space(space) { + coefficients.resize(space.getNumCols()); + coefficients[space.getConstantIdx()] = c; } // Copyable and movable CoefficientVector(const CoefficientVector &o) = default; CoefficientVector &operator=(const CoefficientVector &o) = default; CoefficientVector(CoefficientVector &&o) - : info(o.info), coefficients(std::move(o.coefficients)) { + : space(o.space), coefficients(std::move(o.coefficients)) { o.coefficients.clear(); } ArrayRef getCoefficients() const { return coefficients; } - int64_t getConstant() const { return coefficients[info.getConstantIdx()]; } + int64_t getConstant() const { return coefficients[space.getConstantIdx()]; } size_t size() const { return coefficients.size(); } operator ArrayRef() const { return coefficients; } void resize(size_t size) { coefficients.resize(size); } @@ -126,7 +97,7 @@ struct CoefficientVector { CoefficientVector getPadded(size_t newSize) const { assert(newSize >= size() && "Padding size should be greater than expr size"); - CoefficientVector ret(info); + CoefficientVector ret(space); ret.resize(newSize); // Start constructing the result by taking the dims and symbols of the @@ -176,7 +147,7 @@ enum class DivKind { FloorDiv, Mod }; /// /// Where div = floordiv|mod; ceildiv is pre-reduced struct PureAffineExprImpl { - ParseInfo info; + PresburgerSpace space; DivKind kind = DivKind::FloorDiv; using PureAffineExpr = std::unique_ptr; @@ -185,29 +156,30 @@ struct PureAffineExprImpl { int64_t divisor = 1; SmallVector nestedDivTerms; - PureAffineExprImpl(const ParseInfo &info, int64_t c = 0) - : info(info), linearDividend(info, c) {} - PureAffineExprImpl(const ParseInfo &info, DimOrSymbolExpr idExpr) - : PureAffineExprImpl(info) { + PureAffineExprImpl(const PresburgerSpace &space, int64_t c = 0) + : space(space), linearDividend(space, c) {} + PureAffineExprImpl(const PresburgerSpace &space, DimOrSymbolExpr idExpr) + : PureAffineExprImpl(space) { auto [kind, pos] = idExpr; unsigned startIdx = kind == DimOrSymbolKind::Symbol - ? info.getSymbolStartIdx() - : info.getDimStartIdx(); + ? space.getSymbolStartIdx() + : space.getSetDimStartIdx(); linearDividend.coefficients[startIdx + pos] = 1; } PureAffineExprImpl(const CoefficientVector &linearDividend, int64_t divisor = 1, DivKind kind = DivKind::FloorDiv) - : info(linearDividend.info), kind(kind), linearDividend(linearDividend), + : space(linearDividend.space), kind(kind), linearDividend(linearDividend), divisor(divisor) {} PureAffineExprImpl(PureAffineExprImpl &&div, int64_t divisor, DivKind kind) - : info(div.info), kind(kind), linearDividend(div.info), divisor(divisor) { + : space(div.space), kind(kind), linearDividend(div.space), + divisor(divisor) { addDivTerm(std::move(div)); } // Non-copyable, only movable PureAffineExprImpl(const PureAffineExprImpl &) = delete; PureAffineExprImpl(PureAffineExprImpl &&o) - : info(o.info), kind(o.kind), mulFactor(o.mulFactor), + : space(o.space), kind(o.kind), mulFactor(o.mulFactor), linearDividend(std::move(o.linearDividend)), divisor(o.divisor), nestedDivTerms(std::move(o.nestedDivTerms)) { o.nestedDivTerms.clear(); @@ -277,23 +249,23 @@ using PureAffineExpr = std::unique_ptr; /// compute the total number of columns, and construct the IntMatrix in the /// Flattener. struct FinalParseResult { - ParseInfo info; + PresburgerSpace space; SmallVector exprs; SmallVector eqFlags; IntegerPolyhedron cst; - FinalParseResult(const ParseInfo &parseInfo, + FinalParseResult(const PresburgerSpace &space, SmallVectorImpl &&exprStack, ArrayRef eqFlagStack) - : info(parseInfo), exprs(std::move(exprStack)), eqFlags(eqFlagStack), - cst(0, 0, info.numDims + info.numSymbols + 1, - PresburgerSpace::getSetSpace(info.numDims, info.numSymbols, 0)) { - auto &i = this->info; - i.numExprs = exprs.size(); - i.numDivs = std::accumulate(exprs.begin(), exprs.end(), 0, - [](unsigned acc, const PureAffineExpr &expr) { - return acc + expr->countNestedDivs(); - }); + : space(space), exprs(std::move(exprStack)), eqFlags(eqFlagStack), + cst(0, 0, space.getNumSetDimVars() + space.getNumSymbolVars() + 1, + PresburgerSpace::getSetSpace(space.getNumSetDimVars(), + space.getNumSymbolVars(), 0)) { + this->space.numLocalVars() = + std::accumulate(exprs.begin(), exprs.end(), 0, + [](unsigned acc, const PureAffineExpr &expr) { + return acc + expr->countNestedDivs(); + }); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -337,7 +309,7 @@ inline PureAffineExpr ceildiv(PureAffineExpr &&expr, int64_t c) { // of 1. inline PureAffineExpr canonicalize(PureAffineExpr &&expr) { if (expr->hasDivisor()) - expr->addLinearTerm(CoefficientVector(expr->info)); + expr->addLinearTerm(CoefficientVector(expr->space)); return expr; } } // namespace mlir::presburger diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp index 20b2b6d08638e..0eae49dc1a007 100644 --- a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp +++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp @@ -132,7 +132,7 @@ PureAffineExpr ParserImpl::parseBareIdExpr() { for (const auto &entry : dimsAndSymbols) { if (entry.first == sRef) { consumeToken(); - return std::make_unique(info, entry.second); + return std::make_unique(space, entry.second); } } @@ -187,7 +187,7 @@ PureAffineExpr ParserImpl::parseIntegerExpr() { return emitError("constant too large"), nullptr; consumeToken(Token::integer); - return std::make_unique(info, ret); + return std::make_unique(space, ret); } /// Parses an expression that can be a valid operand of an affine expression. @@ -466,7 +466,8 @@ bool ParserImpl::parseIdentifierDefinition(DimOrSymbolExpr idExpr) { /// Parse the list of dimensional identifiers to an affine map. bool ParserImpl::parseDimIdList() { auto parseElt = [&]() -> bool { - return parseIdentifierDefinition({DimOrSymbolKind::DimId, info.numDims++}); + return parseIdentifierDefinition( + {DimOrSymbolKind::DimId, space.numSetDimVars()++}); }; return parseCommaSepeatedList(Delimiter::Paren, parseElt, " in dimensional identifier list"); @@ -476,7 +477,7 @@ bool ParserImpl::parseDimIdList() { bool ParserImpl::parseSymbolIdList() { auto parseElt = [&]() -> bool { return parseIdentifierDefinition( - {DimOrSymbolKind::Symbol, info.numSymbols++}); + {DimOrSymbolKind::Symbol, space.numSymbolVars()++}); }; return parseCommaSepeatedList(Delimiter::Square, parseElt, " in symbol list"); } @@ -486,7 +487,7 @@ bool ParserImpl::parseDimAndOptionalSymbolIdList() { if (!parseDimIdList()) return false; if (getToken().isNot(Token::l_square)) { - info.numSymbols = 0; + space.numSymbolVars() = 0; return true; } return parseSymbolIdList(); @@ -574,16 +575,16 @@ FinalParseResult ParserImpl::parseAffineMapOrIntegerSet() { " in affine map range")) llvm_unreachable("expected affine map range"); - return {info, std::move(exprs), eqFlags}; + return {space, std::move(exprs), eqFlags}; } if (!parseToken(Token::colon, "expected '->' or ':'")) llvm_unreachable("Unexpected token"); /// Parse the constraints that are part of an integer set definition. - /// integer-set-inline - /// ::= dim-and-symbol-id-lists `:` - /// '(' affine-constraint-conjunction? ')' + /// + /// integer-set ::= dim-and-symbol-id-lists `:` + /// '(' affine-constraint-conjunction? ')' /// affine-constraint-conjunction ::= affine-constraint (`,` /// affine-constraint)* /// @@ -591,21 +592,21 @@ FinalParseResult ParserImpl::parseAffineMapOrIntegerSet() { " in integer set constraint list")) llvm_unreachable("expected integer set"); - return {info, std::move(exprs), eqFlags}; + return {space, std::move(exprs), eqFlags}; } static MultiAffineFunction getMAF(FinalParseResult &&parseResult) { - auto info = parseResult.info; + auto space = parseResult.space; + space.numDomainVars() = space.getNumSetDimVars(); + space.numRangeVars() = parseResult.exprs.size(); + auto [flatMatrix, cst] = Flattener(std::move(parseResult)).flatten(); DivisionRepr divs = cst.getLocalReprs(); assert(divs.hasAllReprs() && "AffineMap cannot produce divs without local representation"); - return MultiAffineFunction( - PresburgerSpace::getRelationSpace(info.numDims, info.numExprs, - info.numSymbols, divs.getNumDivs()), - flatMatrix, divs); + return MultiAffineFunction(space, flatMatrix, divs); } static IntegerPolyhedron getPoly(FinalParseResult &&parseResult) { diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h index 7fef4ded21673..5c3e061d0f896 100644 --- a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h +++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h @@ -152,7 +152,7 @@ class ParserImpl { const SourceMgr &sourceMgr; ParserState state; - ParseInfo info; + PresburgerSpace space; SmallVector, 4> dimsAndSymbols; }; } // namespace mlir::presburger