Skip to content

Commit

Permalink
[LegalizeTypes][VP] Add widening and splitting support for VP_FMA.
Browse files Browse the repository at this point in the history
Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D120854
  • Loading branch information
topperc committed Mar 8, 2022
1 parent c392b99 commit 29511ec
Show file tree
Hide file tree
Showing 3 changed files with 470 additions and 11 deletions.
41 changes: 34 additions & 7 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Expand Up @@ -1060,7 +1060,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::ROTR:
SplitVecRes_BinOp(N, Lo, Hi);
break;
case ISD::FMA:
case ISD::FMA: case ISD::VP_FMA:
case ISD::FSHL:
case ISD::FSHR:
SplitVecRes_TernaryOp(N, Lo, Hi);
Expand Down Expand Up @@ -1182,10 +1182,28 @@ void DAGTypeLegalizer::SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo,
GetSplitVector(N->getOperand(2), Op2Lo, Op2Hi);
SDLoc dl(N);

Lo = DAG.getNode(N->getOpcode(), dl, Op0Lo.getValueType(), Op0Lo, Op1Lo,
Op2Lo, N->getFlags());
Hi = DAG.getNode(N->getOpcode(), dl, Op0Hi.getValueType(), Op0Hi, Op1Hi,
Op2Hi, N->getFlags());
const SDNodeFlags Flags = N->getFlags();
unsigned Opcode = N->getOpcode();
if (N->getNumOperands() == 3) {
Lo = DAG.getNode(Opcode, dl, Op0Lo.getValueType(), Op0Lo, Op1Lo, Op2Lo, Flags);
Hi = DAG.getNode(Opcode, dl, Op0Hi.getValueType(), Op0Hi, Op1Hi, Op2Hi, Flags);
return;
}

assert(N->getNumOperands() == 5 && "Unexpected number of operands!");
assert(N->isVPOpcode() && "Expected VP opcode");

SDValue MaskLo, MaskHi;
std::tie(MaskLo, MaskHi) = SplitMask(N->getOperand(3));

SDValue EVLLo, EVLHi;
std::tie(EVLLo, EVLHi) =
DAG.SplitEVL(N->getOperand(4), N->getValueType(0), dl);

Lo = DAG.getNode(Opcode, dl, Op0Lo.getValueType(),
{Op0Lo, Op1Lo, Op2Lo, MaskLo, EVLLo}, Flags);
Hi = DAG.getNode(Opcode, dl, Op0Hi.getValueType(),
{Op0Hi, Op1Hi, Op2Hi, MaskHi, EVLHi}, Flags);
}

