Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,16 @@ LLVM_ABI LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size);
LLVM_ABI LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx,
unsigned Size);

/// True iff the specified type index is a vector with an element size
/// that's greater than the given size.
LLVM_ABI LegalityPredicate vectorElementCountIsGreaterThan(unsigned TypeIdx,
unsigned Size);

/// True iff the specified type index is a vector with an element size
/// that's less than or equal to the given size.
LLVM_ABI LegalityPredicate
vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx, unsigned Size);

/// True iff the specified type index is a scalar or a vector with an element
/// type that's wider than the given size.
LLVM_ABI LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx,
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,26 @@ LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx,
};
}

LegalityPredicate
LegalityPredicates::vectorElementCountIsGreaterThan(unsigned TypeIdx,
unsigned Size) {

return [=](const LegalityQuery &Query) {
const LLT QueryTy = Query.Types[TypeIdx];
return QueryTy.isFixedVector() && QueryTy.getNumElements() > Size;
};
}

LegalityPredicate
LegalityPredicates::vectorElementCountIsLessThanOrEqualTo(unsigned TypeIdx,
unsigned Size) {

return [=](const LegalityQuery &Query) {
const LLT QueryTy = Query.Types[TypeIdx];
return QueryTy.isFixedVector() && QueryTy.getNumElements() <= Size;
};
}

LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx,
unsigned Size) {
return [=](const LegalityQuery &Query) {
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class Op<bits<16> Opcode, dag outs, dag ins, string asmstr, list<dag> pattern =
let Pattern = pattern;
}

class PureOp<bits<16> Opcode, dag outs, dag ins, string asmstr,
list<dag> pattern = []> : Op<Opcode, outs, ins, asmstr, pattern> {
let hasSideEffects = 0;
}

class UnknownOp<dag outs, dag ins, string asmstr, list<dag> pattern = []>
: Op<0, outs, ins, asmstr, pattern> {
let isPseudo = 1;
Expand Down
179 changes: 108 additions & 71 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -163,52 +163,74 @@ def OpExecutionModeId: Op<331, (outs), (ins ID:$entry, ExecutionMode:$mode, vari

// 3.42.6 Type-Declaration Instructions

def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
"$type = OpTypeFloat $width">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
"$type = OpTypeMatrix $colType $colCount">;
def OpTypeImage: Op<25, (outs TYPE:$res), (ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
i32imm:$arrayed, i32imm:$MS, i32imm:$sampled, ImageFormat:$imFormat, variable_ops),
"$res = OpTypeImage $sampTy $dim $depth $arrayed $MS $sampled $imFormat">;
def OpTypeSampler: Op<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
def OpTypeSampledImage: Op<27, (outs TYPE:$res), (ins TYPE:$imageType),
"$res = OpTypeSampledImage $imageType">;
def OpTypeArray: Op<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
"$type = OpTypeArray $elementType $length">;
def OpTypeRuntimeArray: Op<29, (outs TYPE:$type), (ins TYPE:$elementType),
"$type = OpTypeRuntimeArray $elementType">;
def OpTypeStruct: Op<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
def OpTypeStructContinuedINTEL: Op<6090, (outs), (ins variable_ops),
"OpTypeStructContinuedINTEL">;
def OpTypeOpaque: Op<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
"$res = OpTypeOpaque $name">;
def OpTypePointer: Op<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
"$res = OpTypePointer $storage $type">;
def OpTypeFunction: Op<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
"$funcType = OpTypeFunction $returnType">;
def OpTypeEvent: Op<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
def OpTypeDeviceEvent: Op<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
def OpTypeReserveId: Op<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
def OpTypeQueue: Op<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
def OpTypePipe: Op<38, (outs TYPE:$res), (ins AccessQualifier:$a), "$res = OpTypePipe $a">;
def OpTypeForwardPointer: Op<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
"OpTypeForwardPointer $ptrType $storageClass">;
def OpTypePipeStorage: Op<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
def OpTypeNamedBarrier: Op<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
"$res = OpTypeAccelerationStructureNV">;
def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
"$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
"$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
def OpTypeVoid : PureOp<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool : PureOp<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt
: PureOp<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
def OpTypeFloat
: PureOp<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
"$type = OpTypeFloat $width">;
def OpTypeVector
: PureOp<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
def OpTypeMatrix
: PureOp<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
"$type = OpTypeMatrix $colType $colCount">;
def OpTypeImage : PureOp<25, (outs TYPE:$res),
(ins TYPE:$sampTy, Dim:$dim, i32imm:$depth,
i32imm:$arrayed, i32imm:$MS, i32imm:$sampled,
ImageFormat:$imFormat, variable_ops),
"$res = OpTypeImage $sampTy $dim $depth $arrayed $MS "
"$sampled $imFormat">;
def OpTypeSampler : PureOp<26, (outs TYPE:$res), (ins), "$res = OpTypeSampler">;
def OpTypeSampledImage : PureOp<27, (outs TYPE:$res), (ins TYPE:$imageType),
"$res = OpTypeSampledImage $imageType">;
def OpTypeArray
: PureOp<28, (outs TYPE:$type), (ins TYPE:$elementType, ID:$length),
"$type = OpTypeArray $elementType $length">;
def OpTypeRuntimeArray : PureOp<29, (outs TYPE:$type), (ins TYPE:$elementType),
"$type = OpTypeRuntimeArray $elementType">;
def OpTypeStruct
: PureOp<30, (outs TYPE:$res), (ins variable_ops), "$res = OpTypeStruct">;
def OpTypeStructContinuedINTEL
: PureOp<6090, (outs), (ins variable_ops), "OpTypeStructContinuedINTEL">;
def OpTypeOpaque
: PureOp<31, (outs TYPE:$res), (ins StringImm:$name, variable_ops),
"$res = OpTypeOpaque $name">;
def OpTypePointer
: PureOp<32, (outs TYPE:$res), (ins StorageClass:$storage, TYPE:$type),
"$res = OpTypePointer $storage $type">;
def OpTypeFunction
: PureOp<33, (outs TYPE:$funcType), (ins TYPE:$returnType, variable_ops),
"$funcType = OpTypeFunction $returnType">;
def OpTypeEvent : PureOp<34, (outs TYPE:$res), (ins), "$res = OpTypeEvent">;
def OpTypeDeviceEvent
: PureOp<35, (outs TYPE:$res), (ins), "$res = OpTypeDeviceEvent">;
def OpTypeReserveId
: PureOp<36, (outs TYPE:$res), (ins), "$res = OpTypeReserveId">;
def OpTypeQueue : PureOp<37, (outs TYPE:$res), (ins), "$res = OpTypeQueue">;
def OpTypePipe : PureOp<38, (outs TYPE:$res), (ins AccessQualifier:$a),
"$res = OpTypePipe $a">;
def OpTypeForwardPointer
: PureOp<39, (outs), (ins TYPE:$ptrType, StorageClass:$storageClass),
"OpTypeForwardPointer $ptrType $storageClass">;
def OpTypePipeStorage
: PureOp<322, (outs TYPE:$res), (ins), "$res = OpTypePipeStorage">;
def OpTypeNamedBarrier
: PureOp<327, (outs TYPE:$res), (ins), "$res = OpTypeNamedBarrier">;
def OpTypeAccelerationStructureNV
: PureOp<5341, (outs TYPE:$res), (ins),
"$res = OpTypeAccelerationStructureNV">;
def OpTypeCooperativeMatrixNV
: PureOp<5358, (outs TYPE:$res),
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
"$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
def OpTypeCooperativeMatrixKHR
: PureOp<4456, (outs TYPE:$res),
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
"$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols "
"$use">;

// 3.42.7 Constant-Creation Instructions

Expand All @@ -222,31 +244,46 @@ defm OpConstant: IntFPImm<43, "OpConstant">;

def ConstPseudoTrue: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
def ConstPseudoFalse: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
def OpConstantTrue: Op<41, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantTrue $src_ty",
[(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
def OpConstantFalse: Op<42, (outs iID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantFalse $src_ty",
[(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;

def OpConstantComposite: Op<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
"$res = OpConstantComposite $type">;
def OpConstantCompositeContinuedINTEL: Op<6091, (outs), (ins variable_ops),
"OpConstantCompositeContinuedINTEL">;

def OpConstantSampler: Op<45, (outs ID:$res),
(ins TYPE:$t, SamplerAddressingMode:$s, i32imm:$p, SamplerFilterMode:$f),
"$res = OpConstantSampler $t $s $p $f">;
def OpConstantNull: Op<46, (outs ID:$dst), (ins TYPE:$src_ty), "$dst = OpConstantNull $src_ty">;

def OpSpecConstantTrue: Op<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
def OpSpecConstantFalse: Op<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
def OpSpecConstant: Op<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
"$res = OpSpecConstant $type $imm">;
def OpSpecConstantComposite: Op<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
"$res = OpSpecConstantComposite $type">;
def OpSpecConstantCompositeContinuedINTEL: Op<6092, (outs), (ins variable_ops),
"OpSpecConstantCompositeContinuedINTEL">;
def OpSpecConstantOp: Op<52, (outs ID:$res), (ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
"$res = OpSpecConstantOp $t $c $o">;
def OpConstantTrue
: PureOp<41, (outs iID:$dst), (ins TYPE:$src_ty),
"$dst = OpConstantTrue $src_ty",
[(set iID:$dst, (assigntype ConstPseudoTrue, TYPE:$src_ty))]>;
def OpConstantFalse
: PureOp<42, (outs iID:$dst), (ins TYPE:$src_ty),
"$dst = OpConstantFalse $src_ty",
[(set iID:$dst, (assigntype ConstPseudoFalse, TYPE:$src_ty))]>;

def OpConstantComposite
: PureOp<44, (outs ID:$res), (ins TYPE:$type, variable_ops),
"$res = OpConstantComposite $type">;
def OpConstantCompositeContinuedINTEL
: PureOp<6091, (outs), (ins variable_ops),
"OpConstantCompositeContinuedINTEL">;

def OpConstantSampler : PureOp<45, (outs ID:$res),
(ins TYPE:$t, SamplerAddressingMode:$s,
i32imm:$p, SamplerFilterMode:$f),
"$res = OpConstantSampler $t $s $p $f">;
def OpConstantNull : PureOp<46, (outs ID:$dst), (ins TYPE:$src_ty),
"$dst = OpConstantNull $src_ty">;

def OpSpecConstantTrue
: PureOp<48, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantTrue $t">;
def OpSpecConstantFalse
: PureOp<49, (outs ID:$r), (ins TYPE:$t), "$r = OpSpecConstantFalse $t">;
def OpSpecConstant
: PureOp<50, (outs ID:$res), (ins TYPE:$type, i32imm:$imm, variable_ops),
"$res = OpSpecConstant $type $imm">;
def OpSpecConstantComposite
: PureOp<51, (outs ID:$res), (ins TYPE:$type, variable_ops),
"$res = OpSpecConstantComposite $type">;
def OpSpecConstantCompositeContinuedINTEL
: PureOp<6092, (outs), (ins variable_ops),
"OpSpecConstantCompositeContinuedINTEL">;
def OpSpecConstantOp
: PureOp<52, (outs ID:$res),
(ins TYPE:$t, SpecConstantOpOperands:$c, ID:$o, variable_ops),
"$res = OpSpecConstantOp $t $c $o">;

// 3.42.8 Memory Instructions

Expand Down
58 changes: 45 additions & 13 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1526,33 +1526,57 @@ bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
unsigned ArgI = I.getNumOperands() - 1;
Register SrcReg =
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
SPIRVType *DefType =
SPIRVType *SrcType =
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
if (!SrcType || SrcType->getOpcode() != SPIRV::OpTypeVector)
report_fatal_error(
"cannot select G_UNMERGE_VALUES with a non-vector argument");

SPIRVType *ScalarType =
GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
GR.getSPIRVTypeForVReg(SrcType->getOperand(1).getReg());
MachineBasicBlock &BB = *I.getParent();
bool Res = false;
unsigned CurrentIndex = 0;
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
Register ResVReg = I.getOperand(i).getReg();
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
if (!ResType) {
// There was no "assign type" actions, let's fix this now
ResType = ScalarType;
LLT ResLLT = MRI->getType(ResVReg);
assert(ResLLT.isValid());
if (ResLLT.isVector()) {
ResType = GR.getOrCreateSPIRVVectorType(
ScalarType, ResLLT.getNumElements(), I, TII);
} else {
ResType = ScalarType;
}
MRI->setRegClass(ResVReg, GR.getRegClass(ResType));
MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
}
auto MIB =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(SrcReg)
.addImm(static_cast<int64_t>(i));
Res |= MIB.constrainAllUses(TII, TRI, RBI);

if (ResType->getOpcode() == SPIRV::OpTypeVector) {
Register UndefReg = GR.getOrCreateUndef(I, SrcType, TII);
auto MIB =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpVectorShuffle))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(SrcReg)
.addUse(UndefReg);
unsigned NumElements = GR.getScalarOrVectorComponentCount(ResType);
for (unsigned j = 0; j < NumElements; ++j) {
MIB.addImm(CurrentIndex + j);
}
CurrentIndex += NumElements;
Res |= MIB.constrainAllUses(TII, TRI, RBI);
} else {
auto MIB =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(SrcReg)
.addImm(CurrentIndex);
CurrentIndex++;
Res |= MIB.constrainAllUses(TII, TRI, RBI);
}
}
return Res;
}
Expand Down Expand Up @@ -3119,6 +3143,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectInsertElt(ResVReg, ResType, I);
case Intrinsic::spv_gep:
return selectGEP(ResVReg, ResType, I);
case Intrinsic::spv_bitcast: {
Register OpReg = I.getOperand(2).getReg();
SPIRVType *OpType =
OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
if (!GR.isBitcastCompatible(ResType, OpType))
report_fatal_error("incompatible result and operand types in a bitcast");
return selectOpWithSrcs(ResVReg, ResType, I, {OpReg}, SPIRV::OpBitcast);
}
case Intrinsic::spv_unref_global:
case Intrinsic::spv_init_global: {
MachineInstr *MI = MRI->getVRegDef(I.getOperand(1).getReg());
Expand Down
9 changes: 8 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,23 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// Returns the loaded value.
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
FixedVectorType *TargetType, Value *Source) {
assert(TargetType->getNumElements() <= SourceType->getNumElements());
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
buildAssignType(B, SourceType, NewLoad);
Value *AssignValue = NewLoad;
if (TargetType->getElementType() != SourceType->getElementType()) {
const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
[[maybe_unused]] TypeSize TargetTypeSize =
DL.getTypeSizeInBits(TargetType);
[[maybe_unused]] TypeSize SourceTypeSize =
DL.getTypeSizeInBits(SourceType);
assert(TargetTypeSize == SourceTypeSize);
AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
{TargetType, SourceType}, {NewLoad});
buildAssignType(B, TargetType, AssignValue);
return AssignValue;
}

assert(TargetType->getNumElements() < SourceType->getNumElements());
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
Mask[I] = I;
Expand Down
Loading
Loading