Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 17 additions & 77 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

#include "AlignmentSizeCalculator.h"
#include "InitListHandler.h"
#include "LowerTypeVisitor.h"
#include "RawBufferMethods.h"
#include "dxc/DXIL/DxilConstants.h"
#include "dxc/HlslIntrinsicOp.h"
Expand Down Expand Up @@ -543,71 +542,6 @@ bool isVkRawBufferLoadIntrinsic(const clang::FunctionDecl *FD) {
return true;
}

// Takes an AST member type, and determines its index in the equivalent SPIR-V
// struct type. This is required as the struct layout might change between the
// AST representation and SPIR-V representation.
uint32_t getFieldIndexInStruct(const SpirvCodeGenOptions &spirvOptions,
LowerTypeVisitor &lowerTypeVisitor,
const MemberExpr *expr) {
// If we are accessing a derived struct, we need to account for the number
// of base structs, since they are placed as fields at the beginning of the
// derived struct.
auto baseType = expr->getBase()->getType();
if (baseType->isPointerType()) {
baseType = baseType->getPointeeType();
}

const auto *fieldDecl =
dynamic_cast<const FieldDecl *>(expr->getMemberDecl());
assert(fieldDecl);
const uint32_t indexAST =
getNumBaseClasses(baseType) + fieldDecl->getFieldIndex();

// The AST type index is not representative of the SPIR-V type index
// because we might squash some fields (bitfields by ex.).
// What we need is to match each AST node with the squashed field and then,
// determine the real index.
const SpirvType *spvType = lowerTypeVisitor.lowerType(
baseType, spirvOptions.sBufferLayoutRule, llvm::None, SourceLocation());
assert(spvType);

const auto st = dynamic_cast<const StructType *>(spvType);
assert(st != nullptr);
const auto &fields = st->getFields();
assert(indexAST <= fields.size());

// Some fields in SPIR-V share the same index (bitfields). Computing the final
// index of the requested field.
uint32_t indexSPV = 0;
for (size_t i = 1; i <= indexAST; i++) {
// Do not remove this condition. This is required to support inheritance:
// 1. SPIR-V composite first element is the parent type:
// by ex "OpTypeStruct %base_struct %float".
// 2. if the parent type is an empty class, it's size it zero, hence
// "%float" offset is also 0.
//
// A way to detect such cases is to check for type difference: fields cannot
// be merged if the type is different.
if (fields[i - 1].type != fields[i].type) {
indexSPV++;
continue;
}

if (fields[i - 1].offset.getValueOr(0) != fields[i].offset.getValueOr(0)) {
indexSPV++;
continue;
}
}

// TODO(issue #4140): remove once bitfields are implemented.
// This is just a safeguard until bitfield support is in. Before bitfields,
// AST indices were always correct, so this function should not change that
// behavior. Once the bitfield support is in, indices will start to diverge,
// and this assert should be removed.
assert(indexSPV == indexAST);
return indexSPV;
}

} // namespace

SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
Expand Down Expand Up @@ -7680,17 +7614,23 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
}
}

{
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
const uint32_t fieldIndex =
getFieldIndexInStruct(spirvOptions, lowerTypeVisitor, indexing);

if (rawIndex) {
rawIndices->push_back(fieldIndex);
} else {
indices->push_back(spvBuilder.getConstantInt(
astContext.IntTy, llvm::APInt(32, fieldIndex, true)));
}
// Append the index of the current level
const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
assert(fieldDecl);
// If we are accessing a derived struct, we need to account for the number
// of base structs, since they are placed as fields at the beginning of the
// derived struct.
auto baseType = indexing->getBase()->getType();
if (baseType->isPointerType()) {
baseType = baseType->getPointeeType();
}
const uint32_t index =
getNumBaseClasses(baseType) + fieldDecl->getFieldIndex();
if (rawIndex) {
rawIndices->push_back(index);
} else {
indices->push_back(spvBuilder.getConstantInt(
astContext.IntTy, llvm::APInt(32, index, true)));
}

return base;
Expand Down