void DAGTypeLegalizer::SplitVecRes_FIX(SDNode *N, SDValue &Lo, SDValue &Hi) {
Expand Down Expand Up @@ -3441,7 +3459,7 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FCANONICALIZE:
Res = WidenVecRes_Unary(N);
break;
case ISD::FMA:
case ISD::FMA: case ISD::VP_FMA:
case ISD::FSHL:
case ISD::FSHR:
Res = WidenVecRes_Ternary(N);
Expand All @@ -3460,7 +3478,16 @@ SDValue DAGTypeLegalizer::WidenVecRes_Ternary(SDNode *N) {
SDValue InOp1 = GetWidenedVector(N->getOperand(0));
SDValue InOp2 = GetWidenedVector(N->getOperand(1));
SDValue InOp3 = GetWidenedVector(N->getOperand(2));
return DAG.getNode(N->getOpcode(), dl, WidenVT, InOp1, InOp2, InOp3);
if (N->getNumOperands() == 3)
return DAG.getNode(N->getOpcode(), dl, WidenVT, InOp1, InOp2, InOp3);

assert(N->getNumOperands() == 5 && "Unexpected number of operands!");
assert(N->isVPOpcode() && "Expected VP opcode");

SDValue Mask =
GetWidenedMask(N->getOperand(3), WidenVT.getVectorElementCount());
return DAG.getNode(N->getOpcode(), dl, WidenVT,
{InOp1, InOp2, InOp3, Mask, N->getOperand(4)});
}

SDValue DAGTypeLegalizer::WidenVecRes_Binary(SDNode *N) {
Expand Down
221 changes: 219 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfma-vp.ll
@@ -1,7 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+v -target-abi=ilp32d -riscv-v-vector-bits-min=128 \
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+v,+m -target-abi=ilp32d -riscv-v-vector-bits-min=128 \
; RUN: -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+v -target-abi=lp64d -riscv-v-vector-bits-min=128 \
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+v,+m -target-abi=lp64d -riscv-v-vector-bits-min=128 \
; RUN: -verify-machineinstrs < %s | FileCheck %s

declare <2 x half> @llvm.vp.fma.v2f16(<2 x half>, <2 x half>, <2 x half>, <2 x i1>, i32)
Expand Down Expand Up @@ -565,6 +565,35 @@ define <8 x double> @vfma_vf_v8f64_unmasked(<8 x double> %va, double %b, <8 x do
ret <8 x double> %v
}

declare <15 x double> @llvm.vp.fma.v15f64(<15 x double>, <15 x double>, <15 x double>, <15 x i1>, i32)

define <15 x double> @vfma_vv_v15f64(<15 x double> %va, <15 x double> %b, <15 x double> %c, <15 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vfma_vv_v15f64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, mu
; CHECK-NEXT: vle64.v v24, (a0)
; CHECK-NEXT: vsetvli zero, a1, e64, m8, tu, mu
; CHECK-NEXT: vfmadd.vv v16, v8, v24, v0.t
; CHECK-NEXT: vmv8r.v v8, v16
; CHECK-NEXT: ret
%v = call <15 x double> @llvm.vp.fma.v15f64(<15 x double> %va, <15 x double> %b, <15 x double> %c, <15 x i1> %m, i32 %evl)
ret <15 x double> %v
}

define <15 x double> @vfma_vv_v15f64_unmasked(<15 x double> %va, <15 x double> %b, <15 x double> %c, i32 zeroext %evl) {
; CHECK-LABEL: vfma_vv_v15f64_unmasked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, mu
; CHECK-NEXT: vle64.v v24, (a0)
; CHECK-NEXT: vsetvli zero, a1, e64, m8, ta, mu
; CHECK-NEXT: vfmadd.vv v8, v16, v24
; CHECK-NEXT: ret
%head = insertelement <15 x i1> poison, i1 true, i32 0
%m = shufflevector <15 x i1> %head, <15 x i1> poison, <15 x i32> zeroinitializer
%v = call <15 x double> @llvm.vp.fma.v15f64(<15 x double> %va, <15 x double> %b, <15 x double> %c, <15 x i1> %m, i32 %evl)
ret <15 x double> %v
}

declare <16 x double> @llvm.vp.fma.v16f64(<16 x double>, <16 x double>, <16 x double>, <16 x i1>, i32)

define <16 x double> @vfma_vv_v16f64(<16 x double> %va, <16 x double> %b, <16 x double> %c, <16 x i1> %m, i32 zeroext %evl) {
Expand Down Expand Up @@ -619,3 +648,191 @@ define <16 x double> @vfma_vf_v16f64_unmasked(<16 x double> %va, double %b, <16
%v = call <16 x double> @llvm.vp.fma.v16f64(<16 x double> %va, <16 x double> %vb, <16 x double> %vc, <16 x i1> %m, i32 %evl)
ret <16 x double> %v
}

declare <32 x double> @llvm.vp.fma.v32f64(<32 x double>, <32 x double>, <32 x double>, <32 x i1>, i32)

define <32 x double> @vfma_vv_v32f64(<32 x double> %va, <32 x double> %b, <32 x double> %c, <32 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vfma_vv_v32f64:
; CHECK: # %bb.0:
; CHECK-NEXT: addi sp, sp, -16
; CHECK-NEXT: .cfi_def_cfa_offset 16
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: li a3, 48
; CHECK-NEXT: mul a1, a1, a3
; CHECK-NEXT: sub sp, sp, a1
; CHECK-NEXT: vmv1r.v v1, v0
; CHECK-NEXT: vsetivli zero, 2, e8, mf4, ta, mu
; CHECK-NEXT: vslidedown.vi v0, v0, 2
; CHECK-NEXT: addi a1, a2, 128
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, mu
; CHECK-NEXT: vle64.v v24, (a1)
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: li a3, 24
; CHECK-NEXT: mul a1, a1, a3
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vs8r.v v24, (a1) # Unknown-size Folded Spill
; CHECK-NEXT: addi a1, a0, 128
; CHECK-NEXT: vle64.v v24, (a1)
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: li a3, 40
; CHECK-NEXT: mul a1, a1, a3
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vs8r.v v24, (a1) # Unknown-size Folded Spill
; CHECK-NEXT: addi a3, a4, -16
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: slli a1, a1, 3
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vs8r.v v16, (a1) # Unknown-size Folded Spill
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: slli a1, a1, 5
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vs8r.v v8, (a1) # Unknown-size Folded Spill
; CHECK-NEXT: li a1, 0
; CHECK-NEXT: bltu a4, a3, .LBB50_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a1, a3
; CHECK-NEXT: .LBB50_2:
; CHECK-NEXT: vle64.v v8, (a2)
; CHECK-NEXT: csrr a2, vlenb
; CHECK-NEXT: slli a2, a2, 4
; CHECK-NEXT: add a2, sp, a2
; CHECK-NEXT: addi a2, a2, 16
; CHECK-NEXT: vs8r.v v8, (a2) # Unknown-size Folded Spill
; CHECK-NEXT: vle64.v v8, (a0)
; CHECK-NEXT: addi a0, sp, 16
; CHECK-NEXT: vs8r.v v8, (a0) # Unknown-size Folded Spill
; CHECK-NEXT: vsetvli zero, a1, e64, m8, tu, mu
; CHECK-NEXT: li a0, 16
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: li a2, 24
; CHECK-NEXT: mul a1, a1, a2
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vl8re8.v v8, (a1) # Unknown-size Folded Reload
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: slli a1, a1, 3
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vl8re8.v v24, (a1) # Unknown-size Folded Reload
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: li a2, 40
; CHECK-NEXT: mul a1, a1, a2
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vl8re8.v v16, (a1) # Unknown-size Folded Reload
; CHECK-NEXT: vfmadd.vv v16, v24, v8, v0.t
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: li a2, 40
; CHECK-NEXT: mul a1, a1, a2
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vs8r.v v16, (a1) # Unknown-size Folded Spill
; CHECK-NEXT: bltu a4, a0, .LBB50_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: li a4, 16
; CHECK-NEXT: .LBB50_4:
; CHECK-NEXT: vsetvli zero, a4, e64, m8, tu, mu
; CHECK-NEXT: vmv1r.v v0, v1
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: slli a0, a0, 5
; CHECK-NEXT: add a0, sp, a0
; CHECK-NEXT: addi a0, a0, 16
; CHECK-NEXT: vl8re8.v v8, (a0) # Unknown-size Folded Reload
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: slli a0, a0, 4
; CHECK-NEXT: add a0, sp, a0
; CHECK-NEXT: addi a0, a0, 16
; CHECK-NEXT: vl8re8.v v24, (a0) # Unknown-size Folded Reload
; CHECK-NEXT: addi a0, sp, 16
; CHECK-NEXT: vl8re8.v v16, (a0) # Unknown-size Folded Reload
; CHECK-NEXT: vfmadd.vv v16, v8, v24, v0.t
; CHECK-NEXT: vmv8r.v v8, v16
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: li a1, 40
; CHECK-NEXT: mul a0, a0, a1
; CHECK-NEXT: add a0, sp, a0
; CHECK-NEXT: addi a0, a0, 16
; CHECK-NEXT: vl8re8.v v16, (a0) # Unknown-size Folded Reload
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: li a1, 48
; CHECK-NEXT: mul a0, a0, a1
; CHECK-NEXT: add sp, sp, a0
; CHECK-NEXT: addi sp, sp, 16
; CHECK-NEXT: ret
%v = call <32 x double> @llvm.vp.fma.v32f64(<32 x double> %va, <32 x double> %b, <32 x double> %c, <32 x i1> %m, i32 %evl)
ret <32 x double> %v
}

define <32 x double> @vfma_vv_v32f64_unmasked(<32 x double> %va, <32 x double> %b, <32 x double> %c, i32 zeroext %evl) {
; CHECK-LABEL: vfma_vv_v32f64_unmasked:
; CHECK: # %bb.0:
; CHECK-NEXT: addi sp, sp, -16
; CHECK-NEXT: .cfi_def_cfa_offset 16
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: li a3, 24
; CHECK-NEXT: mul a1, a1, a3
; CHECK-NEXT: sub sp, sp, a1
; CHECK-NEXT: addi a1, a2, 128
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, mu
; CHECK-NEXT: vle64.v v24, (a1)
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: slli a1, a1, 3
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vs8r.v v24, (a1) # Unknown-size Folded Spill
; CHECK-NEXT: addi a1, a0, 128
; CHECK-NEXT: vle64.v v24, (a1)
; CHECK-NEXT: addi a3, a4, -16
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: slli a1, a1, 4
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vs8r.v v8, (a1) # Unknown-size Folded Spill
; CHECK-NEXT: li a1, 0
; CHECK-NEXT: bltu a4, a3, .LBB51_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a1, a3
; CHECK-NEXT: .LBB51_2:
; CHECK-NEXT: vle64.v v8, (a2)
; CHECK-NEXT: addi a2, sp, 16
; CHECK-NEXT: vs8r.v v8, (a2) # Unknown-size Folded Spill
; CHECK-NEXT: vle64.v v0, (a0)
; CHECK-NEXT: vsetvli zero, a1, e64, m8, ta, mu
; CHECK-NEXT: li a0, 16
; CHECK-NEXT: csrr a1, vlenb
; CHECK-NEXT: slli a1, a1, 3
; CHECK-NEXT: add a1, sp, a1
; CHECK-NEXT: addi a1, a1, 16
; CHECK-NEXT: vl8re8.v v8, (a1) # Unknown-size Folded Reload
; CHECK-NEXT: vfmadd.vv v24, v16, v8
; CHECK-NEXT: bltu a4, a0, .LBB51_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: li a4, 16
; CHECK-NEXT: .LBB51_4:
; CHECK-NEXT: vsetvli zero, a4, e64, m8, ta, mu
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: slli a0, a0, 4
; CHECK-NEXT: add a0, sp, a0
; CHECK-NEXT: addi a0, a0, 16
; CHECK-NEXT: vl8re8.v v16, (a0) # Unknown-size Folded Reload
; CHECK-NEXT: addi a0, sp, 16
; CHECK-NEXT: vl8re8.v v8, (a0) # Unknown-size Folded Reload
; CHECK-NEXT: vfmadd.vv v0, v16, v8
; CHECK-NEXT: vmv.v.v v8, v0
; CHECK-NEXT: vmv8r.v v16, v24
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: li a1, 24
; CHECK-NEXT: mul a0, a0, a1
; CHECK-NEXT: add sp, sp, a0
; CHECK-NEXT: addi sp, sp, 16
; CHECK-NEXT: ret
%head = insertelement <32 x i1> poison, i1 true, i32 0
%m = shufflevector <32 x i1> %head, <32 x i1> poison, <32 x i32> zeroinitializer
%v = call <32 x double> @llvm.vp.fma.v32f64(<32 x double> %va, <32 x double> %b, <32 x double> %c, <32 x i1> %m, i32 %evl)
ret <32 x double> %v
}

0 comments on commit 29511ec

Please sign in to comment.