Skip to content

Commit

Permalink
[flang] Check constant arguments to bit manipulation intrinsics even …
Browse files Browse the repository at this point in the history
…if not foldable

When some arguments that specify bit positions, shift counts, and field sizes are
constant at compilation time, but other arguments are not constant, the compiler
should still validate the constant ones.  In the current sources, validation is
only performed for intrinsic references that can be folded to constants.

Differential Revision: https://reviews.llvm.org/D140152
  • Loading branch information
klausler committed Dec 17, 2022
1 parent 8f0df9f commit 4244cab
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 105 deletions.
238 changes: 167 additions & 71 deletions flang/lib/Evaluate/fold-integer.cpp
Expand Up @@ -577,6 +577,21 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
name == "dshiftl" ? &Scalar<T>::DSHIFTL : &Scalar<T>::DSHIFTR};
// Third argument can be of any kind. However, it must be smaller or equal
// than BIT_SIZE. It can be converted to Int4 to simplify.
if (const auto *shiftCon{Folder<Int4>(context).Folding(args[2])}) {
for (const auto &scalar : shiftCon->values()) {
std::int64_t shiftVal{scalar.ToInt64()};
if (shiftVal < 0) {
context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US,
std::intmax_t{shiftVal}, name);
break;
} else if (shiftVal > T::Scalar::bits) {
context.messages().Say(
"SHIFT=%jd count for %s is greater than %d"_err_en_US,
std::intmax_t{shiftVal}, name, T::Scalar::bits);
break;
}
}
}
return FoldElementalIntrinsic<T, T, T, Int4>(context, std::move(funcRef),
ScalarFunc<T, T, T, Int4>(
[&fptr](const Scalar<T> &i, const Scalar<T> &j,
Expand Down Expand Up @@ -662,42 +677,69 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else {
common::die("missing case to fold intrinsic function %s", name.c_str());
}
if (const auto *posCon{Folder<Int4>(context).Folding(args[1])}) {
for (const auto &scalar : posCon->values()) {
std::int64_t posVal{scalar.ToInt64()};
if (posVal < 0) {
context.messages().Say(
"bit position for %s (%jd) is negative"_err_en_US, name,
std::intmax_t{posVal});
break;
} else if (posVal >= T::Scalar::bits) {
context.messages().Say(
"bit position for %s (%jd) is not less than %d"_err_en_US, name,
std::intmax_t{posVal}, T::Scalar::bits);
break;
}
}
}
return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
const Scalar<Int4> &pos) -> Scalar<T> {
auto posVal{static_cast<int>(pos.ToInt64())};
if (posVal < 0) {
context.messages().Say(
"bit position for %s (%d) is negative"_err_en_US, name, posVal);
} else if (posVal >= i.bits) {
context.messages().Say(
"bit position for %s (%d) is not less than %d"_err_en_US, name,
posVal, i.bits);
}
return std::invoke(fptr, i, posVal);
}));
ScalarFunc<T, T, Int4>(
[&](const Scalar<T> &i, const Scalar<Int4> &pos) -> Scalar<T> {
return std::invoke(fptr, i, static_cast<int>(pos.ToInt64()));
}));
} else if (name == "ibits") {
const auto *posCon{Folder<Int4>(context).Folding(args[1])};
const auto *lenCon{Folder<Int4>(context).Folding(args[2])};
if (posCon && lenCon &&
(posCon->size() == 1 || lenCon->size() == 1 ||
posCon->size() == lenCon->size())) {
auto posIter{posCon->values().begin()};
auto lenIter{lenCon->values().begin()};
for (; posIter != posCon->values().end() &&
lenIter != lenCon->values().end();
++posIter, ++lenIter) {
posIter = posIter == posCon->values().end() ? posCon->values().begin()
: posIter;
lenIter = lenIter == lenCon->values().end() ? lenCon->values().begin()
: lenIter;
auto posVal{static_cast<int>(posIter->ToInt64())};
auto lenVal{static_cast<int>(lenIter->ToInt64())};
if (posVal < 0) {
context.messages().Say(
"bit position for IBITS(POS=%jd,LEN=%jd) is negative"_err_en_US,
std::intmax_t{posVal}, std::intmax_t{lenVal});
break;
} else if (lenVal < 0) {
context.messages().Say(
"bit length for IBITS(POS=%jd,LEN=%jd) is negative"_err_en_US,
std::intmax_t{posVal}, std::intmax_t{lenVal});
break;
} else if (posVal + lenVal > T::Scalar::bits) {
context.messages().Say(
"IBITS(POS=%jd,LEN=%jd) must have POS+LEN no greater than %d"_err_en_US,
std::intmax_t{posVal}, std::intmax_t{lenVal}, T::Scalar::bits);
break;
}
}
}
return FoldElementalIntrinsic<T, T, Int4, Int4>(context, std::move(funcRef),
ScalarFunc<T, T, Int4, Int4>([&](const Scalar<T> &i,
const Scalar<Int4> &pos,
const Scalar<Int4> &len) -> Scalar<T> {
auto posVal{static_cast<int>(pos.ToInt64())};
auto lenVal{static_cast<int>(len.ToInt64())};
if (posVal < 0) {
context.messages().Say(
"bit position for IBITS(POS=%d,LEN=%d) is negative"_err_en_US,
posVal, lenVal);
} else if (lenVal < 0) {
context.messages().Say(
"bit length for IBITS(POS=%d,LEN=%d) is negative"_err_en_US,
posVal, lenVal);
} else if (posVal + lenVal > i.bits) {
context.messages().Say(
"IBITS(POS=%d,LEN=%d) must have POS+LEN no greater than %d"_err_en_US,
posVal + lenVal, i.bits);
}
return i.IBITS(posVal, lenVal);
}));
ScalarFunc<T, T, Int4, Int4>(
[&](const Scalar<T> &i, const Scalar<Int4> &pos,
const Scalar<Int4> &len) -> Scalar<T> {
return i.IBITS(static_cast<int>(pos.ToInt64()),
static_cast<int>(len.ToInt64()));
}));
} else if (name == "index" || name == "scan" || name == "verify") {
if (auto *charExpr{UnwrapExpr<Expr<SomeCharacter>>(args[0])}) {
return common::visit(
Expand Down Expand Up @@ -761,24 +803,79 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else if (name == "iparity") {
return FoldBitReduction(
context, std::move(funcRef), &Scalar<T>::IEOR, Scalar<T>{});
} else if (name == "ishft") {
return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
ScalarFunc<T, T, Int4>(
[&](const Scalar<T> &i, const Scalar<Int4> &pos) -> Scalar<T> {
auto posVal{static_cast<int>(pos.ToInt64())};
if (posVal < -i.bits) {
context.messages().Say(
"SHIFT=%d count for ishft is less than %d"_err_en_US,
posVal, -i.bits);
} else if (posVal > i.bits) {
context.messages().Say(
"SHIFT=%d count for ishft is greater than %d"_err_en_US,
posVal, i.bits);
}
return i.ISHFT(posVal);
}));
} else if (name == "ishftc") {
if (args.at(2)) { // SIZE= is present
} else if (name == "ishft" || name == "ishftc") {
const auto *shiftCon{Folder<Int4>(context).Folding(args[1])};
if (shiftCon) {
for (const auto &scalar : shiftCon->values()) {
std::int64_t shiftVal{scalar.ToInt64()};
if (shiftVal < -T::Scalar::bits) {
context.messages().Say(
"SHIFT=%jd count for %s is less than %d"_err_en_US,
std::intmax_t{shiftVal}, name, -T::Scalar::bits);
break;
} else if (shiftVal > T::Scalar::bits) {
context.messages().Say(
"SHIFT=%jd count for %s is greater than %d"_err_en_US,
std::intmax_t{shiftVal}, name, T::Scalar::bits);
break;
}
}
}
if (args.size() == 3) { // ISHFTC
if (const auto *sizeCon{Folder<Int4>(context).Folding(args[2])}) {
for (const auto &scalar : sizeCon->values()) {
std::int64_t sizeVal{scalar.ToInt64()};
if (sizeVal <= 0) {
context.messages().Say(
"SIZE=%jd count for ishftc is not positive"_err_en_US,
std::intmax_t{sizeVal}, name);
break;
} else if (sizeVal > T::Scalar::bits) {
context.messages().Say(
"SIZE=%jd count for ishftc is greater than %d"_err_en_US,
std::intmax_t{sizeVal}, T::Scalar::bits);
break;
}
}
if (shiftCon &&
(shiftCon->size() == 1 || sizeCon->size() == 1 ||
shiftCon->size() == sizeCon->size())) {
auto shiftIter{shiftCon->values().begin()};
auto sizeIter{sizeCon->values().begin()};
for (; shiftIter != shiftCon->values().end() &&
sizeIter != sizeCon->values().end();
++shiftIter, ++sizeIter) {
shiftIter = shiftIter == shiftCon->values().end()
? shiftCon->values().begin()
: shiftIter;
sizeIter = sizeIter == sizeCon->values().end()
? sizeCon->values().begin()
: sizeIter;
auto shiftVal{static_cast<int>(shiftIter->ToInt64())};
auto sizeVal{static_cast<int>(sizeIter->ToInt64())};
if (sizeVal > 0 && std::abs(shiftVal) > sizeVal) {
context.messages().Say(
"SHIFT=%jd count for ishftc is greater in magnitude than SIZE=%jd"_err_en_US,
std::intmax_t{shiftVal}, std::intmax_t{sizeVal});
break;
}
}
}
}
}
if (name == "ishft") {
return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
ScalarFunc<T, T, Int4>(
[&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
return i.ISHFT(static_cast<int>(shift.ToInt64()));
}));
} else if (!args.at(2)) { // ISHFTC(no SIZE=)
return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
ScalarFunc<T, T, Int4>(
[&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
return i.ISHFTC(static_cast<int>(shift.ToInt64()));
}));
} else { // ISHFTC(with SIZE=)
return FoldElementalIntrinsic<T, T, Int4, Int4>(context,
std::move(funcRef),
ScalarFunc<T, T, Int4, Int4>(
Expand All @@ -789,13 +886,6 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
auto sizeVal{static_cast<int>(size.ToInt64())};
return i.ISHFTC(shiftVal, sizeVal);
}));
} else { // no SIZE=
return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
ScalarFunc<T, T, Int4>(
[&](const Scalar<T> &i, const Scalar<Int4> &count) -> Scalar<T> {
auto countVal{static_cast<int>(count.ToInt64())};
return i.ISHFTC(countVal);
}));
}
} else if (name == "izext" || name == "jzext") {
if (args.size() == 1) {
Expand Down Expand Up @@ -1045,20 +1135,26 @@ Expr<Type<TypeCategory::Integer, KIND>> FoldIntrinsicFunction(
} else {
common::die("missing case to fold intrinsic function %s", name.c_str());
}
if (const auto *shiftCon{Folder<Int4>(context).Folding(args[1])}) {
for (const auto &scalar : shiftCon->values()) {
std::int64_t shiftVal{scalar.ToInt64()};
if (shiftVal < 0) {
context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US,
std::intmax_t{shiftVal}, name, -T::Scalar::bits);
break;
} else if (shiftVal > T::Scalar::bits) {
context.messages().Say(
"SHIFT=%jd count for %s is greater than %d"_err_en_US,
std::intmax_t{shiftVal}, name, T::Scalar::bits);
break;
}
}
}
return FoldElementalIntrinsic<T, T, Int4>(context, std::move(funcRef),
ScalarFunc<T, T, Int4>([&](const Scalar<T> &i,
const Scalar<Int4> &pos) -> Scalar<T> {
auto posVal{static_cast<int>(pos.ToInt64())};
if (posVal < 0) {
context.messages().Say(
"SHIFT=%d count for %s is negative"_err_en_US, posVal, name);
} else if (posVal > i.bits) {
context.messages().Say(
"SHIFT=%d count for %s is greater than %d"_err_en_US, posVal,
name, i.bits);
}
return std::invoke(fptr, i, posVal);
}));
ScalarFunc<T, T, Int4>(
[&](const Scalar<T> &i, const Scalar<Int4> &shift) -> Scalar<T> {
return std::invoke(fptr, i, static_cast<int>(shift.ToInt64()));
}));
} else if (name == "sign") {
return FoldElementalIntrinsic<T, T, T>(context, std::move(funcRef),
ScalarFunc<T, T, T>([&context](const Scalar<T> &j,
Expand Down
18 changes: 0 additions & 18 deletions flang/lib/Evaluate/intrinsics.cpp
Expand Up @@ -2918,24 +2918,6 @@ static bool ApplySpecificChecks(SpecificCall &call, FoldingContext &context) {
if (const auto &arg{call.arguments[0]}) {
ok = CheckForNonPositiveValues(context, *arg, name, "image");
}
} else if (name == "ishftc") {
if (const auto &sizeArg{call.arguments[2]}) {
ok = CheckForNonPositiveValues(context, *sizeArg, name, "size");
if (ok) {
if (auto sizeVal{ToInt64(sizeArg->UnwrapExpr())}) {
if (const auto &shiftArg{call.arguments[1]}) {
if (auto shiftVal{ToInt64(shiftArg->UnwrapExpr())}) {
if (std::abs(*shiftVal) > *sizeVal) {
ok = false;
context.messages().Say(shiftArg->sourceLocation(),
"The absolute value of the 'shift=' argument for intrinsic '%s' must be less than or equal to the 'size=' argument"_err_en_US,
name);
}
}
}
}
}
}
} else if (name == "lcobound") {
return CheckDimAgainstCorank(call, context);
} else if (name == "loc") {
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Evaluate/fold-ishftc.f90
@@ -1,9 +1,9 @@
! RUN: %python %S/test_folding.py %s %flang_fc1
! Tests folding of ISHFTC
module m
integer, parameter :: shift8s(*) = ishftc(257, shift = [(ict, ict = -9, 9)], 8)
integer, parameter :: expect1(*) = 256 + [128, 1, 2, 4, 8, 16, 32, 64, 128, &
1, 2, 4, 8, 16, 32, 64, 128, 1, 2]
integer, parameter :: shift8s(*) = ishftc(257, shift = [(ict, ict = -8, 8)], 8)
integer, parameter :: expect1(*) = 256 + [1, 2, 4, 8, 16, 32, 64, 128, &
1, 2, 4, 8, 16, 32, 64, 128, 1]
logical, parameter :: test_1 = all(shift8s == expect1)
integer, parameter :: sizes(*) = [(ishftc(257, ict, [(isz, isz = 1, 8)]), ict = -1, 1)]
integer, parameter :: expect2(*) = 256 + [[1, 2, 4, 8, 16, 32, 64, 128], &
Expand Down
26 changes: 13 additions & 13 deletions flang/test/Semantics/ishftc.f90
Expand Up @@ -18,31 +18,31 @@ program test_ishftc
n = ishftc(3, 2, 3)
array_result = ishftc([3,3], [2,2], [3,3])

!ERROR: 'size=' argument for intrinsic 'ishftc' must be a positive value, but is -3
!ERROR: SIZE=-3 count for ishftc is not positive
n = ishftc(3, 2, -3)
!ERROR: 'size=' argument for intrinsic 'ishftc' must be a positive value, but is 0
!ERROR: SIZE=0 count for ishftc is not positive
n = ishftc(3, 2, 0)
!ERROR: The absolute value of the 'shift=' argument for intrinsic 'ishftc' must be less than or equal to the 'size=' argument
!ERROR: SHIFT=2 count for ishftc is greater in magnitude than SIZE=1
n = ishftc(3, 2, 1)
!ERROR: The absolute value of the 'shift=' argument for intrinsic 'ishftc' must be less than or equal to the 'size=' argument
!ERROR: SHIFT=-2 count for ishftc is greater in magnitude than SIZE=1
n = ishftc(3, -2, 1)
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=-3 count for ishftc is not positive
array_result = ishftc([3,3], [2,2], [-3,3])
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=-3 count for ishftc is not positive
array_result = ishftc([3,3], [2,2], [-3,-3])
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=-3 count for ishftc is not positive
array_result = ishftc([3,3], [-2,-2], const_arr1)
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=0 count for ishftc is not positive
array_result = ishftc([3,3], [-2,-2], const_arr2)
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=0 count for ishftc is not positive
array_result = ishftc([3,3], [-2,-2], const_arr3)
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=0 count for ishftc is not positive
array_result = ishftc([3,3], [-2,-2], const_arr4)
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=0 count for ishftc is not positive
array_result = ishftc([3,3], [-2,-2], const_arr5)
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=0 count for ishftc is not positive
array_result = ishftc([3,3], [-2,-2], const_arr6)
!ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values
!ERROR: SIZE=0 count for ishftc is not positive
array_result = ishftc([3,3], [-2,-2], const_arr7)

end program test_ishftc

0 comments on commit 4244cab

Please sign in to comment.