Skip to content

Commit

Permalink
Pix: Cope with group shared AS->MS payload (#6619)
Browse files Browse the repository at this point in the history
This change copes with the AS->MS payload being placed in group-shared
by the application (and MSFT's samples do indeed do this). (TIL, thanks
to pow2clk, that the spec says that the payload counts against the
group-shared total, implying, if not explicitly stating, that at least
on some platforms, the payload will be in group-shared anyway.)

The MS pass needs to be given data from the AS about the AS's thread
group topology, and this is done by extending the payload struct to add
three uints. This can't be done when the payload is resident in
group-shared, of course, because that would change the layout of
group-shared memory.
So the new approach here is to copy the payload to a new alloca (in the
default address space) struct with the members of the base struct plus
the extended data the MS needs, and then to copy piece-wise because
llvm.memcpy isn't appropriate for group-shared-to-normal address space
copies.
  • Loading branch information
jeffnn committed May 15, 2024
1 parent d9caef5 commit fd7e54b
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 122 deletions.
243 changes: 127 additions & 116 deletions lib/DxilPIXPasses/DxilPIXAddTidToAmplificationShaderPayload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,143 +45,154 @@ void DxilPIXAddTidToAmplificationShaderPayload::applyOptions(PassOptions O) {
}

void AddValueToExpandedPayload(OP *HlslOP, llvm::IRBuilder<> &B,
ExpandedStruct &expanded,
AllocaInst *NewStructAlloca,
unsigned int expandedValueIndex, Value *value) {
Constant *Zero32Arg = HlslOP->GetU32Const(0);
SmallVector<Value *, 2> IndexToAppendedValue;
IndexToAppendedValue.push_back(Zero32Arg);
IndexToAppendedValue.push_back(HlslOP->GetU32Const(expandedValueIndex));
auto *PointerToEmbeddedNewValue = B.CreateInBoundsGEP(
expanded.ExpandedPayloadStructType, NewStructAlloca, IndexToAppendedValue,
NewStructAlloca, IndexToAppendedValue,
"PointerToEmbeddedNewValue" + std::to_string(expandedValueIndex));
B.CreateStore(value, PointerToEmbeddedNewValue);
}

bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) {
void CopyAggregate(IRBuilder<> &B, Type *Ty, Value *Source, Value *Dest,
ArrayRef<Value *> GEPIndices) {
if (StructType *ST = dyn_cast<StructType>(Ty)) {
SmallVector<Value *, 16> StructIndices;
StructIndices.append(GEPIndices.begin(), GEPIndices.end());
StructIndices.push_back(nullptr);
for (unsigned j = 0; j < ST->getNumElements(); ++j) {
StructIndices.back() = B.getInt32(j);
CopyAggregate(B, ST->getElementType(j), Source, Dest, StructIndices);
}
} else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
SmallVector<Value *, 16> StructIndices;
StructIndices.append(GEPIndices.begin(), GEPIndices.end());
StructIndices.push_back(nullptr);
for (unsigned j = 0; j < AT->getNumElements(); ++j) {
StructIndices.back() = B.getInt32(j);
CopyAggregate(B, AT->getArrayElementType(), Source, Dest, StructIndices);
}
} else {
auto *SourceGEP = B.CreateGEP(Source, GEPIndices, "CopyStructSourceGEP");
Value *Val = B.CreateLoad(SourceGEP, "CopyStructLoad");
auto *DestGEP = B.CreateGEP(Dest, GEPIndices, "CopyStructDestGEP");
B.CreateStore(Val, DestGEP, "CopyStructStore");
}
}

bool DxilPIXAddTidToAmplificationShaderPayload::runOnModule(Module &M) {
DxilModule &DM = M.GetOrCreateDxilModule();
LLVMContext &Ctx = M.getContext();
OP *HlslOP = DM.GetOP();

Type *OriginalPayloadStructPointerType = nullptr;
Type *OriginalPayloadStructType = nullptr;
ExpandedStruct expanded;
llvm::Function *entryFunction = PIXPassHelpers::GetEntryFunction(DM);
for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction);
I != E; ++I) {
if (auto *Instr = llvm::cast<Instruction>(&*I)) {
if (hlsl::OP::IsDxilOpFuncCallInst(Instr,
hlsl::OP::OpCode::DispatchMesh)) {
DxilInst_DispatchMesh DispatchMesh(Instr);
OriginalPayloadStructPointerType =
DispatchMesh.get_payload()->getType();
OriginalPayloadStructType =
OriginalPayloadStructPointerType->getPointerElementType();
expanded = ExpandStructType(Ctx, OriginalPayloadStructType);
}
}
}

AllocaInst *OldStructAlloca = nullptr;
AllocaInst *NewStructAlloca = nullptr;
std::vector<AllocaInst *> allocasOfPayloadType;
for (inst_iterator I = inst_begin(entryFunction), E = inst_end(entryFunction);
I != E; ++I) {
auto *Inst = &*I;
if (llvm::isa<AllocaInst>(Inst)) {
auto *Alloca = llvm::cast<AllocaInst>(Inst);
if (Alloca->getType() == OriginalPayloadStructPointerType) {
allocasOfPayloadType.push_back(Alloca);
}
if (hlsl::OP::IsDxilOpFuncCallInst(&*I, hlsl::OP::OpCode::DispatchMesh)) {
DxilInst_DispatchMesh DispatchMesh(&*I);
Type *OriginalPayloadStructPointerType =
DispatchMesh.get_payload()->getType();
Type *OriginalPayloadStructType =
OriginalPayloadStructPointerType->getPointerElementType();
ExpandedStruct expanded =
ExpandStructType(Ctx, OriginalPayloadStructType);

llvm::IRBuilder<> B(&*I);

auto *NewStructAlloca =
B.CreateAlloca(expanded.ExpandedPayloadStructType,
HlslOP->GetU32Const(1), "NewPayload");
NewStructAlloca->setAlignment(4);
auto PayloadType =
llvm::dyn_cast<PointerType>(DispatchMesh.get_payload()->getType());
SmallVector<Value *, 16> GEPIndices;
GEPIndices.push_back(B.getInt32(0));
CopyAggregate(B, PayloadType->getPointerElementType(),
DispatchMesh.get_payload(), NewStructAlloca, GEPIndices);

Constant *Zero32Arg = HlslOP->GetU32Const(0);
Constant *One32Arg = HlslOP->GetU32Const(1);
Constant *Two32Arg = HlslOP->GetU32Const(2);

auto GroupIdFunc =
HlslOP->GetOpFunc(DXIL::OpCode::GroupId, Type::getInt32Ty(Ctx));
Constant *GroupIdOpcode =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::GroupId);
auto *GroupIdX =
B.CreateCall(GroupIdFunc, {GroupIdOpcode, Zero32Arg}, "GroupIdX");
auto *GroupIdY =
B.CreateCall(GroupIdFunc, {GroupIdOpcode, One32Arg}, "GroupIdY");
auto *GroupIdZ =
B.CreateCall(GroupIdFunc, {GroupIdOpcode, Two32Arg}, "GroupIdZ");

// FlatGroupID = z + y*numZ + x*numY*numZ
// Where x,y,z are the group ID components, and numZ and numY are the
// corresponding AS group-count arguments to the DispatchMesh Direct3D API
auto *GroupYxNumZ = B.CreateMul(
GroupIdY, HlslOP->GetU32Const(m_DispatchArgumentZ), "GroupYxNumZ");
auto *FlatGroupNumZY =
B.CreateAdd(GroupIdZ, GroupYxNumZ, "FlatGroupNumZY");
auto *GroupXxNumYZ = B.CreateMul(
GroupIdX,
HlslOP->GetU32Const(m_DispatchArgumentY * m_DispatchArgumentZ),
"GroupXxNumYZ");
auto *FlatGroupID =
B.CreateAdd(GroupXxNumYZ, FlatGroupNumZY, "FlatGroupID");

// The ultimate goal is a single unique thread ID for this AS thread.
// So take the flat group number, multiply it by the number of
// threads per group...
auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul(
FlatGroupID,
HlslOP->GetU32Const(DM.GetNumThreads(0) * DM.GetNumThreads(1) *
DM.GetNumThreads(2)),
"FlatGroupIDWithSpaceForThreadInGroupId");

auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc(
DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty(Ctx));
Constant *FlattenedThreadIdInGroupOpcode =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::FlattenedThreadIdInGroup);
auto FlatThreadIdInGroup = B.CreateCall(FlattenedThreadIdInGroupFunc,
{FlattenedThreadIdInGroupOpcode},
"FlattenedThreadIdInGroup");

// ...and add the flat thread id:
auto *FlatId = B.CreateAdd(FlatGroupIDWithSpaceForThreadInGroupId,
FlatThreadIdInGroup, "FlatId");

AddValueToExpandedPayload(
HlslOP, B, NewStructAlloca,
expanded.ExpandedPayloadStructType->getStructNumElements() - 3,
FlatId);
AddValueToExpandedPayload(
HlslOP, B, NewStructAlloca,
expanded.ExpandedPayloadStructType->getStructNumElements() - 2,
DispatchMesh.get_threadGroupCountY());
AddValueToExpandedPayload(
HlslOP, B, NewStructAlloca,
expanded.ExpandedPayloadStructType->getStructNumElements() - 1,
DispatchMesh.get_threadGroupCountZ());

auto DispatchMeshFn = HlslOP->GetOpFunc(
DXIL::OpCode::DispatchMesh, expanded.ExpandedPayloadStructPtrType);
Constant *DispatchMeshOpcode =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::DispatchMesh);
B.CreateCall(DispatchMeshFn,
{DispatchMeshOpcode, DispatchMesh.get_threadGroupCountX(),
DispatchMesh.get_threadGroupCountY(),
DispatchMesh.get_threadGroupCountZ(), NewStructAlloca});
I->removeFromParent();
delete &*I;
// Validation requires exactly one DispatchMesh in an AS, so we can exit
// after the first one:
DM.ReEmitDxilResources();
return true;
}
}
for (auto &Alloca : allocasOfPayloadType) {
OldStructAlloca = Alloca;
llvm::IRBuilder<> B(Alloca->getContext());
NewStructAlloca = B.CreateAlloca(expanded.ExpandedPayloadStructType,
HlslOP->GetU32Const(1), "NewPayload");
NewStructAlloca->setAlignment(Alloca->getAlignment());
NewStructAlloca->insertAfter(Alloca);

ReplaceAllUsesOfInstructionWithNewValueAndDeleteInstruction(
Alloca, NewStructAlloca, expanded.ExpandedPayloadStructType);
}

auto F = HlslOP->GetOpFunc(DXIL::OpCode::DispatchMesh,
expanded.ExpandedPayloadStructPtrType);
for (auto FI = F->user_begin(); FI != F->user_end();) {
auto *FunctionUser = *FI++;
auto *UserInstruction = llvm::cast<Instruction>(FunctionUser);
DxilInst_DispatchMesh DispatchMesh(UserInstruction);

llvm::IRBuilder<> B(UserInstruction);

Constant *Zero32Arg = HlslOP->GetU32Const(0);
Constant *One32Arg = HlslOP->GetU32Const(1);
Constant *Two32Arg = HlslOP->GetU32Const(2);

auto GroupIdFunc =
HlslOP->GetOpFunc(DXIL::OpCode::GroupId, Type::getInt32Ty(Ctx));
Constant *GroupIdOpcode =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::GroupId);
auto *GroupIdX =
B.CreateCall(GroupIdFunc, {GroupIdOpcode, Zero32Arg}, "GroupIdX");
auto *GroupIdY =
B.CreateCall(GroupIdFunc, {GroupIdOpcode, One32Arg}, "GroupIdY");
auto *GroupIdZ =
B.CreateCall(GroupIdFunc, {GroupIdOpcode, Two32Arg}, "GroupIdZ");

// FlatGroupID = z + y*numZ + x*numY*numZ
// Where x,y,z are the group ID components, and numZ and numY are the
// corresponding AS group-count arguments to the DispatchMesh Direct3D API
auto *GroupYxNumZ = B.CreateMul(
GroupIdY, HlslOP->GetU32Const(m_DispatchArgumentZ), "GroupYxNumZ");
auto *FlatGroupNumZY = B.CreateAdd(GroupIdZ, GroupYxNumZ, "FlatGroupNumZY");
auto *GroupXxNumYZ = B.CreateMul(
GroupIdX,
HlslOP->GetU32Const(m_DispatchArgumentY * m_DispatchArgumentZ),
"GroupXxNumYZ");
auto *FlatGroupID =
B.CreateAdd(GroupXxNumYZ, FlatGroupNumZY, "FlatGroFlatGroupIDupNum");

// The ultimate goal is a single unique thread ID for this AS thread.
// So take the flat group number, multiply it by the number of
// threads per group...
auto *FlatGroupIDWithSpaceForThreadInGroupId = B.CreateMul(
FlatGroupID,
HlslOP->GetU32Const(DM.GetNumThreads(0) * DM.GetNumThreads(1) *
DM.GetNumThreads(2)),
"FlatGroupIDWithSpaceForThreadInGroupId");

auto *FlattenedThreadIdInGroupFunc = HlslOP->GetOpFunc(
DXIL::OpCode::FlattenedThreadIdInGroup, Type::getInt32Ty(Ctx));
Constant *FlattenedThreadIdInGroupOpcode =
HlslOP->GetU32Const((unsigned)DXIL::OpCode::FlattenedThreadIdInGroup);
auto FlatThreadIdInGroup = B.CreateCall(FlattenedThreadIdInGroupFunc,
{FlattenedThreadIdInGroupOpcode},
"FlattenedThreadIdInGroup");

// ...and add the flat thread id:
auto *FlatId = B.CreateAdd(FlatGroupIDWithSpaceForThreadInGroupId,
FlatThreadIdInGroup, "FlatId");

AddValueToExpandedPayload(HlslOP, B, expanded, NewStructAlloca,
OriginalPayloadStructType->getStructNumElements(),
FlatId);
AddValueToExpandedPayload(
HlslOP, B, expanded, NewStructAlloca,
OriginalPayloadStructType->getStructNumElements() + 1,
DispatchMesh.get_threadGroupCountY());
AddValueToExpandedPayload(
HlslOP, B, expanded, NewStructAlloca,
OriginalPayloadStructType->getStructNumElements() + 2,
DispatchMesh.get_threadGroupCountZ());
}

DM.ReEmitDxilResources();

return true;
return false;
}

char DxilPIXAddTidToAmplificationShaderPayload::ID = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: %dxc -enable-16bit-types -Od -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s

// Check that the payload was piece-wise copied into a local copy from group-shared:
// There are 28 elements:

// CHECK: [[LOAD0:%.*]] = load [[TYPE0:.*]], [[TYPE0]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE0]] [[LOAD0]]
// CHECK: [[LOAD1:%.*]] = load [[TYPE1:.*]], [[TYPE1]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE1]] [[LOAD1]]
// CHECK: [[LOAD2:%.*]] = load [[TYPE2:.*]], [[TYPE2]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE2]] [[LOAD2]]
// CHECK: [[LOAD3:%.*]] = load [[TYPE3:.*]], [[TYPE3]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE3]] [[LOAD3]]
// CHECK: [[LOAD4:%.*]] = load [[TYPE4:.*]], [[TYPE4]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE4]] [[LOAD4]]
// CHECK: [[LOAD5:%.*]] = load [[TYPE5:.*]], [[TYPE5]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE5]] [[LOAD5]]
// CHECK: [[LOAD6:%.*]] = load [[TYPE6:.*]], [[TYPE6]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE6]] [[LOAD6]]
// CHECK: [[LOAD7:%.*]] = load [[TYPE7:.*]], [[TYPE7]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE7]] [[LOAD7]]
// CHECK: [[LOAD8:%.*]] = load [[TYPE8:.*]], [[TYPE8]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE8]] [[LOAD8]]
// CHECK: [[LOAD9:%.*]] = load [[TYPE9:.*]], [[TYPE9]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE9]] [[LOAD9]]

