diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index d9144d0c5e228..64446fd8146f3 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1544,6 +1544,26 @@ class TanhPattern : public SPIRVToLLVMConversion { } }; +// `llvm.intr.abs` requires an `is_int_min_poison` immarg that `spirv.GL.SAbs` +// does not carry; default to `false` to preserve SPIR-V's well-defined +// behavior on INT_MIN. +class SAbsPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::GLSAbsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperand(), + /*is_int_min_poison=*/false); + return success(); + } +}; + class VariablePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -1867,14 +1887,18 @@ void mlir::populateSPIRVToLLVMConversionPatterns( DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, + DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, + DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, - InverseSqrtPattern, TanPattern, TanhPattern, + DirectConversionPattern, + DirectConversionPattern, + InverseSqrtPattern, SAbsPattern, TanPattern, TanhPattern, // Logical ops DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir index b17e1c40cb9a7..048519e391cb7 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir @@ -117,6 +117,45 @@ spirv.func @sin(%arg0: f32, %arg1: vector<3xf16>) "None" { spirv.Return } +//===----------------------------------------------------------------------===// +// spirv.GL.Pow +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @pow +spirv.func @pow(%arg0: f32, %arg1: vector<3xf16>) "None" { + // CHECK: llvm.intr.pow(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 + %0 = spirv.GL.Pow %arg0, %arg0 : f32 + // CHECK: llvm.intr.pow(%{{.*}}, %{{.*}}) : (vector<3xf16>, vector<3xf16>) -> vector<3xf16> + %1 = spirv.GL.Pow %arg1, %arg1 : vector<3xf16> + spirv.Return +} + +//===----------------------------------------------------------------------===// +// spirv.GL.Fma +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @fma +spirv.func @fma(%arg0: f32, %arg1: vector<3xf16>) "None" { + // CHECK: llvm.intr.fma(%{{.*}}, %{{.*}}, %{{.*}}) : (f32, f32, f32) -> f32 + %0 = spirv.GL.Fma %arg0, %arg0, %arg0 : f32 + // CHECK: llvm.intr.fma(%{{.*}}, %{{.*}}, %{{.*}}) : (vector<3xf16>, vector<3xf16>, vector<3xf16>) -> vector<3xf16> + %1 = spirv.GL.Fma %arg1, %arg1, %arg1 : vector<3xf16> + spirv.Return +} + +//===----------------------------------------------------------------------===// +// spirv.GL.SAbs +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @sabs +spirv.func @sabs(%arg0: i16, %arg1: vector<3xi32>) "None" { + // CHECK: "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i16) -> i16 + %0 = spirv.GL.SAbs %arg0 : i16 + // CHECK: "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (vector<3xi32>) -> vector<3xi32> + %1 = spirv.GL.SAbs %arg1 : vector<3xi32> + spirv.Return +} + //===----------------------------------------------------------------------===// // spirv.GL.SMax //===----------------------------------------------------------------------===// @@ -143,6 +182,32 @@ spirv.func @smin(%arg0: i16, %arg1: vector<3xi32>) "None" { spirv.Return } +//===----------------------------------------------------------------------===// +// spirv.GL.UMax +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @umax +spirv.func @umax(%arg0: i16, %arg1: vector<3xi32>) "None" { + // CHECK: llvm.intr.umax(%{{.*}}, %{{.*}}) : (i16, i16) -> i16 + %0 = spirv.GL.UMax %arg0, %arg0 : i16 + // CHECK: llvm.intr.umax(%{{.*}}, %{{.*}}) : (vector<3xi32>, vector<3xi32>) -> vector<3xi32> + %1 = spirv.GL.UMax %arg1, %arg1 : vector<3xi32> + spirv.Return +} + +//===----------------------------------------------------------------------===// +// spirv.GL.UMin +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @umin +spirv.func @umin(%arg0: i16, %arg1: vector<3xi32>) "None" { + // CHECK: llvm.intr.umin(%{{.*}}, %{{.*}}) : (i16, i16) -> i16 + %0 = spirv.GL.UMin %arg0, %arg0 : i16 + // CHECK: llvm.intr.umin(%{{.*}}, %{{.*}}) : (vector<3xi32>, vector<3xi32>) -> vector<3xi32> + %1 = spirv.GL.UMin %arg1, %arg1 : vector<3xi32> + spirv.Return +} + //===----------------------------------------------------------------------===// // spirv.GL.Sqrt //===----------------------------------------------------------------------===//