From 5470a3b1b4757fb48f353520b7b4486832c4004e Mon Sep 17 00:00:00 2001 From: luciechoi Date: Fri, 3 Oct 2025 18:22:40 +0000 Subject: [PATCH] [SPIR-V] Fix `asdouble` issue in SPIRV codegen to correctly generate `OpBitCast` instruction. --- .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 25 ++++++++++++++++- .../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 28 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) create mode 100644 llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 75055072573b3..ebd957c42762c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -188,8 +188,31 @@ class SPIRVLegalizePointerCast : public FunctionPass { FixedVectorType *SrcType = cast(Src->getType()); FixedVectorType *DstType = cast(GR->findDeducedElementType(Dst)); - assert(DstType->getNumElements() >= SrcType->getNumElements()); + auto dstNumElements = DstType->getNumElements(); + auto srcNumElements = SrcType->getNumElements(); + + // if the element type differs, it is a bitcast. + if (DstType->getElementType() != SrcType->getElementType()) { + // Support bitcast between vectors of different sizes only if + // the total bitwidth is the same. + auto dstBitWidth = + DstType->getElementType()->getScalarSizeInBits() * dstNumElements; + auto srcBitWidth = + SrcType->getElementType()->getScalarSizeInBits() * srcNumElements; + assert(dstBitWidth == srcBitWidth && + "Unsupported bitcast between vectors of different sizes."); + + Src = + B.CreateIntrinsic(Intrinsic::spv_bitcast, {DstType, SrcType}, {Src}); + buildAssignType(B, DstType, Src); + SrcType = DstType; + + StoreInst *SI = B.CreateStore(Src, Dst); + SI->setAlignment(Alignment); + return SI; + } + assert(DstType->getNumElements() >= SrcType->getNumElements()); LoadInst *LI = B.CreateLoad(DstType, Dst); LI->setAlignment(Alignment); Value *OldValues = LI; diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll new file mode 100644 index 0000000000000..84913283f6868 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll @@ -0,0 +1,28 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - | FileCheck %s --match-full-lines +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v2_uint:]] = OpTypeVector %[[#uint]] 2 +; CHECK-DAG: %[[#double:]] = OpTypeFloat 64 +; CHECK-DAG: %[[#v2_double:]] = OpTypeVector %[[#double]] 2 +; CHECK-DAG: %[[#v4_uint:]] = OpTypeVector %[[#uint]] 4 +@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1 +@.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1 + +define void @main() local_unnamed_addr #0 { +entry: + %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2i32_12_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str) + %1 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2) + %2 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 0) + %3 = load <2 x i32>, ptr addrspace(11) %2, align 8 + %4 = tail call noundef align 8 dereferenceable(8) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2i32_12_0t(target("spirv.VulkanBuffer", [0 x <2 x i32>], 12, 0) %0, i32 1) + %5 = load <2 x i32>, ptr addrspace(11) %4, align 8 +; CHECK: %[[#tmp:]] = OpVectorShuffle %[[#v4_uint]] {{%[0-9]+}} {{%[0-9]+}} 0 2 1 3 + %6 = shufflevector <2 x i32> %3, <2 x i32> %5, <4 x i32> +; CHECK: %[[#access:]] = OpAccessChain {{.*}} + %7 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %1, i32 0) +; CHECK: %[[#bitcast:]] = OpBitcast %[[#v2_double]] %[[#tmp]] +; CHECK: OpStore %[[#access]] %[[#bitcast]] Aligned 16 + store <4 x i32> %6, ptr addrspace(11) %7, align 16 + ret void +}