-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SPIR-V] Fix asdouble
issue in SPIRV codegen to correctly generate OpBitCast
instruction.
#161891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…`OpBitCast` instruction.
@llvm/pr-subscribers-backend-spir-v Author: Lucie Choi (luciechoi) ChangesGenerate The HLSL for the unit test is StructuredBuffer<uint2> In : register(t0);
RWStructuredBuffer<double2> Out : register(u2);
[numthreads(1,1,1)]
void main() {
Out[0] = asdouble(In[0], In[1]);
} Resolves #153513 Full diff: https://github.com/llvm/llvm-project/pull/161891.diff 2 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 75055072573b3..ebd957c42762c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -188,8 +188,31 @@ class SPIRVLegalizePointerCast : public FunctionPass {
FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
FixedVectorType *DstType =
cast<FixedVectorType>(GR->findDeducedElementType(Dst));
- assert(DstType->getNumElements() >= SrcType->getNumElements());
+ auto dstNumElements = DstType->getNumElements();
+ auto srcNumElements = SrcType->getNumElements();
+
+ // if the element type differs, it is a bitcast.
+ if (DstType->getElementType() != SrcType->getElementType()) {
+ // Support bitcast between vectors of different sizes only if
+ // the total bitwidth is the same.
+ auto dstBitWidth =
+ DstType->getElementType()->getScalarSizeInBits() * dstNumElements;
+ auto srcBitWidth =
+ SrcType->getElementType()->getScalarSizeInBits() * srcNumElements;
+ assert(dstBitWidth == srcBitWidth &&
+ "Unsupported bitcast between vectors of different sizes.");
+
+ Src =
+ B.CreateIntrinsic(Intrinsic::spv_bitcast, {DstType, SrcType}, {Src});
+ buildAssignType(B, DstType, Src);
+ SrcType = DstType;
+
+ StoreInst *SI = B.CreateStore(Src, Dst);
+ SI->setAlignment(Alignment);
+ return SI;
+ }
+ assert(DstType->getNumElements() >= SrcType->getNumElements());
LoadInst *LI = B.CreateLoad(DstType, Dst);
LI->setAlignment(Alignment);
Value *OldValues = LI;
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
new file mode 100644
index 0000000000000..d9374833887f0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -0,0 +1,29 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s --match-full-lines
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v2_uint:]] = OpTypeVector %[[#uint]] 2
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#v2_double:]] = OpTypeVector %[[#double]] 2
+; CHECK-DAG: %[[#v4_uint:]] = OpTypeVector %[[#uint]] 4
+@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
+@.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+
+define void @main() local_unnamed_addr #0 {
+entry:
+ %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
+ %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
+ %2 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 0)
+ %3 = load <2 x i32>, ptr addrspace(11) %2, align 8
+ %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 1)
+ %5 = load <2 x i32>, ptr addrspace(11) %4, align 8
+; CHECK: %[[#tmp:]] = OpVectorShuffle %[[#v4_uint]] {{%[0-9]+}} {{%[0-9]+}} 0 2 1 3
+ %6 = shufflevector <2 x i32> %3, <2 x i32> %5, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+; CHECK: %[[#access:]] = OpAccessChain {{.*}}
+ %7 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %1, i32 0)
+; CHECK: %[[#bitcast:]] = OpBitcast %[[#v2_double]] %[[#tmp]]
+; CHECK: OpStore %[[#access]] %[[#bitcast]] Aligned 16
+ store <4 x i32> %6, ptr addrspace(11) %7, align 16
+ ret void
+}
|
%3 = load <2 x i32>, ptr addrspace(11) %2, align 8 | ||
%4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 1) | ||
%5 = load <2 x i32>, ptr addrspace(11) %4, align 8 | ||
; CHECK: %[[#tmp:]] = OpVectorShuffle %[[#v4_uint]] {{%[0-9]+}} {{%[0-9]+}} 0 2 1 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The final SPIRV for this function had lots of verbose OpCopyObject
, so I left some variables unmatched/ignored.
%36 = OpFunction %2 DontInline %3 ; -- Begin function main
%1 = OpLabel
%37 = OpVariable %23 Function %32
%38 = OpVariable %22 Function %33
%39 = OpCopyObject %16 %34
%40 = OpCopyObject %13 %35
%41 = OpAccessChain %10 %39 %24 %24
%42 = OpLoad %9 %41 Aligned 8
%43 = OpAccessChain %10 %39 %24 %25
%44 = OpLoad %9 %43 Aligned 8
%45 = OpVectorShuffle %8 %42 %44 0 2 1 3
%46 = OpAccessChain %6 %40 %24 %24
%47 = OpBitcast %5 %45
%48 = OpCopyObject %13 %35
OpStore %46 %47 Aligned 16
OpReturn
OpFunctionEnd
; -- End function
Generate
OpBitCast
instruction for pointer cast operation if the element type is different.The HLSL for the unit test is
Resolves #153513