Skip to content

Commit f57abf5

Browse files
authored
[SPIRV] Promote scalar arguments to vector for OpExtInst in generateExtInst instead of SPIRVRegularizer (#170155)
This patch consist of 2 parts: * A first part that removes the scalar to vector promotion for built-ins in the `SPIRVRegularizer`; * and a second part that implements the promotion for built-ins from scalar to vector in `generateExtInst`. The implementation in `SPIRVRegularizer` had several issues: * It rolled its own built-in pattern matching that was extremely permissive * the compiler would crash if the built-in had a definition * the compiler would crash if the built-in had no arguments * The compiler would crash if there were more than 2 function definitions in the module. * It'd be better if this was implemented as a module pass; where we iterate over the users of the function, instead of scanning the whole module for callers. This patch does the scalar to vector promotion just before the `OpExtInst` is generated. Without relying on the IR transformation. One change in the generated code from the previous implementation is that this version uses a single `OpCompositeConstruct` operation to convert the scalar into a vector. The old implementation inserted an element at the 0 position in an `undef` vector (using `OpCompositeInsert`); then copied that element for every vector element using `OpVectorShuffle`. This patch also adds a test (`OpExtInst_vector_promotion_bug.ll`) that highlights an issue in the builtin pattern matching that we're using: our pattern matching doesn't consider the number of arguments, only the demangled name, first and last arguments (`min(int,int,int)` matches the same builtin as `min(int, int)`).
1 parent 794551d commit f57abf5

File tree

5 files changed

+263
-117
lines changed

5 files changed

+263
-117
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,10 +1154,63 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) {
11541154
return arrayed ? numComps + 1 : numComps;
11551155
}
11561156

1157+
static bool builtinMayNeedPromotionToVec(uint32_t BuiltinNumber) {
1158+
switch (BuiltinNumber) {
1159+
case SPIRV::OpenCLExtInst::s_min:
1160+
case SPIRV::OpenCLExtInst::u_min:
1161+
case SPIRV::OpenCLExtInst::s_max:
1162+
case SPIRV::OpenCLExtInst::u_max:
1163+
case SPIRV::OpenCLExtInst::fmax:
1164+
case SPIRV::OpenCLExtInst::fmin:
1165+
case SPIRV::OpenCLExtInst::fmax_common:
1166+
case SPIRV::OpenCLExtInst::fmin_common:
1167+
case SPIRV::OpenCLExtInst::s_clamp:
1168+
case SPIRV::OpenCLExtInst::fclamp:
1169+
case SPIRV::OpenCLExtInst::u_clamp:
1170+
case SPIRV::OpenCLExtInst::mix:
1171+
case SPIRV::OpenCLExtInst::step:
1172+
case SPIRV::OpenCLExtInst::smoothstep:
1173+
return true;
1174+
default:
1175+
break;
1176+
}
1177+
return false;
1178+
}
1179+
11571180
//===----------------------------------------------------------------------===//
11581181
// Implementation functions for each builtin group
11591182
//===----------------------------------------------------------------------===//
11601183

1184+
static SmallVector<Register>
1185+
getBuiltinCallArguments(const SPIRV::IncomingCall *Call, uint32_t BuiltinNumber,
1186+
MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {
1187+
1188+
Register ReturnTypeId = GR->getSPIRVTypeID(Call->ReturnType);
1189+
unsigned ResultElementCount =
1190+
GR->getScalarOrVectorComponentCount(ReturnTypeId);
1191+
bool MayNeedPromotionToVec =
1192+
builtinMayNeedPromotionToVec(BuiltinNumber) && ResultElementCount > 1;
1193+
1194+
if (!MayNeedPromotionToVec)
1195+
return {Call->Arguments.begin(), Call->Arguments.end()};
1196+
1197+
SmallVector<Register> Arguments;
1198+
for (Register Argument : Call->Arguments) {
1199+
Register VecArg = Argument;
1200+
SPIRVType *ArgumentType = GR->getSPIRVTypeForVReg(Argument);
1201+
if (ArgumentType != Call->ReturnType) {
1202+
VecArg = createVirtualRegister(Call->ReturnType, GR, MIRBuilder);
1203+
auto VecSplat = MIRBuilder.buildInstr(SPIRV::OpCompositeConstruct)
1204+
.addDef(VecArg)
1205+
.addUse(ReturnTypeId);
1206+
for (unsigned I = 0; I != ResultElementCount; ++I)
1207+
VecSplat.addUse(Argument);
1208+
}
1209+
Arguments.push_back(VecArg);
1210+
}
1211+
return Arguments;
1212+
}
1213+
11611214
static bool generateExtInst(const SPIRV::IncomingCall *Call,
11621215
MachineIRBuilder &MIRBuilder,
11631216
SPIRVGlobalRegistry *GR, const CallBase &CB) {
@@ -1179,16 +1232,21 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
11791232
: SPIRV::OpenCLExtInst::fmax;
11801233
}
11811234

1235+
Register ReturnTypeId = GR->getSPIRVTypeID(Call->ReturnType);
1236+
SmallVector<Register> Arguments =
1237+
getBuiltinCallArguments(Call, Number, MIRBuilder, GR);
1238+
11821239
// Build extended instruction.
11831240
auto MIB =
11841241
MIRBuilder.buildInstr(SPIRV::OpExtInst)
11851242
.addDef(Call->ReturnRegister)
1186-
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
1243+
.addUse(ReturnTypeId)
11871244
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
11881245
.addImm(Number);
11891246

1190-
for (auto Argument : Call->Arguments)
1247+
for (Register Argument : Arguments)
11911248
MIB.addUse(Argument);
1249+
11921250
MIB.getInstr()->copyIRFlags(CB);
11931251
if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common ||
11941252
OrigNumber == SPIRV::OpenCLExtInst::fmax_common) {

llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp

Lines changed: 3 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "SPIRV.h"
15-
#include "llvm/Demangle/Demangle.h"
15+
#include "llvm/IR/Constants.h"
1616
#include "llvm/IR/InstIterator.h"
17-
#include "llvm/IR/InstVisitor.h"
17+
#include "llvm/IR/Instructions.h"
1818
#include "llvm/IR/PassManager.h"
19-
#include "llvm/Transforms/Utils/Cloning.h"
2019

2120
#include <list>
2221

@@ -25,9 +24,7 @@
2524
using namespace llvm;
2625

2726
namespace {
28-
struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
29-
DenseMap<Function *, Function *> Old2NewFuncs;
30-
27+
struct SPIRVRegularizer : public FunctionPass {
3128
public:
3229
static char ID;
3330
SPIRVRegularizer() : FunctionPass(ID) {}
@@ -37,11 +34,8 @@ struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
3734
void getAnalysisUsage(AnalysisUsage &AU) const override {
3835
FunctionPass::getAnalysisUsage(AU);
3936
}
40-
void visitCallInst(CallInst &CI);
4137

4238
private:
43-
void visitCallScalToVec(CallInst *CI, StringRef MangledName,
44-
StringRef DemangledName);
4539
void runLowerConstExpr(Function &F);
4640
};
4741
} // namespace
@@ -157,98 +151,8 @@ void SPIRVRegularizer::runLowerConstExpr(Function &F) {
157151
}
158152
}
159153

160-
// It fixes calls to OCL builtins that accept vector arguments and one of them
161-
// is actually a scalar splat.
162-
void SPIRVRegularizer::visitCallInst(CallInst &CI) {
163-
auto F = CI.getCalledFunction();
164-
if (!F)
165-
return;
166-
167-
auto MangledName = F->getName();
168-
char *NameStr = itaniumDemangle(F->getName().data());
169-
if (!NameStr)
170-
return;
171-
StringRef DemangledName(NameStr);
172-
173-
// TODO: add support for other builtins.
174-
if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
175-
DemangledName.starts_with("min") || DemangledName.starts_with("max"))
176-
visitCallScalToVec(&CI, MangledName, DemangledName);
177-
free(NameStr);
178-
}
179-
180-
void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
181-
StringRef DemangledName) {
182-
// Check if all arguments have the same type - it's simple case.
183-
auto Uniform = true;
184-
Type *Arg0Ty = CI->getOperand(0)->getType();
185-
auto IsArg0Vector = isa<VectorType>(Arg0Ty);
186-
for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
187-
Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
188-
if (Uniform)
189-
return;
190-
191-
auto *OldF = CI->getCalledFunction();
192-
Function *NewF = nullptr;
193-
auto [It, Inserted] = Old2NewFuncs.try_emplace(OldF);
194-
if (Inserted) {
195-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
196-
SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
197-
auto *NewFTy =
198-
FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
199-
NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
200-
*OldF->getParent());
201-
ValueToValueMapTy VMap;
202-
auto NewFArgIt = NewF->arg_begin();
203-
for (auto &Arg : OldF->args()) {
204-
auto ArgName = Arg.getName();
205-
NewFArgIt->setName(ArgName);
206-
VMap[&Arg] = &(*NewFArgIt++);
207-
}
208-
SmallVector<ReturnInst *, 8> Returns;
209-
CloneFunctionInto(NewF, OldF, VMap,
210-
CloneFunctionChangeType::LocalChangesOnly, Returns);
211-
NewF->setAttributes(Attrs);
212-
It->second = NewF;
213-
} else {
214-
NewF = It->second;
215-
}
216-
assert(NewF);
217-
218-
// This produces an instruction sequence that implements a splat of
219-
// CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
220-
// and ShuffleVectorInst to generate the same code as the SPIR-V translator.
221-
// For instance (transcoding/OpMin.ll), this call
222-
// call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
223-
// is translated to
224-
// %8 = OpUndef %v2uint
225-
// %14 = OpConstantComposite %v2uint %uint_1 %uint_10
226-
// ...
227-
// %10 = OpCompositeInsert %v2uint %uint_5 %8 0
228-
// %11 = OpVectorShuffle %v2uint %10 %8 0 0
229-
// %call = OpExtInst %v2uint %1 s_min %14 %11
230-
auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
231-
PoisonValue *PVal = PoisonValue::get(Arg0Ty);
232-
Instruction *Inst = InsertElementInst::Create(
233-
PVal, CI->getOperand(1), ConstInt, "", CI->getIterator());
234-
ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
235-
Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
236-
Value *NewVec =
237-
new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator());
238-
CI->setOperand(1, NewVec);
239-
CI->replaceUsesOfWith(OldF, NewF);
240-
CI->mutateFunctionType(NewF->getFunctionType());
241-
}
242-
243154
bool SPIRVRegularizer::runOnFunction(Function &F) {
244155
runLowerConstExpr(F);
245-
visit(F);
246-
for (auto &OldNew : Old2NewFuncs) {
247-
Function *OldF = OldNew.first;
248-
Function *NewF = OldNew.second;
249-
NewF->takeName(OldF);
250-
OldF->eraseFromParent();
251-
}
252156
return true;
253157
}
254158

0 commit comments

Comments
 (0)