Skip to content

Commit

Permalink
[RISCV] Make fixed-point instructions commutable (#90035)
Browse files Browse the repository at this point in the history
This PR includes:
* vsadd.vv/vsaddu.vv
* vaadd.vv/vaaddu.vv
* vsmul.vv
  • Loading branch information
wangpc-pp committed Apr 28, 2024
1 parent c705c68 commit 2c1c887
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 13 deletions.
5 changes: 5 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3132,6 +3132,11 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI,
case CASE_RVV_OPCODE_WIDEN(VWMACC_VV):
case CASE_RVV_OPCODE_WIDEN(VWMACCU_VV):
case CASE_RVV_OPCODE_UNMASK(VADC_VVM):
case CASE_RVV_OPCODE(VSADD_VV):
case CASE_RVV_OPCODE(VSADDU_VV):
case CASE_RVV_OPCODE(VAADD_VV):
case CASE_RVV_OPCODE(VAADDU_VV):
case CASE_RVV_OPCODE(VSMUL_VV):
// Operands 2 and 3 are commutable.
return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 2, 3);
case CASE_VFMA_SPLATS(FMADD):
Expand Down
29 changes: 16 additions & 13 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -2146,8 +2146,9 @@ multiclass VPseudoBinaryRoundingMode<VReg RetClass,
string Constraint = "",
int sew = 0,
int UsesVXRM = 1,
int TargetConstraintType = 1> {
let VLMul = MInfo.value, SEW=sew in {
int TargetConstraintType = 1,
bit Commutable = 0> {
let VLMul = MInfo.value, SEW=sew, isCommutable = Commutable in {
defvar suffix = !if(sew, "_" # MInfo.MX # "_E" # sew, "_" # MInfo.MX);
def suffix : VPseudoBinaryNoMaskRoundingMode<RetClass, Op1Class, Op2Class,
Constraint, UsesVXRM,
Expand Down Expand Up @@ -2232,8 +2233,9 @@ multiclass VPseudoBinaryV_VV<LMULInfo m, string Constraint = "", int sew = 0, bi
defm _VV : VPseudoBinary<m.vrclass, m.vrclass, m.vrclass, m, Constraint, sew, Commutable=Commutable>;
}

multiclass VPseudoBinaryV_VV_RM<LMULInfo m, string Constraint = ""> {
defm _VV : VPseudoBinaryRoundingMode<m.vrclass, m.vrclass, m.vrclass, m, Constraint>;
multiclass VPseudoBinaryV_VV_RM<LMULInfo m, string Constraint = "", bit Commutable = 0> {
defm _VV : VPseudoBinaryRoundingMode<m.vrclass, m.vrclass, m.vrclass, m, Constraint,
Commutable=Commutable>;
}

// Similar to VPseudoBinaryV_VV, but uses MxListF.
Expand Down Expand Up @@ -2715,10 +2717,11 @@ multiclass VPseudoVGTR_VV_VX_VI<Operand ImmType = simm5, string Constraint = "">
}
}

multiclass VPseudoVSALU_VV_VX_VI<Operand ImmType = simm5, string Constraint = ""> {
multiclass VPseudoVSALU_VV_VX_VI<Operand ImmType = simm5, string Constraint = "",
bit Commutable = 0> {
foreach m = MxList in {
defvar mx = m.MX;
defm "" : VPseudoBinaryV_VV<m, Constraint>,
defm "" : VPseudoBinaryV_VV<m, Constraint, Commutable=Commutable>,
SchedBinary<"WriteVSALUV", "ReadVSALUV", "ReadVSALUX", mx,
forceMergeOpRead=true>;
defm "" : VPseudoBinaryV_VX<m, Constraint>,
Expand Down Expand Up @@ -2788,7 +2791,7 @@ multiclass VPseudoVSALU_VV_VX {
multiclass VPseudoVSMUL_VV_VX_RM {
foreach m = MxList in {
defvar mx = m.MX;
defm "" : VPseudoBinaryV_VV_RM<m>,
defm "" : VPseudoBinaryV_VV_RM<m, Commutable=1>,
SchedBinary<"WriteVSMulV", "ReadVSMulV", "ReadVSMulV", mx,
forceMergeOpRead=true>;
defm "" : VPseudoBinaryV_VX_RM<m>,
Expand All @@ -2797,10 +2800,10 @@ multiclass VPseudoVSMUL_VV_VX_RM {
}
}

multiclass VPseudoVAALU_VV_VX_RM {
multiclass VPseudoVAALU_VV_VX_RM<bit Commutable = 0> {
foreach m = MxList in {
defvar mx = m.MX;
defm "" : VPseudoBinaryV_VV_RM<m>,
defm "" : VPseudoBinaryV_VV_RM<m, Commutable=Commutable>,
SchedBinary<"WriteVAALUV", "ReadVAALUV", "ReadVAALUV", mx,
forceMergeOpRead=true>;
defm "" : VPseudoBinaryV_VX_RM<m>,
Expand Down Expand Up @@ -6448,17 +6451,17 @@ defm PseudoVMV_V : VPseudoUnaryVMV_V_X_I;
// 12.1. Vector Single-Width Saturating Add and Subtract
//===----------------------------------------------------------------------===//
let Defs = [VXSAT], hasSideEffects = 1 in {
defm PseudoVSADDU : VPseudoVSALU_VV_VX_VI;
defm PseudoVSADD : VPseudoVSALU_VV_VX_VI;
defm PseudoVSADDU : VPseudoVSALU_VV_VX_VI<Commutable=1>;
defm PseudoVSADD : VPseudoVSALU_VV_VX_VI<Commutable=1>;
defm PseudoVSSUBU : VPseudoVSALU_VV_VX;
defm PseudoVSSUB : VPseudoVSALU_VV_VX;
}

//===----------------------------------------------------------------------===//
// 12.2. Vector Single-Width Averaging Add and Subtract
//===----------------------------------------------------------------------===//
defm PseudoVAADDU : VPseudoVAALU_VV_VX_RM;
defm PseudoVAADD : VPseudoVAALU_VV_VX_RM;
defm PseudoVAADDU : VPseudoVAALU_VV_VX_RM<Commutable=1>;
defm PseudoVAADD : VPseudoVAALU_VV_VX_RM<Commutable=1>;
defm PseudoVASUBU : VPseudoVAALU_VV_VX_RM;
defm PseudoVASUB : VPseudoVAALU_VV_VX_RM;

Expand Down
173 changes: 173 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/commutable.ll
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,176 @@ entry:
ret <vscale x 1 x i64> %ret
}

; vsadd.vv
declare <vscale x 1 x i64> @llvm.riscv.vsadd.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, iXLen);
define <vscale x 1 x i64> @commutable_vsadd_vv(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen %2) nounwind {
; CHECK-LABEL: commutable_vsadd_vv:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: vsadd.vv v10, v8, v9
; CHECK-NEXT: vsadd.vv v8, v9, v8
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i64> @llvm.riscv.vsadd.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen %2)
%b = call <vscale x 1 x i64> @llvm.riscv.vsadd.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, iXLen %2)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

declare <vscale x 1 x i64> @llvm.riscv.vsadd.mask.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i1>, iXLen, iXLen);
define <vscale x 1 x i64> @commutable_vsadd_vv_masked(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen %2) {
; CHECK-LABEL: commutable_vsadd_vv_masked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: vsadd.vv v10, v8, v9, v0.t
; CHECK-NEXT: vsadd.vv v8, v9, v8, v0.t
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
%a = call <vscale x 1 x i64> @llvm.riscv.vsadd.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen %2, iXLen 1)
%b = call <vscale x 1 x i64> @llvm.riscv.vsadd.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, <vscale x 1 x i1> %mask, iXLen %2, iXLen 1)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

; vsaddu.vv
declare <vscale x 1 x i64> @llvm.riscv.vsaddu.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, iXLen);
define <vscale x 1 x i64> @commutable_vsaddu_vv(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen %2) nounwind {
; CHECK-LABEL: commutable_vsaddu_vv:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: vsaddu.vv v10, v8, v9
; CHECK-NEXT: vsaddu.vv v8, v9, v8
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i64> @llvm.riscv.vsaddu.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen %2)
%b = call <vscale x 1 x i64> @llvm.riscv.vsaddu.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, iXLen %2)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

declare <vscale x 1 x i64> @llvm.riscv.vsaddu.mask.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i1>, iXLen, iXLen);
define <vscale x 1 x i64> @commutable_vsaddu_vv_masked(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen %2) {
; CHECK-LABEL: commutable_vsaddu_vv_masked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: vsaddu.vv v10, v8, v9, v0.t
; CHECK-NEXT: vsaddu.vv v8, v9, v8, v0.t
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
%a = call <vscale x 1 x i64> @llvm.riscv.vsaddu.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen %2, iXLen 1)
%b = call <vscale x 1 x i64> @llvm.riscv.vsaddu.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, <vscale x 1 x i1> %mask, iXLen %2, iXLen 1)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

; vaadd.vv
declare <vscale x 1 x i64> @llvm.riscv.vaadd.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, iXLen, iXLen);
define <vscale x 1 x i64> @commutable_vaadd_vv(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen %2) nounwind {
; CHECK-LABEL: commutable_vaadd_vv:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vaadd.vv v8, v8, v9
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v8, v8
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i64> @llvm.riscv.vaadd.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen 0, iXLen %2)
%b = call <vscale x 1 x i64> @llvm.riscv.vaadd.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, iXLen 0, iXLen %2)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

declare <vscale x 1 x i64> @llvm.riscv.vaadd.mask.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i1>, iXLen, iXLen, iXLen);
define <vscale x 1 x i64> @commutable_vaadd_vv_masked(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen %2) {
; CHECK-LABEL: commutable_vaadd_vv_masked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vaadd.vv v10, v8, v9, v0.t
; CHECK-NEXT: vaadd.vv v8, v8, v9, v0.t
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
%a = call <vscale x 1 x i64> @llvm.riscv.vaadd.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen 0, iXLen %2, iXLen 1)
%b = call <vscale x 1 x i64> @llvm.riscv.vaadd.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, <vscale x 1 x i1> %mask, iXLen 0, iXLen %2, iXLen 1)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

; vaaddu.vv
declare <vscale x 1 x i64> @llvm.riscv.vaaddu.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, iXLen, iXLen);
define <vscale x 1 x i64> @commutable_vaaddu_vv(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen %2) nounwind {
; CHECK-LABEL: commutable_vaaddu_vv:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vaaddu.vv v8, v8, v9
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v8, v8
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i64> @llvm.riscv.vaaddu.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen 0, iXLen %2)
%b = call <vscale x 1 x i64> @llvm.riscv.vaaddu.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, iXLen 0, iXLen %2)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

declare <vscale x 1 x i64> @llvm.riscv.vaaddu.mask.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i1>, iXLen, iXLen, iXLen);
define <vscale x 1 x i64> @commutable_vaaddu_vv_masked(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen %2) {
; CHECK-LABEL: commutable_vaaddu_vv_masked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vaaddu.vv v10, v8, v9, v0.t
; CHECK-NEXT: vaaddu.vv v8, v8, v9, v0.t
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
%a = call <vscale x 1 x i64> @llvm.riscv.vaaddu.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen 0, iXLen %2, iXLen 1)
%b = call <vscale x 1 x i64> @llvm.riscv.vaaddu.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, <vscale x 1 x i1> %mask, iXLen 0, iXLen %2, iXLen 1)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

; vsmul.vv
declare <vscale x 1 x i64> @llvm.riscv.vsmul.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, iXLen, iXLen);
define <vscale x 1 x i64> @commutable_vsmul_vv(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen %2) nounwind {
; CHECK-LABEL: commutable_vsmul_vv:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vsmul.vv v10, v8, v9
; CHECK-NEXT: vsmul.vv v8, v9, v8
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
entry:
%a = call <vscale x 1 x i64> @llvm.riscv.vsmul.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, iXLen 0, iXLen %2)
%b = call <vscale x 1 x i64> @llvm.riscv.vsmul.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, iXLen 0, iXLen %2)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

declare <vscale x 1 x i64> @llvm.riscv.vsmul.mask.nxv1i64.nxv1i64(<vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i64>, <vscale x 1 x i1>, iXLen, iXLen, iXLen);
define <vscale x 1 x i64> @commutable_vsmul_vv_masked(<vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen %2) {
; CHECK-LABEL: commutable_vsmul_vv_masked:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
; CHECK-NEXT: csrwi vxrm, 0
; CHECK-NEXT: vsmul.vv v10, v8, v9, v0.t
; CHECK-NEXT: vsmul.vv v8, v9, v8, v0.t
; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
; CHECK-NEXT: vadd.vv v8, v10, v8
; CHECK-NEXT: ret
%a = call <vscale x 1 x i64> @llvm.riscv.vsmul.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %0, <vscale x 1 x i64> %1, <vscale x 1 x i1> %mask, iXLen 0, iXLen %2, iXLen 1)
%b = call <vscale x 1 x i64> @llvm.riscv.vsmul.mask.nxv1i64.nxv1i64(<vscale x 1 x i64> undef, <vscale x 1 x i64> %1, <vscale x 1 x i64> %0, <vscale x 1 x i1> %mask, iXLen 0, iXLen %2, iXLen 1)
%ret = add <vscale x 1 x i64> %a, %b
ret <vscale x 1 x i64> %ret
}

0 comments on commit 2c1c887

Please sign in to comment.