Expand Up
@@ -464,33 +464,45 @@ def SHL2MUL16 : SDNodeXForm<imm, [{
return CurDAG->getTargetConstant(temp.shl(v), MVT::i16);
}]>;
def MULWIDES64 : NVPTXInst<(outs Int64Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b),
def MULWIDES64
: NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
"mul.wide.s32 \t$dst, $a, $b;", []>;
def MULWIDES64Imm
: NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
"mul.wide.s32 \t$dst, $a, $b;", []>;
def MULWIDES64Imm : NVPTXInst<(outs Int64Regs:$dst),
(ins Int32Regs:$a, i64imm:$b),
def MULWIDES64Imm64
: NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i64imm:$b),
"mul.wide.s32 \t$dst, $a, $b;", []>;
def MULWIDEU64 : NVPTXInst<(outs Int64Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b),
def MULWIDEU64
: NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
"mul.wide.u32 \t$dst, $a, $b;", []>;
def MULWIDEU64Imm
: NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
"mul.wide.u32 \t$dst, $a, $b;", []>;
def MULWIDEU64Imm : NVPTXInst<(outs Int64Regs:$dst),
(ins Int32Regs:$a, i64imm:$b),
def MULWIDEU64Imm64
: NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i64imm:$b),
"mul.wide.u32 \t$dst, $a, $b;", []>;
def MULWIDES32 : NVPTXInst<(outs Int32Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b),
def MULWIDES32
: NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b),
"mul.wide.s16 \t$dst, $a, $b;", []>;
def MULWIDES32Imm : NVPTXInst<(outs Int32Regs:$dst),
(ins Int16Regs:$a, i32imm:$b),
def MULWIDES32Imm
: NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm:$b),
"mul.wide.s16 \t$dst, $a, $b;", []>;
def MULWIDES32Imm32
: NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i32imm:$b),
"mul.wide.s16 \t$dst, $a, $b;", []>;
def MULWIDEU32 : NVPTXInst<(outs Int32Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b),
"mul.wide.u16 \t$dst, $a, $b;", []>;
def MULWIDEU32Imm : NVPTXInst<(outs Int32Regs:$dst),
(ins Int16Regs:$a, i32imm :$b),
def MULWIDEU32
: NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b),
"mul.wide.u16 \t$dst, $a, $b;", []>;
def MULWIDEU32Imm
: NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm :$b),
"mul.wide.u16 \t$dst, $a, $b;", []>;
def MULWIDEU32Imm32
: NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i32imm:$b),
"mul.wide.u16 \t$dst, $a, $b;", []>;
def : Pat<(shl (sext Int32Regs:$a), (i32 Int5Const:$b)),
(MULWIDES64Imm Int32Regs:$a, (SHL2MUL32 node:$b))>,
Expand All
@@ -510,25 +522,63 @@ def : Pat<(mul (sext Int32Regs:$a), (sext Int32Regs:$b)),
(MULWIDES64 Int32Regs:$a, Int32Regs:$b)>,
Requires<[doMulWide]>;
def : Pat<(mul (sext Int32Regs:$a), (i64 SInt32Const:$b)),
(MULWIDES64Imm Int32Regs:$a, (i64 SInt32Const:$b))>,
(MULWIDES64Imm64 Int32Regs:$a, (i64 SInt32Const:$b))>,
Requires<[doMulWide]>;
def : Pat<(mul (zext Int32Regs:$a), (zext Int32Regs:$b)),
(MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>, Requires<[doMulWide]>;
(MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>,
Requires<[doMulWide]>;
def : Pat<(mul (zext Int32Regs:$a), (i64 UInt32Const:$b)),
(MULWIDEU64Imm Int32Regs:$a, (i64 UInt32Const:$b))>,
(MULWIDEU64Imm64 Int32Regs:$a, (i64 UInt32Const:$b))>,
Requires<[doMulWide]>;
def : Pat<(mul (sext Int16Regs:$a), (sext Int16Regs:$b)),
(MULWIDES32 Int16Regs:$a, Int16Regs:$b)>, Requires<[doMulWide]>;
(MULWIDES32 Int16Regs:$a, Int16Regs:$b)>,
Requires<[doMulWide]>;
def : Pat<(mul (sext Int16Regs:$a), (i32 SInt16Const:$b)),
(MULWIDES32Imm Int16Regs:$a, (i32 SInt16Const:$b))>,
(MULWIDES32Imm32 Int16Regs:$a, (i32 SInt16Const:$b))>,
Requires<[doMulWide]>;
def : Pat<(mul (zext Int16Regs:$a), (zext Int16Regs:$b)),
(MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>, Requires<[doMulWide]>;
(MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>,
Requires<[doMulWide]>;
def : Pat<(mul (zext Int16Regs:$a), (i32 UInt16Const:$b)),
(MULWIDEU32Imm Int16Regs:$a, (i32 UInt16Const:$b))>,
(MULWIDEU32Imm32 Int16Regs:$a, (i32 UInt16Const:$b))>,
Requires<[doMulWide]>;
def SDTMulWide
: SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>;
def mul_wide_signed
: SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
def mul_wide_unsigned
: SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
def : Pat<(i32 (mul_wide_signed Int16Regs:$a, Int16Regs:$b)),
(MULWIDES32 Int16Regs:$a, Int16Regs:$b)>,
Requires<[doMulWide]>;
def : Pat<(i32 (mul_wide_signed Int16Regs:$a, imm:$b)),
(MULWIDES32Imm Int16Regs:$a, imm:$b)>,
Requires<[doMulWide]>;
def : Pat<(i32 (mul_wide_unsigned Int16Regs:$a, Int16Regs:$b)),
(MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>,
Requires<[doMulWide]>;
def : Pat<(i32 (mul_wide_unsigned Int16Regs:$a, imm:$b)),
(MULWIDEU32Imm Int16Regs:$a, imm:$b)>,
Requires<[doMulWide]>;
def : Pat<(i64 (mul_wide_signed Int32Regs:$a, Int32Regs:$b)),
(MULWIDES64 Int32Regs:$a, Int32Regs:$b)>,
Requires<[doMulWide]>;
def : Pat<(i64 (mul_wide_signed Int32Regs:$a, imm:$b)),
(MULWIDES64Imm Int32Regs:$a, imm:$b)>,
Requires<[doMulWide]>;
def : Pat<(i64 (mul_wide_unsigned Int32Regs:$a, Int32Regs:$b)),
(MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>,
Requires<[doMulWide]>;
def : Pat<(i64 (mul_wide_unsigned Int32Regs:$a, imm:$b)),
(MULWIDEU64Imm Int32Regs:$a, imm:$b)>,
Requires<[doMulWide]>;
defm MULT : I3<"mul.lo.s", mul>;
Expand All
@@ -544,69 +594,75 @@ defm SREM : I3<"rem.s", srem>;
defm UREM : I3<"rem.u", urem>;
// The ri version will not be selected as DAGCombiner::visitUREM will lower it.
def SDTIMAD
: SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>,
SDTCisInt<2>, SDTCisSameAs<0, 2>,
SDTCisSameAs<0, 3>]>;
def imad
: SDNode<"NVPTXISD::IMAD", SDTIMAD>;
def MAD16rrr : NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
"mad.lo.s16 \t$dst, $a, $b, $c;",
[(set Int16Regs:$dst, (add
(mul Int16Regs:$a, Int16Regs:$b) , Int16Regs:$c))]>;
[(set Int16Regs:$dst,
(imad Int16Regs:$a, Int16Regs:$b, Int16Regs:$c))]>;
def MAD16rri : NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b, i16imm:$c),
"mad.lo.s16 \t$dst, $a, $b, $c;",
[(set Int16Regs:$dst, (add
(mul Int16Regs:$a, Int16Regs:$b) , imm:$c))]>;
[(set Int16Regs:$dst,
(imad Int16Regs:$a, Int16Regs:$b, imm:$c))]>;
def MAD16rir : NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, i16imm:$b, Int16Regs:$c),
"mad.lo.s16 \t$dst, $a, $b, $c;",
[(set Int16Regs:$dst, (add
(mul Int16Regs:$a, imm:$b) , Int16Regs:$c))]>;
[(set Int16Regs:$dst,
(imad Int16Regs:$a, imm:$b, Int16Regs:$c))]>;
def MAD16rii : NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, i16imm:$b, i16imm:$c),
"mad.lo.s16 \t$dst, $a, $b, $c;",
[(set Int16Regs:$dst, (add (mul Int16Regs:$a, imm:$b),
imm:$c))]>;
[(set Int16Regs:$dst,
(imad Int16Regs:$a, imm:$b, imm:$c))]>;
def MAD32rrr : NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
"mad.lo.s32 \t$dst, $a, $b, $c;",
[(set Int32Regs:$dst, (add
(mul Int32Regs:$a, Int32Regs:$b) , Int32Regs:$c))]>;
[(set Int32Regs:$dst,
(imad Int32Regs:$a, Int32Regs:$b, Int32Regs:$c))]>;
def MAD32rri : NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b, i32imm:$c),
"mad.lo.s32 \t$dst, $a, $b, $c;",
[(set Int32Regs:$dst, (add
(mul Int32Regs:$a, Int32Regs:$b) , imm:$c))]>;
[(set Int32Regs:$dst,
(imad Int32Regs:$a, Int32Regs:$b, imm:$c))]>;
def MAD32rir : NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
"mad.lo.s32 \t$dst, $a, $b, $c;",
[(set Int32Regs:$dst, (add
(mul Int32Regs:$a, imm:$b) , Int32Regs:$c))]>;
[(set Int32Regs:$dst,
(imad Int32Regs:$a, imm:$b, Int32Regs:$c))]>;
def MAD32rii : NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, i32imm:$b, i32imm:$c),
"mad.lo.s32 \t$dst, $a, $b, $c;",
[(set Int32Regs:$dst, (add
(mul Int32Regs:$a, imm:$b) , imm:$c))]>;
[(set Int32Regs:$dst,
(imad Int32Regs:$a, imm:$b, imm:$c))]>;
def MAD64rrr : NVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
"mad.lo.s64 \t$dst, $a, $b, $c;",
[(set Int64Regs:$dst, (add
(mul Int64Regs:$a, Int64Regs:$b) , Int64Regs:$c))]>;
[(set Int64Regs:$dst,
(imad Int64Regs:$a, Int64Regs:$b, Int64Regs:$c))]>;
def MAD64rri : NVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, Int64Regs:$b, i64imm:$c),
"mad.lo.s64 \t$dst, $a, $b, $c;",
[(set Int64Regs:$dst, (add
(mul Int64Regs:$a, Int64Regs:$b) , imm:$c))]>;
[(set Int64Regs:$dst,
(imad Int64Regs:$a, Int64Regs:$b, imm:$c))]>;
def MAD64rir : NVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, i64imm:$b, Int64Regs:$c),
"mad.lo.s64 \t$dst, $a, $b, $c;",
[(set Int64Regs:$dst, (add
(mul Int64Regs:$a, imm:$b) , Int64Regs:$c))]>;
[(set Int64Regs:$dst,
(imad Int64Regs:$a, imm:$b, Int64Regs:$c))]>;
def MAD64rii : NVPTXInst<(outs Int64Regs:$dst),
(ins Int64Regs:$a, i64imm:$b, i64imm:$c),
"mad.lo.s64 \t$dst, $a, $b, $c;",
[(set Int64Regs:$dst, (add
(mul Int64Regs:$a, imm:$b), imm:$c))]>;
[(set Int64Regs:$dst,
(imad Int64Regs:$a, imm:$b, imm:$c))]>;
def INEG16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
"neg.s16 \t$dst, $src;",
Expand Down
Expand Up
@@ -812,110 +868,59 @@ multiclass FPCONTRACT32<string OpcStr, Predicate Pred> {
def rrr : NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float32Regs:$dst, (fadd
(fmul Float32Regs:$a, Float32Regs:$b),
Float32Regs:$c))]>, Requires<[Pred]>;
// This is to WAR a weird bug in Tablegen that does not automatically
// generate the following permutated rule rrr2 from the above rrr.
// So we explicitly add it here. This happens to FMA32 only.
// See the comments at FMAD32 and FMA32 for more information.
def rrr2 : NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float32Regs:$dst, (fadd Float32Regs:$c,
(fmul Float32Regs:$a, Float32Regs:$b)))]>,
[(set Float32Regs:$dst,
(fma Float32Regs:$a, Float32Regs:$b, Float32Regs:$c))]>,
Requires<[Pred]>;
def rri : NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float32Regs:$dst, (fadd
(fmul Float32Regs:$a, Float32Regs:$b) , fpimm:$c))]>,
[(set Float32Regs:$dst,
(fma Float32Regs:$a, Float32Regs:$b, fpimm:$c))]>,
Requires<[Pred]>;
def rir : NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b, Float32Regs:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float32Regs:$dst, (fadd
(fmul Float32Regs:$a, fpimm:$b) , Float32Regs:$c))]>,
[(set Float32Regs:$dst,
(fma Float32Regs:$a, fpimm:$b, Float32Regs:$c))]>,
Requires<[Pred]>;
def rii : NVPTXInst<(outs Float32Regs:$dst),
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float32Regs:$dst, (fadd
(fmul Float32Regs:$a, fpimm:$b) , fpimm:$c))]>,
[(set Float32Regs:$dst,
(fma Float32Regs:$a, fpimm:$b, fpimm:$c))]>,
Requires<[Pred]>;
}
multiclass FPCONTRACT64<string OpcStr, Predicate Pred> {
def rrr : NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, Float64Regs:$b, Float64Regs:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float64Regs:$dst, (fadd
(fmul Float64Regs:$a, Float64Regs:$b) ,
Float64Regs:$c))]>, Requires<[Pred]>;
[(set Float64Regs:$dst,
(fma Float64Regs:$a, Float64Regs:$b, Float64Regs:$c))]> ,
Requires<[Pred]>;
def rri : NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, Float64Regs:$b, f64imm:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float64Regs:$dst, (fadd (fmul Float64Regs:$a,
Float64Regs:$b), fpimm:$c))]>, Requires<[Pred]>;
[(set Float64Regs:$dst,
(fma Float64Regs:$a, Float64Regs:$b, fpimm:$c))]>,
Requires<[Pred]>;
def rir : NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, f64imm:$b, Float64Regs:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float64Regs:$dst, (fadd
(fmul Float64Regs:$a, fpimm:$b) , Float64Regs:$c))]>,
[(set Float64Regs:$dst,
(fma Float64Regs:$a, fpimm:$b, Float64Regs:$c))]>,
Requires<[Pred]>;
def rii : NVPTXInst<(outs Float64Regs:$dst),
(ins Float64Regs:$a, f64imm:$b, f64imm:$c),
!strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
[(set Float64Regs:$dst, (fadd
(fmul Float64Regs:$a, fpimm:$b) , fpimm:$c))]>,
[(set Float64Regs:$dst,
(fma Float64Regs:$a, fpimm:$b, fpimm:$c))]>,
Requires<[Pred]>;
}
// Due to a unknown reason (most likely a bug in tablegen), tablegen does not
// automatically generate the rrr2 rule from
// the rrr rule (see FPCONTRACT32) for FMA32, though it does for FMAD32.
// If we reverse the order of the following two lines, then rrr2 rule will be
// generated for FMA32, but not for rrr.
// Therefore, we manually write the rrr2 rule in FPCONTRACT32.
defm FMA32_ftz : FPCONTRACT32<"fma.rn.ftz.f32", doFMAF32_ftz>;
defm FMA32 : FPCONTRACT32<"fma.rn.f32", doFMAF32>;
defm FMA64 : FPCONTRACT64<"fma.rn.f64", doFMAF64>;
// b*c-a => fmad(b, c, -a)
multiclass FPCONTRACT32_SUB_PAT_MAD<NVPTXInst Inst, Predicate Pred> {
def : Pat<(fsub (fmul Float32Regs:$b, Float32Regs:$c), Float32Regs:$a),
(Inst Float32Regs:$b, Float32Regs:$c, (FNEGf32 Float32Regs:$a))>,
Requires<[Pred]>;
}
// a-b*c => fmad(-b,c, a)
// - legal because a-b*c <=> a+(-b*c) <=> a+(-b)*c
// b*c-a => fmad(b, c, -a)
// - legal because b*c-a <=> b*c+(-a)
multiclass FPCONTRACT32_SUB_PAT<NVPTXInst Inst, Predicate Pred> {
def : Pat<(fsub Float32Regs:$a, (fmul Float32Regs:$b, Float32Regs:$c)),
(Inst (FNEGf32 Float32Regs:$b), Float32Regs:$c, Float32Regs:$a)>,
Requires<[Pred]>;
def : Pat<(fsub (fmul Float32Regs:$b, Float32Regs:$c), Float32Regs:$a),
(Inst Float32Regs:$b, Float32Regs:$c, (FNEGf32 Float32Regs:$a))>,
Requires<[Pred]>;
}
// a-b*c => fmad(-b,c, a)
// b*c-a => fmad(b, c, -a)
multiclass FPCONTRACT64_SUB_PAT<NVPTXInst Inst, Predicate Pred> {
def : Pat<(fsub Float64Regs:$a, (fmul Float64Regs:$b, Float64Regs:$c)),
(Inst (FNEGf64 Float64Regs:$b), Float64Regs:$c, Float64Regs:$a)>,
Requires<[Pred]>;
def : Pat<(fsub (fmul Float64Regs:$b, Float64Regs:$c), Float64Regs:$a),
(Inst Float64Regs:$b, Float64Regs:$c, (FNEGf64 Float64Regs:$a))>,
Requires<[Pred]>;
}
defm FMAF32ext_ftz : FPCONTRACT32_SUB_PAT<FMA32_ftzrrr, doFMAF32AGG_ftz>;
defm FMAF32ext : FPCONTRACT32_SUB_PAT<FMA32rrr, doFMAF32AGG>;
defm FMAF64ext : FPCONTRACT64_SUB_PAT<FMA64rrr, doFMAF64AGG>;
defm FMA32_ftz : FPCONTRACT32<"fma.rn.ftz.f32", doF32FTZ>;
defm FMA32 : FPCONTRACT32<"fma.rn.f32", doNoF32FTZ>;
defm FMA64 : FPCONTRACT64<"fma.rn.f64", doNoF32FTZ>;
def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
"sin.approx.f32 \t$dst, $src;",
Expand Down