Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang] Attempt to fix Nan handling in Minloc/Maxloc intrinsic simplification #82313

Merged
merged 1 commit into from
Feb 21, 2024

Conversation

davemgreen
Copy link
Collaborator

In certain case "extreme" values like Nan, Inf and 0xffffffff could lead to generating different code via the inline-generated intrinsics vs the versions in the runtimes (and other compilers like gfortran). There are some examples I was using for testing in https://godbolt.org/z/x4EfqEss5.

This changes the generation for the intrinsics to be more like the runtimes, using a condition that is similar to:
isFirst || (prev != prev && elem == elem) || elem < prev
The middle part is only used for floating point operations, and checks if the values are Nan. This should then hopefully make the logic closer to - return the first element with the lowest value, with Nans ignored unless there are only Nans. The initial limit value for floats are also changed from the largest float to Inf, to make sure it is handled correctly.

The integer reductions are also changed to use a similar scheme to make sure they work with masked values. This means that the preamble after the loop can be removed.

…fication.

In certain case "extreme" values like Nan, Inf and 0xffffffff could lead to
generating different code via the inline-generated intrinsics vs the versions
in the runtimes (and other compilers like gfortran). There are some examples I
was using for testing in https://godbolt.org/z/x4EfqEss5.

This changes the generation for the intrinsics to be more like the runtimes,
using a condition that is similar to:
  isFirst || (prev != prev && elem == elem) || elem < prev
The middle part is only used for floating point operations, and checks if the
values are Nan. This should then hopefully make the logic closer to - return
the first element with the lowest value, with Nans ignored unless there are
only Nans. The initial limit value for floats are also changed from the largest
float to Inf, to make sure it is handled correctly.

The integer reductions are also changed to use a similar scheme to make sure
they work with masked values. This means that the preamble after the loop can
be removed.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 20, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 20, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: David Green (davemgreen)

Changes

In certain case "extreme" values like Nan, Inf and 0xffffffff could lead to generating different code via the inline-generated intrinsics vs the versions in the runtimes (and other compilers like gfortran). There are some examples I was using for testing in https://godbolt.org/z/x4EfqEss5.

This changes the generation for the intrinsics to be more like the runtimes, using a condition that is similar to:
isFirst || (prev != prev && elem == elem) || elem < prev
The middle part is only used for floating point operations, and checks if the values are Nan. This should then hopefully make the logic closer to - return the first element with the lowest value, with Nans ignored unless there are only Nans. The initial limit value for floats are also changed from the largest float to Inf, to make sure it is handled correctly.

The integer reductions are also changed to use a similar scheme to make sure they work with masked values. This means that the preamble after the loop can be removed.


Patch is 45.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82313.diff

5 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp (+21-4)
  • (modified) flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp (+20-39)
  • (modified) flang/test/HLFIR/maxloc-elemental.fir (+19-15)
  • (modified) flang/test/HLFIR/minloc-elemental.fir (+33-37)
  • (modified) flang/test/Transforms/simplifyintrinsics.fir (+43-71)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c2512c7df32f46..685c73d6762570 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -852,9 +852,8 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
                         mlir::Type elementType) {
       if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
         const llvm::fltSemantics &sem = ty.getFloatSemantics();
-        return builder.createRealConstant(
-            loc, elementType,
-            llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
+        llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
+        return builder.createRealConstant(loc, elementType, limit);
       }
       unsigned bits = elementType.getIntOrFloatBitWidth();
       int64_t limitInt =
@@ -895,7 +894,7 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
       // Set flag that mask was true at some point
       mlir::Value flagSet = builder.createIntegerConstant(
           loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
-      builder.create<fir::StoreOp>(loc, flagSet, flagRef);
+      mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
       mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array},
                                              oneBasedIndices);
       mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
