Skip to content
118 changes: 93 additions & 25 deletions cranelift/codegen/src/isa/aarch64/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1246,21 +1246,41 @@ pub(crate) fn maybe_input_insn_via_conv<C: LowerCtx<I = Inst>>(
/// Specifies what [lower_icmp] should do when lowering
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum IcmpOutput {
/// Only sets flags, discarding the results
Flags,
/// Materializes the results into a register. The flags set may be incorrect
/// Lowers the comparison into a cond code, discarding the results. The cond code emitted can
/// be checked in the resulting [IcmpResult].
CondCode,
/// Materializes the results into a register. This may overwrite any flags previously set.
Register(Writable<Reg>),
}

impl IcmpOutput {
pub fn reg(&self) -> Option<Writable<Reg>> {
match self {
IcmpOutput::Flags => None,
IcmpOutput::CondCode => None,
IcmpOutput::Register(reg) => Some(*reg),
}
}
}

/// The output of an Icmp lowering.
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum IcmpResult {
/// The result was output into the given [Cond]. Callers may perform operations using this [Cond]
/// and its inverse, other [Cond]'s are not guaranteed to be correct.
CondCode(Cond),
/// The result was materialized into the output register.
Register,
}

impl IcmpResult {
pub fn unwrap_cond(&self) -> Cond {
match self {
IcmpResult::CondCode(c) => *c,
_ => panic!("Unwrapped cond, but IcmpResult was {:?}", self),
}
}
}

/// Lower an icmp comparision
///
/// We can lower into the status flags, or materialize the result into a register
Expand All @@ -1270,7 +1290,7 @@ pub(crate) fn lower_icmp<C: LowerCtx<I = Inst>>(
insn: IRInst,
condcode: IntCC,
output: IcmpOutput,
) -> CodegenResult<()> {
) -> CodegenResult<IcmpResult> {
debug!(
"lower_icmp: insn {}, condcode: {}, output: {:?}",
insn, condcode, output
Expand All @@ -1288,8 +1308,9 @@ pub(crate) fn lower_icmp<C: LowerCtx<I = Inst>>(
(false, true) => NarrowValueMode::SignExtend64,
(false, false) => NarrowValueMode::ZeroExtend64,
};
let mut should_materialize = output.reg().is_some();

if ty == I128 {
let out_condcode = if ty == I128 {
let lhs = put_input_in_regs(ctx, inputs[0]);
let rhs = put_input_in_regs(ctx, inputs[1]);

Expand Down Expand Up @@ -1321,10 +1342,6 @@ pub(crate) fn lower_icmp<C: LowerCtx<I = Inst>>(
rn: tmp1.to_reg(),
rm: tmp2.to_reg(),
});

if let IcmpOutput::Register(rd) = output {
materialize_bool_result(ctx, insn, rd, cond);
}
}
IntCC::Overflow | IntCC::NotOverflow => {
// We can do an 128bit add while throwing away the results
Expand All @@ -1346,10 +1363,6 @@ pub(crate) fn lower_icmp<C: LowerCtx<I = Inst>>(
rn: lhs.regs()[1],
rm: rhs.regs()[1],
});

if let IcmpOutput::Register(rd) = output {
materialize_bool_result(ctx, insn, rd, cond);
}
}
_ => {
// cmp lhs_lo, rhs_lo
Expand Down Expand Up @@ -1382,7 +1395,7 @@ pub(crate) fn lower_icmp<C: LowerCtx<I = Inst>>(
rm: tmp2.to_reg(),
});

if output == IcmpOutput::Flags {
if output == IcmpOutput::CondCode {
// We only need to guarantee that the flags for `cond` are correct, so we can
// compare rd with 0 or 1

Expand Down Expand Up @@ -1413,24 +1426,79 @@ pub(crate) fn lower_icmp<C: LowerCtx<I = Inst>>(
rm,
});
}

// Prevent a second materialize_bool_result to be emitted at the end of the function
should_materialize = false;
}
}
} else if !ty.is_vector() {
let alu_op = choose_32_64(ty, ALUOp::SubS32, ALUOp::SubS64);
let rn = put_input_in_reg(ctx, inputs[0], narrow_mode);
let rm = put_input_in_rse_imm12(ctx, inputs[1], narrow_mode);
ctx.emit(alu_inst_imm12(alu_op, writable_zero_reg(), rn, rm));
cond
} else if ty.is_vector() {
assert_ne!(output, IcmpOutput::CondCode);
should_materialize = false;

if let IcmpOutput::Register(rd) = output {
materialize_bool_result(ctx, insn, rd, cond);
}
} else {
let rn = put_input_in_reg(ctx, inputs[0], narrow_mode);
let rm = put_input_in_reg(ctx, inputs[1], narrow_mode);
lower_vector_compare(ctx, rd, rn, rm, ty, cond)?;
cond
} else {
let rn = put_input_in_reg(ctx, inputs[0], narrow_mode);
let rm = put_input_in_rse_imm12(ctx, inputs[1], narrow_mode);

let is_overflow = condcode == IntCC::Overflow || condcode == IntCC::NotOverflow;
let is_small_type = ty == I8 || ty == I16;
let (cond, rn, rm) = if is_overflow && is_small_type {
// Overflow checks for non native types require additional instructions, other than
// just the extend op.
//
// TODO: Codegen improvements: Merge the second sxt{h,b} into the following sub instruction.
//
// sxt{h,b} w0, w0
// sxt{h,b} w1, w1
// sub w0, w0, w1
// cmp w0, w0, sxt{h,b}
//
// The result of this comparison is either the EQ or NE condition code, so we need to
// signal that to the caller

let extend_op = if ty == I8 {
ExtendOp::SXTB
} else {
ExtendOp::SXTH
};
let tmp1 = ctx.alloc_tmp(I32).only_reg().unwrap();
ctx.emit(alu_inst_imm12(ALUOp::Sub32, tmp1, rn, rm));

let out_cond = match condcode {
IntCC::Overflow => Cond::Ne,
IntCC::NotOverflow => Cond::Eq,
_ => unreachable!(),
};
(
out_cond,
tmp1.to_reg(),
ResultRSEImm12::RegExtend(tmp1.to_reg(), extend_op),
)
} else {
(cond, rn, rm)
};

let alu_op = choose_32_64(ty, ALUOp::SubS32, ALUOp::SubS64);
ctx.emit(alu_inst_imm12(alu_op, writable_zero_reg(), rn, rm));
cond
};

// Most of the comparisons above produce flags by default, if the caller requested the result
// in a register we materialize those flags into a register. Some branches do end up producing
// the result as a register by default, so we ignore those.
if should_materialize {
materialize_bool_result(ctx, insn, rd, cond);
}

Ok(())
Ok(match output {
// We currently never emit a different register than what was asked for
IcmpOutput::Register(_) => IcmpResult::Register,
IcmpOutput::CondCode => IcmpResult::CondCode(out_condcode),
})
}

pub(crate) fn lower_fcmp_or_ffcmp_to_flags<C: LowerCtx<I = Inst>>(ctx: &mut C, insn: IRInst) {
Expand Down
31 changes: 12 additions & 19 deletions cranelift/codegen/src/isa/aarch64/lower_inst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1658,9 +1658,7 @@ pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
maybe_input_insn_via_conv(ctx, flag_input, Opcode::Icmp, Opcode::Bint)
{
let condcode = ctx.data(icmp_insn).cond_code().unwrap();
let cond = lower_condcode(condcode);
lower_icmp(ctx, icmp_insn, condcode, IcmpOutput::Flags)?;
cond
lower_icmp(ctx, icmp_insn, condcode, IcmpOutput::CondCode)?.unwrap_cond()
} else if let Some(fcmp_insn) =
maybe_input_insn_via_conv(ctx, flag_input, Opcode::Fcmp, Opcode::Bint)
{
Expand Down Expand Up @@ -1723,11 +1721,10 @@ pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(

Opcode::Selectif | Opcode::SelectifSpectreGuard => {
let condcode = ctx.data(insn).cond_code().unwrap();
let cond = lower_condcode(condcode);
// Verification ensures that the input is always a
// single-def ifcmp.
let ifcmp_insn = maybe_input_insn(ctx, inputs[0], Opcode::Ifcmp).unwrap();
lower_icmp(ctx, ifcmp_insn, condcode, IcmpOutput::Flags)?;
let cond = lower_icmp(ctx, ifcmp_insn, condcode, IcmpOutput::CondCode)?.unwrap_cond();

// csel.COND rd, rn, rm
let rd = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
Expand Down Expand Up @@ -2044,12 +2041,10 @@ pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
cond
} else if op == Opcode::Trapif {
let condcode = ctx.data(insn).cond_code().unwrap();
let cond = lower_condcode(condcode);

// Verification ensures that the input is always a single-def ifcmp.
let ifcmp_insn = maybe_input_insn(ctx, inputs[0], Opcode::Ifcmp).unwrap();
lower_icmp(ctx, ifcmp_insn, condcode, IcmpOutput::Flags)?;
cond
lower_icmp(ctx, ifcmp_insn, condcode, IcmpOutput::CondCode)?.unwrap_cond()
} else {
let condcode = ctx.data(insn).fp_cond_code().unwrap();
let cond = lower_fp_condcode(condcode);
Expand Down Expand Up @@ -3603,11 +3598,11 @@ pub(crate) fn lower_branch<C: LowerCtx<I = Inst>>(
maybe_input_insn_via_conv(ctx, flag_input, Opcode::Icmp, Opcode::Bint)
{
let condcode = ctx.data(icmp_insn).cond_code().unwrap();
let cond = lower_condcode(condcode);
let cond =
lower_icmp(ctx, icmp_insn, condcode, IcmpOutput::CondCode)?.unwrap_cond();
let negated = op0 == Opcode::Brz;
let cond = if negated { cond.invert() } else { cond };

lower_icmp(ctx, icmp_insn, condcode, IcmpOutput::Flags)?;
ctx.emit(Inst::CondBr {
taken,
not_taken,
Expand Down Expand Up @@ -3655,32 +3650,30 @@ pub(crate) fn lower_branch<C: LowerCtx<I = Inst>>(
}
Opcode::BrIcmp => {
let condcode = ctx.data(branches[0]).cond_code().unwrap();
let cond = lower_condcode(condcode);
let kind = CondBrKind::Cond(cond);
let cond =
lower_icmp(ctx, branches[0], condcode, IcmpOutput::CondCode)?.unwrap_cond();

lower_icmp(ctx, branches[0], condcode, IcmpOutput::Flags)?;
ctx.emit(Inst::CondBr {
taken,
not_taken,
kind,
kind: CondBrKind::Cond(cond),
});
}

Opcode::Brif => {
let condcode = ctx.data(branches[0]).cond_code().unwrap();
let cond = lower_condcode(condcode);
let kind = CondBrKind::Cond(cond);

let flag_input = InsnInput {
insn: branches[0],
input: 0,
};
if let Some(ifcmp_insn) = maybe_input_insn(ctx, flag_input, Opcode::Ifcmp) {
lower_icmp(ctx, ifcmp_insn, condcode, IcmpOutput::Flags)?;
let cond =
lower_icmp(ctx, ifcmp_insn, condcode, IcmpOutput::CondCode)?.unwrap_cond();
ctx.emit(Inst::CondBr {
taken,
not_taken,
kind,
kind: CondBrKind::Cond(cond),
});
} else {
// If the ifcmp result is actually placed in a
Expand All @@ -3690,7 +3683,7 @@ pub(crate) fn lower_branch<C: LowerCtx<I = Inst>>(
ctx.emit(Inst::CondBr {
taken,
not_taken,
kind,
kind: CondBrKind::Cond(lower_condcode(condcode)),
});
}
}
Expand Down
Loading