Skip to content

Commit

Permalink
[RISCV] Add rvv codegen support for vp.fpext.
Browse files Browse the repository at this point in the history
This patch adds rvv codegen support for vp.fpext. The lowering of fp_round, vp.fptrunc, fp_extend and vp.fpext share most code so use a common lowering function to handle these four.
And this patch changes the intermediate cast from ISD::FP_EXTEND/ISD::FP_ROUND to the RVV VL version op RISCVISD::FP_EXTEND_VL and RISCVISD::FP_ROUND_VL for scalable vectors.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D123975
  • Loading branch information
jacquesguan authored and jacquesguan committed May 11, 2022
1 parent d4609ae commit 2509dcd
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 75 deletions.
109 changes: 35 additions & 74 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -443,7 +443,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX,
ISD::VP_MERGE, ISD::VP_SELECT,
ISD::VP_SITOFP, ISD::VP_UITOFP,
ISD::VP_SETCC, ISD::VP_FP_ROUND};
ISD::VP_SETCC, ISD::VP_FP_ROUND,
ISD::VP_FP_EXTEND};

if (!Subtarget.is64Bit()) {
// We must custom-lower certain vXi64 operations on RV32 due to the vector
Expand Down Expand Up @@ -2795,21 +2796,6 @@ bool RISCVTargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
isInterleaveShuffle(M, SVT, SwapSources, Subtarget);
}

static SDValue getRVVFPExtendOrRound(SDValue Op, MVT VT, MVT ContainerVT,
SDLoc DL, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (VT.isScalableVector())
return DAG.getFPExtendOrRound(Op, DL, VT);
assert(VT.isFixedLengthVector() &&
"Unexpected value type for RVV FP extend/round lowering");
SDValue Mask, VL;
std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
unsigned RVVOpc = ContainerVT.bitsGT(Op.getSimpleValueType())
? RISCVISD::FP_EXTEND_VL
: RISCVISD::FP_ROUND_VL;
return DAG.getNode(RVVOpc, DL, ContainerVT, Op, Mask, VL);
}

// Lower CTLZ_ZERO_UNDEF or CTTZ_ZERO_UNDEF by converting to FP and extracting
// the exponent.
static SDValue lowerCTLZ_CTTZ_ZERO_UNDEF(SDValue Op, SelectionDAG &DAG) {
Expand Down Expand Up @@ -3126,50 +3112,11 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
}
return SDValue();
}
case ISD::FP_EXTEND: {
// RVV can only do fp_extend to types double the size as the source. We
// custom-lower f16->f64 extensions to two hops of ISD::FP_EXTEND, going
// via f32.
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
SDValue Src = Op.getOperand(0);
MVT SrcVT = Src.getSimpleValueType();

// Prepare any fixed-length vector operands.
MVT ContainerVT = VT;
if (SrcVT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VT);
MVT SrcContainerVT =
ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
}

if (!VT.isVector() || VT.getVectorElementType() != MVT::f64 ||
SrcVT.getVectorElementType() != MVT::f16) {
// For scalable vectors, we only need to close the gap between
// vXf16->vXf64.
if (!VT.isFixedLengthVector())
return Op;
// For fixed-length vectors, lower the FP_EXTEND to a custom "VL" version.
Src = getRVVFPExtendOrRound(Src, VT, ContainerVT, DL, DAG, Subtarget);
return convertFromScalableVector(VT, Src, DAG, Subtarget);
}

MVT InterVT = VT.changeVectorElementType(MVT::f32);
MVT InterContainerVT = ContainerVT.changeVectorElementType(MVT::f32);
SDValue IntermediateExtend = getRVVFPExtendOrRound(
Src, InterVT, InterContainerVT, DL, DAG, Subtarget);

SDValue Extend = getRVVFPExtendOrRound(IntermediateExtend, VT, ContainerVT,
DL, DAG, Subtarget);
if (VT.isFixedLengthVector())
return convertFromScalableVector(VT, Extend, DAG, Subtarget);
return Extend;
}
case ISD::FP_EXTEND:
case ISD::FP_ROUND:
if (!Op.getValueType().isVector())
return Op;
return lowerVectorFPRoundLike(Op, DAG);
return lowerVectorFPExtendOrRoundLike(Op, DAG);
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
case ISD::SINT_TO_FP:
Expand Down Expand Up @@ -3512,8 +3459,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
: RISCVISD::VZEXT_VL);
case ISD::VP_TRUNC:
return lowerVectorTruncLike(Op, DAG);
case ISD::VP_FP_EXTEND:
case ISD::VP_FP_ROUND:
return lowerVectorFPRoundLike(Op, DAG);
return lowerVectorFPExtendOrRoundLike(Op, DAG);
case ISD::VP_FPTOSI:
return lowerVPFPIntConvOp(Op, DAG, RISCVISD::FP_TO_SINT_VL);
case ISD::VP_FPTOUI:
Expand Down Expand Up @@ -4281,9 +4229,13 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
return Result;
}

SDValue RISCVTargetLowering::lowerVectorFPRoundLike(SDValue Op,
SDValue
RISCVTargetLowering::lowerVectorFPExtendOrRoundLike(SDValue Op,
SelectionDAG &DAG) const {
bool IsVPFPTrunc = Op.getOpcode() == ISD::VP_FP_ROUND;
bool IsVP =
Op.getOpcode() == ISD::VP_FP_ROUND || Op.getOpcode() == ISD::VP_FP_EXTEND;
bool IsExtend =
Op.getOpcode() == ISD::VP_FP_EXTEND || Op.getOpcode() == ISD::FP_EXTEND;
// RVV can only do truncate fp to types half the size as the source. We
// custom-lower f64->f16 rounds via RVV's round-to-odd float
// conversion instruction.
Expand All @@ -4295,17 +4247,21 @@ SDValue RISCVTargetLowering::lowerVectorFPRoundLike(SDValue Op,
SDValue Src = Op.getOperand(0);
MVT SrcVT = Src.getSimpleValueType();

bool IsDirectConv = VT.getVectorElementType() != MVT::f16 ||
SrcVT.getVectorElementType() != MVT::f64;
bool IsDirectExtend = IsExtend && (VT.getVectorElementType() != MVT::f64 ||
SrcVT.getVectorElementType() != MVT::f16);
bool IsDirectTrunc = !IsExtend && (VT.getVectorElementType() != MVT::f16 ||
SrcVT.getVectorElementType() != MVT::f64);

// For FP_ROUND of scalable vectors, leave it to the pattern.
if (!VT.isFixedLengthVector() && !IsVPFPTrunc && IsDirectConv)
bool IsDirectConv = IsDirectExtend || IsDirectTrunc;

// For FP_ROUND/FP_EXTEND of scalable vectors, leave it to the pattern.
if (!VT.isFixedLengthVector() && !IsVP && IsDirectConv)
return Op;

// Prepare any fixed-length vector operands.
MVT ContainerVT = VT;
SDValue Mask, VL;
if (IsVPFPTrunc) {
if (IsVP) {
Mask = Op.getOperand(1);
VL = Op.getOperand(2);
}
Expand All @@ -4314,31 +4270,36 @@ SDValue RISCVTargetLowering::lowerVectorFPRoundLike(SDValue Op,
ContainerVT =
SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
if (IsVPFPTrunc) {
if (IsVP) {
MVT MaskVT = getMaskTypeFor(ContainerVT);
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
}
}

if (!IsVPFPTrunc)
if (!IsVP)
std::tie(Mask, VL) =
getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget);

unsigned ConvOpc = IsExtend ? RISCVISD::FP_EXTEND_VL : RISCVISD::FP_ROUND_VL;

if (IsDirectConv) {
Src = DAG.getNode(RISCVISD::FP_ROUND_VL, DL, ContainerVT, Src, Mask, VL);
Src = DAG.getNode(ConvOpc, DL, ContainerVT, Src, Mask, VL);
if (VT.isFixedLengthVector())
Src = convertFromScalableVector(VT, Src, DAG, Subtarget);
return Src;
}

unsigned InterConvOpc =
IsExtend ? RISCVISD::FP_EXTEND_VL : RISCVISD::VFNCVT_ROD_VL;

MVT InterVT = ContainerVT.changeVectorElementType(MVT::f32);
SDValue IntermediateRound =
DAG.getNode(RISCVISD::VFNCVT_ROD_VL, DL, InterVT, Src, Mask, VL);
SDValue Round = DAG.getNode(RISCVISD::FP_ROUND_VL, DL, ContainerVT,
IntermediateRound, Mask, VL);
SDValue IntermediateConv =
DAG.getNode(InterConvOpc, DL, InterVT, Src, Mask, VL);
SDValue Result =
DAG.getNode(ConvOpc, DL, ContainerVT, IntermediateConv, Mask, VL);
if (VT.isFixedLengthVector())
return convertFromScalableVector(VT, Round, DAG, Subtarget);
return Round;
return convertFromScalableVector(VT, Result, DAG, Subtarget);
return Result;
}

// Custom-legalize INSERT_VECTOR_ELT so that the value is inserted into the
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.h
Expand Up @@ -614,7 +614,7 @@ class RISCVTargetLowering : public TargetLowering {
int64_t ExtTrueVal) const;
SDValue lowerVectorMaskTruncLike(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVectorTruncLike(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVectorFPRoundLike(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVectorFPExtendOrRoundLike(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
Expand Down
77 changes: 77 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vector-fpext-vp.ll
@@ -0,0 +1,77 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v -riscv-v-vector-bits-min=128 -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v -riscv-v-vector-bits-min=128 -verify-machineinstrs < %s | FileCheck %s

declare <2 x float> @llvm.vp.fpext.v2f32.v2f16(<2 x half>, <2 x i1>, i32)

define <2 x float> @vfpext_v2f16_v2f32(<2 x half> %a, <2 x i1> %m, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_v2f16_v2f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v9, v8, v0.t
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%v = call <2 x float> @llvm.vp.fpext.v2f32.v2f16(<2 x half> %a, <2 x i1> %m, i32 %vl)
ret <2 x float> %v
}

define <2 x float> @vfpext_v2f16_v2f32_unmasked(<2 x half> %a, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_v2f16_v2f32_unmasked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v9, v8
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%v = call <2 x float> @llvm.vp.fpext.v2f32.v2f16(<2 x half> %a, <2 x i1> shufflevector (<2 x i1> insertelement (<2 x i1> undef, i1 true, i32 0), <2 x i1> undef, <2 x i32> zeroinitializer), i32 %vl)
ret <2 x float> %v
}

declare <2 x double> @llvm.vp.fpext.v2f64.v2f16(<2 x half>, <2 x i1>, i32)

define <2 x double> @vfpext_v2f16_v2f64(<2 x half> %a, <2 x i1> %m, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_v2f16_v2f64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v9, v8, v0.t
; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v8, v9, v0.t
; CHECK-NEXT: ret
%v = call <2 x double> @llvm.vp.fpext.v2f64.v2f16(<2 x half> %a, <2 x i1> %m, i32 %vl)
ret <2 x double> %v
}

define <2 x double> @vfpext_v2f16_v2f64_unmasked(<2 x half> %a, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_v2f16_v2f64_unmasked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v9, v8
; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v8, v9
; CHECK-NEXT: ret
%v = call <2 x double> @llvm.vp.fpext.v2f64.v2f16(<2 x half> %a, <2 x i1> shufflevector (<2 x i1> insertelement (<2 x i1> undef, i1 true, i32 0), <2 x i1> undef, <2 x i32> zeroinitializer), i32 %vl)
ret <2 x double> %v
}

declare <2 x double> @llvm.vp.fpext.v2f64.v2f32(<2 x float>, <2 x i1>, i32)

define <2 x double> @vfpext_v2f32_v2f64(<2 x float> %a, <2 x i1> %m, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_v2f32_v2f64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v9, v8, v0.t
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%v = call <2 x double> @llvm.vp.fpext.v2f64.v2f32(<2 x float> %a, <2 x i1> %m, i32 %vl)
ret <2 x double> %v
}

define <2 x double> @vfpext_v2f32_v2f64_unmasked(<2 x float> %a, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_v2f32_v2f64_unmasked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v9, v8
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%v = call <2 x double> @llvm.vp.fpext.v2f64.v2f32(<2 x float> %a, <2 x i1> shufflevector (<2 x i1> insertelement (<2 x i1> undef, i1 true, i32 0), <2 x i1> undef, <2 x i32> zeroinitializer), i32 %vl)
ret <2 x double> %v
}
77 changes: 77 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vfpext-vp.ll
@@ -0,0 +1,77 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v -verify-machineinstrs < %s | FileCheck %s
; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v -verify-machineinstrs < %s | FileCheck %s

declare <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2f16(<vscale x 2 x half>, <vscale x 2 x i1>, i32)

define <vscale x 2 x float> @vfpext_nxv2f16_nxv2f32(<vscale x 2 x half> %a, <vscale x 2 x i1> %m, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_nxv2f16_nxv2f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v9, v8, v0.t
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%v = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2f16(<vscale x 2 x half> %a, <vscale x 2 x i1> %m, i32 %vl)
ret <vscale x 2 x float> %v
}

define <vscale x 2 x float> @vfpext_nxv2f16_nxv2f32_unmasked(<vscale x 2 x half> %a, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_nxv2f16_nxv2f32_unmasked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v9, v8
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%v = call <vscale x 2 x float> @llvm.vp.fpext.nxv2f32.nxv2f16(<vscale x 2 x half> %a, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %vl)
ret <vscale x 2 x float> %v
}

declare <vscale x 2 x double> @llvm.vp.fpext.nxv2f64.nxv2f16(<vscale x 2 x half>, <vscale x 2 x i1>, i32)

define <vscale x 2 x double> @vfpext_nxv2f16_nxv2f64(<vscale x 2 x half> %a, <vscale x 2 x i1> %m, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_nxv2f16_nxv2f64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v10, v8, v0.t
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v8, v10, v0.t
; CHECK-NEXT: ret
%v = call <vscale x 2 x double> @llvm.vp.fpext.nxv2f64.nxv2f16(<vscale x 2 x half> %a, <vscale x 2 x i1> %m, i32 %vl)
ret <vscale x 2 x double> %v
}

define <vscale x 2 x double> @vfpext_nxv2f16_nxv2f64_unmasked(<vscale x 2 x half> %a, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_nxv2f16_nxv2f64_unmasked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v10, v8
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v8, v10
; CHECK-NEXT: ret
%v = call <vscale x 2 x double> @llvm.vp.fpext.nxv2f64.nxv2f16(<vscale x 2 x half> %a, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %vl)
ret <vscale x 2 x double> %v
}

declare <vscale x 2 x double> @llvm.vp.fpext.nxv2f64.nxv2f32(<vscale x 2 x float>, <vscale x 2 x i1>, i32)

define <vscale x 2 x double> @vfpext_nxv2f32_nxv2f64(<vscale x 2 x float> %a, <vscale x 2 x i1> %m, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_nxv2f32_nxv2f64:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v10, v8, v0.t
; CHECK-NEXT: vmv2r.v v8, v10
; CHECK-NEXT: ret
%v = call <vscale x 2 x double> @llvm.vp.fpext.nxv2f64.nxv2f32(<vscale x 2 x float> %a, <vscale x 2 x i1> %m, i32 %vl)
ret <vscale x 2 x double> %v
}

define <vscale x 2 x double> @vfpext_nxv2f32_nxv2f64_unmasked(<vscale x 2 x float> %a, i32 zeroext %vl) {
; CHECK-LABEL: vfpext_nxv2f32_nxv2f64_unmasked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, mu
; CHECK-NEXT: vfwcvt.f.f.v v10, v8
; CHECK-NEXT: vmv2r.v v8, v10
; CHECK-NEXT: ret
%v = call <vscale x 2 x double> @llvm.vp.fpext.nxv2f64.nxv2f32(<vscale x 2 x float> %a, <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer), i32 %vl)
ret <vscale x 2 x double> %v
}

0 comments on commit 2509dcd

Please sign in to comment.