Skip to content

Commit

Permalink
[AArch64][Clang] Refactor code to emit SVE & SME builtins (#70959)
Browse files Browse the repository at this point in the history
This patch removes duplicated code in EmitAArch64SVEBuiltinExpr and
EmitAArch64SMEBuiltinExpr by creating a new function called
GetAArch64SVEProcessedOperands which handles splitting up multi-vector
arguments using vector extracts.

These changes are non-functional.
  • Loading branch information
kmclaughlin-arm committed Nov 2, 2023
1 parent 3aad77e commit 8f59c16
Show file tree
Hide file tree
Showing 10 changed files with 1,243 additions and 1,232 deletions.
150 changes: 78 additions & 72 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9617,22 +9617,17 @@ Value *CodeGenFunction::EmitSVEStructStore(const SVETypeFlags &TypeFlags,
Value *BasePtr = Ops[1];

// Does the store have an offset?
if (Ops.size() > 3)
if (Ops.size() > (2 + N))
BasePtr = Builder.CreateGEP(VTy, BasePtr, Ops[2]);

Value *Val = Ops.back();

// The llvm.aarch64.sve.st2/3/4 intrinsics take legal part vectors, so we
// need to break up the tuple vector.
SmallVector<llvm::Value*, 5> Operands;
unsigned MinElts = VTy->getElementCount().getKnownMinValue();
for (unsigned I = 0; I < N; ++I) {
Value *Idx = ConstantInt::get(CGM.Int64Ty, I * MinElts);
Operands.push_back(Builder.CreateExtractVector(VTy, Val, Idx));
}
for (unsigned I = Ops.size() - N; I < Ops.size(); ++I)
Operands.push_back(Ops[I]);
Operands.append({Predicate, BasePtr});

Function *F = CGM.getIntrinsic(IntID, { VTy });

return Builder.CreateCall(F, Operands);
}

Expand Down Expand Up @@ -9939,26 +9934,24 @@ Value *CodeGenFunction::FormSVEBuiltinResult(Value *Call) {
return Call;
}

Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
void CodeGenFunction::GetAArch64SVEProcessedOperands(
unsigned BuiltinID, const CallExpr *E, SmallVectorImpl<Value *> &Ops,
SVETypeFlags TypeFlags) {
// Find out if any arguments are required to be integer constant expressions.
unsigned ICEArguments = 0;
ASTContext::GetBuiltinTypeError Error;
getContext().GetBuiltinType(BuiltinID, Error, &ICEArguments);
assert(Error == ASTContext::GE_None && "Should not codegen an error");

llvm::Type *Ty = ConvertType(E->getType());
if (BuiltinID >= SVE::BI__builtin_sve_reinterpret_s8_s8 &&
BuiltinID <= SVE::BI__builtin_sve_reinterpret_f64_f64) {
Value *Val = EmitScalarExpr(E->getArg(0));
return EmitSVEReinterpret(Val, Ty);
}
// Tuple set/get only requires one insert/extract vector, which is
// created by EmitSVETupleSetOrGet.
bool IsTupleGetOrSet = TypeFlags.isTupleSet() || TypeFlags.isTupleGet();

llvm::SmallVector<Value *, 4> Ops;
for (unsigned i = 0, e = E->getNumArgs(); i != e; i++) {
if ((ICEArguments & (1 << i)) == 0)
Ops.push_back(EmitScalarExpr(E->getArg(i)));
else {
bool IsICE = ICEArguments & (1 << i);
Value *Arg = EmitScalarExpr(E->getArg(i));

if (IsICE) {
// If this is required to be a constant, constant fold it so that we know
// that the generated intrinsic gets a ConstantInt.
std::optional<llvm::APSInt> Result =
Expand All @@ -9970,12 +9963,49 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
// immediate requires more than a handful of bits.
*Result = Result->extOrTrunc(32);
Ops.push_back(llvm::ConstantInt::get(getLLVMContext(), *Result));
continue;
}

if (IsTupleGetOrSet || !isa<ScalableVectorType>(Arg->getType())) {
Ops.push_back(Arg);
continue;
}

auto *VTy = cast<ScalableVectorType>(Arg->getType());
unsigned MinElts = VTy->getMinNumElements();
bool IsPred = VTy->getElementType()->isIntegerTy(1);
unsigned N = (MinElts * VTy->getScalarSizeInBits()) / (IsPred ? 16 : 128);

if (N == 1) {
Ops.push_back(Arg);
continue;
}

for (unsigned I = 0; I < N; ++I) {
Value *Idx = ConstantInt::get(CGM.Int64Ty, (I * MinElts) / N);
auto *NewVTy =
ScalableVectorType::get(VTy->getElementType(), MinElts / N);
Ops.push_back(Builder.CreateExtractVector(NewVTy, Arg, Idx));
}
}
}

Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
llvm::Type *Ty = ConvertType(E->getType());
if (BuiltinID >= SVE::BI__builtin_sve_reinterpret_s8_s8 &&
BuiltinID <= SVE::BI__builtin_sve_reinterpret_f64_f64) {
Value *Val = EmitScalarExpr(E->getArg(0));
return EmitSVEReinterpret(Val, Ty);
}

auto *Builtin = findARMVectorIntrinsicInMap(AArch64SVEIntrinsicMap, BuiltinID,
AArch64SVEIntrinsicsProvenSorted);

llvm::SmallVector<Value *, 4> Ops;
SVETypeFlags TypeFlags(Builtin->TypeModifier);
GetAArch64SVEProcessedOperands(BuiltinID, E, Ops, TypeFlags);

if (TypeFlags.isLoad())
return EmitSVEMaskedLoad(E, Ty, Ops, Builtin->LLVMIntrinsic,
TypeFlags.isZExtReturn());
Expand All @@ -9989,14 +10019,14 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
return EmitSVEPrefetchLoad(TypeFlags, Ops, Builtin->LLVMIntrinsic);
else if (TypeFlags.isGatherPrefetch())
return EmitSVEGatherPrefetch(TypeFlags, Ops, Builtin->LLVMIntrinsic);
else if (TypeFlags.isStructLoad())
return EmitSVEStructLoad(TypeFlags, Ops, Builtin->LLVMIntrinsic);
else if (TypeFlags.isStructStore())
return EmitSVEStructStore(TypeFlags, Ops, Builtin->LLVMIntrinsic);
else if (TypeFlags.isStructLoad())
return EmitSVEStructLoad(TypeFlags, Ops, Builtin->LLVMIntrinsic);
else if (TypeFlags.isStructStore())
return EmitSVEStructStore(TypeFlags, Ops, Builtin->LLVMIntrinsic);
else if (TypeFlags.isTupleSet() || TypeFlags.isTupleGet())
return EmitSVETupleSetOrGet(TypeFlags, Ty, Ops);
return EmitSVETupleSetOrGet(TypeFlags, Ty, Ops);
else if (TypeFlags.isTupleCreate())
return EmitSVETupleCreate(TypeFlags, Ty, Ops);
return EmitSVETupleCreate(TypeFlags, Ty, Ops);
else if (TypeFlags.isUndef())
return UndefValue::get(Ty);
else if (Builtin->LLVMIntrinsic != 0) {
Expand Down Expand Up @@ -10248,13 +10278,8 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,
case SVE::BI__builtin_sve_svtbl2_f64: {
SVETypeFlags TF(Builtin->TypeModifier);
auto VTy = cast<llvm::ScalableVectorType>(getSVEType(TF));
Value *V0 = Builder.CreateExtractVector(VTy, Ops[0],
ConstantInt::get(CGM.Int64Ty, 0));
unsigned MinElts = VTy->getMinNumElements();
Value *V1 = Builder.CreateExtractVector(
VTy, Ops[0], ConstantInt::get(CGM.Int64Ty, MinElts));
Function *F = CGM.getIntrinsic(Intrinsic::aarch64_sve_tbl2, VTy);
return Builder.CreateCall(F, {V0, V1, Ops[1]});
return Builder.CreateCall(F, Ops);
}

case SVE::BI__builtin_sve_svset_neonq_s8:
Expand Down Expand Up @@ -10312,35 +10337,13 @@ Value *CodeGenFunction::EmitAArch64SVEBuiltinExpr(unsigned BuiltinID,

Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
const CallExpr *E) {
// Find out if any arguments are required to be integer constant expressions.
unsigned ICEArguments = 0;
ASTContext::GetBuiltinTypeError Error;
getContext().GetBuiltinType(BuiltinID, Error, &ICEArguments);
assert(Error == ASTContext::GE_None && "Should not codegen an error");

llvm::Type *Ty = ConvertType(E->getType());
llvm::SmallVector<Value *, 4> Ops;
for (unsigned i = 0, e = E->getNumArgs(); i != e; i++) {
if ((ICEArguments & (1 << i)) == 0)
Ops.push_back(EmitScalarExpr(E->getArg(i)));
else {
// If this is required to be a constant, constant fold it so that we know
// that the generated intrinsic gets a ConstantInt.
std::optional<llvm::APSInt> Result =
E->getArg(i)->getIntegerConstantExpr(getContext());
assert(Result && "Expected argument to be a constant");

// Immediates for SVE llvm intrinsics are always 32bit. We can safely
// truncate because the immediate has been range checked and no valid
// immediate requires more than a handful of bits.
*Result = Result->extOrTrunc(32);
Ops.push_back(llvm::ConstantInt::get(getLLVMContext(), *Result));
}
}

auto *Builtin = findARMVectorIntrinsicInMap(AArch64SMEIntrinsicMap, BuiltinID,
AArch64SMEIntrinsicsProvenSorted);

llvm::SmallVector<Value *, 4> Ops;
SVETypeFlags TypeFlags(Builtin->TypeModifier);
GetAArch64SVEProcessedOperands(BuiltinID, E, Ops, TypeFlags);

if (TypeFlags.isLoad() || TypeFlags.isStore())
return EmitSMELd1St1(TypeFlags, Ops, Builtin->LLVMIntrinsic);
else if (TypeFlags.isReadZA() || TypeFlags.isWriteZA())
Expand All @@ -10353,21 +10356,24 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
BuiltinID == SME::BI__builtin_sme_svldr_za ||
BuiltinID == SME::BI__builtin_sme_svstr_za)
return EmitSMELdrStr(TypeFlags, Ops, Builtin->LLVMIntrinsic);
else if (Builtin->LLVMIntrinsic != 0) {
// Predicates must match the main datatype.
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
if (auto PredTy = dyn_cast<llvm::VectorType>(Ops[i]->getType()))
if (PredTy->getElementType()->isIntegerTy(1))
Ops[i] = EmitSVEPredicateCast(Ops[i], getSVEType(TypeFlags));

Function *F = CGM.getIntrinsic(Builtin->LLVMIntrinsic,
getSVEOverloadTypes(TypeFlags, Ty, Ops));
Value *Call = Builder.CreateCall(F, Ops);
return Call;
}
// Should not happen!
if (Builtin->LLVMIntrinsic == 0)
return nullptr;

/// Should not happen
return nullptr;
// Predicates must match the main datatype.
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
if (auto PredTy = dyn_cast<llvm::VectorType>(Ops[i]->getType()))
if (PredTy->getElementType()->isIntegerTy(1))
Ops[i] = EmitSVEPredicateCast(Ops[i], getSVEType(TypeFlags));

Function *F =
TypeFlags.isOverloadNone()
? CGM.getIntrinsic(Builtin->LLVMIntrinsic)
: CGM.getIntrinsic(Builtin->LLVMIntrinsic, {getSVEType(TypeFlags)});
Value *Call = Builder.CreateCall(F, Ops);

return FormSVEBuiltinResult(Call);
}

Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4311,6 +4311,11 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitSMELdrStr(const SVETypeFlags &TypeFlags,
llvm::SmallVectorImpl<llvm::Value *> &Ops,
unsigned IntID);

void GetAArch64SVEProcessedOperands(unsigned BuiltinID, const CallExpr *E,
SmallVectorImpl<llvm::Value *> &Ops,
SVETypeFlags TypeFlags);

llvm::Value *EmitAArch64SMEBuiltinExpr(unsigned BuiltinID, const CallExpr *E);

llvm::Value *EmitAArch64BuiltinExpr(unsigned BuiltinID, const CallExpr *E,
Expand Down
36 changes: 18 additions & 18 deletions clang/test/CodeGen/aarch64-sve-intrinsics/acle_sve_st2-bfloat.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
#endif
// CHECK-LABEL: @test_svst2_bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA:%.*]], i64 0)
// CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA]], i64 8)
// CHECK-NEXT: tail call void @llvm.aarch64.sve.st2.nxv8bf16(<vscale x 8 x bfloat> [[TMP1]], <vscale x 8 x bfloat> [[TMP2]], <vscale x 8 x i1> [[TMP0]], ptr [[BASE:%.*]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA:%.*]], i64 0)
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA]], i64 8)
// CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: tail call void @llvm.aarch64.sve.st2.nxv8bf16(<vscale x 8 x bfloat> [[TMP0]], <vscale x 8 x bfloat> [[TMP1]], <vscale x 8 x i1> [[TMP2]], ptr [[BASE:%.*]])
// CHECK-NEXT: ret void
//
// CPP-CHECK-LABEL: @_Z15test_svst2_bf16u10__SVBool_tPu6__bf1614svbfloat16x2_t(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA:%.*]], i64 0)
// CPP-CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA]], i64 8)
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sve.st2.nxv8bf16(<vscale x 8 x bfloat> [[TMP1]], <vscale x 8 x bfloat> [[TMP2]], <vscale x 8 x i1> [[TMP0]], ptr [[BASE:%.*]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA:%.*]], i64 0)
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA]], i64 8)
// CPP-CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sve.st2.nxv8bf16(<vscale x 8 x bfloat> [[TMP0]], <vscale x 8 x bfloat> [[TMP1]], <vscale x 8 x i1> [[TMP2]], ptr [[BASE:%.*]])
// CPP-CHECK-NEXT: ret void
//
void test_svst2_bf16(svbool_t pg, bfloat16_t *base, svbfloat16x2_t data)
Expand All @@ -37,20 +37,20 @@ void test_svst2_bf16(svbool_t pg, bfloat16_t *base, svbfloat16x2_t data)

