Skip to content

Conversation

luciechoi
Copy link
Contributor

Generate OpBitCast instruction for pointer cast operation if the element type is different.

The HLSL for the unit test is

StructuredBuffer<uint2> In : register(t0);

RWStructuredBuffer<double2> Out : register(u2);


[numthreads(1,1,1)]
void main() {
  Out[0] = asdouble(In[0], In[1]);
}

Resolves #153513

@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Lucie Choi (luciechoi)

Changes

Generate OpBitCast instruction for pointer cast operation if the element type is different.

The HLSL for the unit test is

StructuredBuffer&lt;uint2&gt; In : register(t0);

RWStructuredBuffer&lt;double2&gt; Out : register(u2);


[numthreads(1,1,1)]
void main() {
  Out[0] = asdouble(In[0], In[1]);
}

Resolves #153513


Full diff: https://github.com/llvm/llvm-project/pull/161891.diff

2 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp (+24-1)
  • (added) llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll (+29)
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 75055072573b3..ebd957c42762c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -188,8 +188,31 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     FixedVectorType *SrcType = cast<FixedVectorType>(Src->getType());
     FixedVectorType *DstType =
         cast<FixedVectorType>(GR->findDeducedElementType(Dst));
-    assert(DstType->getNumElements() >= SrcType->getNumElements());
+    auto dstNumElements = DstType->getNumElements();
+    auto srcNumElements = SrcType->getNumElements();
+
+    // if the element type differs, it is a bitcast.
+    if (DstType->getElementType() != SrcType->getElementType()) {
+      // Support bitcast between vectors of different sizes only if
+      // the total bitwidth is the same.
+      auto dstBitWidth =
+          DstType->getElementType()->getScalarSizeInBits() * dstNumElements;
+      auto srcBitWidth =
+          SrcType->getElementType()->getScalarSizeInBits() * srcNumElements;
+      assert(dstBitWidth == srcBitWidth &&
+             "Unsupported bitcast between vectors of different sizes.");
+
+      Src =
+          B.CreateIntrinsic(Intrinsic::spv_bitcast, {DstType, SrcType}, {Src});
+      buildAssignType(B, DstType, Src);
+      SrcType = DstType;
+
+      StoreInst *SI = B.CreateStore(Src, Dst);
+      SI->setAlignment(Alignment);
+      return SI;
+    }
 
+    assert(DstType->getNumElements() >= SrcType->getNumElements());
     LoadInst *LI = B.CreateLoad(DstType, Dst);
     LI->setAlignment(Alignment);
     Value *OldValues = LI;
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
new file mode 100644
index 0000000000000..d9374833887f0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll
@@ -0,0 +1,29 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s --match-full-lines
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG:        %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:     %[[#v2_uint:]] = OpTypeVector %[[#uint]] 2
+; CHECK-DAG:      %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG:   %[[#v2_double:]] = OpTypeVector %[[#double]] 2
+; CHECK-DAG:     %[[#v4_uint:]] = OpTypeVector %[[#uint]] 4
+@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
+@.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
+
+
+define void @main() local_unnamed_addr #0 {
+entry:
+  %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
+  %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2)
+  %2 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 0)
+  %3 = load <2 x i32>, ptr addrspace(11) %2, align 8
+  %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 1)
+  %5 = load <2 x i32>, ptr addrspace(11) %4, align 8
+; CHECK: %[[#tmp:]] = OpVectorShuffle %[[#v4_uint]] {{%[0-9]+}} {{%[0-9]+}} 0 2 1 3
+  %6 = shufflevector <2 x i32> %3, <2 x i32> %5, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+; CHECK: %[[#access:]] = OpAccessChain {{.*}}
+  %7 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %1, i32 0)
+; CHECK: %[[#bitcast:]] = OpBitcast %[[#v2_double]] %[[#tmp]]
+; CHECK: OpStore %[[#access]] %[[#bitcast]] Aligned 16
+  store <4 x i32> %6, ptr addrspace(11) %7, align 16
+  ret void
+}

%3 = load <2 x i32>, ptr addrspace(11) %2, align 8
%4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 1)
%5 = load <2 x i32>, ptr addrspace(11) %4, align 8
; CHECK: %[[#tmp:]] = OpVectorShuffle %[[#v4_uint]] {{%[0-9]+}} {{%[0-9]+}} 0 2 1 3
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final SPIRV for this function had lots of verbose OpCopyObject, so I left some variables unmatched/ignored.

        %36 = OpFunction %2 DontInline %3       ; -- Begin function main
        %1 = OpLabel
        %37 = OpVariable %23 Function %32
        %38 = OpVariable %22 Function %33
        %39 = OpCopyObject %16 %34
        %40 = OpCopyObject %13 %35
        %41 = OpAccessChain %10 %39 %24 %24
        %42 = OpLoad %9 %41 Aligned 8
        %43 = OpAccessChain %10 %39 %24 %25
        %44 = OpLoad %9 %43 Aligned 8
        %45 = OpVectorShuffle %8 %42 %44 0 2 1 3
        %46 = OpAccessChain %6 %40 %24 %24
        %47 = OpBitcast %5 %45
        %48 = OpCopyObject %13 %35
        OpStore %46 %47 Aligned 16
        OpReturn
        OpFunctionEnd
                                        ; -- End function

@luciechoi
Copy link
Contributor Author

@s-perron

@s-perron s-perron merged commit b0ad9c2 into llvm:main Oct 3, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL][SPIR-V] Hitting assert when compiling program with asdouble in HLSL

3 participants