diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h index 51318c9c2736d..a8748965eb2e8 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size); LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx, unsigned Size); +/// True iff the specified type index is a vector with an element size +/// that's greater than the given size. +LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx, + unsigned Size); + +/// True iff the specified type index is a vector with an element size +/// that's less than or equal to the given size. +LLVM_ABI LegalityPredicate +vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size); + /// True iff the specified type index is a scalar or a vector with an element /// type that's wider than the given size. LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx, diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp index 30c2d089c3121..5e7cd5fd5d9ad 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp @@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx, }; } +LegalityPredicate +LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx, + unsigned Size) { + + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size; + }; +} + +LegalityPredicate +LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, + unsigned Size) { + + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size; + }; +} + LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx, unsigned Size) { return [=](const LegalityQuery &Query) { diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td index 2fde2b0bc0b1f..f93240dc35993 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrFormats.td @@ -25,6 +25,11 @@ class Op Opcode, dag outs, dag ins, string asmstr, list pattern = let Pattern = pattern; } +class PureOp Opcode, dag outs, dag ins, string asmstr, + list pattern = []> : Op { + let hasSideEffects = 0; +} + class UnknownOp pattern = []> : Op<0, outs, ins, asmstr, pattern> { let isPseudo = 1; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index a61351eba03f8..799a82c96b0f0 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari // 3.42.6 Type-Declaration Instructions -def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">; -def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">; -def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness), - "$type = OpTypeInt $width $signedness">; -def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops), - "$type = OpTypeFloat $width">; -def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount), - "$type = OpTypeVector $compType $compCount">; -def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount), - "$type = OpTypeMatrix $colType $colCount">; -def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth, - i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops), - "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">; -def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">; -def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType), - "$res = OpTypeSampledImage $imageType">; -def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length), - "$type = OpTypeArray $elementType $length">; -def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType), - "$type = OpTypeRuntimeArray $elementType">; -def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">; -def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops), - "OpTypeStructContinuedINTEL">; -def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops), - "$res = OpTypeOpaque $name">; -def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type), - "$res = OpTypePointer $storage $type">; -def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops), - "$funcType = OpTypeFunction $returnType">; -def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">; -def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">; -def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">; -def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">; -def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">; -def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass), - "OpTypeForwardPointer $ptrType $storageClass">; -def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">; -def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">; -def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins), - "$res = OpTypeAccelerationStructureNV">; -def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res), - (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), - "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; -def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res), - (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use), - "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">; +def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">; +def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">; +def OpTypeInt + : PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness), + "$type = OpTypeInt $width $signedness">; +def OpTypeFloat + : PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops), + "$type = OpTypeFloat $width">; +def OpTypeVector + : PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount), + "$type = OpTypeVector $compType $compCount">; +def OpTypeMatrix + : PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount), + "$type = OpTypeMatrix $colType $colCount">; +def OpTypeImage : PureOp<25, (outs TYPE:$res), + (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth, + i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, + ImageFormat:$imFormat, variable_ops), + "$res = OpTypeImage $sampTy $dim $depth $arrayed $MS " + "$sampled $imFormat">; +def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">; +def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType), + "$res = OpTypeSampledImage $imageType">; +def OpTypeArray + : PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length), + "$type = OpTypeArray $elementType $length">; +def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType), + "$type = OpTypeRuntimeArray $elementType">; +def OpTypeStruct + : PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">; +def OpTypeStructContinuedINTEL + : PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">; +def OpTypeOpaque + : PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops), + "$res = OpTypeOpaque $name">; +def OpTypePointer + : PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type), + "$res = OpTypePointer $storage $type">; +def OpTypeFunction + : PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops), + "$funcType = OpTypeFunction $returnType">; +def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">; +def OpTypeDeviceEvent + : PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">; +def OpTypeReserveId + : PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">; +def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">; +def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a), + "$res = OpTypePipe $a">; +def OpTypeForwardPointer + : PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass), + "OpTypeForwardPointer $ptrType $storageClass">; +def OpTypePipeStorage + : PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">; +def OpTypeNamedBarrier + : PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">; +def OpTypeAccelerationStructureNV + : PureOp<5341, (outs TYPE:$res), (ins), + "$res = OpTypeAccelerationStructureNV">; +def OpTypeCooperativeMatrixNV + : PureOp<5358, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols), + "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">; +def OpTypeCooperativeMatrixKHR + : PureOp<4456, (outs TYPE:$res), + (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use), + "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols " + "$use">; // 3.42.7 Constant-Creation Instructions @@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">; def ConstPseudoTrue: IntImmLeaf; def ConstPseudoFalse: IntImmLeaf; -def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty", - [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>; -def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty", - [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>; - -def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops), - "$res = OpConstantComposite $type">; -def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops), - "OpConstantCompositeContinuedINTEL">; - -def OpConstantSampler: Op<45, (outs ID:$res), - (ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f), - "$res = OpConstantSampler $t $s $p $f">; -def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">; - -def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">; -def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">; -def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops), - "$res = OpSpecConstant $type $imm">; -def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops), - "$res = OpSpecConstantComposite $type">; -def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops), - "OpSpecConstantCompositeContinuedINTEL">; -def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops), - "$res = OpSpecConstantOp $t $c $o">; +def OpConstantTrue + : PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantTrue $src_ty", + [(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>; +def OpConstantFalse + : PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantFalse $src_ty", + [(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>; + +def OpConstantComposite + : PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpConstantComposite $type">; +def OpConstantCompositeContinuedINTEL + : PureOp<6091, (outs), (ins variable_ops), + "OpConstantCompositeContinuedINTEL">; + +def OpConstantSampler : PureOp<45, (outs ID:$res), + (ins TYPE:$t, SamplerAddressingMode:$s, + i32imm:$p, SamplerFilterMode:$f), + "$res = OpConstantSampler $t $s $p $f">; +def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty), + "$dst = OpConstantNull $src_ty">; + +def OpSpecConstantTrue + : PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">; +def OpSpecConstantFalse + : PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">; +def OpSpecConstant + : PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops), + "$res = OpSpecConstant $type $imm">; +def OpSpecConstantComposite + : PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops), + "$res = OpSpecConstantComposite $type">; +def OpSpecConstantCompositeContinuedINTEL + : PureOp<6092, (outs), (ins variable_ops), + "OpSpecConstantCompositeContinuedINTEL">; +def OpSpecConstantOp + : PureOp<52, (outs ID:$res), + (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops), + "$res = OpSpecConstantOp $t $c $o">; // 3.42.8 Memory Instructions diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 021353ab716f7..f9e6a224f581b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -1526,33 +1526,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const { unsigned ArgI = I.getNumOperands() - 1; Register SrcReg = I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0); - SPIRVType *DefType = + SPIRVType *SrcType = SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr; - if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector) + if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector) report_fatal_error( "cannot select G_UNMERGE_VALUES with a non-vector argument"); SPIRVType *ScalarType = - GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg()); MachineBasicBlock &BB = *I.getParent(); bool Res = false; + unsigned CurrentIndex = 0; for (unsigned i = 0; i < I.getNumDefs(); ++i) { Register ResVReg = I.getOperand(i).getReg(); SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg); if (!ResType) { - // There was no "assign type" actions, let's fix this now - ResType = ScalarType; + LLT ResLLT = MRI->getType(ResVReg); + assert(ResLLT.isValid()); + if (ResLLT.isVector()) { + ResType = GR.getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), I, TII); + } else { + ResType = ScalarType; + } MRI->setRegClass(ResVReg, GR.getRegClass(ResType)); - MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType))); GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF); } - auto MIB = - BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(SrcReg) - .addImm(static_cast(i)); - Res |= MIB.constrainAllUses(TII, TRI, RBI); + + if (ResType->getOpcode() == SPIRV::OpTypeVector) { + Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII); + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .addUse(UndefReg); + unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType); + for (unsigned j = 0; j < NumElements; ++j) { + MIB.addImm(CurrentIndex + j); + } + CurrentIndex += NumElements; + Res |= MIB.constrainAllUses(TII, TRI, RBI); + } else { + auto MIB = + BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(SrcReg) + .addImm(CurrentIndex); + CurrentIndex++; + Res |= MIB.constrainAllUses(TII, TRI, RBI); + } } return Res; } @@ -3119,6 +3143,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectInsertElt(ResVReg, ResType, I); case Intrinsic::spv_gep: return selectGEP(ResVReg, ResType, I); + case Intrinsic::spv_bitcast: { + Register OpReg = I.getOperand(2).getReg(); + SPIRVType *OpType = + OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr; + if (!GR.isBitcastCompatible(ResType, OpType)) + report_fatal_error("incompatible result and operand types in a bitcast"); + return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast); + } case Intrinsic::spv_unref_global: case Intrinsic::spv_init_global: { MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg()); diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 28a1690ef0be1..a692c24363310 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass { // Returns the loaded value. Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType, FixedVectorType *TargetType, Value *Source) { - assert(TargetType->getNumElements() <= SourceType->getNumElements()); LoadInst *NewLoad = B.CreateLoad(SourceType, Source); buildAssignType(B, SourceType, NewLoad); Value *AssignValue = NewLoad; if (TargetType->getElementType() != SourceType->getElementType()) { + const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout(); + [[maybe_unused]] TypeSize TargetTypeSize = + DL.getTypeSizeInBits(TargetType); + [[maybe_unused]] TypeSize SourceTypeSize = + DL.getTypeSizeInBits(SourceType); + assert(TargetTypeSize == SourceTypeSize); AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast, {TargetType, SourceType}, {NewLoad}); buildAssignType(B, TargetType, AssignValue); + return AssignValue; } + assert(TargetType->getNumElements() < SourceType->getNumElements()); SmallVector Mask(/* Size= */ TargetType->getNumElements()); for (unsigned I = 0; I < TargetType->getNumElements(); ++I) Mask[I] = I; diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp index 53074ea3b2597..a7d6bde3c5f1a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp @@ -14,16 +14,22 @@ #include "SPIRV.h" #include "SPIRVGlobalRegistry.h" #include "SPIRVSubtarget.h" +#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" using namespace llvm; using namespace llvm::LegalizeActions; using namespace llvm::LegalityPredicates; +#define DEBUG_TYPE "spirv-legalizer" + LegalityPredicate typeOfExtendedScalars(unsigned TypeIdx, bool IsExtendedInts) { return [IsExtendedInts, TypeIdx](const LegalityQuery &Query) { const LLT Ty = Query.Types[TypeIdx]; @@ -101,6 +107,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; + auto allShaderVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, + v3s1, v3s8, v3s16, v3s32, v3s64, + v4s1, v4s8, v4s16, v4s32, v4s64}; + auto allScalarsAndVectors = { s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, @@ -126,6 +136,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8, p10, p11, p12}; + auto &allowedVectorTypes = ST.isShader() ? allShaderVectors : allVectors; + bool IsExtendedInts = ST.canUseExtension( SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) || @@ -148,14 +160,65 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { return IsExtendedInts && Ty.isValid(); }; - for (auto Opc : getTypeFoldingSupportedOpcodes()) - getActionDefinitionsBuilder(Opc).custom(); + uint32_t MaxVectorSize = ST.isShader() ? 4 : 16; - getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); + for (auto Opc : getTypeFoldingSupportedOpcodes()) { + if (Opc != G_EXTRACT_VECTOR_ELT) + getActionDefinitionsBuilder(Opc).custom(); + } - // TODO: add proper rules for vectors legalization. - getActionDefinitionsBuilder( - {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR}) + getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom(); + + getActionDefinitionsBuilder(G_SHUFFLE_VECTOR) + .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes) + .moreElementsToNextPow2(0) + .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize)) + .moreElementsToNextPow2(1) + .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize)) + .alwaysLegal(); + + getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT) + .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize)) + .moreElementsToNextPow2(1) + .fewerElementsIf(vectorElementCountIsGreaterThan(1, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 1, ElementCount::getFixed(MaxVectorSize))) + .custom(); + + // Illegal G_UNMERGE_VALUES instructions should be handled + // during the combine phase. + getActionDefinitionsBuilder(G_BUILD_VECTOR) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))); + + // When entering the legalizer, there should be no G_BITCAST instructions. + // They should all be calls to the `spv_bitcast` intrinsic. The call to + // the intrinsic will be converted to a G_BITCAST during legalization if + // the vectors are not legal. After using the rules to legalize a G_BITCAST, + // we turn it back into a call to the intrinsic with a custom ruel to avoid + // potential machines verifier failures. + getActionDefinitionsBuilder(G_BITCAST) + .moreElementsToNextPow2(0) + .moreElementsToNextPow2(1) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))) + .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize)) + .custom(); + + getActionDefinitionsBuilder(G_CONCAT_VECTORS) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .moreElementsToNextPow2(0) + .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize)) + .alwaysLegal(); + + getActionDefinitionsBuilder(G_SPLAT_VECTOR) + .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize)) + .moreElementsToNextPow2(0) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementSizeTo(0, MaxVectorSize)) .alwaysLegal(); // Vector Reduction Operations @@ -164,7 +227,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN, G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM, G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR}) - .legalFor(allVectors) + .legalFor(allowedVectorTypes) .scalarize(1) .lower(); @@ -172,9 +235,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { .scalarize(2) .lower(); - // Merge/Unmerge - // TODO: add proper legalization rules. - getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal(); + // Illegal G_UNMERGE_VALUES instructions should be handled + // during the combine phase. + getActionDefinitionsBuilder(G_UNMERGE_VALUES) + .legalIf(vectorElementCountIsLessThanOrEqualTo(1, MaxVectorSize)); getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) .legalIf(all(typeInSet(0, allPtrs), typeInSet(1, allPtrs))); @@ -228,7 +292,14 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { all(typeInSet(0, allPtrsScalarsAndVectors), typeInSet(1, allPtrsScalarsAndVectors))); - getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal(); + getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) + .legalFor({s1}) + .legalFor(allFloatAndIntScalarsAndPtrs) + .legalFor(allowedVectorTypes) + .moreElementsToNextPow2(0) + .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize), + LegalizeMutations::changeElementCountTo( + 0, ElementCount::getFixed(MaxVectorSize))); getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal(); @@ -287,6 +358,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { // Pointer-handling. getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor(allPtrs); + // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); @@ -374,6 +447,12 @@ bool SPIRVLegalizerInfo::legalizeCustom( default: // TODO: implement legalization for other opcodes. return true; + case TargetOpcode::G_BITCAST: + return legalizeBitcast(Helper, MI); + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: + return legalizeIntrinsic(Helper, MI); + case TargetOpcode::G_IS_FPCLASS: return legalizeIsFPClass(Helper, MI, LocObserver); case TargetOpcode::G_ICMP: { @@ -400,6 +479,70 @@ bool SPIRVLegalizerInfo::legalizeCustom( } } +bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper, + MachineInstr &MI) const { + LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI); + + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); + const SPIRVSubtarget &ST = MI.getMF()->getSubtarget(); + + auto IntrinsicID = cast(MI).getIntrinsicID(); + if (IntrinsicID == Intrinsic::spv_bitcast) { + LLVM_DEBUG(dbgs() << "Found a bitcast instruction\n"); + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(2).getReg(); + LLT DstTy = MRI.getType(DstReg); + LLT SrcTy = MRI.getType(SrcReg); + + int32_t MaxVectorSize = ST.isShader() ? 4 : 16; + + bool DstNeedsLegalization = false; + bool SrcNeedsLegalization = false; + + if (DstTy.isVector()) { + if (DstTy.getNumElements() > 4 && + !isPowerOf2_32(DstTy.getNumElements())) { + DstNeedsLegalization = true; + } + + if (DstTy.getNumElements() > MaxVectorSize) { + DstNeedsLegalization = true; + } + } + + if (SrcTy.isVector()) { + if (SrcTy.getNumElements() > 4 && + !isPowerOf2_32(SrcTy.getNumElements())) { + SrcNeedsLegalization = true; + } + + if (SrcTy.getNumElements() > MaxVectorSize) { + SrcNeedsLegalization = true; + } + } + + if (DstNeedsLegalization || SrcNeedsLegalization) { + LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n"); + MIRBuilder.buildBitcast(DstReg, SrcReg); + MI.eraseFromParent(); + } + return true; + } + return true; +} + +bool SPIRVLegalizerInfo::legalizeBitcast(LegalizerHelper &Helper, + MachineInstr &MI) const { + MachineIRBuilder &MIRBuilder = Helper.MIRBuilder; + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + SmallVector DstRegs = {DstReg}; + MIRBuilder.buildIntrinsic(Intrinsic::spv_bitcast, DstRegs).addUse(SrcReg); + MI.eraseFromParent(); + return true; +} + // Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted // to ensure that all instructions created during the lowering have SPIR-V types // assigned to them. diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h index eeefa4239c778..86e7e711caa60 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h @@ -29,11 +29,15 @@ class SPIRVLegalizerInfo : public LegalizerInfo { public: bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override; + bool legalizeIntrinsic(LegalizerHelper &Helper, + MachineInstr &MI) const override; + SPIRVLegalizerInfo(const SPIRVSubtarget &ST); private: bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const; + bool legalizeBitcast(LegalizerHelper &Helper, MachineInstr &MI) const; }; } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp index d17528dd882bf..644e010d8cf94 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp @@ -17,7 +17,8 @@ #include "SPIRV.h" #include "SPIRVSubtarget.h" #include "SPIRVUtils.h" -#include "llvm/IR/Attributes.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "spirv-postlegalizer" @@ -45,6 +46,11 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB, static bool mayBeInserted(unsigned Opcode) { switch (Opcode) { + case TargetOpcode::G_CONSTANT: + case TargetOpcode::G_UNMERGE_VALUES: + case TargetOpcode::G_EXTRACT_VECTOR_ELT: + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: case TargetOpcode::G_SMAX: case TargetOpcode::G_UMAX: case TargetOpcode::G_SMIN: @@ -53,73 +59,457 @@ static bool mayBeInserted(unsigned Opcode) { case TargetOpcode::G_FMINIMUM: case TargetOpcode::G_FMAXNUM: case TargetOpcode::G_FMAXIMUM: + case TargetOpcode::G_IMPLICIT_DEF: + case TargetOpcode::G_BUILD_VECTOR: + case TargetOpcode::G_ICMP: + case TargetOpcode::G_SHUFFLE_VECTOR: + case TargetOpcode::G_ANYEXT: return true; default: return isTypeFoldingSupported(Opcode); } } -static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { +static SPIRVType *deduceTypeForGConstant(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { MachineRegisterInfo &MRI = MF.getRegInfo(); + const LLT &Ty = MRI.getType(ResVReg); + unsigned BitWidth = Ty.getScalarSizeInBits(); + return GR->getOrCreateSPIRVIntegerType(BitWidth, MIB); +} + +static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg(); + SPIRVType *ScalarType = nullptr; + if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) { + assert(DefType->getOpcode() == SPIRV::OpTypeVector); + ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); + } + + if (!ScalarType) { + // If we could not deduce the type from the source, try to deduce it from + // the uses of the results. + for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) { + for (const auto &Use : + MRI.use_nodbg_instructions(I->getOperand(i).getReg())) { + assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR && + "Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR"); + if (auto *VecType = + GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) { + ScalarType = GR->getScalarOrVectorComponentType(VecType); + break; + } + } + } + } + + if (!ScalarType) + return false; + + for (unsigned i = 0; i < I->getNumDefs(); ++i) { + Register DefReg = I->getOperand(i).getReg(); + if (GR->getSPIRVTypeForVReg(DefReg)) + continue; + + LLT DefLLT = MRI.getType(DefReg); + SPIRVType *ResType = + DefLLT.isVector() + ? GR->getOrCreateSPIRVVectorType( + ScalarType, DefLLT.getNumElements(), *I, + *MF.getSubtarget().getInstrInfo()) + : ScalarType; + setRegClassType(DefReg, ResType, GR, &MRI, MF); + } + return true; +} +static SPIRVType *deduceTypeForGExtractVectorElt(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + Register VecReg = I->getOperand(1).getReg(); + if (SPIRVType *VecType = GR->getSPIRVTypeForVReg(VecReg)) { + assert(VecType->getOpcode() == SPIRV::OpTypeVector); + return GR->getScalarOrVectorComponentType(VecType); + } + + // If not handled yet, then check if it is used in a G_BUILD_VECTOR. + // If so get the type from there. + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + if (Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR) { + Register BuildVecResReg = Use.getOperand(0).getReg(); + if (SPIRVType *BuildVecType = GR->getSPIRVTypeForVReg(BuildVecResReg)) + return GR->getScalarOrVectorComponentType(BuildVecType); + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGBuildVector(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Processing " << *I << "\n"); + // First check if any of the operands have a type. + for (unsigned i = 1; i < I->getNumOperands(); ++i) { + if (SPIRVType *OpType = + GR->getSPIRVTypeForVReg(I->getOperand(i).getReg())) { + const LLT &ResLLT = MRI.getType(ResVReg); + LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found operand type " + << *OpType << ", returning vector type\n"); + return GR->getOrCreateSPIRVVectorType(OpType, ResLLT.getNumElements(), + MIB, false); + } + } + // If that did not work, then check the uses. + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + if (Use.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) { + Register ExtractResReg = Use.getOperand(0).getReg(); + if (SPIRVType *ScalarType = GR->getSPIRVTypeForVReg(ExtractResReg)) { + const LLT &ResLLT = MRI.getType(ResVReg); + LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Found use type " + << *ScalarType << ", returning vector type\n"); + return GR->getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), MIB, false); + } + } + } + LLVM_DEBUG(dbgs() << "deduceTypeForGBuildVector: Could not deduce type\n"); + return nullptr; +} + +static SPIRVType *deduceTypeForGShuffleVector(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + const LLT &ResLLT = MRI.getType(ResVReg); + assert(ResLLT.isVector() && "G_SHUFFLE_VECTOR result must be a vector"); + + // The result element type should be the same as the input vector element + // types. + for (unsigned i = 1; i <= 2; ++i) { + Register VReg = I->getOperand(i).getReg(); + if (auto *VType = GR->getSPIRVTypeForVReg(VReg)) { + if (auto *ScalarType = GR->getScalarOrVectorComponentType(VType)) + return GR->getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), MIB, false); + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGImplicitDef(MachineInstr *I, + MachineFunction &MF, + SPIRVGlobalRegistry *GR, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + for (const MachineInstr &Use : MRI.use_nodbg_instructions(ResVReg)) { + SPIRVType *ScalarType = nullptr; + switch (Use.getOpcode()) { + case TargetOpcode::G_BUILD_VECTOR: + case TargetOpcode::G_UNMERGE_VALUES: + // It's possible that the use instruction has not been processed yet. + // We should look at the operands of the use to determine the type. + for (unsigned i = 1; i < Use.getNumOperands(); ++i) { + if (SPIRVType *OpType = + GR->getSPIRVTypeForVReg(Use.getOperand(i).getReg())) + ScalarType = GR->getScalarOrVectorComponentType(OpType); + } + break; + case TargetOpcode::G_SHUFFLE_VECTOR: + // For G_SHUFFLE_VECTOR, only look at the vector input operands. + if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(1).getReg())) + ScalarType = GR->getScalarOrVectorComponentType(Type); + if (auto *Type = GR->getSPIRVTypeForVReg(Use.getOperand(2).getReg())) + ScalarType = GR->getScalarOrVectorComponentType(Type); + break; + } + if (ScalarType) { + const LLT &ResLLT = MRI.getType(ResVReg); + if (!ResLLT.isVector()) + return ScalarType; + return GR->getOrCreateSPIRVVectorType( + ScalarType, ResLLT.getNumElements(), *I, + *MF.getSubtarget().getInstrInfo()); + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGIntrinsic(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + if (!isSpvIntrinsic(*I, Intrinsic::spv_bitcast)) + return nullptr; + + for (const auto &Use : MRI.use_nodbg_instructions(ResVReg)) { + const unsigned UseOpc = Use.getOpcode(); + assert(UseOpc == TargetOpcode::G_EXTRACT_VECTOR_ELT || + UseOpc == TargetOpcode::G_SHUFFLE_VECTOR || + UseOpc == TargetOpcode::G_BUILD_VECTOR || + UseOpc == TargetOpcode::G_UNMERGE_VALUES); + Register UseResultReg = Use.getOperand(0).getReg(); + if (SPIRVType *UseResType = GR->getSPIRVTypeForVReg(UseResultReg)) { + SPIRVType *ScalarType = GR->getScalarOrVectorComponentType(UseResType); + const LLT &BitcastLLT = MRI.getType(ResVReg); + if (BitcastLLT.isVector()) + return GR->getOrCreateSPIRVVectorType( + ScalarType, BitcastLLT.getNumElements(), MIB, false); + return ScalarType; + } + } + return nullptr; +} + +static SPIRVType *deduceTypeForGAnyExt(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB, + Register ResVReg) { + // The result type of G_ANYEXT cannot be inferred from its operand. + // We use the result register's LLT to determine the correct integer type. + const LLT &ResLLT = MIB.getMRI()->getType(ResVReg); + if (!ResLLT.isScalar()) + return nullptr; + return GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB); +} + +static SPIRVType *deduceTypeForDefault(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR) { + if (I->getNumDefs() != 1 || I->getNumOperands() <= 1 || + !I->getOperand(1).isReg()) + return nullptr; + + SPIRVType *OpType = GR->getSPIRVTypeForVReg(I->getOperand(1).getReg()); + if (!OpType) + return nullptr; + return OpType; +} + +static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF, + SPIRVGlobalRegistry *GR, + MachineIRBuilder &MIB) { + LLVM_DEBUG(dbgs() << "Processing instruction: " << *I); + MachineRegisterInfo &MRI = MF.getRegInfo(); + const unsigned Opcode = I->getOpcode(); + Register ResVReg = I->getOperand(0).getReg(); + SPIRVType *ResType = nullptr; + + switch (Opcode) { + case TargetOpcode::G_CONSTANT: { + ResType = deduceTypeForGConstant(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_UNMERGE_VALUES: { + // This one is special as it defines multiple registers. + if (deduceAndAssignTypeForGUnmerge(I, MF, GR)) + return true; + break; + } + case TargetOpcode::G_EXTRACT_VECTOR_ELT: { + ResType = deduceTypeForGExtractVectorElt(I, MF, GR, ResVReg); + break; + } + case TargetOpcode::G_BUILD_VECTOR: { + ResType = deduceTypeForGBuildVector(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_SHUFFLE_VECTOR: { + ResType = deduceTypeForGShuffleVector(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_ANYEXT: { + ResType = deduceTypeForGAnyExt(I, MF, GR, MIB, ResVReg); + break; + } + case TargetOpcode::G_IMPLICIT_DEF: { + ResType = deduceTypeForGImplicitDef(I, MF, GR, ResVReg); + break; + } + case TargetOpcode::G_INTRINSIC: + case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: { + ResType = deduceTypeForGIntrinsic(I, MF, GR, MIB, ResVReg); + break; + } + default: + ResType = deduceTypeForDefault(I, MF, GR); + break; + } + + if (ResType) { + LLVM_DEBUG(dbgs() << "Assigned type to " << *I << ": " << *ResType << "\n"); + GR->assignSPIRVTypeToVReg(ResType, ResVReg, MF); + + if (!MRI.getRegClassOrNull(ResVReg)) { + LLVM_DEBUG(dbgs() << "Updating the register class.\n"); + setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true); + } + return true; + } + return false; +} + +static bool requiresSpirvType(MachineInstr &I, SPIRVGlobalRegistry *GR, + MachineRegisterInfo &MRI) { + LLVM_DEBUG(dbgs() << "Checking if instruction requires a SPIR-V type: " + << I;); + if (I.getNumDefs() == 0) { + LLVM_DEBUG(dbgs() << "Instruction does not have a definition.\n"); + return false; + } + if (!mayBeInserted(I.getOpcode())) { + LLVM_DEBUG(dbgs() << "Instruction may not be inserted.\n"); + return false; + } + + Register ResultRegister = I.defs().begin()->getReg(); + if (GR->getSPIRVTypeForVReg(ResultRegister)) { + LLVM_DEBUG(dbgs() << "Instruction already has a SPIR-V type.\n"); + if (!MRI.getRegClassOrNull(ResultRegister)) { + LLVM_DEBUG(dbgs() << "Updating the register class.\n"); + setRegClassType(ResultRegister, GR->getSPIRVTypeForVReg(ResultRegister), + GR, &MRI, *GR->CurMF, true); + } + return false; + } + + return true; +} + +static void registerSpirvTypeForNewInstructions(MachineFunction &MF, + SPIRVGlobalRegistry *GR) { + MachineRegisterInfo &MRI = MF.getRegInfo(); + SmallVector Worklist; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &I : MBB) { - const unsigned Opcode = I.getOpcode(); - if (Opcode == TargetOpcode::G_UNMERGE_VALUES) { - unsigned ArgI = I.getNumOperands() - 1; - Register SrcReg = I.getOperand(ArgI).isReg() - ? I.getOperand(ArgI).getReg() - : Register(0); - SPIRVType *DefType = - SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr; - if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector) - report_fatal_error( - "cannot select G_UNMERGE_VALUES with a non-vector argument"); - SPIRVType *ScalarType = - GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg()); - for (unsigned i = 0; i < I.getNumDefs(); ++i) { - Register ResVReg = I.getOperand(i).getReg(); - SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg); - if (!ResType) { - // There was no "assign type" actions, let's fix this now - ResType = ScalarType; - setRegClassType(ResVReg, ResType, GR, &MRI, *GR->CurMF, true); - } - } - } else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 && - I.getNumOperands() > 1 && I.getOperand(1).isReg()) { - // Legalizer may have added a new instructions and introduced new - // registers, we must decorate them as if they were introduced in a - // non-automatic way - Register ResVReg = I.getOperand(0).getReg(); - // Check if the register defined by the instruction is newly generated - // or already processed - // Check if we have type defined for operands of the new instruction - bool IsKnownReg = MRI.getRegClassOrNull(ResVReg); - SPIRVType *ResVType = GR->getSPIRVTypeForVReg( - IsKnownReg ? ResVReg : I.getOperand(1).getReg()); - if (!ResVType) - continue; - // Set type & class - if (!IsKnownReg) - setRegClassType(ResVReg, ResVType, GR, &MRI, *GR->CurMF, true); - // If this is a simple operation that is to be reduced by TableGen - // definition we must apply some of pre-legalizer rules here - if (isTypeFoldingSupported(Opcode)) { - processInstr(I, MIB, MRI, GR, GR->getSPIRVTypeForVReg(ResVReg)); - if (IsKnownReg && MRI.hasOneUse(ResVReg)) { - MachineInstr &UseMI = *MRI.use_instr_begin(ResVReg); - if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE) - continue; - } - insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI); + if (requiresSpirvType(I, GR, MRI)) { + Worklist.push_back(&I); + } + } + } + + if (Worklist.empty()) { + LLVM_DEBUG(dbgs() << "Initial worklist is empty.\n"); + return; + } + + LLVM_DEBUG(dbgs() << "Initial worklist:\n"; + for (auto *I : Worklist) { I->dump(); }); + + bool Changed = true; + while (Changed) { + Changed = false; + SmallVector NextWorklist; + + for (MachineInstr *I : Worklist) { + MachineIRBuilder MIB(*I); + if (deduceAndAssignSpirvType(I, MF, GR, MIB)) { + Changed = true; + } else { + NextWorklist.push_back(I); + } + } + Worklist = NextWorklist; + LLVM_DEBUG(dbgs() << "Worklist size: " << Worklist.size() << "\n"); + } + + if (!Worklist.empty()) { + LLVM_DEBUG(dbgs() << "Remaining worklist:\n"; + for (auto *I : Worklist) { I->dump(); }); + for (auto *I : Worklist) { + MachineIRBuilder MIB(*I); + Register ResVReg = I->getOperand(0).getReg(); + const LLT &ResLLT = MRI.getType(ResVReg); + SPIRVType *ResType = nullptr; + if (ResLLT.isVector()) { + SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType( + ResLLT.getElementType().getSizeInBits(), MIB); + ResType = GR->getOrCreateSPIRVVectorType( + CompType, ResLLT.getNumElements(), MIB, false); + } else { + ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB); + } + setRegClassType(ResVReg, ResType, GR, &MRI, MF, true); + } + } +} + +static void ensureAssignTypeForTypeFolding(MachineFunction &MF, + SPIRVGlobalRegistry *GR) { + LLVM_DEBUG(dbgs() << "Entering ensureAssignTypeForTypeFolding for function " + << MF.getName() << "\n"); + MachineRegisterInfo &MRI = MF.getRegInfo(); + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (!isTypeFoldingSupported(MI.getOpcode())) + continue; + if (MI.getNumOperands() == 1 || !MI.getOperand(1).isReg()) + continue; + + LLVM_DEBUG(dbgs() << "Processing instruction: " << MI); + + // Check uses of MI to see if it already has an use in SPIRV::ASSIGN_TYPE + bool HasAssignType = false; + Register ResultRegister = MI.defs().begin()->getReg(); + // All uses of Result register + for (MachineInstr &UseInstr : + MRI.use_nodbg_instructions(ResultRegister)) { + if (UseInstr.getOpcode() == SPIRV::ASSIGN_TYPE) { + HasAssignType = true; + LLVM_DEBUG(dbgs() << " Instruction already has an ASSIGN_TYPE use: " + << UseInstr); + break; } } + + if (!HasAssignType) { + Register ResultRegister = MI.defs().begin()->getReg(); + SPIRVType *ResultType = GR->getSPIRVTypeForVReg(ResultRegister); + LLVM_DEBUG( + dbgs() << " Adding ASSIGN_TYPE for ResultRegister: " + << printReg(ResultRegister, MRI.getTargetRegisterInfo()) + << " with type: " << *ResultType); + MachineIRBuilder MIB(MI); + insertAssignInstr(ResultRegister, nullptr, ResultType, GR, MIB, MRI); + } } } } +static void lowerExtractVectorElements(MachineFunction &MF) { + SmallVector ExtractInstrs; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) { + ExtractInstrs.push_back(&MI); + } + } + } + + for (MachineInstr *MI : ExtractInstrs) { + MachineIRBuilder MIB(*MI); + Register Dst = MI->getOperand(0).getReg(); + Register Vec = MI->getOperand(1).getReg(); + Register Idx = MI->getOperand(2).getReg(); + + auto Intr = MIB.buildIntrinsic(Intrinsic::spv_extractelt, Dst, true, false); + Intr.addUse(Vec); + Intr.addUse(Idx); + + MI->eraseFromParent(); + } +} + // Do a preorder traversal of the CFG starting from the BB |Start|. // point. Calls |op| on each basic block encountered during the traversal. void visit(MachineFunction &MF, MachineBasicBlock &Start, @@ -155,9 +545,9 @@ bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) { const SPIRVSubtarget &ST = MF.getSubtarget(); SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); GR->setCurrentFunc(MF); - MachineIRBuilder MIB(MF); - - processNewInstrs(MF, GR, MIB); + registerSpirvTypeForNewInstructions(MF, GR); + ensureAssignTypeForTypeFolding(MF, GR); + lowerExtractVectorElements(MF); return true; } diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index db6f2d61e8f29..d538009f0ecbe 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -192,31 +192,43 @@ static void buildOpBitcast(SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, .addUse(OpReg); } -// We do instruction selections early instead of calling MIB.buildBitcast() -// generating the general op code G_BITCAST. When MachineVerifier validates -// G_BITCAST we see a check of a kind: if Source Type is equal to Destination -// Type then report error "bitcast must change the type". This doesn't take into -// account the notion of a typed pointer that is important for SPIR-V where a -// user may and should use bitcast between pointers with different pointee types -// (https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpBitcast). -// It's important for correct lowering in SPIR-V, because interpretation of the -// data type is not left to instructions that utilize the pointer, but encoded -// by the pointer declaration, and the SPIRV target can and must handle the -// declaration and use of pointers that specify the type of data they point to. -// It's not feasible to improve validation of G_BITCAST using just information -// provided by low level types of source and destination. Therefore we don't -// produce G_BITCAST as the general op code with semantics different from -// OpBitcast, but rather lower to OpBitcast immediately. As for now, the only -// difference would be that CombinerHelper couldn't transform known patterns -// around G_BUILD_VECTOR. See discussion -// in https://github.com/llvm/llvm-project/pull/110270 for even more context. -static void selectOpBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, - MachineIRBuilder MIB) { +// We lower G_BITCAST to OpBitcast here to avoid a MachineVerifier error. +// The verifier checks if the source and destination LLTs of a G_BITCAST are +// different, but this check is too strict for SPIR-V's typed pointers, which +// may have the same LLT but different SPIRVType (e.g. pointers to different +// pointee types). By lowering to OpBitcast here, we bypass the verifier's +// check. See discussion in https://github.com/llvm/llvm-project/pull/110270 +// for more context. +// +// We also handle the llvm.spv.bitcast intrinsic here. If the source and +// destination SPIR-V types are the same, we lower it to a COPY to enable +// further optimizations like copy propagation. +static void lowerBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, + MachineIRBuilder MIB) { SmallVector ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { + if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(2).getReg(); + SPIRVType *DstType = GR->getSPIRVTypeForVReg(DstReg); + assert( + DstType && + "Expected destination SPIR-V type to have been assigned already."); + SPIRVType *SrcType = GR->getSPIRVTypeForVReg(SrcReg); + assert(SrcType && + "Expected source SPIR-V type to have been assigned already."); + if (DstType == SrcType) { + MIB.setInsertPt(*MI.getParent(), MI); + MIB.buildCopy(DstReg, SrcReg); + ToErase.push_back(&MI); + continue; + } + } + if (MI.getOpcode() != TargetOpcode::G_BITCAST) continue; + MIB.setInsertPt(*MI.getParent(), MI); buildOpBitcast(GR, MIB, MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); @@ -237,16 +249,11 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR, SmallVector ToErase; for (MachineBasicBlock &MBB : MF) { for (MachineInstr &MI : MBB) { - if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && - !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) + if (!isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) continue; assert(MI.getOperand(2).isReg()); MIB.setInsertPt(*MI.getParent(), MI); ToErase.push_back(&MI); - if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) { - MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg()); - continue; - } Register Def = MI.getOperand(0).getReg(); Register Source = MI.getOperand(2).getReg(); Type *ElemTy = getMDOperandAsType(MI.getOperand(3).getMetadata(), 0); @@ -1089,7 +1096,7 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) { removeImplicitFallthroughs(MF, MIB); insertSpirvDecorations(MF, GR, MIB); insertInlineAsm(MF, GR, ST, MIB); - selectOpBitcasts(MF, GR, MIB); + lowerBitcasts(MF, GR, MIB); return true; } diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll index a97492b8453ea..a15d628cc3614 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/AddUint64.ll @@ -63,7 +63,7 @@ entry: ; CHECK: %[[#a_high:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#a]] %[[#undef_v4i32]] 1 3 ; CHECK: %[[#b_low:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#b]] %[[#undef_v4i32]] 0 2 ; CHECK: %[[#b_high:]] = OpVectorShuffle %[[#vec2_int_32]] %[[#b]] %[[#undef_v4i32]] 1 3 -; CHECK: %[[#iaddcarry:]] = OpIAddCarry %[[#struct_v2i32_v2i32]] %[[#a_low]] %[[#vec2_int_32]] +; CHECK: %[[#iaddcarry:]] = OpIAddCarry %[[#struct_v2i32_v2i32]] %[[#a_low]] %[[#b_low]] ; CHECK: %[[#lowsum:]] = OpCompositeExtract %[[#vec2_int_32]] %[[#iaddcarry]] 0 ; CHECK: %[[#carry:]] = OpCompositeExtract %[[#vec2_int_32]] %[[#iaddcarry]] 1 ; CHECK: %[[#carry_ne0:]] = OpINotEqual %[[#vec2_bool]] %[[#carry]] %[[#const_v2i32_0_0]] diff --git a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll index ed67344842b11..4817e7450ac2e 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-resources/issue-146942-ptr-cast.ll @@ -16,7 +16,6 @@ define void @case1() local_unnamed_addr { ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16 ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]] - ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3 %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str) %2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.2) %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0) @@ -29,8 +28,7 @@ define void @case1() local_unnamed_addr { define void @case2() local_unnamed_addr { ; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16 ; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]] - ; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3 - ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2 + ; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#CAST_LOAD]] %[[#UNDEF_INT4]] 0 1 2 %1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str) %2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, ptr nonnull @.str.3) %3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0) diff --git a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll new file mode 100644 index 0000000000000..fbfec1b3ee7cf --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll @@ -0,0 +1,84 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion" +; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4 +; CHECK-DAG: %[[#double:]] = OpTypeFloat 64 +; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4 +; CHECK-DAG: %[[#v2i32:]] = OpTypeVector %[[#int]] 2 +; CHECK-DAG: %[[#ptr_private_v4i32:]] = OpTypePointer Private %[[#v4i32]] +; CHECK-DAG: %[[#ptr_private_v4f64:]] = OpTypePointer Private %[[#v4f64]] +; CHECK-DAG: %[[#global_double:]] = OpVariable %[[#ptr_private_v4f64]] Private + +@G_16 = internal addrspace(10) global [16 x i32] zeroinitializer +@G_4_double = internal addrspace(10) global <4 x double> zeroinitializer +@G_4_int = internal addrspace(10) global <4 x i32> zeroinitializer + + +; This is the way matrices will be represented in HLSL. The memory type will be +; an array, but it will be loaded as a vector. +; TODO: Legalization for loads and stores of long vectors is not implemented yet. │ +;define spir_func void @test_load_store_global() { │ +;entry: │ +; %0 = load <16 x i32>, ptr addrspace(10) @G_16, align 64 │ +; store <16 x i32> %0, ptr addrspace(10) @G_16, align 64 │ +; ret void │ +;} + +; This is the code pattern that can be generated from the `asuint` and `asdouble` +; HLSL intrinsics. + +; TODO: This cods not the best because instruction selection is not folding an +; extract from other intstruction. That needs to be handled. +define spir_func void @test_int32_double_conversion() { +; CHECK: %[[#test_int32_double_conversion]] = OpFunction +entry: + ; CHECK: %[[#LOAD:]] = OpLoad %[[#v4f64]] %[[#global_double]] + ; CHECK: %[[#VEC_SHUF1:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 1 + ; CHECK: %[[#VEC_SHUF2:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 2 3 + ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF1]] + ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF2]] + ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 0 + ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 2 + ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 0 + ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 2 + ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %[[#EXTRACT4]] + ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 1 + ; CHECK: %[[#EXTRACT6:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 3 + ; CHECK: %[[#EXTRACT7:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 1 + ; CHECK: %[[#EXTRACT8:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 3 + ; CHECK: %[[#CONSTRUCT2:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT5]] %[[#EXTRACT6]] %[[#EXTRACT7]] %[[#EXTRACT8]] + ; CHECK: %[[#EXTRACT9:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 0 + ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 0 + ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 1 + ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 1 + ; CHECK: %[[#EXTRACT13:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 2 + ; CHECK: %[[#EXTRACT14:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 2 + ; CHECK: %[[#EXTRACT15:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 3 + ; CHECK: %[[#EXTRACT16:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 3 + ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT9]] %[[#EXTRACT10]] + ; CHECK: %[[#CONSTRUCT4:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT11]] %[[#EXTRACT12]] + ; CHECK: %[[#CONSTRUCT5:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT13]] %[[#EXTRACT14]] + ; CHECK: %[[#CONSTRUCT6:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT15]] %[[#EXTRACT16]] + ; CHECK: %[[#BITCAST3:]] = OpBitcast %[[#double]] %[[#CONSTRUCT3]] + ; CHECK: %[[#BITCAST4:]] = OpBitcast %[[#double]] %[[#CONSTRUCT4]] + ; CHECK: %[[#BITCAST5:]] = OpBitcast %[[#double]] %[[#CONSTRUCT5]] + ; CHECK: %[[#BITCAST6:]] = OpBitcast %[[#double]] %[[#CONSTRUCT6]] + ; CHECK: %[[#CONSTRUCT7:]] = OpCompositeConstruct %[[#v4f64]] %[[#BITCAST3]] %[[#BITCAST4]] %[[#BITCAST5]] %[[#BITCAST6]] + ; CHECK: OpStore %[[#global_double]] %[[#CONSTRUCT7]] Aligned 32 + + %0 = load <8 x i32>, ptr addrspace(10) @G_4_double + %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> + store <8 x i32> %3, ptr addrspace(10) @G_4_double + ret void +} + +; Add a main function to make it a valid module for spirv-val +define void @main() #1 { + ret void +} + +attributes #1 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll new file mode 100644 index 0000000000000..3e39cb78800ee --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/legalization/vector-legalization-kernel.ll @@ -0,0 +1,70 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion" +; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v8i32:]] = OpTypeVector %[[#int]] 8 +; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4 +; CHECK-DAG: %[[#ptr_func_v8i32:]] = OpTypePointer Function %[[#v8i32]] + +; CHECK-DAG: OpName %[[#test_v3f64_conversion:]] "test_v3f64_conversion" +; CHECK-DAG: %[[#double:]] = OpTypeFloat 64 +; CHECK-DAG: %[[#v3f64:]] = OpTypeVector %[[#double]] 3 +; CHECK-DAG: %[[#ptr_func_v3f64:]] = OpTypePointer Function %[[#v3f64]] +; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4 + +define spir_kernel void @test_int32_double_conversion(ptr %G_vec) { +; CHECK: %[[#test_int32_double_conversion]] = OpFunction +; CHECK: %[[#param:]] = OpFunctionParameter %[[#ptr_func_v8i32]] +entry: + ; CHECK: %[[#LOAD:]] = OpLoad %[[#v8i32]] %[[#param]] + ; CHECK: %[[#SHUF1:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 2 4 6 + ; CHECK: %[[#SHUF2:]] = OpVectorShuffle %[[#v4i32]] %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 1 3 5 7 + ; CHECK: %[[#SHUF3:]] = OpVectorShuffle %[[#v8i32]] %[[#SHUF1]] %[[#SHUF2]] 0 4 1 5 2 6 3 7 + ; CHECK: OpStore %[[#param]] %[[#SHUF3]] + + %0 = load <8 x i32>, ptr %G_vec + %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> + %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> + store <8 x i32> %3, ptr %G_vec + ret void +} + +define spir_kernel void @test_v3f64_conversion(ptr %G_vec) { +; CHECK: %[[#test_v3f64_conversion:]] = OpFunction +; CHECK: %[[#param_v3f64:]] = OpFunctionParameter %[[#ptr_func_v3f64]] +entry: + ; CHECK: %[[#LOAD:]] = OpLoad %[[#v3f64]] %[[#param_v3f64]] + ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 0 + ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 1 + ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#double]] %[[#LOAD]] 2 + ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4f64]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %{{[a-zA-Z0-9_]+}} + ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v8i32]] %[[#CONSTRUCT1]] + ; CHECK: %[[#SHUFFLE1:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 0 2 4 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF + ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE1]] 0 + ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE1]] 1 + ; CHECK: %[[#EXTRACT6:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE1]] 2 + ; CHECK: %[[#SHUFFLE2:]] = OpVectorShuffle %[[#v8i32]] %[[#BITCAST1]] %{{[a-zA-Z0-9_]+}} 1 3 5 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF 0xFFFFFFFF + ; CHECK: %[[#EXTRACT7:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE2]] 0 + ; CHECK: %[[#EXTRACT8:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE2]] 1 + ; CHECK: %[[#EXTRACT9:]] = OpCompositeExtract %[[#int]] %[[#SHUFFLE2]] 2 + ; CHECK: %[[#CONSTRUCT2:]] = OpCompositeConstruct %[[#v8i32]] %[[#EXTRACT4]] %[[#EXTRACT5]] %[[#EXTRACT6]] %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} + ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v8i32]] %[[#EXTRACT7]] %[[#EXTRACT8]] %[[#EXTRACT9]] %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} %{{[a-zA-Z0-9_]+}} + ; CHECK: %[[#SHUFFLE2:]] = OpVectorShuffle %[[#v8i32]] %[[#CONSTRUCT2]] %[[#CONSTRUCT3]] 0 8 1 9 2 10 0xFFFFFFFF 0xFFFFFFFF + ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4f64]] %[[#SHUFFLE2]] + ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 0 + ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 1 + ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#double]] %[[#BITCAST2]] 2 + ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v3f64]] %[[#EXTRACT10]] %[[#EXTRACT11]] %[[#EXTRACT12]] + ; CHECK: OpStore %[[#param_v3f64]] %[[#CONSTRUCT3]] + %0 = load <3 x double>, ptr %G_vec + %1 = bitcast <3 x double> %0 to <6 x i32> + %2 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> + %3 = shufflevector <6 x i32> %1, <6 x i32> poison, <3 x i32> + %4 = shufflevector <3 x i32> %2, <3 x i32> %3, <6 x i32> + %5 = bitcast <6 x i32> %4 to <3 x double> + store <3 x double> %5, ptr %G_vec + ret void +} + diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll index 84913283f6868..a1ec2cd1cfdd2 100644 --- a/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll +++ b/llvm/test/CodeGen/SPIRV/pointers/ptrcast-bitcast.ll @@ -26,3 +26,25 @@ entry: store <4 x i32> %6, ptr addrspace(11) %7, align 16 ret void } + +; This tests a load from a pointer that has been bitcast between vector types +; which share the same total bit-width but have different numbers of elements. +; Tests that legalize-pointer-casts works correctly by moving the bitcast to +; the element that was loaded. + +define void @main2() local_unnamed_addr #0 { +entry: +; CHECK: %[[LOAD:[0-9]+]] = OpLoad %[[#v2_double]] {{.*}} +; CHECK: %[[BITCAST1:[0-9]+]] = OpBitcast %[[#v4_uint]] %[[LOAD]] +; CHECK: %[[BITCAST2:[0-9]+]] = OpBitcast %[[#v2_double]] %[[BITCAST1]] +; CHECK: OpStore {{%[0-9]+}} %[[BITCAST2]] {{.*}} + + %0 = tail call target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v2f64_12_1t(i32 0, i32 2, i32 1, i32 0, ptr nonnull @.str.2) + %2 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 0) + %3 = load <4 x i32>, ptr addrspace(11) %2 + %4 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v2f64_12_1t(target("spirv.VulkanBuffer", [0 x <2 x double>], 12, 1) %0, i32 1) + store <4 x i32> %3, ptr addrspace(11) %4 + ret void +} + +attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" } diff --git a/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll b/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll index 7548f4757dbe6..6fc03a386d14d 100644 --- a/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll +++ b/llvm/test/CodeGen/SPIRV/pointers/resource-vector-load-store.ll @@ -4,18 +4,23 @@ @.str = private unnamed_addr constant [7 x i8] c"buffer\00", align 1 +; The i64 values in the extracts will be turned +; into immidiate values. There should be no 64-bit +; integers in the module. +; CHECK-NOT: OpTypeInt 64 0 + define void @main() "hlsl.shader"="pixel" { -; CHECK: %24 = OpFunction %2 None %3 ; -- Begin function main -; CHECK-NEXT: %1 = OpLabel -; CHECK-NEXT: %25 = OpVariable %13 Function %22 -; CHECK-NEXT: %26 = OpLoad %7 %23 -; CHECK-NEXT: %27 = OpImageRead %5 %26 %15 -; CHECK-NEXT: %28 = OpCompositeExtract %4 %27 0 -; CHECK-NEXT: %29 = OpCompositeExtract %4 %27 1 -; CHECK-NEXT: %30 = OpFAdd %4 %29 %28 -; CHECK-NEXT: %31 = OpCompositeInsert %5 %30 %27 0 -; CHECK-NEXT: %32 = OpLoad %7 %23 -; CHECK-NEXT: OpImageWrite %32 %15 %31 +; CHECK: %[[FUNC:[0-9]+]] = OpFunction %[[VOID:[0-9]+]] None %[[FNTYPE:[0-9]+]] ; -- Begin function main +; CHECK-NEXT: %[[LABEL:[0-9]+]] = OpLabel +; CHECK-NEXT: %[[VAR:[0-9]+]] = OpVariable %[[PTR_FN:[a-zA-Z0-9_]+]] Function %[[INIT:[a-zA-Z0-9_]+]] +; CHECK-NEXT: %[[LOAD1:[0-9]+]] = OpLoad %[[IMG_TYPE:[a-zA-Z0-9_]+]] %[[IMG_VAR:[a-zA-Z0-9_]+]] +; CHECK-NEXT: %[[READ:[0-9]+]] = OpImageRead %[[VEC4:[a-zA-Z0-9_]+]] %[[LOAD1]] %[[COORD:[a-zA-Z0-9_]+]] +; CHECK-NEXT: %[[EXTRACT1:[0-9]+]] = OpCompositeExtract %[[FLOAT:[a-zA-Z0-9_]+]] %[[READ]] 0 +; CHECK-NEXT: %[[EXTRACT2:[0-9]+]] = OpCompositeExtract %[[FLOAT]] %[[READ]] 1 +; CHECK-NEXT: %[[ADD:[0-9]+]] = OpFAdd %[[FLOAT]] %[[EXTRACT2]] %[[EXTRACT1]] +; CHECK-NEXT: %[[INSERT:[0-9]+]] = OpCompositeInsert %[[VEC4]] %[[ADD]] %[[READ]] 0 +; CHECK-NEXT: %[[LOAD2:[0-9]+]] = OpLoad %[[IMG_TYPE]] %[[IMG_VAR]] +; CHECK-NEXT: OpImageWrite %[[LOAD2]] %[[COORD]] %[[INSERT]] ; CHECK-NEXT: OpReturn ; CHECK-NEXT: OpFunctionEnd entry: