Skip to content

Commit 2c6adc9

Browse files
authored
[Clang] Add vector gather / scatter builtins to clang (#157895)
Summary: This patch exposes `__builtin_masked_gather` and `__builtin_masked_scatter` to clang. These map to the underlying intrinsic relatively cleanly, needing only a level of indirection to take a vector of indices and a base pointer to a vector of pointers.
1 parent 45a0843 commit 2c6adc9

File tree

7 files changed

+259
-3
lines changed

7 files changed

+259
-3
lines changed

clang/docs/LanguageExtensions.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,11 @@ builtins have the same interface but store the result in consecutive indices.
957957
Effectively this performs the ``if (mask[i]) val[i] = ptr[j++]`` and ``if
958958
(mask[i]) ptr[j++] = val[i]`` pattern respectively.
959959

960+
The ``__builtin_masked_gather`` and ``__builtin_masked_scatter`` builtins handle
961+
non-sequential memory access for vector types. These use a base pointer and a
962+
vector of integer indices to gather memory into a vector type or scatter it to
963+
separate indices.
964+
960965
Example:
961966

962967
.. code-block:: c++
@@ -978,6 +983,14 @@ Example:
978983
__builtin_masked_compress_store(mask, val, ptr);
979984
}
980985
986+
v8i gather(v8b mask, v8i idx, int *ptr) {
987+
return __builtin_masked_gather(mask, idx, ptr);
988+
}
989+
990+
void scatter(v8b mask, v8i val, v8i idx, int *ptr) {
991+
__builtin_masked_scatter(mask, idx, val, ptr);
992+
}
993+
981994

982995
Matrix Types
983996
============

clang/docs/ReleaseNotes.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ Non-comprehensive list of changes in this release
213213
conditional memory loads from vectors. Binds to the LLVM intrinsics of the
214214
same name.
215215

216+
- Added ``__builtin_masked_gather`` and ``__builtin_masked_scatter`` for
217+
conditional gathering and scattering operations on vectors. Binds to the LLVM
218+
intrinsics of the same name.
219+
216220
- The ``__builtin_popcountg``, ``__builtin_ctzg``, and ``__builtin_clzg``
217221
functions now accept fixed-size boolean vectors.
218222

clang/include/clang/Basic/Builtins.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,18 @@ def MaskedCompressStore : Builtin {
12561256
let Prototype = "void(...)";
12571257
}
12581258

1259+
def MaskedGather : Builtin {
1260+
let Spellings = ["__builtin_masked_gather"];
1261+
let Attributes = [NoThrow, Pure, CustomTypeChecking];
1262+
let Prototype = "void(...)";
1263+
}
1264+
1265+
def MaskedScatter : Builtin {
1266+
let Spellings = ["__builtin_masked_scatter"];
1267+
let Attributes = [NoThrow, CustomTypeChecking];
1268+
let Prototype = "void(...)";
1269+
}
1270+
12591271
def AllocaUninitialized : Builtin {
12601272
let Spellings = ["__builtin_alloca_uninitialized"];
12611273
let Attributes = [FunctionWithBuiltinPrefix, NoThrow];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4298,6 +4298,33 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
42984298
}
42994299
return RValue::get(Result);
43004300
};
4301+
case Builtin::BI__builtin_masked_gather: {
4302+
llvm::Value *Mask = EmitScalarExpr(E->getArg(0));
4303+
llvm::Value *Idx = EmitScalarExpr(E->getArg(1));
4304+
llvm::Value *Ptr = EmitScalarExpr(E->getArg(2));
4305+
4306+
llvm::Type *RetTy = CGM.getTypes().ConvertType(E->getType());
4307+
CharUnits Align = CGM.getNaturalTypeAlignment(
4308+
E->getType()->getAs<VectorType>()->getElementType(), nullptr);
4309+
llvm::Value *AlignVal =
4310+
llvm::ConstantInt::get(Int32Ty, Align.getQuantity());
4311+
4312+
llvm::Value *PassThru = llvm::PoisonValue::get(RetTy);
4313+
if (E->getNumArgs() > 3)
4314+
PassThru = EmitScalarExpr(E->getArg(3));
4315+
4316+
llvm::Type *ElemTy = CGM.getTypes().ConvertType(
4317+
E->getType()->getAs<VectorType>()->getElementType());
4318+
llvm::Value *PtrVec = Builder.CreateGEP(ElemTy, Ptr, Idx);
4319+
4320+
llvm::Value *Result;
4321+
Function *F =
4322+
CGM.getIntrinsic(Intrinsic::masked_gather, {RetTy, PtrVec->getType()});
4323+
4324+
Result = Builder.CreateCall(F, {PtrVec, AlignVal, Mask, PassThru},
4325+
"masked_gather");
4326+
return RValue::get(Result);
4327+
}
43014328
case Builtin::BI__builtin_masked_store:
43024329
case Builtin::BI__builtin_masked_compress_store: {
43034330
llvm::Value *Mask = EmitScalarExpr(E->getArg(0));
@@ -4323,7 +4350,28 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
43234350
}
43244351
return RValue::get(nullptr);
43254352
}
4353+
case Builtin::BI__builtin_masked_scatter: {
4354+
llvm::Value *Mask = EmitScalarExpr(E->getArg(0));
4355+
llvm::Value *Idx = EmitScalarExpr(E->getArg(1));
4356+
llvm::Value *Val = EmitScalarExpr(E->getArg(2));
4357+
llvm::Value *Ptr = EmitScalarExpr(E->getArg(3));
43264358

4359+
CharUnits Align = CGM.getNaturalTypeAlignment(
4360+
E->getArg(2)->getType()->getAs<VectorType>()->getElementType(),
4361+
nullptr);
4362+
llvm::Value *AlignVal =
4363+
llvm::ConstantInt::get(Int32Ty, Align.getQuantity());
4364+
4365+
llvm::Type *ElemTy = CGM.getTypes().ConvertType(
4366+
E->getArg(1)->getType()->getAs<VectorType>()->getElementType());
4367+
llvm::Value *PtrVec = Builder.CreateGEP(ElemTy, Ptr, Idx);
4368+
4369+
Function *F = CGM.getIntrinsic(Intrinsic::masked_scatter,
4370+
{Val->getType(), PtrVec->getType()});
4371+
4372+
Builder.CreateCall(F, {Val, PtrVec, AlignVal, Mask});
4373+
return RValue();
4374+
}
43274375
case Builtin::BI__builtin_isinf_sign: {
43284376
// isinf_sign(x) -> fabs(x) == infinity ? (signbit(x) ? -1 : 1) : 0
43294377
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(*this, E);

clang/lib/Sema/SemaChecking.cpp

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,17 +2268,19 @@ static bool BuiltinCountZeroBitsGeneric(Sema &S, CallExpr *TheCall) {
22682268
}
22692269

22702270
static bool CheckMaskedBuiltinArgs(Sema &S, Expr *MaskArg, Expr *PtrArg,
2271-
unsigned Pos) {
2271+
unsigned Pos, bool Vector = true) {
22722272
QualType MaskTy = MaskArg->getType();
22732273
if (!MaskTy->isExtVectorBoolType())
22742274
return S.Diag(MaskArg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
22752275
<< 1 << /* vector of */ 4 << /* booleans */ 6 << /* no fp */ 0
22762276
<< MaskTy;
22772277

22782278
QualType PtrTy = PtrArg->getType();
2279-
if (!PtrTy->isPointerType() || !PtrTy->getPointeeType()->isVectorType())
2279+
if (!PtrTy->isPointerType() ||
2280+
(Vector && !PtrTy->getPointeeType()->isVectorType()) ||
2281+
(!Vector && PtrTy->getPointeeType()->isVectorType()))
22802282
return S.Diag(PtrArg->getExprLoc(), diag::err_vec_masked_load_store_ptr)
2281-
<< Pos << "pointer to vector";
2283+
<< Pos << (Vector ? "pointer to vector" : "scalar pointer");
22822284
return false;
22832285
}
22842286

@@ -2359,6 +2361,101 @@ static ExprResult BuiltinMaskedStore(Sema &S, CallExpr *TheCall) {
23592361
return TheCall;
23602362
}
23612363

2364+
static ExprResult BuiltinMaskedGather(Sema &S, CallExpr *TheCall) {
2365+
if (S.checkArgCountRange(TheCall, 3, 4))
2366+
return ExprError();
2367+
2368+
Expr *MaskArg = TheCall->getArg(0);
2369+
Expr *IdxArg = TheCall->getArg(1);
2370+
Expr *PtrArg = TheCall->getArg(2);
2371+
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3, /*Vector=*/false))
2372+
return ExprError();
2373+
2374+
QualType IdxTy = IdxArg->getType();
2375+
const VectorType *IdxVecTy = IdxTy->getAs<VectorType>();
2376+
if (!IdxTy->isExtVectorType() || !IdxVecTy->getElementType()->isIntegerType())
2377+
return S.Diag(MaskArg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2378+
<< 1 << /* vector of */ 4 << /* integer */ 1 << /* no fp */ 0
2379+
<< IdxTy;
2380+
2381+
QualType MaskTy = MaskArg->getType();
2382+
QualType PtrTy = PtrArg->getType();
2383+
QualType PointeeTy = PtrTy->getPointeeType();
2384+
const VectorType *MaskVecTy = MaskTy->getAs<VectorType>();
2385+
if (MaskVecTy->getNumElements() != IdxVecTy->getNumElements())
2386+
return ExprError(
2387+
S.Diag(TheCall->getBeginLoc(), diag::err_vec_masked_load_store_size)
2388+
<< S.getASTContext().BuiltinInfo.getQuotedName(
2389+
TheCall->getBuiltinCallee())
2390+
<< MaskTy << IdxTy);
2391+
2392+
QualType RetTy =
2393+
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
2394+
if (TheCall->getNumArgs() == 4) {
2395+
Expr *PassThruArg = TheCall->getArg(3);
2396+
QualType PassThruTy = PassThruArg->getType();
2397+
if (!S.Context.hasSameType(PassThruTy, RetTy))
2398+
return S.Diag(PassThruArg->getExprLoc(),
2399+
diag::err_vec_masked_load_store_ptr)
2400+
<< /* fourth argument */ 4 << RetTy;
2401+
}
2402+
2403+
TheCall->setType(RetTy);
2404+
return TheCall;
2405+
}
2406+
2407+
static ExprResult BuiltinMaskedScatter(Sema &S, CallExpr *TheCall) {
2408+
if (S.checkArgCount(TheCall, 4))
2409+
return ExprError();
2410+
2411+
Expr *MaskArg = TheCall->getArg(0);
2412+
Expr *IdxArg = TheCall->getArg(1);
2413+
Expr *ValArg = TheCall->getArg(2);
2414+
Expr *PtrArg = TheCall->getArg(3);
2415+
2416+
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3, /*Vector=*/false))
2417+
return ExprError();
2418+
2419+
QualType IdxTy = IdxArg->getType();
2420+
const VectorType *IdxVecTy = IdxTy->getAs<VectorType>();
2421+
if (!IdxTy->isExtVectorType() || !IdxVecTy->getElementType()->isIntegerType())
2422+
return S.Diag(MaskArg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2423+
<< 2 << /* vector of */ 4 << /* integer */ 1 << /* no fp */ 0
2424+
<< IdxTy;
2425+
2426+
QualType ValTy = ValArg->getType();
2427+
QualType MaskTy = MaskArg->getType();
2428+
QualType PtrTy = PtrArg->getType();
2429+
QualType PointeeTy = PtrTy->getPointeeType();
2430+
2431+
const VectorType *MaskVecTy = MaskTy->castAs<VectorType>();
2432+
const VectorType *ValVecTy = ValTy->castAs<VectorType>();
2433+
if (MaskVecTy->getNumElements() != IdxVecTy->getNumElements())
2434+
return ExprError(
2435+
S.Diag(TheCall->getBeginLoc(), diag::err_vec_masked_load_store_size)
2436+
<< S.getASTContext().BuiltinInfo.getQuotedName(
2437+
TheCall->getBuiltinCallee())
2438+
<< MaskTy << IdxTy);
2439+
if (MaskVecTy->getNumElements() != ValVecTy->getNumElements())
2440+
return ExprError(
2441+
S.Diag(TheCall->getBeginLoc(), diag::err_vec_masked_load_store_size)
2442+
<< S.getASTContext().BuiltinInfo.getQuotedName(
2443+
TheCall->getBuiltinCallee())
2444+
<< MaskTy << ValTy);
2445+
2446+
QualType ArgTy =
2447+
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
2448+
if (!S.Context.hasSameType(ValTy, ArgTy))
2449+
return ExprError(S.Diag(TheCall->getBeginLoc(),
2450+
diag::err_vec_builtin_incompatible_vector)
2451+
<< TheCall->getDirectCallee() << /*isMoreThanTwoArgs*/ 2
2452+
<< SourceRange(TheCall->getArg(1)->getBeginLoc(),
2453+
TheCall->getArg(1)->getEndLoc()));
2454+
2455+
TheCall->setType(S.Context.VoidTy);
2456+
return TheCall;
2457+
}
2458+
23622459
static ExprResult BuiltinInvoke(Sema &S, CallExpr *TheCall) {
23632460
SourceLocation Loc = TheCall->getBeginLoc();
23642461
MutableArrayRef Args(TheCall->getArgs(), TheCall->getNumArgs());
@@ -2617,6 +2714,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
26172714
case Builtin::BI__builtin_masked_store:
26182715
case Builtin::BI__builtin_masked_compress_store:
26192716
return BuiltinMaskedStore(*this, TheCall);
2717+
case Builtin::BI__builtin_masked_gather:
2718+
return BuiltinMaskedGather(*this, TheCall);
2719+
case Builtin::BI__builtin_masked_scatter:
2720+
return BuiltinMaskedScatter(*this, TheCall);
26202721
case Builtin::BI__builtin_invoke:
26212722
return BuiltinInvoke(*this, TheCall);
26222723
case Builtin::BI__builtin_prefetch:

clang/test/CodeGen/builtin-masked.c

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,61 @@ void test_store(v8b m, v8i v, v8i *p) {
129129
void test_compress_store(v8b m, v8i v, v8i *p) {
130130
__builtin_masked_compress_store(m, v, p);
131131
}
132+
133+
// CHECK-LABEL: define dso_local <8 x i32> @test_gather(
134+
// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr noundef [[PTR:%.*]]) #[[ATTR0]] {
135+
// CHECK-NEXT: [[ENTRY:.*:]]
136+
// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
137+
// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
138+
// CHECK-NEXT: [[IDX_ADDR:%.*]] = alloca <8 x i32>, align 32
139+
// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr, align 8
140+
// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
141+
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
142+
// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
143+
// CHECK-NEXT: [[IDX:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
144+
// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
145+
// CHECK-NEXT: store i8 [[TMP1]], ptr [[MASK_ADDR]], align 1
146+
// CHECK-NEXT: store <8 x i32> [[IDX]], ptr [[IDX_ADDR]], align 32
147+
// CHECK-NEXT: store ptr [[PTR]], ptr [[PTR_ADDR]], align 8
148+
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
149+
// CHECK-NEXT: [[TMP2:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
150+
// CHECK-NEXT: [[TMP3:%.*]] = load <8 x i32>, ptr [[IDX_ADDR]], align 32
151+
// CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[PTR_ADDR]], align 8
152+
// CHECK-NEXT: [[TMP5:%.*]] = getelementptr i32, ptr [[TMP4]], <8 x i32> [[TMP3]]
153+
// CHECK-NEXT: [[MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> [[TMP5]], i32 4, <8 x i1> [[TMP2]], <8 x i32> poison)
154+
// CHECK-NEXT: ret <8 x i32> [[MASKED_GATHER]]
155+
//
156+
v8i test_gather(v8b mask, v8i idx, int *ptr) {
157+
return __builtin_masked_gather(mask, idx, ptr);
158+
}
159+
160+
// CHECK-LABEL: define dso_local void @test_scatter(
161+
// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP1:%.*]], ptr noundef [[PTR:%.*]]) #[[ATTR3]] {
162+
// CHECK-NEXT: [[ENTRY:.*:]]
163+
// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
164+
// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
165+
// CHECK-NEXT: [[VAL_ADDR:%.*]] = alloca <8 x i32>, align 32
166+
// CHECK-NEXT: [[IDX_ADDR:%.*]] = alloca <8 x i32>, align 32
167+
// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr, align 8
168+
// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
169+
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
170+
// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
171+
// CHECK-NEXT: [[VAL:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
172+
// CHECK-NEXT: [[IDX:%.*]] = load <8 x i32>, ptr [[TMP1]], align 32
173+
// CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
174+
// CHECK-NEXT: store i8 [[TMP2]], ptr [[MASK_ADDR]], align 1
175+
// CHECK-NEXT: store <8 x i32> [[VAL]], ptr [[VAL_ADDR]], align 32
176+
// CHECK-NEXT: store <8 x i32> [[IDX]], ptr [[IDX_ADDR]], align 32
177+
// CHECK-NEXT: store ptr [[PTR]], ptr [[PTR_ADDR]], align 8
178+
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
179+
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
180+
// CHECK-NEXT: [[TMP4:%.*]] = load <8 x i32>, ptr [[VAL_ADDR]], align 32
181+
// CHECK-NEXT: [[TMP5:%.*]] = load <8 x i32>, ptr [[IDX_ADDR]], align 32
182+
// CHECK-NEXT: [[TMP6:%.*]] = load ptr, ptr [[PTR_ADDR]], align 8
183+
// CHECK-NEXT: [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], <8 x i32> [[TMP4]]
184+
// CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p0(<8 x i32> [[TMP5]], <8 x ptr> [[TMP7]], i32 4, <8 x i1> [[TMP3]])
185+
// CHECK-NEXT: ret void
186+
//
187+
void test_scatter(v8b mask, v8i val, v8i idx, int *ptr) {
188+
__builtin_masked_scatter(mask, val, idx, ptr);
189+
}

clang/test/Sema/builtin-masked.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,23 @@ void test_masked_compress_store(v8i *pf, v8f *pf2, v8b mask, v2b mask2) {
4444
__builtin_masked_compress_store(mask2, *pf, pf); // expected-error {{all arguments to '__builtin_masked_compress_store' must have the same number of elements}}
4545
__builtin_masked_compress_store(mask, *pf, pf2); // expected-error {{last two arguments to '__builtin_masked_compress_store' must have the same type}}
4646
}
47+
48+
void test_masked_gather(int *p, v8i idx, v8b mask, v2b mask2, v2b thru) {
49+
__builtin_masked_gather(mask); // expected-error {{too few arguments to function call, expected 3, have 1}}
50+
__builtin_masked_gather(mask, p, p, p, p, p); // expected-error {{too many arguments to function call, expected at most 4, have 6}}
51+
__builtin_masked_gather(p, p, p); // expected-error {{1st argument must be a vector of boolean types (was 'int *')}}
52+
__builtin_masked_gather(mask, p, p); // expected-error {{1st argument must be a vector of integer types (was 'int *')}}
53+
__builtin_masked_gather(mask2, idx, p); // expected-error {{all arguments to '__builtin_masked_gather' must have the same number of elements (was 'v2b'}}
54+
__builtin_masked_gather(mask, idx, p, thru); // expected-error {{4th argument must be a 'int __attribute__((ext_vector_type(8)))' (vector of 8 'int' values)}}
55+
__builtin_masked_gather(mask, idx, &idx); // expected-error {{3rd argument must be a scalar pointer}}
56+
}
57+
58+
void test_masked_scatter(int *p, v8i idx, v8b mask, v2b mask2, v8i val) {
59+
__builtin_masked_scatter(mask); // expected-error {{too few arguments to function call, expected 4, have 1}}
60+
__builtin_masked_scatter(mask, p, p, p, p, p); // expected-error {{too many arguments to function call, expected 4, have 6}}
61+
__builtin_masked_scatter(p, p, p, p); // expected-error {{1st argument must be a vector of boolean types (was 'int *')}}
62+
__builtin_masked_scatter(mask, p, p, p); // expected-error {{2nd argument must be a vector of integer types (was 'int *')}}
63+
__builtin_masked_scatter(mask, idx, mask, p); // expected-error {{last two arguments to '__builtin_masked_scatter' must have the same type}}
64+
__builtin_masked_scatter(mask, idx, val, idx); // expected-error {{3rd argument must be a scalar pointer}}
65+
__builtin_masked_scatter(mask, idx, val, &idx); // expected-error {{3rd argument must be a scalar pointer}}
66+
}

0 commit comments

Comments
 (0)