From 4244cab23afcee5c3515d67d6d340cf82ce5289f Mon Sep 17 00:00:00 2001 From: Peter Klausler Date: Sat, 17 Dec 2022 10:48:27 -0800 Subject: [PATCH] [flang] Check constant arguments to bit manipulation intrinsics even if not foldable When some arguments that specify bit positions, shift counts, and field sizes are constant at compilation time, but other arguments are not constant, the compiler should still validate the constant ones. In the current sources, validation is only performed for intrinsic references that can be folded to constants. Differential Revision: https://reviews.llvm.org/D140152 --- flang/lib/Evaluate/fold-integer.cpp | 238 +++++++++++++++++++--------- flang/lib/Evaluate/intrinsics.cpp | 18 --- flang/test/Evaluate/fold-ishftc.f90 | 6 +- flang/test/Semantics/ishftc.f90 | 26 +-- 4 files changed, 183 insertions(+), 105 deletions(-) diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp index 0aaf5b182635e..61e76b59b1f9b 100644 --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -577,6 +577,21 @@ Expr> FoldIntrinsicFunction( name == "dshiftl" ? &Scalar::DSHIFTL : &Scalar::DSHIFTR}; // Third argument can be of any kind. However, it must be smaller or equal // than BIT_SIZE. It can be converted to Int4 to simplify. + if (const auto *shiftCon{Folder(context).Folding(args[2])}) { + for (const auto &scalar : shiftCon->values()) { + std::int64_t shiftVal{scalar.ToInt64()}; + if (shiftVal < 0) { + context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US, + std::intmax_t{shiftVal}, name); + break; + } else if (shiftVal > T::Scalar::bits) { + context.messages().Say( + "SHIFT=%jd count for %s is greater than %d"_err_en_US, + std::intmax_t{shiftVal}, name, T::Scalar::bits); + break; + } + } + } return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [&fptr](const Scalar &i, const Scalar &j, @@ -662,42 +677,69 @@ Expr> FoldIntrinsicFunction( } else { common::die("missing case to fold intrinsic function %s", name.c_str()); } + if (const auto *posCon{Folder(context).Folding(args[1])}) { + for (const auto &scalar : posCon->values()) { + std::int64_t posVal{scalar.ToInt64()}; + if (posVal < 0) { + context.messages().Say( + "bit position for %s (%jd) is negative"_err_en_US, name, + std::intmax_t{posVal}); + break; + } else if (posVal >= T::Scalar::bits) { + context.messages().Say( + "bit position for %s (%jd) is not less than %d"_err_en_US, name, + std::intmax_t{posVal}, T::Scalar::bits); + break; + } + } + } return FoldElementalIntrinsic(context, std::move(funcRef), - ScalarFunc([&](const Scalar &i, - const Scalar &pos) -> Scalar { - auto posVal{static_cast(pos.ToInt64())}; - if (posVal < 0) { - context.messages().Say( - "bit position for %s (%d) is negative"_err_en_US, name, posVal); - } else if (posVal >= i.bits) { - context.messages().Say( - "bit position for %s (%d) is not less than %d"_err_en_US, name, - posVal, i.bits); - } - return std::invoke(fptr, i, posVal); - })); + ScalarFunc( + [&](const Scalar &i, const Scalar &pos) -> Scalar { + return std::invoke(fptr, i, static_cast(pos.ToInt64())); + })); } else if (name == "ibits") { + const auto *posCon{Folder(context).Folding(args[1])}; + const auto *lenCon{Folder(context).Folding(args[2])}; + if (posCon && lenCon && + (posCon->size() == 1 || lenCon->size() == 1 || + posCon->size() == lenCon->size())) { + auto posIter{posCon->values().begin()}; + auto lenIter{lenCon->values().begin()}; + for (; posIter != posCon->values().end() && + lenIter != lenCon->values().end(); + ++posIter, ++lenIter) { + posIter = posIter == posCon->values().end() ? posCon->values().begin() + : posIter; + lenIter = lenIter == lenCon->values().end() ? lenCon->values().begin() + : lenIter; + auto posVal{static_cast(posIter->ToInt64())}; + auto lenVal{static_cast(lenIter->ToInt64())}; + if (posVal < 0) { + context.messages().Say( + "bit position for IBITS(POS=%jd,LEN=%jd) is negative"_err_en_US, + std::intmax_t{posVal}, std::intmax_t{lenVal}); + break; + } else if (lenVal < 0) { + context.messages().Say( + "bit length for IBITS(POS=%jd,LEN=%jd) is negative"_err_en_US, + std::intmax_t{posVal}, std::intmax_t{lenVal}); + break; + } else if (posVal + lenVal > T::Scalar::bits) { + context.messages().Say( + "IBITS(POS=%jd,LEN=%jd) must have POS+LEN no greater than %d"_err_en_US, + std::intmax_t{posVal}, std::intmax_t{lenVal}, T::Scalar::bits); + break; + } + } + } return FoldElementalIntrinsic(context, std::move(funcRef), - ScalarFunc([&](const Scalar &i, - const Scalar &pos, - const Scalar &len) -> Scalar { - auto posVal{static_cast(pos.ToInt64())}; - auto lenVal{static_cast(len.ToInt64())}; - if (posVal < 0) { - context.messages().Say( - "bit position for IBITS(POS=%d,LEN=%d) is negative"_err_en_US, - posVal, lenVal); - } else if (lenVal < 0) { - context.messages().Say( - "bit length for IBITS(POS=%d,LEN=%d) is negative"_err_en_US, - posVal, lenVal); - } else if (posVal + lenVal > i.bits) { - context.messages().Say( - "IBITS(POS=%d,LEN=%d) must have POS+LEN no greater than %d"_err_en_US, - posVal + lenVal, i.bits); - } - return i.IBITS(posVal, lenVal); - })); + ScalarFunc( + [&](const Scalar &i, const Scalar &pos, + const Scalar &len) -> Scalar { + return i.IBITS(static_cast(pos.ToInt64()), + static_cast(len.ToInt64())); + })); } else if (name == "index" || name == "scan" || name == "verify") { if (auto *charExpr{UnwrapExpr>(args[0])}) { return common::visit( @@ -761,24 +803,79 @@ Expr> FoldIntrinsicFunction( } else if (name == "iparity") { return FoldBitReduction( context, std::move(funcRef), &Scalar::IEOR, Scalar{}); - } else if (name == "ishft") { - return FoldElementalIntrinsic(context, std::move(funcRef), - ScalarFunc( - [&](const Scalar &i, const Scalar &pos) -> Scalar { - auto posVal{static_cast(pos.ToInt64())}; - if (posVal < -i.bits) { - context.messages().Say( - "SHIFT=%d count for ishft is less than %d"_err_en_US, - posVal, -i.bits); - } else if (posVal > i.bits) { - context.messages().Say( - "SHIFT=%d count for ishft is greater than %d"_err_en_US, - posVal, i.bits); - } - return i.ISHFT(posVal); - })); - } else if (name == "ishftc") { - if (args.at(2)) { // SIZE= is present + } else if (name == "ishft" || name == "ishftc") { + const auto *shiftCon{Folder(context).Folding(args[1])}; + if (shiftCon) { + for (const auto &scalar : shiftCon->values()) { + std::int64_t shiftVal{scalar.ToInt64()}; + if (shiftVal < -T::Scalar::bits) { + context.messages().Say( + "SHIFT=%jd count for %s is less than %d"_err_en_US, + std::intmax_t{shiftVal}, name, -T::Scalar::bits); + break; + } else if (shiftVal > T::Scalar::bits) { + context.messages().Say( + "SHIFT=%jd count for %s is greater than %d"_err_en_US, + std::intmax_t{shiftVal}, name, T::Scalar::bits); + break; + } + } + } + if (args.size() == 3) { // ISHFTC + if (const auto *sizeCon{Folder(context).Folding(args[2])}) { + for (const auto &scalar : sizeCon->values()) { + std::int64_t sizeVal{scalar.ToInt64()}; + if (sizeVal <= 0) { + context.messages().Say( + "SIZE=%jd count for ishftc is not positive"_err_en_US, + std::intmax_t{sizeVal}, name); + break; + } else if (sizeVal > T::Scalar::bits) { + context.messages().Say( + "SIZE=%jd count for ishftc is greater than %d"_err_en_US, + std::intmax_t{sizeVal}, T::Scalar::bits); + break; + } + } + if (shiftCon && + (shiftCon->size() == 1 || sizeCon->size() == 1 || + shiftCon->size() == sizeCon->size())) { + auto shiftIter{shiftCon->values().begin()}; + auto sizeIter{sizeCon->values().begin()}; + for (; shiftIter != shiftCon->values().end() && + sizeIter != sizeCon->values().end(); + ++shiftIter, ++sizeIter) { + shiftIter = shiftIter == shiftCon->values().end() + ? shiftCon->values().begin() + : shiftIter; + sizeIter = sizeIter == sizeCon->values().end() + ? sizeCon->values().begin() + : sizeIter; + auto shiftVal{static_cast(shiftIter->ToInt64())}; + auto sizeVal{static_cast(sizeIter->ToInt64())}; + if (sizeVal > 0 && std::abs(shiftVal) > sizeVal) { + context.messages().Say( + "SHIFT=%jd count for ishftc is greater in magnitude than SIZE=%jd"_err_en_US, + std::intmax_t{shiftVal}, std::intmax_t{sizeVal}); + break; + } + } + } + } + } + if (name == "ishft") { + return FoldElementalIntrinsic(context, std::move(funcRef), + ScalarFunc( + [&](const Scalar &i, const Scalar &shift) -> Scalar { + return i.ISHFT(static_cast(shift.ToInt64())); + })); + } else if (!args.at(2)) { // ISHFTC(no SIZE=) + return FoldElementalIntrinsic(context, std::move(funcRef), + ScalarFunc( + [&](const Scalar &i, const Scalar &shift) -> Scalar { + return i.ISHFTC(static_cast(shift.ToInt64())); + })); + } else { // ISHFTC(with SIZE=) return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( @@ -789,13 +886,6 @@ Expr> FoldIntrinsicFunction( auto sizeVal{static_cast(size.ToInt64())}; return i.ISHFTC(shiftVal, sizeVal); })); - } else { // no SIZE= - return FoldElementalIntrinsic(context, std::move(funcRef), - ScalarFunc( - [&](const Scalar &i, const Scalar &count) -> Scalar { - auto countVal{static_cast(count.ToInt64())}; - return i.ISHFTC(countVal); - })); } } else if (name == "izext" || name == "jzext") { if (args.size() == 1) { @@ -1045,20 +1135,26 @@ Expr> FoldIntrinsicFunction( } else { common::die("missing case to fold intrinsic function %s", name.c_str()); } + if (const auto *shiftCon{Folder(context).Folding(args[1])}) { + for (const auto &scalar : shiftCon->values()) { + std::int64_t shiftVal{scalar.ToInt64()}; + if (shiftVal < 0) { + context.messages().Say("SHIFT=%jd count for %s is negative"_err_en_US, + std::intmax_t{shiftVal}, name, -T::Scalar::bits); + break; + } else if (shiftVal > T::Scalar::bits) { + context.messages().Say( + "SHIFT=%jd count for %s is greater than %d"_err_en_US, + std::intmax_t{shiftVal}, name, T::Scalar::bits); + break; + } + } + } return FoldElementalIntrinsic(context, std::move(funcRef), - ScalarFunc([&](const Scalar &i, - const Scalar &pos) -> Scalar { - auto posVal{static_cast(pos.ToInt64())}; - if (posVal < 0) { - context.messages().Say( - "SHIFT=%d count for %s is negative"_err_en_US, posVal, name); - } else if (posVal > i.bits) { - context.messages().Say( - "SHIFT=%d count for %s is greater than %d"_err_en_US, posVal, - name, i.bits); - } - return std::invoke(fptr, i, posVal); - })); + ScalarFunc( + [&](const Scalar &i, const Scalar &shift) -> Scalar { + return std::invoke(fptr, i, static_cast(shift.ToInt64())); + })); } else if (name == "sign") { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc([&context](const Scalar &j, diff --git a/flang/lib/Evaluate/intrinsics.cpp b/flang/lib/Evaluate/intrinsics.cpp index 308e3e9d61855..6f4914d8fa715 100644 --- a/flang/lib/Evaluate/intrinsics.cpp +++ b/flang/lib/Evaluate/intrinsics.cpp @@ -2918,24 +2918,6 @@ static bool ApplySpecificChecks(SpecificCall &call, FoldingContext &context) { if (const auto &arg{call.arguments[0]}) { ok = CheckForNonPositiveValues(context, *arg, name, "image"); } - } else if (name == "ishftc") { - if (const auto &sizeArg{call.arguments[2]}) { - ok = CheckForNonPositiveValues(context, *sizeArg, name, "size"); - if (ok) { - if (auto sizeVal{ToInt64(sizeArg->UnwrapExpr())}) { - if (const auto &shiftArg{call.arguments[1]}) { - if (auto shiftVal{ToInt64(shiftArg->UnwrapExpr())}) { - if (std::abs(*shiftVal) > *sizeVal) { - ok = false; - context.messages().Say(shiftArg->sourceLocation(), - "The absolute value of the 'shift=' argument for intrinsic '%s' must be less than or equal to the 'size=' argument"_err_en_US, - name); - } - } - } - } - } - } } else if (name == "lcobound") { return CheckDimAgainstCorank(call, context); } else if (name == "loc") { diff --git a/flang/test/Evaluate/fold-ishftc.f90 b/flang/test/Evaluate/fold-ishftc.f90 index 35c6bf3283b78..705134823956c 100644 --- a/flang/test/Evaluate/fold-ishftc.f90 +++ b/flang/test/Evaluate/fold-ishftc.f90 @@ -1,9 +1,9 @@ ! RUN: %python %S/test_folding.py %s %flang_fc1 ! Tests folding of ISHFTC module m - integer, parameter :: shift8s(*) = ishftc(257, shift = [(ict, ict = -9, 9)], 8) - integer, parameter :: expect1(*) = 256 + [128, 1, 2, 4, 8, 16, 32, 64, 128, & - 1, 2, 4, 8, 16, 32, 64, 128, 1, 2] + integer, parameter :: shift8s(*) = ishftc(257, shift = [(ict, ict = -8, 8)], 8) + integer, parameter :: expect1(*) = 256 + [1, 2, 4, 8, 16, 32, 64, 128, & + 1, 2, 4, 8, 16, 32, 64, 128, 1] logical, parameter :: test_1 = all(shift8s == expect1) integer, parameter :: sizes(*) = [(ishftc(257, ict, [(isz, isz = 1, 8)]), ict = -1, 1)] integer, parameter :: expect2(*) = 256 + [[1, 2, 4, 8, 16, 32, 64, 128], & diff --git a/flang/test/Semantics/ishftc.f90 b/flang/test/Semantics/ishftc.f90 index 3e0ebe5a41d06..15d1213999cc9 100644 --- a/flang/test/Semantics/ishftc.f90 +++ b/flang/test/Semantics/ishftc.f90 @@ -18,31 +18,31 @@ program test_ishftc n = ishftc(3, 2, 3) array_result = ishftc([3,3], [2,2], [3,3]) - !ERROR: 'size=' argument for intrinsic 'ishftc' must be a positive value, but is -3 + !ERROR: SIZE=-3 count for ishftc is not positive n = ishftc(3, 2, -3) - !ERROR: 'size=' argument for intrinsic 'ishftc' must be a positive value, but is 0 + !ERROR: SIZE=0 count for ishftc is not positive n = ishftc(3, 2, 0) - !ERROR: The absolute value of the 'shift=' argument for intrinsic 'ishftc' must be less than or equal to the 'size=' argument + !ERROR: SHIFT=2 count for ishftc is greater in magnitude than SIZE=1 n = ishftc(3, 2, 1) - !ERROR: The absolute value of the 'shift=' argument for intrinsic 'ishftc' must be less than or equal to the 'size=' argument + !ERROR: SHIFT=-2 count for ishftc is greater in magnitude than SIZE=1 n = ishftc(3, -2, 1) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=-3 count for ishftc is not positive array_result = ishftc([3,3], [2,2], [-3,3]) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=-3 count for ishftc is not positive array_result = ishftc([3,3], [2,2], [-3,-3]) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=-3 count for ishftc is not positive array_result = ishftc([3,3], [-2,-2], const_arr1) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=0 count for ishftc is not positive array_result = ishftc([3,3], [-2,-2], const_arr2) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=0 count for ishftc is not positive array_result = ishftc([3,3], [-2,-2], const_arr3) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=0 count for ishftc is not positive array_result = ishftc([3,3], [-2,-2], const_arr4) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=0 count for ishftc is not positive array_result = ishftc([3,3], [-2,-2], const_arr5) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=0 count for ishftc is not positive array_result = ishftc([3,3], [-2,-2], const_arr6) - !ERROR: 'size=' argument for intrinsic 'ishftc' must contain all positive values + !ERROR: SIZE=0 count for ishftc is not positive array_result = ishftc([3,3], [-2,-2], const_arr7) end program test_ishftc