Skip to content

Commit

Permalink
[SPIR-V] Fix vloadn OpenCL builtin lowering (#81148)
Browse files Browse the repository at this point in the history
This pull request fixes an issue with missing vector element count
immediate in OpExtInst calls and adds a case for generating bitcasts
before GEPs for kernel arguments of non-matching pointer type. The new
LITs are based on basic/vload_local and basic/vload_global OpenCL CTS
tests. The tests after this change pass SPIR-V validation.
  • Loading branch information
michalpaszkowski committed Feb 21, 2024
1 parent c02b0d0 commit 03203b7
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 114 deletions.
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ struct VectorLoadStoreBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
uint32_t Number;
uint32_t ElementCount;
bool IsRounded;
FPRoundingMode::FPRoundingMode RoundingMode;
};
Expand Down Expand Up @@ -2042,6 +2043,7 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
.addImm(Builtin->Number);
for (auto Argument : Call->Arguments)
MIB.addUse(Argument);
MIB.addImm(Builtin->ElementCount);

// Rounding mode should be passed as a last argument in the MI for builtins
// like "vstorea_halfn_r".
Expand Down
16 changes: 11 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -1236,18 +1236,24 @@ class VectorLoadStoreBuiltin<string name, InstructionSet set, int number> {
string Name = name;
InstructionSet Set = set;
bits<32> Number = number;
bits<32> ElementCount = !cond(!not(!eq(!find(name, "2"), -1)) : 2,
!not(!eq(!find(name, "3"), -1)) : 3,
!not(!eq(!find(name, "4"), -1)) : 4,
!not(!eq(!find(name, "8"), -1)) : 8,
!not(!eq(!find(name, "16"), -1)) : 16,
true : 1);
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
!not(!eq(!find(name, "_rtn"), -1)) : RTN,
true : RTE);
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
!not(!eq(!find(name, "_rtn"), -1)) : RTN,
true : RTE);
}

// Table gathering all the vector data load/store builtins.
def VectorLoadStoreBuiltins : GenericTable {
let FilterClass = "VectorLoadStoreBuiltin";
let Fields = ["Name", "Set", "Number", "IsRounded", "RoundingMode"];
let Fields = ["Name", "Set", "Number", "ElementCount", "IsRounded", "RoundingMode"];
string TypeOf_Set = "InstructionSet";
string TypeOf_RoundingMode = "FPRoundingMode";
}
Expand Down
27 changes: 13 additions & 14 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,25 +290,14 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
Value *Pointer;
Type *ExpectedElementType;
unsigned OperandToReplace;
bool AllowCastingToChar = false;

StoreInst *SI = dyn_cast<StoreInst>(I);
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
SI->getValueOperand()->getType()->isPointerTy() &&
isa<Argument>(SI->getValueOperand())) {
Argument *Arg = cast<Argument>(SI->getValueOperand());
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
if (!ArgType || ArgType->getString().starts_with("uchar*"))
return;

// Handle special case when StoreInst's value operand is a kernel argument
// of a pointer type. Since these arguments could have either a basic
// element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
// the StoreInst's value operand to default pointer element type (i8).
Pointer = Arg;
Pointer = SI->getValueOperand();
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
OperandToReplace = 0;
AllowCastingToChar = true;
} else if (SI) {
Pointer = SI->getPointerOperand();
ExpectedElementType = SI->getValueOperand()->getType();
Expand Down Expand Up @@ -390,10 +379,20 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
}

// Do not emit spv_ptrcast if it would cast to the default pointer element
// type (i8) of the same address space.
if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
// type (i8) of the same address space. In case of OpenCL kernels, make sure
// i8 is the pointer element type defined for the given kernel argument.
if (ExpectedElementType->isIntegerTy(8) &&
F->getCallingConv() != CallingConv::SPIR_KERNEL)
return;

Argument *Arg = dyn_cast<Argument>(Pointer);
if (ExpectedElementType->isIntegerTy(8) &&
F->getCallingConv() == CallingConv::SPIR_KERNEL && Arg) {
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
if (ArgType && ArgType->getString().starts_with("uchar*"))
return;
}

// If this would be the first spv_ptrcast, the pointer's defining instruction
// requires spv_assign_ptr_type and does not already have one, do not emit
// spv_ptrcast and emit spv_assign_ptr_type instead.
Expand Down
95 changes: 0 additions & 95 deletions llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll

This file was deleted.

40 changes: 40 additions & 0 deletions llvm/test/CodeGen/SPIRV/opencl/vload2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; This test only intends to check the vloadn builtin name resolution.
; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.

; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"

; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
; CHECK-DAG: %[[#INT16:]] = OpTypeInt 16 0
; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
; CHECK-DAG: %[[#VINT16:]] = OpTypeVector %[[#INT16]] 2
; CHECK-DAG: %[[#VINT32:]] = OpTypeVector %[[#INT32]] 2
; CHECK-DAG: %[[#VINT64:]] = OpTypeVector %[[#INT64]] 2
; CHECK-DAG: %[[#VFLOAT:]] = OpTypeVector %[[#FLOAT]] 2
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]

; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]

define spir_kernel void @test_fn(i64 %offset, ptr addrspace(1) %address) {
; CHECK: %[[#]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call1 = call spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64 %offset, ptr addrspace(1) %address)
; CHECK: %[[#]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call2 = call spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64 %offset, ptr addrspace(1) %address)
; CHECK: %[[#]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call3 = call spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64 %offset, ptr addrspace(1) %address)
; CHECK: %[[#]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call4 = call spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64 %offset, ptr addrspace(1) %address)
; CHECK: %[[#]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
%call5 = call spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64 %offset, ptr addrspace(1) %address)
ret void
}

declare spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64, ptr addrspace(1))
declare spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64, ptr addrspace(1))
declare spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64, ptr addrspace(1))
declare spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64, ptr addrspace(1))
declare spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64, ptr addrspace(1))
31 changes: 31 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/getelementptr-kernel-arg-char.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer Workgroup %[[#INT8]]
; CHECK-DAG: %[[#PTRVINT8:]] = OpTypePointer Workgroup %[[#VINT8]]
; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#INT64]] 1

; CHECK: %[[#PARAM1:]] = OpFunctionParameter %[[#PTRVINT8]]
define spir_kernel void @test1(ptr addrspace(3) %address) !kernel_arg_type !1 {
; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM1]]
; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST1]] %[[#CONST]]
%cast = bitcast ptr addrspace(3) %address to ptr addrspace(3)
%gep = getelementptr inbounds i8, ptr addrspace(3) %cast, i64 1
ret void
}

; CHECK: %[[#PARAM2:]] = OpFunctionParameter %[[#PTRVINT8]]
define spir_kernel void @test2(ptr addrspace(3) %address) !kernel_arg_type !1 {
; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM2]]
; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST2]] %[[#CONST]]
%gep = getelementptr inbounds i8, ptr addrspace(3) %address, i64 1
ret void
}

declare spir_func <2 x i8> @_Z6vload2mPU3AS3Kc(i64, ptr addrspace(3))

!1 = !{!"char2*"}

0 comments on commit 03203b7

Please sign in to comment.