Skip to content

Commit

Permalink
[ARM] Armv8.2-A FP16 code generation (part 2/3)
Browse files Browse the repository at this point in the history
Half-precision arguments and return values are passed as if it were an int or
float for ARM. This results in truncates and bitcasts to/from i16 and f16
values, which are legalized very early to stack stores/loads. When FullFP16 is
enabled, we want to avoid codegen for these bitcasts as it is unnecessary and
inefficient.

Differential Revision: https://reviews.llvm.org/D42580

llvm-svn: 323861
  • Loading branch information
Sjoerd Meijer committed Jan 31, 2018
1 parent c209180 commit 98d5359
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 44 deletions.
104 changes: 73 additions & 31 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Expand Up @@ -524,9 +524,9 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,

if (Subtarget->hasFullFP16()) {
addRegisterClass(MVT::f16, &ARM::HPRRegClass);
// Clean up bitcast of incoming arguments if hard float abi is enabled.
if (Subtarget->isTargetHardFloat())
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
setOperationAction(ISD::BITCAST, MVT::i32, Custom);
setOperationAction(ISD::BITCAST, MVT::f16, Custom);
}

for (MVT VT : MVT::vector_valuetypes()) {
Expand Down Expand Up @@ -1273,6 +1273,8 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {

case ARMISD::VMOVRRD: return "ARMISD::VMOVRRD";
case ARMISD::VMOVDRR: return "ARMISD::VMOVDRR";
case ARMISD::VMOVhr: return "ARMISD::VMOVhr";
case ARMISD::VMOVrh: return "ARMISD::VMOVrh";

case ARMISD::EH_SJLJ_SETJMP: return "ARMISD::EH_SJLJ_SETJMP";
case ARMISD::EH_SJLJ_LONGJMP: return "ARMISD::EH_SJLJ_LONGJMP";
Expand Down Expand Up @@ -5051,7 +5053,8 @@ static SDValue CombineVMOVDRRCandidateWithVecOp(const SDNode *BC,
/// use a VMOVDRR or VMOVRRD node. This should not be done when the non-i64
/// operand type is illegal (e.g., v2f32 for a target that doesn't support
/// vectors), since the legalizer won't know what to do with that.
static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG) {
static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG,
const ARMSubtarget *Subtarget) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDLoc dl(N);
SDValue Op = N->getOperand(0);
Expand All @@ -5060,39 +5063,78 @@ static SDValue ExpandBITCAST(SDNode *N, SelectionDAG &DAG) {
// source or destination of the bit convert.
EVT SrcVT = Op.getValueType();
EVT DstVT = N->getValueType(0);
const bool HasFullFP16 = Subtarget->hasFullFP16();

// Half-precision arguments can be passed in like this:
//
// t4: f32,ch = CopyFromReg t0, Register:f32 %1
// t8: i32 = bitcast t4
// t9: i16 = truncate t8
// t10: f16 = bitcast t9 <~~~~ SDNode N
//
// but we want to avoid code generation for the bitcast, so transform this
// into:
//
// t18: f16 = CopyFromReg t0, Register:f32 %0
//
if (SrcVT == MVT::i16 && DstVT == MVT::f16) {
if (Op.getOpcode() != ISD::TRUNCATE)
return SDValue();
if (SrcVT == MVT::f32 && DstVT == MVT::i32) {
// FullFP16: half values are passed in S-registers, and we don't
// need any of the bitcast and moves:
//
// t2: f32,ch = CopyFromReg t0, Register:f32 %0
// t5: i32 = bitcast t2
// t18: f16 = ARMISD::VMOVhr t5
if (Op.getOpcode() != ISD::CopyFromReg ||
Op.getValueType() != MVT::f32)
return SDValue();

auto Move = N->use_begin();
if (Move->getOpcode() != ARMISD::VMOVhr)
return SDValue();

SDValue Ops[] = { Op.getOperand(0), Op.getOperand(1) };
SDValue Copy = DAG.getNode(ISD::CopyFromReg, SDLoc(Op), MVT::f16, Ops);
DAG.ReplaceAllUsesWith(*Move, &Copy);
return Copy;
}

SDValue Bitcast = Op.getOperand(0);
if (Bitcast.getOpcode() != ISD::BITCAST ||
Bitcast.getValueType() != MVT::i32)
if (SrcVT == MVT::i16 && DstVT == MVT::f16) {
if (!HasFullFP16)
return SDValue();
// SoftFP: read half-precision arguments:
//
// t2: i32,ch = ...
// t7: i16 = truncate t2 <~~~~ Op
// t8: f16 = bitcast t7 <~~~~ N
//
if (Op.getOperand(0).getValueType() == MVT::i32)
return DAG.getNode(ARMISD::VMOVhr, SDLoc(Op),
MVT::f16, Op.getOperand(0));

return SDValue();
}

SDValue Copy = Bitcast.getOperand(0);
if (Copy.getOpcode() != ISD::CopyFromReg ||
Copy.getValueType() != MVT::f32)
// Half-precision return values
if (SrcVT == MVT::f16 && DstVT == MVT::i16) {
if (!HasFullFP16)
return SDValue();
//
// t11: f16 = fadd t8, t10
// t12: i16 = bitcast t11 <~~~ SDNode N
// t13: i32 = zero_extend t12
// t16: ch,glue = CopyToReg t0, Register:i32 %r0, t13
// t17: ch = ARMISD::RET_FLAG t16, Register:i32 %r0, t16:1
//
// transform this into:
//
// t20: i32 = ARMISD::VMOVrh t11
// t16: ch,glue = CopyToReg t0, Register:i32 %r0, t20
//
auto ZeroExtend = N->use_begin();
if (N->use_size() != 1 || ZeroExtend->getOpcode() != ISD::ZERO_EXTEND ||
ZeroExtend->getValueType(0) != MVT::i32)
return SDValue();

SDValue Ops[] = { Copy->getOperand(0), Copy->getOperand(1) };
return DAG.getNode(ISD::CopyFromReg, SDLoc(Copy), MVT::f16, Ops);
auto Copy = ZeroExtend->use_begin();
if (Copy->getOpcode() == ISD::CopyToReg &&
Copy->use_begin()->getOpcode() == ARMISD::RET_FLAG) {
SDValue Cvt = DAG.getNode(ARMISD::VMOVrh, SDLoc(Op), MVT::i32, Op);
DAG.ReplaceAllUsesWith(*ZeroExtend, &Cvt);
return Cvt;
}
return SDValue();
}

assert((SrcVT == MVT::i64 || DstVT == MVT::i64) &&
"ExpandBITCAST called for non-i64 type");
if (!(SrcVT == MVT::i64 || DstVT == MVT::i64))
return SDValue();

// Turn i64->f64 into VMOVDRR.
if (SrcVT == MVT::i64 && TLI.isTypeLegal(DstVT)) {
Expand Down Expand Up @@ -7982,7 +8024,7 @@ SDValue ARMTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::EH_SJLJ_SETUP_DISPATCH: return LowerEH_SJLJ_SETUP_DISPATCH(Op, DAG);
case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG,
Subtarget);
case ISD::BITCAST: return ExpandBITCAST(Op.getNode(), DAG);
case ISD::BITCAST: return ExpandBITCAST(Op.getNode(), DAG, Subtarget);
case ISD::SHL:
case ISD::SRL:
case ISD::SRA: return LowerShift(Op.getNode(), DAG, Subtarget);
Expand Down Expand Up @@ -8084,7 +8126,7 @@ void ARMTargetLowering::ReplaceNodeResults(SDNode *N,
ExpandREAD_REGISTER(N, Results, DAG);
break;
case ISD::BITCAST:
Res = ExpandBITCAST(N, DAG);
Res = ExpandBITCAST(N, DAG, Subtarget);
break;
case ISD::SRL:
case ISD::SRA:
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.h
Expand Up @@ -171,6 +171,10 @@ class VectorType;
// Vector move f32 immediate:
VMOVFPIMM,

// Move H <-> R, clearing top 16 bits
VMOVrh,
VMOVhr,

// Vector duplicate:
VDUP,
VDUPLANE,
Expand Down
13 changes: 9 additions & 4 deletions llvm/lib/Target/ARM/ARMInstrVFP.td
Expand Up @@ -23,6 +23,11 @@ def arm_cmpfp0 : SDNode<"ARMISD::CMPFPw0", SDT_CMPFP0, [SDNPOutGlue]>;
def arm_fmdrr : SDNode<"ARMISD::VMOVDRR", SDT_VMOVDRR>;
def arm_fmrrd : SDNode<"ARMISD::VMOVRRD", SDT_VMOVRRD>;

def SDT_VMOVhr : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVT<1, i32>] >;
def SDT_VMOVrh : SDTypeProfile<1, 1, [SDTCisVT<0, i32>, SDTCisFP<1>] >;
def arm_vmovhr : SDNode<"ARMISD::VMOVhr", SDT_VMOVhr>;
def arm_vmovrh : SDNode<"ARMISD::VMOVrh", SDT_VMOVrh>;

//===----------------------------------------------------------------------===//
// Operand Definitions.
//
Expand Down Expand Up @@ -1171,9 +1176,9 @@ def VMOVSRR : AVConv5I<0b11000100, 0b1010,

// Move H->R, clearing top 16 bits
def VMOVRH : AVConv2I<0b11100001, 0b1001,
(outs GPR:$Rt), (ins SPR:$Sn),
(outs GPR:$Rt), (ins HPR:$Sn),
IIC_fpMOVSI, "vmov", ".f16\t$Rt, $Sn",
[]>,
[(set GPR:$Rt, (arm_vmovrh HPR:$Sn))]>,
Requires<[HasFullFP16]>,
Sched<[WriteFPMOV]> {
// Instruction operands.
Expand All @@ -1191,9 +1196,9 @@ def VMOVRH : AVConv2I<0b11100001, 0b1001,

// Move R->H, clearing top 16 bits
def VMOVHR : AVConv4I<0b11100000, 0b1001,
(outs SPR:$Sn), (ins GPR:$Rt),
(outs HPR:$Sn), (ins GPR:$Rt),
IIC_fpMOVIS, "vmov", ".f16\t$Sn, $Rt",
[]>,
[(set HPR:$Sn, (arm_vmovhr GPR:$Rt))]>,
Requires<[HasFullFP16]>,
Sched<[WriteFPMOV]> {
// Instruction operands.
Expand Down
14 changes: 5 additions & 9 deletions llvm/test/CodeGen/ARM/fp16-instructions.ll
Expand Up @@ -43,14 +43,11 @@ entry:
; CHECK-SOFTFP-FP16: vcvtb.f16.f32 [[S0]], [[S0]]
; CHECK-SOFTFP-FP16: vmov r0, s0

; CHECK-SOFTFP-FULLFP16: strh r1, {{.*}}
; CHECK-SOFTFP-FULLFP16: strh r0, {{.*}}
; CHECK-SOFTFP-FULLFP16: vldr.16 [[S0:s[0-9]]], {{.*}}
; CHECK-SOFTFP-FULLFP16: vldr.16 [[S2:s[0-9]]], {{.*}}
; CHECK-SOFTFP-FULLFP16: vadd.f16 [[S0]], [[S2]], [[S0]]
; CHECK-SOFTFP-FULLFP16: vstr.16 [[S2:s[0-9]]], {{.*}}
; CHECK-SOFTFP-FULLFP16: ldrh r0, {{.*}}
; CHECK-SOFTFP-FULLFP16: mov pc, lr
; CHECK-SOFTFP-FULLFP16: vmov.f16 [[S0:s[0-9]]], r1
; CHECK-SOFTFP-FULLFP16: vmov.f16 [[S2:s[0-9]]], r0
; CHECK-SOFTFP-FULLFP16: vadd.f16 [[S0]], [[S2]], [[S0]]
; CHECK-SOFTFP-FULLFP16-NEXT: vmov.f16 r0, s0
; CHECK-SOFTFP-FULLFP16-NEXT: mov pc, lr

; CHECK-HARDFP-VFP3: vmov r{{.}}, s0
; CHECK-HARDFP-VFP3: vmov{{.*}}, s1
Expand All @@ -69,4 +66,3 @@ entry:
; CHECK-HARDFP-FULLFP16-NEXT: mov pc, lr

}

0 comments on commit 98d5359

Please sign in to comment.