diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h index 9811bce25acd3..ff05e76c8bb7b 100644 --- a/flang/runtime/tools.h +++ b/flang/runtime/tools.h @@ -94,6 +94,31 @@ static inline RT_API_ATTRS std::int64_t GetInt64( } } +static inline RT_API_ATTRS std::optional GetInt64Safe( + const char *p, std::size_t bytes, Terminator &terminator) { + switch (bytes) { + case 1: + return *reinterpret_cast *>(p); + case 2: + return *reinterpret_cast *>(p); + case 4: + return *reinterpret_cast *>(p); + case 8: + return *reinterpret_cast *>(p); + case 16: { + using Int128 = CppTypeFor; + auto n{*reinterpret_cast(p)}; + std::int64_t result = n; + if (result == n) { + return result; + } + return std::nullopt; + } + default: + terminator.Crash("GetInt64Safe: no case for %zd bytes", bytes); + } +} + template inline RT_API_ATTRS bool SetInteger(INT &x, int kind, std::int64_t value) { switch (kind) { diff --git a/flang/runtime/transformational.cpp b/flang/runtime/transformational.cpp index da8ec05c884fa..cf1e61c0844d8 100644 --- a/flang/runtime/transformational.cpp +++ b/flang/runtime/transformational.cpp @@ -52,9 +52,11 @@ class ShiftControl { } } } + } else if (auto count{GetInt64Safe( + shift_.OffsetElement(), shiftElemLen_, terminator_)}) { + shiftCount_ = *count; } else { - shiftCount_ = - GetInt64(shift_.OffsetElement(), shiftElemLen_, terminator_); + terminator_.Crash("%s: SHIFT= value exceeds 64 bits", which); } } RT_API_ATTRS SubscriptValue GetShift(const SubscriptValue resultAt[]) const { @@ -67,8 +69,10 @@ class ShiftControl { ++k; } } - return GetInt64( - shift_.Element(shiftAt), shiftElemLen_, terminator_); + auto count{GetInt64Safe( + shift_.Element(shiftAt), shiftElemLen_, terminator_)}; + RUNTIME_CHECK(terminator_, count.has_value()); + return *count; } else { return shiftCount_; // invariant count extracted in Init() } @@ -719,12 +723,15 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source, std::size_t resultElements{1}; SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()}; for (int j{0}; j < resultRank; ++j, ++shapeSubscript) { - resultExtent[j] = GetInt64( - shape.Element(&shapeSubscript), shapeElementBytes, terminator); - if (resultExtent[j] < 0) { + auto extent{GetInt64Safe( + shape.Element(&shapeSubscript), shapeElementBytes, terminator)}; + if (!extent) { + terminator.Crash("RESHAPE: value of SHAPE(%d) exceeds 64 bits", j + 1); + } else if (*extent < 0) { terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1, - static_cast(resultExtent[j])); + static_cast(*extent)); } + resultExtent[j] = *extent; resultElements *= resultExtent[j]; } @@ -762,14 +769,16 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source, SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()}; std::size_t orderElementBytes{order->ElementBytes()}; for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) { - auto k{GetInt64(order->Element(&orderSubscript), orderElementBytes, - terminator)}; - if (k < 1 || k > resultRank || ((values >> k) & 1)) { + auto k{GetInt64Safe(order->Element(&orderSubscript), + orderElementBytes, terminator)}; + if (!k) { + terminator.Crash("RESHAPE: ORDER element value exceeds 64 bits"); + } else if (*k < 1 || *k > resultRank || ((values >> *k) & 1)) { terminator.Crash("RESHAPE: bad value for ORDER element (%jd)", - static_cast(k)); + static_cast(*k)); } - values |= std::uint64_t{1} << k; - dimOrder[j] = k - 1; + values |= std::uint64_t{1} << *k; + dimOrder[j] = *k - 1; } } else { for (int j{0}; j < resultRank; ++j) {