diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h index 51318c9c2736d..9324bab3fe656 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size); LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx, unsigned Size); +/// True iff the specified type index is a vector with a number of elements +/// that's greater than the given size. +LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx, + unsigned Size); + +/// True iff the specified type index is a vector with a number of elements +/// that's less than or equal to the given size. +LLVM_ABI LegalityPredicate +vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size); + /// True iff the specified type index is a scalar or a vector with an element /// type that's wider than the given size. LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx, diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp index 30c2d089c3121..5e7cd5fd5d9ad 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp @@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx, }; } +LegalityPredicate +LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx, + unsigned Size) { + + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size; + }; +} + +LegalityPredicate +LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, + unsigned Size) { + + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size; + }; +} + LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx, unsigned Size) { return [=](const LegalityQuery &Query) { diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 2c27289e759eb..a2e29366dc4cc 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1781,33 +1781,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const { unsigned ArgI = I.getNumOperands() - 1; Register SrcReg = I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0); - SPIRVType *DefType = + SPIRVType *SrcType = SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr; - if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector) + if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector) report_fatal_error( "cannot select G_UNMERGE_VALUES with a non-vector argument"); SPIRVType *ScalarType = - GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg()); MachineBasicBlock &BB = *I.getParent(); bool Res = false; + unsigned CurrentIndex = 0; for (unsigned i = 0; i < I.getNumDefs(); ++i) { Register ResVReg = I.getOperand(i).getReg(); SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg); if (!ResType) { - // There was no "assign type" actions, let's fix this now - ResType = ScalarType; + LLT ResLLT = MRI->getType(ResVReg); + assert(ResLLT.isValid()); + if (ResLLT.isVector()) { + ResType = GR.getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), I, TII); + } else { + ResType = ScalarType; + } MRI->setRegClass(ResVReg, GR.getRegClass(ResType)); - MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType))); GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF); } - auto MIB = - BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(SrcReg) - .addImm(static_cast(i)); - Res |= MIB.constrainAllUses(TII, TRI, RBI); + + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII); + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .addUse(UndefReg); + unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType); + for (unsigned j = 0; j < NumElements; ++j) { + MIB.addImm(CurrentIndex + j); + } + CurrentIndex += NumElements; + Res |= MIB.constrainAllUses(TII, TRI, RBI); + } else { + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .addImm(CurrentIndex); + CurrentIndex++; + Res |= MIB.constrainAllUses(TII, TRI, RBI); + } } return Res; } diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 53074ea3b2597..32739d7e5e87f 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -14,16 +14,22 @@ #include "SPIRV.h" #include "SPIRVGlobalRegistry.h" #include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" using namespace llvm; using namespace llvm::LegalizeActions; using namespace llvm::LegalityPredicates; +#define DEBUG_TYPE "spirv-legalizer" + LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) { return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) { const LLT Ty = Query.Types[TypeIdx]; @@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, + v3s1, v3s8, v3s16, v3s32, v3s64, + v4s1, v4s8, v4s16, v4s32, v4s64}; + auto allScalarsAndVectors = { s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, @@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12}; + auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors; + bool IsExtendedInts = ST.canUseExtension( SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) || @@ -148,14 +160,70 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { return IsExtendedInts && Ty.isValid(); }; - for (auto Opc : getTypeFoldingSupportedOpcodes()) - getActionDefinitionsBuilder(Opc).custom(); + // The universal validation rules in the SPIR-V specification state that + // vector sizes are typically limited to 2, 3, or 4. However, larger vector + // sizes (8 and 16) are enabled when the Kernel capability is present. For + // shader execution models, vector sizes are strictly limited to 4. In + // non-shader contexts, vector sizes of 8 and 16 are also permitted, but + // arbitrary sizes (e.g., 6 or 11) are not. + uint32_t MaxVectorSize = ST.isShader() ? 4 : 16; + + for (auto Opc : getTypeFoldingSupportedOpcodes()) { + if (Opc != G_EXTRACT_VECTOR_ELT) + getActionDefinitionsBuilder(Opc).custom(); + } - getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); + getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom(); - // TODO: add proper rules for vectors legalization. - getActionDefinitionsBuilder( - {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR}) + getActionDefinitionsBuilder(G_SHUFFLE_VECTOR) + .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes) + .moreElementsToNextPow2(0) + .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize)) + .moreElementsToNextPow2(1) + .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize)) + .alwaysLegal(); + + getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT) + .moreElementsToNextPow2(1) + .fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 1, ElementCount::getFixed(MaxVectorSize))) + .custom(); + + // Illegal G_UNMERGE_VALUES instructions should be handled + // during the combine phase. + getActionDefinitionsBuilder(G_BUILD_VECTOR) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))); + + // When entering the legalizer, there should be no G_BITCAST instructions. + // They should all be calls to the `spv_bitcast` intrinsic. The call to + // the intrinsic will be converted to a G_BITCAST during legalization if + // the vectors are not legal. After using the rules to legalize a G_BITCAST, + // we turn it back into a call to the intrinsic with a custom rule to avoid + // potential machine verifier failures. + getActionDefinitionsBuilder(G_BITCAST) + .moreElementsToNextPow2(0) + .moreElementsToNextPow2(1) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))) + .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize)) + .custom(); + + getActionDefinitionsBuilder(G_CONCAT_VECTORS) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .moreElementsToNextPow2(0) + .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize)) + .alwaysLegal(); + + getActionDefinitionsBuilder(G_SPLAT_VECTOR) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .moreElementsToNextPow2(0) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementSizeTo(0, MaxVectorSize)) .alwaysLegal(); // Vector Reduction Operations @@ -164,7 +232,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN, G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM, G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR}) - .legalFor(allVectors) + .legalFor(allowedVectorTypes) .scalarize(1) .lower(); @@ -172,9 +240,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { .scalarize(2) .lower(); - // Merge/Unmerge - // TODO: add proper legalization rules. - getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal(); + // Illegal G_UNMERGE_VALUES instructions should be handled + // during the combine phase. + getActionDefinitionsBuilder(G_UNMERGE_VALUES) + .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize)); getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs))); @@ -228,7 +297,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { all(typeInSet(0, allPtrsScalarsAndVectors), typeInSet(1, allPtrsScalarsAndVectors))); - getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal(); + getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) + .legalFor({s1}) + .legalFor(allFloatAndIntScalarsAndPtrs) + .legalFor(allowedVectorTypes) + .moreElementsToNextPow2(0) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))); getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal(); @@ -287,6 +363,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { // Pointer-handling. getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs); + // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); @@ -353,6 +431,21 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { verify(*ST.getInstrInfo()); } +static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI, + SPIRVGlobalRegistry *GR) { + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + Register IdxReg = MI.getOperand(2).getReg(); + + MIRBuilder + .buildIntrinsic(Intrinsic::spv_extractelt, ArrayRef{DstReg}) + .addUse(SrcReg) + .addUse(IdxReg); + MI.eraseFromParent(); + return true; +} + static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, @@ -374,6 +467,13 @@ bool SPIRVLegalizerInfo::legalizeCustom( default: // TODO: implement legalization for other opcodes. return true; + case TargetOpcode::G_BITCAST: + return legalizeBitcast(Helper, MI); + case TargetOpcode::G_EXTRACT_VECTOR_ELT: + return legalizeExtractVectorElt(Helper, MI, GR); + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: + return legalizeIntrinsic(Helper, MI); case TargetOpcode::G_IS_FPCLASS: return legalizeIsFPClass(Helper, MI, LocObserver); case TargetOpcode::G_ICMP: { @@ -400,6 +500,76 @@ bool SPIRVLegalizerInfo::legalizeCustom( } } +bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, + MachineInstr &MI) const { + LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI); + + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); + const SPIRVSubtarget &ST = MI.getMF()->getSubtarget(); + + auto IntrinsicID = cast(MI).getIntrinsicID(); + if (IntrinsicID == Intrinsic::spv_bitcast) { + LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n"); + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(2).getReg(); + LLT DstTy = MRI.getType(DstReg); + LLT SrcTy = MRI.getType(SrcReg); + + int32_t MaxVectorSize = ST.isShader() ? 4 : 16; + + bool DstNeedsLegalization = false; + bool SrcNeedsLegalization = false; + + if (DstTy.isVector()) { + if (DstTy.getNumElements() > 4 && + !isPowerOf2_32(DstTy.getNumElements())) { + DstNeedsLegalization = true; + } + + if (DstTy.getNumElements() > MaxVectorSize) { + DstNeedsLegalization = true; + } + } + + if (SrcTy.isVector()) { + if (SrcTy.getNumElements() > 4 && + !isPowerOf2_32(SrcTy.getNumElements())) { + SrcNeedsLegalization = true; + } + + if (SrcTy.getNumElements() > MaxVectorSize) { + SrcNeedsLegalization = true; + } + } + + // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to + // allow using the generic legalization rules. + if (DstNeedsLegalization || SrcNeedsLegalization) { + LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n"); + MIRBuilder.buildBitcast(DstReg, SrcReg); + MI.eraseFromParent(); + } + return true; + } + return true; +} + +bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper, + MachineInstr &MI) const { + // Once the G_BITCAST is using vectors that are allowed, we turn it back into + // an spv_bitcast to avoid verifier problems when the register types are the + // same for the source and the result. Note that the SPIR-V types associated + // with the bitcast can be different even if the register types are the same. + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + SmallVector DstRegs = {DstReg}; + MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg); + MI.eraseFromParent(); + return true; +} + // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted // to ensure that all instructions created during the lowering have SPIR-V types // assigned to them. diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h index eeefa4239c778..86e7e711caa60 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo { public: bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override; + bool legalizeIntrinsic(LegalizerHelper &Helper, + MachineInstr &MI) const override; + SPIRVLegalizerInfo(const SPIRVSubtarget &ST); private: bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const; + bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const; }; } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll new file mode 100644 index 0000000000000..4fe6f217dd40f --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll @@ -0,0 +1,69 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion" +; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v8i32:]] = OpTypeVector %[[#int]] 8 +; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4 +; CHECK-DAG: %[[#ptr_func_v8i32:]] = OpTypePointer Function %[[#v8i32]] + +; CHECK-DAG: OpName %[[#test_v3f64_conversion:]] "test_v3f64_conversion" +; CHECK-DAG: %[[#double:]] = OpTypeFloat 64 +; CHECK-DAG: %[[#v3f64:]] = OpTypeVector %[[#double]] 3 +; CHECK-DAG: %[[#ptr_func_v3f64:]] = OpTypePointer Function %[[#v3f64]] +; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4 + +define spir_kernel void @test_int32_double_conversion(ptr %G_vec) { +; CHECK: %[[#test_int32_double_conversion]] = OpFunction +; CHECK: %[[#param:]] = OpFunctionParameter %[[#ptr_func_v8i32]] +entry: + ; CHECK: %[[#LOAD:]] = OpLoad %[[#v8i32]] %[[#param]] + ; CHECK: %[[#SHUF1:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 2 4 6 + ; CHECK: %[[#SHUF2:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 1 3 5 7 + ; CHECK: %[[#SHUF3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUF1]] %[[#SHUF2]] 0 4 1 5 2 6 3 7 + ; CHECK: OpStore %[[#param]] %[[#SHUF3]] + + %0 = load <8 x i32>, ptr %G_vec + %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> + store <8 x i32> %3, ptr %G_vec + ret void +} + +define spir_kernel void @test_v3f64_conversion(ptr %G_vec) { +; CHECK: %[[#test_v3f64_conversion:]] = OpFunction +; CHECK: %[[#param_v3f64:]] = OpFunctionParameter %[[#ptr_func_v3f64]] +entry: + ; CHECK: %[[#LOAD:]] = OpLoad %[[#v3f64]] %[[#param_v3f64]] + %0 = load <3 x double>, ptr %G_vec + + ; The 6-element vector is not legal. It get expanded to 8. + ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 0 + ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 1 + ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 2 + ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4f64]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %{{[a-zA-Z0-9_]+}} + ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v8i32]] %[[#CONSTRUCT1]] + %1 = bitcast <3 x double> %0 to <6 x i32> + + ; CHECK: %[[#SHUFFLE1:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 0 2 4 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF + %2 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> + + ; CHECK: %[[#SHUFFLE2:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 1 3 5 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF + %3 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> + + ; CHECK: %[[#SHUFFLE3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUFFLE1]] %[[#SHUFFLE2]] 0 8 1 9 2 10 0xFFFFFFFF 0xFFFFFFFF + %4 = shufflevector <3 x i32> %2, <3 x i32> %3, <6 x i32> + + ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4f64]] %[[#SHUFFLE3]] + ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 0 + ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 1 + ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 2 + ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v3f64]] %[[#EXTRACT10]] %[[#EXTRACT11]] %[[#EXTRACT12]] + %5 = bitcast <6 x i32> %4 to <3 x double> + + ; CHECK: OpStore %[[#param_v3f64]] %[[#CONSTRUCT3]] + store <3 x double> %5, ptr %G_vec + ret void +} + diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-shader.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-shader.ll new file mode 100644 index 0000000000000..438d7ae21283a --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-shader.ll @@ -0,0 +1,133 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %} + +; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#double:]] = OpTypeFloat 64 +; CHECK-DAG: %[[#v4int:]] = OpTypeVector %[[#int]] 4 +; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4 +; CHECK-DAG: %[[#v2int:]] = OpTypeVector %[[#int]] 2 +; CHECK-DAG: %[[#v2double:]] = OpTypeVector %[[#double]] 2 +; CHECK-DAG: %[[#v3int:]] = OpTypeVector %[[#int]] 3 +; CHECK-DAG: %[[#v3double:]] = OpTypeVector %[[#double]] 3 +; CHECK-DAG: %[[#ptr_v4double:]] = OpTypePointer Private %[[#v4double]] +; CHECK-DAG: %[[#ptr_v4int:]] = OpTypePointer Private %[[#v4int]] +; CHECK-DAG: %[[#ptr_v3double:]] = OpTypePointer Private %[[#v3double]] +; CHECK-DAG: %[[#ptr_v3int:]] = OpTypePointer Private %[[#v3int]] +; CHECK-DAG: %[[#GVec4:]] = OpVariable %[[#ptr_v4double]] Private +; CHECK-DAG: %[[#Lows:]] = OpVariable %[[#ptr_v4int]] Private +; CHECK-DAG: %[[#Highs:]] = OpVariable %[[#ptr_v4int]] Private +; CHECK-DAG: %[[#GVec3:]] = OpVariable %[[#ptr_v3double]] Private +; CHECK-DAG: %[[#Lows3:]] = OpVariable %[[#ptr_v3int]] Private +; CHECK-DAG: %[[#Highs3:]] = OpVariable %[[#ptr_v3int]] Private + +@GVec4 = internal addrspace(10) global <4 x double> zeroinitializer +@Lows = internal addrspace(10) global <4 x i32> zeroinitializer +@Highs = internal addrspace(10) global <4 x i32> zeroinitializer +@GVec3 = internal addrspace(10) global <3 x double> zeroinitializer +@Lows3 = internal addrspace(10) global <3 x i32> zeroinitializer +@Highs3 = internal addrspace(10) global <3 x i32> zeroinitializer + +; Test splitting a vector of size 8. +define internal void @test_split() { +entry: + ; CHECK: %[[#load_v4double:]] = OpLoad %[[#v4double]] %[[#GVec4]] + ; CHECK: %[[#v2double_01:]] = OpVectorShuffle %[[#v2double]] %[[#load_v4double]] %{{[a-zA-Z0-9_]+}} 0 1 + ; CHECK: %[[#v2double_23:]] = OpVectorShuffle %[[#v2double]] %[[#load_v4double]] %{{[a-zA-Z0-9_]+}} 2 3 + ; CHECK: %[[#v4int_01:]] = OpBitcast %[[#v4int]] %[[#v2double_01]] + ; CHECK: %[[#v4int_23:]] = OpBitcast %[[#v4int]] %[[#v2double_23]] + %0 = load <8 x i32>, ptr addrspace(10) @GVec4, align 32 + + ; CHECK: %[[#l0:]] = OpCompositeExtract %[[#int]] %[[#v4int_01]] 0 + ; CHECK: %[[#l1:]] = OpCompositeExtract %[[#int]] %[[#v4int_01]] 2 + ; CHECK: %[[#l2:]] = OpCompositeExtract %[[#int]] %[[#v4int_23]] 0 + ; CHECK: %[[#l3:]] = OpCompositeExtract %[[#int]] %[[#v4int_23]] 2 + ; CHECK: %[[#res_low:]] = OpCompositeConstruct %[[#v4int]] %[[#l0]] %[[#l1]] %[[#l2]] %[[#l3]] + %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + + ; CHECK: %[[#h0:]] = OpCompositeExtract %[[#int]] %[[#v4int_01]] 1 + ; CHECK: %[[#h1:]] = OpCompositeExtract %[[#int]] %[[#v4int_01]] 3 + ; CHECK: %[[#h2:]] = OpCompositeExtract %[[#int]] %[[#v4int_23]] 1 + ; CHECK: %[[#h3:]] = OpCompositeExtract %[[#int]] %[[#v4int_23]] 3 + ; CHECK: %[[#res_high:]] = OpCompositeConstruct %[[#v4int]] %[[#h0]] %[[#h1]] %[[#h2]] %[[#h3]] + %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + + store <4 x i32> %1, ptr addrspace(10) @Lows, align 16 + store <4 x i32> %2, ptr addrspace(10) @Highs, align 16 + ret void +} + +define internal void @test_recombine() { +entry: + ; CHECK: %[[#l:]] = OpLoad %[[#v4int]] %[[#Lows]] + %0 = load <4 x i32>, ptr addrspace(10) @Lows, align 16 + ; CHECK: %[[#h:]] = OpLoad %[[#v4int]] %[[#Highs]] + %1 = load <4 x i32>, ptr addrspace(10) @Highs, align 16 + + ; CHECK-DAG: %[[#l0:]] = OpCompositeExtract %[[#int]] %[[#l]] 0 + ; CHECK-DAG: %[[#l1:]] = OpCompositeExtract %[[#int]] %[[#l]] 1 + ; CHECK-DAG: %[[#l2:]] = OpCompositeExtract %[[#int]] %[[#l]] 2 + ; CHECK-DAG: %[[#l3:]] = OpCompositeExtract %[[#int]] %[[#l]] 3 + ; CHECK-DAG: %[[#h0:]] = OpCompositeExtract %[[#int]] %[[#h]] 0 + ; CHECK-DAG: %[[#h1:]] = OpCompositeExtract %[[#int]] %[[#h]] 1 + ; CHECK-DAG: %[[#h2:]] = OpCompositeExtract %[[#int]] %[[#h]] 2 + ; CHECK-DAG: %[[#h3:]] = OpCompositeExtract %[[#int]] %[[#h]] 3 + ; CHECK-DAG: %[[#v2i0:]] = OpCompositeConstruct %[[#v2int]] %[[#l0]] %[[#h0]] + ; CHECK-DAG: %[[#d0:]] = OpBitcast %[[#double]] %[[#v2i0]] + ; CHECK-DAG: %[[#v2i1:]] = OpCompositeConstruct %[[#v2int]] %[[#l1]] %[[#h1]] + ; CHECK-DAG: %[[#d1:]] = OpBitcast %[[#double]] %[[#v2i1]] + ; CHECK-DAG: %[[#v2i2:]] = OpCompositeConstruct %[[#v2int]] %[[#l2]] %[[#h2]] + ; CHECK-DAG: %[[#d2:]] = OpBitcast %[[#double]] %[[#v2i2]] + ; CHECK-DAG: %[[#v2i3:]] = OpCompositeConstruct %[[#v2int]] %[[#l3]] %[[#h3]] + ; CHECK-DAG: %[[#d3:]] = OpBitcast %[[#double]] %[[#v2i3]] + ; CHECK-DAG: %[[#res:]] = OpCompositeConstruct %[[#v4double]] %[[#d0]] %[[#d1]] %[[#d2]] %[[#d3]] + %2 = shufflevector <4 x i32> %0, <4 x i32> %1, <8 x i32> + + ; CHECK: OpStore %[[#GVec4]] %[[#res]] + store <8 x i32> %2, ptr addrspace(10) @GVec4, align 32 + ret void +} + +; Test splitting a vector of size 6. It must be expanded to 8, and then split. +define internal void @test_bitcast_expand() { +entry: + ; CHECK: %[[#load:]] = OpLoad %[[#v3double]] %[[#GVec3]] + %0 = load <3 x double>, ptr addrspace(10) @GVec3, align 32 + + ; CHECK: %[[#d0:]] = OpCompositeExtract %[[#double]] %[[#load]] 0 + ; CHECK: %[[#d1:]] = OpCompositeExtract %[[#double]] %[[#load]] 1 + ; CHECK: %[[#d2:]] = OpCompositeExtract %[[#double]] %[[#load]] 2 + ; CHECK: %[[#v2d0:]] = OpCompositeConstruct %[[#v2double]] %[[#d0]] %[[#d1]] + ; CHECK: %[[#v2d1:]] = OpCompositeConstruct %[[#v2double]] %[[#d2]] %[[#]] + ; CHECK: %[[#v4i0:]] = OpBitcast %[[#v4int]] %[[#v2d0]] + ; CHECK: %[[#v4i1:]] = OpBitcast %[[#v4int]] %[[#v2d1]] + %1 = bitcast <3 x double> %0 to <6 x i32> + + ; CHECK: %[[#l0:]] = OpCompositeExtract %[[#int]] %[[#v4i0]] 0 + ; CHECK: %[[#l1:]] = OpCompositeExtract %[[#int]] %[[#v4i0]] 2 + ; CHECK: %[[#l2:]] = OpCompositeExtract %[[#int]] %[[#v4i1]] 0 + ; CHECK: %[[#res_low:]] = OpCompositeConstruct %[[#v3int]] %[[#l0]] %[[#l1]] %[[#l2]] + %2 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> + + ; CHECK: %[[#h0:]] = OpCompositeExtract %[[#int]] %[[#v4i0]] 1 + ; CHECK: %[[#h1:]] = OpCompositeExtract %[[#int]] %[[#v4i0]] 3 + ; CHECK: %[[#h2:]] = OpCompositeExtract %[[#int]] %[[#v4i1]] 1 + ; CHECK: %[[#res_high:]] = OpCompositeConstruct %[[#v3int]] %[[#h0]] %[[#h1]] %[[#h2]] + %3 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> + + ; CHECK: OpStore %[[#Lows3]] %[[#res_low]] + store <3 x i32> %2, ptr addrspace(10) @Lows3, align 16 + + ; CHECK: OpStore %[[#Highs3]] %[[#res_high]] + store <3 x i32> %3, ptr addrspace(10) @Highs3, align 16 + ret void +} + +define void @main() local_unnamed_addr #0 { +entry: + call void @test_split() + call void @test_recombine() + call void @test_bitcast_expand() + ret void +} + +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }