Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25655,3 +25655,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {

return VT.getSizeInBits() <= Subtarget.getXLen();
}

bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
SDValue N1) const {
if (!N0.hasOneUse())
return false;

// Avoid reassociating expressions that can be lowered to vector
// multiply accumulate (i.e. add (mul x, y), z)
if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
(N0.getValueType().isVector() && Subtarget.hasVInstructions()))
return false;

return true;
}
5 changes: 5 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering {

bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;

/// Control the following reassociation of operands: (op (op x, c1), y) -> (op
/// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
SDValue N1) const override;

/// Match a mask which "spreads" the leading elements of a vector evenly
/// across the result. Factor is the spread amount, and Index is the
/// offset applied.
Expand Down
143 changes: 143 additions & 0 deletions llvm/test/CodeGen/RISCV/vmadd-reassociate.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s

define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind {
; CHECK-LABEL: madd_scalar:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: mul a0, a0, a1
; CHECK-NEXT: mul a1, a2, a3
; CHECK-NEXT: add a0, a0, a1
; CHECK-NEXT: addiw a0, a0, 32
; CHECK-NEXT: ret
entry:
%mul0 = mul i32 %m00, %m01
%mul1 = mul i32 %m10, %m11
%add0 = add i32 %mul0, 32
%add1 = add i32 %add0, %mul1
ret i32 %add1
}

define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) {
; CHECK-LABEL: vmadd_non_constant:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmadd.vv v8, v10, v16
; CHECK-NEXT: vmacc.vv v8, v14, v12
; CHECK-NEXT: ret
entry:
%mul0 = mul <8 x i32> %m00, %m01
%mul1 = mul <8 x i32> %m10, %m11
%add0 = add <8 x i32> %mul0, %addend
%add1 = add <8 x i32> %add0, %mul1
ret <8 x i32> %add1
}

define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01) {
; CHECK-LABEL: vmadd_vscale_no_chain:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: li a0, 32
; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
; CHECK-NEXT: vmv.v.x v10, a0
; CHECK-NEXT: vmadd.vv v8, v9, v10
; CHECK-NEXT: ret
entry:
%mul = mul <vscale x 1 x i32> %m00, %m01
%add = add <vscale x 1 x i32> %mul, splat (i32 32)
ret <vscale x 1 x i32> %add
}

define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) {
; CHECK-LABEL: vmadd_fixed_no_chain:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: li a0, 32
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmv.v.x v12, a0
; CHECK-NEXT: vmadd.vv v8, v10, v12
; CHECK-NEXT: ret
entry:
%mul = mul <8 x i32> %m00, %m01
%add = add <8 x i32> %mul, splat (i32 32)
ret <8 x i32> %add
}

define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
; CHECK-LABEL: vmadd_vscale:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: li a0, 32
; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
; CHECK-NEXT: vmv.v.x v12, a0
; CHECK-NEXT: vmadd.vv v8, v9, v12
; CHECK-NEXT: vmacc.vv v8, v11, v10
; CHECK-NEXT: ret
entry:
%mul0 = mul <vscale x 1 x i32> %m00, %m01
%mul1 = mul <vscale x 1 x i32> %m10, %m11
%add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
%add1 = add <vscale x 1 x i32> %add0, %mul1
ret <vscale x 1 x i32> %add1
}

define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
; CHECK-LABEL: vmadd_fixed:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: li a0, 32
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmv.v.x v16, a0
; CHECK-NEXT: vmadd.vv v8, v10, v16
; CHECK-NEXT: vmacc.vv v8, v14, v12
; CHECK-NEXT: ret
entry:
%mul0 = mul <8 x i32> %m00, %m01
%mul1 = mul <8 x i32> %m10, %m11
%add0 = add <8 x i32> %mul0, splat (i32 32)
%add1 = add <8 x i32> %add0, %mul1
ret <8 x i32> %add1
}

define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
; CHECK-LABEL: vmadd_vscale_long:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: li a0, 32
; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
; CHECK-NEXT: vmv.v.x v16, a0
; CHECK-NEXT: vmadd.vv v8, v9, v16
; CHECK-NEXT: vmacc.vv v8, v11, v10
; CHECK-NEXT: vmacc.vv v8, v13, v12
; CHECK-NEXT: vmacc.vv v8, v15, v14
; CHECK-NEXT: ret
<vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
entry:
%mul0 = mul <vscale x 1 x i32> %m00, %m01
%mul1 = mul <vscale x 1 x i32> %m10, %m11
%mul2 = mul <vscale x 1 x i32> %m20, %m21
%mul3 = mul <vscale x 1 x i32> %m30, %m31
%add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
%add1 = add <vscale x 1 x i32> %add0, %mul1
%add2 = add <vscale x 1 x i32> %add1, %mul2
%add3 = add <vscale x 1 x i32> %add2, %mul3
ret <vscale x 1 x i32> %add3
}

define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
; CHECK-LABEL: vmadd_fixed_long:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: li a0, 32
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; CHECK-NEXT: vmv.v.x v24, a0
; CHECK-NEXT: vmadd.vv v8, v10, v24
; CHECK-NEXT: vmacc.vv v8, v14, v12
; CHECK-NEXT: vmacc.vv v8, v18, v16
; CHECK-NEXT: vmacc.vv v8, v22, v20
; CHECK-NEXT: ret
<8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
entry:
%mul0 = mul <8 x i32> %m00, %m01
%mul1 = mul <8 x i32> %m10, %m11
%mul2 = mul <8 x i32> %m20, %m21
%mul3 = mul <8 x i32> %m30, %m31
%add0 = add <8 x i32> %mul0, splat (i32 32)
%add1 = add <8 x i32> %add0, %mul1
%add2 = add <8 x i32> %add1, %mul2
%add3 = add <8 x i32> %add2, %mul3
ret <8 x i32> %add3
}