diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h index 054eb7b26d06ee..9b9cfd178ce2dd 100644 --- a/mlir/include/mlir/Analysis/Presburger/Matrix.h +++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h @@ -103,6 +103,9 @@ class Matrix { /// Set the specified row to `elems`. void setRow(unsigned row, ArrayRef elems); + /// Add the specified row to `elems`. + void addToRow(unsigned row, ArrayRef elems); + /// Insert columns having positions pos, pos + 1, ... pos + count - 1. /// Columns that were at positions 0 to pos - 1 will stay where they are; /// columns that were at positions pos to nColumns - 1 will be pushed to the diff --git a/mlir/unittests/Analysis/Presburger/Parser.h b/mlir/include/mlir/Analysis/Presburger/Parser.h similarity index 65% rename from mlir/unittests/Analysis/Presburger/Parser.h rename to mlir/include/mlir/Analysis/Presburger/Parser.h index 75842fb054e2b9..0e117e51e203eb 100644 --- a/mlir/unittests/Analysis/Presburger/Parser.h +++ b/mlir/include/mlir/Analysis/Presburger/Parser.h @@ -11,32 +11,26 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H -#define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_H #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/PWMAFunction.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" -#include "mlir/AsmParser/AsmParser.h" -#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/IntegerSet.h" - -namespace mlir { -namespace presburger { - -/// Parses an IntegerPolyhedron from a StringRef. It is expected that the string -/// represents a valid IntegerSet. -inline IntegerPolyhedron parseIntegerPolyhedron(StringRef str) { - MLIRContext context(MLIRContext::Threading::DISABLED); - return affine::FlatAffineValueConstraints(parseIntegerSet(str, &context)); -} + +namespace mlir::presburger { +using llvm::StringRef; + +/// Parses an IntegerPolyhedron from a StringRef. +IntegerPolyhedron parseIntegerPolyhedron(StringRef str); + +/// Parses a MultiAffineFunction from a StringRef. +MultiAffineFunction parseMultiAffineFunction(StringRef str); /// Parse a list of StringRefs to IntegerRelation and combine them into a -/// PresburgerSet by using the union operation. It is expected that the strings -/// are all valid IntegerSet representation and that all of them have compatible -/// spaces. +/// PresburgerSet by using the union operation. It is expected that the +/// strings are all valid IntegerSet representation and that all of them have +/// compatible spaces. inline PresburgerSet parsePresburgerSet(ArrayRef strs) { assert(!strs.empty() && "strs should not be empty"); @@ -47,25 +41,10 @@ inline PresburgerSet parsePresburgerSet(ArrayRef strs) { return result; } -inline MultiAffineFunction parseMultiAffineFunction(StringRef str) { - MLIRContext context(MLIRContext::Threading::DISABLED); - - // TODO: Add default constructor for MultiAffineFunction. - MultiAffineFunction multiAff(PresburgerSpace::getRelationSpace(), - IntMatrix(0, 1)); - if (getMultiAffineFunctionFromMap(parseAffineMap(str, &context), multiAff) - .failed()) - llvm_unreachable( - "Failed to parse MultiAffineFunction because of semi-affinity"); - return multiAff; -} - inline PWMAFunction parsePWMAF(ArrayRef> pieces) { assert(!pieces.empty() && "At least one piece should be present."); - MLIRContext context(MLIRContext::Threading::DISABLED); - IntegerPolyhedron initDomain = parseIntegerPolyhedron(pieces[0].first); MultiAffineFunction initMultiAff = parseMultiAffineFunction(pieces[0].second); @@ -100,8 +79,6 @@ parsePresburgerRelationFromPresburgerSet(ArrayRef strs, result.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain, 0); return result; } +} // namespace mlir::presburger -} // namespace presburger -} // namespace mlir - -#endif // MLIR_UNITTESTS_ANALYSIS_PRESBURGER_PARSER_H +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_H diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt index 6b9842a6160ab7..03fe3eeb1658ab 100644 --- a/mlir/lib/Analysis/Presburger/CMakeLists.txt +++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Parser) + add_mlir_library(MLIRPresburger Barvinok.cpp IntegerRelation.cpp diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp index 110c5df1af37c0..d7aee64044e8f4 100644 --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -144,6 +144,14 @@ void Matrix::setRow(unsigned row, ArrayRef elems) { at(row, i) = elems[i]; } +template +void Matrix::addToRow(unsigned row, ArrayRef elems) { + assert(elems.size() == getNumColumns() && + "elems size must match row length!"); + for (unsigned i = 0, e = getNumColumns(); i < e; ++i) + at(row, i) += elems[i]; +} + template void Matrix::insertColumn(unsigned pos) { insertColumns(pos, 1); diff --git a/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt b/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt new file mode 100644 index 00000000000000..f708a5c8db949a --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_library(MLIRPresburgerParser + Flattener.cpp + Lexer.cpp + ParserImpl.cpp + ParseStructs.cpp + Token.cpp) diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp new file mode 100644 index 00000000000000..90f645a9f20bdb --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.cpp @@ -0,0 +1,88 @@ +//===- Flattener.cpp - Presburger ParseStruct Flattener ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Flattener class for flattening the parse tree +// produced by the parser for the Presburger library. +// +//===----------------------------------------------------------------------===// + +#include "Flattener.h" +#include "ParseStructs.h" + +using namespace mlir; +using namespace presburger; + +void Flattener::visitDiv(const PureAffineExprImpl &div) { + int64_t divisor = div.getDivisor(); + + // First construct the linear part of the divisor. + auto dividend = div.collectLinearTerms().getPadded( + info.getLocalVarStartIdx() + localExprs.size() + 1); + + // Next, insert the non-linear coefficients. + for (const auto &[hash, adjustedMulFactor, adjustedLinearTerm] : + div.getNonLinearCoeffs()) { + + // adjustedMulFactor will be mulFactor * -divisor in case of mod, and + // mulFactor in case of floordiv. + dividend[lookupLocal(hash)] = adjustedMulFactor; + + // The expansion is either a modExpansion which we previously stored, or the + // adjustedLinearTerm, which is correct in the case when we're encountering + // the innermost mod for the first time. + CoefficientVector expansion = + lookupModExpansion(hash) + .value_or(adjustedLinearTerm) + .getPadded(info.getLocalVarStartIdx() + localExprs.size() + 1); + dividend += expansion; + + // If this is a mod, insert the new computed expansion, which is the + // dividend * mulFactor. + if (div.isMod()) + localModExpansion.insert({div.hash(), dividend * div.getMulFactor()}); + } + + cst.addLocalFloorDiv(dividend, divisor); + localExprs.insert(div.hash()); +} + +void Flattener::flatten(unsigned row, PureAffineExprImpl &div) { + // Visit divs inner to outer. + for (auto &nestedDiv : div.getNestedDivTerms()) + flatten(row, *nestedDiv); + + if (div.hasDivisor()) { + visitDiv(div); + return; + } + + // Hit multiple times every time we have a linear sub-expression, but the + // row is overwritten to consider only the outermost div, which is hit + // last. + + // Set the linear part of the row. + setRow(row, div.getLinearDividend()); + + // Set the non-linear coefficients. + for (const auto &[hash, adjustedMulFactor, adjustedLinearTerm] : + div.getNonLinearCoeffs()) { + flatMatrix(row, lookupLocal(hash)) = adjustedMulFactor; + CoefficientVector expansion = + lookupModExpansion(hash).value_or(adjustedLinearTerm); + addToRow(row, expansion); + } +} + +std::pair Flattener::flatten() { + // Use the same flattener to simplify each expression successively. This way + // local variables / expressions are shared. + for (const auto &[row, expr] : enumerate(exprs)) + flatten(row, *expr); + + return {flatMatrix, cst}; +} diff --git a/mlir/lib/Analysis/Presburger/Parser/Flattener.h b/mlir/lib/Analysis/Presburger/Parser/Flattener.h new file mode 100644 index 00000000000000..d3043ed67f56d5 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Flattener.h @@ -0,0 +1,67 @@ +//===- Flattener.h - Presburger ParseStruct Flattener -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H + +#include "ParseStructs.h" +#include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SetVector.h" + +namespace mlir::presburger { +using llvm::SmallDenseMap; +using llvm::SmallSetVector; + +class Flattener : public FinalParseResult { +public: + // The final flattened result is stored here. + IntMatrix flatMatrix; + + // We maintain a set of divs that we have seen while flattening. The size of + // this set is at most info.numDivs, hitting info.numDivs at the end of the + // flattening, if that expression contains all the possible divs. + SmallSetVector localExprs; + + // We maintain a mapping between local mods and their expansions. The vector + // is the dividend. + SmallDenseMap localModExpansion; + + Flattener(FinalParseResult &&parseResult) + : FinalParseResult(std::move(parseResult)), + flatMatrix(info.numExprs, info.getNumCols()) {} + + std::pair flatten(); + +private: + void flatten(unsigned row, PureAffineExprImpl &div); + void visitDiv(const PureAffineExprImpl &div); + + void addToRow(unsigned row, const CoefficientVector &l) { + flatMatrix.addToRow(row, + getDynamicAPIntVec(l.getPadded(info.getNumCols()))); + } + void setRow(unsigned row, const CoefficientVector &l) { + flatMatrix.setRow(row, getDynamicAPIntVec(l.getPadded(info.getNumCols()))); + } + + unsigned lookupLocal(size_t hash) { + const auto *it = find(localExprs, hash); + assert(it != localExprs.end() && + "Local expression not found; walking from inner to outer?"); + return info.getLocalVarStartIdx() + it - localExprs.begin(); + } + std::optional lookupModExpansion(size_t hash) { + return localModExpansion.contains(hash) + ? std::make_optional(localModExpansion.at(hash)) + : std::nullopt; + } +}; +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_FLATTENER_H diff --git a/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp b/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp new file mode 100644 index 00000000000000..8c037539094a61 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Lexer.cpp @@ -0,0 +1,136 @@ +//===- Lexer.cpp - Presburger Lexer Implementation ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the lexer for the Presburger textual form. +// +//===----------------------------------------------------------------------===// + +#include "Lexer.h" +#include "Token.h" +#include "llvm/ADT/StringSwitch.h" + +using namespace mlir::presburger; + +Lexer::Lexer(const llvm::SourceMgr &sourceMgr) : sourceMgr(sourceMgr) { + auto bufferID = sourceMgr.getMainFileID(); + curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer(); + curPtr = curBuffer.begin(); +} + +Token Lexer::emitError(const char *loc, const llvm::Twine &message) { + sourceMgr.PrintMessage(SMLoc::getFromPointer(loc), llvm::SourceMgr::DK_Error, + message); + return formToken(Token::error, loc); +} + +Token Lexer::lexToken() { + while (true) { + const char *tokStart = curPtr; + + switch (*curPtr++) { + default: + // Handle bare identifiers. + if (isalpha(curPtr[-1])) + return lexBareIdentifierOrKeyword(tokStart); + return emitError(tokStart, "unexpected character"); + + case ' ': + case '\t': + case '\n': + case '\r': + // Handle whitespace. + continue; + + case 0: + // This may be the EOF marker that llvm::MemoryBuffer guarantees will be + // there. + if (curPtr - 1 == curBuffer.end()) + return formToken(Token::eof, tokStart); + return emitError(tokStart, "unexpected character"); + + case ':': + return formToken(Token::colon, tokStart); + case ',': + return formToken(Token::comma, tokStart); + case '(': + return formToken(Token::l_paren, tokStart); + case ')': + return formToken(Token::r_paren, tokStart); + case '{': + return formToken(Token::l_brace, tokStart); + case '}': + return formToken(Token::r_brace, tokStart); + case '[': + return formToken(Token::l_square, tokStart); + case ']': + return formToken(Token::r_square, tokStart); + case '<': + return formToken(Token::less, tokStart); + case '>': + return formToken(Token::greater, tokStart); + case '=': + return formToken(Token::equal, tokStart); + case '+': + return formToken(Token::plus, tokStart); + case '*': + return formToken(Token::star, tokStart); + case '-': + if (*curPtr == '>') { + ++curPtr; + return formToken(Token::arrow, tokStart); + } + return formToken(Token::minus, tokStart); + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + return lexNumber(tokStart); + } + } +} + +/// Lex a bare identifier or keyword that starts with a letter. +/// +/// bare-id ::= (letter|[_]) (letter|digit|[_$.])* +/// +Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_.$]* + while (isalpha(*curPtr) || isdigit(*curPtr) || *curPtr == '_' || + *curPtr == '$' || *curPtr == '.') + ++curPtr; + + // Check to see if this identifier is a keyword. + StringRef spelling(tokStart, curPtr - tokStart); + + Token::Kind kind = llvm::StringSwitch(spelling) +#define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING) +#include "TokenKinds.def" + .Default(Token::bare_identifier); + + return Token(kind, spelling); +} + +/// Lex a number literal. +/// +/// integer-literal ::= digit+ +/// +Token Lexer::lexNumber(const char *tokStart) { + assert(isdigit(curPtr[-1])); + + // Handle the normal decimal case. + while (isdigit(*curPtr)) + ++curPtr; + + return formToken(Token::integer, tokStart); +} diff --git a/mlir/lib/Analysis/Presburger/Parser/Lexer.h b/mlir/lib/Analysis/Presburger/Parser/Lexer.h new file mode 100644 index 00000000000000..1dc458a358a949 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Lexer.h @@ -0,0 +1,56 @@ +//===- Lexer.h - Presburger Lexer Interface ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Presburger Lexer class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H + +#include "Token.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir::presburger { +/// This class breaks up the current file into a token stream. +class Lexer { +public: + explicit Lexer(const llvm::SourceMgr &sourceMgr); + + Token lexToken(); + + /// Change the position of the lexer cursor. The next token we lex will start + /// at the designated point in the input. + void resetPointer(const char *newPointer) { curPtr = newPointer; } + + /// Returns the start of the buffer. + const char *getBufferBegin() { return curBuffer.data(); } + +private: + Token formToken(Token::Kind kind, const char *tokStart) { + return Token(kind, StringRef(tokStart, curPtr - tokStart)); + } + + Token emitError(const char *loc, const llvm::Twine &message); + + // Lexer implementation methods. + Token lexBareIdentifierOrKeyword(const char *tokStart); + Token lexNumber(const char *tokStart); + + const llvm::SourceMgr &sourceMgr; + + StringRef curBuffer; + const char *curPtr; + + Lexer(const Lexer &) = delete; + void operator=(const Lexer &) = delete; +}; +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_LEXER_H diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp new file mode 100644 index 00000000000000..b937ae27a8fea8 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.cpp @@ -0,0 +1,268 @@ +//===- ParseStructs.cpp - Presburger Parse Structrures ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ParseStructs class that the parser for the +// Presburger library parses into. +// +//===----------------------------------------------------------------------===// + +#include "ParseStructs.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir::presburger; +using llvm::dbgs; +using llvm::divideFloorSigned; +using llvm::mod; + +CoefficientVector PureAffineExprImpl::collectLinearTerms() const { + CoefficientVector nestedLinear = std::accumulate( + nestedDivTerms.begin(), nestedDivTerms.end(), CoefficientVector(info), + [](const CoefficientVector &acc, const PureAffineExpr &div) { + return acc + div->getLinearDividend(); + }); + return nestedLinear += linearDividend; +} + +SmallVector, 8> +PureAffineExprImpl::getNonLinearCoeffs() const { + SmallVector, 8> ret; + // dividend `floordiv` divisor <=> q; adjustedMulFactor = 1, + // adjustedLinearTerm is empty. + // + // dividend `mod` divisor <=> dividend - divisor*q; adjustedMulFactor = + // -divisor, adjustedLinearTerm is the linear part of the dividend. + // + // where q is a floordiv id added by the flattener. + auto adjustedMulFactor = [](const PureAffineExprImpl &div) { + return div.mulFactor * (div.kind == DivKind::Mod ? -div.divisor : 1); + }; + auto adjustedLinearTerm = [](PureAffineExprImpl &div) { + return div.kind == DivKind::Mod ? div.linearDividend *= div.mulFactor + : CoefficientVector(div.info); + }; + if (hasDivisor()) + for (const auto &toplevel : nestedDivTerms) + for (const auto &div : toplevel->getNestedDivTerms()) + ret.emplace_back(div->hash(), adjustedMulFactor(*div), + adjustedLinearTerm(*div)); + else + for (const auto &div : nestedDivTerms) + ret.emplace_back(div->hash(), adjustedMulFactor(*div), + adjustedLinearTerm(*div)); + return ret; +} + +unsigned PureAffineExprImpl::countNestedDivs() const { + return hasDivisor() + + std::accumulate(getNestedDivTerms().begin(), getNestedDivTerms().end(), + 0, [](unsigned acc, const PureAffineExpr &div) { + return acc + div->countNestedDivs(); + }); +} + +PureAffineExpr mlir::presburger::operator+(PureAffineExpr &&lhs, + PureAffineExpr &&rhs) { + if (lhs->isLinear() && rhs->isLinear()) + return std::make_unique(lhs->getLinearDividend() + + rhs->getLinearDividend()); + if (lhs->isLinear()) + return std::make_unique( + std::move(rhs->addLinearTerm(lhs->getLinearDividend()))); + if (rhs->isLinear()) + return std::make_unique( + std::move(lhs->addLinearTerm(rhs->getLinearDividend()))); + + if (!(lhs->hasDivisor() ^ rhs->hasDivisor())) { + auto ret = PureAffineExprImpl(lhs->info); + ret.addDivTerm(std::move(*lhs)); + ret.addDivTerm(std::move(*rhs)); + return std::make_unique(std::move(ret)); + } + if (lhs->hasDivisor()) { + rhs->addDivTerm(std::move(*lhs)); + return rhs; + } + if (rhs->hasDivisor()) { + lhs->addDivTerm(std::move(*rhs)); + return lhs; + } + llvm_unreachable("Malformed AffineExpr"); +} + +PureAffineExpr mlir::presburger::operator*(PureAffineExpr &&expr, int64_t c) { + if (expr->isLinear()) + return std::make_unique(expr->getLinearDividend() * c); + return std::make_unique(std::move(expr->mulConstant(c))); +} + +PureAffineExpr mlir::presburger::operator+(PureAffineExpr &&expr, int64_t c) { + return std::move(expr) + std::make_unique(expr->info, c); +} + +PureAffineExpr mlir::presburger::div(PureAffineExpr &÷nd, int64_t divisor, + DivKind kind) { + assert(divisor > 0 && "floorDiv or mod with a negative divisor"); + + // Constant fold. + if (dividend->isConstant()) { + int64_t c = kind == DivKind::FloorDiv + ? divideFloorSigned(dividend->getConstant(), divisor) + : mod(dividend->getConstant(), divisor); + return std::make_unique(dividend->info, c); + } + + // Factor out mul, using gcd internally. + uint64_t exprMultiple; + if (dividend->isLinear()) { + exprMultiple = dividend->getLinearDividend().factorMulFromLinearTerm(); + } else { + // Canonicalize the div. + uint64_t constMultiple = + dividend->getLinearDividend().factorMulFromLinearTerm(); + dividend->divLinearDividend(static_cast(constMultiple)); + dividend->mulFactor *= constMultiple; + exprMultiple = dividend->mulFactor; + } + + // Perform gcd with divisor. + uint64_t gcd = std::gcd(std::abs(divisor), exprMultiple); + + // Divide according to the type. + if (gcd > 1) { + if (dividend->isLinear()) + dividend->linearDividend /= static_cast(gcd); + else + dividend->mulFactor /= (static_cast(gcd)); + + divisor /= static_cast(gcd); + } + + if (dividend->isLinear()) + return std::make_unique(dividend->getLinearDividend(), + divisor, kind); + + // x floordiv 1 <=> x, x % 1 <=> 0 + return divisor == 1 + ? (dividend->isMod() + ? std::make_unique(dividend->info) + : std::move(dividend)) + : std::make_unique(std::move(*dividend), + divisor, kind); + ; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +enum class BindingStrength { + Weak, // + and - + Strong, // All other binary operators. +}; + +static void printCoefficient(int64_t c, bool &isExprBegin, + const ParseInfo &info, int idx = -1) { + bool isConstant = idx != -1 && info.isConstantIdx(idx); + bool isDimOrSymbol = idx != -1 && !info.isConstantIdx(idx); + if (!c) + return; + if (!isExprBegin) + dbgs() << " "; + if (c < 0) { + dbgs() << "- "; + if (c != -1 || isConstant) + dbgs() << std::abs(c); + } else { + if (!isExprBegin) + dbgs() << "+ "; + if (c != 1 || isConstant) { + dbgs() << c; + isExprBegin = false; + } + } + if (isDimOrSymbol) { + if (std::abs(c) != 1) + dbgs() << " * "; + dbgs() << (info.isDimIdx(idx) ? 'd' : 's') << idx; + isExprBegin = false; + } +} + +static bool printCoefficientVec( + const CoefficientVector &linear, bool isExprBegin = true, + BindingStrength enclosingTightness = BindingStrength::Weak) { + if (enclosingTightness == BindingStrength::Strong && + linear.hasMultipleCoefficients()) { + dbgs() << '('; + isExprBegin = true; + } + for (auto [idx, c] : enumerate(linear.getCoefficients())) + printCoefficient(c, isExprBegin, linear.info, idx); + + if (enclosingTightness == BindingStrength::Strong && + linear.hasMultipleCoefficients()) + dbgs() << ')'; + return isExprBegin; +} + +static bool +printAffineExpr(const PureAffineExprImpl &expr, bool isExprBegin = true, + BindingStrength enclosingTightness = BindingStrength::Weak) { + if (expr.isLinear()) + return printCoefficientVec(expr.getLinearDividend(), isExprBegin, + enclosingTightness); + + const auto &div = expr; + const auto &linearDividend = div.getLinearDividend(); + const auto &divisor = div.getDivisor(); + const auto &mulFactor = div.getMulFactor(); + const auto &nestedDivs = div.getNestedDivTerms(); + + printCoefficient(mulFactor, isExprBegin, expr.info); + if (std::abs(mulFactor) != 1) + dbgs() << " * "; + + if (div.hasDivisor() || enclosingTightness == BindingStrength::Strong) { + dbgs() << '('; + isExprBegin = true; + } + + isExprBegin = printCoefficientVec(linearDividend, isExprBegin); + + for (const auto &div : nestedDivs) + isExprBegin = printAffineExpr(*div, isExprBegin, BindingStrength::Strong); + + if (div.hasDivisor()) + dbgs() << (expr.isMod() ? " % " : " floordiv ") << divisor; + + if (div.hasDivisor() || enclosingTightness == BindingStrength::Strong) + dbgs() << ')'; + + return isExprBegin; +} + +LLVM_DUMP_METHOD void CoefficientVector::dump() const { + printCoefficientVec(*this); + dbgs() << '\n'; +} + +LLVM_DUMP_METHOD void PureAffineExprImpl::dump() const { + printAffineExpr(*this); + dbgs() << '\n'; +} + +LLVM_DUMP_METHOD void FinalParseResult::dump() const { + dbgs() << "Exprs:\n"; + for (const auto &expr : exprs) + expr->dump(); + dbgs() << "EqFlags: "; + for (bool eqF : eqFlags) + dbgs() << eqF << ' '; + dbgs() << '\n'; +} +#endif diff --git a/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h new file mode 100644 index 00000000000000..ecf8428acd6569 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/ParseStructs.h @@ -0,0 +1,353 @@ +//===- ParseStructs.h - Presburger Parse Structrures ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H + +#include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/PresburgerSpace.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir::presburger { +using llvm::ArrayRef; +using llvm::SmallVector; +using llvm::SmallVectorImpl; + +/// This structure is central to the parser and flattener, and holds the number +/// of dimensions, symbols, locals, and the constant term. +struct ParseInfo { + unsigned numDims = 0; + unsigned numSymbols = 0; + unsigned numExprs = 0; + unsigned numDivs = 0; + + constexpr unsigned getDimStartIdx() const { return 0; } + constexpr unsigned getSymbolStartIdx() const { return numDims; } + constexpr unsigned getLocalVarStartIdx() const { + return numDims + numSymbols; + } + constexpr unsigned getNumCols() const { + return numDims + numSymbols + numDivs + 1; + } + constexpr unsigned getConstantIdx() const { return getNumCols() - 1; } + + constexpr bool isDimIdx(unsigned i) const { return i < getSymbolStartIdx(); } + constexpr bool isSymbolIdx(unsigned i) const { + return i >= getSymbolStartIdx() && i < getLocalVarStartIdx(); + } + constexpr bool isLocalVarIdx(unsigned i) const { + return i >= getLocalVarStartIdx() && i < getConstantIdx(); + } + constexpr bool isConstantIdx(unsigned i) const { + return i == getConstantIdx(); + } +}; + +/// Helper for storing coefficients in canonical form: dims followed by symbols, +/// followed by locals, and finally the constant term. +/// +/// (x, y)[a, b]: y * 91 + x + 3 * a + 7 +/// coefficients: [1, 91, 3, 0, 7] +struct CoefficientVector { + ParseInfo info; + SmallVector coefficients; + + CoefficientVector(const ParseInfo &info, int64_t c = 0) : info(info) { + coefficients.resize(info.getNumCols()); + coefficients[info.getConstantIdx()] = c; + } + + // Copyable and movable + CoefficientVector(const CoefficientVector &o) = default; + CoefficientVector &operator=(const CoefficientVector &o) = default; + CoefficientVector(CoefficientVector &&o) + : info(o.info), coefficients(std::move(o.coefficients)) { + o.coefficients.clear(); + } + + ArrayRef getCoefficients() const { return coefficients; } + int64_t getConstant() const { + return coefficients[info.getConstantIdx()]; + } + size_t size() const { return coefficients.size(); } + operator ArrayRef() const { return coefficients; } + void resize(size_t size) { coefficients.resize(size); } + operator bool() const { + return any_of(coefficients, [](int64_t c) { return c; }); + } + int64_t &operator[](unsigned i) { + assert(i < coefficients.size()); + return coefficients[i]; + } + int64_t &back() { return coefficients.back(); } + int64_t back() const { return coefficients.back(); } + void clear() { + for_each(coefficients, [](auto &coeff) { coeff = 0; }); + } + + CoefficientVector &operator+=(const CoefficientVector &l) { + coefficients.resize(l.size()); + for (auto [idx, c] : enumerate(l.getCoefficients())) + coefficients[idx] += c; + return *this; + } + CoefficientVector &operator*=(int64_t c) { + for_each(coefficients, [c](auto &coeff) { coeff *= c; }); + return *this; + } + CoefficientVector &operator/=(int64_t c) { + assert(c && "Division by zero"); + for_each(coefficients, [c](auto &coeff) { coeff /= c; }); + return *this; + } + + CoefficientVector operator+(const CoefficientVector &l) const { + CoefficientVector ret(*this); + return ret += l; + } + CoefficientVector operator*(int64_t c) const { + CoefficientVector ret(*this); + return ret *= c; + } + CoefficientVector operator/(int64_t c) const { + CoefficientVector ret(*this); + return ret /= c; + } + + bool isConstant() const { + return all_of(drop_end(coefficients), [](int64_t c) { return !c; }); + } + CoefficientVector getPadded(size_t newSize) const { + assert(newSize >= size() && + "Padding size should be greater than expr size"); + CoefficientVector ret(info); + ret.resize(newSize); + + // Start constructing the result by taking the dims and symbols of the + // coefficients. + for (const auto &[col, coeff] : enumerate(drop_end(coefficients))) + ret[col] = coeff; + + // Put the constant at the end. + ret.back() = back(); + return ret; + } + + uint64_t factorMulFromLinearTerm() const { + uint64_t gcd = 1; + for (int64_t val : coefficients) + gcd = std::gcd(gcd, std::abs(val)); + return gcd; + } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + bool hasMultipleCoefficients() const { + return count_if(coefficients, [](auto &coeff) { return coeff; }) > 1; + } + LLVM_DUMP_METHOD void dump() const; +#endif +}; + +enum class DimOrSymbolKind { + DimId, + Symbol, +}; + +using DimOrSymbolExpr = std::pair; +enum class DivKind { FloorDiv, Mod }; + +/// Represents a pure Affine expression. Linear expressions are represented with +/// divisor = 1, and no nestedDivTerms. +/// +/// 3 - a * (3 + (3 - x div 3) div 4 + y div 7) div 4 +/// ^ linearDivident = 3, mulFactor = 1, divisor = 1 +/// ^ nest: 1, mulFactor: -a +/// ^ nest: 1, linearDividend +/// ^ nest: 2, linearDividend +/// ^ nest: 3 +/// ^ nest: 2 +/// ^ nest: 2 +/// nest: 1, divisor ^ +/// +/// Where div = floordiv|mod; ceildiv is pre-reduced +struct PureAffineExprImpl { + ParseInfo info; + DivKind kind = DivKind::FloorDiv; + using PureAffineExpr = std::unique_ptr; + + int64_t mulFactor = 1; + CoefficientVector linearDividend; + int64_t divisor = 1; + SmallVector nestedDivTerms; + + PureAffineExprImpl(const ParseInfo &info, int64_t c = 0) + : info(info), linearDividend(info, c) {} + PureAffineExprImpl(const ParseInfo &info, DimOrSymbolExpr idExpr) + : PureAffineExprImpl(info) { + auto [kind, pos] = idExpr; + unsigned startIdx = kind == DimOrSymbolKind::Symbol + ? info.getSymbolStartIdx() + : info.getDimStartIdx(); + linearDividend.coefficients[startIdx + pos] = 1; + } + PureAffineExprImpl(const CoefficientVector &linearDividend, + int64_t divisor = 1, DivKind kind = DivKind::FloorDiv) + : info(linearDividend.info), kind(kind), linearDividend(linearDividend), + divisor(divisor) {} + PureAffineExprImpl(PureAffineExprImpl &&div, int64_t divisor, DivKind kind) + : info(div.info), kind(kind), linearDividend(div.info), divisor(divisor) { + addDivTerm(std::move(div)); + } + + // Non-copyable, only movable + PureAffineExprImpl(const PureAffineExprImpl &) = delete; + PureAffineExprImpl(PureAffineExprImpl &&o) + : info(o.info), kind(o.kind), mulFactor(o.mulFactor), + linearDividend(std::move(o.linearDividend)), divisor(o.divisor), + nestedDivTerms(std::move(o.nestedDivTerms)) { + o.nestedDivTerms.clear(); + o.divisor = o.mulFactor = 1; + } + + const CoefficientVector &getLinearDividend() const { return linearDividend; } + CoefficientVector collectLinearTerms() const; + SmallVector, 8> + getNonLinearCoeffs() const; + + constexpr bool isMod() const { return kind == DivKind::Mod; } + constexpr bool hasDivisor() const { return divisor != 1; } + bool isLinear() const { + return nestedDivTerms.empty() && !hasDivisor(); + } + constexpr int64_t getDivisor() const { return divisor; } + constexpr int64_t getMulFactor() const { return mulFactor; } + bool isConstant() const { + return nestedDivTerms.empty() && getLinearDividend().isConstant(); + } + int64_t getConstant() const { + return getLinearDividend().getConstant(); + } + size_t hash() const { + return std::hash{}(this); + } + ArrayRef getNestedDivTerms() const { return nestedDivTerms; } + + PureAffineExprImpl &mulConstant(int64_t c) { + // Canonicalize mulFactors in div terms without divisors. + mulFactor *= c; + distributeMulFactor(); + return *this; + } + PureAffineExprImpl &addLinearTerm(const CoefficientVector &l) { + if (hasDivisor()) + nestedDivTerms.emplace_back( + std::make_unique(std::move(*this))); + linearDividend += l; + return *this; + } + PureAffineExprImpl &addDivTerm(PureAffineExprImpl &&d) { + nestedDivTerms.emplace_back( + std::make_unique(std::move(d))); + return *this; + } + PureAffineExprImpl &divLinearDividend(int64_t c) { + linearDividend /= c; + return *this; + } + + void distributeMulFactor() { + // Canonicalize the -1 mulFactor for divs without divisor. + if (mulFactor != -1 || hasDivisor()) + return; + linearDividend *= -1; + for_each(nestedDivTerms, + [](const PureAffineExpr &div) { div->mulFactor *= -1; }); + mulFactor *= -1; + } + unsigned countNestedDivs() const; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const; +#endif +}; + +using PureAffineExpr = std::unique_ptr; + +/// This structure holds the final parse result, and is constructed +/// non-trivially to compute the total number of divs, which is then used to +/// compute the total number of columns, and construct the IntMatrix in the +/// Flattener. +struct FinalParseResult { + ParseInfo info; + SmallVector exprs; + SmallVector eqFlags; + IntegerPolyhedron cst; + + FinalParseResult(const ParseInfo &parseInfo, + SmallVectorImpl &&exprStack, + ArrayRef eqFlagStack) + : info(parseInfo), exprs(std::move(exprStack)), eqFlags(eqFlagStack), + cst(0, 0, info.numDims + info.numSymbols + 1, + PresburgerSpace::getSetSpace(info.numDims, info.numSymbols, 0)) { + auto &i = this->info; + i.numExprs = exprs.size(); + i.numDivs = std::accumulate(exprs.begin(), exprs.end(), 0, + [](unsigned acc, const PureAffineExpr &expr) { + return acc + expr->countNestedDivs(); + }); + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() const; +#endif +}; + +// The main interface to the parser is a bunch of operators, which is used to +// successively build the final AffineExpr. +PureAffineExpr operator+(PureAffineExpr &&lhs, PureAffineExpr &&rhs); +PureAffineExpr operator*(PureAffineExpr &&expr, int64_t c); +PureAffineExpr operator+(PureAffineExpr &&expr, int64_t c); +PureAffineExpr div(PureAffineExpr &÷nd, int64_t divisor, DivKind kind); +inline PureAffineExpr floordiv(PureAffineExpr &&expr, int64_t c) { + return div(std::move(expr), c, DivKind::FloorDiv); +} +inline PureAffineExpr operator%(PureAffineExpr &&expr, int64_t c) { + return div(std::move(expr), c, DivKind::Mod); +} +inline PureAffineExpr operator*(int64_t c, PureAffineExpr &&expr) { + return std::move(expr) * c; +} +inline PureAffineExpr operator-(PureAffineExpr &&lhs, PureAffineExpr &&rhs) { + return std::move(lhs) + std::move(rhs) * -1; +} +inline PureAffineExpr operator+(int64_t c, PureAffineExpr &&expr) { + return std::move(expr) + c; +} +inline PureAffineExpr operator-(PureAffineExpr &&expr, int64_t c) { + return std::move(expr) + (-c); +} +inline PureAffineExpr operator-(int64_t c, PureAffineExpr &&expr) { + return -1 * std::move(expr) + c; +} +inline PureAffineExpr ceildiv(PureAffineExpr &&expr, int64_t c) { + // expr ceildiv c <=> (expr + c - 1) floordiv c + return floordiv(std::move(expr) + (c - 1), c); +} + +// Our final canonical expression, the outermost div, should have a divisor +// of 1. +inline PureAffineExpr canonicalize(PureAffineExpr &&expr) { + if (expr->hasDivisor()) + expr->addLinearTerm(CoefficientVector(expr->info)); + return expr; +} +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_PARSESTRUCTS_H diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp new file mode 100644 index 00000000000000..20b2b6d08638e2 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.cpp @@ -0,0 +1,640 @@ +//===- ParserImpl.cpp - Presburger Parser Implementation --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ParserImpl class for the Presburger textual form. +// +//===----------------------------------------------------------------------===// + +#include "ParserImpl.h" +#include "Flattener.h" +#include "ParseStructs.h" +#include "mlir/Analysis/Presburger/Parser.h" + +using namespace mlir; +using namespace presburger; +using llvm::MemoryBuffer; +using llvm::SourceMgr; + +static bool isIdentifier(const Token &token) { + return token.is(Token::bare_identifier) || token.isKeyword(); +} + +bool ParserImpl::parseToken(Token::Kind expectedToken, const Twine &message) { + if (consumeIf(expectedToken)) + return true; + return emitError(message); +} + +bool ParserImpl::parseCommaSepeatedList(Delimiter delimiter, + function_ref parseElementFn, + StringRef contextMessage) { + switch (delimiter) { + case Delimiter::None: + break; + case Delimiter::OptionalParen: + if (getToken().isNot(Token::l_paren)) + return true; + [[fallthrough]]; + case Delimiter::Paren: + if (!parseToken(Token::l_paren, "expected '('" + contextMessage)) + return false; + if (consumeIf(Token::r_paren)) + return true; + break; + case Delimiter::OptionalLessGreater: + if (getToken().isNot(Token::less)) + return true; + [[fallthrough]]; + case Delimiter::LessGreater: + if (!parseToken(Token::less, "expected '<'" + contextMessage)) + return true; + if (consumeIf(Token::greater)) + return true; + break; + case Delimiter::OptionalSquare: + if (getToken().isNot(Token::l_square)) + return true; + [[fallthrough]]; + case Delimiter::Square: + if (!parseToken(Token::l_square, "expected '['" + contextMessage)) + return false; + if (consumeIf(Token::r_square)) + return true; + break; + case Delimiter::OptionalBraces: + if (getToken().isNot(Token::l_brace)) + return true; + [[fallthrough]]; + case Delimiter::Braces: + if (!parseToken(Token::l_brace, "expected '{'" + contextMessage)) + return false; + if (consumeIf(Token::r_brace)) + return true; + break; + } + + if (!parseElementFn()) + return false; + + while (consumeIf(Token::comma)) + if (!parseElementFn()) + return false; + + switch (delimiter) { + case Delimiter::None: + return true; + case Delimiter::OptionalParen: + case Delimiter::Paren: + return parseToken(Token::r_paren, "expected ')'" + contextMessage); + case Delimiter::OptionalLessGreater: + case Delimiter::LessGreater: + return parseToken(Token::greater, "expected '>'" + contextMessage); + case Delimiter::OptionalSquare: + case Delimiter::Square: + return parseToken(Token::r_square, "expected ']'" + contextMessage); + case Delimiter::OptionalBraces: + case Delimiter::Braces: + return parseToken(Token::r_brace, "expected '}'" + contextMessage); + } + llvm_unreachable("Unknown delimiter"); +} + +bool ParserImpl::emitError(SMLoc loc, const Twine &message) { + // If we hit a parse error in response to a lexer error, then the lexer + // already reported the error. + if (getToken().isNot(Token::error)) + sourceMgr.PrintMessage(loc, SourceMgr::DK_Error, message); + return false; +} + +bool ParserImpl::emitError(const Twine &message) { + SMLoc loc = state.curToken.getLoc(); + if (state.curToken.isNot(Token::eof)) + return emitError(loc, message); + + // If the error is to be emitted at EOF, move it back one character. + return emitError(SMLoc::getFromPointer(loc.getPointer() - 1), message); +} + +/// Parse a bare id that may appear in an affine expression. +/// +/// affine-expr ::= bare-id +PureAffineExpr ParserImpl::parseBareIdExpr() { + if (!isIdentifier(getToken())) + return emitError("expected bare identifier"), nullptr; + + StringRef sRef = getTokenSpelling(); + for (const auto &entry : dimsAndSymbols) { + if (entry.first == sRef) { + consumeToken(); + return std::make_unique(info, entry.second); + } + } + + return emitError("use of undeclared identifier"), nullptr; +} + +/// Parse an affine expression inside parentheses. +/// +/// affine-expr ::= `(` affine-expr `)` +PureAffineExpr ParserImpl::parseParentheticalExpr() { + if (!parseToken(Token::l_paren, "expected '('")) + return nullptr; + if (getToken().is(Token::r_paren)) { + return emitError("no expression inside parentheses"), nullptr; + } + + PureAffineExpr expr = parseAffineExpr(); + if (!expr || !parseToken(Token::r_paren, "expected ')'")) + return nullptr; + + return expr; +} + +/// Parse the negation expression. +/// +/// affine-expr ::= `-` affine-expr +PureAffineExpr ParserImpl::parseNegateExpression(const PureAffineExpr &lhs) { + if (!parseToken(Token::minus, "expected '-'")) + return nullptr; + + PureAffineExpr operand = parseAffineOperandExpr(lhs); + // Since negation has the highest precedence of all ops (including high + // precedence ops) but lower than parentheses, we are only going to use + // parseAffineOperandExpr instead of parseAffineExpr here. + if (!operand) { + // Extra error message although parseAffineOperandExpr would have + // complained. Leads to a better diagnostic. + return emitError("missing operand of negation"), nullptr; + } + return -1 * std::move(operand); +} + +/// Parse a positive integral constant appearing in an affine expression. +/// +/// affine-expr ::= integer-literal +PureAffineExpr ParserImpl::parseIntegerExpr() { + std::optional val = getToken().getUInt64IntegerValue(); + if (!val) + return emitError("failed to parse constant"), nullptr; + int64_t ret = static_cast(*val); + if (ret < 0) + return emitError("constant too large"), nullptr; + + consumeToken(Token::integer); + return std::make_unique(info, ret); +} + +/// Parses an expression that can be a valid operand of an affine expression. +/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary +/// operator, the rhs of which is being parsed. This is used to determine +/// whether an error should be emitted for a missing right operand. +// Eg: for an expression without parentheses (like i + j + k + l), each +// of the four identifiers is an operand. For i + j*k + l, j*k is not an +// operand expression, it's an op expression and will be parsed via +// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and +// -l are valid operands that will be parsed by this function. +PureAffineExpr ParserImpl::parseAffineOperandExpr(const PureAffineExpr &lhs) { + switch (getToken().getKind()) { + case Token::integer: + return parseIntegerExpr(); + case Token::l_paren: + return parseParentheticalExpr(); + case Token::minus: + return parseNegateExpression(lhs); + case Token::kw_ceildiv: + case Token::kw_floordiv: + case Token::kw_mod: + return parseBareIdExpr(); + case Token::plus: + case Token::star: + if (lhs) + emitError("missing right operand of binary operator"); + else + emitError("missing left operand of binary operator"); + return nullptr; + default: + if (isIdentifier(getToken())) + return parseBareIdExpr(); + + if (lhs) + emitError("missing right operand of binary operator"); + else + emitError("expected affine expression"); + return nullptr; + } +} + +PureAffineExpr ParserImpl::getAffineBinaryOpExpr(AffineHighPrecOp op, + PureAffineExpr &&lhs, + PureAffineExpr &&rhs, + SMLoc opLoc) { + switch (op) { + case Mul: + if (!lhs->isConstant() && !rhs->isConstant()) { + return emitError(opLoc, + "non-affine expression: at least one of the multiply " + "operands has to be a constant"), + nullptr; + } + if (rhs->isConstant()) + return std::move(lhs) * rhs->getConstant(); + return std::move(rhs) * lhs->getConstant(); + case FloorDiv: + if (!rhs->isConstant()) { + return emitError(opLoc, + "non-affine expression: right operand of floordiv " + "has to be a constant"), + nullptr; + } + return floordiv(std::move(lhs), rhs->getConstant()); + case CeilDiv: + if (!rhs->isConstant()) { + return emitError(opLoc, "non-affine expression: right operand of ceildiv " + "has to be a constant"), + nullptr; + } + return ceildiv(std::move(lhs), rhs->getConstant()); + case Mod: + if (!rhs->isConstant()) { + return emitError(opLoc, "non-affine expression: right operand of mod " + "has to be a constant"), + nullptr; + } + return std::move(lhs) % rhs->getConstant(); + case HNoOp: + llvm_unreachable("can't create affine expression for null high prec op"); + return nullptr; + } + llvm_unreachable("Unknown AffineHighPrecOp"); +} + +PureAffineExpr ParserImpl::getAffineBinaryOpExpr(AffineLowPrecOp op, + PureAffineExpr &&lhs, + PureAffineExpr &&rhs) { + switch (op) { + case AffineLowPrecOp::Add: + return std::move(lhs) + std::move(rhs); + case AffineLowPrecOp::Sub: + return std::move(lhs) - std::move(rhs); + case AffineLowPrecOp::LNoOp: + llvm_unreachable("can't create affine expression for null low prec op"); + return nullptr; + } + llvm_unreachable("Unknown AffineLowPrecOp"); +} + +AffineLowPrecOp ParserImpl::consumeIfLowPrecOp() { + switch (getToken().getKind()) { + case Token::plus: + consumeToken(Token::plus); + return AffineLowPrecOp::Add; + case Token::minus: + consumeToken(Token::minus); + return AffineLowPrecOp::Sub; + default: + return AffineLowPrecOp::LNoOp; + } +} + +/// Consume this token if it is a higher precedence affine op (there are only +/// two precedence levels) +AffineHighPrecOp ParserImpl::consumeIfHighPrecOp() { + switch (getToken().getKind()) { + case Token::star: + consumeToken(Token::star); + return Mul; + case Token::kw_floordiv: + consumeToken(Token::kw_floordiv); + return FloorDiv; + case Token::kw_ceildiv: + consumeToken(Token::kw_ceildiv); + return CeilDiv; + case Token::kw_mod: + consumeToken(Token::kw_mod); + return Mod; + default: + return HNoOp; + } +} + +/// Parse a high precedence op expression list: mul, div, and mod are high +/// precedence binary ops, i.e., parse a +/// expr_1 op_1 expr_2 op_2 ... expr_n +/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). +/// All affine binary ops are left associative. +/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is +/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is +/// null. llhsOpLoc is the location of the llhsOp token that will be used to +/// report an error for non-conforming expressions. +PureAffineExpr ParserImpl::parseAffineHighPrecOpExpr(PureAffineExpr &&llhs, + AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc) { + PureAffineExpr lhs = parseAffineOperandExpr(llhs); + if (!lhs) + return nullptr; + + // Found an LHS. Parse the remaining expression. + SMLoc opLoc = getToken().getLoc(); + if (AffineHighPrecOp op = consumeIfHighPrecOp()) { + if (llhs) { + PureAffineExpr expr = + getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs), opLoc); + if (!expr) + return nullptr; + return parseAffineHighPrecOpExpr(std::move(expr), op, opLoc); + } + // No LLHS, get RHS + return parseAffineHighPrecOpExpr(std::move(lhs), op, opLoc); + } + + // This is the last operand in this expression. + if (llhs) + return getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs), + llhsOpLoc); + + return lhs; +} + +/// Parse affine expressions that are bare-id's, integer constants, +/// parenthetical affine expressions, and affine op expressions that are a +/// composition of those. +/// +/// All binary op's associate from left to right. +/// +/// {add, sub} have lower precedence than {mul, div, and mod}. +/// +/// Add, sub'are themselves at the same precedence level. Mul, floordiv, +/// ceildiv, and mod are at the same higher precedence level. Negation has +/// higher precedence than any binary op. +/// +/// llhs: the affine expression appearing on the left of the one being parsed. +/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, +/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned +/// if llhs is non-null; otherwise lhs is returned. This is to deal with left +/// associativity. +/// +/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function +/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where +/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). +PureAffineExpr ParserImpl::parseAffineLowPrecOpExpr(PureAffineExpr &&llhs, + AffineLowPrecOp llhsOp) { + PureAffineExpr lhs = parseAffineOperandExpr(llhs); + if (!lhs) + return nullptr; + + // Found an LHS. Deal with the ops. + if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { + if (llhs) { + PureAffineExpr sum = + getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs)); + return parseAffineLowPrecOpExpr(std::move(sum), lOp); + } + + return parseAffineLowPrecOpExpr(std::move(lhs), lOp); + } + SMLoc opLoc = getToken().getLoc(); + if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { + // We have a higher precedence op here. Get the rhs operand for the llhs + // through parseAffineHighPrecOpExpr. + PureAffineExpr highRes = + parseAffineHighPrecOpExpr(std::move(lhs), hOp, opLoc); + if (!highRes) + return nullptr; + + // If llhs is null, the product forms the first operand of the yet to be + // found expression. If non-null, the op to associate with llhs is llhsOp. + PureAffineExpr expr = llhs ? getAffineBinaryOpExpr(llhsOp, std::move(llhs), + std::move(highRes)) + : std::move(highRes); + + // Recurse for subsequent low prec op's after the affine high prec op + // expression. + if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) + return parseAffineLowPrecOpExpr(std::move(expr), nextOp); + return expr; + } + // Last operand in the expression list. + if (llhs) + return getAffineBinaryOpExpr(llhsOp, std::move(llhs), std::move(lhs)); + + return lhs; +} + +/// Parse an affine expression. +/// affine-expr ::= `(` affine-expr `)` +/// | `-` affine-expr +/// | affine-expr `+` affine-expr +/// | affine-expr `-` affine-expr +/// | affine-expr `*` affine-expr +/// | affine-expr `floordiv` affine-expr +/// | affine-expr `ceildiv` affine-expr +/// | affine-expr `mod` affine-expr +/// | bare-id +/// | integer-literal +/// +/// Additional conditions are checked depending on the production. For eg., +/// one of the operands for `*` has to be either constant/symbolic; the second +/// operand for floordiv, ceildiv, and mod has to be a positive integer. +PureAffineExpr ParserImpl::parseAffineExpr() { + return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); +} + +/// Parse a dim or symbol from the lists appearing before the actual +/// expressions of the affine map. Update our state to store the +/// dimensional/symbolic identifier. +bool ParserImpl::parseIdentifierDefinition(DimOrSymbolExpr idExpr) { + if (!isIdentifier(getToken())) + return emitError("expected bare identifier"); + + StringRef name = getTokenSpelling(); + for (const auto &entry : dimsAndSymbols) { + if (entry.first == name) + return emitError("redefinition of identifier '" + name + "'"); + } + consumeToken(); + + dimsAndSymbols.emplace_back(name, idExpr); + return true; +} + +/// Parse the list of dimensional identifiers to an affine map. +bool ParserImpl::parseDimIdList() { + auto parseElt = [&]() -> bool { + return parseIdentifierDefinition({DimOrSymbolKind::DimId, info.numDims++}); + }; + return parseCommaSepeatedList(Delimiter::Paren, parseElt, + " in dimensional identifier list"); +} + +/// Parse the list of symbolic identifiers to an affine map. +bool ParserImpl::parseSymbolIdList() { + auto parseElt = [&]() -> bool { + return parseIdentifierDefinition( + {DimOrSymbolKind::Symbol, info.numSymbols++}); + }; + return parseCommaSepeatedList(Delimiter::Square, parseElt, " in symbol list"); +} + +/// Parse the list of symbolic identifiers to an affine map. +bool ParserImpl::parseDimAndOptionalSymbolIdList() { + if (!parseDimIdList()) + return false; + if (getToken().isNot(Token::l_square)) { + info.numSymbols = 0; + return true; + } + return parseSymbolIdList(); +} + +/// Parse an affine constraint. +/// affine-constraint ::= affine-expr `>=` `affine-expr` +/// | affine-expr `<=` `affine-expr` +/// | affine-expr `==` `affine-expr` +/// +/// The constraint is normalized to +/// affine-constraint ::= affine-expr `>=` `0` +/// | affine-expr `==` `0` +/// before returning. +/// +/// isEq is set to true if the parsed constraint is an equality, false if it +/// is an inequality (greater than or equal). +PureAffineExpr ParserImpl::parseAffineConstraint(bool *isEq) { + PureAffineExpr lhsExpr = parseAffineExpr(); + if (!lhsExpr) + return nullptr; + + if (consumeIf(Token::greater) && consumeIf(Token::equal)) { + PureAffineExpr rhsExpr = parseAffineExpr(); + if (!rhsExpr) + return nullptr; + *isEq = false; + return std::move(lhsExpr) - std::move(rhsExpr); + } + + if (consumeIf(Token::less) && consumeIf(Token::equal)) { + PureAffineExpr rhsExpr = parseAffineExpr(); + if (!rhsExpr) + return nullptr; + *isEq = false; + return std::move(rhsExpr) - std::move(lhsExpr); + } + + if (consumeIf(Token::equal) && consumeIf(Token::equal)) { + PureAffineExpr rhsExpr = parseAffineExpr(); + if (!rhsExpr) + return nullptr; + *isEq = true; + return std::move(lhsExpr) - std::move(rhsExpr); + } + + return emitError("expected '==', '<=' or '>='"), nullptr; +} + +FinalParseResult ParserImpl::parseAffineMapOrIntegerSet() { + SmallVector exprs; + SmallVector eqFlags; + + if (!parseDimAndOptionalSymbolIdList()) + llvm_unreachable("expected dim and symbol list"); + + auto parseExpr = [&]() -> bool { + PureAffineExpr elt = canonicalize(parseAffineExpr()); + if (elt) { + exprs.emplace_back(std::move(elt)); + return true; + } + return false; + }; + + auto parseConstraint = [&]() -> bool { + bool isEq; + PureAffineExpr elt = canonicalize(parseAffineConstraint(&isEq)); + if (elt) { + exprs.emplace_back(std::move(elt)); + eqFlags.push_back(isEq); + return true; + } + return false; + }; + + if (consumeIf(Token::arrow)) { + /// Parse the range and sizes affine map definition inline. + /// + /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr + /// + /// multi-dim-affine-expr ::= `(` `)` + /// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)` + if (!parseCommaSepeatedList(Delimiter::Paren, parseExpr, + " in affine map range")) + llvm_unreachable("expected affine map range"); + + return {info, std::move(exprs), eqFlags}; + } + + if (!parseToken(Token::colon, "expected '->' or ':'")) + llvm_unreachable("Unexpected token"); + + /// Parse the constraints that are part of an integer set definition. + /// integer-set-inline + /// ::= dim-and-symbol-id-lists `:` + /// '(' affine-constraint-conjunction? ')' + /// affine-constraint-conjunction ::= affine-constraint (`,` + /// affine-constraint)* + /// + if (!parseCommaSepeatedList(Delimiter::Paren, parseConstraint, + " in integer set constraint list")) + llvm_unreachable("expected integer set"); + + return {info, std::move(exprs), eqFlags}; +} + +static MultiAffineFunction getMAF(FinalParseResult &&parseResult) { + auto info = parseResult.info; + auto [flatMatrix, cst] = Flattener(std::move(parseResult)).flatten(); + + DivisionRepr divs = cst.getLocalReprs(); + assert(divs.hasAllReprs() && + "AffineMap cannot produce divs without local representation"); + + return MultiAffineFunction( + PresburgerSpace::getRelationSpace(info.numDims, info.numExprs, + info.numSymbols, divs.getNumDivs()), + flatMatrix, divs); +} + +static IntegerPolyhedron getPoly(FinalParseResult &&parseResult) { + auto eqFlags = parseResult.eqFlags; + auto [flatMatrix, cst] = Flattener(std::move(parseResult)).flatten(); + + for (unsigned i = 0; i < flatMatrix.getNumRows(); ++i) { + if (eqFlags[i]) + cst.addEquality(flatMatrix.getRow(i)); + else + cst.addInequality(flatMatrix.getRow(i)); + } + + return cst; +} + +static FinalParseResult parseAffineMapOrIntegerSet(StringRef str) { + SourceMgr sourceMgr; + auto memBuffer = + MemoryBuffer::getMemBuffer(str, "", false); + sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); + ParserImpl parser(sourceMgr); + return parser.parseAffineMapOrIntegerSet(); +} + +IntegerPolyhedron mlir::presburger::parseIntegerPolyhedron(StringRef str) { + return getPoly(parseAffineMapOrIntegerSet(str)); +} + +MultiAffineFunction mlir::presburger::parseMultiAffineFunction(StringRef str) { + return getMAF(parseAffineMapOrIntegerSet(str)); +} diff --git a/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h new file mode 100644 index 00000000000000..7fef4ded216733 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/ParserImpl.h @@ -0,0 +1,160 @@ +//===- ParserImpl.h - Presburger Parser Implementation ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H + +#include "Lexer.h" +#include "ParseStructs.h" + +namespace mlir::presburger { +template +using function_ref = llvm::function_ref; +using llvm::SourceMgr; +using llvm::Twine; + +/// This class refers to the lexing-related state for the parser. +struct ParserState { + ParserState(const llvm::SourceMgr &sourceMgr) + : lex(sourceMgr), curToken(lex.lexToken()), lastToken(Token::error, "") {} + ParserState(const ParserState &) = delete; + + Lexer lex; + Token curToken; + Token lastToken; +}; + +/// These are the supported delimiters around operand lists and region +/// argument lists, used by parseOperandList. +enum class Delimiter { + /// Zero or more operands with no delimiters. + None, + /// Parens surrounding zero or more operands. + Paren, + /// Square brackets surrounding zero or more operands. + Square, + /// <> brackets surrounding zero or more operands. + LessGreater, + /// {} brackets surrounding zero or more operands. + Braces, + /// Parens supporting zero or more operands, or nothing. + OptionalParen, + /// Square brackets supporting zero or more ops, or nothing. + OptionalSquare, + /// <> brackets supporting zero or more ops, or nothing. + OptionalLessGreater, + /// {} brackets surrounding zero or more operands, or nothing. + OptionalBraces, +}; + +/// Lower precedence ops (all at the same precedence level). LNoOp is false in +/// the boolean sense. +enum AffineLowPrecOp { + LNoOp, // Null value. + Add, + Sub +}; + +/// Higher precedence ops - all at the same precedence level. HNoOp is false +/// in the boolean sense. +enum AffineHighPrecOp { + HNoOp, // Null value. + Mul, + FloorDiv, + CeilDiv, + Mod +}; + +class ParserImpl { +public: + ParserImpl(const SourceMgr &sourceMgr) + : sourceMgr(sourceMgr), state(sourceMgr) {} + + FinalParseResult parseAffineMapOrIntegerSet(); + +private: + /// Parse a list of comma-separated items with an optional delimiter. If a + /// delimiter is provided, then an empty list is allowed. If not, then at + /// least one element will be parsed. + bool parseCommaSepeatedList(Delimiter delimiter, + function_ref parseElementFn, + StringRef contextMessage); + + //===--------------------------------------------------------------------===// + // Error Handling + //===--------------------------------------------------------------------===// + bool emitError(const Twine &message); + bool emitError(SMLoc loc, const Twine &message); + + //===--------------------------------------------------------------------===// + // Token Parsing + //===--------------------------------------------------------------------===// + const Token &getToken() const { return state.curToken; } + StringRef getTokenSpelling() const { return state.curToken.getSpelling(); } + const Token &getLastToken() const { return state.lastToken; } + bool consumeIf(Token::Kind kind) { + if (state.curToken.isNot(kind)) + return false; + consumeToken(kind); + return true; + } + void consumeToken() { + assert(state.curToken.isNot(Token::eof, Token::error) && + "shouldn't advance past EOF or errors"); + state.lastToken = state.curToken; + state.curToken = state.lex.lexToken(); + } + void consumeToken(Token::Kind kind) { + assert(state.curToken.is(kind) && "consumed an unexpected token"); + consumeToken(); + } + void resetToken(const char *tokPos) { + state.lex.resetPointer(tokPos); + state.lastToken = state.curToken; + state.curToken = state.lex.lexToken(); + } + bool parseToken(Token::Kind expectedToken, const Twine &message); + + //===--------------------------------------------------------------------===// + // Affine Parsing + //===--------------------------------------------------------------------===// + AffineLowPrecOp consumeIfLowPrecOp(); + AffineHighPrecOp consumeIfHighPrecOp(); + + bool parseDimIdList(); + bool parseSymbolIdList(); + bool parseDimAndOptionalSymbolIdList(); + bool parseIdentifierDefinition(DimOrSymbolExpr idExpr); + + PureAffineExpr parseAffineExpr(); + PureAffineExpr parseParentheticalExpr(); + PureAffineExpr parseNegateExpression(const PureAffineExpr &lhs); + PureAffineExpr parseIntegerExpr(); + PureAffineExpr parseBareIdExpr(); + + PureAffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, + PureAffineExpr &&lhs, + PureAffineExpr &&rhs, SMLoc opLoc); + PureAffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, PureAffineExpr &&lhs, + PureAffineExpr &&rhs); + PureAffineExpr parseAffineOperandExpr(const PureAffineExpr &lhs); + PureAffineExpr parseAffineLowPrecOpExpr(PureAffineExpr &&llhs, + AffineLowPrecOp llhsOp); + PureAffineExpr parseAffineHighPrecOpExpr(PureAffineExpr &&llhs, + AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc); + PureAffineExpr parseAffineConstraint(bool *isEq); + + const SourceMgr &sourceMgr; + ParserState state; + ParseInfo info; + SmallVector, 4> dimsAndSymbols; +}; +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_PARSERIMPL_H diff --git a/mlir/lib/Analysis/Presburger/Parser/Token.cpp b/mlir/lib/Analysis/Presburger/Parser/Token.cpp new file mode 100644 index 00000000000000..d675e0662c88a4 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Token.cpp @@ -0,0 +1,64 @@ +//===- Token.cpp - Presburger Token Implementation --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Token class for the Presburger textual form. +// +//===----------------------------------------------------------------------===// + +#include "Token.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/SMLoc.h" +#include + +using namespace mlir::presburger; + +SMLoc Token::getLoc() const { return SMLoc::getFromPointer(spelling.data()); } + +SMLoc Token::getEndLoc() const { + return SMLoc::getFromPointer(spelling.data() + spelling.size()); +} + +SMRange Token::getLocRange() const { return SMRange(getLoc(), getEndLoc()); } + +/// For an integer token, return its value as a uint64_t. If it doesn't fit, +/// return std::nullopt. +std::optional Token::getUInt64IntegerValue(StringRef spelling) { + uint64_t result = 0; + if (spelling.getAsInteger(10, result)) + return std::nullopt; + return result; +} + +/// Given a punctuation or keyword token kind, return the spelling of the +/// token as a string. Warning: This will abort on markers, identifiers and +/// literal tokens since they have no fixed spelling. +StringRef Token::getTokenSpelling(Kind kind) { + switch (kind) { + default: + llvm_unreachable("This token kind has no fixed spelling"); +#define TOK_PUNCTUATION(NAME, SPELLING) \ + case NAME: \ + return SPELLING; +#define TOK_KEYWORD(SPELLING) \ + case kw_##SPELLING: \ + return #SPELLING; +#include "TokenKinds.def" + } +} + +/// Return true if this is one of the keyword token kinds (e.g. kw_if). +bool Token::isKeyword() const { + switch (kind) { + default: + return false; +#define TOK_KEYWORD(SPELLING) \ + case kw_##SPELLING: \ + return true; +#include "TokenKinds.def" + } +} diff --git a/mlir/lib/Analysis/Presburger/Parser/Token.h b/mlir/lib/Analysis/Presburger/Parser/Token.h new file mode 100644 index 00000000000000..5c84c1ba820eed --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/Token.h @@ -0,0 +1,91 @@ +//===- Token.h - Presburger Token Interface ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H +#define MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SMLoc.h" +#include + +namespace mlir::presburger { +using llvm::SMLoc; +using llvm::SMRange; +using llvm::StringRef; + +class Token { +public: + enum Kind { +#define TOK_MARKER(NAME) NAME, +#define TOK_IDENTIFIER(NAME) NAME, +#define TOK_LITERAL(NAME) NAME, +#define TOK_PUNCTUATION(NAME, SPELLING) NAME, +#define TOK_KEYWORD(SPELLING) kw_##SPELLING, +#include "TokenKinds.def" + }; + + Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} + + // Return the bytes that make up this token. + StringRef getSpelling() const { return spelling; } + + // Token classification. + Kind getKind() const { return kind; } + bool is(Kind k) const { return kind == k; } + + bool isAny(Kind k1, Kind k2) const { return is(k1) || is(k2); } + + /// Return true if this token is one of the specified kinds. + template + bool isAny(Kind k1, Kind k2, Kind k3, T... others) const { + if (is(k1)) + return true; + return isAny(k2, k3, others...); + } + + bool isNot(Kind k) const { return kind != k; } + + /// Return true if this token isn't one of the specified kinds. + template + bool isNot(Kind k1, Kind k2, T... others) const { + return !isAny(k1, k2, others...); + } + + /// Return true if this is one of the keyword token kinds (e.g. kw_if). + bool isKeyword() const; + + // Helpers to decode specific sorts of tokens. + + /// For an integer token, return its value as an uint64_t. If it doesn't fit, + /// return std::nullopt. + static std::optional getUInt64IntegerValue(StringRef spelling); + std::optional getUInt64IntegerValue() const { + return getUInt64IntegerValue(getSpelling()); + } + + // Location processing. + SMLoc getLoc() const; + SMLoc getEndLoc() const; + SMRange getLocRange() const; + + /// Given a punctuation or keyword token kind, return the spelling of the + /// token as a string. Warning: This will abort on markers, identifiers and + /// literal tokens since they have no fixed spelling. + static StringRef getTokenSpelling(Kind kind); + +private: + /// Discriminator that indicates the sort of token this is. + Kind kind; + + /// A reference to the entire token contents; this is always a pointer into + /// a memory buffer owned by the source manager. + StringRef spelling; +}; +} // namespace mlir::presburger + +#endif // MLIR_ANALYSIS_PRESBURGER_PARSER_TOKEN_H diff --git a/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def b/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def new file mode 100644 index 00000000000000..3285e77a409755 --- /dev/null +++ b/mlir/lib/Analysis/Presburger/Parser/TokenKinds.def @@ -0,0 +1,72 @@ +//===- TokenKinds.def - Presburger Parser Token Description -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file is intended to be #include'd multiple times to extract information +// about tokens for various clients in the lexer. +// +//===----------------------------------------------------------------------===// + +#if !defined(TOK_MARKER) && !defined(TOK_IDENTIFIER) && \ + !defined(TOK_LITERAL) && !defined(TOK_PUNCTUATION) && \ + !defined(TOK_KEYWORD) +#error Must define one of the TOK_ macros. +#endif + +#ifndef TOK_MARKER +#define TOK_MARKER(X) +#endif +#ifndef TOK_IDENTIFIER +#define TOK_IDENTIFIER(NAME) +#endif +#ifndef TOK_LITERAL +#define TOK_LITERAL(NAME) +#endif +#ifndef TOK_PUNCTUATION +#define TOK_PUNCTUATION(NAME, SPELLING) +#endif +#ifndef TOK_KEYWORD +#define TOK_KEYWORD(SPELLING) +#endif + +// Markers +TOK_MARKER(eof) +TOK_MARKER(error) + +// Identifiers. +TOK_IDENTIFIER(bare_identifier) // foo + +// Literals +TOK_LITERAL(integer) // 42 + +// Punctuation. +TOK_PUNCTUATION(arrow, "->") +TOK_PUNCTUATION(colon, ":") +TOK_PUNCTUATION(comma, ",") +TOK_PUNCTUATION(equal, "=") +TOK_PUNCTUATION(greater, ">") +TOK_PUNCTUATION(l_brace, "{") +TOK_PUNCTUATION(l_paren, "(") +TOK_PUNCTUATION(l_square, "[") +TOK_PUNCTUATION(less, "<") +TOK_PUNCTUATION(minus, "-") +TOK_PUNCTUATION(plus, "+") +TOK_PUNCTUATION(r_brace, "}") +TOK_PUNCTUATION(r_paren, ")") +TOK_PUNCTUATION(r_square, "]") +TOK_PUNCTUATION(star, "*") + +// Keywords. These turn "foo" into Token::kw_foo enums. +TOK_KEYWORD(ceildiv) +TOK_KEYWORD(floordiv) +TOK_KEYWORD(mod) + +#undef TOK_MARKER +#undef TOK_IDENTIFIER +#undef TOK_LITERAL +#undef TOK_PUNCTUATION +#undef TOK_KEYWORD diff --git a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp index 5e279b542fdf95..135697ec2d6a58 100644 --- a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp +++ b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp @@ -1,6 +1,6 @@ #include "mlir/Analysis/Presburger/Barvinok.h" -#include "./Utils.h" -#include "Parser.h" +#include "Utils.h" +#include "mlir/Analysis/Presburger/Parser.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt index b69f514711337b..d65e32917a35d7 100644 --- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt +++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt @@ -6,7 +6,6 @@ add_mlir_unittest(MLIRPresburgerTests IntegerRelationTest.cpp LinearTransformTest.cpp MatrixTest.cpp - Parser.h ParserTest.cpp PresburgerSetTest.cpp PresburgerRelationTest.cpp @@ -19,6 +18,5 @@ add_mlir_unittest(MLIRPresburgerTests target_link_libraries(MLIRPresburgerTests PRIVATE MLIRPresburger - MLIRAffineAnalysis - MLIRParser + MLIRPresburgerParser ) diff --git a/mlir/unittests/Analysis/Presburger/FractionTest.cpp b/mlir/unittests/Analysis/Presburger/FractionTest.cpp index c9fad953dacd56..92bdd57d6cf82b 100644 --- a/mlir/unittests/Analysis/Presburger/FractionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/FractionTest.cpp @@ -1,5 +1,4 @@ #include "mlir/Analysis/Presburger/Fraction.h" -#include "./Utils.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp index 3fc68cddaad007..759f4d6cc5194d 100644 --- a/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/GeneratingFunctionTest.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/GeneratingFunction.h" -#include "./Utils.h" +#include "Utils.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp index f64bb240b4ee48..bc0f7a0426d08f 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" #include "Utils.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/PWMAFunction.h" +#include "mlir/Analysis/Presburger/Parser.h" #include "mlir/Analysis/Presburger/Simplex.h" #include diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp index 7df500bc9568aa..a6e02f3982936b 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/IntegerRelation.h" -#include "Parser.h" +#include "mlir/Analysis/Presburger/Parser.h" #include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "mlir/Analysis/Presburger/Simplex.h" diff --git a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp index cb8df8b346011a..2da49d7793b722 100644 --- a/mlir/unittests/Analysis/Presburger/MatrixTest.cpp +++ b/mlir/unittests/Analysis/Presburger/MatrixTest.cpp @@ -7,8 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/Matrix.h" -#include "./Utils.h" +#include "Utils.h" #include "mlir/Analysis/Presburger/Fraction.h" +#include "mlir/Analysis/Presburger/PresburgerRelation.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp index ee2931e78185ca..22557414b8cb09 100644 --- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp @@ -10,11 +10,8 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" - #include "mlir/Analysis/Presburger/PWMAFunction.h" -#include "mlir/Analysis/Presburger/PresburgerRelation.h" -#include "mlir/IR/MLIRContext.h" +#include "mlir/Analysis/Presburger/Parser.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/ParserTest.cpp b/mlir/unittests/Analysis/Presburger/ParserTest.cpp index 06b728cd1a8fa1..81026cf4ee0a5b 100644 --- a/mlir/unittests/Analysis/Presburger/ParserTest.cpp +++ b/mlir/unittests/Analysis/Presburger/ParserTest.cpp @@ -13,7 +13,7 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" +#include "mlir/Analysis/Presburger/Parser.h" #include diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp index ad71bb32a06880..b8e21dcf762004 100644 --- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/PresburgerRelation.h" -#include "Parser.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/Presburger/Parser.h" #include "mlir/Analysis/Presburger/Simplex.h" #include diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp index 8e31a8bb2030b6..ec265a71b38cdb 100644 --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -14,10 +14,9 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" #include "Utils.h" +#include "mlir/Analysis/Presburger/Parser.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" -#include "mlir/IR/MLIRContext.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp b/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp index a84f0234067ab7..bc3d788aeb4c33 100644 --- a/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp +++ b/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/QuasiPolynomial.h" -#include "./Utils.h" +#include "Utils.h" #include "mlir/Analysis/Presburger/Fraction.h" #include #include @@ -137,4 +137,4 @@ TEST(QuasiPolynomialTest, simplify) { {{{Fraction(1, 1), Fraction(3, 4), Fraction(5, 3)}, {Fraction(2, 1), Fraction(0, 1), Fraction(0, 1)}}, {{Fraction(1, 1), Fraction(4, 5), Fraction(6, 5)}}})); -} \ No newline at end of file +} diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp index 63d02438085559..59fe453eb6a052 100644 --- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp +++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp @@ -6,11 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "Parser.h" -#include "Utils.h" - #include "mlir/Analysis/Presburger/Simplex.h" -#include "mlir/IR/MLIRContext.h" +#include "mlir/Analysis/Presburger/Parser.h" #include #include diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h index 58b92671685404..89d1b0e9127f67 100644 --- a/mlir/unittests/Analysis/Presburger/Utils.h +++ b/mlir/unittests/Analysis/Presburger/Utils.h @@ -14,7 +14,6 @@ #define MLIR_UNITTESTS_ANALYSIS_PRESBURGER_UTILS_H #include "mlir/Analysis/Presburger/GeneratingFunction.h" -#include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Analysis/Presburger/Matrix.h" #include "mlir/Analysis/Presburger/QuasiPolynomial.h"