Skip to content

Commit 4a1918d

Browse files
committed
[RISCV][DAGCombiner] Restrict VL->VW extension combine to single node
Previously, combineOp_VLToVWOp_VL maintained a private worklist and recursively folded multiple nodes. This could cause inconsistent DAG state and missed updates for some users. This patch simplifies the combine to only attempt folding the input node, leaving replacement and user updates to the outer DAGCombiner worklist. The DEBUG-only option `-ext-max-web-size` is now obsolete and have been removed. A new test (combine-vl-vw-macc.ll) is added to verify that consecutive vwmacc-like operations are generated correctly.
1 parent b196c52 commit 4a1918d

File tree

6 files changed

+106
-510
lines changed

6 files changed

+106
-510
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 18 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,6 @@ using namespace llvm;
5757

5858
STATISTIC(NumTailCalls, "Number of tail calls");
5959

60-
static cl::opt<unsigned> ExtensionMaxWebSize(
61-
DEBUG_TYPE "-ext-max-web-size", cl::Hidden,
62-
cl::desc("Give the maximum size (in number of nodes) of the web of "
63-
"instructions that we will consider for VW expansion"),
64-
cl::init(18));
65-
6660
static cl::opt<bool>
6761
AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden,
6862
cl::desc("Allow the formation of VW_W operations (e.g., "
@@ -18201,109 +18195,30 @@ static SDValue combineOp_VLToVWOp_VL(SDNode *N,
1820118195
if (!NodeExtensionHelper::isSupportedRoot(N, Subtarget))
1820218196
return SDValue();
1820318197

18204-
SmallVector<SDNode *> Worklist;
18205-
SmallPtrSet<SDNode *, 8> Inserted;
18206-
SmallPtrSet<SDNode *, 8> ExtensionsToRemove;
18207-
Worklist.push_back(N);
18208-
Inserted.insert(N);
18209-
SmallVector<CombineResult> CombinesToApply;
18210-
18211-
while (!Worklist.empty()) {
18212-
SDNode *Root = Worklist.pop_back_val();
18213-
18214-
NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
18215-
NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
18216-
auto AppendUsersIfNeeded =
18217-
[&Worklist, &Subtarget, &Inserted,
18218-
&ExtensionsToRemove](const NodeExtensionHelper &Op) {
18219-
if (Op.needToPromoteOtherUsers()) {
18220-
// Remember that we're supposed to remove this extension.
18221-
ExtensionsToRemove.insert(Op.OrigOperand.getNode());
18222-
for (SDUse &Use : Op.OrigOperand->uses()) {
18223-
SDNode *TheUser = Use.getUser();
18224-
if (!NodeExtensionHelper::isSupportedRoot(TheUser, Subtarget))
18225-
return false;
18226-
// We only support the first 2 operands of FMA.
18227-
if (Use.getOperandNo() >= 2)
18228-
return false;
18229-
if (Inserted.insert(TheUser).second)
18230-
Worklist.push_back(TheUser);
18231-
}
18232-
}
18233-
return true;
18234-
};
18198+
NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
18199+
NodeExtensionHelper RHS(N, 1, DAG, Subtarget);
1823518200

18236-
// Control the compile time by limiting the number of node we look at in
18237-
// total.
18238-
if (Inserted.size() > ExtensionMaxWebSize)
18239-
return SDValue();
18201+
SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
18202+
NodeExtensionHelper::getSupportedFoldings(N);
1824018203

18241-
SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
18242-
NodeExtensionHelper::getSupportedFoldings(Root);
18243-
18244-
assert(!FoldingStrategies.empty() && "Nothing to be folded");
18245-
bool Matched = false;
18246-
for (int Attempt = 0;
18247-
(Attempt != 1 + NodeExtensionHelper::isCommutative(Root)) && !Matched;
18248-
++Attempt) {
18249-
18250-
for (NodeExtensionHelper::CombineToTry FoldingStrategy :
18251-
FoldingStrategies) {
18252-
std::optional<CombineResult> Res =
18253-
FoldingStrategy(Root, LHS, RHS, DAG, Subtarget);
18254-
if (Res) {
18255-
// If this strategy wouldn't remove an extension we're supposed to
18256-
// remove, reject it.
18257-
if (!Res->LHSExt.has_value() &&
18258-
ExtensionsToRemove.contains(LHS.OrigOperand.getNode()))
18259-
continue;
18260-
if (!Res->RHSExt.has_value() &&
18261-
ExtensionsToRemove.contains(RHS.OrigOperand.getNode()))
18262-
continue;
18204+
if (FoldingStrategies.empty())
18205+
return SDValue();
1826318206

18264-
Matched = true;
18265-
CombinesToApply.push_back(*Res);
18266-
// All the inputs that are extended need to be folded, otherwise
18267-
// we would be leaving the old input (since it is may still be used),
18268-
// and the new one.
18269-
if (Res->LHSExt.has_value())
18270-
if (!AppendUsersIfNeeded(LHS))
18271-
return SDValue();
18272-
if (Res->RHSExt.has_value())
18273-
if (!AppendUsersIfNeeded(RHS))
18274-
return SDValue();
18275-
break;
18276-
}
18207+
bool IsComm = NodeExtensionHelper::isCommutative(N);
18208+
for (int Attempt = 0; Attempt != 1 + IsComm; ++Attempt) {
18209+
for (NodeExtensionHelper::CombineToTry FoldingStrategy :
18210+
FoldingStrategies) {
18211+
std::optional<CombineResult> Res =
18212+
FoldingStrategy(N, LHS, RHS, DAG, Subtarget);
18213+
if (Res) {
18214+
SDValue NewValue = Res->materialize(DAG, Subtarget);
18215+
return NewValue;
1827718216
}
18278-
std::swap(LHS, RHS);
18279-
}
18280-
// Right now we do an all or nothing approach.
18281-
if (!Matched)
18282-
return SDValue();
18283-
}
18284-
// Store the value for the replacement of the input node separately.
18285-
SDValue InputRootReplacement;
18286-
// We do the RAUW after we materialize all the combines, because some replaced
18287-
// nodes may be feeding some of the yet-to-be-replaced nodes. Put differently,
18288-
// some of these nodes may appear in the NodeExtensionHelpers of some of the
18289-
// yet-to-be-visited CombinesToApply roots.
18290-
SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
18291-
ValuesToReplace.reserve(CombinesToApply.size());
18292-
for (CombineResult Res : CombinesToApply) {
18293-
SDValue NewValue = Res.materialize(DAG, Subtarget);
18294-
if (!InputRootReplacement) {
18295-
assert(Res.Root == N &&
18296-
"First element is expected to be the current node");
18297-
InputRootReplacement = NewValue;
18298-
} else {
18299-
ValuesToReplace.emplace_back(SDValue(Res.Root, 0), NewValue);
1830018217
}
18218+
std::swap(LHS, RHS);
1830118219
}
18302-
for (std::pair<SDValue, SDValue> OldNewValues : ValuesToReplace) {
18303-
DAG.ReplaceAllUsesOfValueWith(OldNewValues.first, OldNewValues.second);
18304-
DCI.AddToWorklist(OldNewValues.second.getNode());
18305-
}
18306-
return InputRootReplacement;
18220+
18221+
return SDValue();
1830718222
}
1830818223

1830918224
// Fold (vwadd(u).wv y, (vmerge cond, x, 0)) -> vwadd(u).wv y, x, y, cond
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2+
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32
3+
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
4+
5+
define void @matmul_min(<32 x i8>* %vptr, i8* %scalars, <32 x i16>* %acc0_ptr, <32 x i16>* %acc1_ptr) {
6+
; CHECK-LABEL: matmul_min:
7+
; CHECK: # %bb.0: # %entry
8+
; CHECK-NEXT: li a4, 64
9+
; CHECK-NEXT: li a5, 32
10+
; CHECK-NEXT: vsetvli zero, a5, e8, m2, ta, ma
11+
; CHECK-NEXT: vle8.v v16, (a0)
12+
; CHECK-NEXT: lb a0, 0(a1)
13+
; CHECK-NEXT: lb a1, 1(a1)
14+
; CHECK-NEXT: vsetvli zero, a4, e8, m4, ta, ma
15+
; CHECK-NEXT: vle8.v v8, (a2)
16+
; CHECK-NEXT: vle8.v v12, (a3)
17+
; CHECK-NEXT: vsetvli zero, a5, e8, m2, ta, ma
18+
; CHECK-NEXT: vwmacc.vx v8, a0, v16
19+
; CHECK-NEXT: vwmacc.vx v12, a1, v16
20+
; CHECK-NEXT: vsetvli zero, a4, e8, m4, ta, ma
21+
; CHECK-NEXT: vse8.v v8, (a2)
22+
; CHECK-NEXT: vse8.v v12, (a3)
23+
; CHECK-NEXT: ret
24+
entry:
25+
%acc0 = load <32 x i16>, <32 x i16>* %acc0_ptr, align 1
26+
%acc1 = load <32 x i16>, <32 x i16>* %acc1_ptr, align 1
27+
28+
%v8 = load <32 x i8>, <32 x i8>* %vptr, align 1
29+
%v16 = sext <32 x i8> %v8 to <32 x i16>
30+
31+
%s0_ptr = getelementptr i8, i8* %scalars, i32 0
32+
%s0_i8 = load i8, i8* %s0_ptr, align 1
33+
%s0_i16 = sext i8 %s0_i8 to i16
34+
%tmp0 = insertelement <32 x i16> undef, i16 %s0_i16, i32 0
35+
%splat0 = shufflevector <32 x i16> %tmp0, <32 x i16> undef, <32 x i32> zeroinitializer
36+
%mul0 = mul <32 x i16> %splat0, %v16
37+
%add0 = add <32 x i16> %mul0, %acc0
38+
39+
%s1_ptr = getelementptr i8, i8* %scalars, i32 1
40+
%s1_i8 = load i8, i8* %s1_ptr, align 1
41+
%s1_i16 = sext i8 %s1_i8 to i16
42+
%tmp1 = insertelement <32 x i16> undef, i16 %s1_i16, i32 0
43+
%splat1 = shufflevector <32 x i16> %tmp1, <32 x i16> undef, <32 x i32> zeroinitializer
44+
%mul1 = mul <32 x i16> %splat1, %v16
45+
%add1 = add <32 x i16> %mul1, %acc1
46+
47+
store <32 x i16> %add0, <32 x i16>* %acc0_ptr, align 1
48+
store <32 x i16> %add1, <32 x i16>* %acc1_ptr, align 1
49+
50+
ret void
51+
}
52+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
53+
; RV32: {{.*}}
54+
; RV64: {{.*}}

0 commit comments

Comments
 (0)