Skip to content

Commit

Permalink
[flang][runtime] Accept 128-bit integer SHIFT values in CSHIFT/EOSHIFT (
Browse files Browse the repository at this point in the history
#75246)

It would surprise me if this case ever arose outside a couple of tests
in llvm-test-suite/Fortran/gfortran/regression (namely
cshift_large_1.f90 and eoshift_large_1.f90), but now at least those
tests will pass.
  • Loading branch information
klausler authored Dec 26, 2023
1 parent befdfae commit 8fc045e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
25 changes: 25 additions & 0 deletions flang/runtime/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,31 @@ static inline RT_API_ATTRS std::int64_t GetInt64(
}
}

static inline RT_API_ATTRS std::optional<std::int64_t> GetInt64Safe(
const char *p, std::size_t bytes, Terminator &terminator) {
switch (bytes) {
case 1:
return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 1> *>(p);
case 2:
return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 2> *>(p);
case 4:
return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 4> *>(p);
case 8:
return *reinterpret_cast<const CppTypeFor<TypeCategory::Integer, 8> *>(p);
case 16: {
using Int128 = CppTypeFor<TypeCategory::Integer, 16>;
auto n{*reinterpret_cast<const Int128 *>(p)};
std::int64_t result = n;
if (result == n) {
return result;
}
return std::nullopt;
}
default:
terminator.Crash("GetInt64Safe: no case for %zd bytes", bytes);
}
}

template <typename INT>
inline RT_API_ATTRS bool SetInteger(INT &x, int kind, std::int64_t value) {
switch (kind) {
Expand Down
37 changes: 23 additions & 14 deletions flang/runtime/transformational.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ class ShiftControl {
}
}
}
} else if (auto count{GetInt64Safe(
shift_.OffsetElement<char>(), shiftElemLen_, terminator_)}) {
shiftCount_ = *count;
} else {
shiftCount_ =
GetInt64(shift_.OffsetElement<char>(), shiftElemLen_, terminator_);
terminator_.Crash("%s: SHIFT= value exceeds 64 bits", which);
}
}
RT_API_ATTRS SubscriptValue GetShift(const SubscriptValue resultAt[]) const {
Expand All @@ -67,8 +69,10 @@ class ShiftControl {
++k;
}
}
return GetInt64(
shift_.Element<char>(shiftAt), shiftElemLen_, terminator_);
auto count{GetInt64Safe(
shift_.Element<char>(shiftAt), shiftElemLen_, terminator_)};
RUNTIME_CHECK(terminator_, count.has_value());
return *count;
} else {
return shiftCount_; // invariant count extracted in Init()
}
Expand Down Expand Up @@ -719,12 +723,15 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
std::size_t resultElements{1};
SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()};
for (int j{0}; j < resultRank; ++j, ++shapeSubscript) {
resultExtent[j] = GetInt64(
shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator);
if (resultExtent[j] < 0) {
auto extent{GetInt64Safe(
shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator)};
if (!extent) {
terminator.Crash("RESHAPE: value of SHAPE(%d) exceeds 64 bits", j + 1);
} else if (*extent < 0) {
terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1,
static_cast<std::intmax_t>(resultExtent[j]));
static_cast<std::intmax_t>(*extent));
}
resultExtent[j] = *extent;
resultElements *= resultExtent[j];
}

Expand Down Expand Up @@ -762,14 +769,16 @@ void RTDEF(Reshape)(Descriptor &result, const Descriptor &source,
SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()};
std::size_t orderElementBytes{order->ElementBytes()};
for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) {
auto k{GetInt64(order->Element<char>(&orderSubscript), orderElementBytes,
terminator)};
if (k < 1 || k > resultRank || ((values >> k) & 1)) {
auto k{GetInt64Safe(order->Element<char>(&orderSubscript),
orderElementBytes, terminator)};
if (!k) {
terminator.Crash("RESHAPE: ORDER element value exceeds 64 bits");
} else if (*k < 1 || *k > resultRank || ((values >> *k) & 1)) {
terminator.Crash("RESHAPE: bad value for ORDER element (%jd)",
static_cast<std::intmax_t>(k));
static_cast<std::intmax_t>(*k));
}
values |= std::uint64_t{1} << k;
dimOrder[j] = k - 1;
values |= std::uint64_t{1} << *k;
dimOrder[j] = *k - 1;
}
} else {
for (int j{0}; j < resultRank; ++j) {
Expand Down

0 comments on commit 8fc045e

Please sign in to comment.