Skip to content

Commit

Permalink
[spirv] Add support for extension (de)serialization
Browse files Browse the repository at this point in the history
Only a few important KHR extensions are registered to the
SPIR-V dialect for now.

PiperOrigin-RevId: 264939428
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Aug 22, 2019
1 parent 986f930 commit 51cbf97
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 19 deletions.
62 changes: 45 additions & 17 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
Expand Up @@ -74,6 +74,7 @@ class SPV_OpCode<string name, int val> {

def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>;
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
Expand Down Expand Up @@ -135,23 +136,24 @@ def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;

def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid,
SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector,
SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub,
SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv,
SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem,
SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn, SPV_OC_OpReturnValue
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpMemoryModel,
SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability,
SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat,
SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpReturn, SPV_OC_OpReturnValue
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
Expand Down Expand Up @@ -205,6 +207,32 @@ def SPV_IsEntryPointType :
CPred<"$_self.isa<::mlir::spirv::EntryPointType>()">;
def SPV_EntryPoint : Type<SPV_IsEntryPointType, "SPIR-V entry point type">;

//===----------------------------------------------------------------------===//
// SPIR-V extension definitions
//===----------------------------------------------------------------------===//

// Extensions known to the SPIR-V dialect.
// https://github.com/KhronosGroup/SPIRV-Registry has the full list.
def SPV_KHR_16bit_storage : StrEnumAttrCase<"SPV_KHR_16bit_storage">;
def SPV_KHR_8bit_storage : StrEnumAttrCase<"SPV_KHR_8bit_storage">;
def SPV_KHR_float_controls : StrEnumAttrCase<"SPV_KHR_float_controls">;
def SPV_KHR_shader_atomic_counter_ops : StrEnumAttrCase<"SPV_KHR_shader_atomic_counter_ops">;
def SPV_KHR_shader_ballot : StrEnumAttrCase<"SPV_KHR_shader_ballot">;
def SPV_KHR_storage_buffer_storage_class : StrEnumAttrCase<"SPV_KHR_storage_buffer_storage_class">;
def SPV_KHR_subgroup_vote : StrEnumAttrCase<"SPV_KHR_subgroup_vote">;
def SPV_KHR_variable_pointers : StrEnumAttrCase<"SPV_KHR_variable_pointers">;
def SPV_KHR_vulkan_memory_model : StrEnumAttrCase<"SPV_KHR_vulkan_memory_model">;

def SPV_ExtensionAttr :
StrEnumAttr<"Extension", "supported SPIR-V extensions", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_float_controls,
SPV_KHR_shader_atomic_counter_ops, SPV_KHR_shader_ballot,
SPV_KHR_storage_buffer_storage_class, SPV_KHR_subgroup_vote,
SPV_KHR_variable_pointers, SPV_KHR_vulkan_memory_model
]> {
let cppNamespace = "::mlir::spirv";
}

//===----------------------------------------------------------------------===//
// SPIR-V enum definitions
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Expand Up @@ -1037,6 +1037,16 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
}
}

// Verify extensions. ODS already guarantees that we have an array of
// string attributes.
if (auto exts = moduleOp.getAttrOfType<ArrayAttr>("extensions")) {
for (auto ext : exts.getValue()) {
auto extStr = ext.cast<StringAttr>().getValue();
if (!spirv::symbolizeExtension(extStr))
return moduleOp.emitOpError("uses unknown extension: ") << extStr;
}
}

return success();
}

Expand Down
42 changes: 40 additions & 2 deletions mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -1,4 +1,3 @@
//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
//
// Copyright 2019 The MLIR Authors.
//
Expand Down Expand Up @@ -84,6 +83,13 @@ class Deserializer {
/// Attaches all collected capabilites to `module` as an attribute.
void attachCapabilities();

/// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping
/// in the deserializer.
LogicalResult processExtension(ArrayRef<uint32_t> operands);

/// Attaches all collected extensions to `module` as an attribute.
void attachExtensions();

/// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);

Expand Down Expand Up @@ -236,6 +242,9 @@ class Deserializer {
/// The list of capabilities used by the module.
llvm::SmallSetVector<spirv::Capability, 4> capabilities;

/// The list of extensions used by the module.
llvm::SmallSetVector<StringRef, 2> extensions;

// Result <id> to type mapping.
DenseMap<uint32_t, Type> typeMap;

Expand Down Expand Up @@ -316,8 +325,9 @@ LogicalResult Deserializer::deserialize() {
}
}

// Attaches the capabilities as an attribute to the module.
// Attaches the capabilities/extensions as an attribute to the module.
attachCapabilities();
attachExtensions();

return success();
}
Expand Down Expand Up @@ -377,6 +387,32 @@ void Deserializer::attachCapabilities() {
module->setAttr("capabilities", opBuilder.getStrArrayAttr(caps));
}

LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> operands) {
if (operands.empty()) {
return emitError(
unknownLoc,
"OpExtension must have a literal string for the extension name");
}

unsigned wordIndex = 0;
StringRef extName = decodeStringLiteral(operands, wordIndex);
if (wordIndex != operands.size()) {
return emitError(unknownLoc,
"unexpected trailing words in OpExtension instruction");
}

extensions.insert(extName);
return success();
}

void Deserializer::attachExtensions() {
if (extensions.empty())
return;

module->setAttr("extensions",
opBuilder.getStrArrayAttr(extensions.getArrayRef()));
}

LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc, "OpMemoryModel must have two operands");
Expand Down Expand Up @@ -1144,6 +1180,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
switch (opcode) {
case spirv::Opcode::OpCapability:
return processCapability(operands);
case spirv::Opcode::OpExtension:
return processExtension(operands);
case spirv::Opcode::OpMemoryModel:
return processMemoryModel(operands);
case spirv::Opcode::OpEntryPoint:
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
Expand Up @@ -132,6 +132,8 @@ class Serializer {

void processCapability();

void processExtension();

void processMemoryModel();

LogicalResult processConstantOp(spirv::ConstantOp op);
Expand Down Expand Up @@ -321,6 +323,7 @@ LogicalResult Serializer::serialize() {

// TODO(antiagainst): handle the other sections
processCapability();
processExtension();
processMemoryModel();

// Iterate over the module body to serialze it. Assumptions are that there is
Expand Down Expand Up @@ -372,6 +375,20 @@ void Serializer::processCapability() {
}
}

void Serializer::processExtension() {
auto exts = module.getAttrOfType<ArrayAttr>("extensions");
if (!exts)
return;

SmallVector<uint32_t, 16> extName;
for (auto ext : exts.getValue()) {
auto extStr = ext.cast<StringAttr>().getValue();
extName.clear();
encodeStringLiteralInto(extName, extStr);
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
}
}

void Serializer::processMemoryModel() {
uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt();
uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt();
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/SPIRV/Serialization/extension.mlir
@@ -0,0 +1,8 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s

spv.module "Logical" "GLSL450" {
} attributes {
// CHECK: extensions = ["SPV_KHR_float_controls", "SPV_KHR_subgroup_vote"]
extensions = ["SPV_KHR_float_controls", "SPV_KHR_subgroup_vote"]
}

8 changes: 8 additions & 0 deletions mlir/test/Dialect/SPIRV/structure-ops.mlir
Expand Up @@ -375,6 +375,14 @@ spv.module "Logical" "GLSL450" {

// -----

// expected-error @+1 {{uses unknown extension: MyAwesomeExtension}}
spv.module "Logical" "GLSL450" {
} attributes {
extensions = ["MyAwesomeExtension"]
}

// -----

//===----------------------------------------------------------------------===//
// spv._module_end
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 51cbf97

Please sign in to comment.