Skip to content

Commit

Permalink
[mlir][Arith] Add arith.is_nan and arith.is_inf predicates
Browse files Browse the repository at this point in the history
Both LLVM and SPIR-V have some form of "is this float a NaN/Inf"
operation (though LLVM's uses the rather opaque "is.fpclass"
intrinsic), which is not exposed in MLIR.

This has lead to awkward workarounds in -arith-expands-ops where a NaN
test was performed by comparing an operation to itself. This commit
resolves that issue.

Reviewed By: dcaballe, kuhar

Differential Revision: https://reviews.llvm.org/D156169
  • Loading branch information
krzysz00 committed Aug 2, 2023
1 parent d6f1880 commit 7c349c3
Show file tree
Hide file tree
Showing 19 changed files with 294 additions and 66 deletions.
40 changes: 20 additions & 20 deletions flang/test/Lower/Intrinsics/ieee_class_queries.f90
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,38 @@
real(10) :: x10 = -10.0
real(16) :: x16 = -16.0

! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 504 : i32}> : (f16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 60 : i32}> : (f16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (f16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 60 : i32}> : (f16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f16) -> i1
print*, ieee_is_finite(x2), ieee_is_negative(x2), ieee_is_normal(x2), &
ieee_is_nan(x2)

! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 504 : i32}> : (bf16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 60 : i32}> : (bf16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (bf16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (bf16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (bf16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 60 : i32}> : (bf16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (bf16) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (bf16) -> i1
print*, ieee_is_finite(x3), ieee_is_negative(x3), ieee_is_normal(x3), &
ieee_is_nan(x3)

! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 504 : i32}> : (f32) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 60 : i32}> : (f32) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f32) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f32) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (f32) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 60 : i32}> : (f32) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f32) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f32) -> i1
print*, ieee_is_finite(x4), ieee_is_negative(x4), ieee_is_normal(x4), &
ieee_is_nan(x4)

! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 504 : i32}> : (f64) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 60 : i32}> : (f64) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f64) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f64) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (f64) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 60 : i32}> : (f64) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f64) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f64) -> i1
print*, ieee_is_finite(x8), ieee_is_negative(x8), ieee_is_normal(x8), &
ieee_is_nan(x8)

! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 504 : i32}> : (f128) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 60 : i32}> : (f128) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f128) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (f128) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 60 : i32}> : (f128) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f128) -> i1
! CHECK: "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
print*, ieee_is_finite(x16), ieee_is_negative(x16), ieee_is_normal(x16), &
ieee_is_nan(x16)

Expand Down
8 changes: 4 additions & 4 deletions flang/test/Lower/Intrinsics/ieee_is_finite.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ subroutine is_finite_test(x, y)
real(8) y

! CHECK: %[[V_3:[0-9]+]] = fir.load %arg0 : !fir.ref<f32>
! CHECK: %[[V_4:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_3]]) <{bit = 504 : i32}> : (f32) -> i1
! CHECK: %[[V_4:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_3]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (f32) -> i1
! CHECK: %[[V_5:[0-9]+]] = fir.convert %[[V_4]] : (i1) -> !fir.logical<4>
! CHECK: %[[V_6:[0-9]+]] = fir.convert %[[V_5]] : (!fir.logical<4>) -> i1
! CHECK: %[[V_7:[0-9]+]] = fir.call @_FortranAioOutputLogical(%{{.*}}, %[[V_6]]) {{.*}} : (!fir.ref<i8>, i1) -> i1
Expand All @@ -16,14 +16,14 @@ subroutine is_finite_test(x, y)
! CHECK: %[[V_12:[0-9]+]] = fir.load %arg0 : !fir.ref<f32>
! CHECK: %[[V_13:[0-9]+]] = fir.load %arg0 : !fir.ref<f32>
! CHECK: %[[V_14:[0-9]+]] = arith.addf %[[V_12]], %[[V_13]] {{.*}} : f32
! CHECK: %[[V_15:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_14]]) <{bit = 504 : i32}> : (f32) -> i1
! CHECK: %[[V_15:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_14]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (f32) -> i1
! CHECK: %[[V_16:[0-9]+]] = fir.convert %[[V_15]] : (i1) -> !fir.logical<4>
! CHECK: %[[V_17:[0-9]+]] = fir.convert %[[V_16]] : (!fir.logical<4>) -> i1
! CHECK: %[[V_18:[0-9]+]] = fir.call @_FortranAioOutputLogical(%{{.*}}, %[[V_17]]) {{.*}} : (!fir.ref<i8>, i1) -> i1
print*, ieee_is_finite(x+x)

! CHECK: %[[V_23:[0-9]+]] = fir.load %arg1 : !fir.ref<f64>
! CHECK: %[[V_24:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_23]]) <{bit = 504 : i32}> : (f64) -> i1
! CHECK: %[[V_24:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_23]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (f64) -> i1
! CHECK: %[[V_25:[0-9]+]] = fir.convert %[[V_24]] : (i1) -> !fir.logical<4>
! CHECK: %[[V_26:[0-9]+]] = fir.convert %[[V_25]] : (!fir.logical<4>) -> i1
! CHECK: %[[V_27:[0-9]+]] = fir.call @_FortranAioOutputLogical(%{{.*}}, %[[V_26]]) {{.*}} : (!fir.ref<i8>, i1) -> i1
Expand All @@ -32,7 +32,7 @@ subroutine is_finite_test(x, y)
! CHECK: %[[V_32:[0-9]+]] = fir.load %arg1 : !fir.ref<f64>
! CHECK: %[[V_33:[0-9]+]] = fir.load %arg1 : !fir.ref<f64>
! CHECK: %[[V_34:[0-9]+]] = arith.addf %[[V_32]], %[[V_33]] {{.*}} : f64
! CHECK: %[[V_35:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_34]]) <{bit = 504 : i32}> : (f64) -> i1
! CHECK: %[[V_35:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_34]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 504 : i32}> : (f64) -> i1
! CHECK: %[[V_36:[0-9]+]] = fir.convert %[[V_35]] : (i1) -> !fir.logical<4>
! CHECK: %[[V_37:[0-9]+]] = fir.convert %[[V_36]] : (!fir.logical<4>) -> i1
! CHECK: %[[V_38:[0-9]+]] = fir.call @_FortranAioOutputLogical(%{{.*}}, %[[V_37]]) {{.*}} : (!fir.ref<i8>, i1) -> i1
Expand Down
12 changes: 6 additions & 6 deletions flang/test/Lower/Intrinsics/ieee_is_normal.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ subroutine ieee_is_normal_f16(r)
use ieee_arithmetic
real(KIND=2) :: r
i = ieee_is_normal(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f16) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f16) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_normal_f16

Expand All @@ -15,7 +15,7 @@ subroutine ieee_is_normal_bf16(r)
use ieee_arithmetic
real(KIND=3) :: r
i = ieee_is_normal(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (bf16) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (bf16) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_normal_bf16

Expand All @@ -26,7 +26,7 @@ subroutine ieee_is_normal_f32(r)
use ieee_arithmetic
real :: r
i = ieee_is_normal(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f32) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f32) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_normal_f32

Expand All @@ -35,7 +35,7 @@ subroutine ieee_is_normal_f64(r)
use ieee_arithmetic
real(KIND=8) :: r
i = ieee_is_normal(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f64) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f64) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_normal_f64

Expand All @@ -44,7 +44,7 @@ subroutine ieee_is_normal_f80(r)
use ieee_arithmetic
real(KIND=10) :: r
i = ieee_is_normal(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f80) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f80) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_normal_f80

