Skip to content

Commit

Permalink
[AArch64] Emit AssertZExt for i1 arguments
Browse files Browse the repository at this point in the history
AAPCS requires i1 argument to be zero-extended to 8-bits by the
caller. Emit a new AArch64ISD::ASSERT_ZEXT_BOOL hint (or AssertZExt
for GlobalISel) to enable some optimization opportunities. In
particular, when the argument is forwarded to the callee, we can avoid
zero-extension and use it as-is.

Differential Revision: https://reviews.llvm.org/D107160
  • Loading branch information
asavonic committed Oct 11, 2021
1 parent 342d7b6 commit 7ae8f39
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 18 deletions.
50 changes: 48 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -1799,6 +1799,11 @@ void AArch64TargetLowering::computeKnownBitsForTargetNode(
Known.Zero = APInt::getHighBitsSet(64, 32);
break;
}
case AArch64ISD::ASSERT_ZEXT_BOOL: {
Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
Known.Zero |= APInt(Known.getBitWidth(), 0xFE);
break;
}
case ISD::INTRINSIC_W_CHAIN: {
ConstantSDNode *CN = cast<ConstantSDNode>(Op->getOperand(1));
Intrinsic::ID IntID = static_cast<Intrinsic::ID>(CN->getZExtValue());
Expand Down Expand Up @@ -2190,6 +2195,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::INDEX_VECTOR)
MAKE_CASE(AArch64ISD::UADDLP)
MAKE_CASE(AArch64ISD::CALL_RVMARKER)
MAKE_CASE(AArch64ISD::ASSERT_ZEXT_BOOL)
}
#undef MAKE_CASE
return nullptr;
Expand Down Expand Up @@ -5369,6 +5375,19 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
if (Subtarget->isTargetILP32() && Ins[i].Flags.isPointer())
ArgValue = DAG.getNode(ISD::AssertZext, DL, ArgValue.getValueType(),
ArgValue, DAG.getValueType(MVT::i32));

// i1 arguments are zero-extended to i8 by the caller. Emit a
// hint to reflect this.
if (Ins[i].isOrigArg()) {
Argument *OrigArg = MF.getFunction().getArg(Ins[i].getOrigArgIndex());
if (OrigArg->getType()->isIntegerTy(1)) {
if (!Ins[i].Flags.isZExt()) {
ArgValue = DAG.getNode(AArch64ISD::ASSERT_ZEXT_BOOL, DL,
ArgValue.getValueType(), ArgValue);
}
}
}

InVals.push_back(ArgValue);
}
}
Expand Down Expand Up @@ -5807,6 +5826,19 @@ bool AArch64TargetLowering::DoesCalleeRestoreStack(CallingConv::ID CallCC,
CallCC == CallingConv::Tail || CallCC == CallingConv::SwiftTail;
}

// Check if the value is zero-extended from i1 to i8
static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
unsigned SizeInBits = Arg.getValueType().getSizeInBits();
if (SizeInBits < 8)
return false;

APInt LowBits(SizeInBits, 0xFF);
APInt RequredZero(SizeInBits, 0xFE);
KnownBits Bits = DAG.computeKnownBits(Arg, LowBits, 4);
bool ZExtBool = (Bits.Zero & RequredZero) == RequredZero;
return ZExtBool;
}

/// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
/// and add input and output parameter nodes.
SDValue
Expand Down Expand Up @@ -6004,8 +6036,22 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
case CCValAssign::AExt:
if (Outs[i].ArgVT == MVT::i1) {
// AAPCS requires i1 to be zero-extended to 8-bits by the caller.
Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg);
Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i8, Arg);
//
// Check if we actually have to do this, because the value may
// already be zero-extended.
//
// We cannot just emit a (zext i8 (trunc (assert-zext i8)))
// and rely on DAGCombiner to fold this, because the following
// (anyext i32) is combined with (zext i8) in DAG.getNode:
//
// (ext (zext x)) -> (zext x)
//
// This will give us (zext i32), which we cannot remove, so
// try to check this beforehand.
if (!checkZExtBool(Arg, DAG)) {
Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg);
Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i8, Arg);
}
}
Arg = DAG.getNode(ISD::ANY_EXTEND, DL, VA.getLocVT(), Arg);
break;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Expand Up @@ -405,6 +405,10 @@ enum NodeType : unsigned {
SSTNT1_PRED,
SSTNT1_INDEX_PRED,

// Asserts that a function argument (i32) is zero-extended to i8 by
// the caller
ASSERT_ZEXT_BOOL,

// Strict (exception-raising) floating point comparison
STRICT_FCMP = ISD::FIRST_TARGET_STRICTFP_OPCODE,
STRICT_FCMPE,
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Expand Up @@ -8170,6 +8170,10 @@ def StoreSwiftAsyncContext
: Pseudo<(outs), (ins GPR64:$ctx, GPR64sp:$base, simm9:$offset),
[]>, Sched<[]>;

def AArch64AssertZExtBool : SDNode<"AArch64ISD::ASSERT_ZEXT_BOOL", SDT_assert>;
def : Pat<(AArch64AssertZExtBool GPR32:$op),
(i32 GPR32:$op)>;

include "AArch64InstrAtomics.td"
include "AArch64SVEInstrInfo.td"
include "AArch64SMEInstrInfo.td"
Expand Down
44 changes: 42 additions & 2 deletions llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
Expand Up @@ -531,6 +531,7 @@ bool AArch64CallLowering::lowerFormalArguments(
auto &DL = F.getParent()->getDataLayout();

SmallVector<ArgInfo, 8> SplitArgs;
SmallVector<std::pair<Register, Register>> BoolArgs;
unsigned i = 0;
for (auto &Arg : F.args()) {
if (DL.getTypeStoreSize(Arg.getType()).isZero())
Expand All @@ -539,6 +540,22 @@ bool AArch64CallLowering::lowerFormalArguments(
ArgInfo OrigArg{VRegs[i], Arg, i};
setArgFlags(OrigArg, i + AttributeList::FirstArgIndex, DL, F);

// i1 arguments are zero-extended to i8 by the caller. Emit a
// hint to reflect this.
if (OrigArg.Ty->isIntegerTy(1)) {
assert(OrigArg.Regs.size() == 1 &&
MRI.getType(OrigArg.Regs[0]).getSizeInBits() == 1 &&
"Unexpected registers used for i1 arg");

if (!OrigArg.Flags[0].isZExt()) {
// Lower i1 argument as i8, and insert AssertZExt + Trunc later.
Register OrigReg = OrigArg.Regs[0];
Register WideReg = MRI.createGenericVirtualRegister(LLT::scalar(8));
OrigArg.Regs[0] = WideReg;
BoolArgs.push_back({OrigReg, WideReg});
}
}

if (Arg.hasAttribute(Attribute::SwiftAsync))
MF.getInfo<AArch64FunctionInfo>()->setHasSwiftAsyncContext(true);

Expand All @@ -559,6 +576,18 @@ bool AArch64CallLowering::lowerFormalArguments(
F.getCallingConv(), F.isVarArg()))
return false;

if (!BoolArgs.empty()) {
for (auto &KV : BoolArgs) {
Register OrigReg = KV.first;
Register WideReg = KV.second;
LLT WideTy = MRI.getType(WideReg);
assert(MRI.getType(OrigReg).getScalarSizeInBits() == 1 &&
"Unexpected bit size of a bool arg");
MIRBuilder.buildTrunc(
OrigReg, MIRBuilder.buildAssertZExt(WideTy, WideReg, 1).getReg(0));
}
}

AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
uint64_t StackOffset = Assigner.StackOffset;
if (F.isVarArg()) {
Expand Down Expand Up @@ -1051,8 +1080,19 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
for (auto &OrigArg : Info.OrigArgs) {
splitToValueTypes(OrigArg, OutArgs, DL, Info.CallConv);
// AAPCS requires that we zero-extend i1 to 8 bits by the caller.
if (OrigArg.Ty->isIntegerTy(1))
OutArgs.back().Flags[0].setZExt();
if (OrigArg.Ty->isIntegerTy(1)) {
ArgInfo &OutArg = OutArgs.back();
assert(OutArg.Regs.size() == 1 &&
MRI.getType(OutArg.Regs[0]).getSizeInBits() == 1 &&
"Unexpected registers used for i1 arg");

// We cannot use a ZExt ArgInfo flag here, because it will
// zero-extend the argument to i32 instead of just i8.
OutArg.Regs[0] =
MIRBuilder.buildZExt(LLT::scalar(8), OutArg.Regs[0]).getReg(0);
LLVMContext &Ctx = MF.getFunction().getContext();
OutArg.Ty = Type::getInt8Ty(Ctx);
}
}

SmallVector<ArgInfo, 8> InArgs;
Expand Down
Expand Up @@ -1422,10 +1422,12 @@ define i1 @i1_value_cmp_is_signed(i1) {
; CHECK: successors: %bb.3(0x40000000), %bb.2(0x40000000)
; CHECK: liveins: $w0
; CHECK: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
; CHECK: [[TRUNC:%[0-9]+]]:_(s1) = G_TRUNC [[COPY]](s32)
; CHECK: [[TRUNC:%[0-9]+]]:_(s8) = G_TRUNC [[COPY]](s32)
; CHECK: [[ASSERT_ZEXT:%[0-9]+]]:_(s8) = G_ASSERT_ZEXT [[TRUNC]], 1
; CHECK: [[TRUNC1:%[0-9]+]]:_(s1) = G_TRUNC [[ASSERT_ZEXT]](s8)
; CHECK: [[C:%[0-9]+]]:_(s1) = G_CONSTANT i1 true
; CHECK: [[C1:%[0-9]+]]:_(s1) = G_CONSTANT i1 false
; CHECK: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(sle), [[TRUNC]](s1), [[C1]]
; CHECK: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(sle), [[TRUNC1]](s1), [[C1]]
; CHECK: G_BRCOND [[ICMP]](s1), %bb.3
; CHECK: G_BR %bb.2
; CHECK: bb.2.BadValue:
Expand All @@ -1434,7 +1436,7 @@ define i1 @i1_value_cmp_is_signed(i1) {
; CHECK: BL @bar, csr_aarch64_aapcs, implicit-def $lr, implicit $sp
; CHECK: ADJCALLSTACKUP 0, 0, implicit-def $sp, implicit $sp
; CHECK: bb.3.OkValue:
; CHECK: [[ZEXT:%[0-9]+]]:_(s8) = G_ZEXT [[TRUNC]](s1)
; CHECK: [[ZEXT:%[0-9]+]]:_(s8) = G_ZEXT [[TRUNC1]](s1)
; CHECK: [[ANYEXT:%[0-9]+]]:_(s32) = G_ANYEXT [[ZEXT]](s8)
; CHECK: $w0 = COPY [[ANYEXT]](s32)
; CHECK: RET_ReallyLR implicit $w0
Expand Down
22 changes: 16 additions & 6 deletions llvm/test/CodeGen/AArch64/GlobalISel/arm64-irtranslator.ll
Expand Up @@ -929,9 +929,11 @@ define void @test_insertvalue_agg(%struct.nested* %addr, {i8, i32}* %addr2) {

; CHECK-LABEL: name: test_select
; CHECK: [[TST_C:%[0-9]+]]:_(s32) = COPY $w0
; CHECK: [[TST:%[0-9]+]]:_(s1) = G_TRUNC [[TST_C]]
; CHECK: [[TSTEXT:%[0-9]+]]:_(s8) = G_TRUNC [[TST_C]]
; CHECK: [[LHS:%[0-9]+]]:_(s32) = COPY $w1
; CHECK: [[RHS:%[0-9]+]]:_(s32) = COPY $w2
; CHECK: [[TSTASSERT:%[0-9]+]]:_(s8) = G_ASSERT_ZEXT [[TSTEXT]], 1
; CHECK: [[TST:%[0-9]+]]:_(s1) = G_TRUNC [[TSTASSERT]]
; CHECK: [[RES:%[0-9]+]]:_(s32) = G_SELECT [[TST]](s1), [[LHS]], [[RHS]]
; CHECK: $w0 = COPY [[RES]]
define i32 @test_select(i1 %tst, i32 %lhs, i32 %rhs) {
Expand All @@ -941,9 +943,11 @@ define i32 @test_select(i1 %tst, i32 %lhs, i32 %rhs) {

; CHECK-LABEL: name: test_select_flags
; CHECK: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
; CHECK: [[TRUNC:%[0-9]+]]:_(s1) = G_TRUNC [[COPY]](s32)
; CHECK: [[TRUNC8:%[0-9]+]]:_(s8) = G_TRUNC [[COPY]]
; CHECK: [[COPY1:%[0-9]+]]:_(s32) = COPY $s0
; CHECK: [[COPY2:%[0-9]+]]:_(s32) = COPY $s1
; CHECK: [[TRUNCASSERT:%[0-9]+]]:_(s8) = G_ASSERT_ZEXT [[TRUNC8]], 1
; CHECK: [[TRUNC:%[0-9]+]]:_(s1) = G_TRUNC [[TRUNCASSERT]]
; CHECK: [[SELECT:%[0-9]+]]:_(s32) = nnan G_SELECT [[TRUNC]](s1), [[COPY1]], [[COPY2]]
define float @test_select_flags(i1 %tst, float %lhs, float %rhs) {
%res = select nnan i1 %tst, float %lhs, float %rhs
Expand All @@ -966,9 +970,11 @@ define float @test_select_cmp_flags(float %cmp0, float %cmp1, float %lhs, float

; CHECK-LABEL: name: test_select_ptr
; CHECK: [[TST_C:%[0-9]+]]:_(s32) = COPY $w0
; CHECK: [[TST:%[0-9]+]]:_(s1) = G_TRUNC [[TST_C]]
; CHECK: [[TSTEXT:%[0-9]+]]:_(s8) = G_TRUNC [[TST_C]]
; CHECK: [[LHS:%[0-9]+]]:_(p0) = COPY $x1
; CHECK: [[RHS:%[0-9]+]]:_(p0) = COPY $x2
; CHECK: [[TSTASSERT:%[0-9]+]]:_(s8) = G_ASSERT_ZEXT [[TSTEXT]], 1
; CHECK: [[TST:%[0-9]+]]:_(s1) = G_TRUNC [[TSTASSERT]]
; CHECK: [[RES:%[0-9]+]]:_(p0) = G_SELECT [[TST]](s1), [[LHS]], [[RHS]]
; CHECK: $x0 = COPY [[RES]]
define i8* @test_select_ptr(i1 %tst, i8* %lhs, i8* %rhs) {
Expand All @@ -978,9 +984,11 @@ define i8* @test_select_ptr(i1 %tst, i8* %lhs, i8* %rhs) {

; CHECK-LABEL: name: test_select_vec
; CHECK: [[TST_C:%[0-9]+]]:_(s32) = COPY $w0
; CHECK: [[TST:%[0-9]+]]:_(s1) = G_TRUNC [[TST_C]]
; CHECK: [[TSTEXT:%[0-9]+]]:_(s8) = G_TRUNC [[TST_C]]
; CHECK: [[LHS:%[0-9]+]]:_(<4 x s32>) = COPY $q0
; CHECK: [[RHS:%[0-9]+]]:_(<4 x s32>) = COPY $q1
; CHECK: [[TSTASSERT:%[0-9]+]]:_(s8) = G_ASSERT_ZEXT [[TSTEXT]], 1
; CHECK: [[TST:%[0-9]+]]:_(s1) = G_TRUNC [[TSTASSERT]]
; CHECK: [[RES:%[0-9]+]]:_(<4 x s32>) = G_SELECT [[TST]](s1), [[LHS]], [[RHS]]
; CHECK: $q0 = COPY [[RES]]
define <4 x i32> @test_select_vec(i1 %tst, <4 x i32> %lhs, <4 x i32> %rhs) {
Expand Down Expand Up @@ -1842,8 +1850,10 @@ define void @test_phi_diamond({ i8, i16, i32 }* %a.ptr, { i8, i16, i32 }* %b.ptr
; CHECK: [[ARG1:%[0-9]+]]:_(p0) = COPY $x0
; CHECK: [[ARG2:%[0-9]+]]:_(p0) = COPY $x1
; CHECK: [[ARG3:%[0-9]+]]:_(s32) = COPY $w2
; CHECK: [[TRUNC:%[0-9]+]]:_(s1) = G_TRUNC [[ARG3]](s32)
; CHECK: [[TRUNC8:%[0-9]+]]:_(s8) = G_TRUNC [[ARG3]]
; CHECK: [[ARG4:%[0-9]+]]:_(p0) = COPY $x3
; CHECK: [[TRUNCASSERT:%[0-9]+]]:_(s8) = G_ASSERT_ZEXT [[TRUNC8]], 1
; CHECK: [[TRUNC:%[0-9]+]]:_(s1) = G_TRUNC [[TRUNCASSERT]]
; CHECK: G_BRCOND [[TRUNC]](s1), %bb.2
; CHECK: G_BR %bb.3

Expand Down Expand Up @@ -2351,7 +2361,7 @@ define void @test_i1_arg_zext(void (i1)* %f) {
; CHECK-LABEL: name: test_i1_arg_zext
; CHECK: [[I1:%[0-9]+]]:_(s1) = G_CONSTANT i1 true
; CHECK: [[ZEXT0:%[0-9]+]]:_(s8) = G_ZEXT [[I1]](s1)
; CHECK: [[ZEXT1:%[0-9]+]]:_(s32) = G_ZEXT [[ZEXT0]](s8)
; CHECK: [[ZEXT1:%[0-9]+]]:_(s32) = G_ANYEXT [[ZEXT0]](s8)
; CHECK: $w0 = COPY [[ZEXT1]](s32)
call void %f(i1 true)
ret void
Expand Down
6 changes: 4 additions & 2 deletions llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-signext.ll
Expand Up @@ -57,8 +57,10 @@ define i32 @signext_param_stack(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e, i64 %f,
; CHECK: [[FRAME_INDEX1:%[0-9]+]]:_(p0) = G_FRAME_INDEX %fixed-stack.0
; CHECK: [[SEXTLOAD:%[0-9]+]]:_(s32) = G_SEXTLOAD [[FRAME_INDEX1]](p0) :: (invariant load (s8) from %fixed-stack.0, align 8)
; CHECK: [[ASSERT_SEXT:%[0-9]+]]:_(s32) = G_ASSERT_SEXT [[SEXTLOAD]], 1
; CHECK: [[TRUNC:%[0-9]+]]:_(s1) = G_TRUNC [[ASSERT_SEXT]](s32)
; CHECK: [[ZEXT:%[0-9]+]]:_(s32) = G_ZEXT [[TRUNC]](s1)
; CHECK: [[TRUNC:%[0-9]+]]:_(s8) = G_TRUNC [[ASSERT_SEXT]](s32)
; CHECK: [[ASSERT_ZEXT:%[0-9]+]]:_(s8) = G_ASSERT_ZEXT [[TRUNC]], 1
; CHECK: [[TRUNC1:%[0-9]+]]:_(s1) = G_TRUNC [[ASSERT_ZEXT]](s8)
; CHECK: [[ZEXT:%[0-9]+]]:_(s32) = G_ZEXT [[TRUNC1]](s1)
; CHECK: $w0 = COPY [[ZEXT]](s32)
; CHECK: RET_ReallyLR implicit $w0
i64 %g, i64 %h, i64 %i, i1 signext %j) {
Expand Down
5 changes: 4 additions & 1 deletion llvm/test/CodeGen/AArch64/GlobalISel/call-translator.ll
Expand Up @@ -254,7 +254,10 @@ define void @test_call_stack() {
; CHECK-NEXT: isImmutable: true,
; CHECK: [[ADDR:%[0-9]+]]:_(p0) = G_FRAME_INDEX %fixed-stack.[[SLOT]]
; CHECK: [[LOAD:%[0-9]+]]:_(s32) = G_LOAD [[ADDR]](p0) :: (invariant load (s8) from %fixed-stack.[[SLOT]], align 16)
; CHECK-NEXT: {{%[0-9]+}}:_(s1) = G_TRUNC [[LOAD]]
; CHECK: [[TRUNC8:%[0-9]+]]:_(s8) = G_TRUNC [[LOAD]]
; CHECK: [[TRUNCASSERT:%[0-9]+]]:_(s8) = G_ASSERT_ZEXT [[TRUNC8]], 1
; CHECK: {{%[0-9]+}}:_(s1) = G_TRUNC [[TRUNCASSERT]]

define void @test_mem_i1([8 x i64], i1 %in) {
ret void
}
Expand Down
1 change: 0 additions & 1 deletion llvm/test/CodeGen/AArch64/arm64-aapcs.ll
Expand Up @@ -33,7 +33,6 @@ define dso_local void @test_stack_slots([8 x i64], i1 %bool, i8 %char, i16 %shor
; CHECK-DAG: ldrb w[[ext3:[0-9]+]], [sp, #8]
; CHECK-DAG: ldr x[[ext4:[0-9]+]], [sp, #32]
; CHECK-DAG: ldrb w[[ext5:[0-9]+]], [sp]
; CHECK-DAG: and x[[ext5]], x[[ext5]], #0x1

%ext_bool = zext i1 %bool to i64
store volatile i64 %ext_bool, i64* @var64, align 8
Expand Down
31 changes: 30 additions & 1 deletion llvm/test/CodeGen/AArch64/i1-contents.ll
@@ -1,4 +1,5 @@
; RUN: llc -mtriple=aarch64-linux-gnu -o - %s | FileCheck %s
; RUN: llc -mtriple=aarch64-linux-gnu -o - %s | FileCheck %s --check-prefixes CHECK,CHECK-SDAG
; RUN: llc -global-isel -mtriple=aarch64-linux-gnu -o - %s | FileCheck %s --check-prefixes CHECK,CHECK-GISEL
%big = type i32

@var = dso_local global %big 0
Expand Down Expand Up @@ -49,6 +50,34 @@ define dso_local void @produce_i1_arg() {
}


define dso_local void @forward_i1_arg1(i1 %in) {
; CHECK-LABEL: forward_i1_arg1:
; CHECK-NOT: and
; CHECK: bl consume_i1_arg
call void @consume_i1_arg(i1 %in)
ret void
}

define dso_local void @forward_i1_arg2(i1 %in, i1 %cond) {
; CHECK-LABEL: forward_i1_arg2:
;
; The optimization in SelectionDAG currently fails to recognize that
; %in is already zero-extended to i8 if the call is not in the entry
; block.
;
; CHECK-SDAG: and
; CHECK-GISEL-NOT: and
;
; CHECK: bl consume_i1_arg
br i1 %cond, label %true, label %false
true:
call void @consume_i1_arg(i1 %in)
ret void

false:
ret void
}

;define zeroext i1 @foo(i8 %in) {
; %val = trunc i8 %in to i1
; ret i1 %val
Expand Down

0 comments on commit 7ae8f39

Please sign in to comment.