@@ -903,11 +902,22 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
       // Compare with the max reduction value
       mlir::Value cmp;
       if (elementType.isa<mlir::FloatType>()) {
+        // For FP reductions we want the first smallest value to be used, that
+        // is not NaN. A OGL/OLT condition will usually work for this unless all
+        // the values are Nan or Inf. This follows the same logic as
+        // NumericCompare for Minloc/Maxlox in extrema.cpp.
         cmp = builder.create<mlir::arith::CmpFOp>(
             loc,
             isMax ? mlir::arith::CmpFPredicate::OGT
                   : mlir::arith::CmpFPredicate::OLT,
             elem, reduction);
+
+        mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
+            loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
+        mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
+            loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
+        cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
+        cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
       } else if (elementType.isa<mlir::IntegerType>()) {
         cmp = builder.create<mlir::arith::CmpIOp>(
             loc,
@@ -918,11 +928,18 @@ class MinMaxlocElementalConversion : public mlir::OpRewritePattern<Op> {
         llvm_unreachable("unsupported type");
       }
 
+      // The condition used for the loop is isFirst || <the condition above>.
+      isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
+      isFirst = builder.create<mlir::arith::XOrIOp>(
+          loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
+      cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
+
       // Set the new coordinate to the result
       fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
                                                  /*withElseRegion*/ true);
 
       builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      builder.create<fir::StoreOp>(loc, flagSet, flagRef);
       mlir::Type resultElemTy =
           hlfir::getFortranElementType(resultArr.getType());
       mlir::Type returnRefTy = builder.getRefType(resultElemTy);
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 86343e23c6e5db..f483651a68dc17 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -649,42 +649,6 @@ void fir::genMinMaxlocReductionLoop(
       reductionVal = ifOp.getResult(0);
     }
   }
-
-  // Check for case where array was full of max values.
-  // flag will be 0 if mask was never true, 1 if mask was true as some point,
-  // this is needed to avoid catching cases where we didn't access any elements
-  // e.g. mask=.FALSE.
-  mlir::Value flagValue =
-      builder.create<fir::LoadOp>(loc, resultElemType, flagRef);
-  mlir::Value flagCmp = builder.create<mlir::arith::CmpIOp>(
-      loc, mlir::arith::CmpIPredicate::eq, flagValue, flagSet);
-  fir::IfOp ifMaskTrueOp =
-      builder.create<fir::IfOp>(loc, flagCmp, /*withElseRegion=*/false);
-  builder.setInsertionPointToStart(&ifMaskTrueOp.getThenRegion().front());
-
-  mlir::Value testInit = initVal(builder, loc, elementType);
-  fir::IfOp ifMinSetOp;
-  if (elementType.isa<mlir::FloatType>()) {
-    mlir::Value cmp = builder.create<mlir::arith::CmpFOp>(
-        loc, mlir::arith::CmpFPredicate::OEQ, testInit, reductionVal);
-    ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
-                                           /*withElseRegion*/ false);
-  } else {
-    mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
-        loc, mlir::arith::CmpIPredicate::eq, testInit, reductionVal);
-    ifMinSetOp = builder.create<fir::IfOp>(loc, cmp,
-                                           /*withElseRegion*/ false);
-  }
-  builder.setInsertionPointToStart(&ifMinSetOp.getThenRegion().front());
-
-  // Load output array with 1s instead of 0s
-  for (unsigned int i = 0; i < rank; ++i) {
-    mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
-    mlir::Value resultElemAddr =
-        getAddrFn(builder, loc, resultElemType, resultArr, index);
-    builder.create<fir::StoreOp>(loc, flagSet, resultElemAddr);
-  }
-  builder.setInsertionPointAfter(ifMaskTrueOp);
 }
 
 static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
@@ -697,8 +661,8 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
                       mlir::Type elementType) {
     if (auto ty = elementType.dyn_cast<mlir::FloatType>()) {
       const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(
-          loc, elementType, llvm::APFloat::getLargest(sem, /*Negative=*/isMax));
+      llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax);
+      return builder.createRealConstant(loc, elementType, limit);
     }
     unsigned bits = elementType.getIntOrFloatBitWidth();
     int64_t initValue = (isMax ? llvm::APInt::getSignedMinValue(bits)
@@ -770,7 +734,7 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
     // Set flag that mask was true at some point
     mlir::Value flagSet = builder.createIntegerConstant(
         loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
-    builder.create<fir::StoreOp>(loc, flagSet, flagRef);
+    mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
     mlir::Type eleRefTy = builder.getRefType(elementType);
     mlir::Value addr =
         builder.create<fir::CoordinateOp>(loc, eleRefTy, array, indices);
@@ -778,11 +742,22 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
 
     mlir::Value cmp;
     if (elementType.isa<mlir::FloatType>()) {
+      // For FP reductions we want the first smallest value to be used, that
+      // is not NaN. A OGL/OLT condition will usually work for this unless all
+      // the values are Nan or Inf. This follows the same logic as
+      // NumericCompare for Minloc/Maxlox in extrema.cpp.
       cmp = builder.create<mlir::arith::CmpFOp>(
           loc,
           isMax ? mlir::arith::CmpFPredicate::OGT
                 : mlir::arith::CmpFPredicate::OLT,
           elem, reduction);
+
+      mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
+          loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
+      mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
+          loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
+      cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
+      cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
     } else if (elementType.isa<mlir::IntegerType>()) {
       cmp = builder.create<mlir::arith::CmpIOp>(
           loc,
@@ -793,10 +768,16 @@ static void genRuntimeMinMaxlocBody(fir::FirOpBuilder &builder,
       llvm_unreachable("unsupported type");
     }
 
+    // The condition used for the loop is isFirst || <the condition above>.
+    isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst);
+    isFirst = builder.create<mlir::arith::XOrIOp>(
+        loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1));
+    cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
     fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp,
                                                /*withElseRegion*/ true);
 
     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    builder.create<fir::StoreOp>(loc, flagSet, flagRef);
     mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
     mlir::Type returnRefTy = builder.getRefType(resultElemTy);
     mlir::IndexType idxTy = builder.getIndexType();
diff --git a/flang/test/HLFIR/maxloc-elemental.fir b/flang/test/HLFIR/maxloc-elemental.fir
index b4a3ca0d86068f..c97117dd10de13 100644
--- a/flang/test/HLFIR/maxloc-elemental.fir
+++ b/flang/test/HLFIR/maxloc-elemental.fir
@@ -23,6 +23,7 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
   return
 }
 // CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
+// CHECK-NEXT:    %true = arith.constant true
 // CHECK-NEXT:    %c-2147483648_i32 = arith.constant -2147483648 : i32
 // CHECK-NEXT:    %c1_i32 = arith.constant 1 : i32
 // CHECK-NEXT:    %c0 = arith.constant 0 : index
@@ -45,14 +46,18 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
 // CHECK-NEXT:      %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<i32>
 // CHECK-NEXT:      %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32
 // CHECK-NEXT:      %[[V18:.*]] = fir.if %[[V17]] -> (i32) {
-// CHECK-NEXT:        fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
+// CHECK-NEXT:        %[[ISFIRST:.*]] = fir.load %[[V0]] : !fir.ref<i32>
 // CHECK-NEXT:        %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK-NEXT:        %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
 // CHECK-NEXT:        %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
 // CHECK-NEXT:        %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
 // CHECK-NEXT:        %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<i32>
 // CHECK-NEXT:        %[[V21:.*]] = arith.cmpi sgt, %[[V20]], %arg4 : i32
-// CHECK-NEXT:        %[[V22:.*]] = fir.if %[[V21]] -> (i32) {
+// CHECK-NEXT:        %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1
+// CHECK-NEXT:        %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1
+// CHECK-NEXT:        %[[ORCOND:.*]] = arith.ori %[[V21]], %[[ISFIRSTNOT]] : i1
+// CHECK-NEXT:        %[[V22:.*]] = fir.if %[[ORCOND]] -> (i32) {
+// CHECK-NEXT:          fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
 // CHECK-NEXT:          %[[V23:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
 // CHECK-NEXT:          %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
 // CHECK-NEXT:          fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
@@ -66,15 +71,6 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
 // CHECK-NEXT:      }
 // CHECK-NEXT:      fir.result %[[V18]] : i32
 // CHECK-NEXT:    }
-// CHECK-NEXT:    %[[V12:.*]] = fir.load %[[V0]] : !fir.ref<i32>
-// CHECK-NEXT:    %[[V13:.*]] = arith.cmpi eq, %[[V12]], %c1_i32 : i32
-// CHECK-NEXT:    fir.if %[[V13]] {
-// CHECK-NEXT:      %[[V14:.*]] = arith.cmpi eq, %[[V11]], %c-2147483648_i32 : i32
-// CHECK-NEXT:      fir.if %[[V14]] {
-// CHECK-NEXT:        %[[V15:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
-// CHECK-NEXT:        fir.store %c1_i32 to %[[V15]] : !fir.ref<i32>
-// CHECK-NEXT:      }
-// CHECK-NEXT:    }
 // CHECK-NEXT:    %[[BD:.*]]:3 = fir.box_dims %[[V2]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK-NEXT:    fir.do_loop %arg3 = %c1 to %[[BD]]#1 step %c1 unordered {
 // CHECK-NEXT:      %[[V13:.*]] = hlfir.designate %[[RES]] (%arg3)  : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
@@ -110,21 +106,29 @@ func.func @_QPtest_float(%arg0: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "a
   return
 }
 // CHECK-LABEL: _QPtest_float
-// CHECK:        %cst = arith.constant -3.40282347E+38 : f32
+// CHECK:        %cst = arith.constant 0xFF800000 : f32
 // CHECK:        %[[V11:.*]] = fir.do_loop %arg3 = %c0 to %[[V10:.*]] step %c1 iter_args(%arg4 = %cst) -> (f32) {
 // CHECK-NEXT:     %[[V14:.*]] = arith.addi %arg3, %c1 : index
 // CHECK-NEXT:     %[[V15:.*]] = hlfir.designate %[[V1:.*]]#0 (%[[V14]])  : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
 // CHECK-NEXT:     %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<f32>
 // CHECK-NEXT:     %[[V17:.*]] = arith.cmpf oge, %[[V16]], %[[V4:.*]] : f32
 // CHECK-NEXT:     %[[V18:.*]] = fir.if %[[V17]] -> (f32) {
-// CHECK-NEXT:       fir.store %c1_i32 to %[[V0:.*]] : !fir.ref<i32>
+// CHECK-NEXT:       %[[ISFIRST:.*]] = fir.load %[[V0:.*]] : !fir.ref<i32>
 // CHECK-NEXT:       %[[DIMS:.*]]:3 = fir.box_dims %2#0, %c0 : (!fir.box<!fir.array<?xf32>>, index) -> (index, index, index)
 // CHECK-NEXT:       %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
 // CHECK-NEXT:       %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
 // CHECK-NEXT:       %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
 // CHECK-NEXT:       %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<f32>
-// CHECK-NEXT:       %[[V21:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
-// CHECK-NEXT:       %[[V22:.*]] = fir.if %[[V21]] -> (f32) {
+// CHECK-NEXT:       %[[NEW_MIN:.*]] = arith.cmpf ogt, %[[V20]], %arg4 fastmath<contract> : f32
+// CHECK-NEXT:       %[[CONDRED:.*]] = arith.cmpf une, %arg4, %arg4 fastmath<contract> : f32
+// CHECK-NEXT:       %[[CONDELEM:.*]] = arith.cmpf oeq, %[[V20]], %[[V20]] fastmath<contract> : f32
+// CHECK-NEXT:       %[[ANDCOND:.*]] = arith.andi %[[CONDRED]], %[[CONDELEM]] : i1
+// CHECK-NEXT:       %[[NEW_MIN2:.*]] = arith.ori %[[NEW_MIN]], %[[ANDCOND]] : i1
+// CHECK-NEXT:       %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1
+// CHECK-NEXT:       %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1
+// CHECK-NEXT:       %[[ORCOND:.*]] = arith.ori %[[NEW_MIN2]], %[[ISFIRSTNOT]] : i1
+// CHECK-NEXT:       %[[V22:.*]] = fir.if %[[ORCOND]] -> (f32) {
+// CHECK-NEXT:         fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
 // CHECK-NEXT:         %[[V23:.*]] = hlfir.designate %{{.}} (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
 // CHECK-NEXT:         %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
 // CHECK-NEXT:         fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
diff --git a/flang/test/HLFIR/minloc-elemental.fir b/flang/test/HLFIR/minloc-elemental.fir
index 5cc608b65be8bc..58cfe3ea012793 100644
--- a/flang/test/HLFIR/minloc-elemental.fir
+++ b/flang/test/HLFIR/minloc-elemental.fir
@@ -23,6 +23,7 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
   return
 }
 // CHECK-LABEL: func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "m"}) {
+// CHECK-NEXT:    %true = arith.constant true
 // CHECK-NEXT:    %c2147483647_i32 = arith.constant 2147483647 : i32
 // CHECK-NEXT:    %c1_i32 = arith.constant 1 : i32
 // CHECK-NEXT:    %c0 = arith.constant 0 : index
@@ -45,14 +46,18 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
 // CHECK-NEXT:      %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<i32>
 // CHECK-NEXT:      %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32
 // CHECK-NEXT:      %[[V18:.*]] = fir.if %[[V17]] -> (i32) {
-// CHECK-NEXT:        fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
+// CHECK-NEXT:        %[[ISFIRST:.*]] = fir.load %[[V0]] : !fir.ref<i32>
 // CHECK-NEXT:        %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK-NEXT:        %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
 // CHECK-NEXT:        %[[ADD:.*]] = arith.addi %[[V14]], %[[SUB]] : index
 // CHECK-NEXT:        %[[V19:.*]] = hlfir.designate %[[V1]]#0 (%[[ADD]]) : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
 // CHECK-NEXT:        %[[V20:.*]] = fir.load %[[V19]] : !fir.ref<i32>
 // CHECK-NEXT:        %[[V21:.*]] = arith.cmpi slt, %[[V20]], %arg4 : i32
-// CHECK-NEXT:        %[[V22:.*]] = fir.if %[[V21]] -> (i32) {
+// CHECK-NEXT:        %[[ISFIRSTL:.*]] = fir.convert %[[ISFIRST]] : (i32) -> i1
+// CHECK-NEXT:        %[[ISFIRSTNOT:.*]] = arith.xori %[[ISFIRSTL]], %true : i1
+// CHECK-NEXT:        %[[ORCOND:.*]] = arith.ori %[[V21]], %[[ISFIRSTNOT]] : i1
+// CHECK-NEXT:        %[[V22:.*]] = fir.if %[[ORCOND]] -> (i32) {
+// CHECK-NEXT:          fir.store %c1_i32 to %[[V0]] : !fir.ref<i32>
 // CHECK-NEXT:          %[[V23:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
 // CHECK-NEXT:          %[[V24:.*]] = fir.convert %[[V14]] : (index) -> i32
 // CHECK-NEXT:          fir.store %[[V24]] to %[[V23]] : !fir.ref<i32>
@@ -66,15 +71,6 @@ func.func @_QPtest(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}
 // CHECK-NEXT:      }
 // CHECK-NEXT:      fir.result %[[V18]] : i32
 // CHECK-NEXT:    }
-// CHECK-NEXT:    %[[V12:.*]] = fir.load %[[V0]] : !fir.ref<i32>
-// CHECK-NEXT:    %[[V13:.*]] = arith.cmpi eq, %[[V12]], %c1_i32 : i32
-// CHECK-NEXT:    fir.if %[[V13]] {
-// CHECK-NEXT:      %[[V14:.*]] = arith.cmpi eq, %[[V11]], %c2147483647_i32 : i32
-// CHECK-NEXT:      fir.if %[[V14]] {
-// CHECK-NEXT:        %[[V15:.*]] = hlfir.designate %[[RES]] (%c1) : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
-// CHECK-NEXT:        fir.store %c1_i32 to %[[V15]] : !fir.ref<i32>
-// CHECK-NEXT:      }
-// CHECK-NEXT:    }
 // CHECK-NEXT:    %[[BD:.*]]:3 = fir.box_dims %[[V2]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK-NEXT:    fir.do_loop %arg3 = %c1 to %[[BD]]#1 step %c1 unordered {
 // CHECK-NEXT:      %[[V13:.*]] = hlfir.designate %[[RES]] (%arg3)  : (!fir.ref<!fir.array<1xi32>>, index) -> !fir.ref<i32>
@@ -109,6 +105,7 @@ func.func @_QPtest_kind2(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a
   return
 }
 // CHECK-LABEL:  func.func @_QPtest_kind2(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "array"}, %arg1: !fir.ref<i32> {fir.bindc_name = "val"}, %arg2: !fir.box<!fir.array<?xi16>> {fir.bindc_name = "m"}) {
+// CHECK-NEXT:    %true = arith.constant true
 // CHECK-NEXT:    %c2147483647_i32 = arith.constant 2147483647 : i32
 // CHECK-NEXT:    %c1_i16 = arith.constant 1 : i16
 // CHECK-NEXT:    %c0 = arith.constant 0 : index
@@ -131,14 +128,18 @@ func.func @_QPtest_kind2(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "a
 // CHECK-NEXT:      %[[V16:.*]] = fir.load %[[V15]] : !fir.ref<i32>
 // CHECK-NEXT:      %[[V17:.*]] = arith.cmpi sge, %[[V16]], %[[V4]] : i32
 // CHECK-NEXT:      %[[V18:.*]] = fir.if %[[V17]] -> (i32) {
-// CHECK-NEXT:        fir.store %c1_i16 to %[[V0]] : !fir.ref<i16>
+// CHECK-NEXT:        %[[ISFIRST:.*]] = fir.load %[[V0]] : !fir.ref<i16>
 // CHECK-NEXT:        %[[DIMS:.*]]:3 = fir.box_dims %[[V1]]#0, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
 // CHECK-NEXT:        %[[SUB:.*]] = arith.subi %[[DIMS]]#0, %c1 : index
 // CHECK-NEXT:        %[[ADD:.*]] = arith.addi %[[V14]], %[[S...
[truncated]

Copy link
Contributor

@psteinfeld psteinfeld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes pass all of our internal tests.

But you should get someone who understands the code to review and approve before merging.

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me. Thank you for the changes, Dave!

Pete, thank you for testing it!

@davemgreen
Copy link
Collaborator Author

It looks good to me. Thank you for the changes, Dave!

Pete, thank you for testing it!

Yeah very much so, Thanks for your patience! #81619 has run into many more issues than I expected, that I didn't see in downstream testing.

@davemgreen davemgreen merged commit 7242896 into llvm:main Feb 21, 2024
7 checks passed
@davemgreen davemgreen deleted the gh-flang-fixnan branch February 21, 2024 09:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants