Skip to content

Commit

Permalink
[flang][HLFIR] Relax verifiers of intrinsic operations (#80132)
Browse files Browse the repository at this point in the history
The verifiers are currently very strict: requiring intrinsic operations
to be used only in cases where the Fortran standard permits the
intrinsic to be used.

There have now been a lot of cases where these verifiers have caused
bugs in corner cases. In a recent ticket, @jeanPerier pointed out that
it could be useful for future optimizations if somewhat invalid uses of
these operations could be allowed in dead code. See this comment:
#79995 (comment)

In response to all of this, I have decided to relax the intrinsic
operation verifiers. The intention is now to only disallow operation
uses that are likely to crash the compiler. Other checks are still
available under `-strict-intrinsic-verifier`.

The disadvantage of this approach is that IR can now represent intrinsic
invocations which are incorrect. The lowering and implementation of
these intrinsic functions is unlikely to do the right thing in all of
these cases, and as they should mostly be impossible to generate using
normal Fortran code, these edge cases will see very little testing,
before some new optimization causes them to become more common.

Fixes #79995
  • Loading branch information
tblah committed Feb 1, 2024
1 parent ca7fd25 commit e9e0167
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 29 deletions.
71 changes: 44 additions & 27 deletions flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include <iterator>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <optional>
#include <tuple>

static llvm::cl::opt<bool> useStrictIntrinsicVerifier(
"strict-intrinsic-verifier", llvm::cl::init(false),
llvm::cl::desc("use stricter verifier for HLFIR intrinsic operations"));

/// generic implementation of the memory side effects interface for hlfir
/// transformational intrinsic operations
static void
Expand Down Expand Up @@ -498,7 +503,7 @@ verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
mlir::Type resultType = results[0];
if (mlir::isa<fir::LogicalType>(resultType)) {
// Result is of the same type as MASK
if (resultType != logicalTy)
if ((resultType != logicalTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as MASK argument");

Expand All @@ -509,7 +514,7 @@ verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
if (!resultExpr.isArray())
return reductionOp->emitOpError("result must be an array");

if (resultExpr.getEleTy() != logicalTy)
if ((resultExpr.getEleTy() != logicalTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as MASK argument");

Expand Down Expand Up @@ -585,7 +590,7 @@ mlir::LogicalResult hlfir::CountOp::verify() {
if (resultShape.size() != (maskShape.size() - 1))
return emitOpError("result rank must be one less than MASK");
} else {
return emitOpError("result must be of numerical scalar type");
return emitOpError("result must be of numerical array type");
}
} else if (!hlfir::isFortranScalarNumericalType(resultType)) {
return emitOpError("result must be of numerical scalar type");
Expand Down Expand Up @@ -682,15 +687,18 @@ verifyArrayAndMaskForReductionOp(NumericalReductionOp reductionOp) {
if (!maskShape.empty()) {
if (maskShape.size() != arrayShape.size())
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
static_assert(fir::SequenceType::getUnknownExtent() ==
hlfir::ExprType::getUnknownExtent());
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (std::size_t i = 0; i < arrayShape.size(); ++i) {
int64_t arrayExtent = arrayShape[i];
int64_t maskExtent = maskShape[i];
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
(maskExtent != unknownExtent))
return reductionOp->emitWarning("MASK must be conformable to ARRAY");
if (useStrictIntrinsicVerifier) {
static_assert(fir::SequenceType::getUnknownExtent() ==
hlfir::ExprType::getUnknownExtent());
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
for (std::size_t i = 0; i < arrayShape.size(); ++i) {
int64_t arrayExtent = arrayShape[i];
int64_t maskExtent = maskShape[i];
if ((arrayExtent != maskExtent) && (arrayExtent != unknownExtent) &&
(maskExtent != unknownExtent))
return reductionOp->emitWarning(
"MASK must be conformable to ARRAY");
}
}
}
}
Expand Down Expand Up @@ -719,7 +727,7 @@ verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
mlir::Type resultType = results[0];
if (hlfir::isFortranScalarNumericalType(resultType)) {
// Result is of the same type as ARRAY
if (resultType != numTy)
if ((resultType != numTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");

Expand All @@ -729,7 +737,7 @@ verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
if (!resultExpr.isArray())
return reductionOp->emitOpError("result must be an array");

if (resultExpr.getEleTy() != numTy)
if ((resultExpr.getEleTy() != numTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");

Expand Down Expand Up @@ -792,7 +800,7 @@ verifyCharacterReductionOp(CharacterReductionOp reductionOp) {
"result must be character");

// Result is of the same type as ARRAY
if (resultType != numTy)
if ((resultType != numTy) && useStrictIntrinsicVerifier)
return reductionOp->emitOpError(
"result must have the same element type as ARRAY argument");

Expand Down Expand Up @@ -823,9 +831,8 @@ mlir::LogicalResult hlfir::MaxvalOp::verify() {
auto resultExpr = mlir::dyn_cast<hlfir::ExprType>(results[0]);
if (resultExpr && mlir::isa<fir::CharacterType>(resultExpr.getEleTy())) {
return verifyCharacterReductionOp<hlfir::MaxvalOp *>(this);
} else {
return verifyNumericalReductionOp<hlfir::MaxvalOp *>(this);
}
return verifyNumericalReductionOp<hlfir::MaxvalOp *>(this);
}

void hlfir::MaxvalOp::getEffects(
Expand All @@ -848,9 +855,8 @@ mlir::LogicalResult hlfir::MinvalOp::verify() {
auto resultExpr = mlir::dyn_cast<hlfir::ExprType>(results[0]);
if (resultExpr && mlir::isa<fir::CharacterType>(resultExpr.getEleTy())) {
return verifyCharacterReductionOp<hlfir::MinvalOp *>(this);
} else {
return verifyNumericalReductionOp<hlfir::MinvalOp *>(this);
}
return verifyNumericalReductionOp<hlfir::MinvalOp *>(this);
}

void hlfir::MinvalOp::getEffects(
Expand Down Expand Up @@ -1007,17 +1013,19 @@ mlir::LogicalResult hlfir::DotProductOp::verify() {

constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
if ((lhsSize != unknownExtent) && (rhsSize != unknownExtent) &&
(lhsSize != rhsSize))
(lhsSize != rhsSize) && useStrictIntrinsicVerifier)
return emitOpError("both arrays must have the same size");

if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(rhsEleTy))
return emitOpError("if one array is logical, so should the other be");
if (useStrictIntrinsicVerifier) {
if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(rhsEleTy))
return emitOpError("if one array is logical, so should the other be");

if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(resultTy))
return emitOpError("the result type should be a logical only if the "
"argument types are logical");
if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(resultTy))
return emitOpError("the result type should be a logical only if the "
"argument types are logical");
}

if (!hlfir::isFortranScalarNumericalType(resultTy) &&
!mlir::isa<fir::LogicalType>(resultTy))
Expand Down Expand Up @@ -1067,6 +1075,9 @@ mlir::LogicalResult hlfir::MatmulOp::verify() {
mlir::isa<fir::LogicalType>(rhsEleTy))
return emitOpError("if one array is logical, so should the other be");

if (!useStrictIntrinsicVerifier)
return mlir::success();

int64_t lastLhsDim = lhsShape[lhsRank - 1];
int64_t firstRhsDim = rhsShape[0];
constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
Expand Down Expand Up @@ -1179,6 +1190,9 @@ mlir::LogicalResult hlfir::TransposeOp::verify() {
if (rank != 2 || resultRank != 2)
return emitOpError("input and output arrays should have rank 2");

if (!useStrictIntrinsicVerifier)
return mlir::success();

constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
if ((inShape[0] != resultShape[1]) && (inShape[0] != unknownExtent))
return emitOpError("output shape does not match input array");
Expand Down Expand Up @@ -1226,6 +1240,9 @@ mlir::LogicalResult hlfir::MatmulTransposeOp::verify() {
if ((lhsRank != 2) || ((rhsRank != 1) && (rhsRank != 2)))
return emitOpError("array must have either rank 1 or rank 2");

if (!useStrictIntrinsicVerifier)
return mlir::success();

if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
mlir::isa<fir::LogicalType>(rhsEleTy))
return emitOpError("if one array is logical, so should the other be");
Expand Down
4 changes: 2 additions & 2 deletions flang/test/HLFIR/invalid.fir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// HLFIR ops diagnotic tests

// RUN: fir-opt -split-input-file -verify-diagnostics %s
// RUN: fir-opt -strict-intrinsic-verifier -split-input-file -verify-diagnostics %s

func.func @bad_declare(%arg0: !fir.ref<f32>) {
// expected-error@+1 {{'hlfir.declare' op first result type is inconsistent with variable properties: expected '!fir.ref<f32>'}}
Expand Down Expand Up @@ -382,7 +382,7 @@ func.func @bad_count2(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32){

// -----
func.func @bad_count3(%arg0: !hlfir.expr<?x!fir.logical<4>>, %arg1: i32) {
// expected-error@+1 {{'hlfir.count' op result must be of numerical scalar type}}
// expected-error@+1 {{'hlfir.count' op result must be of numerical array type}}
%0 = hlfir.count %arg0 dim %arg1 : (!hlfir.expr<?x!fir.logical<4>>, i32) -> !hlfir.expr<i32>
}

Expand Down
44 changes: 44 additions & 0 deletions flang/test/Lower/HLFIR/minval.f90
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,47 @@ end subroutine test_unknown_char_len_result
! CHECK-NEXT: hlfir.destroy %[[EXPR]]
! CHECK-NEXT: return
! CHECK-NEXT: }

! Test edge case with missmatch between argument type !fir.char<1,?> and result
! type !fir.char<1,4>
function test_type_mismatch
character(:), allocatable :: test_type_mismatch(:)
character(3) :: char(3,4)
test_type_mismatch = minval(char//' ', dim=1)
end function
! CHECK-LABEL: func.func @_QPtest_type_mismatch() -> !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>> {
! CHECK: %[[VAL_0:.*]] = arith.constant 3 : index
! CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
! CHECK: %[[VAL_2:.*]] = arith.constant 4 : index
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.array<3x4x!fir.char<1,3>> {bindc_name = "char", uniq_name = "_QFtest_type_mismatchEchar"}
! CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_1]], %[[VAL_2]] : (index, index) -> !fir.shape<2>
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_3]](%[[VAL_4]]) typeparams %[[VAL_0]] {uniq_name = "_QFtest_type_mismatchEchar"} : (!fir.ref<!fir.array<3x4x!fir.char<1,3>>>, !fir.shape<2>, index) -> (!fir.ref<!fir.array<3x4x!fir.char<1,3>>>, !fir.ref<!fir.array<3x4x!fir.char<1,3>>>)
! CHECK: %[[VAL_6:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>> {bindc_name = "test_type_mismatch", uniq_name = "_QFtest_type_mismatchEtest_type_mismatch"}
! CHECK: %[[VAL_7:.*]] = fir.zero_bits !fir.heap<!fir.array<?x!fir.char<1,?>>>
! CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_8]] : (index) -> !fir.shape<1>
! CHECK: %[[VAL_10:.*]] = arith.constant 0 : index
! CHECK: %[[VAL_11:.*]] = fir.embox %[[VAL_7]](%[[VAL_9]]) typeparams %[[VAL_10]] : (!fir.heap<!fir.array<?x!fir.char<1,?>>>, !fir.shape<1>, index) -> !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>
! CHECK: fir.store %[[VAL_11]] to %[[VAL_6]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>
! CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_6]] {fortran_attrs = #{{.*}}, uniq_name = "_QFtest_type_mismatchEtest_type_mismatch"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>)
! CHECK: %[[VAL_13:.*]] = fir.address_of(@_QQclX20) : !fir.ref<!fir.char<1>>
! CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_15:.*]]:2 = hlfir.declare %[[VAL_13]] typeparams %[[VAL_14]] {fortran_attrs = {{.*}}, uniq_name = "_QQclX20"} : (!fir.ref<!fir.char<1>>, index) -> (!fir.ref<!fir.char<1>>, !fir.ref<!fir.char<1>>)
! CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_0]], %[[VAL_14]] : index
! CHECK: %[[VAL_17:.*]] = hlfir.elemental %[[VAL_4]] typeparams %[[VAL_16]] unordered : (!fir.shape<2>, index) -> !hlfir.expr<3x4x!fir.char<1,?>> {
! CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
! CHECK: %[[VAL_20:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_18]], %[[VAL_19]]) typeparams %[[VAL_0]] : (!fir.ref<!fir.array<3x4x!fir.char<1,3>>>, index, index, index) -> !fir.ref<!fir.char<1,3>>
! CHECK: %[[VAL_21:.*]] = hlfir.concat %[[VAL_20]], %[[VAL_15]]#0 len %[[VAL_16]] : (!fir.ref<!fir.char<1,3>>, !fir.ref<!fir.char<1>>, index) -> !hlfir.expr<!fir.char<1,4>>
! CHECK: hlfir.yield_element %[[VAL_21]] : !hlfir.expr<!fir.char<1,4>>
! CHECK: }
! CHECK: %[[VAL_22:.*]] = arith.constant 1 : i32
! CHECK: %[[VAL_23:.*]] = hlfir.minval %[[VAL_17]] dim %[[VAL_22]] {fastmath = {{.*}}} : (!hlfir.expr<3x4x!fir.char<1,?>>, i32) -> !hlfir.expr<4x!fir.char<1,4>>
! CHECK: hlfir.assign %[[VAL_23]] to %[[VAL_12]]#0 realloc : !hlfir.expr<4x!fir.char<1,4>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>
! CHECK: hlfir.destroy %[[VAL_23]] : !hlfir.expr<4x!fir.char<1,4>>
! CHECK: hlfir.destroy %[[VAL_17]] : !hlfir.expr<3x4x!fir.char<1,?>>
! CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_12]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>>
! CHECK: %[[VAL_25:.*]] = arith.constant 1 : index
! CHECK: %[[VAL_26:.*]] = fir.shift %[[VAL_25]] : (index) -> !fir.shift<1>
! CHECK: %[[VAL_27:.*]] = fir.rebox %[[VAL_24]](%[[VAL_26]]) : (!fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>, !fir.shift<1>) -> !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>
! CHECK: return %[[VAL_27]] : !fir.box<!fir.heap<!fir.array<?x!fir.char<1,?>>>>
! CHECK: }

0 comments on commit e9e0167

Please sign in to comment.