Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Fix vloadn OpenCL builtin lowering #81148

Merged
merged 3 commits into from
Feb 21, 2024

Conversation

michalpaszkowski
Copy link
Member

@michalpaszkowski michalpaszkowski commented Feb 8, 2024

This pull request fixes an issue with missing vector element count immediate in OpExtInst calls and adds a case for generating bitcasts before GEPs for kernel arguments of non-matching pointer type. The new LITs are based on basic/vload_local and basic/vload_global OpenCL CTS tests. The tests after this change pass SPIR-V validation.

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 8, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Michal Paszkowski (michalpaszkowski)

Changes

This pull request fixes an issue with missing vector element count immediate in OpExtInst calls such as:

%call = OpExtInst %v2uchar %1 vloadn %conv1 %add_ptr 2

Full diff: https://github.com/llvm/llvm-project/pull/81148.diff

4 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+3-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+11-5)
  • (removed) llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll (-95)
  • (added) llvm/test/CodeGen/SPIRV/opencl/vload2.ll (+40)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index e4593e7db90e8b..7a83ea77f199f8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -114,6 +114,7 @@ struct VectorLoadStoreBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
   uint32_t Number;
+  uint32_t ElementCount;
   bool IsRounded;
   FPRoundingMode::FPRoundingMode RoundingMode;
 };
@@ -1851,7 +1852,8 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
           .addImm(Builtin->Number);
   for (auto Argument : Call->Arguments)
     MIB.addUse(Argument);
-
+  MIB.addImm(Builtin->ElementCount);
+  
   // Rounding mode should be passed as a last argument in the MI for builtins
   // like "vstorea_halfn_r".
   if (Builtin->IsRounded)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 8acd4691787e4c..63ca0a909b69c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1046,18 +1046,24 @@ class VectorLoadStoreBuiltin<string name, InstructionSet set, int number> {
   string Name = name;
   InstructionSet Set = set;
   bits<32> Number = number;
+  bits<32> ElementCount = !cond(!not(!eq(!find(name, "2"), -1)) : 2,
+                                !not(!eq(!find(name, "3"), -1)) : 3,
+                                !not(!eq(!find(name, "4"), -1)) : 4,
+                                !not(!eq(!find(name, "8"), -1)) : 8,
+                                !not(!eq(!find(name, "16"), -1)) : 16,
+                                true : 1);
   bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
   FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
-                                  !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
-                                  !not(!eq(!find(name, "_rtp"), -1)) : RTP,
-                                  !not(!eq(!find(name, "_rtn"), -1)) : RTN,
-                                  true : RTE);
+                                      !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
+                                      !not(!eq(!find(name, "_rtp"), -1)) : RTP,
+                                      !not(!eq(!find(name, "_rtn"), -1)) : RTN,
+                                      true : RTE);
 }
 
 // Table gathering all the vector data load/store builtins.
 def VectorLoadStoreBuiltins : GenericTable {
   let FilterClass = "VectorLoadStoreBuiltin";
-  let Fields = ["Name", "Set", "Number", "IsRounded", "RoundingMode"];
+  let Fields = ["Name", "Set", "Number", "ElementCount", "IsRounded", "RoundingMode"];
   string TypeOf_Set = "InstructionSet";
   string TypeOf_RoundingMode = "FPRoundingMode";
 }