// CHECK: [[LOAD10:%.*]] = load [[TYPE10:.*]], [[TYPE10]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE10]] [[LOAD10]]
// CHECK: [[LOAD11:%.*]] = load [[TYPE11:.*]], [[TYPE11]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE11]] [[LOAD11]]
// CHECK: [[LOAD12:%.*]] = load [[TYPE12:.*]], [[TYPE12]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE12]] [[LOAD12]]
// CHECK: [[LOAD13:%.*]] = load [[TYPE13:.*]], [[TYPE13]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE13]] [[LOAD13]]
// CHECK: [[LOAD14:%.*]] = load [[TYPE14:.*]], [[TYPE14]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE14]] [[LOAD14]]
// CHECK: [[LOAD15:%.*]] = load [[TYPE15:.*]], [[TYPE15]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE15]] [[LOAD15]]
// CHECK: [[LOAD16:%.*]] = load [[TYPE16:.*]], [[TYPE16]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE16]] [[LOAD16]]
// CHECK: [[LOAD17:%.*]] = load [[TYPE17:.*]], [[TYPE17]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE17]] [[LOAD17]]
// CHECK: [[LOAD18:%.*]] = load [[TYPE18:.*]], [[TYPE18]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE18]] [[LOAD18]]
// CHECK: [[LOAD19:%.*]] = load [[TYPE19:.*]], [[TYPE19]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE19]] [[LOAD19]]

// CHECK: [[LOAD20:%.*]] = load [[TYPE20:.*]], [[TYPE20]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE20]] [[LOAD20]]
// CHECK: [[LOAD21:%.*]] = load [[TYPE21:.*]], [[TYPE21]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE21]] [[LOAD21]]
// CHECK: [[LOAD22:%.*]] = load [[TYPE22:.*]], [[TYPE22]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE22]] [[LOAD22]]
// CHECK: [[LOAD23:%.*]] = load [[TYPE23:.*]], [[TYPE23]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE23]] [[LOAD23]]
// CHECK: [[LOAD24:%.*]] = load [[TYPE24:.*]], [[TYPE24]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE24]] [[LOAD24]]
// CHECK: [[LOAD25:%.*]] = load [[TYPE25:.*]], [[TYPE25]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE25]] [[LOAD25]]
// CHECK: [[LOAD26:%.*]] = load [[TYPE26:.*]], [[TYPE26]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE26]] [[LOAD26]]
// CHECK: [[LOAD27:%.*]] = load [[TYPE27:.*]], [[TYPE27]] addrspace(3)* getelementptr inbounds
// CHECK:store volatile [[TYPE27]] [[LOAD27]]

// And no more:
// CHECK-NOT: [[LOAD28:%.*]] = load [[TYPE28:.*]], [[TYPE28]] addrspace(3)* getelementptr inbounds

struct Contained {
uint j;
float af[3];
};

struct Bigger {
half h;
Contained a[2];
};

struct MyPayload {
uint i;
Bigger big[3];
};

groupshared MyPayload payload;

[numthreads(1, 1, 1)] void main(uint gid
: SV_GroupID) {
DispatchMesh(1, 1, 1, payload);
}
21 changes: 21 additions & 0 deletions tools/clang/test/HLSLFileCheck/pix/DebugAsGroupSharedPayload.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %dxc -Od -Emain -Tas_6_6 %s | %opt -S -hlsl-dxil-PIX-add-tid-to-as-payload,dispatchArgY=3,dispatchArgZ=7 | %FileCheck %s

// Check that the payload was piece-wise copied into a local copy
// CHECK: [[LOADGEP:%.*]] = getelementptr %struct.MyPayload
// CHECK: [[LOAD:%.*]] = load i32, i32* [[LOADGEP]]
// CHECK: store volatile i32 [[LOAD]]

struct MyPayload
{
uint i;
};

groupshared MyPayload payload;

[numthreads(1, 1, 1)]
void main(uint gid : SV_GroupID)
{
MyPayload copy;
copy = payload;
DispatchMesh(1, 1, 1, copy);
}
Loading

0 comments on commit fd7e54b

Please sign in to comment.