Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,12 @@ enum NodeType {
/// separately rounded operations.
FMAD,

/// FMULADD - Performs a * b + c, with, or without, intermediate rounding.
/// It is expected that this will be illegal for most targets, as it usually
/// makes sense to split this or use an FMA. But some targets, such as
/// WebAssembly, can directly support these semantics.
FMULADD,

/// FCOPYSIGN(X, Y) - Return the value of X with the sign of Y. NOTE: This
/// DAG node does not require that X and Y have the same type, just that
/// they are both floating point. X and the result must have the same type.
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def fdiv : SDNode<"ISD::FDIV" , SDTFPBinOp>;
def frem : SDNode<"ISD::FREM" , SDTFPBinOp>;
def fma : SDNode<"ISD::FMA" , SDTFPTernaryOp, [SDNPCommutative]>;
def fmad : SDNode<"ISD::FMAD" , SDTFPTernaryOp, [SDNPCommutative]>;
def fmuladd : SDNode<"ISD::FMULADD" , SDTFPTernaryOp, [SDNPCommutative]>;
def fabs : SDNode<"ISD::FABS" , SDTFPUnaryOp>;
def fminnum : SDNode<"ISD::FMINNUM" , SDTFPBinOp,
[SDNPCommutative, SDNPAssociative]>;
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ namespace {
SDValue visitFMUL(SDNode *N);
template <class MatchContextClass> SDValue visitFMA(SDNode *N);
SDValue visitFMAD(SDNode *N);
SDValue visitFMULADD(SDNode *N);
SDValue visitFDIV(SDNode *N);
SDValue visitFREM(SDNode *N);
SDValue visitFSQRT(SDNode *N);
Expand Down Expand Up @@ -1991,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::FMUL: return visitFMUL(N);
case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
case ISD::FMAD: return visitFMAD(N);
case ISD::FMULADD: return visitFMULADD(N);
case ISD::FDIV: return visitFDIV(N);
case ISD::FREM: return visitFREM(N);
case ISD::FSQRT: return visitFSQRT(N);
Expand Down Expand Up @@ -18444,6 +18446,21 @@ SDValue DAGCombiner::visitFMAD(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitFMULADD(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
EVT VT = N->getValueType(0);
SDLoc DL(N);

// Constant fold FMULADD.
if (SDValue C =
DAG.FoldConstantArithmetic(ISD::FMULADD, DL, VT, {N0, N1, N2}))
return C;

return SDValue();
}

// Combine multiple FDIVs with the same divisor into multiple FMULs by the
// reciprocal.
// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5786,6 +5786,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
case ISD::FCOPYSIGN:
case ISD::FMA:
case ISD::FMAD:
case ISD::FMULADD:
case ISD::FP_EXTEND:
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
Expand Down Expand Up @@ -5904,6 +5905,7 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, const APInt &DemandedElts,
case ISD::FCOSH:
case ISD::FTANH:
case ISD::FMA:
case ISD::FMULADD:
case ISD::FMAD: {
if (SNaN)
return true;
Expand Down Expand Up @@ -7231,7 +7233,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
}

// Handle fma/fmad special cases.
if (Opcode == ISD::FMA || Opcode == ISD::FMAD) {
if (Opcode == ISD::FMA || Opcode == ISD::FMAD || Opcode == ISD::FMULADD) {
assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
assert(Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
Ops[2].getValueType() == VT && "FMA types must match!");
Expand All @@ -7242,7 +7244,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
APFloat V1 = C1->getValueAPF();
const APFloat &V2 = C2->getValueAPF();
const APFloat &V3 = C3->getValueAPF();
if (Opcode == ISD::FMAD) {
if (Opcode == ISD::FMAD || Opcode == ISD::FMULADD) {
V1.multiply(V2, APFloat::rmNearestTiesToEven);
V1.add(V3, APFloat::rmNearestTiesToEven);
} else
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6996,6 +6996,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
getValue(I.getArgOperand(0)),
getValue(I.getArgOperand(1)),
getValue(I.getArgOperand(2)), Flags));
} else if (TLI.isOperationLegalOrCustom(ISD::FMULADD, VT)) {
// TODO: Support splitting the vector.
setValue(&I, DAG.getNode(ISD::FMULADD, sdl,
getValue(I.getArgOperand(0)).getValueType(),
getValue(I.getArgOperand(0)),
getValue(I.getArgOperand(1)),
getValue(I.getArgOperand(2)), Flags));
} else {
// TODO: Intrinsic calls should have fast-math-flags.
SDValue Mul = DAG.getNode(
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FMA: return "fma";
case ISD::STRICT_FMA: return "strict_fma";
case ISD::FMAD: return "fmad";
case ISD::FMULADD: return "fmuladd";
case ISD::FREM: return "frem";
case ISD::STRICT_FREM: return "strict_frem";
case ISD::FCOPYSIGN: return "fcopysign";
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7676,6 +7676,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
break;
}
case ISD::FMA:
case ISD::FMULADD:
case ISD::FMAD: {
if (!Flags.hasNoSignedZeros())
break;
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,8 @@ void TargetLoweringBase::initActions() {
ISD::FTAN, ISD::FACOS,
ISD::FASIN, ISD::FATAN,
ISD::FCOSH, ISD::FSINH,
ISD::FTANH, ISD::FATAN2},
ISD::FTANH, ISD::FATAN2,
ISD::FMULADD},
VT, Expand);

// Overflow operations default to expand
Expand Down
21 changes: 21 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,15 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
}

if (Subtarget->hasFP16()) {
setOperationAction(ISD::FMA, MVT::v8f16, Legal);
}

if (Subtarget->hasRelaxedSIMD()) {
setOperationAction(ISD::FMULADD, MVT::v4f32, Legal);
setOperationAction(ISD::FMULADD, MVT::v2f64, Legal);
}

// Partial MLA reductions.
for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal);
Expand Down Expand Up @@ -1120,6 +1129,18 @@ WebAssemblyTargetLowering::getPreferredVectorAction(MVT VT) const {
return TargetLoweringBase::getPreferredVectorAction(VT);
}

bool WebAssemblyTargetLowering::isFMAFasterThanFMulAndFAdd(
const MachineFunction &MF, EVT VT) const {
if (!Subtarget->hasFP16() || !VT.isVector())
return false;

EVT ScalarVT = VT.getScalarType();
if (!ScalarVT.isSimple())
return false;

return ScalarVT.getSimpleVT().SimpleTy == MVT::f16;
}

bool WebAssemblyTargetLowering::shouldSimplifyDemandedVectorElts(
SDValue Op, const TargetLoweringOpt &TLO) const {
// ISel process runs DAGCombiner after legalization; this step is called
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class WebAssemblyTargetLowering final : public TargetLowering {

TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const override;
bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
EVT VT) const override;

SDValue LowerCall(CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const override;
Expand Down
41 changes: 33 additions & 8 deletions llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
Original file line number Diff line number Diff line change
Expand Up @@ -1626,7 +1626,8 @@ defm "" : RelaxedConvert<I32x4, F64x2, int_wasm_relaxed_trunc_unsigned_zero,
// Relaxed (Negative) Multiply-Add (madd/nmadd)
//===----------------------------------------------------------------------===//

multiclass SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS, list<Predicate> reqs> {
multiclass RELAXED_SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS,
list<Predicate> reqs> {
defm MADD_#vec :
SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins),
[(set (vec.vt V128:$dst), (int_wasm_relaxed_madd
Expand All @@ -1640,16 +1641,40 @@ multiclass SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS, list<Predicate>
vec.prefix#".relaxed_nmadd\t$dst, $a, $b, $c",
vec.prefix#".relaxed_nmadd", simdopS, reqs>;

def : Pat<(fadd_contract (vec.vt V128:$a), (fmul_contract (vec.vt V128:$b), (vec.vt V128:$c))),
(!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<[HasRelaxedSIMD]>;
def : Pat<(fadd_contract (fmul_contract (vec.vt V128:$a), (vec.vt V128:$b)), (vec.vt V128:$c)),
(!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>;
def : Pat<(fmuladd (vec.vt V128:$a), (vec.vt V128:$b), (vec.vt V128:$c)),
(!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>;

def : Pat<(fsub_contract (vec.vt V128:$a), (fmul_contract (vec.vt V128:$b), (vec.vt V128:$c))),
(!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<[HasRelaxedSIMD]>;
def : Pat<(fsub_contract (vec.vt V128:$c), (fmul_contract (vec.vt V128:$a), (vec.vt V128:$b))),
(!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>;
def : Pat<(fmuladd (fneg (vec.vt V128:$a)), (vec.vt V128:$b), (vec.vt V128:$c)),
(!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>;
}

defm "" : SIMDMADD<F32x4, 0x105, 0x106, [HasRelaxedSIMD]>;
defm "" : SIMDMADD<F64x2, 0x107, 0x108, [HasRelaxedSIMD]>;
defm "" : SIMDMADD<F16x8, 0x14e, 0x14f, [HasFP16]>;
defm "" : RELAXED_SIMDMADD<F32x4, 0x105, 0x106, [HasRelaxedSIMD]>;
defm "" : RELAXED_SIMDMADD<F64x2, 0x107, 0x108, [HasRelaxedSIMD]>;

//===----------------------------------------------------------------------===//
// FP16 (Negative) Multiply-Add (madd/nmadd)
//===----------------------------------------------------------------------===//

multiclass HALF_PRECISION_SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS,
list<Predicate> reqs> {
defm MADD_#vec :
SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins),
[(set (vec.vt V128:$dst), (fma
(vec.vt V128:$a), (vec.vt V128:$b), (vec.vt V128:$c)))],
vec.prefix#".madd\t$dst, $a, $b, $c",
vec.prefix#".madd", simdopA, reqs>;
defm NMADD_#vec :
SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins),
[(set (vec.vt V128:$dst), (fma
(fneg (vec.vt V128:$a)), (vec.vt V128:$b), (vec.vt V128:$c)))],
vec.prefix#".nmadd\t$dst, $a, $b, $c",
vec.prefix#".nmadd", simdopS, reqs>;
}
defm "" : HALF_PRECISION_SIMDMADD<F16x8, 0x14e, 0x14f, [HasFP16]>;

//===----------------------------------------------------------------------===//
// Laneselect
Expand Down
Loading