Skip to content

Commit

Permalink
[spirv] Add bit ops
Browse files Browse the repository at this point in the history
This CL added op definitions for a few bit operations:

* OpShiftLeftLogical
* OpShiftRightArithmetic
* OpShiftRightLogical
* OpBitCount
* OpBitReverse
* OpNot

Also moved the definition of spv.BitwiseAnd to follow the
lexicographical order.

Closes tensorflow/mlir#215

COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#215 from denis0x0D:sandbox/bit_ops d9b0852b689ac6c4879a9740b1740a2357f44d24
PiperOrigin-RevId: 279350470
  • Loading branch information
denis0x0D authored and tensorflower-gardener committed Nov 8, 2019
1 parent 24f306a commit 4697d65
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 12 deletions.
10 changes: 9 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
Expand Up @@ -166,9 +166,15 @@ def SPV_OC_OpFOrdLessThanEqual : I32EnumAttrCase<"OpFOrdLessThanEqual", 188
def SPV_OC_OpFUnordLessThanEqual : I32EnumAttrCase<"OpFUnordLessThanEqual", 189>;
def SPV_OC_OpFOrdGreaterThanEqual : I32EnumAttrCase<"OpFOrdGreaterThanEqual", 190>;
def SPV_OC_OpFUnordGreaterThanEqual : I32EnumAttrCase<"OpFUnordGreaterThanEqual", 191>;
def SPV_OC_OpShiftRightLogical : I32EnumAttrCase<"OpShiftRightLogical", 194>;
def SPV_OC_OpShiftRightArithmetic : I32EnumAttrCase<"OpShiftRightArithmetic", 195>;
def SPV_OC_OpShiftLeftLogical : I32EnumAttrCase<"OpShiftLeftLogical", 196>;
def SPV_OC_OpBitwiseOr : I32EnumAttrCase<"OpBitwiseOr", 197>;
def SPV_OC_OpBitwiseXor : I32EnumAttrCase<"OpBitwiseXor", 198>;
def SPV_OC_OpBitwiseAnd : I32EnumAttrCase<"OpBitwiseAnd", 199>;
def SPV_OC_OpNot : I32EnumAttrCase<"OpNot", 200>;
def SPV_OC_OpBitReverse : I32EnumAttrCase<"OpBitReverse", 204>;
def SPV_OC_OpBitCount : I32EnumAttrCase<"OpBitCount", 205>;
def SPV_OC_OpControlBarrier : I32EnumAttrCase<"OpControlBarrier", 224>;
def SPV_OC_OpMemoryBarrier : I32EnumAttrCase<"OpMemoryBarrier", 225>;
def SPV_OC_OpPhi : I32EnumAttrCase<"OpPhi", 245>;
Expand Down Expand Up @@ -213,7 +219,9 @@ def SPV_OpcodeAttr :
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd,
SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic,
SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor,
SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitReverse, SPV_OC_OpBitCount,
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpPhi,
SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,
SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue,
Expand Down
264 changes: 253 additions & 11 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVBitOps.td
Expand Up @@ -33,6 +33,124 @@ class SPV_BitBinaryOp<string mnemonic, list<OpTrait> traits = []> :
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultType])>;

class SPV_BitUnaryOp<string mnemonic, list<OpTrait> traits = []> :
SPV_UnaryOp<mnemonic, SPV_Integer, SPV_Integer,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultType])>;

class SPV_ShiftOp<string mnemonic, list<OpTrait> traits = []> :
SPV_BinaryOp<mnemonic, SPV_Integer, SPV_Integer,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultShape])> {
let parser = [{ return ::parseShiftOp(parser, result); }];
let printer = [{ ::printShiftOp(this->getOperation(), p); }];
let verifier = [{ return ::verifyShiftOp(this->getOperation()); }];
}

// -----

def SPV_BitCountOp : SPV_BitUnaryOp<"BitCount", []> {
let summary = "Count the number of set bits in an object.";

let description = [{
Results are computed per component.

Result Type must be a scalar or vector of integer type. The components
must be wide enough to hold the unsigned Width of Base as an unsigned
value. That is, no sign bit is needed or counted when checking for a
wide enough result width.

Base must be a scalar or vector of integer type. It must have the same
number of components as Result Type.

The result is the unsigned value that is the number of bits in Base that
are 1.

### Custom assembly form

``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
bitcount-op ::= ssa-id `=` `spv.BitCount` ssa-use
`:` integer-scalar-vector-type
```

For example:

```
%2 = spv.BitCount %0: i32
%3 = spv.BitCount %1: vector<4xi32>
```
}];
}

// -----

def SPV_BitReverseOp : SPV_BitUnaryOp<"BitReverse", []> {
let summary = "Reverse the bits in an object.";

let description = [{
Results are computed per component.

Result Type must be a scalar or vector of integer type.

The type of Base must be the same as Result Type.

The bit-number n of the result will be taken from bit-number Width - 1 -
n of Base, where Width is the OpTypeInt operand of the Result Type.

### Custom assembly form

``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
bitreverse-op ::= ssa-id `=` `spv.BitReverse` ssa-use
`:` integer-scalar-vector-type
```

For example:

```
%2 = spv.BitReverse %0 : i32
%3 = spv.BitReverse %1 : vector<4xi32>
```
}];
}

