diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 921d12757d672..e3f9e24555dae 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -25655,3 +25655,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const { return VT.getSizeInBits() <= Subtarget.getXLen(); } + +bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0, + SDValue N1) const { + if (!N0.hasOneUse()) + return false; + + // Avoid reassociating expressions that can be lowered to vector + // multiply accumulate (i.e. add (mul x, y), z) + if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL && + (N0.getValueType().isVector() && Subtarget.hasVInstructions())) + return false; + + return true; +} diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 5cc427c867cfd..f4b3faefb1e95 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering { bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override; + /// Control the following reassociation of operands: (op (op x, c1), y) -> (op + /// (op x, y), c1) where N0 is (op x, c1) and N1 is y. + bool isReassocProfitable(SelectionDAG &DAG, SDValue N0, + SDValue N1) const override; + /// Match a mask which "spreads" the leading elements of a vector evenly /// across the result. Factor is the spread amount, and Index is the /// offset applied. diff --git a/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll new file mode 100644 index 0000000000000..9fa0cec0ea339 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/vmadd-reassociate.ll @@ -0,0 +1,143 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s + +define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind { +; CHECK-LABEL: madd_scalar: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: mul a0, a0, a1 +; CHECK-NEXT: mul a1, a2, a3 +; CHECK-NEXT: add a0, a0, a1 +; CHECK-NEXT: addiw a0, a0, 32 +; CHECK-NEXT: ret +entry: + %mul0 = mul i32 %m00, %m01 + %mul1 = mul i32 %m10, %m11 + %add0 = add i32 %mul0, 32 + %add1 = add i32 %add0, %mul1 + ret i32 %add1 +} + +define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) { +; CHECK-LABEL: vmadd_non_constant: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmadd.vv v8, v10, v16 +; CHECK-NEXT: vmacc.vv v8, v14, v12 +; CHECK-NEXT: ret +entry: + %mul0 = mul <8 x i32> %m00, %m01 + %mul1 = mul <8 x i32> %m10, %m11 + %add0 = add <8 x i32> %mul0, %addend + %add1 = add <8 x i32> %add0, %mul1 + ret <8 x i32> %add1 +} + +define @vmadd_vscale_no_chain( %m00, %m01) { +; CHECK-LABEL: vmadd_vscale_no_chain: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma +; CHECK-NEXT: vmv.v.x v10, a0 +; CHECK-NEXT: vmadd.vv v8, v9, v10 +; CHECK-NEXT: ret +entry: + %mul = mul %m00, %m01 + %add = add %mul, splat (i32 32) + ret %add +} + +define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) { +; CHECK-LABEL: vmadd_fixed_no_chain: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmv.v.x v12, a0 +; CHECK-NEXT: vmadd.vv v8, v10, v12 +; CHECK-NEXT: ret +entry: + %mul = mul <8 x i32> %m00, %m01 + %add = add <8 x i32> %mul, splat (i32 32) + ret <8 x i32> %add +} + +define @vmadd_vscale( %m00, %m01, %m10, %m11) { +; CHECK-LABEL: vmadd_vscale: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma +; CHECK-NEXT: vmv.v.x v12, a0 +; CHECK-NEXT: vmadd.vv v8, v9, v12 +; CHECK-NEXT: vmacc.vv v8, v11, v10 +; CHECK-NEXT: ret +entry: + %mul0 = mul %m00, %m01 + %mul1 = mul %m10, %m11 + %add0 = add %mul0, splat (i32 32) + %add1 = add %add0, %mul1 + ret %add1 +} + +define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) { +; CHECK-LABEL: vmadd_fixed: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmv.v.x v16, a0 +; CHECK-NEXT: vmadd.vv v8, v10, v16 +; CHECK-NEXT: vmacc.vv v8, v14, v12 +; CHECK-NEXT: ret +entry: + %mul0 = mul <8 x i32> %m00, %m01 + %mul1 = mul <8 x i32> %m10, %m11 + %add0 = add <8 x i32> %mul0, splat (i32 32) + %add1 = add <8 x i32> %add0, %mul1 + ret <8 x i32> %add1 +} + +define @vmadd_vscale_long( %m00, %m01, %m10, %m11, +; CHECK-LABEL: vmadd_vscale_long: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma +; CHECK-NEXT: vmv.v.x v16, a0 +; CHECK-NEXT: vmadd.vv v8, v9, v16 +; CHECK-NEXT: vmacc.vv v8, v11, v10 +; CHECK-NEXT: vmacc.vv v8, v13, v12 +; CHECK-NEXT: vmacc.vv v8, v15, v14 +; CHECK-NEXT: ret + %m20, %m21, %m30, %m31) { +entry: + %mul0 = mul %m00, %m01 + %mul1 = mul %m10, %m11 + %mul2 = mul %m20, %m21 + %mul3 = mul %m30, %m31 + %add0 = add %mul0, splat (i32 32) + %add1 = add %add0, %mul1 + %add2 = add %add1, %mul2 + %add3 = add %add2, %mul3 + ret %add3 +} + +define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, +; CHECK-LABEL: vmadd_fixed_long: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a0, 32 +; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma +; CHECK-NEXT: vmv.v.x v24, a0 +; CHECK-NEXT: vmadd.vv v8, v10, v24 +; CHECK-NEXT: vmacc.vv v8, v14, v12 +; CHECK-NEXT: vmacc.vv v8, v18, v16 +; CHECK-NEXT: vmacc.vv v8, v22, v20 +; CHECK-NEXT: ret + <8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) { +entry: + %mul0 = mul <8 x i32> %m00, %m01 + %mul1 = mul <8 x i32> %m10, %m11 + %mul2 = mul <8 x i32> %m20, %m21 + %mul3 = mul <8 x i32> %m30, %m31 + %add0 = add <8 x i32> %mul0, splat (i32 32) + %add1 = add <8 x i32> %add0, %mul1 + %add2 = add <8 x i32> %add1, %mul2 + %add3 = add <8 x i32> %add2, %mul3 + ret <8 x i32> %add3 +}