Skip to content

Commit

Permalink
[flang] Fix UBOUND() folding for constant arrays
Browse files Browse the repository at this point in the history
Similarly to LBOUND in https://reviews.llvm.org/D123237, fix UBOUND() folding
for constant arrays (for both w/ and w/o DIM=): convert
GetConstantArrayLboundHelper into common helper class for both lower/upper
bounds.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D123520
  • Loading branch information
FruitClover committed Apr 27, 2022
1 parent 88bc24a commit d8b4ea4
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 20 deletions.
1 change: 1 addition & 0 deletions flang/include/flang/Evaluate/constant.h
Expand Up @@ -64,6 +64,7 @@ class ConstantBounds {
~ConstantBounds();
const ConstantSubscripts &shape() const { return shape_; }
const ConstantSubscripts &lbounds() const { return lbounds_; }
ConstantSubscripts ComputeUbounds(std::optional<int> dim) const;
void set_lbounds(ConstantSubscripts &&);
void SetLowerBoundsToOne();
int Rank() const { return GetRank(shape_); }
Expand Down
14 changes: 14 additions & 0 deletions flang/lib/Evaluate/constant.cpp
Expand Up @@ -32,6 +32,20 @@ void ConstantBounds::set_lbounds(ConstantSubscripts &&lb) {
}
}

ConstantSubscripts ConstantBounds::ComputeUbounds(
std::optional<int> dim) const {
if (dim) {
CHECK(*dim < Rank());
return {lbounds_[*dim] + shape_[*dim] - 1};
} else {
ConstantSubscripts ubounds(Rank());
for (int i{0}; i < Rank(); ++i) {
ubounds[i] = lbounds_[i] + shape_[i] - 1;
}
return ubounds;
}
}

void ConstantBounds::SetLowerBoundsToOne() {
for (auto &n : lbounds_) {
n = 1;
Expand Down
59 changes: 41 additions & 18 deletions flang/lib/Evaluate/fold-integer.cpp
Expand Up @@ -29,44 +29,66 @@ Expr<T> PackageConstantBounds(
}
}

// Class to retrieve the constant lower bound of an expression which is an
// Class to retrieve the constant bound of an expression which is an
// array that devolves to a type of Constant<T>
class GetConstantArrayLboundHelper {
class GetConstantArrayBoundHelper {
public:
GetConstantArrayLboundHelper(std::optional<ConstantSubscript> dim)
: dim_{dim} {}
template <typename T>
static Expr<T> GetLbound(
const Expr<SomeType> &array, std::optional<int> dim) {
return PackageConstantBounds<T>(
GetConstantArrayBoundHelper(dim, /*getLbound=*/true).Get(array),
dim.has_value());
}

template <typename T>
static Expr<T> GetUbound(
const Expr<SomeType> &array, std::optional<int> dim) {
return PackageConstantBounds<T>(
GetConstantArrayBoundHelper(dim, /*getLbound=*/false).Get(array),
dim.has_value());
}

private:
GetConstantArrayBoundHelper(
std::optional<ConstantSubscript> dim, bool getLbound)
: dim_{dim}, getLbound_{getLbound} {}

template <typename T> ConstantSubscripts GetLbound(const T &) {
template <typename T> ConstantSubscripts Get(const T &) {
// The method is needed for template expansion, but we should never get
// here in practice.
CHECK(false);
return {0};
}

template <typename T> ConstantSubscripts GetLbound(const Constant<T> &x) {
// Return the lower bound
if (dim_) {
return {x.lbounds().at(*dim_)};
template <typename T> ConstantSubscripts Get(const Constant<T> &x) {
if (getLbound_) {
// Return the lower bound
if (dim_) {
return {x.lbounds().at(*dim_)};
} else {
return x.lbounds();
}
} else {
return x.lbounds();
return x.ComputeUbounds(dim_);
}
}

template <typename T> ConstantSubscripts GetLbound(const Parentheses<T> &x) {
template <typename T> ConstantSubscripts Get(const Parentheses<T> &x) {
// LBOUND for (x) is [1, ..., 1] cause of temp variable inside
// parentheses (lower bound is omitted, the default value is 1).
return ConstantSubscripts(x.Rank(), ConstantSubscript{1});
}

template <typename T> ConstantSubscripts GetLbound(const Expr<T> &x) {
template <typename T> ConstantSubscripts Get(const Expr<T> &x) {
// recurse through Expr<T>'a until we hit a constant
return common::visit([&](const auto &inner) { return GetLbound(inner); },
return common::visit([&](const auto &inner) { return Get(inner); },
// [&](const auto &) { return 0; },
x.u);
}

private:
std::optional<ConstantSubscript> dim_;
const std::optional<ConstantSubscript> dim_;
const bool getLbound_;
};

template <int KIND>
Expand Down Expand Up @@ -112,9 +134,7 @@ Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context,
}
}
if (IsActuallyConstant(*array)) {
const ConstantSubscripts bounds{
GetConstantArrayLboundHelper{dim}.GetLbound(*array)};
return PackageConstantBounds<T>(std::move(bounds), dim.has_value());
return GetConstantArrayBoundHelper::GetLbound<T>(*array, dim);
}
if (lowerBoundsAreOne) {
ConstantSubscripts ones(rank, ConstantSubscript{1});
Expand Down Expand Up @@ -178,6 +198,9 @@ Expr<Type<TypeCategory::Integer, KIND>> UBOUND(FoldingContext &context,
takeBoundsFromShape = symbol.Rank() == 0; // UBOUND(array%component)
}
}
if (IsActuallyConstant(*array)) {
return GetConstantArrayBoundHelper::GetUbound<T>(*array, dim);
}
if (takeBoundsFromShape) {
if (auto shape{GetContextFreeShape(context, *array)}) {
if (dim) {
Expand Down
14 changes: 12 additions & 2 deletions flang/test/Evaluate/folding08.f90
Expand Up @@ -77,23 +77,33 @@ subroutine test2
end block
end associate
end subroutine
subroutine test3_lbound_parameter
! Test lbound with constant arrays
subroutine test3_bound_parameter
! Test [ul]bound with parameter arrays
integer, parameter :: a1(1) = 0
integer, parameter :: lba1(*) = lbound(a1)
logical, parameter :: test_lba1 = all(lba1 == [1])
integer, parameter :: uba1(*) = ubound(a1)
logical, parameter :: test_uba1 = all(lba1 == [1])

integer, parameter :: a2(0:1) = 0
integer, parameter :: lba2(*) = lbound(a2)
logical, parameter :: test_lba2 = all(lba2 == [0])
integer, parameter :: uba2(*) = ubound(a2)
logical, parameter :: test_uba2 = all(uba2 == [1])

integer, parameter :: a3(-10:-5,1,4:6) = 0
integer, parameter :: lba3(*) = lbound(a3)
logical, parameter :: test_lba3 = all(lba3 == [-10, 1, 4])
integer, parameter :: uba3(*) = ubound(a3)
logical, parameter :: test_uba3 = all(uba3 == [-5, 1, 6])

! Exercise with DIM=
logical, parameter :: test_lba3_dim = lbound(a3, 1) == -10 .and. &
lbound(a3, 2) == 1 .and. &
lbound(a3, 3) == 4
logical, parameter :: test_uba3_dim = ubound(a3, 1) == -5 .and. &
ubound(a3, 2) == 1 .and. &
ubound(a3, 3) == 6
end subroutine
subroutine test4_lbound_parentheses
! Test lbound with (x) expressions
Expand Down

0 comments on commit d8b4ea4

Please sign in to comment.