diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index eca309aab4463..9b6e1c958aec7 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -727,8 +727,10 @@ struct ConvertOpConversion : public FIROpConversion { mlir::LogicalResult matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - auto fromTy = convertType(convert.getValue().getType()); - auto toTy = convertType(convert.getRes().getType()); + auto fromFirTy = convert.getValue().getType(); + auto toFirTy = convert.getRes().getType(); + auto fromTy = convertType(fromFirTy); + auto toTy = convertType(toFirTy); mlir::Value op0 = adaptor.getOperands()[0]; if (fromTy == toTy) { rewriter.replaceOp(convert, op0); @@ -750,8 +752,7 @@ struct ConvertOpConversion : public FIROpConversion { return rewriter.create(loc, toTy, val); }; // Complex to complex conversion. - if (fir::isa_complex(convert.getValue().getType()) && - fir::isa_complex(convert.getRes().getType())) { + if (fir::isa_complex(fromFirTy) && fir::isa_complex(toFirTy)) { // Special case: handle the conversion of a complex such that both the // real and imaginary parts are converted together. auto zero = mlir::ArrayAttr::get(convert.getContext(), @@ -773,6 +774,22 @@ struct ConvertOpConversion : public FIROpConversion { ic, one); return mlir::success(); } + + // Follow UNIX F77 convention for logicals: + // 1. underlying integer is not zero => logical is .TRUE. + // 2. logical is .TRUE. => set underlying integer to 1. + auto i1Type = mlir::IntegerType::get(convert.getContext(), 1); + if (fromFirTy.isa() && toFirTy == i1Type) { + mlir::Value zero = genConstantIndex(loc, fromTy, rewriter, 0); + rewriter.replaceOpWithNewOp( + convert, mlir::LLVM::ICmpPredicate::ne, op0, zero); + return mlir::success(); + } + if (fromFirTy == i1Type && toFirTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + // Floating point to floating point conversion. if (isFloatingPointTy(fromTy)) { if (isFloatingPointTy(toTy)) { diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir index aae5275f5ae19..6bf9a026cda69 100644 --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -1288,7 +1288,8 @@ func @select_case_logical(%arg0: !fir.ref>) { // CHECK-LABEL: llvm.func @select_case_logical( // CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr // CHECK: %[[LOAD_ARG0:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -// CHECK: %[[SELECT_VALUE:.*]] = llvm.trunc %[[LOAD_ARG0]] : i32 to i1 +// CHECK: %[[CST_ZERO:.*]] = llvm.mlir.constant(0 : i64) : i32 +// CHECK: %[[SELECT_VALUE:.*]] = llvm.icmp "ne" %[[LOAD_ARG0]], %[[CST_ZERO]] : i32 // CHECK: %[[CST_FALSE:.*]] = llvm.mlir.constant(false) : i1 // CHECK: %[[CST_TRUE:.*]] = llvm.mlir.constant(true) : i1 // CHECK: %[[CMPEQ:.*]] = llvm.icmp "eq" %[[SELECT_VALUE]], %[[CST_FALSE]] : i1