diff --git a/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll b/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
deleted file mode 100644
index 40f1d59e4365e1..00000000000000
--- a/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
+++ /dev/null
@@ -1,95 +0,0 @@
-; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
-
-; TODO(#60133): Requires updates following opaque pointer migration.
-; XFAIL: *
-
-; CHECK: %[[#i16_ty:]] = OpTypeInt 16 0
-; CHECK: %[[#v4xi16_ty:]] = OpTypeVector %[[#i16_ty]] 4
-; CHECK: %[[#pv4xi16_ty:]] = OpTypePointer Function %[[#v4xi16_ty]]
-; CHECK: %[[#i16_const0:]] = OpConstant %[[#i16_ty]] 0
-; CHECK: %[[#i16_undef:]] = OpUndef %[[#i16_ty]]
-; CHECK: %[[#comp_const:]] = OpConstantComposite %[[#v4xi16_ty]] %[[#i16_const0]] %[[#i16_const0]] %[[#i16_const0]] %[[#i16_undef]]
-
-; CHECK: %[[#r:]] = OpInBoundsPtrAccessChain
-; CHECK: %[[#r2:]] = OpBitcast %[[#pv4xi16_ty]] %[[#r]]
-; CHECK: OpStore %[[#r2]] %[[#comp_const]] Aligned 8
-
-define spir_kernel void @test_fn(i16 addrspace(1)* %srcValues, i32 addrspace(1)* %offsets, <3 x i16> addrspace(1)* %destBuffer, i32 %alignmentOffset) {
-entry:
-  %sPrivateStorage = alloca [42 x <3 x i16>], align 8
-  %0 = bitcast [42 x <3 x i16>]* %sPrivateStorage to i8*
-  %1 = bitcast i8* %0 to i8*
-  call void @llvm.lifetime.start.p0i8(i64 336, i8* %1)
-  %2 = call spir_func <3 x i64> @BuiltInGlobalInvocationId()
-  %call = extractelement <3 x i64> %2, i32 0
-  %conv = trunc i64 %call to i32
-  %idxprom = sext i32 %conv to i64
-  %arrayidx = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 %idxprom
-  %storetmp = bitcast <3 x i16>* %arrayidx to <4 x i16>*
-  store <4 x i16> <i16 0, i16 0, i16 0, i16 undef>, <4 x i16>* %storetmp, align 8
-  %conv1 = sext i32 %conv to i64
-  %call2 = call spir_func <3 x i16> @OpenCL_vload3_i64_p1i16_i32(i64 %conv1, i16 addrspace(1)* %srcValues, i32 3)
-  %idxprom3 = sext i32 %conv to i64
-  %arrayidx4 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom3
-  %3 = load i32, i32 addrspace(1)* %arrayidx4, align 4
-  %conv5 = zext i32 %3 to i64
-  %arraydecay = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 0
-  %4 = bitcast <3 x i16>* %arraydecay to i16*
-  %idx.ext = zext i32 %alignmentOffset to i64
-  %add.ptr = getelementptr inbounds i16, i16* %4, i64 %idx.ext
-  call spir_func void @OpenCL_vstore3_v3i16_i64_p0i16(<3 x i16> %call2, i64 %conv5, i16* %add.ptr)
-  %arraydecay6 = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 0
-  %5 = bitcast <3 x i16>* %arraydecay6 to i16*
-  %idxprom7 = sext i32 %conv to i64
-  %arrayidx8 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom7
-  %6 = load i32, i32 addrspace(1)* %arrayidx8, align 4
-  %mul = mul i32 3, %6
-  %idx.ext9 = zext i32 %mul to i64
-  %add.ptr10 = getelementptr inbounds i16, i16* %5, i64 %idx.ext9
-  %idx.ext11 = zext i32 %alignmentOffset to i64
-  %add.ptr12 = getelementptr inbounds i16, i16* %add.ptr10, i64 %idx.ext11
-  %7 = bitcast <3 x i16> addrspace(1)* %destBuffer to i16 addrspace(1)*
-  %idxprom13 = sext i32 %conv to i64
-  %arrayidx14 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom13
-  %8 = load i32, i32 addrspace(1)* %arrayidx14, align 4
-  %mul15 = mul i32 3, %8
-  %idx.ext16 = zext i32 %mul15 to i64
-  %add.ptr17 = getelementptr inbounds i16, i16 addrspace(1)* %7, i64 %idx.ext16
-  %idx.ext18 = zext i32 %alignmentOffset to i64
-  %add.ptr19 = getelementptr inbounds i16, i16 addrspace(1)* %add.ptr17, i64 %idx.ext18
-  br label %for.cond
-
-for.cond:                                         ; preds = %for.inc, %entry
-  %i.0 = phi i32 [ 0, %entry ], [ %inc, %for.inc ]
-  %cmp = icmp ult i32 %i.0, 3
-  br i1 %cmp, label %for.body, label %for.end
-
-for.body:                                         ; preds = %for.cond
-  %idxprom21 = zext i32 %i.0 to i64
-  %arrayidx22 = getelementptr inbounds i16, i16* %add.ptr12, i64 %idxprom21
-  %9 = load i16, i16* %arrayidx22, align 2
-  %idxprom23 = zext i32 %i.0 to i64
-  %arrayidx24 = getelementptr inbounds i16, i16 addrspace(1)* %add.ptr19, i64 %idxprom23
-  store i16 %9, i16 addrspace(1)* %arrayidx24, align 2
-  br label %for.inc
-
-for.inc:                                          ; preds = %for.body
-  %inc = add i32 %i.0, 1
-  br label %for.cond
-
-for.end:                                          ; preds = %for.cond
-  %10 = bitcast [42 x <3 x i16>]* %sPrivateStorage to i8*
-  %11 = bitcast i8* %10 to i8*
-  call void @llvm.lifetime.end.p0i8(i64 336, i8* %11)
-  ret void
-}
-
-declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture)
-
-declare spir_func <3 x i16> @OpenCL_vload3_i64_p1i16_i32(i64, i16 addrspace(1)*, i32)
-
-declare spir_func void @OpenCL_vstore3_v3i16_i64_p0i16(<3 x i16>, i64, i16*)
-
-declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture)
-
-declare spir_func <3 x i64> @BuiltInGlobalInvocationId()
diff --git a/llvm/test/CodeGen/SPIRV/opencl/vload2.ll b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
new file mode 100644
index 00000000000000..f7d380b96a3ef0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
@@ -0,0 +1,40 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; This test only itends to check the vloadn builtin lowering. 
+; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.
+
+; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"
+
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#INT16:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
+; CHECK-DAG: %[[#VINT16:]] = OpTypeVector %[[#INT16]] 2
+; CHECK-DAG: %[[#VINT32:]] = OpTypeVector %[[#INT32]] 2
+; CHECK-DAG: %[[#VINT64:]] = OpTypeVector %[[#INT64]] 2
+; CHECK-DAG: %[[#VFLOAT:]] = OpTypeVector %[[#FLOAT]] 2
+; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
+
+; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
+; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]
+
+define spir_kernel void @test_fn(i64 %offset, ptr addrspace(1) %address) {
+; CHECK: %[[#]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call1 = call spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call2 = call spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call3 = call spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call4 = call spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call5 = call spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64 %offset, ptr addrspace(1) %address)
+  ret void
+}
+
+declare spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64, ptr addrspace(1))
+declare spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64, ptr addrspace(1))
+declare spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64, ptr addrspace(1))
+declare spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64, ptr addrspace(1))
+declare spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64, ptr addrspace(1))

Copy link

github-actions bot commented Feb 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@michalpaszkowski michalpaszkowski changed the title [SPIR-V] Explicitly emit vector element count for OpenCL vloadn calls [SPIR-V] Fix vloadn OpenCL builtin lowering Feb 10, 2024
Copy link
Contributor

@iliya-diyachkov iliya-diyachkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch looks good to me.

@@ -0,0 +1,40 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not running spirv-val on this too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment below. It looks like the SPIR-V validator is doing additional checks on the arguments of those OpExtInst calls and those are not valid here. I want this test to only check the vloadn builtin name resolution to keep the tests as small as possible.

; This test only intends to check the vloadn builtin name resolution.
; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.

@michalpaszkowski michalpaszkowski merged commit 03203b7 into llvm:main Feb 21, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants