Skip to content

Commit

Permalink
Add mechanism to specify extended instruction sets in SPIR-V.
Browse files Browse the repository at this point in the history
Add support for specifying extended instructions sets. The operations
in SPIR-V dialect are named as 'spv.<extension-name>.<op-name>'. Use
this mechanism to define a 'Exp' operation from GLSL(450)
instructions.
Later CLs will add support for (de)serialization of these operations,
and update the dialect generation scripts to auto-generate the
specification using the spec directly.

Additional changes:
Add a Type Constraint to OpBase.td to check for vector of specified
lengths. This is used to check that the vector type used in SPIR-V
dialect are of lengths 2, 3 or 4.
Update SPIRVBase.td to use this Type constraints for vectors.

PiperOrigin-RevId: 269234377
  • Loading branch information
Mahesh Ravishankar authored and tensorflower-gardener committed Sep 16, 2019
1 parent faaa1ce commit 9814b3f
Show file tree
Hide file tree
Showing 10 changed files with 302 additions and 22 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
Expand Up @@ -3,6 +3,11 @@ mlir_tablegen(SPIRVOps.h.inc -gen-op-decls)
mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRSPIRVOpsIncGen)

set(LLVM_TARGET_DEFINITIONS SPIRVGLSLOps.td)
mlir_tablegen(SPIRVGLSLOps.h.inc -gen-op-decls)
mlir_tablegen(SPIRVGLSLOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRSPIRVGLSLOpsIncGen)

set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
Expand Up @@ -202,7 +202,8 @@ def SPV_Void : TypeAlias<NoneType, "void type">;
def SPV_Bool : IntOfWidths<[1]>;
def SPV_Integer : IntOfWidths<[8, 16, 32, 64]>;
def SPV_Float : FloatOfWidths<[16, 32, 64]>;
def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
def SPV_Vector : VectorOfLengthAndType<[2, 3, 4],
[SPV_Bool, SPV_Integer, SPV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
Expand All @@ -219,7 +220,10 @@ def SPV_Type : AnyTypeOf<[
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
]>;

class SPV_ScalarOrVectorOf<Type type> : AnyTypeOf<[type, VectorOf<[type]>]>;
class SPV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;

def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;

// TODO(antiagainst): Use a more appropriate way to model optional operands
class SPV_Optional<Type type> : Variadic<type>;
Expand Down
37 changes: 37 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.h
@@ -0,0 +1,37 @@
//===- SPIRVGLSLOps.h - MLIR SPIR-V extended ops for GLSL --------*- C++-*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file declares the extended operations for GLSL in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_
#define MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_

#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/OpDefinition.h"

namespace mlir {
namespace spirv {

#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h.inc"

} // namespace spirv
} // namespace mlir

#endif // MLIR_DIALECT_SPIRV_SPIRVGLSLOPS_H_
107 changes: 107 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td
@@ -0,0 +1,107 @@
//===- SPIRVGLSLOps.td - GLSL extended insts spec file -----*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the op definition spec of GLSL extension ops.
//
//===----------------------------------------------------------------------===//

#ifdef SPIRV_GLSL_OPS
#else
#define SPIRV_GLSL_OPS

#ifdef SPIRV_BASE
#else
include "mlir/Dialect/SPIRV/SPIRVBase.td"
#endif // SPIRV_BASE

//===----------------------------------------------------------------------===//
// SPIR-V GLSL 4.50 opcode specification.
//===----------------------------------------------------------------------===//

// Base class for all GLSL ops.
class SPV_GLSLOp<string mnemonic, int opcode, list<OpTrait> traits = []> :
SPV_Op<"glsl." # mnemonic, traits> {

// Do not use the default auto-generation serializer/deserializer.
let hasOpcode = 0;

// Opcode within the extended instruction set.
int glslOpcode = opcode;

// Name used to refer to the extended instruction set.
string extensionSetName = "GLSL.std.450";
}

// Base class for GLSL unary ops.
class SPV_GLSLUnaryOp<string mnemonic, Type resultType, Type operandType,
int opcode, list<OpTrait> traits = []> :
SPV_GLSLOp<mnemonic, opcode, traits> {

let arguments = (ins
SPV_ScalarOrVectorOf<operandType>:$operand
);

let results = (outs
SPV_ScalarOrVectorOf<resultType>:$result
);

let parser = [{ return parseGLSLUnaryOp(parser, result); }];

let printer = [{ return printGLSLUnaryOp(getOperation(), p); }];

let verifier = [{ return success(); }];
}

// Base class for GLSL Unary arithmatic ops where return type matches
// the operand type.
class SPV_GLSLUnaryArithmaticOp<string mnemonic, int opcode, Type type,
list<OpTrait> traits = []> :
SPV_GLSLUnaryOp<mnemonic, type, type, opcode, traits>;

// -----

def SPV_GLSLExpOp : SPV_GLSLUnaryArithmaticOp<"Exp", 27, FloatOfWidths<[16, 32]>> {
let summary = "Exponentiation of Operand 1";

let description = [{
Result is the natural exponentiation of x; e^x.

The operand x must be a scalar or vector whose component type is
16-bit or 32-bit floating-point.

Result Type and the type of x must be the same type. Results are
computed per component.";

### Custom assembly format
``` {.ebnf}
restricted-float-scalar-type ::= `f16` | `f32`
restricted-float-scalar-vector-type ::=
restricted-float-scalar-type |
`vector<` integer-literal `x` restricted-float-scalar-type `>`
exp-op ::= ssa-id `=` `spv.glsl.Exp` ssa-use `:`
restricted-float-scalar-vector-type
```
For example:

```
%2 = spv.glsl.Exp %0 : f32
%3 = spv.glsl.Exp %1 : vector<3xf16>
```
}];
}

#endif // SPIRV_GLSL_OPS
26 changes: 26 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Expand Up @@ -377,6 +377,32 @@ class HasAnyRankOfPred<list<int> ranks> : And<[
class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">;

// Whether the number of elements of a vector is from the given
// `allowedLengths` list
class IsVectorOfLengthPred<list<int> allowedLengths> :
And<[IsVectorTypePred,
Or<!foreach(allowedlength, allowedLengths,
CPred<[{$_self.cast<VectorType>().getNumElements()
== }]
# allowedlength>)>]>;

// Any vector where the number of elements is from the given
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
IsVectorOfLengthPred<allowedLengths>,
" of length " # StrJoinInt<allowedLengths, "/">.result>;


// Any vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes`
// list
class VectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : Type<
And<[VectorOf<allowedTypes>.predicate,
VectorOfLength<allowedLengths>.predicate]>,
VectorOf<allowedTypes>.description #
VectorOfLength<allowedLengths>.description>;

def AnyVector : VectorOf<[AnyType]>;

// Tensor types.
Expand Down
25 changes: 5 additions & 20 deletions mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
Expand Up @@ -24,26 +24,11 @@ include "mlir/Dialect/StandardOps/Ops.td"
include "mlir/Dialect/SPIRV/SPIRVOps.td"
#endif // SPIRV_OPS

def IsScalar : TypeConstraint<CPred<"!($_self.isa<ShapedType>())">, "scalar">;
class BinaryOpPattern<Op src, Op tgt> :
Pat<(src SPV_ScalarOrVector:$l, SPV_ScalarOrVector:$r),
(tgt $l, $r)>;

class IsVectorLengthPred<int vecLength> :
CPred<"($_self.cast<VectorType>().getShape().size() == 1 && " #
"$_self.cast<VectorType>().getShape()[0] == " # vecLength # ")">;

class IsVectorOfLength<int vecLength>:
TypeConstraint<And<[IsVectorTypePred, IsVectorLengthPred<vecLength>]>,
vecLength # "-element vector">;

multiclass BinaryOpPattern<Op src, SPV_Op tgt> {
def : Pat<(src IsScalar:$l, IsScalar:$r), (tgt $l, $r)>;
foreach vecLength = [2, 3, 4] in {
def : Pat<(src IsVectorOfLength<vecLength>:$l,
IsVectorOfLength<vecLength>:$r),
(tgt $l, $r)>;
}
}

defm : BinaryOpPattern<AddFOp, SPV_FAddOp>;
defm : BinaryOpPattern<MulFOp, SPV_FMulOp>;
def : BinaryOpPattern<AddFOp, SPV_FAddOp>;
def : BinaryOpPattern<MulFOp, SPV_FMulOp>;

#endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SPIRV/CMakeLists.txt
@@ -1,6 +1,7 @@
add_llvm_library(MLIRSPIRV
DialectRegistration.cpp
SPIRVDialect.cpp
SPIRVGLSLOps.cpp
SPIRVOps.cpp
SPIRVTypes.cpp

Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -41,11 +42,18 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();

// Add SPIR-V ops.
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
>();

// Add SPIR-V extension ops of GLSL.
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.cpp.inc"
>();

// Allow unknown operations because SPIR-V is extensible.
allowUnknownOperations();
}
Expand Down
58 changes: 58 additions & 0 deletions mlir/lib/Dialect/SPIRV/SPIRVGLSLOps.cpp
@@ -0,0 +1,58 @@
//===- SPIRVGLSLOps.cpp - MLIR SPIR-V GLSL extended operations ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the operations in the SPIR-V extended instructions set for
// GLSL
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/OpImplementation.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// spv.glsl.UnaryOp
//===----------------------------------------------------------------------===//

static ParseResult parseGLSLUnaryOp(OpAsmParser *parser,
OperationState *state) {
OpAsmParser::OperandType operandInfo;
Type type;
if (parser->parseOperand(operandInfo) || parser->parseColonType(type) ||
parser->resolveOperands(operandInfo, type, state->operands)) {
return failure();
}
state->addTypes(type);
return success();
}

static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter *printer) {
*printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
<< unaryOp->getOperand(0)->getType();
}

namespace mlir {
namespace spirv {

#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVGLSLOps.cpp.inc"

} // namespace spirv
} // namespace mlir
49 changes: 49 additions & 0 deletions mlir/test/Dialect/SPIRV/glslops.mlir
@@ -0,0 +1,49 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s

//===----------------------------------------------------------------------===//
// spv.glsl.Exp
//===----------------------------------------------------------------------===//

func @exp(%arg0 : f32) -> () {
// CHECK: spv.glsl.Exp {{%.*}} : f32
%2 = spv.glsl.Exp %arg0 : f32
return
}

func @expvec(%arg0 : vector<3xf16>) -> () {
// CHECK: spv.glsl.Exp {{%.*}} : vector<3xf16>
%2 = spv.glsl.Exp %arg0 : vector<3xf16>
return
}

// -----

func @exp(%arg0 : i32) -> () {
// expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values}}
%2 = spv.glsl.Exp %arg0 : i32
return
}

// -----

func @exp(%arg0 : vector<5xf32>) -> () {
// expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
%2 = spv.glsl.Exp %arg0 : vector<5xf32>
return
}

// -----

func @exp(%arg0 : f32, %arg1 : f32) -> () {
// expected-error @+1 {{expected ':'}}
%2 = spv.glsl.Exp %arg0, %arg1 : i32
return
}

// -----

func @exp(%arg0 : i32) -> () {
// expected-error @+2 {{expected non-function type}}
%2 = spv.glsl.Exp %arg0 :
return
}

0 comments on commit 9814b3f

Please sign in to comment.