Skip to content

Commit

Permalink
DxilValidation Restore any missing alignment from RDAT
Browse files Browse the repository at this point in the history
Since DxilValidation for RDAT part expects exact binary comparison match, we need to restore any missing alignment from the RDAT before we generate RDAT and compare.

Alignment for records used with GetNodeRecordPtr() will be restored from llvm, and not overwritten by this alignment from RDAT, so it still varifies that alignment matches type llvm module when available.
  • Loading branch information
tex3d committed Feb 20, 2024
1 parent 86d70b3 commit 1acf3ff
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions lib/HLSL/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6334,6 +6334,66 @@ bool ValidateCompilerVersionPart(const void *pBlobPtr, UINT blobSize) {
return true;
}

static bool UpdateNodeRecordAlignment(
DxilFunctionProps &props,
RDAT::RecordArrayReader<RDAT::NodeShaderIOAttrib_Reader> attribs,
unsigned index, bool bIsInput) {
for (unsigned iAttrib = 0; iAttrib < attribs.Count(); iAttrib++) {
auto &attrib = attribs[iAttrib];
if (attrib.hasRecordAlignmentInBytes()) {
// check whether our alignment is 0. If our alignment is 0, set it to the
// alignment from the RDAT.
auto &nodes = bIsInput ? props.InputNodes : props.OutputNodes;
if (nodes.size() <= index)
return false;
unsigned &ourAlignment = nodes[index].RecordType.alignment;
if (ourAlignment == 0)
ourAlignment = attrib.getRecordAlignmentInBytes();
else if (ourAlignment != attrib.getRecordAlignmentInBytes())
return false;
}
}
}

static bool RestoreNodeRecordAlignment(DxilModule &DM,
RDAT::DxilRuntimeData &rdat,
const char *PartName) {
// Records used with GetNodeRecordPtr() will have alignment set from llvm
// struct. Restore missing alignment from RDAT for unused node records.
if (auto fnTable = rdat.GetFunctionTable()) {
for (unsigned iFunction = 0; iFunction < fnTable.Count(); iFunction++) {
auto fn = fnTable[iFunction];
if (fn.getShaderKind() == DXIL::ShaderKind::Node) {
// Get function properties from DxilModule.
Function *F = DM.GetModule()->getFunction(fn.getName());
if (!F)
return false;
if (!DM.HasDxilFunctionProps(F))
return false;
DxilFunctionProps &props = DM.GetDxilFunctionProps(F);
if (props.shaderKind != DXIL::ShaderKind::Node)
return false;
// Iterate through input nodes and output nodes.
// If input or output node record has an alignment, then update
// alignment as necessary.
auto fn2 = RDAT::RuntimeDataFunctionInfo2_Reader(fn);
auto node = fn2.getNode();
for (unsigned i = 0; i < node.getInputs().Count(); i++) {
auto input = node.getInputs()[i];
if (!UpdateNodeRecordAlignment(props, input.getAttribs(), i, true))
return false;
}
for (unsigned i = 0; i < node.getOutputs().Count(); i++) {
auto output = node.getOutputs()[i];
if (!UpdateNodeRecordAlignment(props, output.getAttribs(), i, false))
return false;
}
}
}
}
return true;
}

static void VerifyRDATMatches(ValidationContext &ValCtx, const void *pRDATData,
uint32_t RDATSize) {
const char *PartName = "Runtime Data (RDAT)";
Expand All @@ -6357,6 +6417,11 @@ static void VerifyRDATMatches(ValidationContext &ValCtx, const void *pRDATData,
}
}

if (!RestoreNodeRecordAlignment(ValCtx.DxilMod, rdat, PartName)) {
ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {PartName});
return;
}

unique_ptr<DxilPartWriter> pWriter(NewRDATWriter(ValCtx.DxilMod));
VerifyBlobPartMatches(ValCtx, PartName, pWriter.get(), pRDATData, RDATSize);
}
Expand Down

0 comments on commit 1acf3ff

Please sign in to comment.