diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td index 55b3a67bc336c..437805b508ecf 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -22,7 +22,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // Base class for all GLSL ops. class SPV_GLSLOp traits = []> : - SPV_ExtInstOp; + SPV_ExtInstOp { + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Shader]> + ]; +} // Base class for GLSL unary ops. class SPV_GLSLUnaryOp traits = []> : - SPV_ExtInstOp; + SPV_ExtInstOp { + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; +} // Base class for OpenCL unary ops. class SPV_OCLUnaryOp { + let summary = "Compute hyperbolic tangent of x radians."; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + tanh-op ::= ssa-id `=` `spv.OCL.tanh` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.tanh %0 : f32 + %3 = spv.OCL.tanh %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_OCLCeilOp : SPV_OCLUnaryArithmeticOp<"ceil", 12, SPV_Float> { + let summary = [{ + Round x to integral value using the round to positive infinity rounding + mode. + }]; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + ceil-op ::= ssa-id `=` `spv.OCL.ceil` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.ceil %0 : f32 + %3 = spv.OCL.ceil %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> { let summary = "Compute the cosine of x radians."; @@ -93,7 +164,7 @@ def SPV_OCLCosOp : SPV_OCLUnaryArithmeticOp<"cos", 14, SPV_Float> { ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - abs-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:` + cos-op ::= ssa-id `=` `spv.OCL.cos` ssa-use `:` float-scalar-vector-type ```mlir @@ -168,6 +239,39 @@ def SPV_OCLFAbsOp : SPV_OCLUnaryArithmeticOp<"fabs", 23, SPV_Float> { // ----- +def SPV_OCLFloorOp : SPV_OCLUnaryArithmeticOp<"floor", 25, SPV_Float> { + let summary = [{ + Round x to the integral value using the round to negative infinity + rounding mode. + }]; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + floor-op ::= ssa-id `=` `spv.OCL.floor` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.floor %0 : f32 + %3 = spv.OCL.ceifloorl %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> { let summary = "Compute the natural logarithm of x."; @@ -183,7 +287,7 @@ def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> { ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - abs-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:` + log-op ::= ssa-id `=` `spv.OCL.log` ssa-use `:` float-scalar-vector-type ```mlir @@ -198,6 +302,67 @@ def SPV_OCLLogOp : SPV_OCLUnaryArithmeticOp<"log", 37, SPV_Float> { // ----- +def SPV_OCLPowOp : SPV_OCLBinaryArithmeticOp<"pow", 48, SPV_Float> { + let summary = "Compute x to the power y."; + + let description = [{ + Result Type, x and y must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + restricted-float-scalar-type ::= `f16` | `f32` + restricted-float-scalar-vector-type ::= + restricted-float-scalar-type | + `vector<` integer-literal `x` restricted-float-scalar-type `>` + pow-op ::= ssa-id `=` `spv.OCL.pow` ssa-use `:` + restricted-float-scalar-vector-type + ``` + #### Example: + + ```mlir + %2 = spv.OCL.pow %0, %1 : f32 + %3 = spv.OCL.pow %0, %1 : vector<3xf16> + ``` + }]; +} + +// ----- + +def SPV_OCLRsqrtOp : SPV_OCLUnaryArithmeticOp<"rsqrt", 56, SPV_Float> { + let summary = "Compute inverse square root of x."; + + let description = [{ + Result Type and x must be floating-point or vector(2,3,4,8,16) of + floating-point values. + + All of the operands, including the Result Type operand, must be of the + same type. + + + + ``` + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + rsqrt-op ::= ssa-id `=` `spv.OCL.rsqrt` ssa-use `:` + float-scalar-vector-type + ```mlir + + #### Example: + + ``` + %2 = spv.OCL.rsqrt %0 : f32 + %3 = spv.OCL.rsqrt %1 : vector<3xf16> + ``` + }]; +} + +// ----- + def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> { let summary = "Compute sine of x radians."; @@ -213,7 +378,7 @@ def SPV_OCLSinOp : SPV_OCLUnaryArithmeticOp<"sin", 57, SPV_Float> { ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - abs-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:` + sin-op ::= ssa-id `=` `spv.OCL.sin` ssa-use `:` float-scalar-vector-type ```mlir @@ -243,7 +408,7 @@ def SPV_OCLSqrtOp : SPV_OCLUnaryArithmeticOp<"sqrt", 61, SPV_Float> { ``` float-scalar-vector-type ::= float-type | `vector<` integer-literal `x` float-type `>` - abs-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:` + sqrt-op ::= ssa-id `=` `spv.OCL.sqrt` ssa-use `:` float-scalar-vector-type ```mlir diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index d1aa7865f88b1..9e96829e79c1d 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -34,6 +34,7 @@ namespace { /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to /// these operations. +template class Log1pOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -48,7 +49,7 @@ class Log1pOpPattern final : public OpConversionPattern { auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); auto onePlus = rewriter.create(loc, one, adaptor.getOperands()[0]); - rewriter.replaceOpWithNewOp(operation, type, onePlus); + rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } }; @@ -61,8 +62,10 @@ class Log1pOpPattern final : public OpConversionPattern { namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { + + // GLSL patterns patterns.add< - Log1pOpPattern, + Log1pOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, @@ -75,6 +78,21 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern>( typeConverter, patterns.getContext()); + + // OpenCL patterns + patterns.add, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern>( + typeConverter, patterns.getContext()); } } // namespace mlir diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir index 8a41a90a2fc0b..5ab13e780d0d4 100644 --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -6,7 +6,7 @@ module attributes { spv.target_env = #spv.target_env< - #spv.vce, {}> + #spv.vce, {}> } { // Check integer operation conversions. diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir similarity index 95% rename from mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir rename to mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir index 65bdc18909e35..ad32a88a876ea 100644 --- a/mlir/test/Conversion/MathToSPIRV/math-to-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s +module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { + // CHECK-LABEL: @float32_unary_scalar func @float32_unary_scalar(%arg0: f32) { // CHECK: spv.GLSL.Cos %{{.*}}: f32 @@ -59,3 +61,5 @@ func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) { %0 = math.powf %lhs, %rhs : vector<4xf32> return } + +} // end module diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir new file mode 100644 index 0000000000000..8a1a3acc5f0cd --- /dev/null +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s + +module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { + +// CHECK-LABEL: @float32_unary_scalar +func @float32_unary_scalar(%arg0: f32) { + // CHECK: spv.OCL.cos %{{.*}}: f32 + %0 = math.cos %arg0 : f32 + // CHECK: spv.OCL.exp %{{.*}}: f32 + %1 = math.exp %arg0 : f32 + // CHECK: spv.OCL.log %{{.*}}: f32 + %2 = math.log %arg0 : f32 + // CHECK: %[[ONE:.+]] = spv.Constant 1.000000e+00 : f32 + // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} + // CHECK: spv.OCL.log %[[ADDONE]] + %3 = math.log1p %arg0 : f32 + // CHECK: spv.OCL.rsqrt %{{.*}}: f32 + %4 = math.rsqrt %arg0 : f32 + // CHECK: spv.OCL.sqrt %{{.*}}: f32 + %5 = math.sqrt %arg0 : f32 + // CHECK: spv.OCL.tanh %{{.*}}: f32 + %6 = math.tanh %arg0 : f32 + // CHECK: spv.OCL.sin %{{.*}}: f32 + %7 = math.sin %arg0 : f32 + return +} + +// CHECK-LABEL: @float32_unary_vector +func @float32_unary_vector(%arg0: vector<3xf32>) { + // CHECK: spv.OCL.cos %{{.*}}: vector<3xf32> + %0 = math.cos %arg0 : vector<3xf32> + // CHECK: spv.OCL.exp %{{.*}}: vector<3xf32> + %1 = math.exp %arg0 : vector<3xf32> + // CHECK: spv.OCL.log %{{.*}}: vector<3xf32> + %2 = math.log %arg0 : vector<3xf32> + // CHECK: %[[ONE:.+]] = spv.Constant dense<1.000000e+00> : vector<3xf32> + // CHECK: %[[ADDONE:.+]] = spv.FAdd %[[ONE]], %{{.+}} + // CHECK: spv.OCL.log %[[ADDONE]] + %3 = math.log1p %arg0 : vector<3xf32> + // CHECK: spv.OCL.rsqrt %{{.*}}: vector<3xf32> + %4 = math.rsqrt %arg0 : vector<3xf32> + // CHECK: spv.OCL.sqrt %{{.*}}: vector<3xf32> + %5 = math.sqrt %arg0 : vector<3xf32> + // CHECK: spv.OCL.tanh %{{.*}}: vector<3xf32> + %6 = math.tanh %arg0 : vector<3xf32> + // CHECK: spv.OCL.sin %{{.*}}: vector<3xf32> + %7 = math.sin %arg0 : vector<3xf32> + return +} + +// CHECK-LABEL: @float32_binary_scalar +func @float32_binary_scalar(%lhs: f32, %rhs: f32) { + // CHECK: spv.OCL.pow %{{.*}}: f32 + %0 = math.powf %lhs, %rhs : f32 + return +} + +// CHECK-LABEL: @float32_binary_vector +func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) { + // CHECK: spv.OCL.pow %{{.*}}: vector<4xf32> + %0 = math.powf %lhs, %rhs : vector<4xf32> + return +} + +} // end module diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index db11f41c11e43..029c161b3df45 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -6,7 +6,7 @@ module attributes { spv.target_env = #spv.target_env< - #spv.vce, {}> + #spv.vce, {}> } { // Check integer operation conversions. diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir index 37ab7d90e2bc5..c6c9af28f97d6 100644 --- a/mlir/test/Target/SPIRV/ocl-ops.mlir +++ b/mlir/test/Target/SPIRV/ocl-ops.mlir @@ -14,6 +14,14 @@ spv.module Physical64 OpenCL requires #spv.vce { %4 = spv.OCL.log %arg0 : f32 // CHECK: {{%.*}} = spv.OCL.sqrt {{%.*}} : f32 %5 = spv.OCL.sqrt %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.ceil {{%.*}} : f32 + %6 = spv.OCL.ceil %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.floor {{%.*}} : f32 + %7 = spv.OCL.floor %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.pow {{%.*}}, {{%.*}} : f32 + %8 = spv.OCL.pow %arg0, %arg0 : f32 + // CHECK: {{%.*}} = spv.OCL.rsqrt {{%.*}} : f32 + %9 = spv.OCL.rsqrt %arg0 : f32 spv.Return }