Skip to content

Commit

Permalink
[MLIR][SPIRV] Support identified and recursive structs.
Browse files Browse the repository at this point in the history
This PR adds support for identified and recursive structs.
This includes: parsing, printing, serializing, and
deserializing such structs.

The following C struct:

```C
struct A {
  A* next;
};
```

which is translated to the following MLIR code as:

```mlir
!spv.struct<A, (!spv.ptr<!spv.struct<A>, Generic>)>
```

would be represented in the SPIR-V module as:

```spirv
OpName %A "A"
OpTypeForwardPointer %APtr Generic
%A = OpTypeStruct %APtr
%APtr = OpTypePointer Generic %A
```

In particular the following changes are included:
- SPIR-V structs can now be either identified or literal
  (i.e. non-identified).
- All structs now have their members surrounded by a ()-pair.
- For recursive references,
  (1) an OpTypeForwardPointer instruction is emitted before
  the OpTypeStruct instruction defining the recursive struct
  (2) an OpTypePointer instruction is emitted after the
  OpTypeStruct instruction which actually defines the recursive
  pointer to struct type.

Reviewed By: antiagainst, rriddle, ftynse

Differential Revision: https://reviews.llvm.org/D87206
  • Loading branch information
ergawy authored and antiagainst committed Oct 13, 2020
1 parent 08e4e08 commit bddaa7a
Show file tree
Hide file tree
Showing 41 changed files with 1,052 additions and 409 deletions.
31 changes: 16 additions & 15 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
Expand Up @@ -3155,6 +3155,7 @@ def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 2
def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>;
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
def SPV_OC_OpTypeForwardPointer : I32EnumAttrCase<"OpTypeForwardPointer", 39>;
def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>;
def SPV_OC_OpConstantFalse : I32EnumAttrCase<"OpConstantFalse", 42>;
def SPV_OC_OpConstant : I32EnumAttrCase<"OpConstant", 43>;
Expand Down Expand Up @@ -3302,21 +3303,21 @@ def SPV_OpcodeAttr :
SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix,
SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory,
SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast,
SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd,
SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv,
SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod,
SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpTypeForwardPointer,
SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,
SPV_OC_OpBitcast, SPV_OC_OpSNegate, SPV_OC_OpFNegate, SPV_OC_OpIAdd,
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar,
SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual,
SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
Expand Down
53 changes: 49 additions & 4 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
Expand Up @@ -262,7 +262,24 @@ class RuntimeArrayType
Optional<StorageClass> storage = llvm::None);
};

// SPIR-V struct type
/// SPIR-V struct type. Two kinds of struct types are supported:
/// - Literal: a literal struct type is uniqued by its fields (types + offset
/// info + decoration info).
/// - Identified: an indentified struct type is uniqued by its string identifier
/// (name). This is useful in representing recursive structs. For example, the
/// following C struct:
///
/// struct A {
/// A* next;
/// };
///
/// would be represented in MLIR as:
///
/// !spv.struct<A, (!spv.ptr<!spv.struct<A>, Generic>)>
///
/// In the above, expressing recursive struct types is accomplished by giving a
/// recursive struct a unique identified and using that identifier in the struct
/// definition for recursive references.
class StructType : public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage> {
public:
Expand Down Expand Up @@ -297,13 +314,34 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
}
};

/// Construct a StructType with at least one member.
/// Construct a literal StructType with at least one member.
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});

/// Construct a struct with no members.
static StructType getEmpty(MLIRContext *context);
/// Construct an identified StructType. This creates a StructType whose body
/// (member types, offset info, and decorations) is not set yet. A call to
/// StructType::trySetBody(...) must follow when the StructType contents are
/// available (e.g. parsed or deserialized).
///
/// Note: If another thread creates (or had already created) a struct with the
/// same identifier, that struct will be returned as a result.
static StructType getIdentified(MLIRContext *context, StringRef identifier);

/// Construct a (possibly identified) StructType with no members.
///
/// Note: this method might fail in a multi-threaded setup if another thread
/// created an identified struct with the same identifier but with different
/// contents before returning. In which case, an empty (default-constructed)
/// StructType is returned.
static StructType getEmpty(MLIRContext *context, StringRef identifier = "");

/// For literal structs, return an empty string.
/// For identified structs, return the struct's identifier.
StringRef getIdentifier() const;

/// Returns true if the StructType is identified.
bool isIdentified() const;

unsigned getNumElements() const;

Expand Down Expand Up @@ -346,6 +384,13 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
SmallVectorImpl<StructType::MemberDecorationInfo>
&decorationsInfo) const;

/// Sets the contents of an incomplete identified StructType. This method must
/// be called only for identified StructTypes and it must be called only once
/// per instance. Otherwise, failure() is returned.
LogicalResult
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
ArrayRef<MemberDecorationInfo> memberDecorations = {});

void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
Optional<StorageClass> storage = llvm::None);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/Dialect/SPIRV/LayoutUtils.cpp
Expand Up @@ -67,7 +67,13 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
alignment = maxMemberAlignment;
structType.getMemberDecorations(memberDecorations);
return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);

if (!structType.isIdentified())
return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);

// Identified structs are uniqued by identifier so it is not possible
// to create 2 structs with the same name but different decorations.
return nullptr;
}

Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
Expand Down
123 changes: 109 additions & 14 deletions mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSwitch.h"
Expand Down Expand Up @@ -589,15 +590,80 @@ static ParseResult parseStructMemberDecorations(
}

// struct-member-decoration ::= integer-literal? spirv-decoration*
// struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)?
// (`, ` spirv-type (`[` struct-member-decoration `]`)? `>`
// struct-type ::=
// `!spv.struct<` (id `,`)?
// `(`
// (spirv-type (`[` struct-member-decoration `]`)?)*
// `)>`
static Type parseStructType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
// TODO: This function is quite lengthy. Break it down into smaller chunks.

// To properly resolve recursive references while parsing recursive struct
// types, we need to maintain a list of enclosing struct type names. This set
// maintains the names of struct types in which the type we are about to parse
// is nested.
//
// Note: This has to be thread_local to enable multiple threads to safely
// parse concurrently.
thread_local llvm::SetVector<StringRef> structContext;

static auto removeIdentifierAndFail =
[](llvm::SetVector<StringRef> &structContext, StringRef identifier) {
if (!identifier.empty())
structContext.remove(identifier);

return Type();
};

if (parser.parseLess())
return Type();

if (succeeded(parser.parseOptionalGreater()))
return StructType::getEmpty(dialect.getContext());
StringRef identifier;

// Check if this is an idenitifed struct type.
if (succeeded(parser.parseOptionalKeyword(&identifier))) {
// Check if this is a possible recursive reference.
if (succeeded(parser.parseOptionalGreater())) {
if (structContext.count(identifier) == 0) {
parser.emitError(
parser.getNameLoc(),
"recursive struct reference not nested in struct definition");

return Type();
}

return StructType::getIdentified(dialect.getContext(), identifier);
}

if (failed(parser.parseComma()))
return Type();

if (structContext.count(identifier) != 0) {
parser.emitError(parser.getNameLoc(),
"identifier already used for an enclosing struct");

return removeIdentifierAndFail(structContext, identifier);
}

structContext.insert(identifier);
}

if (failed(parser.parseLParen()))
return removeIdentifierAndFail(structContext, identifier);

if (succeeded(parser.parseOptionalRParen()) &&
succeeded(parser.parseOptionalGreater())) {
if (!identifier.empty())
structContext.remove(identifier);

return StructType::getEmpty(dialect.getContext(), identifier);
}

StructType idStructTy;

if (!identifier.empty())
idStructTy = StructType::getIdentified(dialect.getContext(), identifier);

SmallVector<Type, 4> memberTypes;
SmallVector<StructType::OffsetInfo, 4> offsetInfo;
Expand All @@ -606,24 +672,33 @@ static Type parseStructType(SPIRVDialect const &dialect,
do {
Type memberType;
if (parser.parseType(memberType))
return Type();
return removeIdentifierAndFail(structContext, identifier);
memberTypes.push_back(memberType);

if (succeeded(parser.parseOptionalLSquare())) {
if (succeeded(parser.parseOptionalLSquare()))
if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
memberDecorationInfo)) {
return Type();
}
}
memberDecorationInfo))
return removeIdentifierAndFail(structContext, identifier);
} while (succeeded(parser.parseOptionalComma()));

if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
parser.emitError(parser.getNameLoc(),
"offset specification must be given for all members");
return Type();
return removeIdentifierAndFail(structContext, identifier);
}
if (parser.parseGreater())
return Type();

if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
return removeIdentifierAndFail(structContext, identifier);

if (!identifier.empty()) {
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
memberDecorationInfo)))
return Type();

structContext.remove(identifier);
return idStructTy;
}

return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
}

Expand Down Expand Up @@ -689,7 +764,24 @@ static void print(ImageType type, DialectAsmPrinter &os) {
}

static void print(StructType type, DialectAsmPrinter &os) {
thread_local llvm::SetVector<StringRef> structContext;

os << "struct<";

if (type.isIdentified()) {
os << type.getIdentifier();

if (structContext.count(type.getIdentifier())) {
os << ">";
return;
}

os << ", ";
structContext.insert(type.getIdentifier());
}

os << "(";

auto printMember = [&](unsigned i) {
os << type.getElementType(i);
SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
Expand All @@ -713,7 +805,10 @@ static void print(StructType type, DialectAsmPrinter &os) {
};
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
os << ">";
os << ")>";

if (type.isIdentified())
structContext.remove(type.getIdentifier());
}

static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
Expand Down

0 comments on commit bddaa7a

Please sign in to comment.