diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index c92edc3256..84c7eafacb 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -15,7 +15,6 @@ #include "AlignmentSizeCalculator.h" #include "InitListHandler.h" -#include "LowerTypeVisitor.h" #include "RawBufferMethods.h" #include "dxc/DXIL/DxilConstants.h" #include "dxc/HlslIntrinsicOp.h" @@ -543,71 +542,6 @@ bool isVkRawBufferLoadIntrinsic(const clang::FunctionDecl *FD) { return true; } -// Takes an AST member type, and determines its index in the equivalent SPIR-V -// struct type. This is required as the struct layout might change between the -// AST representation and SPIR-V representation. -uint32_t getFieldIndexInStruct(const SpirvCodeGenOptions &spirvOptions, - LowerTypeVisitor &lowerTypeVisitor, - const MemberExpr *expr) { - // If we are accessing a derived struct, we need to account for the number - // of base structs, since they are placed as fields at the beginning of the - // derived struct. - auto baseType = expr->getBase()->getType(); - if (baseType->isPointerType()) { - baseType = baseType->getPointeeType(); - } - - const auto *fieldDecl = - dynamic_cast(expr->getMemberDecl()); - assert(fieldDecl); - const uint32_t indexAST = - getNumBaseClasses(baseType) + fieldDecl->getFieldIndex(); - - // The AST type index is not representative of the SPIR-V type index - // because we might squash some fields (bitfields by ex.). - // What we need is to match each AST node with the squashed field and then, - // determine the real index. - const SpirvType *spvType = lowerTypeVisitor.lowerType( - baseType, spirvOptions.sBufferLayoutRule, llvm::None, SourceLocation()); - assert(spvType); - - const auto st = dynamic_cast(spvType); - assert(st != nullptr); - const auto &fields = st->getFields(); - assert(indexAST <= fields.size()); - - // Some fields in SPIR-V share the same index (bitfields). Computing the final - // index of the requested field. - uint32_t indexSPV = 0; - for (size_t i = 1; i <= indexAST; i++) { - // Do not remove this condition. This is required to support inheritance: - // 1. SPIR-V composite first element is the parent type: - // by ex "OpTypeStruct %base_struct %float". - // 2. if the parent type is an empty class, it's size it zero, hence - // "%float" offset is also 0. - // - // A way to detect such cases is to check for type difference: fields cannot - // be merged if the type is different. - if (fields[i - 1].type != fields[i].type) { - indexSPV++; - continue; - } - - if (fields[i - 1].offset.getValueOr(0) != fields[i].offset.getValueOr(0)) { - indexSPV++; - continue; - } - } - - // TODO(issue #4140): remove once bitfields are implemented. - // This is just a safeguard until bitfield support is in. Before bitfields, - // AST indices were always correct, so this function should not change that - // behavior. Once the bitfield support is in, indices will start to diverge, - // and this assert should be removed. - assert(indexSPV == indexAST); - return indexSPV; -} - } // namespace SpirvEmitter::SpirvEmitter(CompilerInstance &ci) @@ -7680,17 +7614,23 @@ const Expr *SpirvEmitter::collectArrayStructIndices( } } - { - LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions); - const uint32_t fieldIndex = - getFieldIndexInStruct(spirvOptions, lowerTypeVisitor, indexing); - - if (rawIndex) { - rawIndices->push_back(fieldIndex); - } else { - indices->push_back(spvBuilder.getConstantInt( - astContext.IntTy, llvm::APInt(32, fieldIndex, true))); - } + // Append the index of the current level + const auto *fieldDecl = cast(indexing->getMemberDecl()); + assert(fieldDecl); + // If we are accessing a derived struct, we need to account for the number + // of base structs, since they are placed as fields at the beginning of the + // derived struct. + auto baseType = indexing->getBase()->getType(); + if (baseType->isPointerType()) { + baseType = baseType->getPointeeType(); + } + const uint32_t index = + getNumBaseClasses(baseType) + fieldDecl->getFieldIndex(); + if (rawIndex) { + rawIndices->push_back(index); + } else { + indices->push_back(spvBuilder.getConstantInt( + astContext.IntTy, llvm::APInt(32, index, true))); } return base;