diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index c691d5901529b..531feccccb032 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -330,10 +330,34 @@ class StructType bool hasValue() const { return !isa(decorationValue); } }; + // Type for specifying the decoration(s) on the struct itself. + struct StructDecorationInfo { + Decoration decoration; + Attribute decorationValue; + + StructDecorationInfo(Decoration decoration, Attribute decorationValue) + : decoration(decoration), decorationValue(decorationValue) {} + + friend bool operator==(const StructDecorationInfo &lhs, + const StructDecorationInfo &rhs) { + return lhs.decoration == rhs.decoration && + lhs.decorationValue == rhs.decorationValue; + } + + friend bool operator<(const StructDecorationInfo &lhs, + const StructDecorationInfo &rhs) { + return llvm::to_underlying(lhs.decoration) < + llvm::to_underlying(rhs.decoration); + } + + bool hasValue() const { return !isa(decorationValue); } + }; + /// Construct a literal StructType with at least one member. static StructType get(ArrayRef memberTypes, ArrayRef offsetInfo = {}, - ArrayRef memberDecorations = {}); + ArrayRef memberDecorations = {}, + ArrayRef structDecorations = {}); /// Construct an identified StructType. This creates a StructType whose body /// (member types, offset info, and decorations) is not set yet. A call to @@ -367,6 +391,9 @@ class StructType bool hasOffset() const; + /// Returns true if the struct has a specified decoration. + bool hasDecoration(spirv::Decoration decoration) const; + uint64_t getMemberOffset(unsigned) const; // Returns in `memberDecorations` the Decorations (apart from Offset) @@ -380,12 +407,18 @@ class StructType unsigned i, SmallVectorImpl &decorationsInfo) const; + // Returns in `structDecorations` the Decorations associated with the + // StructType. + void getStructDecorations(SmallVectorImpl + &structDecorations) const; + /// Sets the contents of an incomplete identified StructType. This method must /// be called only for identified StructTypes and it must be called only once /// per instance. Otherwise, failure() is returned. LogicalResult trySetBody(ArrayRef memberTypes, ArrayRef offsetInfo = {}, - ArrayRef memberDecorations = {}); + ArrayRef memberDecorations = {}, + ArrayRef structDecorations = {}); void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage = std::nullopt); @@ -396,6 +429,9 @@ class StructType llvm::hash_code hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); +llvm::hash_code +hash_value(const StructType::StructDecorationInfo &structDecorationInfo); + // SPIR-V KHR cooperative matrix type class CooperativeMatrixType : public Type::TypeBase` +// `)` +// (`,` struct-decoration)? +// `>` static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser) { // TODO: This function is quite lengthy. Break it down into smaller chunks. @@ -765,17 +767,48 @@ static Type parseStructType(SPIRVDialect const &dialect, return Type(); } - if (failed(parser.parseRParen()) || failed(parser.parseGreater())) + if (failed(parser.parseRParen())) + return Type(); + + SmallVector structDecorationInfo; + + auto parseStructDecoration = [&]() { + std::optional decoration = + parseAndVerify(dialect, parser); + if (!decoration) + return failure(); + + // Parse decoration value if it exists. + if (succeeded(parser.parseOptionalEqual())) { + Attribute decorationValue; + if (failed(parser.parseAttribute(decorationValue))) + return failure(); + + structDecorationInfo.emplace_back(decoration.value(), decorationValue); + } else { + structDecorationInfo.emplace_back(decoration.value(), + UnitAttr::get(dialect.getContext())); + } + return success(); + }; + + while (succeeded(parser.parseOptionalComma())) + if (failed(parseStructDecoration())) + return Type(); + + if (failed(parser.parseGreater())) return Type(); if (!identifier.empty()) { if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, - memberDecorationInfo))) + memberDecorationInfo, + structDecorationInfo))) return Type(); return idStructTy; } - return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); + return StructType::get(memberTypes, offsetInfo, memberDecorationInfo, + structDecorationInfo); } // spirv-type ::= array-type @@ -891,7 +924,23 @@ static void print(StructType type, DialectAsmPrinter &os) { }; llvm::interleaveComma(llvm::seq(0, type.getNumElements()), os, printMember); - os << ")>"; + os << ")"; + + SmallVector decorations; + type.getStructDecorations(decorations); + if (!decorations.empty()) { + os << ", "; + auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) { + os << stringifyDecoration(decoration.decoration); + if (decoration.hasValue()) { + os << "="; + os.printAttributeWithoutType(decoration.decorationValue); + } + }; + llvm::interleaveComma(decorations, os, eachFn); + } + + os << ">"; } static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 46739bcd79b8a..ddb342621f371 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -835,12 +835,14 @@ void SampledImageType::getCapabilities( /// - for literal structs: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. /// /// Identified structures only have a mutable component consisting of: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. struct spirv::detail::StructTypeStorage : public TypeStorage { /// Construct a storage object for an identified struct type. A struct type /// associated with such storage must call StructType::trySetBody(...) later @@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage(StringRef identifier) : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr), numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr), + numStructDecorations(0), structDecorationsInfo(nullptr), identifier(identifier) {} /// Construct a storage object for a literal struct type. A struct type @@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage( unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, - StructType::MemberDecorationInfo const *memberDecorationsInfo) + StructType::MemberDecorationInfo const *memberDecorationsInfo, + unsigned numStructDecorations, + StructType::StructDecorationInfo const *structDecorationsInfo) : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo), numMembers(numMembers), numMemberDecorations(numMemberDecorations), - memberDecorationsInfo(memberDecorationsInfo) {} + memberDecorationsInfo(memberDecorationsInfo), + numStructDecorations(numStructDecorations), + structDecorationsInfo(structDecorationsInfo) {} /// A storage key is divided into 2 parts: /// - for identified structs: @@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - an ArrayRef for member types; /// - an ArrayRef for member offset info; /// - an ArrayRef for member decoration + /// info; + /// - an ArrayRef for struct decoration /// info. /// /// An identified struct type is uniqued only by the first part (field 0) /// of the key. /// - /// A literal struct type is uniqued only by the second part (fields 1, 2, and - /// 3) of the key. The identifier field (field 0) must be empty. + /// A literal struct type is uniqued only by the second part (fields 1, 2, 3 + /// and 4) of the key. The identifier field (field 0) must be empty. using KeyTy = std::tuple, ArrayRef, - ArrayRef>; + ArrayRef, + ArrayRef>; /// For identified structs, return true if the given key contains the same /// identifier. @@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { } return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), - getMemberDecorationsInfo()); + getMemberDecorationsInfo(), getStructDecorationsInfo()); } /// If the given key contains a non-empty identifier, this method constructs @@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } - return new (allocator.allocate()) - StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, - numMemberDecorations, memberDecorationList); + const StructType::StructDecorationInfo *structDecorationList = nullptr; + unsigned numStructDecorations = 0; + if (!std::get<4>(key).empty()) { + auto keyStructDecorations = std::get<4>(key); + numStructDecorations = keyStructDecorations.size(); + structDecorationList = allocator.copyInto(keyStructDecorations).data(); + } + + return new (allocator.allocate()) StructTypeStorage( + keyTypes.size(), typesList, offsetInfoList, numMemberDecorations, + memberDecorationList, numStructDecorations, structDecorationList); } ArrayRef getMemberTypes() const { @@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { return {}; } + ArrayRef getStructDecorationsInfo() const { + if (structDecorationsInfo) + return ArrayRef(structDecorationsInfo, + numStructDecorations); + return {}; + } + StringRef getIdentifier() const { return identifier; } bool isIdentified() const { return !identifier.empty(); } @@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - If called for an identified struct whose body was set before (through a /// call to this method) but with different contents from the passed /// arguments. - LogicalResult mutate( - TypeStorageAllocator &allocator, ArrayRef structMemberTypes, - ArrayRef structOffsetInfo, - ArrayRef structMemberDecorationInfo) { + LogicalResult + mutate(TypeStorageAllocator &allocator, ArrayRef structMemberTypes, + ArrayRef structOffsetInfo, + ArrayRef structMemberDecorationInfo, + ArrayRef structDecorationInfo) { if (!isIdentified()) return failure(); if (memberTypesAndIsBodySet.getInt() && (getMemberTypes() != structMemberTypes || getOffsetInfo() != structOffsetInfo || - getMemberDecorationsInfo() != structMemberDecorationInfo)) + getMemberDecorationsInfo() != structMemberDecorationInfo || + getStructDecorationsInfo() != structDecorationInfo)) return failure(); memberTypesAndIsBodySet.setInt(true); @@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { allocator.copyInto(structMemberDecorationInfo).data(); } + if (!structDecorationInfo.empty()) { + numStructDecorations = structDecorationInfo.size(); + structDecorationsInfo = allocator.copyInto(structDecorationInfo).data(); + } + return success(); } @@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { unsigned numMembers; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; + unsigned numStructDecorations; + StructType::StructDecorationInfo const *structDecorationsInfo; StringRef identifier; }; StructType StructType::get(ArrayRef memberTypes, ArrayRef offsetInfo, - ArrayRef memberDecorations) { + ArrayRef memberDecorations, + ArrayRef structDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); // Sort the decorations. - SmallVector sortedDecorations( + SmallVector sortedMemberDecorations( memberDecorations); - llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); + llvm::array_pod_sort(sortedMemberDecorations.begin(), + sortedMemberDecorations.end()); + SmallVector sortedStructDecorations( + structDecorations); + llvm::array_pod_sort(sortedStructDecorations.begin(), + sortedStructDecorations.end()); + return Base::get(memberTypes.vec().front().getContext(), /*identifier=*/StringRef(), memberTypes, offsetInfo, - sortedDecorations); + sortedMemberDecorations, sortedStructDecorations); } StructType StructType::getIdentified(MLIRContext *context, @@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context, return Base::get(context, identifier, ArrayRef(), ArrayRef(), - ArrayRef()); + ArrayRef(), + ArrayRef()); } StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { StructType newStructType = Base::get( context, identifier, ArrayRef(), ArrayRef(), - ArrayRef()); + ArrayRef(), + ArrayRef()); // Set an empty body in case this is a identified struct. if (newStructType.isIdentified() && failed(newStructType.trySetBody( ArrayRef(), ArrayRef(), - ArrayRef()))) + ArrayRef(), + ArrayRef()))) return StructType(); return newStructType; @@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const { bool StructType::hasOffset() const { return getImpl()->offsetInfo; } +bool StructType::hasDecoration(spirv::Decoration decoration) const { + for (StructType::StructDecorationInfo info : + getImpl()->getStructDecorationsInfo()) + if (info.decoration == decoration) + return true; + + return false; +} + uint64_t StructType::getMemberOffset(unsigned index) const { assert(getNumElements() > index && "member index out of range"); return getImpl()->offsetInfo[index]; @@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations( } } +void StructType::getStructDecorations( + SmallVectorImpl &structDecorations) + const { + structDecorations.clear(); + auto implDecorations = getImpl()->getStructDecorationsInfo(); + structDecorations.append(implDecorations.begin(), implDecorations.end()); +} + LogicalResult StructType::trySetBody(ArrayRef memberTypes, ArrayRef offsetInfo, - ArrayRef memberDecorations) { - return Base::mutate(memberTypes, offsetInfo, memberDecorations); + ArrayRef memberDecorations, + ArrayRef structDecorations) { + return Base::mutate(memberTypes, offsetInfo, memberDecorations, + structDecorations); } void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, @@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value( memberDecorationInfo.decoration); } +llvm::hash_code spirv::hash_value( + const StructType::StructDecorationInfo &structDecorationInfo) { + return llvm::hash_value(structDecorationInfo.decoration); +} + //===----------------------------------------------------------------------===// // MatrixType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index e5934bb9943fd..88931b53a6889 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target "; } - // Block decoration does not affect spirv.struct type, but is still stored - // for verification. - // TODO: Update StructType to contain this information since - // it is needed for many validation rules. decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); break; case spirv::Decoration::Location: @@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef operands) { if (failed(structType.trySetBody( deferredStructIt->memberTypes, deferredStructIt->offsetInfo, - deferredStructIt->memberDecorationsInfo))) + deferredStructIt->memberDecorationsInfo, + deferredStructIt->structDecorationsInfo))) return failure(); deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); @@ -1203,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef operands) { } } + SmallVector structDecorationsInfo; + if (decorations.count(operands[0])) { + NamedAttrList &allDecorations = decorations[operands[0]]; + for (NamedAttribute &decorationAttr : allDecorations) { + std::optional decoration = spirv::symbolizeDecoration( + llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true)); + assert(decoration.has_value()); + structDecorationsInfo.emplace_back(decoration.value(), + decorationAttr.getValue()); + } + } + uint32_t structID = operands[0]; std::string structIdentifier = nameMap.lookup(structID).str(); if (structIdentifier.empty()) { assert(unresolvedMemberTypes.empty() && "didn't expect unresolved member types"); - typeMap[structID] = - spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + typeMap[structID] = spirv::StructType::get( + memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo); } else { auto structTy = spirv::StructType::getIdentified(context, structIdentifier); typeMap[structID] = structTy; if (!unresolvedMemberTypes.empty()) - deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, - memberTypes, offsetInfo, - memberDecorationsInfo}); + deferredStructTypesInfos.push_back( + {structTy, unresolvedMemberTypes, memberTypes, offsetInfo, + memberDecorationsInfo, structDecorationsInfo}); else if (failed(structTy.trySetBody(memberTypes, offsetInfo, - memberDecorationsInfo))) + memberDecorationsInfo, + structDecorationsInfo))) return failure(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 20482bd2bf501..db1cc3f8d79c2 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -95,6 +95,7 @@ struct DeferredStructTypeInfo { SmallVector memberTypes; SmallVector offsetInfo; SmallVector memberDecorationsInfo; + SmallVector structDecorationsInfo; }; /// A struct that collects the info needed to materialize/emit a diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 6aab6bb668c78..04d698748a796 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -319,6 +319,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::RestrictPointer: case spirv::Decoration::NoContraction: case spirv::Decoration::Constant: + case spirv::Decoration::Block: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. if (isa(attr)) @@ -618,11 +619,16 @@ LogicalResult Serializer::prepareBasicType( operands.push_back(static_cast(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); + // TODO: Now struct decorations are supported this code may not be + // necessary. However, it is left to support backwards compatibility. + // Ideally, Block decorations should be inserted when converting to SPIR-V. if (isInterfaceStructPtrType(ptrType)) { - if (failed(emitDecoration(getTypeID(pointeeStruct), - spirv::Decoration::Block))) - return emitError(loc, "cannot decorate ") - << pointeeStruct << " with Block decoration"; + auto structType = cast(ptrType.getPointeeType()); + if (!structType.hasDecoration(spirv::Decoration::Block)) + if (failed(emitDecoration(getTypeID(pointeeStruct), + spirv::Decoration::Block))) + return emitError(loc, "cannot decorate ") + << pointeeStruct << " with Block decoration"; } return success(); @@ -692,6 +698,20 @@ LogicalResult Serializer::prepareBasicType( } } + SmallVector structDecorations; + structType.getStructDecorations(structDecorations); + + for (spirv::StructType::StructDecorationInfo &structDecoration : + structDecorations) { + if (failed(processDecorationAttr(loc, resultID, + structDecoration.decoration, + structDecoration.decorationValue))) { + return emitError(loc, "cannot decorate struct ") + << structType << " with " + << stringifyDecoration(structDecoration.decoration); + } + } + typeEnum = spirv::Opcode::OpTypeStruct; if (structType.isIdentified()) diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index 5d05a65414969..6d321afebf7f8 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -296,6 +296,12 @@ func.func private @struct_type_with_matrix_2(!spirv.struct<(!spirv.matrix<3 x ve // CHECK: func private @struct_empty(!spirv.struct<()>) func.func private @struct_empty(!spirv.struct<()>) +// CHECK: func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) +func.func private @struct_block(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block>) + +// CHECK: func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) +func.func private @struct_two_dec(!spirv.struct<(vector<4xf32> [BuiltIn=0], f32 [BuiltIn=1]), Block, CPacked>) + // ----- // expected-error @+1 {{offset specification must be given for all members}} diff --git a/mlir/test/Target/SPIRV/memory-ops.mlir b/mlir/test/Target/SPIRV/memory-ops.mlir index 6b50c3921d427..786d07a218c66 100644 --- a/mlir/test/Target/SPIRV/memory-ops.mlir +++ b/mlir/test/Target/SPIRV/memory-ops.mlir @@ -37,32 +37,32 @@ spirv.module Logical GLSL450 requires #spirv.vce { // ----- spirv.module Logical GLSL450 requires #spirv.vce { - spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr [0])>, StorageBuffer>, %arg1: !spirv.ptr [0])>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr [0])> + spirv.func @load_store_zero_rank_float(%arg0: !spirv.ptr [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr [0]), Block>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr [0]), Block>, StorageBuffer> // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : f32 %0 = spirv.Constant 0 : i32 - %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr + %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr %2 = spirv.Load "StorageBuffer" %1 : f32 - // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr [0])> + // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr [0]), Block>, StorageBuffer> // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32 %3 = spirv.Constant 0 : i32 - %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr + %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr spirv.Store "StorageBuffer" %4, %2 : f32 spirv.Return } - spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr [0])>, StorageBuffer>, %arg1: !spirv.ptr [0])>, StorageBuffer>) "None" { - // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr [0])> + spirv.func @load_store_zero_rank_int(%arg0: !spirv.ptr [0]), Block>, StorageBuffer>, %arg1: !spirv.ptr [0]), Block>, StorageBuffer>) "None" { + // CHECK: [[LOAD_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr [0]), Block>, StorageBuffer> // CHECK-NEXT: [[VAL:%.*]] = spirv.Load "StorageBuffer" [[LOAD_PTR]] : i32 %0 = spirv.Constant 0 : i32 - %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr + %1 = spirv.AccessChain %arg0[%0, %0] : !spirv.ptr [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr %2 = spirv.Load "StorageBuffer" %1 : i32 - // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr [0])> + // CHECK: [[STORE_PTR:%.*]] = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr [0]), Block>, StorageBuffer> // CHECK-NEXT: spirv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32 %3 = spirv.Constant 0 : i32 - %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr + %4 = spirv.AccessChain %arg1[%3, %3] : !spirv.ptr [0]), Block>, StorageBuffer>, i32, i32 -> !spirv.ptr spirv.Store "StorageBuffer" %4, %2 : i32 spirv.Return } diff --git a/mlir/test/Target/SPIRV/struct.mlir b/mlir/test/Target/SPIRV/struct.mlir index 0db0c0bfa2660..4984ee79f903d 100644 --- a/mlir/test/Target/SPIRV/struct.mlir +++ b/mlir/test/Target/SPIRV/struct.mlir @@ -7,23 +7,23 @@ spirv.module Logical GLSL450 requires #spirv.vce { // CHECK: !spirv.ptr [4])> [4])>, Input> spirv.GlobalVariable @var1 bind(0, 2) : !spirv.ptr [4])> [4])>, Input> - // CHECK: !spirv.ptr, StorageBuffer> - spirv.GlobalVariable @var2 : !spirv.ptr, StorageBuffer> + // CHECK: !spirv.ptr, StorageBuffer> + spirv.GlobalVariable @var2 : !spirv.ptr, StorageBuffer> - // CHECK: !spirv.ptr [0])>, stride=512> [0])>, StorageBuffer> - spirv.GlobalVariable @var3 : !spirv.ptr [0])>, stride=512> [0])>, StorageBuffer> + // CHECK: !spirv.ptr [0])>, stride=512> [0]), Block>, StorageBuffer> + spirv.GlobalVariable @var3 : !spirv.ptr [0])>, stride=512> [0]), Block>, StorageBuffer> - // CHECK: !spirv.ptr, StorageBuffer> - spirv.GlobalVariable @var4 : !spirv.ptr, StorageBuffer> + // CHECK: !spirv.ptr, StorageBuffer> + spirv.GlobalVariable @var4 : !spirv.ptr, StorageBuffer> - // CHECK: !spirv.ptr, StorageBuffer> - spirv.GlobalVariable @var5 : !spirv.ptr, StorageBuffer> + // CHECK: !spirv.ptr, StorageBuffer> + spirv.GlobalVariable @var5 : !spirv.ptr, StorageBuffer> - // CHECK: !spirv.ptr, StorageBuffer> - spirv.GlobalVariable @var6 : !spirv.ptr, StorageBuffer> + // CHECK: !spirv.ptr, StorageBuffer> + spirv.GlobalVariable @var6 : !spirv.ptr, StorageBuffer> - // CHECK: !spirv.ptr> [0, ColMajor, MatrixStride=16])>, StorageBuffer> - spirv.GlobalVariable @var7 : !spirv.ptr> [0, ColMajor, MatrixStride=16])>, StorageBuffer> + // CHECK: !spirv.ptr> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer> + spirv.GlobalVariable @var7 : !spirv.ptr> [0, ColMajor, MatrixStride=16]), Block>, StorageBuffer> // CHECK: !spirv.ptr, StorageBuffer> spirv.GlobalVariable @empty : !spirv.ptr, StorageBuffer> @@ -34,15 +34,17 @@ spirv.module Logical GLSL450 requires #spirv.vce { // CHECK: !spirv.ptr [0])>, Input> spirv.GlobalVariable @id_var0 : !spirv.ptr [0])>, Input> + // CHECK: !spirv.ptr, StorageBuffer>), Block>, StorageBuffer> + spirv.GlobalVariable @recursive_simple : !spirv.ptr, StorageBuffer>), Block>, StorageBuffer> - // CHECK: !spirv.ptr, StorageBuffer>)>, StorageBuffer> - spirv.GlobalVariable @recursive_simple : !spirv.ptr, StorageBuffer>)>, StorageBuffer> + // CHECK: !spirv.ptr, Uniform>), Block>, Uniform>), Block>, Uniform> + spirv.GlobalVariable @recursive_2 : !spirv.ptr, Uniform>), Block>, Uniform>), Block>, Uniform> - // CHECK: !spirv.ptr, Uniform>)>, Uniform>)>, Uniform> - spirv.GlobalVariable @recursive_2 : !spirv.ptr, Uniform>)>, Uniform>)>, Uniform> + // CHECK: !spirv.ptr, Uniform>, !spirv.ptr, Uniform>), Block>, Uniform>), Block>, Uniform> + spirv.GlobalVariable @recursive_3 : !spirv.ptr, Uniform>, !spirv.ptr, Uniform>), Block>, Uniform>), Block>, Uniform> - // CHECK: !spirv.ptr, Uniform>, !spirv.ptr, Uniform>)>, Uniform>)>, Uniform> - spirv.GlobalVariable @recursive_3 : !spirv.ptr, Uniform>, !spirv.ptr, Uniform>)>, Uniform>)>, Uniform> + // CHECK: spirv.GlobalVariable @block : !spirv.ptr [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> + spirv.GlobalVariable @block : !spirv.ptr [BuiltIn=0], f32 [BuiltIn=1]), Block>, Output> // CHECK: !spirv.ptr [0])>, Input>, // CHECK-SAME: !spirv.ptr [0])>, Output> diff --git a/mlir/test/Target/SPIRV/undef.mlir b/mlir/test/Target/SPIRV/undef.mlir index b9044fe8b40af..8889b80e86f95 100644 --- a/mlir/test/Target/SPIRV/undef.mlir +++ b/mlir/test/Target/SPIRV/undef.mlir @@ -13,10 +13,10 @@ spirv.module Logical GLSL450 requires #spirv.vce { // CHECK: {{%.*}} = spirv.Undef : !spirv.array<4 x !spirv.array<4 x i32>> %5 = spirv.Undef : !spirv.array<4x!spirv.array<4xi32>> %6 = spirv.CompositeExtract %5[1 : i32, 2 : i32] : !spirv.array<4x!spirv.array<4xi32>> - // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr, StorageBuffer> - %7 = spirv.Undef : !spirv.ptr, StorageBuffer> + // CHECK: {{%.*}} = spirv.Undef : !spirv.ptr, StorageBuffer> + %7 = spirv.Undef : !spirv.ptr, StorageBuffer> %8 = spirv.Constant 0 : i32 - %9 = spirv.AccessChain %7[%8] : !spirv.ptr, StorageBuffer>, i32 -> !spirv.ptr + %9 = spirv.AccessChain %7[%8] : !spirv.ptr, StorageBuffer>, i32 -> !spirv.ptr spirv.Return } }