// CHECK-LABEL: @test_svst2_vnum_bf16(
// CHECK-NEXT: entry:
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP1:%.*]] = getelementptr <vscale x 8 x bfloat>, ptr [[BASE:%.*]], i64 [[VNUM:%.*]]
// CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA:%.*]], i64 0)
// CHECK-NEXT: [[TMP3:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA]], i64 8)
// CHECK-NEXT: tail call void @llvm.aarch64.sve.st2.nxv8bf16(<vscale x 8 x bfloat> [[TMP2]], <vscale x 8 x bfloat> [[TMP3]], <vscale x 8 x i1> [[TMP0]], ptr [[TMP1]])
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA:%.*]], i64 0)
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA]], i64 8)
// CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CHECK-NEXT: [[TMP3:%.*]] = getelementptr <vscale x 8 x bfloat>, ptr [[BASE:%.*]], i64 [[VNUM:%.*]]
// CHECK-NEXT: tail call void @llvm.aarch64.sve.st2.nxv8bf16(<vscale x 8 x bfloat> [[TMP0]], <vscale x 8 x bfloat> [[TMP1]], <vscale x 8 x i1> [[TMP2]], ptr [[TMP3]])
// CHECK-NEXT: ret void
//
// CPP-CHECK-LABEL: @_Z20test_svst2_vnum_bf16u10__SVBool_tPu6__bf16l14svbfloat16x2_t(
// CPP-CHECK-NEXT: entry:
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP1:%.*]] = getelementptr <vscale x 8 x bfloat>, ptr [[BASE:%.*]], i64 [[VNUM:%.*]]
// CPP-CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA:%.*]], i64 0)
// CPP-CHECK-NEXT: [[TMP3:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA]], i64 8)
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sve.st2.nxv8bf16(<vscale x 8 x bfloat> [[TMP2]], <vscale x 8 x bfloat> [[TMP3]], <vscale x 8 x i1> [[TMP0]], ptr [[TMP1]])
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA:%.*]], i64 0)
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x bfloat> @llvm.vector.extract.nxv8bf16.nxv16bf16(<vscale x 16 x bfloat> [[DATA]], i64 8)
// CPP-CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
// CPP-CHECK-NEXT: [[TMP3:%.*]] = getelementptr <vscale x 8 x bfloat>, ptr [[BASE:%.*]], i64 [[VNUM:%.*]]
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sve.st2.nxv8bf16(<vscale x 8 x bfloat> [[TMP0]], <vscale x 8 x bfloat> [[TMP1]], <vscale x 8 x i1> [[TMP2]], ptr [[TMP3]])
// CPP-CHECK-NEXT: ret void
//
void test_svst2_vnum_bf16(svbool_t pg, bfloat16_t *base, int64_t vnum, svbfloat16x2_t data)
Expand Down

0 comments on commit 8f59c16

Please sign in to comment.