Expand All @@ -53,6 +53,6 @@ subroutine ieee_is_normal_f128(r)
use ieee_arithmetic
real(KIND=16) :: r
i = ieee_is_normal(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 360 : i32}> : (f128) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 360 : i32}> : (f128) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_normal_f128
12 changes: 6 additions & 6 deletions flang/test/Lower/Intrinsics/ieee_unordered.f90
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,26 @@

! CHECK: %[[V_40:[0-9]+]] = fir.load %[[V_2]] : !fir.ref<f128>
! CHECK: %[[V_41:[0-9]+]] = fir.load %[[V_3]] : !fir.ref<f128>
! CHECK: %[[V_42:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_40]]) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_43:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_41]]) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_42:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_40]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_43:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_41]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_44:[0-9]+]] = arith.ori %[[V_42]], %[[V_43]] : i1
! CHECK: %[[V_45:[0-9]+]] = fir.convert %[[V_44]] : (i1) -> !fir.logical<4>
! CHECK: %[[V_46:[0-9]+]] = fir.convert %[[V_45]] : (!fir.logical<4>) -> i1
! CHECK: %[[V_47:[0-9]+]] = fir.call @_FortranAioOutputLogical(%{{.*}}, %[[V_46]]) {{.*}} : (!fir.ref<i8>, i1) -> i1

! CHECK: %[[V_48:[0-9]+]] = fir.load %[[V_2]] : !fir.ref<f128>
! CHECK: %[[V_49:[0-9]+]] = fir.load %[[V_4]] : !fir.ref<f128>
! CHECK: %[[V_50:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_48]]) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_51:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_49]]) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_50:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_48]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_51:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_49]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_52:[0-9]+]] = arith.ori %[[V_50]], %[[V_51]] : i1
! CHECK: %[[V_53:[0-9]+]] = fir.convert %[[V_52]] : (i1) -> !fir.logical<4>
! CHECK: %[[V_54:[0-9]+]] = fir.convert %[[V_53]] : (!fir.logical<4>) -> i1
! CHECK: %[[V_55:[0-9]+]] = fir.call @_FortranAioOutputLogical(%{{.*}}, %[[V_54]]) {{.*}} : (!fir.ref<i8>, i1) -> i1

! CHECK: %[[V_56:[0-9]+]] = fir.load %[[V_3]] : !fir.ref<f128>
! CHECK: %[[V_57:[0-9]+]] = fir.load %[[V_4]] : !fir.ref<f128>
! CHECK: %[[V_58:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_56]]) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_59:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_57]]) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_58:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_56]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_59:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_57]]) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
! CHECK: %[[V_60:[0-9]+]] = arith.ori %[[V_58]], %[[V_59]] : i1
! CHECK: %[[V_61:[0-9]+]] = fir.convert %[[V_60]] : (i1) -> !fir.logical<4>
! CHECK: %[[V_62:[0-9]+]] = fir.convert %[[V_61]] : (!fir.logical<4>) -> i1
Expand Down
16 changes: 8 additions & 8 deletions flang/test/Lower/Intrinsics/isnan.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
subroutine isnan_f32(r)
real :: r
i = isnan(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f32) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f32) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine isnan_f32

Expand All @@ -14,15 +14,15 @@ subroutine ieee_is_nan_f32(r)
use ieee_arithmetic
real :: r
i = ieee_is_nan(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f32) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f32) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_nan_f32

! CHECK-LABEL: isnan_f64
subroutine isnan_f64(r)
real(KIND=8) :: r
i = isnan(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f64) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f64) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine isnan_f64

Expand All @@ -31,15 +31,15 @@ subroutine ieee_is_nan_f64(r)
use ieee_arithmetic
real(KIND=8) :: r
i = ieee_is_nan(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f64) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f64) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_nan_f64

! CHECK-LABEL: isnan_f80
subroutine isnan_f80(r)
real(KIND=10) :: r
i = isnan(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f80) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f80) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine isnan_f80

Expand All @@ -48,15 +48,15 @@ subroutine ieee_is_nan_f80(r)
use ieee_arithmetic
real(KIND=10) :: r
i = ieee_is_nan(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f80) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f80) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_nan_f80

! CHECK-LABEL: isnan_f128
subroutine isnan_f128(r)
real(KIND=16) :: r
i = isnan(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine isnan_f128

Expand All @@ -65,6 +65,6 @@ subroutine ieee_is_nan_f128(r)
use ieee_arithmetic
real(KIND=16) :: r
i = ieee_is_nan(r)
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{bit = 3 : i32}> : (f128) -> i1
! CHECK: %[[l:.*]] = "llvm.intr.is.fpclass"(%{{.*}}) <{fastmathFlags = #llvm.fastmath<none>, kinds = 3 : i32}> : (f128) -> i1
! CHECK: fir.convert %[[l]] : (i1) -> !fir.logical<4>
end subroutine ieee_is_nan_f128
22 changes: 21 additions & 1 deletion mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,29 @@ class AttrConvertFastMathToLLVM {

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }

private:
protected:
NamedAttrList convertedAttr;
};

/// Wrapper around AttrConvertFastMathToLLVM that also sets the "kinds"
/// attribute to the bitmask specified in `Kinds`, which is used for converting
/// operations that lower to llvm.is.fpclass.
template <unsigned Kinds, typename SourceOp, typename TargetOp>
class AttrConvertAddFpclassKinds
: public AttrConvertFastMathToLLVM<SourceOp, TargetOp> {
public:
AttrConvertAddFpclassKinds(SourceOp op)
: AttrConvertFastMathToLLVM<SourceOp, TargetOp>(op) {
convertedAttr.set(
"kinds",
IntegerAttr::get(IntegerType::get(op.getContext(), 32), Kinds));
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }

protected:
using AttrConvertFastMathToLLVM<SourceOp, TargetOp>::convertedAttr;
};
} // namespace arith
} // namespace mlir

Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
attr-dict `:` type($result) }];
}

// Base class for floating point unary predicate operations.
class Arith_FloatPredicateOp<string mnemonic, list<Trait> traits = []> :
Arith_Op<mnemonic,
!listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>, Pure,
TypesMatchWith<"result type has i1 element type and same shape as operands",
"operand", "result", "::getI1SameShape($_self)">], traits)>,
Arguments<(ins FloatLike:$operand,
DefaultValuedAttr<
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
Results<(outs BoolLike:$result)> {
let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
attr-dict `:` type($operand) }];
}

// Base class for arithmetic cast operations. Requires a single operand and
// result. If either is a shaped type, then the other must be of the same shape.
class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
Expand Down Expand Up @@ -971,6 +985,22 @@ def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// IsNanOp
//===----------------------------------------------------------------------===//
def Arith_IsNanOp : Arith_FloatPredicateOp<"is_nan"> {
let summary = "Returns true for IEEE NaN inputs";
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// IsInfOp
//===----------------------------------------------------------------------===//
def Arith_IsInfOp : Arith_FloatPredicateOp<"is_inf"> {
let summary = "Returns true for infinite float inputs";
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 9 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,21 @@ def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure]> {
}];
}

def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure]> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$in, I32Attr:$bit);
def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure],
/*requiresFastmath=*/1> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$in,
I32Attr:$kinds,
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location,
$_resultType, $in, $_int_attr($bit));
$res = op;
$_resultType, $in, $_int_attr($kinds), nullptr);
moduleImport.setFastmathFlagsAttr(inst, op);
$res = op;
}];
string llvmBuilder = [{
auto *inst = createIntrinsicCall(
builder, llvm::Intrinsic::}] # llvmEnumName # [{,
{$in, builder.getInt32(op.getBit())},
{$in, builder.getInt32(op.getKinds())},
}] # declTypes # [{);
$res = inst;
}];
Expand Down

0 comments on commit 7c349c3

Please sign in to comment.