Skip to content

Commit

Permalink
[flang] Normalize logical values during type conversions.
Browse files Browse the repository at this point in the history
Flang was missing value normalization for logical<->integer conversions
which is required by Flang specification. The shrinking logical<->logical
conversions were also incorrectly truncating the input.
This change performs value normalization for all logical<->integer
conversions and logical<->logical conversions between different kinds.

Note that value normalization is not strictly required for
logical(kind=k1)->logical(kind=k2) conversions when k1 < k2.

Differential Revision: https://reviews.llvm.org/D147019
  • Loading branch information
vzakhari committed Mar 28, 2023
1 parent dbd99cf commit f9e995b
Show file tree
Hide file tree
Showing 5 changed files with 591 additions and 25 deletions.
4 changes: 3 additions & 1 deletion flang/docs/Extensions.md
Expand Up @@ -188,7 +188,9 @@ end
relax enforcement of some requirements on actual arguments that must otherwise
hold true for definable arguments.
* Assignment of `LOGICAL` to `INTEGER` and vice versa (but not other types) is
allowed. The values are normalized.
allowed. The values are normalized to canonical `.TRUE.`/`.FALSE.`.
The values are also normalized for assignments of `LOGICAL(KIND=K1)` to
`LOGICAL(KIND=K2)`, when `K1 != K2`.
* Static initialization of `LOGICAL` with `INTEGER` is allowed in `DATA` statements
and object initializers.
The results are *not* normalized to canonical `.TRUE.`/`.FALSE.`.
Expand Down
92 changes: 72 additions & 20 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Expand Up @@ -79,17 +79,28 @@ static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
mlir::Region::iterator(insertBefore));
}

/// Extract constant from a value that must be the result of one of the
/// ConstantOp operations.
static int64_t getConstantIntValue(mlir::Value val) {
assert(val && val.dyn_cast<mlir::OpResult>() && "must not be null value");
/// Extract constant from a value if it is a result of one of the
/// ConstantOp operations, otherwise, return std::nullopt.
static std::optional<int64_t> getIfConstantIntValue(mlir::Value val) {
if (!val || !val.dyn_cast<mlir::OpResult>())
return {};

mlir::Operation *defop = val.getDefiningOp();

if (auto constOp = mlir::dyn_cast<mlir::arith::ConstantIntOp>(defop))
return constOp.value();
if (auto llConstOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(defop))
if (auto attr = llConstOp.getValue().dyn_cast<mlir::IntegerAttr>())
return attr.getValue().getSExtValue();

return {};
}

/// Extract constant from a value that must be the result of one of the
/// ConstantOp operations.
static int64_t getConstantIntValue(mlir::Value val) {
if (auto constVal = getIfConstantIntValue(val))
return *constVal;
fir::emitFatalError(val.getLoc(), "must be a constant");
}

Expand Down Expand Up @@ -858,11 +869,67 @@ struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
auto fromTy = convertType(fromFirTy);
auto toTy = convertType(toFirTy);
mlir::Value op0 = adaptor.getOperands()[0];
if (fromTy == toTy) {

if (fromFirTy == toFirTy) {
rewriter.replaceOp(convert, op0);
return mlir::success();
}

auto loc = convert.getLoc();
auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);

if (fromFirTy.isa<fir::LogicalType>() || toFirTy.isa<fir::LogicalType>()) {
// By specification fir::LogicalType value may be any number,
// where non-zero value represents .true. and zero value represents
// .false.
//
// integer<->logical conversion requires value normalization.
// Conversion from wide logical to narrow logical must set the result
// to non-zero iff the input is non-zero - the easiest way to implement
// it is to compare the input agains zero and set the result to
// the canonical 0/1.
// Conversion from narrow logical to wide logical may be implemented
// as a zero or sign extension of the input, but it may use value
// normalization as well.
if (!fromTy.isa<mlir::IntegerType>() || !toTy.isa<mlir::IntegerType>())
return mlir::emitError(loc)
<< "unsupported types for logical conversion: " << fromTy
<< " -> " << toTy;

// Do folding for constant inputs.
if (auto constVal = getIfConstantIntValue(op0)) {
mlir::Value normVal =
genConstantIndex(loc, toTy, rewriter, *constVal ? 1 : 0);
rewriter.replaceOp(convert, normVal);
return mlir::success();
}

// If the input is i1, then we can just zero extend it, and
// the result will be normalized.
if (fromTy == i1Type) {
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, op0);
return mlir::success();
}

// Compare the input with zero.
mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0);
auto isTrue = rewriter.create<mlir::LLVM::ICmpOp>(
loc, mlir::LLVM::ICmpPredicate::ne, op0, zero);

// Zero extend the i1 isTrue result to the required type (unless it is i1
// itself).
if (toTy != i1Type)
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, isTrue);
else
rewriter.replaceOp(convert, isTrue.getResult());

return mlir::success();
}

if (fromTy == toTy) {
rewriter.replaceOp(convert, op0);
return mlir::success();
}
auto convertFpToFp = [&](mlir::Value val, unsigned fromBits,
unsigned toBits, mlir::Type toTy) -> mlir::Value {
if (fromBits == toBits) {
Expand Down Expand Up @@ -896,21 +963,6 @@ struct ConvertOpConversion : public FIROpConversion<fir::ConvertOp> {
return mlir::success();
}

// Follow UNIX F77 convention for logicals:
// 1. underlying integer is not zero => logical is .TRUE.
// 2. logical is .TRUE. => set underlying integer to 1.
auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
if (fromFirTy.isa<fir::LogicalType>() && toFirTy == i1Type) {
mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
convert, mlir::LLVM::ICmpPredicate::ne, op0, zero);
return mlir::success();
}
if (fromFirTy == i1Type && toFirTy.isa<fir::LogicalType>()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(convert, toTy, op0);
return mlir::success();
}

// Floating point to floating point conversion.
if (isFloatingPointTy(fromTy)) {
if (isFloatingPointTy(toTy)) {
Expand Down
5 changes: 2 additions & 3 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Expand Up @@ -490,9 +490,8 @@ func.func @_QPsb() {

// CHECK: omp.reduction.declare @[[EQV_REDUCTION:.*]] : i32 init {
// CHECK: ^bb0(%{{.*}}: i32):
// CHECK: %[[TRUE:.*]] = llvm.mlir.constant(true) : i1
// CHECK: %[[TRUE_EXT:.*]] = llvm.zext %[[TRUE]] : i1 to i32
// CHECK: omp.yield(%[[TRUE_EXT]] : i32)
// CHECK: %[[TRUE:.*]] = llvm.mlir.constant(1 : i64) : i32
// CHECK: omp.yield(%[[TRUE]] : i32)
// CHECK: } combiner {
// CHECK: ^bb0(%[[ARG_1:.*]]: i32, %[[ARG_2:.*]]: i32):
// CHECK: %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Fir/global-initialization.fir
Expand Up @@ -40,7 +40,7 @@ fir.global internal @_QEmasklogical : !fir.array<32768x!fir.logical<4>> {
// CHECK: llvm.mlir.global internal @_QEmasklogical() {addr_space = 0 : i32} : !llvm.array<32768 x i32> {
// CHECK: [[VAL0:%.*]] = llvm.mlir.constant(true) : i1
// CHECK: [[VAL1:%.*]] = llvm.mlir.undef : !llvm.array<32768 x i32>
// CHECK: [[VAL2:%.*]] = llvm.zext [[VAL0]] : i1 to i32
// CHECK: [[VAL2:%.*]] = llvm.mlir.constant(1 : i64) : i32
// CHECK: [[VAL3:%.*]] = llvm.mlir.constant(dense<true> : vector<32768xi1>) : !llvm.array<32768 x i32>
// CHECK: llvm.return [[VAL3]] : !llvm.array<32768 x i32>
// CHECK: }
Expand Down

0 comments on commit f9e995b

Please sign in to comment.