From aa1e9a0d2c0fdef0c0e8ddcedc115dba06bddb6e Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Tue, 7 Oct 2025 13:26:47 -0400 Subject: [PATCH 1/9] [SPIRV] Set hasSideEffects flag to false on type and constant opcodes This change sets the hasSideEffects flag to false on type and constant opcodes so that they can be considered trivially dead if their result is unused. This means that instruction selection will now be able to remove them. --- llvm/lib/Target/SPIRV/SPIRVInstrFormats.td | 5 + llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 179 +++++++++++------- .../SPIRV/hlsl-intrinsics/AddUint64.ll | 2 +- .../pointers/resource-vector-load-store.ll | 27 +-- 4 files changed, 130 insertions(+), 83 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td index 2fde2b0bc0b1f..f93240dc35993 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td @@ -25,6 +25,11 @@ class Op Opcode, dag outs, dag ins, string asmstr, list pattern = let Pattern = pattern; } +class PureOp Opcode, dag outs, dag ins, string asmstr, + list pattern = []> : Op { + let hasSideEffects = 0; +} + class UnknownOp pattern = []> : Op<0, outs, ins, asmstr, pattern> { let isPseudo = 1; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index a61351eba03f8..799a82c96b0f0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari // 3.42.6 Type-Declaration Instructions -def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">; -def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">; -def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness), - "$type = OpTypeInt $width $signedness">; -def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops), - "$type = OpTypeFloat $width">; -def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount), - "$type = OpTypeVector $compType $compCount">; -def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount), - "$type = OpTypeMatrix $colType $colCount">; -def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth, - i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops), - "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">; -def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">; -def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType), - "$res = OpTypeSampledImage $imageType">; -def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length), - "$type = OpTypeArray $elementType $length">; -def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType), - "$type = OpTypeRuntimeArray $elementType">; -def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">; -def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops), - "OpTypeStructContinuedINTEL">; -def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops), - "$res = OpTypeOpaque $name">; -def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type), - "$res = OpTypePointer $storage $type">; -def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops), - "$funcType = OpTypeFunction $returnType">; -def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">; -def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">; -def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">; -def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">; -def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">; -def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass), - "OpTypeForwardPointer $ptrType $storageClass">; -def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">; -def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">; -def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins), - "$res = OpTypeAccelerationStructureNV">; -def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res), - (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), - "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; -def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res), - (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use), - "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">; +def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">; +def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">; +def OpTypeInt + : PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness), + "$type = OpTypeInt $width $signedness">; +def OpTypeFloat + : PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops), + "$type = OpTypeFloat $width">; +def OpTypeVector + : PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount), + "$type = OpTypeVector $compType $compCount">; +def OpTypeMatrix + : PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount), + "$type = OpTypeMatrix $colType $colCount">; +def OpTypeImage : PureOp<25, (outs TYPE:$res), + (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth, + i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, + ImageFormat:$imFormat, variable_ops), + "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS " + "$sampled $imFormat">; +def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">; +def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType), + "$res = OpTypeSampledImage $imageType">; +def OpTypeArray + : PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length), + "$type = OpTypeArray $elementType $length">; +def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType), + "$type = OpTypeRuntimeArray $elementType">; +def OpTypeStruct + : PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">; +def OpTypeStructContinuedINTEL + : PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">; +def OpTypeOpaque + : PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops), + "$res = OpTypeOpaque $name">; +def OpTypePointer + : PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type), + "$res = OpTypePointer $storage $type">; +def OpTypeFunction + : PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops), + "$funcType = OpTypeFunction $returnType">; +def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">; +def OpTypeDeviceEvent + : PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">; +def OpTypeReserveId + : PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">; +def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">; +def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a), + "$res = OpTypePipe $a">; +def OpTypeForwardPointer + : PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass), + "OpTypeForwardPointer $ptrType $storageClass">; +def OpTypePipeStorage + : PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">; +def OpTypeNamedBarrier + : PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">; +def OpTypeAccelerationStructureNV + : PureOp<5341, (outs TYPE:$res), (ins), + "$res = OpTypeAccelerationStructureNV">; +def OpTypeCooperativeMatrixNV + : PureOp<5358, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), + "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; +def OpTypeCooperativeMatrixKHR + : PureOp<4456, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use), + "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols " + "$use">; // 3.42.7 Constant-Creation Instructions @@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">; def ConstPseudoTrue: IntImmLeaf; def ConstPseudoFalse: IntImmLeaf; -def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty", - [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>; -def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty", - [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>; - -def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops), - "$res = OpConstantComposite $type">; -def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops), - "OpConstantCompositeContinuedINTEL">; - -def OpConstantSampler: Op<45, (outs ID:$res), - (ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f), - "$res = OpConstantSampler $t $s $p $f">; -def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">; - -def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">; -def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">; -def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops), - "$res = OpSpecConstant $type $imm">; -def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops), - "$res = OpSpecConstantComposite $type">; -def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops), - "OpSpecConstantCompositeContinuedINTEL">; -def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops), - "$res = OpSpecConstantOp $t $c $o">; +def OpConstantTrue + : PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantTrue $src_ty", + [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>; +def OpConstantFalse + : PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantFalse $src_ty", + [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>; + +def OpConstantComposite + : PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpConstantComposite $type">; +def OpConstantCompositeContinuedINTEL + : PureOp<6091, (outs), (ins variable_ops), + "OpConstantCompositeContinuedINTEL">; + +def OpConstantSampler : PureOp<45, (outs ID:$res), + (ins TYPE:$t, SamplerAddressingMode:$s, + i32imm:$p, SamplerFilterMode:$f), + "$res = OpConstantSampler $t $s $p $f">; +def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantNull $src_ty">; + +def OpSpecConstantTrue + : PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">; +def OpSpecConstantFalse + : PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">; +def OpSpecConstant + : PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops), + "$res = OpSpecConstant $type $imm">; +def OpSpecConstantComposite + : PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpSpecConstantComposite $type">; +def OpSpecConstantCompositeContinuedINTEL + : PureOp<6092, (outs), (ins variable_ops), + "OpSpecConstantCompositeContinuedINTEL">; +def OpSpecConstantOp + : PureOp<52, (outs ID:$res), + (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops), + "$res = OpSpecConstantOp $t $c $o">; // 3.42.8 Memory Instructions diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll index a97492b8453ea..a15d628cc3614 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll @@ -63,7 +63,7 @@ entry: ; CHECK: %[[#a_high:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#a]] %[[#undef_v4i32]] 1 3 ; CHECK: %[[#b_low:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#b]] %[[#undef_v4i32]] 0 2 ; CHECK: %[[#b_high:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#b]] %[[#undef_v4i32]] 1 3 -; CHECK: %[[#iaddcarry:]] = OpIAddCarry %[[#struct_v2i32_v2i32]] %[[#a_low]] %[[#vec2_int_32]] +; CHECK: %[[#iaddcarry:]] = OpIAddCarry %[[#struct_v2i32_v2i32]] %[[#a_low]] %[[#b_low]] ; CHECK: %[[#lowsum:]] = OpCompositeExtract %[[#vec2_int_32]] %[[#iaddcarry]] 0 ; CHECK: %[[#carry:]] = OpCompositeExtract %[[#vec2_int_32]] %[[#iaddcarry]] 1 ; CHECK: %[[#carry_ne0:]] = OpINotEqual %[[#vec2_bool]] %[[#carry]] %[[#const_v2i32_0_0]] diff --git a/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll b/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll index 7548f4757dbe6..6fc03a386d14d 100644 --- a/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll +++ b/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll @@ -4,18 +4,23 @@ @.str = private unnamed_addr constant [7 x i8] c"buffer\00", align 1 +; The i64 values in the extracts will be turned +; into immidiate values. There should be no 64-bit +; integers in the module. +; CHECK-NOT: OpTypeInt 64 0 + define void @main() "hlsl.shader"="pixel" { -; CHECK: %24 = OpFunction %2 None %3 ; -- Begin function main -; CHECK-NEXT: %1 = OpLabel -; CHECK-NEXT: %25 = OpVariable %13 Function %22 -; CHECK-NEXT: %26 = OpLoad %7 %23 -; CHECK-NEXT: %27 = OpImageRead %5 %26 %15 -; CHECK-NEXT: %28 = OpCompositeExtract %4 %27 0 -; CHECK-NEXT: %29 = OpCompositeExtract %4 %27 1 -; CHECK-NEXT: %30 = OpFAdd %4 %29 %28 -; CHECK-NEXT: %31 = OpCompositeInsert %5 %30 %27 0 -; CHECK-NEXT: %32 = OpLoad %7 %23 -; CHECK-NEXT: OpImageWrite %32 %15 %31 +; CHECK: %[[FUNC:[0-9]+]] = OpFunction %[[VOID:[0-9]+]] None %[[FNTYPE:[0-9]+]] ; -- Begin function main +; CHECK-NEXT: %[[LABEL:[0-9]+]] = OpLabel +; CHECK-NEXT: %[[VAR:[0-9]+]] = OpVariable %[[PTR_FN:[a-zA-Z0-9_]+]] Function %[[INIT:[a-zA-Z0-9_]+]] +; CHECK-NEXT: %[[LOAD1:[0-9]+]] = OpLoad %[[IMG_TYPE:[a-zA-Z0-9_]+]] %[[IMG_VAR:[a-zA-Z0-9_]+]] +; CHECK-NEXT: %[[READ:[0-9]+]] = OpImageRead %[[VEC4:[a-zA-Z0-9_]+]] %[[LOAD1]] %[[COORD:[a-zA-Z0-9_]+]] +; CHECK-NEXT: %[[EXTRACT1:[0-9]+]] = OpCompositeExtract %[[FLOAT:[a-zA-Z0-9_]+]] %[[READ]] 0 +; CHECK-NEXT: %[[EXTRACT2:[0-9]+]] = OpCompositeExtract %[[FLOAT]] %[[READ]] 1 +; CHECK-NEXT: %[[ADD:[0-9]+]] = OpFAdd %[[FLOAT]] %[[EXTRACT2]] %[[EXTRACT1]] +; CHECK-NEXT: %[[INSERT:[0-9]+]] = OpCompositeInsert %[[VEC4]] %[[ADD]] %[[READ]] 0 +; CHECK-NEXT: %[[LOAD2:[0-9]+]] = OpLoad %[[IMG_TYPE]] %[[IMG_VAR]] +; CHECK-NEXT: OpImageWrite %[[LOAD2]] %[[COORD]] %[[INSERT]] ; CHECK-NEXT: OpReturn ; CHECK-NEXT: OpFunctionEnd entry: From dd1d522bee0706efe72638dbfcbcb08b397c534c Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Tue, 7 Oct 2025 13:26:47 -0400 Subject: [PATCH 2/9] [SPIRV] Expand spv_bitcast intrinsic during instruction selection The spv_bitcast intrinsic is currently replaced by an OpBitcast during prelegalization. This will cause a problem when we need to legalize the OpBitcast. The legalizer assumes that instruction already lowered to a target specific opcode is legal. We cannot lower it to a G_BITCAST because the bitcasts sometimes the LLT type will be the same, causing an error in the verifier, even if the SPIR-V types will be different. This commit keeps the intrinsic around until instructoin selection. We can create rules to legalize a G_INTRINISIC* instruction, and it does not create problem for the verifier. --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 8 +++ llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 56 ++++++++++--------- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 021353ab716f7..ccc2c0fc467fb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -3119,6 +3119,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectInsertElt(ResVReg, ResType, I); case Intrinsic::spv_gep: return selectGEP(ResVReg, ResType, I); + case Intrinsic::spv_bitcast: { + Register OpReg = I.getOperand(2).getReg(); + SPIRVType *OpType = + OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr; + if (!GR.isBitcastCompatible(ResType, OpType)) + report_fatal_error("incompatible result and operand types in a bitcast"); + return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast); + } case Intrinsic::spv_unref_global: case Intrinsic::spv_init_global: { MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg()); diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index db6f2d61e8f29..43ded6a71dd6d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -192,31 +192,38 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, .addUse(OpReg); } -// We do instruction selections early instead of calling MIB.buildBitcast() -// generating the general op code G_BITCAST. When MachineVerifier validates -// G_BITCAST we see a check of a kind: if Source Type is equal to Destination -// Type then report error "bitcast must change the type". This doesn't take into -// account the notion of a typed pointer that is important for SPIR-V where a -// user may and should use bitcast between pointers with different pointee types -// (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast). -// It's important for correct lowering in SPIR-V, because interpretation of the -// data type is not left to instructions that utilize the pointer, but encoded -// by the pointer declaration, and the SPIRV target can and must handle the -// declaration and use of pointers that specify the type of data they point to. -// It's not feasible to improve validation of G_BITCAST using just information -// provided by low level types of source and destination. Therefore we don't -// produce G_BITCAST as the general op code with semantics different from -// OpBitcast, but rather lower to OpBitcast immediately. As for now, the only -// difference would be that CombinerHelper couldn't transform known patterns -// around G_BUILD_VECTOR. See discussion -// in https://github.com/llvm/llvm-project/pull/110270 for even more context. -static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { +// We lower G_BITCAST to OpBitcast here to avoid a MachineVerifier error. +// The verifier checks if the source and destination LLTs of a G_BITCAST are +// different, but this check is too strict for SPIR-V's typed pointers, which +// may have the same LLT but different SPIRVType (e.g. pointers to different +// pointee types). By lowering to OpBitcast here, we bypass the verifier's +// check. See discussion in https://github.com/llvm/llvm-project/pull/110270 +// for more context. +// +// We also handle the llvm.spv.bitcast intrinsic here. If the source and +// destination SPIR-V types are the same, we lower it to a COPY to enable +// further optimizations like copy propagation. +static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { SmallVector ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { + if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(2).getReg(); + SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg); + SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg); + if (DstType && SrcType && DstType == SrcType) { + MIB.setInsertPt(*MI.getParent(), MI); + MIB.buildCopy(DstReg, SrcReg); + ToErase.push_back(&MI); + continue; + } + } + if (MI.getOpcode() != TargetOpcode::G_BITCAST) continue; + MIB.setInsertPt(*MI.getParent(), MI); buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); @@ -237,16 +244,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, SmallVector ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { - if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && - !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) + if (!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) continue; assert(MI.getOperand(2).isReg()); MIB.setInsertPt(*MI.getParent(), MI); ToErase.push_back(&MI); - if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { - MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg()); - continue; - } Register Def = MI.getOperand(0).getReg(); Register Source = MI.getOperand(2).getReg(); Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0); @@ -1089,7 +1091,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) { removeImplicitFallthroughs(MF, MIB); insertSpirvDecorations(MF, GR, MIB); insertInlineAsm(MF, GR, ST, MIB); - selectOpBitcasts(MF, GR, MIB); + lowerBitcasts(MF, GR, MIB); return true; } From cd5a1d26985f2a173a490f6445314fabd54f0fa1 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Fri, 24 Oct 2025 09:53:11 -0400 Subject: [PATCH 3/9] Remove unnecessary pointer checks --- llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index 43ded6a71dd6d..d538009f0ecbe 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -212,8 +212,13 @@ static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, Register DstReg = MI.getOperand(0).getReg(); Register SrcReg = MI.getOperand(2).getReg(); SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg); + assert( + DstType && + "Expected destination SPIR-V type to have been assigned already."); SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg); - if (DstType && SrcType && DstType == SrcType) { + assert(SrcType && + "Expected source SPIR-V type to have been assigned already."); + if (DstType == SrcType) { MIB.setInsertPt(*MI.getParent(), MI); MIB.buildCopy(DstReg, SrcReg); ToErase.push_back(&MI); From 9d342d353cc48ec8b73e84686abdf1be15c90a50 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Fri, 24 Oct 2025 10:03:31 -0400 Subject: [PATCH 4/9] [SPIRV] Fix vector bitcast check in LegalizePointerCast The previous check for vector bitcasts in `loadVectorFromVector` only compared the number of elements, which is insufficient when the element types differ. This can lead to incorrect assumptions about the validity of the cast. This commit replaces the element count check with a comparison of the total size of the vectors in bits. This ensures that the bitcast is only performed between vectors of the same size, preventing potential miscompilations. --- .../Target/SPIRV/SPIRVLegalizePointerCast.cpp | 9 +++++++- .../hlsl-resources/issue-146942-ptr-cast.ll | 4 +--- .../CodeGen/SPIRV/pointers/ptrcast-bitcast.ll | 22 +++++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 28a1690ef0be1..a692c24363310 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass { // Returns the loaded value. Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType, FixedVectorType *TargetType, Value *Source) { - assert(TargetType->getNumElements() <= SourceType->getNumElements()); LoadInst *NewLoad = B.CreateLoad(SourceType, Source); buildAssignType(B, SourceType, NewLoad); Value *AssignValue = NewLoad; if (TargetType->getElementType() != SourceType->getElementType()) { + const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout(); + [[maybe_unused]] TypeSize TargetTypeSize = + DL.getTypeSizeInBits(TargetType); + [[maybe_unused]] TypeSize SourceTypeSize = + DL.getTypeSizeInBits(SourceType); + assert(TargetTypeSize == SourceTypeSize); AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast, {TargetType, SourceType}, {NewLoad}); buildAssignType(B, TargetType, AssignValue); + return AssignValue; } + assert(TargetType->getNumElements() < SourceType->getNumElements()); SmallVector Mask(/* Size= */ TargetType->getNumElements()); for (unsigned I = 0; I < TargetType->getNumElements(); ++I) Mask[I] = I; diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll index ed67344842b11..4817e7450ac2e 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll @@ -16,7 +16,6 @@ define void @case1() local_unnamed_addr { ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16 ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]] - ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3 %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str) %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.2) %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0) @@ -29,8 +28,7 @@ define void @case1() local_unnamed_addr { define void @case2() local_unnamed_addr { ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16 ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]] - ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3 - ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2 + ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#CAST_LOAD]] %[[#UNDEF_INT4]] 0 1 2 %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str) %2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.3) %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0) diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll index 84913283f6868..a1ec2cd1cfdd2 100644 --- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll +++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll @@ -26,3 +26,25 @@ entry: store <4 x i32> %6, ptr addrspace(11) %7, align 16 ret void } + +; This tests a load from a pointer that has been bitcast between vector types +; which share the same total bit-width but have different numbers of elements. +; Tests that legalize-pointer-casts works correctly by moving the bitcast to +; the element that was loaded. + +define void @main2() local_unnamed_addr #0 { +entry: +; CHECK: %[[LOAD:[0-9]+]] = OpLoad %[[#v2_double]] {{.*}} +; CHECK: %[[BITCAST1:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD]] +; CHECK: %[[BITCAST2:[0-9]+]] = OpBitcast %[[#v2_double]] %[[BITCAST1]] +; CHECK: OpStore {{%[0-9]+}} %[[BITCAST2]] {{.*}} + + %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2) + %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 0) + %3 = load <4 x i32>, ptr addrspace(11) %2 + %4 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 1) + store <4 x i32> %3, ptr addrspace(11) %4 + ret void +} + +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } From c16bc1ef99215d81e79e76772f37f08369b04602 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Fri, 24 Oct 2025 13:21:24 -0400 Subject: [PATCH 5/9] [SPIRV] Use a worklist in the post-legalizer This commit refactors the SPIRV post-legalizer to use a worklist to process new instructions. Previously, the post-legalizer would iterate through all instructions and try to assign types. This could fail if a new instruction depended on another new instruction that had not been processed yet. The new implementation adds all new instructions that require a SPIR-V type to a worklist. It then iteratively processes the worklist until it is empty. This ensures that all dependencies are met before an instruction is processed. This change makes the post-legalizer more robust and fixes potential ordering issues with newly generated instructions. Existing tests cover existing functionality. More tests will be added as the legalizer is modified. Part of #153091 --- llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 388 ++++++++++++++++--- 1 file changed, 334 insertions(+), 54 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp index d17528dd882bf..d11168b70aea8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp @@ -17,7 +17,8 @@ #include "SPIRV.h" #include "SPIRVSubtarget.h" #include "SPIRVUtils.h" -#include "llvm/IR/Attributes.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "spirv-postlegalizer" @@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, static bool mayBeInserted(unsigned Opcode) { switch (Opcode) { + case TargetOpcode::G_CONSTANT: + case TargetOpcode::G_UNMERGE_VALUES: + case TargetOpcode::G_EXTRACT_VECTOR_ELT: + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: case TargetOpcode::G_SMAX: case TargetOpcode::G_UMAX: case TargetOpcode::G_SMIN: @@ -53,69 +59,344 @@ static bool mayBeInserted(unsigned Opcode) { case TargetOpcode::G_FMINIMUM: case TargetOpcode::G_FMAXNUM: case TargetOpcode::G_FMAXIMUM: + case TargetOpcode::G_IMPLICIT_DEF: + case TargetOpcode::G_BUILD_VECTOR: + case TargetOpcode::G_ICMP: + case TargetOpcode::G_ANYEXT: return true; default: return isTypeFoldingSupported(Opcode); } } -static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { +static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { MachineRegisterInfo &MRI = MF.getRegInfo(); + const LLT &Ty = MRI.getType(ResVReg); + unsigned BitWidth = Ty.getScalarSizeInBits(); + return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB); +} - for (MachineBasicBlock &MBB : MF) { - for (MachineInstr &I : MBB) { - const unsigned Opcode = I.getOpcode(); - if (Opcode == TargetOpcode::G_UNMERGE_VALUES) { - unsigned ArgI = I.getNumOperands() - 1; - Register SrcReg = I.getOperand(ArgI).isReg() - ? I.getOperand(ArgI).getReg() - : Register(0); - SPIRVType *DefType = - SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr; - if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector) - report_fatal_error( - "cannot select G_UNMERGE_VALUES with a non-vector argument"); - SPIRVType *ScalarType = - GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); - for (unsigned i = 0; i < I.getNumDefs(); ++i) { - Register ResVReg = I.getOperand(i).getReg(); - SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg); - if (!ResType) { - // There was no "assign type" actions, let's fix this now +static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg(); + if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) { + if (DefType->getOpcode() == SPIRV::OpTypeVector) { + SPIRVType *ScalarType = + GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + for (unsigned i = 0; i < I->getNumDefs(); ++i) { + Register DefReg = I->getOperand(i).getReg(); + if (!GR->getSPIRVTypeForVReg(DefReg)) { + LLT DefLLT = MRI.getType(DefReg); + SPIRVType *ResType; + if (DefLLT.isVector()) { + const SPIRVInstrInfo *TII = + MF.getSubtarget().getInstrInfo(); + ResType = GR->getOrCreateSPIRVVectorType( + ScalarType, DefLLT.getNumElements(), *I, *TII); + } else { ResType = ScalarType; - setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true); } + setRegClassType(DefReg, ResType, GR, &MRI, MF); } - } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 && - I.getNumOperands() > 1 && I.getOperand(1).isReg()) { - // Legalizer may have added a new instructions and introduced new - // registers, we must decorate them as if they were introduced in a - // non-automatic way - Register ResVReg = I.getOperand(0).getReg(); - // Check if the register defined by the instruction is newly generated - // or already processed - // Check if we have type defined for operands of the new instruction - bool IsKnownReg = MRI.getRegClassOrNull(ResVReg); - SPIRVType *ResVType = GR->getSPIRVTypeForVReg( - IsKnownReg ? ResVReg : I.getOperand(1).getReg()); - if (!ResVType) - continue; - // Set type & class - if (!IsKnownReg) - setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true); - // If this is a simple operation that is to be reduced by TableGen - // definition we must apply some of pre-legalizer rules here - if (isTypeFoldingSupported(Opcode)) { - processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg)); - if (IsKnownReg && MRI.hasOneUse(ResVReg)) { - MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg); - if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE) - continue; - } - insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI); + } + return true; + } + } + return false; +} + +static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + Register VecReg = I->getOperand(1).getReg(); + if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) { + assert(VecType->getOpcode() == SPIRV::OpTypeVector); + return GR->getScalarOrVectorComponentType(VecType); + } + + // If not handled yet, then check if it is used in a G_BUILD_VECTOR. + // If so get the type from there. + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) { + Register BuildVecResReg = Use.getOperand(0).getReg(); + if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg)) + return GR->getScalarOrVectorComponentType(BuildVecType); + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + // First check if any of the operands have a type. + for (unsigned i = 1; i < I->getNumOperands(); ++i) { + if (SPIRVType *OpType = + GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) { + const LLT &ResLLT = MRI.getType(ResVReg); + return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(), + MIB, false); + } + } + // If that did not work, then check the uses. + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) { + Register ExtractResReg = Use.getOperand(0).getReg(); + if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) { + const LLT &ResLLT = MRI.getType(ResVReg); + return GR->getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), MIB, false); + } + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + const unsigned UseOpc = Use.getOpcode(); + assert(UseOpc == TargetOpcode::G_BUILD_VECTOR || + UseOpc == TargetOpcode::G_SHUFFLE_VECTOR); + // It's possible that the use instruction has not been processed yet. + // We should look at the operands of the use to determine the type. + for (unsigned i = 1; i < Use.getNumOperands(); ++i) { + if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg())) + return Type; + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast)) + return nullptr; + + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + const unsigned UseOpc = Use.getOpcode(); + assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT || + UseOpc == TargetOpcode::G_SHUFFLE_VECTOR); + Register UseResultReg = Use.getOperand(0).getReg(); + if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) { + SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType); + const LLT &BitcastLLT = MRI.getType(ResVReg); + if (BitcastLLT.isVector()) + return GR->getOrCreateSPIRVVectorType( + ScalarType, BitcastLLT.getNumElements(), MIB, false); + return ScalarType; + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + // The result type of G_ANYEXT cannot be inferred from its operand. + // We use the result register's LLT to determine the correct integer type. + const LLT &ResLLT = MIB.getMRI()->getType(ResVReg); + if (!ResLLT.isScalar()) + return nullptr; + return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB); +} + +static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR) { + if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 || + !I->getOperand(1).isReg()) + return nullptr; + + SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg()); + if (!OpType) + return nullptr; + return OpType; +} + +static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB) { + LLVM_DEBUG(dbgs() << "Processing instruction: " << *I); + MachineRegisterInfo &MRI = MF.getRegInfo(); + const unsigned Opcode = I->getOpcode(); + Register ResVReg = I->getOperand(0).getReg(); + SPIRVType *ResType = nullptr; + + switch (Opcode) { + case TargetOpcode::G_CONSTANT: { + ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_UNMERGE_VALUES: { + // This one is special as it defines multiple registers. + if (deduceAndAssignTypeForGUnmerge(I, MF, GR)) + return true; + break; + } + case TargetOpcode::G_EXTRACT_VECTOR_ELT: { + ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg); + break; + } + case TargetOpcode::G_BUILD_VECTOR: { + ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_ANYEXT: { + ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_IMPLICIT_DEF: { + ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg); + break; + } + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: { + ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg); + break; + } + default: + ResType = deduceTypeForDefault(I, MF, GR); + break; + } + + if (ResType) { + LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n"); + GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF); + + if (!MRI.getRegClassOrNull(ResVReg)) { + LLVM_DEBUG(dbgs() << "Updating the register class.\n"); + setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true); + } + return true; + } + return false; +} + +static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR, + MachineRegisterInfo &MRI) { + LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: " + << I;); + if (I.getNumDefs() == 0) { + LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n"); + return false; + } + if (!mayBeInserted(I.getOpcode())) { + LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n"); + return false; + } + + Register ResultRegister = I.defs().begin()->getReg(); + if (GR->getSPIRVTypeForVReg(ResultRegister)) { + LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n"); + if (!MRI.getRegClassOrNull(ResultRegister)) { + LLVM_DEBUG(dbgs() << "Updating the register class.\n"); + setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister), + GR, &MRI, *GR->CurMF, true); + } + return false; + } + + return true; +} + +static void registerSpirvTypeForNewInstructions(MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + SmallVector Worklist; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &I : MBB) { + if (requiresSpirvType(I, GR, MRI)) { + Worklist.push_back(&I); + } + } + } + + if (Worklist.empty()) { + LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n"); + return; + } + + LLVM_DEBUG(dbgs() << "Initial worklist:\n"; + for (auto *I : Worklist) { I->dump(); }); + + bool Changed = true; + while (Changed) { + Changed = false; + SmallVector NextWorklist; + + for (MachineInstr *I : Worklist) { + if (deduceAndAssignSpirvType(I, MF, GR, MIB)) { + Changed = true; + } else { + NextWorklist.push_back(I); + } + } + Worklist = NextWorklist; + LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n"); + } + + if (!Worklist.empty()) { + LLVM_DEBUG(dbgs() << "Remaining worklist:\n"; + for (auto *I : Worklist) { I->dump(); }); + assert(Worklist.empty() && "Worklist is not empty"); + } +} + +static void ensureAssignTypeForTypeFolding(MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function " + << MF.getName() << "\n"); + MachineRegisterInfo &MRI = MF.getRegInfo(); + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (!isTypeFoldingSupported(MI.getOpcode())) + continue; + if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg()) + continue; + + LLVM_DEBUG(dbgs() << "Processing instruction: " << MI); + + // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE + bool HasAssignType = false; + Register ResultRegister = MI.defs().begin()->getReg(); + // All uses of Result register + for (MachineInstr &UseInstr : + MRI.use_nodbg_instructions(ResultRegister)) { + if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) { + HasAssignType = true; + LLVM_DEBUG(dbgs() << " Instruction already has an ASSIGN_TYPE use: " + << UseInstr); + break; } } + + if (!HasAssignType) { + Register ResultRegister = MI.defs().begin()->getReg(); + SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister); + LLVM_DEBUG( + dbgs() << " Adding ASSIGN_TYPE for ResultRegister: " + << printReg(ResultRegister, MRI.getTargetRegisterInfo()) + << " with type: " << *ResultType); + insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI); + } } } } @@ -156,9 +437,8 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) { SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); GR->setCurrentFunc(MF); MachineIRBuilder MIB(MF); - - processNewInstrs(MF, GR, MIB); - + registerSpirvTypeForNewInstructions(MF, GR, MIB); + ensureAssignTypeForTypeFolding(MF, GR, MIB); return true; } From f2f29a52e3c61d52dbcb2f4728318305026016b3 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Fri, 24 Oct 2025 13:21:24 -0400 Subject: [PATCH 6/9] [SPIRV] Use a worklist in the post-legalizer This commit refactors the SPIRV post-legalizer to use a worklist to process new instructions. Previously, the post-legalizer would iterate through all instructions and try to assign types. This could fail if a new instruction depended on another new instruction that had not been processed yet. The new implementation adds all new instructions that require a SPIR-V type to a worklist. It then iteratively processes the worklist until it is empty. This ensures that all dependencies are met before an instruction is processed. This change makes the post-legalizer more robust and fixes potential ordering issues with newly generated instructions. Existing tests cover existing functionality. More tests will be added as the legalizer is modified. Part of #153091 --- llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 412 ++++++++++++++++--- 1 file changed, 359 insertions(+), 53 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp index d17528dd882bf..b6c650c802247 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp @@ -17,7 +17,8 @@ #include "SPIRV.h" #include "SPIRVSubtarget.h" #include "SPIRVUtils.h" -#include "llvm/IR/Attributes.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "spirv-postlegalizer" @@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, static bool mayBeInserted(unsigned Opcode) { switch (Opcode) { + case TargetOpcode::G_CONSTANT: + case TargetOpcode::G_UNMERGE_VALUES: + case TargetOpcode::G_EXTRACT_VECTOR_ELT: + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: case TargetOpcode::G_SMAX: case TargetOpcode::G_UMAX: case TargetOpcode::G_SMIN: @@ -53,73 +59,372 @@ static bool mayBeInserted(unsigned Opcode) { case TargetOpcode::G_FMINIMUM: case TargetOpcode::G_FMAXNUM: case TargetOpcode::G_FMAXIMUM: + case TargetOpcode::G_IMPLICIT_DEF: + case TargetOpcode::G_BUILD_VECTOR: + case TargetOpcode::G_ICMP: + case TargetOpcode::G_ANYEXT: return true; default: return isTypeFoldingSupported(Opcode); } } -static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { +static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { MachineRegisterInfo &MRI = MF.getRegInfo(); + const LLT &Ty = MRI.getType(ResVReg); + unsigned BitWidth = Ty.getScalarSizeInBits(); + return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB); +} - for (MachineBasicBlock &MBB : MF) { - for (MachineInstr &I : MBB) { - const unsigned Opcode = I.getOpcode(); - if (Opcode == TargetOpcode::G_UNMERGE_VALUES) { - unsigned ArgI = I.getNumOperands() - 1; - Register SrcReg = I.getOperand(ArgI).isReg() - ? I.getOperand(ArgI).getReg() - : Register(0); - SPIRVType *DefType = - SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr; - if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector) - report_fatal_error( - "cannot select G_UNMERGE_VALUES with a non-vector argument"); - SPIRVType *ScalarType = - GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); - for (unsigned i = 0; i < I.getNumDefs(); ++i) { - Register ResVReg = I.getOperand(i).getReg(); - SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg); - if (!ResType) { - // There was no "assign type" actions, let's fix this now +static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg(); + if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) { + if (DefType->getOpcode() == SPIRV::OpTypeVector) { + SPIRVType *ScalarType = + GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + for (unsigned i = 0; i < I->getNumDefs(); ++i) { + Register DefReg = I->getOperand(i).getReg(); + if (!GR->getSPIRVTypeForVReg(DefReg)) { + LLT DefLLT = MRI.getType(DefReg); + SPIRVType *ResType; + if (DefLLT.isVector()) { + const SPIRVInstrInfo *TII = + MF.getSubtarget().getInstrInfo(); + ResType = GR->getOrCreateSPIRVVectorType( + ScalarType, DefLLT.getNumElements(), *I, *TII); + } else { ResType = ScalarType; - setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true); } + setRegClassType(DefReg, ResType, GR, &MRI, MF); } - } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 && - I.getNumOperands() > 1 && I.getOperand(1).isReg()) { - // Legalizer may have added a new instructions and introduced new - // registers, we must decorate them as if they were introduced in a - // non-automatic way - Register ResVReg = I.getOperand(0).getReg(); - // Check if the register defined by the instruction is newly generated - // or already processed - // Check if we have type defined for operands of the new instruction - bool IsKnownReg = MRI.getRegClassOrNull(ResVReg); - SPIRVType *ResVType = GR->getSPIRVTypeForVReg( - IsKnownReg ? ResVReg : I.getOperand(1).getReg()); - if (!ResVType) - continue; - // Set type & class - if (!IsKnownReg) - setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true); - // If this is a simple operation that is to be reduced by TableGen - // definition we must apply some of pre-legalizer rules here - if (isTypeFoldingSupported(Opcode)) { - processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg)); - if (IsKnownReg && MRI.hasOneUse(ResVReg)) { - MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg); - if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE) - continue; - } - insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI); + } + return true; + } + } + return false; +} + +static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + Register VecReg = I->getOperand(1).getReg(); + if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) { + assert(VecType->getOpcode() == SPIRV::OpTypeVector); + return GR->getScalarOrVectorComponentType(VecType); + } + + // If not handled yet, then check if it is used in a G_BUILD_VECTOR. + // If so get the type from there. + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) { + Register BuildVecResReg = Use.getOperand(0).getReg(); + if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg)) + return GR->getScalarOrVectorComponentType(BuildVecType); + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + // First check if any of the operands have a type. + for (unsigned i = 1; i < I->getNumOperands(); ++i) { + if (SPIRVType *OpType = + GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) { + const LLT &ResLLT = MRI.getType(ResVReg); + return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(), + MIB, false); + } + } + // If that did not work, then check the uses. + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) { + Register ExtractResReg = Use.getOperand(0).getReg(); + if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) { + const LLT &ResLLT = MRI.getType(ResVReg); + return GR->getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), MIB, false); + } + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + const unsigned UseOpc = Use.getOpcode(); + assert(UseOpc == TargetOpcode::G_BUILD_VECTOR || + UseOpc == TargetOpcode::G_SHUFFLE_VECTOR); + // It's possible that the use instruction has not been processed yet. + // We should look at the operands of the use to determine the type. + for (unsigned i = 1; i < Use.getNumOperands(); ++i) { + if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg())) + return Type; + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast)) + return nullptr; + + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + const unsigned UseOpc = Use.getOpcode(); + assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT || + UseOpc == TargetOpcode::G_SHUFFLE_VECTOR); + Register UseResultReg = Use.getOperand(0).getReg(); + if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) { + SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType); + const LLT &BitcastLLT = MRI.getType(ResVReg); + if (BitcastLLT.isVector()) + return GR->getOrCreateSPIRVVectorType( + ScalarType, BitcastLLT.getNumElements(), MIB, false); + return ScalarType; + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + // The result type of G_ANYEXT cannot be inferred from its operand. + // We use the result register's LLT to determine the correct integer type. + const LLT &ResLLT = MIB.getMRI()->getType(ResVReg); + if (!ResLLT.isScalar()) + return nullptr; + return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB); +} + +static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR) { + if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 || + !I->getOperand(1).isReg()) + return nullptr; + + SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg()); + if (!OpType) + return nullptr; + return OpType; +} + +static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB) { + LLVM_DEBUG(dbgs() << "Processing instruction: " << *I); + MachineRegisterInfo &MRI = MF.getRegInfo(); + const unsigned Opcode = I->getOpcode(); + Register ResVReg = I->getOperand(0).getReg(); + SPIRVType *ResType = nullptr; + + switch (Opcode) { + case TargetOpcode::G_CONSTANT: { + ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_UNMERGE_VALUES: { + // This one is special as it defines multiple registers. + if (deduceAndAssignTypeForGUnmerge(I, MF, GR)) + return true; + break; + } + case TargetOpcode::G_EXTRACT_VECTOR_ELT: { + ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg); + break; + } + case TargetOpcode::G_BUILD_VECTOR: { + ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_ANYEXT: { + ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_IMPLICIT_DEF: { + ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg); + break; + } + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: { + ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg); + break; + } + default: + ResType = deduceTypeForDefault(I, MF, GR); + break; + } + + if (ResType) { + LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n"); + GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF); + + if (!MRI.getRegClassOrNull(ResVReg)) { + LLVM_DEBUG(dbgs() << "Updating the register class.\n"); + setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true); + } + return true; + } + return false; +} + +static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR, + MachineRegisterInfo &MRI) { + LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: " + << I;); + if (I.getNumDefs() == 0) { + LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n"); + return false; + } + if (!mayBeInserted(I.getOpcode())) { + LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n"); + return false; + } + + Register ResultRegister = I.defs().begin()->getReg(); + if (GR->getSPIRVTypeForVReg(ResultRegister)) { + LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n"); + if (!MRI.getRegClassOrNull(ResultRegister)) { + LLVM_DEBUG(dbgs() << "Updating the register class.\n"); + setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister), + GR, &MRI, *GR->CurMF, true); + } + return false; + } + + return true; +} + +static void registerSpirvTypeForNewInstructions(MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + SmallVector Worklist; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &I : MBB) { + if (requiresSpirvType(I, GR, MRI)) { + Worklist.push_back(&I); + } + } + } + + if (Worklist.empty()) { + LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n"); + return; + } + + LLVM_DEBUG(dbgs() << "Initial worklist:\n"; + for (auto *I : Worklist) { I->dump(); }); + + bool Changed = true; + while (Changed) { + Changed = false; + SmallVector NextWorklist; + + for (MachineInstr *I : Worklist) { + if (deduceAndAssignSpirvType(I, MF, GR, MIB)) { + Changed = true; + } else { + NextWorklist.push_back(I); + } + } + Worklist = NextWorklist; + LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n"); + } + + if (!Worklist.empty()) { + LLVM_DEBUG(dbgs() << "Remaining worklist:\n"; + for (auto *I : Worklist) { I->dump(); }); + assert(Worklist.empty() && "Worklist is not empty"); + } +} + +static void ensureAssignTypeForTypeFolding(MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { + LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function " + << MF.getName() << "\n"); + MachineRegisterInfo &MRI = MF.getRegInfo(); + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (!isTypeFoldingSupported(MI.getOpcode())) + continue; + if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg()) + continue; + + LLVM_DEBUG(dbgs() << "Processing instruction: " << MI); + + // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE + bool HasAssignType = false; + Register ResultRegister = MI.defs().begin()->getReg(); + // All uses of Result register + for (MachineInstr &UseInstr : + MRI.use_nodbg_instructions(ResultRegister)) { + if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) { + HasAssignType = true; + LLVM_DEBUG(dbgs() << " Instruction already has an ASSIGN_TYPE use: " + << UseInstr); + break; } } + + if (!HasAssignType) { + Register ResultRegister = MI.defs().begin()->getReg(); + SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister); + LLVM_DEBUG( + dbgs() << " Adding ASSIGN_TYPE for ResultRegister: " + << printReg(ResultRegister, MRI.getTargetRegisterInfo()) + << " with type: " << *ResultType); + insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI); + } } } } +static void lowerExtractVectorElements(MachineFunction &MF) { + SmallVector ExtractInstrs; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) { + ExtractInstrs.push_back(&MI); + } + } + } + + for (MachineInstr *MI : ExtractInstrs) { + MachineIRBuilder MIB(*MI); + Register Dst = MI->getOperand(0).getReg(); + Register Vec = MI->getOperand(1).getReg(); + Register Idx = MI->getOperand(2).getReg(); + + auto Intr = MIB.buildIntrinsic(Intrinsic::spv_extractelt, Dst, true, false); + Intr.addUse(Vec); + Intr.addUse(Idx); + + MI->eraseFromParent(); + } +} + // Do a preorder traversal of the CFG starting from the BB |Start|. // point. Calls |op| on each basic block encountered during the traversal. void visit(MachineFunction &MF, MachineBasicBlock &Start, @@ -156,8 +461,9 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) { SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); GR->setCurrentFunc(MF); MachineIRBuilder MIB(MF); - - processNewInstrs(MF, GR, MIB); + registerSpirvTypeForNewInstructions(MF, GR, MIB); + ensureAssignTypeForTypeFolding(MF, GR, MIB); + lowerExtractVectorElements(MF); return true; } From c248de25c59edffdfa2a11e05d78610b10488306 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Tue, 28 Oct 2025 13:01:07 -0400 Subject: [PATCH 7/9] Set insertion point in MIB. --- llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 22 +++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp index b6c650c802247..69de5a6360c66 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp @@ -138,11 +138,14 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I, MachineIRBuilder &MIB, Register ResVReg) { MachineRegisterInfo &MRI = MF.getRegInfo(); + LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Processing " << *I << "\n"); // First check if any of the operands have a type. for (unsigned i = 1; i < I->getNumOperands(); ++i) { if (SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) { const LLT &ResLLT = MRI.getType(ResVReg); + LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found operand type " + << *OpType << ", returning vector type\n"); return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(), MIB, false); } @@ -153,11 +156,14 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I, Register ExtractResReg = Use.getOperand(0).getReg(); if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) { const LLT &ResLLT = MRI.getType(ResVReg); + LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found use type " + << *ScalarType << ", returning vector type\n"); return GR->getOrCreateSPIRVVectorType( ScalarType, ResLLT.getNumElements(), MIB, false); } } } + LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Could not deduce type\n"); return nullptr; } @@ -191,7 +197,8 @@ static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF, for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { const unsigned UseOpc = Use.getOpcode(); assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT || - UseOpc == TargetOpcode::G_SHUFFLE_VECTOR); + UseOpc == TargetOpcode::G_SHUFFLE_VECTOR || + UseOpc == TargetOpcode::G_BUILD_VECTOR); Register UseResultReg = Use.getOperand(0).getReg(); if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) { SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType); @@ -316,8 +323,7 @@ static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR, } static void registerSpirvTypeForNewInstructions(MachineFunction &MF, - SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { + SPIRVGlobalRegistry *GR) { MachineRegisterInfo &MRI = MF.getRegInfo(); SmallVector Worklist; for (MachineBasicBlock &MBB : MF) { @@ -342,6 +348,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF, SmallVector NextWorklist; for (MachineInstr *I : Worklist) { + MachineIRBuilder MIB(*I); if (deduceAndAssignSpirvType(I, MF, GR, MIB)) { Changed = true; } else { @@ -360,8 +367,7 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF, } static void ensureAssignTypeForTypeFolding(MachineFunction &MF, - SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { + SPIRVGlobalRegistry *GR) { LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function " << MF.getName() << "\n"); MachineRegisterInfo &MRI = MF.getRegInfo(); @@ -395,6 +401,7 @@ static void ensureAssignTypeForTypeFolding(MachineFunction &MF, dbgs() << " Adding ASSIGN_TYPE for ResultRegister: " << printReg(ResultRegister, MRI.getTargetRegisterInfo()) << " with type: " << *ResultType); + MachineIRBuilder MIB(MI); insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI); } } @@ -460,9 +467,8 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) { const SPIRVSubtarget &ST = MF.getSubtarget(); SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); GR->setCurrentFunc(MF); - MachineIRBuilder MIB(MF); - registerSpirvTypeForNewInstructions(MF, GR, MIB); - ensureAssignTypeForTypeFolding(MF, GR, MIB); + registerSpirvTypeForNewInstructions(MF, GR); + ensureAssignTypeForTypeFolding(MF, GR); lowerExtractVectorElements(MF); return true; From 547301b644c8e17cd4da1bf36e39101cbea30d1d Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 29 Oct 2025 09:36:48 -0400 Subject: [PATCH 8/9] Handle vector shuffle. --- llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 138 +++++++++++++++---- 1 file changed, 108 insertions(+), 30 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp index 69de5a6360c66..644e010d8cf94 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp @@ -62,6 +62,7 @@ static bool mayBeInserted(unsigned Opcode) { case TargetOpcode::G_IMPLICIT_DEF: case TargetOpcode::G_BUILD_VECTOR: case TargetOpcode::G_ICMP: + case TargetOpcode::G_SHUFFLE_VECTOR: case TargetOpcode::G_ANYEXT: return true; default: @@ -83,30 +84,47 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF, SPIRVGlobalRegistry *GR) { MachineRegisterInfo &MRI = MF.getRegInfo(); Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg(); + SPIRVType *ScalarType = nullptr; if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) { - if (DefType->getOpcode() == SPIRV::OpTypeVector) { - SPIRVType *ScalarType = - GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); - for (unsigned i = 0; i < I->getNumDefs(); ++i) { - Register DefReg = I->getOperand(i).getReg(); - if (!GR->getSPIRVTypeForVReg(DefReg)) { - LLT DefLLT = MRI.getType(DefReg); - SPIRVType *ResType; - if (DefLLT.isVector()) { - const SPIRVInstrInfo *TII = - MF.getSubtarget().getInstrInfo(); - ResType = GR->getOrCreateSPIRVVectorType( - ScalarType, DefLLT.getNumElements(), *I, *TII); - } else { - ResType = ScalarType; - } - setRegClassType(DefReg, ResType, GR, &MRI, MF); + assert(DefType->getOpcode() == SPIRV::OpTypeVector); + ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + } + + if (!ScalarType) { + // If we could not deduce the type from the source, try to deduce it from + // the uses of the results. + for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) { + for (const auto &Use : + MRI.use_nodbg_instructions(I->getOperand(i).getReg())) { + assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR && + "Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR"); + if (auto *VecType = + GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) { + ScalarType = GR->getScalarOrVectorComponentType(VecType); + break; } } - return true; } } - return false; + + if (!ScalarType) + return false; + + for (unsigned i = 0; i < I->getNumDefs(); ++i) { + Register DefReg = I->getOperand(i).getReg(); + if (GR->getSPIRVTypeForVReg(DefReg)) + continue; + + LLT DefLLT = MRI.getType(DefReg); + SPIRVType *ResType = + DefLLT.isVector() + ? GR->getOrCreateSPIRVVectorType( + ScalarType, DefLLT.getNumElements(), *I, + *MF.getSubtarget().getInstrInfo()) + : ScalarType; + setRegClassType(DefReg, ResType, GR, &MRI, MF); + } + return true; } static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I, @@ -167,20 +185,61 @@ static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I, return nullptr; } +static SPIRVType *deduceTypeForGShuffleVector(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + const LLT &ResLLT = MRI.getType(ResVReg); + assert(ResLLT.isVector() && "G_SHUFFLE_VECTOR result must be a vector"); + + // The result element type should be the same as the input vector element + // types. + for (unsigned i = 1; i <= 2; ++i) { + Register VReg = I->getOperand(i).getReg(); + if (auto *VType = GR->getSPIRVTypeForVReg(VReg)) { + if (auto *ScalarType = GR->getScalarOrVectorComponentType(VType)) + return GR->getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), MIB, false); + } + } + return nullptr; +} + static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I, MachineFunction &MF, SPIRVGlobalRegistry *GR, Register ResVReg) { MachineRegisterInfo &MRI = MF.getRegInfo(); - for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { - const unsigned UseOpc = Use.getOpcode(); - assert(UseOpc == TargetOpcode::G_BUILD_VECTOR || - UseOpc == TargetOpcode::G_SHUFFLE_VECTOR); - // It's possible that the use instruction has not been processed yet. - // We should look at the operands of the use to determine the type. - for (unsigned i = 1; i < Use.getNumOperands(); ++i) { - if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg())) - return Type; + for (const MachineInstr &Use : MRI.use_nodbg_instructions(ResVReg)) { + SPIRVType *ScalarType = nullptr; + switch (Use.getOpcode()) { + case TargetOpcode::G_BUILD_VECTOR: + case TargetOpcode::G_UNMERGE_VALUES: + // It's possible that the use instruction has not been processed yet. + // We should look at the operands of the use to determine the type. + for (unsigned i = 1; i < Use.getNumOperands(); ++i) { + if (SPIRVType *OpType = + GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg())) + ScalarType = GR->getScalarOrVectorComponentType(OpType); + } + break; + case TargetOpcode::G_SHUFFLE_VECTOR: + // For G_SHUFFLE_VECTOR, only look at the vector input operands. + if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(1).getReg())) + ScalarType = GR->getScalarOrVectorComponentType(Type); + if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(2).getReg())) + ScalarType = GR->getScalarOrVectorComponentType(Type); + break; + } + if (ScalarType) { + const LLT &ResLLT = MRI.getType(ResVReg); + if (!ResLLT.isVector()) + return ScalarType; + return GR->getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), *I, + *MF.getSubtarget().getInstrInfo()); } } return nullptr; @@ -198,7 +257,8 @@ static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF, const unsigned UseOpc = Use.getOpcode(); assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT || UseOpc == TargetOpcode::G_SHUFFLE_VECTOR || - UseOpc == TargetOpcode::G_BUILD_VECTOR); + UseOpc == TargetOpcode::G_BUILD_VECTOR || + UseOpc == TargetOpcode::G_UNMERGE_VALUES); Register UseResultReg = Use.getOperand(0).getReg(); if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) { SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType); @@ -264,6 +324,10 @@ static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF, ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg); break; } + case TargetOpcode::G_SHUFFLE_VECTOR: { + ResType = deduceTypeForGShuffleVector(I, MF, GR, MIB, ResVReg); + break; + } case TargetOpcode::G_ANYEXT: { ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg); break; @@ -362,7 +426,21 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF, if (!Worklist.empty()) { LLVM_DEBUG(dbgs() << "Remaining worklist:\n"; for (auto *I : Worklist) { I->dump(); }); - assert(Worklist.empty() && "Worklist is not empty"); + for (auto *I : Worklist) { + MachineIRBuilder MIB(*I); + Register ResVReg = I->getOperand(0).getReg(); + const LLT &ResLLT = MRI.getType(ResVReg); + SPIRVType *ResType = nullptr; + if (ResLLT.isVector()) { + SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType( + ResLLT.getElementType().getSizeInBits(), MIB); + ResType = GR->getOrCreateSPIRVVectorType( + CompType, ResLLT.getNumElements(), MIB, false); + } else { + ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB); + } + setRegClassType(ResVReg, ResType, GR, &MRI, MF, true); + } } } From 0e00c3a495d6dc5ac0786a649d402aa7d98e99a5 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Fri, 24 Oct 2025 15:16:19 -0400 Subject: [PATCH 9/9] [SPIRV] Add legalization for long vectors This patch introduces the necessary infrastructure to legalize vector operations on vectors that are longer than what the SPIR-V target supports. For instance, shaders only support vectors up to 4 elements. The legalization is done by splitting the long vectors into smaller vectors of a legal size. Specifically, this patch does the following: - Introduces `vectorElementCountIsGreaterThan` and `vectorElementCountIsLessThanOrEqualTo` legality predicates. - Adds legalization rules for `G_SHUFFLE_VECTOR`, `G_EXTRACT_VECTOR_ELT`, `G_BUILD_VECTOR`, `G_CONCAT_VECTORS`, `G_SPLAT_VECTOR`, and `G_UNMERGE_VALUES`. - Handles `G_BITCAST` of long vectors by converting them to `@llvm.spv.bitcast` intrinsics which are then legalized. - Updates `selectUnmergeValues` to handle extraction of both scalars and vectors from a larger vector, using `OpCompositeExtract` and `OpVectorShuffle` respectively. - Adds a test case to verify the legalization of a bitcast between a `<8 x i32>` and `<4 x f64>`, which is a pattern generated by HLSL's `asuint` and `asdouble` intrinsics. Fixes: https://github.com/llvm/llvm-project/pull/165444 --- .../llvm/CodeGen/GlobalISel/LegalizerInfo.h | 10 ++ .../CodeGen/GlobalISel/LegalityPredicates.cpp | 20 +++ .../Target/SPIRV/SPIRVInstructionSelector.cpp | 50 ++++-- llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 165 ++++++++++++++++-- llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h | 4 + .../SPIRV/legalization/load-store-global.ll | 84 +++++++++ .../vector-legalization-kernel.ll | 70 ++++++++ 7 files changed, 379 insertions(+), 24 deletions(-) create mode 100644 llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll create mode 100644 llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h index 51318c9c2736d..a8748965eb2e8 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size); LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx, unsigned Size); +/// True iff the specified type index is a vector with an element size +/// that's greater than the given size. +LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx, + unsigned Size); + +/// True iff the specified type index is a vector with an element size +/// that's less than or equal to the given size. +LLVM_ABI LegalityPredicate +vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size); + /// True iff the specified type index is a scalar or a vector with an element /// type that's wider than the given size. LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx, diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp index 30c2d089c3121..5e7cd5fd5d9ad 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp @@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx, }; } +LegalityPredicate +LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx, + unsigned Size) { + + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size; + }; +} + +LegalityPredicate +LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, + unsigned Size) { + + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size; + }; +} + LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx, unsigned Size) { return [=](const LegalityQuery &Query) { diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index ccc2c0fc467fb..f9e6a224f581b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1526,33 +1526,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const { unsigned ArgI = I.getNumOperands() - 1; Register SrcReg = I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0); - SPIRVType *DefType = + SPIRVType *SrcType = SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr; - if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector) + if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector) report_fatal_error( "cannot select G_UNMERGE_VALUES with a non-vector argument"); SPIRVType *ScalarType = - GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg()); MachineBasicBlock &BB = *I.getParent(); bool Res = false; + unsigned CurrentIndex = 0; for (unsigned i = 0; i < I.getNumDefs(); ++i) { Register ResVReg = I.getOperand(i).getReg(); SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg); if (!ResType) { - // There was no "assign type" actions, let's fix this now - ResType = ScalarType; + LLT ResLLT = MRI->getType(ResVReg); + assert(ResLLT.isValid()); + if (ResLLT.isVector()) { + ResType = GR.getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), I, TII); + } else { + ResType = ScalarType; + } MRI->setRegClass(ResVReg, GR.getRegClass(ResType)); - MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType))); GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF); } - auto MIB = - BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(SrcReg) - .addImm(static_cast(i)); - Res |= MIB.constrainAllUses(TII, TRI, RBI); + + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII); + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .addUse(UndefReg); + unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType); + for (unsigned j = 0; j < NumElements; ++j) { + MIB.addImm(CurrentIndex + j); + } + CurrentIndex += NumElements; + Res |= MIB.constrainAllUses(TII, TRI, RBI); + } else { + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .addImm(CurrentIndex); + CurrentIndex++; + Res |= MIB.constrainAllUses(TII, TRI, RBI); + } } return Res; } diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 53074ea3b2597..a7d6bde3c5f1a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -14,16 +14,22 @@ #include "SPIRV.h" #include "SPIRVGlobalRegistry.h" #include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" using namespace llvm; using namespace llvm::LegalizeActions; using namespace llvm::LegalityPredicates; +#define DEBUG_TYPE "spirv-legalizer" + LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) { return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) { const LLT Ty = Query.Types[TypeIdx]; @@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, + v3s1, v3s8, v3s16, v3s32, v3s64, + v4s1, v4s8, v4s16, v4s32, v4s64}; + auto allScalarsAndVectors = { s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, @@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12}; + auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors; + bool IsExtendedInts = ST.canUseExtension( SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) || @@ -148,14 +160,65 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { return IsExtendedInts && Ty.isValid(); }; - for (auto Opc : getTypeFoldingSupportedOpcodes()) - getActionDefinitionsBuilder(Opc).custom(); + uint32_t MaxVectorSize = ST.isShader() ? 4 : 16; - getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); + for (auto Opc : getTypeFoldingSupportedOpcodes()) { + if (Opc != G_EXTRACT_VECTOR_ELT) + getActionDefinitionsBuilder(Opc).custom(); + } - // TODO: add proper rules for vectors legalization. - getActionDefinitionsBuilder( - {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR}) + getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom(); + + getActionDefinitionsBuilder(G_SHUFFLE_VECTOR) + .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes) + .moreElementsToNextPow2(0) + .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize)) + .moreElementsToNextPow2(1) + .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize)) + .alwaysLegal(); + + getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT) + .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize)) + .moreElementsToNextPow2(1) + .fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 1, ElementCount::getFixed(MaxVectorSize))) + .custom(); + + // Illegal G_UNMERGE_VALUES instructions should be handled + // during the combine phase. + getActionDefinitionsBuilder(G_BUILD_VECTOR) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))); + + // When entering the legalizer, there should be no G_BITCAST instructions. + // They should all be calls to the `spv_bitcast` intrinsic. The call to + // the intrinsic will be converted to a G_BITCAST during legalization if + // the vectors are not legal. After using the rules to legalize a G_BITCAST, + // we turn it back into a call to the intrinsic with a custom ruel to avoid + // potential machines verifier failures. + getActionDefinitionsBuilder(G_BITCAST) + .moreElementsToNextPow2(0) + .moreElementsToNextPow2(1) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))) + .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize)) + .custom(); + + getActionDefinitionsBuilder(G_CONCAT_VECTORS) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .moreElementsToNextPow2(0) + .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize)) + .alwaysLegal(); + + getActionDefinitionsBuilder(G_SPLAT_VECTOR) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .moreElementsToNextPow2(0) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementSizeTo(0, MaxVectorSize)) .alwaysLegal(); // Vector Reduction Operations @@ -164,7 +227,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN, G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM, G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR}) - .legalFor(allVectors) + .legalFor(allowedVectorTypes) .scalarize(1) .lower(); @@ -172,9 +235,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { .scalarize(2) .lower(); - // Merge/Unmerge - // TODO: add proper legalization rules. - getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal(); + // Illegal G_UNMERGE_VALUES instructions should be handled + // during the combine phase. + getActionDefinitionsBuilder(G_UNMERGE_VALUES) + .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize)); getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs))); @@ -228,7 +292,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { all(typeInSet(0, allPtrsScalarsAndVectors), typeInSet(1, allPtrsScalarsAndVectors))); - getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal(); + getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) + .legalFor({s1}) + .legalFor(allFloatAndIntScalarsAndPtrs) + .legalFor(allowedVectorTypes) + .moreElementsToNextPow2(0) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))); getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal(); @@ -287,6 +358,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { // Pointer-handling. getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs); + // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); @@ -374,6 +447,12 @@ bool SPIRVLegalizerInfo::legalizeCustom( default: // TODO: implement legalization for other opcodes. return true; + case TargetOpcode::G_BITCAST: + return legalizeBitcast(Helper, MI); + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: + return legalizeIntrinsic(Helper, MI); + case TargetOpcode::G_IS_FPCLASS: return legalizeIsFPClass(Helper, MI, LocObserver); case TargetOpcode::G_ICMP: { @@ -400,6 +479,70 @@ bool SPIRVLegalizerInfo::legalizeCustom( } } +bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, + MachineInstr &MI) const { + LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI); + + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); + const SPIRVSubtarget &ST = MI.getMF()->getSubtarget(); + + auto IntrinsicID = cast(MI).getIntrinsicID(); + if (IntrinsicID == Intrinsic::spv_bitcast) { + LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n"); + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(2).getReg(); + LLT DstTy = MRI.getType(DstReg); + LLT SrcTy = MRI.getType(SrcReg); + + int32_t MaxVectorSize = ST.isShader() ? 4 : 16; + + bool DstNeedsLegalization = false; + bool SrcNeedsLegalization = false; + + if (DstTy.isVector()) { + if (DstTy.getNumElements() > 4 && + !isPowerOf2_32(DstTy.getNumElements())) { + DstNeedsLegalization = true; + } + + if (DstTy.getNumElements() > MaxVectorSize) { + DstNeedsLegalization = true; + } + } + + if (SrcTy.isVector()) { + if (SrcTy.getNumElements() > 4 && + !isPowerOf2_32(SrcTy.getNumElements())) { + SrcNeedsLegalization = true; + } + + if (SrcTy.getNumElements() > MaxVectorSize) { + SrcNeedsLegalization = true; + } + } + + if (DstNeedsLegalization || SrcNeedsLegalization) { + LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n"); + MIRBuilder.buildBitcast(DstReg, SrcReg); + MI.eraseFromParent(); + } + return true; + } + return true; +} + +bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper, + MachineInstr &MI) const { + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + SmallVector DstRegs = {DstReg}; + MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg); + MI.eraseFromParent(); + return true; +} + // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted // to ensure that all instructions created during the lowering have SPIR-V types // assigned to them. diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h index eeefa4239c778..86e7e711caa60 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo { public: bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override; + bool legalizeIntrinsic(LegalizerHelper &Helper, + MachineInstr &MI) const override; + SPIRVLegalizerInfo(const SPIRVSubtarget &ST); private: bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const; + bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const; }; } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H diff --git a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll new file mode 100644 index 0000000000000..fbfec1b3ee7cf --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll @@ -0,0 +1,84 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion" +; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4 +; CHECK-DAG: %[[#double:]] = OpTypeFloat 64 +; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4 +; CHECK-DAG: %[[#v2i32:]] = OpTypeVector %[[#int]] 2 +; CHECK-DAG: %[[#ptr_private_v4i32:]] = OpTypePointer Private %[[#v4i32]] +; CHECK-DAG: %[[#ptr_private_v4f64:]] = OpTypePointer Private %[[#v4f64]] +; CHECK-DAG: %[[#global_double:]] = OpVariable %[[#ptr_private_v4f64]] Private + +@G_16 = internal addrspace(10) global [16 x i32] zeroinitializer +@G_4_double = internal addrspace(10) global <4 x double> zeroinitializer +@G_4_int = internal addrspace(10) global <4 x i32> zeroinitializer + + +; This is the way matrices will be represented in HLSL. The memory type will be +; an array, but it will be loaded as a vector. +; TODO: Legalization for loads and stores of long vectors is not implemented yet. │ +;define spir_func void @test_load_store_global() { │ +;entry: │ +; %0 = load <16 x i32>, ptr addrspace(10) @G_16, align 64 │ +; store <16 x i32> %0, ptr addrspace(10) @G_16, align 64 │ +; ret void │ +;} + +; This is the code pattern that can be generated from the `asuint` and `asdouble` +; HLSL intrinsics. + +; TODO: This cods not the best because instruction selection is not folding an +; extract from other intstruction. That needs to be handled. +define spir_func void @test_int32_double_conversion() { +; CHECK: %[[#test_int32_double_conversion]] = OpFunction +entry: + ; CHECK: %[[#LOAD:]] = OpLoad %[[#v4f64]] %[[#global_double]] + ; CHECK: %[[#VEC_SHUF1:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 1 + ; CHECK: %[[#VEC_SHUF2:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 2 3 + ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF1]] + ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF2]] + ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 0 + ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 2 + ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 0 + ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 2 + ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %[[#EXTRACT4]] + ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 1 + ; CHECK: %[[#EXTRACT6:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 3 + ; CHECK: %[[#EXTRACT7:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 1 + ; CHECK: %[[#EXTRACT8:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 3 + ; CHECK: %[[#CONSTRUCT2:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT5]] %[[#EXTRACT6]] %[[#EXTRACT7]] %[[#EXTRACT8]] + ; CHECK: %[[#EXTRACT9:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 0 + ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 0 + ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 1 + ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 1 + ; CHECK: %[[#EXTRACT13:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 2 + ; CHECK: %[[#EXTRACT14:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 2 + ; CHECK: %[[#EXTRACT15:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 3 + ; CHECK: %[[#EXTRACT16:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 3 + ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT9]] %[[#EXTRACT10]] + ; CHECK: %[[#CONSTRUCT4:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT11]] %[[#EXTRACT12]] + ; CHECK: %[[#CONSTRUCT5:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT13]] %[[#EXTRACT14]] + ; CHECK: %[[#CONSTRUCT6:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT15]] %[[#EXTRACT16]] + ; CHECK: %[[#BITCAST3:]] = OpBitcast %[[#double]] %[[#CONSTRUCT3]] + ; CHECK: %[[#BITCAST4:]] = OpBitcast %[[#double]] %[[#CONSTRUCT4]] + ; CHECK: %[[#BITCAST5:]] = OpBitcast %[[#double]] %[[#CONSTRUCT5]] + ; CHECK: %[[#BITCAST6:]] = OpBitcast %[[#double]] %[[#CONSTRUCT6]] + ; CHECK: %[[#CONSTRUCT7:]] = OpCompositeConstruct %[[#v4f64]] %[[#BITCAST3]] %[[#BITCAST4]] %[[#BITCAST5]] %[[#BITCAST6]] + ; CHECK: OpStore %[[#global_double]] %[[#CONSTRUCT7]] Aligned 32 + + %0 = load <8 x i32>, ptr addrspace(10) @G_4_double + %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> + store <8 x i32> %3, ptr addrspace(10) @G_4_double + ret void +} + +; Add a main function to make it a valid module for spirv-val +define void @main() #1 { + ret void +} + +attributes #1 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll new file mode 100644 index 0000000000000..3e39cb78800ee --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll @@ -0,0 +1,70 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion" +; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v8i32:]] = OpTypeVector %[[#int]] 8 +; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4 +; CHECK-DAG: %[[#ptr_func_v8i32:]] = OpTypePointer Function %[[#v8i32]] + +; CHECK-DAG: OpName %[[#test_v3f64_conversion:]] "test_v3f64_conversion" +; CHECK-DAG: %[[#double:]] = OpTypeFloat 64 +; CHECK-DAG: %[[#v3f64:]] = OpTypeVector %[[#double]] 3 +; CHECK-DAG: %[[#ptr_func_v3f64:]] = OpTypePointer Function %[[#v3f64]] +; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4 + +define spir_kernel void @test_int32_double_conversion(ptr %G_vec) { +; CHECK: %[[#test_int32_double_conversion]] = OpFunction +; CHECK: %[[#param:]] = OpFunctionParameter %[[#ptr_func_v8i32]] +entry: + ; CHECK: %[[#LOAD:]] = OpLoad %[[#v8i32]] %[[#param]] + ; CHECK: %[[#SHUF1:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 2 4 6 + ; CHECK: %[[#SHUF2:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 1 3 5 7 + ; CHECK: %[[#SHUF3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUF1]] %[[#SHUF2]] 0 4 1 5 2 6 3 7 + ; CHECK: OpStore %[[#param]] %[[#SHUF3]] + + %0 = load <8 x i32>, ptr %G_vec + %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> + store <8 x i32> %3, ptr %G_vec + ret void +} + +define spir_kernel void @test_v3f64_conversion(ptr %G_vec) { +; CHECK: %[[#test_v3f64_conversion:]] = OpFunction +; CHECK: %[[#param_v3f64:]] = OpFunctionParameter %[[#ptr_func_v3f64]] +entry: + ; CHECK: %[[#LOAD:]] = OpLoad %[[#v3f64]] %[[#param_v3f64]] + ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 0 + ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 1 + ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 2 + ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4f64]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %{{[a-zA-Z0-9_]+}} + ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v8i32]] %[[#CONSTRUCT1]] + ; CHECK: %[[#SHUFFLE1:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 0 2 4 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF + ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE1]] 0 + ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE1]] 1 + ; CHECK: %[[#EXTRACT6:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE1]] 2 + ; CHECK: %[[#SHUFFLE2:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 1 3 5 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF + ; CHECK: %[[#EXTRACT7:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE2]] 0 + ; CHECK: %[[#EXTRACT8:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE2]] 1 + ; CHECK: %[[#EXTRACT9:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE2]] 2 + ; CHECK: %[[#CONSTRUCT2:]] = OpCompositeConstruct %[[#v8i32]] %[[#EXTRACT4]] %[[#EXTRACT5]] %[[#EXTRACT6]] %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} + ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v8i32]] %[[#EXTRACT7]] %[[#EXTRACT8]] %[[#EXTRACT9]] %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} + ; CHECK: %[[#SHUFFLE2:]] = OpVectorShuffle %[[#v8i32]] %[[#CONSTRUCT2]] %[[#CONSTRUCT3]] 0 8 1 9 2 10 0xFFFFFFFF 0xFFFFFFFF + ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4f64]] %[[#SHUFFLE2]] + ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 0 + ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 1 + ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 2 + ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v3f64]] %[[#EXTRACT10]] %[[#EXTRACT11]] %[[#EXTRACT12]] + ; CHECK: OpStore %[[#param_v3f64]] %[[#CONSTRUCT3]] + %0 = load <3 x double>, ptr %G_vec + %1 = bitcast <3 x double> %0 to <6 x i32> + %2 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> + %3 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> + %4 = shufflevector <3 x i32> %2, <3 x i32> %3, <6 x i32> + %5 = bitcast <6 x i32> %4 to <3 x double> + store <3 x double> %5, ptr %G_vec + ret void +} +