Skip to content

Commit

Permalink
[spirv] Add support for capability (de)serialization
Browse files Browse the repository at this point in the history
This CL pulls in capabilities defined in the spec and adds
support for (de)serialize capabilities of a spv.module.

PiperOrigin-RevId: 264877413
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Aug 22, 2019
1 parent b1ce4df commit 27ed82f
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 20 deletions.
241 changes: 225 additions & 16 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Expand Up @@ -1026,6 +1026,17 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
"functions in 'spv.module' can only contain spv.* ops");
}
}

// Verify capabilities. ODS already guarantees that we have an array of
// string attributes.
if (auto caps = moduleOp.getAttrOfType<ArrayAttr>("capabilities")) {
for (auto cap : caps.getValue()) {
auto capStr = cap.cast<StringAttr>().getValue();
if (!spirv::symbolizeCapability(capStr))
return moduleOp.emitOpError("uses unknown capability: ") << capStr;
}
}

return success();
}

Expand Down
42 changes: 42 additions & 0 deletions mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"

Expand Down Expand Up @@ -76,6 +77,13 @@ class Deserializer {
/// Processes SPIR-V module header in `binary`.
LogicalResult processHeader();

/// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping
/// in the deserializer.
LogicalResult processCapability(ArrayRef<uint32_t> operands);

/// Attaches all collected capabilites to `module` as an attribute.
void attachCapabilities();

/// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);

Expand Down Expand Up @@ -225,6 +233,9 @@ class Deserializer {

OpBuilder opBuilder;

/// The list of capabilities used by the module.
llvm::SmallSetVector<spirv::Capability, 4> capabilities;

// Result <id> to type mapping.
DenseMap<uint32_t, Type> typeMap;

Expand Down Expand Up @@ -305,6 +316,9 @@ LogicalResult Deserializer::deserialize() {
}
}

// Attaches the capabilities as an attribute to the module.
attachCapabilities();

return success();
}

Expand Down Expand Up @@ -337,6 +351,32 @@ LogicalResult Deserializer::processHeader() {
return success();
}

LogicalResult Deserializer::processCapability(ArrayRef<uint32_t> operands) {
if (operands.size() != 1)
return emitError(unknownLoc, "OpMemoryModel must have one parameter");

auto cap = spirv::symbolizeCapability(operands[0]);
if (!cap)
return emitError(unknownLoc, "unknown capability: ") << operands[0];

capabilities.insert(*cap);
return success();
}

void Deserializer::attachCapabilities() {
if (capabilities.empty())
return;

SmallVector<StringRef, 2> caps;
caps.reserve(capabilities.size());

for (auto cap : capabilities) {
caps.push_back(spirv::stringifyCapability(cap));
}

module->setAttr("capabilities", opBuilder.getStrArrayAttr(caps));
}

LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc, "OpMemoryModel must have two operands");
Expand Down Expand Up @@ -1102,6 +1142,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
// First dispatch all the instructions whose opcode does not correspond to
// those that have a direct mirror in the SPIR-V dialect
switch (opcode) {
case spirv::Opcode::OpCapability:
return processCapability(operands);
case spirv::Opcode::OpMemoryModel:
return processMemoryModel(operands);
case spirv::Opcode::OpEntryPoint:
Expand Down
23 changes: 19 additions & 4 deletions mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
Expand Up @@ -130,7 +130,9 @@ class Serializer {
return funcIDMap.lookup(fnName);
}

LogicalResult processMemoryModel();
void processCapability();

void processMemoryModel();

LogicalResult processConstantOp(spirv::ConstantOp op);

Expand Down Expand Up @@ -318,6 +320,7 @@ LogicalResult Serializer::serialize() {
return failure();

// TODO(antiagainst): handle the other sections
processCapability();
processMemoryModel();

// Iterate over the module body to serialze it. Assumptions are that there is
Expand Down Expand Up @@ -356,12 +359,24 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
// Module structure
//===----------------------------------------------------------------------===//

LogicalResult Serializer::processMemoryModel() {
void Serializer::processCapability() {
auto caps = module.getAttrOfType<ArrayAttr>("capabilities");
if (!caps)
return;

for (auto cap : caps.getValue()) {
auto capStr = cap.cast<StringAttr>().getValue();
auto capVal = spirv::symbolizeCapability(capStr);
encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
{static_cast<uint32_t>(*capVal)});
}
}

void Serializer::processMemoryModel() {
uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt();
uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt();

return encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel,
{am, mm});
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
}

LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Dialect/SPIRV/Serialization/capability.mlir
@@ -0,0 +1,7 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s

spv.module "Logical" "GLSL450" {
} attributes {
// CHECK: capabilities = ["Shader", "Float16"]
capabilities = ["Shader", "Float16"]
}
8 changes: 8 additions & 0 deletions mlir/test/Dialect/SPIRV/structure-ops.mlir
Expand Up @@ -367,6 +367,14 @@ spv.module "Logical" "VulkanKHR" {

// -----

// expected-error @+1 {{uses unknown capability: MyAwesomeCapability}}
spv.module "Logical" "GLSL450" {
} attributes {
capabilities = ["MyAwesomeCapability"]
}

// -----

//===----------------------------------------------------------------------===//
// spv._module_end
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 27ed82f

Please sign in to comment.