Skip to content

Commit

Permalink
[flang] Convert RUNTIME_CHECK to better error for user errors in tran…
Browse files Browse the repository at this point in the history
…sformational.cpp

In flang/runtime/transformational.cpp, there are many RUNTIME_CHECK assertions
for errors that should have been caught in semantics, but there are alno others
that signify program errors that in principle cannot be detected until
execution.  Convert this second group into readable fatal error messages.
Also clean up some missing braces and incorrect printf formats found
along the way.

Differential Revision: https://reviews.llvm.org/D122037
  • Loading branch information
klausler committed Mar 18, 2022
1 parent 09ff41a commit 251d062
Showing 1 changed file with 81 additions and 35 deletions.
116 changes: 81 additions & 35 deletions flang/runtime/transformational.cpp
Expand Up @@ -31,7 +31,7 @@ class ShiftControl {
public:
ShiftControl(const Descriptor &s, Terminator &t, int dim)
: shift_{s}, terminator_{t}, shiftRank_{s.rank()}, dim_{dim} {}
void Init(const Descriptor &source) {
void Init(const Descriptor &source, const char *which) {
int rank{source.rank()};
RUNTIME_CHECK(terminator_, shiftRank_ == 0 || shiftRank_ == rank - 1);
auto catAndKind{shift_.type().GetCategoryAndKind()};
Expand All @@ -44,8 +44,12 @@ class ShiftControl {
if (j + 1 != dim_) {
const Dimension &shiftDim{shift_.GetDimension(k)};
lb_[k++] = shiftDim.LowerBound();
RUNTIME_CHECK(terminator_,
shiftDim.Extent() == source.GetDimension(j).Extent());
if (shiftDim.Extent() != source.GetDimension(j).Extent()) {
terminator_.Crash("%s: on dimension %d, SHIFT= has extent %jd but "
"SOURCE= has extent %jd",
which, k, static_cast<std::intmax_t>(shiftDim.Extent()),
static_cast<std::intmax_t>(source.GetDimension(j).Extent()));
}
}
}
} else {
Expand Down Expand Up @@ -137,9 +141,12 @@ void RTNAME(Cshift)(Descriptor &result, const Descriptor &source,
Terminator terminator{sourceFile, line};
int rank{source.rank()};
RUNTIME_CHECK(terminator, rank > 1);
RUNTIME_CHECK(terminator, dim >= 1 && dim <= rank);
if (dim < 1 || dim > rank) {
terminator.Crash(
"CSHIFT: DIM=%d must be >= 1 and <= SOURCE= rank %d", dim, rank);
}
ShiftControl shiftControl{shift, terminator, dim};
shiftControl.Init(source);
shiftControl.Init(source, "CSHIFT");
SubscriptValue extent[maxRank];
source.GetShape(extent);
AllocateResult(result, source, rank, extent, terminator, "CSHIFT");
Expand Down Expand Up @@ -200,29 +207,39 @@ void RTNAME(Eoshift)(Descriptor &result, const Descriptor &source,
SubscriptValue extent[maxRank];
int rank{source.GetShape(extent)};
RUNTIME_CHECK(terminator, rank > 1);
RUNTIME_CHECK(terminator, dim >= 1 && dim <= rank);
if (dim < 1 || dim > rank) {
terminator.Crash(
"EOSHIFT: DIM=%d must be >= 1 and <= SOURCE= rank %d", dim, rank);
}
std::size_t elementLen{
AllocateResult(result, source, rank, extent, terminator, "EOSHIFT")};
int boundaryRank{-1};
if (boundary) {
boundaryRank = boundary->rank();
RUNTIME_CHECK(terminator, boundaryRank == 0 || boundaryRank == rank - 1);
RUNTIME_CHECK(terminator,
boundary->type() == source.type() &&
boundary->ElementBytes() == elementLen);
RUNTIME_CHECK(terminator, boundary->type() == source.type());
if (boundary->ElementBytes() != elementLen) {
terminator.Crash("EOSHIFT: BOUNDARY= has element byte length %zd, but "
"SOURCE= has length %zd",
boundary->ElementBytes(), elementLen);
}
if (boundaryRank > 0) {
int k{0};
for (int j{0}; j < rank; ++j) {
if (j != dim - 1) {
RUNTIME_CHECK(
terminator, boundary->GetDimension(k).Extent() == extent[j]);
if (boundary->GetDimension(k).Extent() != extent[j]) {
terminator.Crash("EOSHIFT: BOUNDARY= has extent %jd on dimension "
"%d but must conform with extent %jd of SOURCE=",
static_cast<std::intmax_t>(boundary->GetDimension(k).Extent()),
k + 1, static_cast<std::intmax_t>(extent[j]));
}
++k;
}
}
}
}
ShiftControl shiftControl{shift, terminator, dim};
shiftControl.Init(source);
shiftControl.Init(source, "EOSHIFT");
SubscriptValue resultAt[maxRank];
for (int j{0}; j < rank; ++j) {
resultAt[j] = 1;
Expand Down Expand Up @@ -273,9 +290,12 @@ void RTNAME(EoshiftVector)(Descriptor &result, const Descriptor &source,
AllocateResult(result, source, 1, &extent, terminator, "EOSHIFT")};
if (boundary) {
RUNTIME_CHECK(terminator, boundary->rank() == 0);
RUNTIME_CHECK(terminator,
boundary->type() == source.type() &&
boundary->ElementBytes() == elementLen);
RUNTIME_CHECK(terminator, boundary->type() == source.type());
if (boundary->ElementBytes() != elementLen) {
terminator.Crash("EOSHIFT: BOUNDARY= has element byte length %zd but "
"SOURCE= has length %zd",
boundary->ElementBytes(), elementLen);
}
}
if (!boundary) {
DefaultInitialize(result, terminator);
Expand Down Expand Up @@ -318,11 +338,19 @@ void RTNAME(Pack)(Descriptor &result, const Descriptor &source,
SubscriptValue extent{trues};
if (vector) {
RUNTIME_CHECK(terminator, vector->rank() == 1);
RUNTIME_CHECK(terminator,
source.type() == vector->type() &&
source.ElementBytes() == vector->ElementBytes());
RUNTIME_CHECK(terminator, source.type() == vector->type());
if (source.ElementBytes() != vector->ElementBytes()) {
terminator.Crash("PACK: SOURCE= has element byte length %zd, but VECTOR= "
"has length %zd",
source.ElementBytes(), vector->ElementBytes());
}
extent = vector->GetDimension(0).Extent();
RUNTIME_CHECK(terminator, extent >= trues);
if (extent < trues) {
terminator.Crash("PACK: VECTOR= has extent %jd but there are %jd MASK= "
"elements that are .TRUE.",
static_cast<std::intmax_t>(extent),
static_cast<std::intmax_t>(trues));
}
}
AllocateResult(result, source, 1, &extent, terminator, "PACK");
SubscriptValue sourceAt[maxRank], resultAt{1};
Expand Down Expand Up @@ -366,20 +394,24 @@ void RTNAME(Reshape)(Descriptor &result, const Descriptor &source,
RUNTIME_CHECK(terminator, shape.rank() == 1);
RUNTIME_CHECK(terminator, shape.type().IsInteger());
SubscriptValue resultRank{shape.GetDimension(0).Extent()};
RUNTIME_CHECK(terminator,
resultRank >= 0 && resultRank <= static_cast<SubscriptValue>(maxRank));
if (resultRank < 0 || resultRank > static_cast<SubscriptValue>(maxRank)) {
terminator.Crash(
"RESHAPE: SHAPE= vector length %jd implies a bad result rank",
static_cast<std::intmax_t>(resultRank));
}

// Extract and check the shape of the result; compute its element count.
SubscriptValue resultExtent[maxRank];
std::size_t shapeElementBytes{shape.ElementBytes()};
std::size_t resultElements{1};
SubscriptValue shapeSubscript{shape.GetDimension(0).LowerBound()};
for (SubscriptValue j{0}; j < resultRank; ++j, ++shapeSubscript) {
for (int j{0}; j < resultRank; ++j, ++shapeSubscript) {
resultExtent[j] = GetInt64(
shape.Element<char>(&shapeSubscript), shapeElementBytes, terminator);
if (resultExtent[j] < 0)
terminator.Crash(
"RESHAPE: bad value for SHAPE(%d)=%d", j + 1, resultExtent[j]);
if (resultExtent[j] < 0) {
terminator.Crash("RESHAPE: bad value for SHAPE(%d)=%jd", j + 1,
static_cast<std::intmax_t>(resultExtent[j]));
}
resultElements *= resultExtent[j];
}

Expand All @@ -389,10 +421,16 @@ void RTNAME(Reshape)(Descriptor &result, const Descriptor &source,
std::size_t sourceElements{source.Elements()};
std::size_t padElements{pad ? pad->Elements() : 0};
if (resultElements > sourceElements) {
if (padElements <= 0)
terminator.Crash("RESHAPE: not eough elements, need %d but only have %d",
if (padElements <= 0) {
terminator.Crash(
"RESHAPE: not enough elements, need %zd but only have %zd",
resultElements, sourceElements);
RUNTIME_CHECK(terminator, pad->ElementBytes() == elementBytes);
}
if (pad->ElementBytes() != elementBytes) {
terminator.Crash("RESHAPE: PAD= has element byte length %zd but SOURCE= "
"has length %zd",
pad->ElementBytes(), elementBytes);
}
}

// Extract and check the optional ORDER= argument, which must be a
Expand All @@ -401,18 +439,22 @@ void RTNAME(Reshape)(Descriptor &result, const Descriptor &source,
if (order) {
RUNTIME_CHECK(terminator, order->rank() == 1);
RUNTIME_CHECK(terminator, order->type().IsInteger());
if (order->GetDimension(0).Extent() != resultRank)
terminator.Crash("RESHAPE: the extent of ORDER (%d) must match the rank"
if (order->GetDimension(0).Extent() != resultRank) {
terminator.Crash("RESHAPE: the extent of ORDER (%jd) must match the rank"
" of the SHAPE (%d)",
order->GetDimension(0).Extent(), resultRank);
static_cast<std::intmax_t>(order->GetDimension(0).Extent()),
resultRank);
}
std::uint64_t values{0};
SubscriptValue orderSubscript{order->GetDimension(0).LowerBound()};
std::size_t orderElementBytes{order->ElementBytes()};
for (SubscriptValue j{0}; j < resultRank; ++j, ++orderSubscript) {
auto k{GetInt64(order->Element<char>(&orderSubscript), orderElementBytes,
terminator)};
if (k < 1 || k > resultRank || ((values >> k) & 1))
terminator.Crash("RESHAPE: bad value for ORDER element (%d)", k);
if (k < 1 || k > resultRank || ((values >> k) & 1)) {
terminator.Crash("RESHAPE: bad value for ORDER element (%jd)",
static_cast<std::intmax_t>(k));
}
values |= std::uint64_t{1} << k;
dimOrder[j] = k - 1;
}
Expand Down Expand Up @@ -516,8 +558,12 @@ void RTNAME(Unpack)(Descriptor &result, const Descriptor &vector,
CheckConformability(mask, field, terminator, "UNPACK", "MASK=", "FIELD=");
std::size_t elementLen{
AllocateResult(result, field, rank, extent, terminator, "UNPACK")};
RUNTIME_CHECK(terminator,
vector.type() == field.type() && vector.ElementBytes() == elementLen);
RUNTIME_CHECK(terminator, vector.type() == field.type());
if (vector.ElementBytes() != elementLen) {
terminator.Crash(
"UNPACK: VECTOR= has element byte length %zd but FIELD= has length %zd",
vector.ElementBytes(), elementLen);
}
SubscriptValue resultAt[maxRank], maskAt[maxRank], fieldAt[maxRank],
vectorAt{vector.GetDimension(0).LowerBound()};
for (int j{0}; j < rank; ++j) {
Expand Down

0 comments on commit 251d062

Please sign in to comment.