diff --git a/flang/lib/Evaluate/fold-integer.cpp b/flang/lib/Evaluate/fold-integer.cpp index 6795f9007e0e1..98eda5948e4ff 100644 --- a/flang/lib/Evaluate/fold-integer.cpp +++ b/flang/lib/Evaluate/fold-integer.cpp @@ -610,34 +610,21 @@ Expr> FoldIntrinsicFunction( } else if (name == "iparity") { return FoldBitReduction( context, std::move(funcRef), &Scalar::IEOR, Scalar{}); - } else if (name == "ishft" || name == "shifta" || name == "shiftr" || - name == "shiftl") { - // Second argument can be of any kind. However, it must be smaller or - // equal than BIT_SIZE. It can be converted to Int4 to simplify. - auto fptr{&Scalar::ISHFT}; - if (name == "ishft") { // done in fptr definition - } else if (name == "shifta") { - fptr = &Scalar::SHIFTA; - } else if (name == "shiftr") { - fptr = &Scalar::SHIFTR; - } else if (name == "shiftl") { - fptr = &Scalar::SHIFTL; - } else { - common::die("missing case to fold intrinsic function %s", name.c_str()); - } + } else if (name == "ishft") { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc([&](const Scalar &i, const Scalar &pos) -> Scalar { auto posVal{static_cast(pos.ToInt64())}; - if (posVal < 0) { + if (posVal < -i.bits) { context.messages().Say( - "shift count for %s (%d) is negative"_err_en_US, name, posVal); + "SHIFT=%d count for ishft is less than %d"_err_en_US, posVal, + -i.bits); } else if (posVal > i.bits) { context.messages().Say( - "shift count for %s (%d) is greater than %d"_err_en_US, name, - posVal, i.bits); + "SHIFT=%d count for ishft is greater than %d"_err_en_US, posVal, + i.bits); } - return std::invoke(fptr, i, posVal); + return i.ISHFT(posVal); })); } else if (name == "lbound") { return LBOUND(context, std::move(funcRef)); @@ -856,6 +843,32 @@ Expr> FoldIntrinsicFunction( return Fold(context, ConvertToType(std::move(*shapeExpr))); } } + } else if (name == "shifta" || name == "shiftr" || name == "shiftl") { + // Second argument can be of any kind. However, it must be smaller or + // equal than BIT_SIZE. It can be converted to Int4 to simplify. + auto fptr{&Scalar::SHIFTA}; + if (name == "shifta") { // done in fptr definition + } else if (name == "shiftr") { + fptr = &Scalar::SHIFTR; + } else if (name == "shiftl") { + fptr = &Scalar::SHIFTL; + } else { + common::die("missing case to fold intrinsic function %s", name.c_str()); + } + return FoldElementalIntrinsic(context, std::move(funcRef), + ScalarFunc([&](const Scalar &i, + const Scalar &pos) -> Scalar { + auto posVal{static_cast(pos.ToInt64())}; + if (posVal < 0) { + context.messages().Say( + "SHIFT=%d count for %s is negative"_err_en_US, posVal, name); + } else if (posVal > i.bits) { + context.messages().Say( + "SHIFT=%d count for %s is greater than %d"_err_en_US, posVal, + name, i.bits); + } + return std::invoke(fptr, i, posVal); + })); } else if (name == "sign") { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( diff --git a/flang/test/Evaluate/fold-ishft.f90 b/flang/test/Evaluate/fold-ishft.f90 new file mode 100644 index 0000000000000..64d23c596044d --- /dev/null +++ b/flang/test/Evaluate/fold-ishft.f90 @@ -0,0 +1,6 @@ +! RUN: %python %S/test_folding.py %s %flang_fc1 +! Tests folding of ISHFT +module m1 + logical :: test_ishft_lsb = all(ishft(1, [-32, -31, -1, 0, 1, 2, 31, 32]) == [0, 0, 0, 1, 2, 4, int(z'80000000'), 0]) + logical :: test_ishft_msb = all(ishft(ishft(1,31), [-32, -31, -1, 0, 1, 2, 31, 32]) == [0, 1, int(z'40000000'), int(z'80000000'), 0, 0, 0, 0]) +end module diff --git a/flang/test/Evaluate/folding19.f90 b/flang/test/Evaluate/folding19.f90 index cbd6d2082b52a..32d4be7e01db2 100644 --- a/flang/test/Evaluate/folding19.f90 +++ b/flang/test/Evaluate/folding19.f90 @@ -56,8 +56,38 @@ subroutine s6 !CHECK: error: POS=32 out of range for BTEST logical, parameter :: bad2 = btest(0, 32) !CHECK-NOT: error: POS=33 out of range for BTEST - logical, parameter :: bad3 = btest(0_8, 33) + logical, parameter :: ok1 = btest(0_8, 33) !CHECK: error: POS=64 out of range for BTEST logical, parameter :: bad4 = btest(0_8, 64) end subroutine + subroutine s7 + !CHECK: error: SHIFT=-33 count for ishft is less than -32 + integer, parameter :: bad1 = ishft(1, -33) + integer, parameter :: ok1 = ishft(1, -32) + integer, parameter :: ok2 = ishft(1, 32) + !CHECK: error: SHIFT=33 count for ishft is greater than 32 + integer, parameter :: bad2 = ishft(1, 33) + !CHECK: error: SHIFT=-65 count for ishft is less than -64 + integer(8), parameter :: bad3 = ishft(1_8, -65) + integer(8), parameter :: ok3 = ishft(1_8, -64) + integer(8), parameter :: ok4 = ishft(1_8, 64) + !CHECK: error: SHIFT=65 count for ishft is greater than 64 + integer(8), parameter :: bad4 = ishft(1_8, 65) + end subroutine + subroutine s8 + !CHECK: error: SHIFT=-33 count for shiftl is negative + integer, parameter :: bad1 = shiftl(1, -33) + !CHECK: error: SHIFT=-32 count for shiftl is negative + integer, parameter :: bad2 = shiftl(1, -32) + integer, parameter :: ok1 = shiftl(1, 32) + !CHECK: error: SHIFT=33 count for shiftl is greater than 32 + integer, parameter :: bad3 = shiftl(1, 33) + !CHECK: error: SHIFT=-65 count for shiftl is negative + integer(8), parameter :: bad4 = shiftl(1_8, -65) + !CHECK: error: SHIFT=-64 count for shiftl is negative + integer(8), parameter :: bad5 = shiftl(1_8, -64) + integer(8), parameter :: ok2 = shiftl(1_8, 64) + !CHECK: error: SHIFT=65 count for shiftl is greater than 64 + integer(8), parameter :: bad6 = shiftl(1_8, 65) + end subroutine end module