diff --git a/flang/include/flang/Lower/Support/Utils.h b/flang/include/flang/Lower/Support/Utils.h index 1cc74521e22d8..baaf644dd6efb 100644 --- a/flang/include/flang/Lower/Support/Utils.h +++ b/flang/include/flang/Lower/Support/Utils.h @@ -14,13 +14,13 @@ #define FORTRAN_LOWER_SUPPORT_UTILS_H #include "flang/Common/indirection.h" +#include "flang/Lower/IterationSpace.h" #include "flang/Parser/char-block.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "llvm/ADT/StringRef.h" -#include namespace Fortran::lower { using SomeExpr = Fortran::evaluate::Expr; @@ -87,555 +87,13 @@ A flatZip(const A &container1, const A &container2) { } namespace Fortran::lower { -// Fortran::evaluate::Expr are functional values organized like an AST. A -// Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end -// tools can often cause copies and extra wrapper classes to be added to any -// Fortran::evaluate::Expr. These values should not be assumed or relied upon to -// have an *object* identity. They are deeply recursive, irregular structures -// built from a large number of classes which do not use inheritance and -// necessitate a large volume of boilerplate code as a result. -// -// Contrastingly, LLVM data structures make ubiquitous assumptions about an -// object's identity via pointers to the object. An object's location in memory -// is thus very often an identifying relation. - -// This class defines a hash computation of a Fortran::evaluate::Expr tree value -// so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not -// have the same address. -class HashEvaluateExpr { -public: - // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an - // identity property. - static unsigned getHashValue(const Fortran::semantics::Symbol &x) { - return static_cast(reinterpret_cast(&x)); - } - template - static unsigned getHashValue(const Fortran::common::Indirection &x) { - return getHashValue(x.value()); - } - template - static unsigned getHashValue(const std::optional &x) { - if (x.has_value()) - return getHashValue(x.value()); - return 0u; - } - static unsigned getHashValue(const Fortran::evaluate::Subscript &x) { - return Fortran::common::visit( - [&](const auto &v) { return getHashValue(v); }, x.u); - } - static unsigned getHashValue(const Fortran::evaluate::Triplet &x) { - return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u - - getHashValue(x.stride()) * 11u; - } - static unsigned getHashValue(const Fortran::evaluate::Component &x) { - return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol()); - } - static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) { - unsigned subs = 1u; - for (const Fortran::evaluate::Subscript &v : x.subscript()) - subs -= getHashValue(v); - return getHashValue(x.base()) * 89u - subs; - } - static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) { - unsigned subs = 1u; - for (const Fortran::evaluate::Subscript &v : x.subscript()) - subs -= getHashValue(v); - unsigned cosubs = 3u; - for (const Fortran::evaluate::Expr &v : - x.cosubscript()) - cosubs -= getHashValue(v); - unsigned syms = 7u; - for (const Fortran::evaluate::SymbolRef &v : x.base()) - syms += getHashValue(v); - return syms * 97u - subs - cosubs + getHashValue(x.stat()) + 257u + - getHashValue(x.team()); - } - static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) { - if (x.IsSymbol()) - return getHashValue(x.GetFirstSymbol()) * 11u; - return getHashValue(x.GetComponent()) * 13u; - } - static unsigned getHashValue(const Fortran::evaluate::DataRef &x) { - return Fortran::common::visit( - [&](const auto &v) { return getHashValue(v); }, x.u); - } - static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) { - return getHashValue(x.complex()) - static_cast(x.part()); - } - template - static unsigned getHashValue( - const Fortran::evaluate::Convert, TC2> - &x) { - return getHashValue(x.left()) - (static_cast(TC1) + 2u) - - (static_cast(KIND) + 5u); - } - template - static unsigned - getHashValue(const Fortran::evaluate::ComplexComponent &x) { - return getHashValue(x.left()) - - (static_cast(x.isImaginaryPart) + 1u) * 3u; - } - template - static unsigned getHashValue(const Fortran::evaluate::Parentheses &x) { - return getHashValue(x.left()) * 17u; - } - template - static unsigned getHashValue( - const Fortran::evaluate::Negate> &x) { - return getHashValue(x.left()) - (static_cast(TC) + 5u) - - (static_cast(KIND) + 7u); - } - template - static unsigned getHashValue( - const Fortran::evaluate::Add> &x) { - return (getHashValue(x.left()) + getHashValue(x.right())) * 23u + - static_cast(TC) + static_cast(KIND); - } - template - static unsigned getHashValue( - const Fortran::evaluate::Subtract> &x) { - return (getHashValue(x.left()) - getHashValue(x.right())) * 19u + - static_cast(TC) + static_cast(KIND); - } - template - static unsigned getHashValue( - const Fortran::evaluate::Multiply> &x) { - return (getHashValue(x.left()) + getHashValue(x.right())) * 29u + - static_cast(TC) + static_cast(KIND); - } - template - static unsigned getHashValue( - const Fortran::evaluate::Divide> &x) { - return (getHashValue(x.left()) - getHashValue(x.right())) * 31u + - static_cast(TC) + static_cast(KIND); - } - template - static unsigned getHashValue( - const Fortran::evaluate::Power> &x) { - return (getHashValue(x.left()) - getHashValue(x.right())) * 37u + - static_cast(TC) + static_cast(KIND); - } - template - static unsigned getHashValue( - const Fortran::evaluate::Extremum> &x) { - return (getHashValue(x.left()) + getHashValue(x.right())) * 41u + - static_cast(TC) + static_cast(KIND) + - static_cast(x.ordering) * 7u; - } - template - static unsigned getHashValue( - const Fortran::evaluate::RealToIntPower> - &x) { - return (getHashValue(x.left()) - getHashValue(x.right())) * 43u + - static_cast(TC) + static_cast(KIND); - } - template - static unsigned - getHashValue(const Fortran::evaluate::ComplexConstructor &x) { - return (getHashValue(x.left()) - getHashValue(x.right())) * 47u + - static_cast(KIND); - } - template - static unsigned getHashValue(const Fortran::evaluate::Concat &x) { - return (getHashValue(x.left()) - getHashValue(x.right())) * 53u + - static_cast(KIND); - } - template - static unsigned getHashValue(const Fortran::evaluate::SetLength &x) { - return (getHashValue(x.left()) - getHashValue(x.right())) * 59u + - static_cast(KIND); - } - static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) { - return getHashValue(sym.get()); - } - static unsigned getHashValue(const Fortran::evaluate::Substring &x) { - return 61u * - Fortran::common::visit( - [&](const auto &p) { return getHashValue(p); }, x.parent()) - - getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u); - } - static unsigned - getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) { - return llvm::hash_value(x->name()); - } - static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) { - return llvm::hash_value(x.name); - } - template - static unsigned getHashValue(const Fortran::evaluate::Constant &x) { - // FIXME: Should hash the content. - return 103u; - } - static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) { - if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy()) - return getHashValue(*sym); - return getHashValue(*x.UnwrapExpr()); - } - static unsigned - getHashValue(const Fortran::evaluate::ProcedureDesignator &x) { - return Fortran::common::visit( - [&](const auto &v) { return getHashValue(v); }, x.u); - } - static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) { - unsigned args = 13u; - for (const std::optional &v : - x.arguments()) - args -= getHashValue(v); - return getHashValue(x.proc()) * 101u - args; - } - template - static unsigned - getHashValue(const Fortran::evaluate::ArrayConstructor &x) { - // FIXME: hash the contents. - return 127u; - } - static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) { - return llvm::hash_value(toStringRef(x.name).str()) * 131u; - } - static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) { - return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u; - } - static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) { - return getHashValue(x.base()) * 139u - - static_cast(x.field()) * 13u + - static_cast(x.dimension()); - } - static unsigned - getHashValue(const Fortran::evaluate::StructureConstructor &x) { - // FIXME: hash the contents. - return 149u; - } - template - static unsigned getHashValue(const Fortran::evaluate::Not &x) { - return getHashValue(x.left()) * 61u + static_cast(KIND); - } - template - static unsigned - getHashValue(const Fortran::evaluate::LogicalOperation &x) { - unsigned result = getHashValue(x.left()) + getHashValue(x.right()); - return result * 67u + static_cast(x.logicalOperator) * 5u; - } - template - static unsigned getHashValue( - const Fortran::evaluate::Relational> - &x) { - return (getHashValue(x.left()) + getHashValue(x.right())) * 71u + - static_cast(TC) + static_cast(KIND) + - static_cast(x.opr) * 11u; - } - template - static unsigned getHashValue(const Fortran::evaluate::Expr &x) { - return Fortran::common::visit( - [&](const auto &v) { return getHashValue(v); }, x.u); - } - static unsigned getHashValue( - const Fortran::evaluate::Relational &x) { - return Fortran::common::visit( - [&](const auto &v) { return getHashValue(v); }, x.u); - } - template - static unsigned getHashValue(const Fortran::evaluate::Designator &x) { - return Fortran::common::visit( - [&](const auto &v) { return getHashValue(v); }, x.u); - } - template - static unsigned - getHashValue(const Fortran::evaluate::value::Integer &x) { - return static_cast(x.ToSInt()); - } - static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) { - return ~179u; - } -}; +unsigned getHashValue(const Fortran::lower::SomeExpr *x); +unsigned getHashValue(const Fortran::lower::ExplicitIterSpace::ArrayBases &x); -// Define the is equals test for using Fortran::evaluate::Expr values with -// llvm::DenseMap. -class IsEqualEvaluateExpr { -public: - // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an - // identity property. - static bool isEqual(const Fortran::semantics::Symbol &x, - const Fortran::semantics::Symbol &y) { - return isEqual(&x, &y); - } - static bool isEqual(const Fortran::semantics::Symbol *x, - const Fortran::semantics::Symbol *y) { - return x == y; - } - template - static bool isEqual(const Fortran::common::Indirection &x, - const Fortran::common::Indirection &y) { - return isEqual(x.value(), y.value()); - } - template - static bool isEqual(const std::optional &x, const std::optional &y) { - if (x.has_value() && y.has_value()) - return isEqual(x.value(), y.value()); - return !x.has_value() && !y.has_value(); - } - template - static bool isEqual(const std::vector &x, const std::vector &y) { - if (x.size() != y.size()) - return false; - const std::size_t size = x.size(); - for (std::remove_const_t i = 0; i < size; ++i) - if (!isEqual(x[i], y[i])) - return false; - return true; - } - static bool isEqual(const Fortran::evaluate::Subscript &x, - const Fortran::evaluate::Subscript &y) { - return Fortran::common::visit( - [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); - } - static bool isEqual(const Fortran::evaluate::Triplet &x, - const Fortran::evaluate::Triplet &y) { - return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) && - isEqual(x.stride(), y.stride()); - } - static bool isEqual(const Fortran::evaluate::Component &x, - const Fortran::evaluate::Component &y) { - return isEqual(x.base(), y.base()) && - isEqual(x.GetLastSymbol(), y.GetLastSymbol()); - } - static bool isEqual(const Fortran::evaluate::ArrayRef &x, - const Fortran::evaluate::ArrayRef &y) { - return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript()); - } - static bool isEqual(const Fortran::evaluate::CoarrayRef &x, - const Fortran::evaluate::CoarrayRef &y) { - return isEqual(x.base(), y.base()) && - isEqual(x.subscript(), y.subscript()) && - isEqual(x.cosubscript(), y.cosubscript()) && - isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team()); - } - static bool isEqual(const Fortran::evaluate::NamedEntity &x, - const Fortran::evaluate::NamedEntity &y) { - if (x.IsSymbol() && y.IsSymbol()) - return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol()); - return !x.IsSymbol() && !y.IsSymbol() && - isEqual(x.GetComponent(), y.GetComponent()); - } - static bool isEqual(const Fortran::evaluate::DataRef &x, - const Fortran::evaluate::DataRef &y) { - return Fortran::common::visit( - [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); - } - static bool isEqual(const Fortran::evaluate::ComplexPart &x, - const Fortran::evaluate::ComplexPart &y) { - return isEqual(x.complex(), y.complex()) && x.part() == y.part(); - } - template - static bool isEqual(const Fortran::evaluate::Convert &x, - const Fortran::evaluate::Convert &y) { - return isEqual(x.left(), y.left()); - } - template - static bool isEqual(const Fortran::evaluate::ComplexComponent &x, - const Fortran::evaluate::ComplexComponent &y) { - return isEqual(x.left(), y.left()) && - x.isImaginaryPart == y.isImaginaryPart; - } - template - static bool isEqual(const Fortran::evaluate::Parentheses &x, - const Fortran::evaluate::Parentheses &y) { - return isEqual(x.left(), y.left()); - } - template - static bool isEqual(const Fortran::evaluate::Negate &x, - const Fortran::evaluate::Negate &y) { - return isEqual(x.left(), y.left()); - } - template - static bool isBinaryEqual(const A &x, const A &y) { - return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); - } - template - static bool isEqual(const Fortran::evaluate::Add &x, - const Fortran::evaluate::Add &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::Subtract &x, - const Fortran::evaluate::Subtract &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::Multiply &x, - const Fortran::evaluate::Multiply &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::Divide &x, - const Fortran::evaluate::Divide &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::Power &x, - const Fortran::evaluate::Power &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::Extremum &x, - const Fortran::evaluate::Extremum &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::RealToIntPower &x, - const Fortran::evaluate::RealToIntPower &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::ComplexConstructor &x, - const Fortran::evaluate::ComplexConstructor &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::Concat &x, - const Fortran::evaluate::Concat &y) { - return isBinaryEqual(x, y); - } - template - static bool isEqual(const Fortran::evaluate::SetLength &x, - const Fortran::evaluate::SetLength &y) { - return isBinaryEqual(x, y); - } - static bool isEqual(const Fortran::semantics::SymbolRef &x, - const Fortran::semantics::SymbolRef &y) { - return isEqual(x.get(), y.get()); - } - static bool isEqual(const Fortran::evaluate::Substring &x, - const Fortran::evaluate::Substring &y) { - return Fortran::common::visit( - [&](const auto &p, const auto &q) { return isEqual(p, q); }, - x.parent(), y.parent()) && - isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()); - } - static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x, - const Fortran::evaluate::StaticDataObject::Pointer &y) { - return x->name() == y->name(); - } - static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x, - const Fortran::evaluate::SpecificIntrinsic &y) { - return x.name == y.name; - } - template - static bool isEqual(const Fortran::evaluate::Constant &x, - const Fortran::evaluate::Constant &y) { - return x == y; - } - static bool isEqual(const Fortran::evaluate::ActualArgument &x, - const Fortran::evaluate::ActualArgument &y) { - if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) { - if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy()) - return isEqual(*xs, *ys); - return false; - } - return !y.GetAssumedTypeDummy() && - isEqual(*x.UnwrapExpr(), *y.UnwrapExpr()); - } - static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x, - const Fortran::evaluate::ProcedureDesignator &y) { - return Fortran::common::visit( - [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); - } - static bool isEqual(const Fortran::evaluate::ProcedureRef &x, - const Fortran::evaluate::ProcedureRef &y) { - return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments()); - } - template - static bool isEqual(const Fortran::evaluate::ArrayConstructor &x, - const Fortran::evaluate::ArrayConstructor &y) { - llvm::report_fatal_error("not implemented"); - } - static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x, - const Fortran::evaluate::ImpliedDoIndex &y) { - return toStringRef(x.name) == toStringRef(y.name); - } - static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x, - const Fortran::evaluate::TypeParamInquiry &y) { - return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter()); - } - static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x, - const Fortran::evaluate::DescriptorInquiry &y) { - return isEqual(x.base(), y.base()) && x.field() == y.field() && - x.dimension() == y.dimension(); - } - static bool isEqual(const Fortran::evaluate::StructureConstructor &x, - const Fortran::evaluate::StructureConstructor &y) { - const auto &xValues = x.values(); - const auto &yValues = y.values(); - if (xValues.size() != yValues.size()) - return false; - if (x.derivedTypeSpec() != y.derivedTypeSpec()) - return false; - for (const auto &[xSymbol, xValue] : xValues) { - auto yIt = yValues.find(xSymbol); - // This should probably never happen, since the derived type - // should be the same. - if (yIt == yValues.end()) - return false; - if (!isEqual(xValue, yIt->second)) - return false; - } - return true; - } - template - static bool isEqual(const Fortran::evaluate::Not &x, - const Fortran::evaluate::Not &y) { - return isEqual(x.left(), y.left()); - } - template - static bool isEqual(const Fortran::evaluate::LogicalOperation &x, - const Fortran::evaluate::LogicalOperation &y) { - return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); - } - template - static bool isEqual(const Fortran::evaluate::Relational &x, - const Fortran::evaluate::Relational &y) { - return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); - } - template - static bool isEqual(const Fortran::evaluate::Expr &x, - const Fortran::evaluate::Expr &y) { - return Fortran::common::visit( - [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); - } - static bool - isEqual(const Fortran::evaluate::Relational &x, - const Fortran::evaluate::Relational &y) { - return Fortran::common::visit( - [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); - } - template - static bool isEqual(const Fortran::evaluate::Designator &x, - const Fortran::evaluate::Designator &y) { - return Fortran::common::visit( - [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); - } - template - static bool isEqual(const Fortran::evaluate::value::Integer &x, - const Fortran::evaluate::value::Integer &y) { - return x == y; - } - static bool isEqual(const Fortran::evaluate::NullPointer &x, - const Fortran::evaluate::NullPointer &y) { - return true; - } - template , bool> = true> - static bool isEqual(const A &, const B &) { - return false; - } -}; - -static inline unsigned getHashValue(const Fortran::lower::SomeExpr *x) { - return HashEvaluateExpr::getHashValue(*x); -} - -static bool isEqual(const Fortran::lower::SomeExpr *x, - const Fortran::lower::SomeExpr *y); +bool isEqual(const Fortran::lower::SomeExpr *x, + const Fortran::lower::SomeExpr *y); +bool isEqual(const Fortran::lower::ExplicitIterSpace::ArrayBases &x, + const Fortran::lower::ExplicitIterSpace::ArrayBases &y); } // end namespace Fortran::lower // DenseMapInfo for pointers to Fortran::lower::SomeExpr. @@ -658,17 +116,4 @@ struct DenseMapInfo { }; } // namespace llvm -namespace Fortran::lower { -static inline bool isEqual(const Fortran::lower::SomeExpr *x, - const Fortran::lower::SomeExpr *y) { - const auto *empty = - llvm::DenseMapInfo::getEmptyKey(); - const auto *tombstone = - llvm::DenseMapInfo::getTombstoneKey(); - if (x == empty || y == empty || x == tombstone || y == tombstone) - return x == y; - return x == y || IsEqualEvaluateExpr::isEqual(*x, *y); -} -} // end namespace Fortran::lower - #endif // FORTRAN_LOWER_SUPPORT_UTILS_H diff --git a/flang/lib/Lower/CMakeLists.txt b/flang/lib/Lower/CMakeLists.txt index f611010765cb5..8a4ed1c067cb2 100644 --- a/flang/lib/Lower/CMakeLists.txt +++ b/flang/lib/Lower/CMakeLists.txt @@ -34,6 +34,7 @@ add_flang_library(FortranLower OpenMP/Utils.cpp PFTBuilder.cpp Runtime.cpp + Support/Utils.cpp SymbolMap.cpp VectorSubscripts.cpp diff --git a/flang/lib/Lower/IterationSpace.cpp b/flang/lib/Lower/IterationSpace.cpp index 63011483022b7..b011b3ab9a248 100644 --- a/flang/lib/Lower/IterationSpace.cpp +++ b/flang/lib/Lower/IterationSpace.cpp @@ -19,36 +19,6 @@ #define DEBUG_TYPE "flang-lower-iteration-space" -unsigned Fortran::lower::getHashValue( - const Fortran::lower::ExplicitIterSpace::ArrayBases &x) { - return Fortran::common::visit( - [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x); -} - -bool Fortran::lower::isEqual( - const Fortran::lower::ExplicitIterSpace::ArrayBases &x, - const Fortran::lower::ExplicitIterSpace::ArrayBases &y) { - return Fortran::common::visit( - Fortran::common::visitors{ - // Fortran::semantics::Symbol * are the exception here. These pointers - // have identity; if two Symbol * values are the same (different) then - // they are the same (different) logical symbol. - [&](Fortran::lower::FrontEndSymbol p, - Fortran::lower::FrontEndSymbol q) { return p == q; }, - [&](const auto *p, const auto *q) { - if constexpr (std::is_same_v) { - LLVM_DEBUG(llvm::dbgs() - << "is equal: " << p << ' ' << q << ' ' - << IsEqualEvaluateExpr::isEqual(*p, *q) << '\n'); - return IsEqualEvaluateExpr::isEqual(*p, *q); - } else { - // Different subtree types are never equal. - return false; - } - }}, - x, y); -} - namespace { /// This class can recover the base array in an expression that contains diff --git a/flang/lib/Lower/Support/Utils.cpp b/flang/lib/Lower/Support/Utils.cpp new file mode 100644 index 0000000000000..5a9a839330364 --- /dev/null +++ b/flang/lib/Lower/Support/Utils.cpp @@ -0,0 +1,605 @@ +//===-- Lower/Support/Utils.cpp -- utilities --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ +// +//===----------------------------------------------------------------------===// + +#include "flang/Lower/Support/Utils.h" + +#include "flang/Common/indirection.h" +#include "flang/Lower/IterationSpace.h" +#include "flang/Semantics/tools.h" +#include +#include +#include + +namespace Fortran::lower { +// Fortran::evaluate::Expr are functional values organized like an AST. A +// Fortran::evaluate::Expr is meant to be moved and cloned. Using the front end +// tools can often cause copies and extra wrapper classes to be added to any +// Fortran::evaluate::Expr. These values should not be assumed or relied upon to +// have an *object* identity. They are deeply recursive, irregular structures +// built from a large number of classes which do not use inheritance and +// necessitate a large volume of boilerplate code as a result. +// +// Contrastingly, LLVM data structures make ubiquitous assumptions about an +// object's identity via pointers to the object. An object's location in memory +// is thus very often an identifying relation. + +// This class defines a hash computation of a Fortran::evaluate::Expr tree value +// so it can be used with llvm::DenseMap. The Fortran::evaluate::Expr need not +// have the same address. +class HashEvaluateExpr { +public: + // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an + // identity property. + static unsigned getHashValue(const Fortran::semantics::Symbol &x) { + return static_cast(reinterpret_cast(&x)); + } + template + static unsigned getHashValue(const Fortran::common::Indirection &x) { + return getHashValue(x.value()); + } + template + static unsigned getHashValue(const std::optional &x) { + if (x.has_value()) + return getHashValue(x.value()); + return 0u; + } + static unsigned getHashValue(const Fortran::evaluate::Subscript &x) { + return Fortran::common::visit( + [&](const auto &v) { return getHashValue(v); }, x.u); + } + static unsigned getHashValue(const Fortran::evaluate::Triplet &x) { + return getHashValue(x.lower()) - getHashValue(x.upper()) * 5u - + getHashValue(x.stride()) * 11u; + } + static unsigned getHashValue(const Fortran::evaluate::Component &x) { + return getHashValue(x.base()) * 83u - getHashValue(x.GetLastSymbol()); + } + static unsigned getHashValue(const Fortran::evaluate::ArrayRef &x) { + unsigned subs = 1u; + for (const Fortran::evaluate::Subscript &v : x.subscript()) + subs -= getHashValue(v); + return getHashValue(x.base()) * 89u - subs; + } + static unsigned getHashValue(const Fortran::evaluate::CoarrayRef &x) { + unsigned subs = 1u; + for (const Fortran::evaluate::Subscript &v : x.subscript()) + subs -= getHashValue(v); + unsigned cosubs = 3u; + for (const Fortran::evaluate::Expr &v : + x.cosubscript()) + cosubs -= getHashValue(v); + unsigned syms = 7u; + for (const Fortran::evaluate::SymbolRef &v : x.base()) + syms += getHashValue(v); + return syms * 97u - subs - cosubs + getHashValue(x.stat()) + 257u + + getHashValue(x.team()); + } + static unsigned getHashValue(const Fortran::evaluate::NamedEntity &x) { + if (x.IsSymbol()) + return getHashValue(x.GetFirstSymbol()) * 11u; + return getHashValue(x.GetComponent()) * 13u; + } + static unsigned getHashValue(const Fortran::evaluate::DataRef &x) { + return Fortran::common::visit( + [&](const auto &v) { return getHashValue(v); }, x.u); + } + static unsigned getHashValue(const Fortran::evaluate::ComplexPart &x) { + return getHashValue(x.complex()) - static_cast(x.part()); + } + template + static unsigned getHashValue( + const Fortran::evaluate::Convert, TC2> + &x) { + return getHashValue(x.left()) - (static_cast(TC1) + 2u) - + (static_cast(KIND) + 5u); + } + template + static unsigned + getHashValue(const Fortran::evaluate::ComplexComponent &x) { + return getHashValue(x.left()) - + (static_cast(x.isImaginaryPart) + 1u) * 3u; + } + template + static unsigned getHashValue(const Fortran::evaluate::Parentheses &x) { + return getHashValue(x.left()) * 17u; + } + template + static unsigned getHashValue( + const Fortran::evaluate::Negate> &x) { + return getHashValue(x.left()) - (static_cast(TC) + 5u) - + (static_cast(KIND) + 7u); + } + template + static unsigned getHashValue( + const Fortran::evaluate::Add> &x) { + return (getHashValue(x.left()) + getHashValue(x.right())) * 23u + + static_cast(TC) + static_cast(KIND); + } + template + static unsigned getHashValue( + const Fortran::evaluate::Subtract> &x) { + return (getHashValue(x.left()) - getHashValue(x.right())) * 19u + + static_cast(TC) + static_cast(KIND); + } + template + static unsigned getHashValue( + const Fortran::evaluate::Multiply> &x) { + return (getHashValue(x.left()) + getHashValue(x.right())) * 29u + + static_cast(TC) + static_cast(KIND); + } + template + static unsigned getHashValue( + const Fortran::evaluate::Divide> &x) { + return (getHashValue(x.left()) - getHashValue(x.right())) * 31u + + static_cast(TC) + static_cast(KIND); + } + template + static unsigned getHashValue( + const Fortran::evaluate::Power> &x) { + return (getHashValue(x.left()) - getHashValue(x.right())) * 37u + + static_cast(TC) + static_cast(KIND); + } + template + static unsigned getHashValue( + const Fortran::evaluate::Extremum> &x) { + return (getHashValue(x.left()) + getHashValue(x.right())) * 41u + + static_cast(TC) + static_cast(KIND) + + static_cast(x.ordering) * 7u; + } + template + static unsigned getHashValue( + const Fortran::evaluate::RealToIntPower> + &x) { + return (getHashValue(x.left()) - getHashValue(x.right())) * 43u + + static_cast(TC) + static_cast(KIND); + } + template + static unsigned + getHashValue(const Fortran::evaluate::ComplexConstructor &x) { + return (getHashValue(x.left()) - getHashValue(x.right())) * 47u + + static_cast(KIND); + } + template + static unsigned getHashValue(const Fortran::evaluate::Concat &x) { + return (getHashValue(x.left()) - getHashValue(x.right())) * 53u + + static_cast(KIND); + } + template + static unsigned getHashValue(const Fortran::evaluate::SetLength &x) { + return (getHashValue(x.left()) - getHashValue(x.right())) * 59u + + static_cast(KIND); + } + static unsigned getHashValue(const Fortran::semantics::SymbolRef &sym) { + return getHashValue(sym.get()); + } + static unsigned getHashValue(const Fortran::evaluate::Substring &x) { + return 61u * + Fortran::common::visit( + [&](const auto &p) { return getHashValue(p); }, x.parent()) - + getHashValue(x.lower()) - (getHashValue(x.lower()) + 1u); + } + static unsigned + getHashValue(const Fortran::evaluate::StaticDataObject::Pointer &x) { + return llvm::hash_value(x->name()); + } + static unsigned getHashValue(const Fortran::evaluate::SpecificIntrinsic &x) { + return llvm::hash_value(x.name); + } + template + static unsigned getHashValue(const Fortran::evaluate::Constant &x) { + // FIXME: Should hash the content. + return 103u; + } + static unsigned getHashValue(const Fortran::evaluate::ActualArgument &x) { + if (const Fortran::evaluate::Symbol *sym = x.GetAssumedTypeDummy()) + return getHashValue(*sym); + return getHashValue(*x.UnwrapExpr()); + } + static unsigned + getHashValue(const Fortran::evaluate::ProcedureDesignator &x) { + return Fortran::common::visit( + [&](const auto &v) { return getHashValue(v); }, x.u); + } + static unsigned getHashValue(const Fortran::evaluate::ProcedureRef &x) { + unsigned args = 13u; + for (const std::optional &v : + x.arguments()) + args -= getHashValue(v); + return getHashValue(x.proc()) * 101u - args; + } + template + static unsigned + getHashValue(const Fortran::evaluate::ArrayConstructor &x) { + // FIXME: hash the contents. + return 127u; + } + static unsigned getHashValue(const Fortran::evaluate::ImpliedDoIndex &x) { + return llvm::hash_value(toStringRef(x.name).str()) * 131u; + } + static unsigned getHashValue(const Fortran::evaluate::TypeParamInquiry &x) { + return getHashValue(x.base()) * 137u - getHashValue(x.parameter()) * 3u; + } + static unsigned getHashValue(const Fortran::evaluate::DescriptorInquiry &x) { + return getHashValue(x.base()) * 139u - + static_cast(x.field()) * 13u + + static_cast(x.dimension()); + } + static unsigned + getHashValue(const Fortran::evaluate::StructureConstructor &x) { + // FIXME: hash the contents. + return 149u; + } + template + static unsigned getHashValue(const Fortran::evaluate::Not &x) { + return getHashValue(x.left()) * 61u + static_cast(KIND); + } + template + static unsigned + getHashValue(const Fortran::evaluate::LogicalOperation &x) { + unsigned result = getHashValue(x.left()) + getHashValue(x.right()); + return result * 67u + static_cast(x.logicalOperator) * 5u; + } + template + static unsigned getHashValue( + const Fortran::evaluate::Relational> + &x) { + return (getHashValue(x.left()) + getHashValue(x.right())) * 71u + + static_cast(TC) + static_cast(KIND) + + static_cast(x.opr) * 11u; + } + template + static unsigned getHashValue(const Fortran::evaluate::Expr &x) { + return Fortran::common::visit( + [&](const auto &v) { return getHashValue(v); }, x.u); + } + static unsigned getHashValue( + const Fortran::evaluate::Relational &x) { + return Fortran::common::visit( + [&](const auto &v) { return getHashValue(v); }, x.u); + } + template + static unsigned getHashValue(const Fortran::evaluate::Designator &x) { + return Fortran::common::visit( + [&](const auto &v) { return getHashValue(v); }, x.u); + } + template + static unsigned + getHashValue(const Fortran::evaluate::value::Integer &x) { + return static_cast(x.ToSInt()); + } + static unsigned getHashValue(const Fortran::evaluate::NullPointer &x) { + return ~179u; + } +}; + +// Define the is equals test for using Fortran::evaluate::Expr values with +// llvm::DenseMap. +class IsEqualEvaluateExpr { +public: + // A Se::Symbol is the only part of an Fortran::evaluate::Expr with an + // identity property. + static bool isEqual(const Fortran::semantics::Symbol &x, + const Fortran::semantics::Symbol &y) { + return isEqual(&x, &y); + } + static bool isEqual(const Fortran::semantics::Symbol *x, + const Fortran::semantics::Symbol *y) { + return x == y; + } + template + static bool isEqual(const Fortran::common::Indirection &x, + const Fortran::common::Indirection &y) { + return isEqual(x.value(), y.value()); + } + template + static bool isEqual(const std::optional &x, const std::optional &y) { + if (x.has_value() && y.has_value()) + return isEqual(x.value(), y.value()); + return !x.has_value() && !y.has_value(); + } + template + static bool isEqual(const std::vector &x, const std::vector &y) { + if (x.size() != y.size()) + return false; + const std::size_t size = x.size(); + for (std::remove_const_t i = 0; i < size; ++i) + if (!isEqual(x[i], y[i])) + return false; + return true; + } + static bool isEqual(const Fortran::evaluate::Subscript &x, + const Fortran::evaluate::Subscript &y) { + return Fortran::common::visit( + [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); + } + static bool isEqual(const Fortran::evaluate::Triplet &x, + const Fortran::evaluate::Triplet &y) { + return isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()) && + isEqual(x.stride(), y.stride()); + } + static bool isEqual(const Fortran::evaluate::Component &x, + const Fortran::evaluate::Component &y) { + return isEqual(x.base(), y.base()) && + isEqual(x.GetLastSymbol(), y.GetLastSymbol()); + } + static bool isEqual(const Fortran::evaluate::ArrayRef &x, + const Fortran::evaluate::ArrayRef &y) { + return isEqual(x.base(), y.base()) && isEqual(x.subscript(), y.subscript()); + } + static bool isEqual(const Fortran::evaluate::CoarrayRef &x, + const Fortran::evaluate::CoarrayRef &y) { + return isEqual(x.base(), y.base()) && + isEqual(x.subscript(), y.subscript()) && + isEqual(x.cosubscript(), y.cosubscript()) && + isEqual(x.stat(), y.stat()) && isEqual(x.team(), y.team()); + } + static bool isEqual(const Fortran::evaluate::NamedEntity &x, + const Fortran::evaluate::NamedEntity &y) { + if (x.IsSymbol() && y.IsSymbol()) + return isEqual(x.GetFirstSymbol(), y.GetFirstSymbol()); + return !x.IsSymbol() && !y.IsSymbol() && + isEqual(x.GetComponent(), y.GetComponent()); + } + static bool isEqual(const Fortran::evaluate::DataRef &x, + const Fortran::evaluate::DataRef &y) { + return Fortran::common::visit( + [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); + } + static bool isEqual(const Fortran::evaluate::ComplexPart &x, + const Fortran::evaluate::ComplexPart &y) { + return isEqual(x.complex(), y.complex()) && x.part() == y.part(); + } + template + static bool isEqual(const Fortran::evaluate::Convert &x, + const Fortran::evaluate::Convert &y) { + return isEqual(x.left(), y.left()); + } + template + static bool isEqual(const Fortran::evaluate::ComplexComponent &x, + const Fortran::evaluate::ComplexComponent &y) { + return isEqual(x.left(), y.left()) && + x.isImaginaryPart == y.isImaginaryPart; + } + template + static bool isEqual(const Fortran::evaluate::Parentheses &x, + const Fortran::evaluate::Parentheses &y) { + return isEqual(x.left(), y.left()); + } + template + static bool isEqual(const Fortran::evaluate::Negate &x, + const Fortran::evaluate::Negate &y) { + return isEqual(x.left(), y.left()); + } + template + static bool isBinaryEqual(const A &x, const A &y) { + return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); + } + template + static bool isEqual(const Fortran::evaluate::Add &x, + const Fortran::evaluate::Add &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::Subtract &x, + const Fortran::evaluate::Subtract &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::Multiply &x, + const Fortran::evaluate::Multiply &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::Divide &x, + const Fortran::evaluate::Divide &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::Power &x, + const Fortran::evaluate::Power &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::Extremum &x, + const Fortran::evaluate::Extremum &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::RealToIntPower &x, + const Fortran::evaluate::RealToIntPower &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::ComplexConstructor &x, + const Fortran::evaluate::ComplexConstructor &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::Concat &x, + const Fortran::evaluate::Concat &y) { + return isBinaryEqual(x, y); + } + template + static bool isEqual(const Fortran::evaluate::SetLength &x, + const Fortran::evaluate::SetLength &y) { + return isBinaryEqual(x, y); + } + static bool isEqual(const Fortran::semantics::SymbolRef &x, + const Fortran::semantics::SymbolRef &y) { + return isEqual(x.get(), y.get()); + } + static bool isEqual(const Fortran::evaluate::Substring &x, + const Fortran::evaluate::Substring &y) { + return Fortran::common::visit( + [&](const auto &p, const auto &q) { return isEqual(p, q); }, + x.parent(), y.parent()) && + isEqual(x.lower(), y.lower()) && isEqual(x.upper(), y.upper()); + } + static bool isEqual(const Fortran::evaluate::StaticDataObject::Pointer &x, + const Fortran::evaluate::StaticDataObject::Pointer &y) { + return x->name() == y->name(); + } + static bool isEqual(const Fortran::evaluate::SpecificIntrinsic &x, + const Fortran::evaluate::SpecificIntrinsic &y) { + return x.name == y.name; + } + template + static bool isEqual(const Fortran::evaluate::Constant &x, + const Fortran::evaluate::Constant &y) { + return x == y; + } + static bool isEqual(const Fortran::evaluate::ActualArgument &x, + const Fortran::evaluate::ActualArgument &y) { + if (const Fortran::evaluate::Symbol *xs = x.GetAssumedTypeDummy()) { + if (const Fortran::evaluate::Symbol *ys = y.GetAssumedTypeDummy()) + return isEqual(*xs, *ys); + return false; + } + return !y.GetAssumedTypeDummy() && + isEqual(*x.UnwrapExpr(), *y.UnwrapExpr()); + } + static bool isEqual(const Fortran::evaluate::ProcedureDesignator &x, + const Fortran::evaluate::ProcedureDesignator &y) { + return Fortran::common::visit( + [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); + } + static bool isEqual(const Fortran::evaluate::ProcedureRef &x, + const Fortran::evaluate::ProcedureRef &y) { + return isEqual(x.proc(), y.proc()) && isEqual(x.arguments(), y.arguments()); + } + template + static bool isEqual(const Fortran::evaluate::ArrayConstructor &x, + const Fortran::evaluate::ArrayConstructor &y) { + llvm::report_fatal_error("not implemented"); + } + static bool isEqual(const Fortran::evaluate::ImpliedDoIndex &x, + const Fortran::evaluate::ImpliedDoIndex &y) { + return toStringRef(x.name) == toStringRef(y.name); + } + static bool isEqual(const Fortran::evaluate::TypeParamInquiry &x, + const Fortran::evaluate::TypeParamInquiry &y) { + return isEqual(x.base(), y.base()) && isEqual(x.parameter(), y.parameter()); + } + static bool isEqual(const Fortran::evaluate::DescriptorInquiry &x, + const Fortran::evaluate::DescriptorInquiry &y) { + return isEqual(x.base(), y.base()) && x.field() == y.field() && + x.dimension() == y.dimension(); + } + static bool isEqual(const Fortran::evaluate::StructureConstructor &x, + const Fortran::evaluate::StructureConstructor &y) { + const auto &xValues = x.values(); + const auto &yValues = y.values(); + if (xValues.size() != yValues.size()) + return false; + if (x.derivedTypeSpec() != y.derivedTypeSpec()) + return false; + for (const auto &[xSymbol, xValue] : xValues) { + auto yIt = yValues.find(xSymbol); + // This should probably never happen, since the derived type + // should be the same. + if (yIt == yValues.end()) + return false; + if (!isEqual(xValue, yIt->second)) + return false; + } + return true; + } + template + static bool isEqual(const Fortran::evaluate::Not &x, + const Fortran::evaluate::Not &y) { + return isEqual(x.left(), y.left()); + } + template + static bool isEqual(const Fortran::evaluate::LogicalOperation &x, + const Fortran::evaluate::LogicalOperation &y) { + return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); + } + template + static bool isEqual(const Fortran::evaluate::Relational &x, + const Fortran::evaluate::Relational &y) { + return isEqual(x.left(), y.left()) && isEqual(x.right(), y.right()); + } + template + static bool isEqual(const Fortran::evaluate::Expr &x, + const Fortran::evaluate::Expr &y) { + return Fortran::common::visit( + [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); + } + static bool + isEqual(const Fortran::evaluate::Relational &x, + const Fortran::evaluate::Relational &y) { + return Fortran::common::visit( + [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); + } + template + static bool isEqual(const Fortran::evaluate::Designator &x, + const Fortran::evaluate::Designator &y) { + return Fortran::common::visit( + [&](const auto &v, const auto &w) { return isEqual(v, w); }, x.u, y.u); + } + template + static bool isEqual(const Fortran::evaluate::value::Integer &x, + const Fortran::evaluate::value::Integer &y) { + return x == y; + } + static bool isEqual(const Fortran::evaluate::NullPointer &x, + const Fortran::evaluate::NullPointer &y) { + return true; + } + template , bool> = true> + static bool isEqual(const A &, const B &) { + return false; + } +}; + +unsigned getHashValue(const Fortran::lower::SomeExpr *x) { + return HashEvaluateExpr::getHashValue(*x); +} + +unsigned getHashValue(const Fortran::lower::ExplicitIterSpace::ArrayBases &x) { + return Fortran::common::visit( + [&](const auto *p) { return HashEvaluateExpr::getHashValue(*p); }, x); +} + +bool isEqual(const Fortran::lower::SomeExpr *x, + const Fortran::lower::SomeExpr *y) { + const auto *empty = + llvm::DenseMapInfo::getEmptyKey(); + const auto *tombstone = + llvm::DenseMapInfo::getTombstoneKey(); + if (x == empty || y == empty || x == tombstone || y == tombstone) + return x == y; + return x == y || IsEqualEvaluateExpr::isEqual(*x, *y); +} + +bool isEqual(const Fortran::lower::ExplicitIterSpace::ArrayBases &x, + const Fortran::lower::ExplicitIterSpace::ArrayBases &y) { + return Fortran::common::visit( + Fortran::common::visitors{ + // Fortran::semantics::Symbol * are the exception here. These pointers + // have identity; if two Symbol * values are the same (different) then + // they are the same (different) logical symbol. + [&](Fortran::lower::FrontEndSymbol p, + Fortran::lower::FrontEndSymbol q) { return p == q; }, + [&](const auto *p, const auto *q) { + if constexpr (std::is_same_v) { + return IsEqualEvaluateExpr::isEqual(*p, *q); + } else { + // Different subtree types are never equal. + return false; + } + }}, + x, y); +} +} // end namespace Fortran::lower