// -----

def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd", [Commutative]> {
let summary = [{
Result is 1 if both Operand 1 and Operand 2 are 1. Result is 0 if either
Operand 1 or Operand 2 are 0.
}];

let description = [{
Results are computed per component, and within each component, per bit.

Result Type must be a scalar or vector of integer type. The type of
Operand 1 and Operand 2 must be a scalar or vector of integer type.
They must have the same number of components as Result Type. They must
have the same component width as Result Type.

### Custom assembly form

``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
bitwise-and-op ::= ssa-id `=` `spv.BitwiseAnd` ssa-use, ssa-use
`:` integer-scalar-vector-type
```

For example:

```
%2 = spv.BitwiseAnd %0, %1 : i32
%2 = spv.BitwiseAnd %0, %1 : vector<4xi32>
```
}];
}

// -----

def SPV_BitwiseOrOp : SPV_BitBinaryOp<"BitwiseOr", [Commutative]> {
Expand Down Expand Up @@ -103,34 +221,158 @@ def SPV_BitwiseXorOp : SPV_BitBinaryOp<"BitwiseXor", [Commutative]> {

// -----

def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd", [Commutative]> {
def SPV_ShiftLeftLogicalOp : SPV_ShiftOp<"ShiftLeftLogical", []> {
let summary = [{
Result is 1 if both Operand 1 and Operand 2 are 1. Result is 0 if either
Operand 1 or Operand 2 are 0.
Shift the bits in Base left by the number of bits specified in Shift.
The least-significant bits will be zero filled.
}];

let description = [{
Result Type must be a scalar or vector of integer type.

The type of each Base and Shift must be a scalar or vector of integer
type. Base and Shift must have the same number of components. The
number of components and bit width of the type of Base must be the same
as in Result Type.

Shift is treated as unsigned. The result is undefined if Shift is
greater than or equal to the bit width of the components of Base.

The number of components and bit width of Result Type must match those
Base type. All types must be integer types.

Results are computed per component.

### Custom assembly form

``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
shift-left-logical-op ::= ssa-id `=` `spv.ShiftLeftLogical`
ssa-use `,` ssa-use `:`
integer-scalar-vector-type `,`
integer-scalar-vector-type
```

For example:

```
%2 = spv.ShiftLeftLogical %0, %1 : i32, i16
%5 = spv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
}

// -----

def SPV_ShiftRightArithmeticOp : SPV_ShiftOp<"ShiftRightArithmetic", []> {
let summary = [{
Shift the bits in Base right by the number of bits specified in Shift.
The most-significant bits will be filled with the sign bit from Base.
}];

let description = [{
Result Type must be a scalar or vector of integer type.

The type of each Base and Shift must be a scalar or vector of integer
type. Base and Shift must have the same number of components. The
number of components and bit width of the type of Base must be the same
as in Result Type.

Shift is treated as unsigned. The result is undefined if Shift is
greater than or equal to the bit width of the components of Base.

Results are computed per component.

### Custom assembly form

``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
shift-right-arithmetic-op ::= ssa-id `=` `spv.ShiftRightArithmetic`
ssa-use `,` ssa-use `:`
integer-scalar-vector-type `,`
integer-scalar-vector-type
```

For example:

```
%2 = spv.ShiftRightArithmetic %0, %1 : i32, i16
%5 = spv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
}

// -----

def SPV_ShiftRightLogicalOp : SPV_ShiftOp<"ShiftRightLogical", []> {
let summary = [{
Shift the bits in Base right by the number of bits specified in Shift.
The most-significant bits will be zero filled.
}];

let description = [{
Result Type must be a scalar or vector of integer type.

The type of each Base and Shift must be a scalar or vector of integer
type. Base and Shift must have the same number of components. The
number of components and bit width of the type of Base must be the same
as in Result Type.

Shift is consumed as an unsigned integer. The result is undefined if
Shift is greater than or equal to the bit width of the components of
Base.

Results are computed per component.

### Custom assembly form

``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
shift-right-logical-op ::= ssa-id `=` `spv.ShiftRightLogical`
ssa-use `,` ssa-use `:`
integer-scalar-vector-type `,`
integer-scalar-vector-type
```

For example:

```
%2 = spv.ShiftRightLogical %0, %1 : i32, i16
%5 = spv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
}

// -----

def SPV_NotOp : SPV_BitUnaryOp<"Not", []> {
let summary = "Complement the bits of Operand.";

let description = [{
Results are computed per component, and within each component, per bit.

Result Type must be a scalar or vector of integer type. The type of
Operand 1 and Operand 2 must be a scalar or vector of integer type.
They must have the same number of components as Result Type. They must
have the same component width as Result Type.
Result Type must be a scalar or vector of integer type.

Operand’s type must be a scalar or vector of integer type. It must
have the same number of components as Result Type. The component width
must equal the component width in Result Type.

### Custom assembly form

``` {.ebnf}
integer-scalar-vector-type ::= integer-type |
`vector<` integer-literal `x` integer-type `>`
bitwise-and-op ::= ssa-id `=` `spv.BitwiseAnd` ssa-use, ssa-use
`:` integer-scalar-vector-type
not-op ::= ssa-id `=` `spv.BitNot` ssa-use `:` integer-scalar-vector-type
```

For example:

```
%2 = spv.BitwiseAnd %0, %1 : i32
%2 = spv.BitwiseAnd %0, %1 : vector<4xi32>
%2 = spv.Not %0 : i32
%3 = spv.Not %1 : vector<4xi32>
```
}];
}
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Expand Up @@ -449,6 +449,40 @@ static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
printer << " : " << logicalOp->getOperand(0)->getType();
}

static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
Type baseType;
Type shiftType;
auto loc = parser.getCurrentLocation();

if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
parser.parseType(baseType) || parser.parseComma() ||
parser.parseType(shiftType) ||
parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
state.operands)) {
return failure();
}
state.addTypes(baseType);
return success();
}

static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
Value *base = op->getOperand(0);
Value *shift = op->getOperand(1);
printer << op->getName() << ' ' << *base << ", " << *shift << " : "
<< base->getType() << ", " << shift->getType();
}

static LogicalResult verifyShiftOp(Operation *op) {
if (op->getOperand(0)->getType() != op->getResult(0)->getType()) {
return op->emitError("expected the same type for the first operand and "
"result, but provided ")
<< op->getOperand(0)->getType() << " and "
<< op->getResult(0)->getType();
}
return success();
}

//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Dialect/SPIRV/Serialization/bit-ops.td
@@ -0,0 +1,34 @@
// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s

spv.module "Logical" "GLSL450" {
func @bitcount(%arg: i32) -> i32 {
// CHECK: spv.BitCount {{%.*}} : i32
%0 = spv.BitCount %arg : i32
spv.ReturnValue %0 : i32
}
func @bitreverse(%arg: i32) -> i32 {
// CHECK: spv.BitReverse {{%.*}} : i32
%0 = spv.BitReverse %arg : i32
spv.ReturnValue %0 : i32
}
func @not(%arg: i32) -> i32 {
// CHECK: spv.Not {{%.*}} : i32
%0 = spv.Not %arg : i32
spv.ReturnValue %0 : i32
}
func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 {
// CHECK: {{%.*}} = spv.ShiftLeftLogical {{%.*}}, {{%.*}} : i32, i16
%0 = spv.ShiftLeftLogical %arg0, %arg1: i32, i16
spv.ReturnValue %0 : i32
}
func @shift_right_aritmethic(%arg0: vector<4xi32>, %arg1 : vector<4xi8>) -> vector<4xi32> {
// CHECK: {{%.*}} = spv.ShiftRightArithmetic {{%.*}}, {{%.*}} : vector<4xi32>, vector<4xi8>
%0 = spv.ShiftRightArithmetic %arg0, %arg1: vector<4xi32>, vector<4xi8>
spv.ReturnValue %0 : vector<4xi32>
}
func @shift_right_logical(%arg0: vector<2xi32>, %arg1 : vector<2xi8>) -> vector<2xi32> {
// CHECK: {{%.*}} = spv.ShiftRightLogical {{%.*}}, {{%.*}} : vector<2xi32>, vector<2xi8>
%0 = spv.ShiftRightLogical %arg0, %arg1: vector<2xi32>, vector<2xi8>
spv.ReturnValue %0 : vector<2xi32>
}
}

0 comments on commit 4697d65

Please sign in to comment.