From 9883b14cd1a4ea2dec8d7ed30df632671f56c69b Mon Sep 17 00:00:00 2001 From: Denis Khalikov Date: Tue, 7 Jan 2020 21:40:42 -0500 Subject: [PATCH] [mlir][spirv] Add lowering for standard bit ops Differential Revision: https://reviews.llvm.org/D72205 --- .../StandardToSPIRV/StandardToSPIRV.td | 8 +++ .../StandardToSPIRV/std-to-spirv.mlir | 66 ++++++++++++++----- .../Dialect/SPIRV/Serialization/bit-ops.mlir | 9 +++ 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td index fc09cf604fc56..77f9305b4a937 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td @@ -26,6 +26,14 @@ class UnaryOpPattern : def : BinaryOpPattern; def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; +def : BinaryOpPattern; def : BinaryOpPattern; def : BinaryOpPattern; def : BinaryOpPattern; diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir index c07bdd4a97994..e356c5d63f8b0 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -94,6 +94,54 @@ func @div_rem(%arg0 : i32, %arg1 : i32) { return } +//===----------------------------------------------------------------------===// +// std bit ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @bitwise_scalar +func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { + // CHECK: spv.BitwiseAnd + %0 = and %arg0, %arg1 : i32 + // CHECK: spv.BitwiseOr + %1 = or %arg0, %arg1 : i32 + // CHECK: spv.BitwiseXor + %2 = xor %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: @bitwise_vector +func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + // CHECK: spv.BitwiseAnd + %0 = and %arg0, %arg1 : vector<4xi32> + // CHECK: spv.BitwiseOr + %1 = or %arg0, %arg1 : vector<4xi32> + // CHECK: spv.BitwiseXor + %2 = xor %arg0, %arg1 : vector<4xi32> + return +} + +// CHECK-LABEL: @shift_scalar +func @shift_scalar(%arg0 : i32, %arg1 : i32) { + // CHECK: spv.ShiftLeftLogical + %0 = shift_left %arg0, %arg1 : i32 + // CHECK: spv.ShiftRightArithmetic + %1 = shift_right_signed %arg0, %arg1 : i32 + // CHECK: spv.ShiftRightLogical + %2 = shift_right_unsigned %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: @shift_vector +func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { + // CHECK: spv.ShiftLeftLogical + %0 = shift_left %arg0, %arg1 : vector<4xi32> + // CHECK: spv.ShiftRightArithmetic + %1 = shift_right_signed %arg0, %arg1 : vector<4xi32> + // CHECK: spv.ShiftRightLogical + %2 = shift_right_unsigned %arg0, %arg1 : vector<4xi32> + return +} + //===----------------------------------------------------------------------===// // std.cmpi //===----------------------------------------------------------------------===// @@ -156,24 +204,6 @@ func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { return } -// CHECK-LABEL: @logical_scalar_fail -func @logical_scalar_fail(%arg0 : i32, %arg1 : i32) { - // CHECK-NOT: spv.LogicalAnd - %0 = and %arg0, %arg1 : i32 - // CHECK-NOT: spv.LogicalOr - %1 = or %arg0, %arg1 : i32 - return -} - -// CHECK-LABEL: @logical_vector_fail -func @logical_vector_fail(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { - // CHECK-NOT: spv.LogicalAnd - %0 = and %arg0, %arg1 : vector<4xi32> - // CHECK-NOT: spv.LogicalOr - %1 = or %arg0, %arg1 : vector<4xi32> - return -} - //===----------------------------------------------------------------------===// // std.fpext //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir index fa823ea8a9d58..9ca6174ec16a7 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/bit-ops.mlir @@ -31,6 +31,15 @@ spv.module "Logical" "GLSL450" { %0 = spv.Not %arg : i32 spv.ReturnValue %0 : i32 } + func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { + // CHECK: spv.BitwiseAnd + %0 = spv.BitwiseAnd %arg0, %arg1 : i32 + // CHECK: spv.BitwiseOr + %1 = spv.BitwiseOr %arg0, %arg1 : i32 + // CHECK: spv.BitwiseXor + %2 = spv.BitwiseXor %arg0, %arg1 : i32 + spv.Return + } func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 { // CHECK: {{%.*}} = spv.ShiftLeftLogical {{%.*}}, {{%.*}} : i32, i16 %0 = spv.ShiftLeftLogical %arg0, %arg1: i32, i16