Skip to content

Commit

Permalink
[spirv] Add Block decoration for spv.struct.
Browse files Browse the repository at this point in the history
Add Block decoration for top-level spv.struct.

Closes tensorflow/mlir#102

PiperOrigin-RevId: 265716241
  • Loading branch information
denis0x0D authored and tensorflower-gardener committed Aug 27, 2019
1 parent 2f59f76 commit 8f2dfb5
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
Expand Up @@ -468,6 +468,13 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
}
typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
break;
case spirv::Decoration::Block:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
// Block decoration does not affect spv.struct type.
break;
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
}
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
Expand Up @@ -174,6 +174,10 @@ class Serializer {

bool isVoidType(Type type) const { return type.isa<NoneType>(); }

/// Returns true if the given type is a pointer type to a struct in Uniform or
/// StorageBuffer storage class.
bool isInterfaceStructPtrType(Type type) const;

/// Main dispatch method for serializing a type. The result <id> of the
/// serialized type will be returned as `typeID`.
LogicalResult processType(Location loc, Type type, uint32_t &typeID);
Expand Down Expand Up @@ -558,6 +562,22 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
return failure();
}

if (isInterfaceStructPtrType(varOp.type())) {
auto structType = varOp.type()
.cast<spirv::PointerType>()
.getPointeeType()
.cast<spirv::StructType>();
SmallVector<uint32_t, 2> args{
findTypeID(structType),
static_cast<uint32_t>(spirv::Decoration::Block)};
if (failed(encodeInstructionInto(decorations, spirv::Opcode::OpDecorate,
args))) {
return varOp.emitError("cannot decorate ")
<< structType << " with Block decoration";
}
}

elidedAttrs.push_back("type");
SmallVector<uint32_t, 4> operands;
operands.push_back(resultTypeID);
Expand Down Expand Up @@ -609,6 +629,17 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
// Type
//===----------------------------------------------------------------------===//

bool Serializer::isInterfaceStructPtrType(Type type) const {
if (auto ptrType = type.dyn_cast<spirv::PointerType>()) {
auto storageClass = ptrType.getStorageClass();
if (storageClass == spirv::StorageClass::Uniform ||
storageClass == spirv::StorageClass::StorageBuffer) {
return ptrType.getPointeeType().isa<spirv::StructType>();
}
}
return false;
}

LogicalResult Serializer::processType(Location loc, Type type,
uint32_t &typeID) {
typeID = findTypeID(type);
Expand Down
3 changes: 3 additions & 0 deletions mlir/unittests/Dialect/SPIRV/CMakeLists.txt
@@ -1,8 +1,11 @@
add_mlir_unittest(MLIRSPIRVTests
DeserializationTest.cpp
SerializationTest.cpp
)
target_link_libraries(MLIRSPIRVTests
PRIVATE
MLIRSPIRV
MLIRSPIRVSerialization)

whole_archive_link(MLIRSPIRVTests MLIRSPIRV)

124 changes: 124 additions & 0 deletions mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -0,0 +1,124 @@
//===- SerializationTest.cpp - SPIR-V Seserialization Tests -------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains corner case tests for the SPIR-V serializer that are not
// covered by normal serialization and deserialization roundtripping.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "gmock/gmock.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Test Fixture
//===----------------------------------------------------------------------===//

class SerializationTest : public ::testing::Test {
protected:
SerializationTest() { createModuleOp(); }

void createModuleOp() {
Builder builder(&context);
OperationState state(UnknownLoc::get(&context),
spirv::ModuleOp::getOperationName());
state.addAttribute("addressing_model",
builder.getI32IntegerAttr(static_cast<uint32_t>(
spirv::AddressingModel::Logical)));
state.addAttribute("memory_model",
builder.getI32IntegerAttr(
static_cast<uint32_t>(spirv::MemoryModel::GLSL450)));
spirv::ModuleOp::build(&builder, &state);
module = cast<spirv::ModuleOp>(Operation::create(state));
}

Type getFloatStructType() {
OpBuilder opBuilder(module.body());
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
llvm::SmallVector<spirv::StructType::LayoutInfo, 1> layoutInfo{0};
auto structType = spirv::StructType::get(elementTypes, layoutInfo);
return structType;
}

void addGlobalVar(Type type, llvm::StringRef name) {
OpBuilder opBuilder(module.body());
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
opBuilder.create<spirv::GlobalVariableOp>(
UnknownLoc::get(&context), opBuilder.getTypeAttr(ptrType),
opBuilder.getStringAttr(name), nullptr);
}

bool findInstruction(llvm::function_ref<bool(spirv::Opcode opcode,
ArrayRef<uint32_t> operands)>
matchFn) {
auto binarySize = binary.size();
auto begin = binary.begin();
auto currOffset = spirv::kHeaderWordCount;

while (currOffset < binarySize) {
auto wordCount = binary[currOffset] >> 16;
if (!wordCount || (currOffset + wordCount > binarySize)) {
return false;
}
spirv::Opcode opcode =
static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);

if (matchFn(opcode,
llvm::ArrayRef<uint32_t>(begin + currOffset + 1,
begin + currOffset + wordCount))) {
return true;
}
currOffset += wordCount;
}
return false;
}

protected:
MLIRContext context;
spirv::ModuleOp module;
SmallVector<uint32_t, 0> binary;
};

//===----------------------------------------------------------------------===//
// Block decoration
//===----------------------------------------------------------------------===//

TEST_F(SerializationTest, BlockDecorationTest) {
auto structType = getFloatStructType();
addGlobalVar(structType, "var0");
ASSERT_TRUE(succeeded(spirv::serialize(module, binary)));
auto hasBlockDecoration = [](spirv::Opcode opcode,
ArrayRef<uint32_t> operands) -> bool {
if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2)
return false;
return operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
};
EXPECT_TRUE(findInstruction(hasBlockDecoration));
}

0 comments on commit 8f2dfb5

Please sign in to comment.