Skip to content

Commit

Permalink
Support atan2pi, acos2pi, atan2pi
Browse files Browse the repository at this point in the history
Fixes google#84
Fixes google#85
Fixes google#86
  • Loading branch information
dneto0 committed Oct 16, 2017
1 parent e9a0351 commit 3fbb407
Show file tree
Hide file tree
Showing 13 changed files with 1,003 additions and 52 deletions.
201 changes: 149 additions & 52 deletions lib/SPIRVProducerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ using namespace mdconst;

namespace {

// The value of 1/pi. This value is from MSDN
// https://msdn.microsoft.com/en-us/library/4hwaceh6.aspx
const double kOneOverPi = 0.318309886183790671538;
const glsl::ExtInst kGlslExtInstBad = static_cast<glsl::ExtInst>(0);

// By default, reuse the same descriptor set number for all arguments.
// To turn that off, use -distinct-kernel-descriptor-sets
llvm::cl::opt<bool> distinct_kernel_descriptor_sets(
Expand Down Expand Up @@ -222,7 +227,17 @@ struct SPIRVProducerPass final : public ModulePass {
bool is4xi8vec(Type *Ty) const;
spv::StorageClass GetStorageClass(unsigned AddrSpace) const;
spv::BuiltIn GetBuiltin(StringRef globalVarName) const;
// Returns the GLSL extended instruction enum that the given function
// call maps to. If none, then returns the 0 value, i.e. GLSLstd4580Bad.
glsl::ExtInst getExtInstEnum(StringRef Name);
// Returns the GLSL extended instruction enum indirectly used by the given
// function. That is, to implement the given function, we use an extended
// instruction plus one more instruction. If none, then returns the 0 value,
// i.e. GLSLstd4580Bad.
glsl::ExtInst getIndirectExtInstEnum(StringRef Name);
// Returns the single GLSL extended instruction used directly or
// indirectly by the given function call.
glsl::ExtInst getDirectOrIndirectExtInstEnum(StringRef Name);
void PrintResID(SPIRVInstruction *Inst);
void PrintOpcode(SPIRVInstruction *Inst);
void PrintOperand(SPIRVOperand *Op);
Expand Down Expand Up @@ -865,24 +880,51 @@ bool SPIRVProducerPass::FindExtInst(Module &M) {
if (CallInst *Call = dyn_cast<CallInst>(&I)) {
Function *Callee = Call->getCalledFunction();
// Check whether this call is for extend instructions.
glsl::ExtInst EInst = getExtInstEnum(Callee->getName());
if (EInst) {
// clz needs OpExtInst and OpISub with constant 31, or splat vector
// of 31. Add it to the constant list here.
if (EInst == glsl::ExtInstFindUMsb) {
Type *IdxTy = Type::getInt32Ty(Context);
auto Idx = ConstantInt::get(IdxTy, 31);
FindType(IdxTy);
FindConstant(Idx);
if (auto* vectorTy = dyn_cast<VectorType>(I.getType())) {
// Register the splat vector with element 31.
FindConstant(ConstantVector::getSplat(
static_cast<unsigned>(vectorTy->getNumElements()), Idx));
FindType(vectorTy);
auto callee_name = Callee->getName();
const glsl::ExtInst EInst = getExtInstEnum(callee_name);
const glsl::ExtInst IndirectEInst =
getIndirectExtInstEnum(callee_name);

HasExtInst |=
(EInst != kGlslExtInstBad) || (IndirectEInst != kGlslExtInstBad);

if (IndirectEInst) {
// Register extra constants if needed.

// Registers a type and constant for computing the result of the
// given instruction. If the result of the instruction is a vector,
// then make a splat vector constant with the same number of
// elements.
auto register_constant = [this, &I](Constant *constant) {
FindType(constant->getType());
FindConstant(constant);
if (auto *vectorTy = dyn_cast<VectorType>(I.getType())) {
// Register the splat vector of the value with the same
// width as the result of the instruction.
auto *vec_constant = ConstantVector::getSplat(
static_cast<unsigned>(vectorTy->getNumElements()),
constant);
FindConstant(vec_constant);
FindType(vec_constant->getType());
}
};
switch (IndirectEInst) {
case glsl::ExtInstFindUMsb:
// clz needs OpExtInst and OpISub with constant 31, or splat
// vector of 31. Add it to the constant list here.
register_constant(
ConstantInt::get(Type::getInt32Ty(Context), 31));
break;
case glsl::ExtInstAcos:
case glsl::ExtInstAsin:
case glsl::ExtInstAtan2:
// We need 1/pi for acospi, asinpi, atan2pi.
register_constant(
ConstantFP::get(Type::getFloatTy(Context), kOneOverPi));
break;
default:
assert(false && "internally inconsistent");
}

HasExtInst = true;
}
}
}
Expand Down Expand Up @@ -4984,10 +5026,12 @@ void SPIRVProducerPass::GenerateInstruction(Instruction &I) {
DeferredInsts.push_back(
std::make_tuple(&I, --SPIRVInstList.end(), nextID++));

// Check whether this call is for extend instructions.
glsl::ExtInst EInst = getExtInstEnum(Callee->getName());
if (EInst == glsl::ExtInstFindUMsb) {
// clz needs OpExtInst and OpISub with constant 31 or vector constant 31.
// Check whether the implementation of this call uses an extended
// instruction plus one more value-producing instruction. If so, then
// reserve the id for the extra value-producing slot.
glsl::ExtInst EInst = getIndirectExtInstEnum(Callee->getName());
if (EInst != kGlslExtInstBad) {
// Reserve a spot for the extra value.
// Increase nextID.
VMap[&I] = nextID;
nextID++;
Expand Down Expand Up @@ -5254,7 +5298,8 @@ void SPIRVProducerPass::HandleDeferredInstruction() {
std::get<2>(*DeferredInst), Ops));
} else if (CallInst *Call = dyn_cast<CallInst>(Inst)) {
Function *Callee = Call->getCalledFunction();
glsl::ExtInst EInst = getExtInstEnum(Callee->getName());
auto callee_name = Callee->getName();
glsl::ExtInst EInst = getDirectOrIndirectExtInstEnum(callee_name);

if (EInst) {
uint32_t &ExtInstImportID = getOpExtInstImportID();
Expand Down Expand Up @@ -5299,42 +5344,67 @@ void SPIRVProducerPass::HandleDeferredInstruction() {
WordCount, spv::OpExtInst, std::get<2>(*DeferredInst), Ops);
SPIRVInstList.insert(InsertPoint, ExtInst);

// clz needs OpExtInst and OpISub with constant 31.
if (EInst == glsl::ExtInstFindUMsb) {
const auto IndirectExtInst = getIndirectExtInstEnum(callee_name);
if (IndirectExtInst != kGlslExtInstBad) {
// Generate one more instruction that uses the result of the extended
// instruction. Its result id is one more than the id of the
// extended instruction.
LLVMContext &Context =
Call->getParent()->getParent()->getParent()->getContext();
//
// Generate OpISub with constant 31.
//
// Ops[0] = Result Type ID
// Ops[1] = Operand 0
// Ops[2] = Operand 1
Ops.clear();

Type *resultTy = Call->getType();
Ops.push_back(new SPIRVOperand(SPIRVOperandType::NUMBERID,
lookupType(resultTy)));
auto generate_extra_inst = [this, &Context, &Call, &DeferredInst,
&VMap, &SPIRVInstList, &InsertPoint](
spv::Op opcode, Constant *constant) {
//
// Generate instruction like:
// result = opcode constant <extinst-result>
//
// Ops[0] = Result Type ID
// Ops[1] = Operand 0 ;; the constant, suitably splatted
// Ops[2] = Operand 1 ;; the result of the extended instruction
SPIRVOperandList Ops;

Type *IdxTy = Type::getInt32Ty(Context);
Constant *minuend = ConstantInt::get(IdxTy, 31);
if (auto *vectorTy = dyn_cast<VectorType>(resultTy)) {
minuend = ConstantVector::getSplat(
static_cast<unsigned>(vectorTy->getNumElements()), minuend);
}
uint32_t Op0ID = VMap[minuend];
SPIRVOperand *Op0IDOp =
new SPIRVOperand(SPIRVOperandType::NUMBERID, Op0ID);
Ops.push_back(Op0IDOp);
Type *resultTy = Call->getType();
Ops.push_back(new SPIRVOperand(SPIRVOperandType::NUMBERID,
lookupType(resultTy)));

SPIRVOperand *Op1IDOp = new SPIRVOperand(SPIRVOperandType::NUMBERID,
std::get<2>(*DeferredInst));
Ops.push_back(Op1IDOp);
if (auto *vectorTy = dyn_cast<VectorType>(resultTy)) {
constant = ConstantVector::getSplat(
static_cast<unsigned>(vectorTy->getNumElements()), constant);
}
uint32_t Op0ID = VMap[constant];
SPIRVOperand *Op0IDOp =
new SPIRVOperand(SPIRVOperandType::NUMBERID, Op0ID);
Ops.push_back(Op0IDOp);

SPIRVOperand *Op1IDOp = new SPIRVOperand(
SPIRVOperandType::NUMBERID, std::get<2>(*DeferredInst));
Ops.push_back(Op1IDOp);

SPIRVInstList.insert(
InsertPoint,
new SPIRVInstruction(5, opcode, std::get<2>(*DeferredInst) + 1,
Ops));
};

switch (IndirectExtInst) {
case glsl::ExtInstFindUMsb: // Implementing clz
generate_extra_inst(
spv::OpISub, ConstantInt::get(Type::getInt32Ty(Context), 31));
break;
case glsl::ExtInstAcos: // Implementing acospi
case glsl::ExtInstAsin: // Implementing asinpi
case glsl::ExtInstAtan2: // Implementing atan2pi
generate_extra_inst(
spv::OpFMul,
ConstantFP::get(Type::getFloatTy(Context), kOneOverPi));
break;

SPIRVInstList.insert(
InsertPoint,
new SPIRVInstruction(5, spv::OpISub,
std::get<2>(*DeferredInst) + 1, Ops));
default:
assert(false && "internally inconsistent");
}
}

} else if (Callee->getName().equals("_Z8popcounti") ||
Callee->getName().equals("_Z8popcountj") ||
Callee->getName().equals("_Z8popcountDv2_i") ||
Expand Down Expand Up @@ -5469,7 +5539,6 @@ glsl::ExtInst SPIRVProducerPass::getExtInstEnum(StringRef Name) {
.Case("_Z5clampDv2_fS_S_", glsl::ExtInst::ExtInstFClamp)
.Case("_Z5clampDv3_fS_S_", glsl::ExtInst::ExtInstFClamp)
.Case("_Z5clampDv4_fS_S_", glsl::ExtInst::ExtInstFClamp)
.StartsWith("_Z3clz", glsl::ExtInst::ExtInstFindUMsb)
.Case("_Z3maxii", glsl::ExtInst::ExtInstSMax)
.Case("_Z3maxDv2_iS_", glsl::ExtInst::ExtInstSMax)
.Case("_Z3maxDv3_iS_", glsl::ExtInst::ExtInstSMax)
Expand Down Expand Up @@ -5556,7 +5625,35 @@ glsl::ExtInst SPIRVProducerPass::getExtInstEnum(StringRef Name) {
.StartsWith("llvm.fmuladd.", glsl::ExtInst::ExtInstFma)
.Case("spirv.unpack.v2f16", glsl::ExtInst::ExtInstUnpackHalf2x16)
.Case("spirv.pack.v2f16", glsl::ExtInst::ExtInstPackHalf2x16)
.Default(static_cast<glsl::ExtInst>(0));
.Default(kGlslExtInstBad);
}

glsl::ExtInst SPIRVProducerPass::getIndirectExtInstEnum(StringRef Name) {
// Check indirect cases.
return StringSwitch<glsl::ExtInst>(Name)
.StartsWith("_Z3clz", glsl::ExtInst::ExtInstFindUMsb)
// Use exact match on float arg because these need a multiply
// of a constant of the right floating point type.
.Case("_Z6acospif", glsl::ExtInst::ExtInstAcos)
.Case("_Z6acospiDv2_f", glsl::ExtInst::ExtInstAcos)
.Case("_Z6acospiDv3_f", glsl::ExtInst::ExtInstAcos)
.Case("_Z6acospiDv4_f", glsl::ExtInst::ExtInstAcos)
.Case("_Z6asinpif", glsl::ExtInst::ExtInstAsin)
.Case("_Z6asinpiDv2_f", glsl::ExtInst::ExtInstAsin)
.Case("_Z6asinpiDv3_f", glsl::ExtInst::ExtInstAsin)
.Case("_Z6asinpiDv4_f", glsl::ExtInst::ExtInstAsin)
.Case("_Z7atan2piff", glsl::ExtInst::ExtInstAtan2)
.Case("_Z7atan2piDv2_fS_", glsl::ExtInst::ExtInstAtan2)
.Case("_Z7atan2piDv3_fS_", glsl::ExtInst::ExtInstAtan2)
.Case("_Z7atan2piDv4_fS_", glsl::ExtInst::ExtInstAtan2)
.Default(kGlslExtInstBad);
}

glsl::ExtInst SPIRVProducerPass::getDirectOrIndirectExtInstEnum(StringRef Name) {
auto direct = getExtInstEnum(Name);
if (direct != kGlslExtInstBad)
return direct;
return getIndirectExtInstEnum(Name);
}

void SPIRVProducerPass::PrintResID(SPIRVInstruction *Inst) {
Expand Down
70 changes: 70 additions & 0 deletions test/MathBuiltins/float2_acospi.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: clspv %s -S -o %t.spvasm
// RUN: FileCheck %s < %t.spvasm
// RUN: clspv %s -o %t.spv
// RUN: spirv-dis -o %t2.spvasm %t.spv
// RUN: FileCheck %s < %t2.spvasm
// RUN: spirv-val --target-env vulkan1.0 %t.spv

void kernel foo(global float2* A, float2 x)
{
*A = acospi(x);
}
// CHECK: ; SPIR-V
// CHECK: ; Version: 1.0
// CHECK: ; Generator: Codeplay; 0
// CHECK: ; Bound: 32
// CHECK: ; Schema: 0
// CHECK: OpCapability Shader
// CHECK: OpCapability VariablePointers
// CHECK: OpExtension "SPV_KHR_storage_buffer_storage_class"
// CHECK: OpExtension "SPV_KHR_variable_pointers"
// CHECK: [[_1:%[a-zA-Z0-9_]+]] = OpExtInstImport "GLSL.std.450"
// CHECK: OpMemoryModel Logical GLSL450
// CHECK: OpEntryPoint GLCompute [[_25:%[a-zA-Z0-9_]+]] "foo"
// CHECK: OpSource OpenCL_C 120
// CHECK: OpDecorate [[_18:%[a-zA-Z0-9_]+]] SpecId 0
// CHECK: OpDecorate [[_19:%[a-zA-Z0-9_]+]] SpecId 1
// CHECK: OpDecorate [[_20:%[a-zA-Z0-9_]+]] SpecId 2
// CHECK: OpDecorate [[__runtimearr_v2float:%[a-zA-Z0-9_]+]] ArrayStride 8
// CHECK: OpMemberDecorate [[__struct_6:%[a-zA-Z0-9_]+]] 0 Offset 0
// CHECK: OpDecorate [[__struct_6]] Block
// CHECK: OpMemberDecorate [[__struct_8:%[a-zA-Z0-9_]+]] 0 Offset 0
// CHECK: OpDecorate [[__struct_8]] Block
// CHECK: OpDecorate [[_gl_WorkGroupSize:%[a-zA-Z0-9_]+]] BuiltIn WorkgroupSize
// CHECK: OpDecorate [[_23:%[a-zA-Z0-9_]+]] DescriptorSet 0
// CHECK: OpDecorate [[_23]] Binding 0
// CHECK: OpDecorate [[_24:%[a-zA-Z0-9_]+]] DescriptorSet 0
// CHECK: OpDecorate [[_24]] Binding 1
// CHECK: [[_float:%[a-zA-Z0-9_]+]] = OpTypeFloat 32
// CHECK: [[_v2float:%[a-zA-Z0-9_]+]] = OpTypeVector [[_float]] 2
// CHECK: [[__ptr_StorageBuffer_v2float:%[a-zA-Z0-9_]+]] = OpTypePointer StorageBuffer [[_v2float]]
// CHECK: [[__runtimearr_v2float]] = OpTypeRuntimeArray [[_v2float]]
// CHECK: [[__struct_6]] = OpTypeStruct [[__runtimearr_v2float]]
// CHECK: [[__ptr_StorageBuffer__struct_6:%[a-zA-Z0-9_]+]] = OpTypePointer StorageBuffer [[__struct_6]]
// CHECK: [[__struct_8]] = OpTypeStruct [[_v2float]]
// CHECK: [[__ptr_StorageBuffer__struct_8:%[a-zA-Z0-9_]+]] = OpTypePointer StorageBuffer [[__struct_8]]
// CHECK: [[_uint:%[a-zA-Z0-9_]+]] = OpTypeInt 32 0
// CHECK: [[_void:%[a-zA-Z0-9_]+]] = OpTypeVoid
// CHECK: [[_12:%[a-zA-Z0-9_]+]] = OpTypeFunction [[_void]]
// CHECK: [[_v3uint:%[a-zA-Z0-9_]+]] = OpTypeVector [[_uint]] 3
// CHECK: [[__ptr_Private_v3uint:%[a-zA-Z0-9_]+]] = OpTypePointer Private [[_v3uint]]
// CHECK: [[_uint_0:%[a-zA-Z0-9_]+]] = OpConstant [[_uint]] 0
// CHECK: [[_float_0_31831:%[a-zA-Z0-9_]+]] = OpConstant [[_float]] 0.31831
// CHECK: [[_17:%[a-zA-Z0-9_]+]] = OpConstantComposite [[_v2float]] [[_float_0_31831]] [[_float_0_31831]]
// CHECK: [[_18]] = OpSpecConstant [[_uint]] 1
// CHECK: [[_19]] = OpSpecConstant [[_uint]] 1
// CHECK: [[_20]] = OpSpecConstant [[_uint]] 1
// CHECK: [[_gl_WorkGroupSize]] = OpSpecConstantComposite [[_v3uint]] [[_18]] [[_19]] [[_20]]
// CHECK: [[_22:%[a-zA-Z0-9_]+]] = OpVariable [[__ptr_Private_v3uint]] Private [[_gl_WorkGroupSize]]
// CHECK: [[_23]] = OpVariable [[__ptr_StorageBuffer__struct_6]] StorageBuffer
// CHECK: [[_24]] = OpVariable [[__ptr_StorageBuffer__struct_8]] StorageBuffer
// CHECK: [[_25]] = OpFunction [[_void]] None [[_12]]
// CHECK: [[_26:%[a-zA-Z0-9_]+]] = OpLabel
// CHECK: [[_27:%[a-zA-Z0-9_]+]] = OpAccessChain [[__ptr_StorageBuffer_v2float]] [[_23]] [[_uint_0]] [[_uint_0]]
// CHECK: [[_28:%[a-zA-Z0-9_]+]] = OpAccessChain [[__ptr_StorageBuffer_v2float]] [[_24]] [[_uint_0]]
// CHECK: [[_29:%[a-zA-Z0-9_]+]] = OpLoad [[_v2float]] [[_28]]
// CHECK: [[_30:%[a-zA-Z0-9_]+]] = OpExtInst [[_v2float]] [[_1]] Acos [[_29]]
// CHECK: [[_31:%[a-zA-Z0-9_]+]] = OpFMul [[_v2float]] [[_17]] [[_30]]
// CHECK: OpStore [[_27]] [[_31]]
// CHECK: OpReturn
// CHECK: OpFunctionEnd

0 comments on commit 3fbb407

Please sign in to comment.