Skip to content

Commit

Permalink
[RISCV][ISel] Fold extensions when all the users can consume them
Browse files Browse the repository at this point in the history
This patch allows the combines that fold extensions in binary operations
to have more than one use.
The approach here is pretty conservative: if all the users of an
extension can fold the extension, then the folding is done, otherwise we
don't fold.
This is the first step towards avoiding the one-use limitation.

As a result, we make a decision to fold/don't fold for a web of
instructions. An instruction is part of the web of instructions as soon
as it consumes an extension that needs to be folded for all its users.

Because of how SDISel works a web of instructions can be visited over
and over. More precisely, if the folding happens, it happens for the
whole web and that's the end of it, but if the folding fails, the whole
web may be revisited when another member of the web is visited.

To avoid a compile time explosion in pathological cases, we bail out
earlier for webs that are bigger than a given threshold (arbitrarily set
at 18 for now.) This size can be changed using
`--riscv-lower-ext-max-web-size=<maxWebSize>`.

At the current time, I didn't see a better scheme for that. Assuming we
want to stick with doing that in SDISel.

Differential Revision: https://reviews.llvm.org/D133739
  • Loading branch information
qcolombet committed Oct 5, 2022
1 parent 03b1454 commit c5c2de2
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 29 deletions.
102 changes: 80 additions & 22 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -46,6 +46,12 @@ using namespace llvm;

STATISTIC(NumTailCalls, "Number of tail calls");

static cl::opt<unsigned> ExtensionMaxWebSize(
DEBUG_TYPE "-ext-max-web-size", cl::Hidden,
cl::desc("Give the maximum size (in number of nodes) of the web of "
"instructions that we will consider for VW expansion"),
cl::init(18));

static cl::opt<bool>
AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden,
cl::desc("Allow the formation of VW_W operations (e.g., "
Expand Down Expand Up @@ -8547,9 +8553,9 @@ struct CombineResult {
/// Root of the combine.
SDNode *Root;
/// LHS of the TargetOpcode.
const NodeExtensionHelper &LHS;
NodeExtensionHelper LHS;
/// RHS of the TargetOpcode.
const NodeExtensionHelper &RHS;
NodeExtensionHelper RHS;

CombineResult(unsigned TargetOpcode, SDNode *Root,
const NodeExtensionHelper &LHS, Optional<bool> SExtLHS,
Expand Down Expand Up @@ -8728,31 +8734,83 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {

assert(NodeExtensionHelper::isSupportedRoot(N) &&
"Shouldn't have called this method");
SmallVector<SDNode *> Worklist;
SmallSet<SDNode *, 8> Inserted;
Worklist.push_back(N);
Inserted.insert(N);
SmallVector<CombineResult> CombinesToApply;

while (!Worklist.empty()) {
SDNode *Root = Worklist.pop_back_val();
if (!NodeExtensionHelper::isSupportedRoot(Root))
return SDValue();

NodeExtensionHelper LHS(N, 0, DAG);
NodeExtensionHelper RHS(N, 1, DAG);

if (LHS.needToPromoteOtherUsers() && !LHS.OrigOperand.hasOneUse())
return SDValue();

if (RHS.needToPromoteOtherUsers() && !RHS.OrigOperand.hasOneUse())
return SDValue();
NodeExtensionHelper LHS(N, 0, DAG);
NodeExtensionHelper RHS(N, 1, DAG);
auto AppendUsersIfNeeded = [&Worklist,
&Inserted](const NodeExtensionHelper &Op) {
if (Op.needToPromoteOtherUsers()) {
for (SDNode *TheUse : Op.OrigOperand->uses()) {
if (Inserted.insert(TheUse).second)
Worklist.push_back(TheUse);
}
}
};
AppendUsersIfNeeded(LHS);
AppendUsersIfNeeded(RHS);

SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
NodeExtensionHelper::getSupportedFoldings(N);
// Control the compile time by limiting the number of node we look at in
// total.
if (Inserted.size() > ExtensionMaxWebSize)
return SDValue();

assert(!FoldingStrategies.empty() && "Nothing to be folded");
for (int Attempt = 0; Attempt != 1 + NodeExtensionHelper::isCommutative(N);
++Attempt) {
for (NodeExtensionHelper::CombineToTry FoldingStrategy :
FoldingStrategies) {
Optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
if (Res)
return Res->materialize(DAG);
SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
NodeExtensionHelper::getSupportedFoldings(N);

assert(!FoldingStrategies.empty() && "Nothing to be folded");
bool Matched = false;
for (int Attempt = 0;
(Attempt != 1 + NodeExtensionHelper::isCommutative(N)) && !Matched;
++Attempt) {

for (NodeExtensionHelper::CombineToTry FoldingStrategy :
FoldingStrategies) {
Optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
if (Res) {
Matched = true;
CombinesToApply.push_back(*Res);
break;
}
}
std::swap(LHS, RHS);
}
std::swap(LHS, RHS);
// Right now we do an all or nothing approach.
if (!Matched)
return SDValue();
}
return SDValue();
// Store the value for the replacement of the input node separately.
SDValue InputRootReplacement;
// We do the RAUW after we materialize all the combines, because some replaced
// nodes may be feeding some of the yet-to-be-replaced nodes. Put differently,
// some of these nodes may appear in the NodeExtensionHelpers of some of the
// yet-to-be-visited CombinesToApply roots.
SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
ValuesToReplace.reserve(CombinesToApply.size());
for (CombineResult Res : CombinesToApply) {
SDValue NewValue = Res.materialize(DAG);
if (!InputRootReplacement) {
assert(Res.Root == N &&
"First element is expected to be the current node");
InputRootReplacement = NewValue;
} else {
ValuesToReplace.emplace_back(SDValue(Res.Root, 0), NewValue);
}
}
for (std::pair<SDValue, SDValue> OldNewValues : ValuesToReplace) {
DAG.ReplaceAllUsesOfValueWith(OldNewValues.first, OldNewValues.second);
DCI.AddToWorklist(OldNewValues.second.getNode());
}
return InputRootReplacement;
}

// Fold
Expand Down
60 changes: 60 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
@@ -0,0 +1,60 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
; Check that the default value enables the web folding and
; that it is bigger than 3.
; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING


; Check that the add/sub/mul operations are all promoted into their
; vw counterpart when the folding of the web size is increased to 3.
; We need the web size to be at least 3 for the folding to happen, because
; %c has 3 uses.
define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) {
; NO_FOLDING-LABEL: vwmul_v2i16_multiple_users:
; NO_FOLDING: # %bb.0:
; NO_FOLDING-NEXT: vsetivli zero, 2, e16, mf4, ta, mu
; NO_FOLDING-NEXT: vle8.v v8, (a0)
; NO_FOLDING-NEXT: vle8.v v9, (a1)
; NO_FOLDING-NEXT: vle8.v v10, (a2)
; NO_FOLDING-NEXT: vsext.vf2 v11, v8
; NO_FOLDING-NEXT: vsext.vf2 v8, v9
; NO_FOLDING-NEXT: vsext.vf2 v9, v10
; NO_FOLDING-NEXT: vmul.vv v8, v11, v8
; NO_FOLDING-NEXT: vadd.vv v10, v11, v9
; NO_FOLDING-NEXT: vsub.vv v9, v11, v9
; NO_FOLDING-NEXT: vor.vv v8, v8, v10
; NO_FOLDING-NEXT: vor.vv v8, v8, v9
; NO_FOLDING-NEXT: ret
;
; FOLDING-LABEL: vwmul_v2i16_multiple_users:
; FOLDING: # %bb.0:
; FOLDING-NEXT: vsetivli zero, 2, e8, mf8, ta, mu
; FOLDING-NEXT: vle8.v v8, (a0)
; FOLDING-NEXT: vle8.v v9, (a1)
; FOLDING-NEXT: vle8.v v10, (a2)
; FOLDING-NEXT: vwmul.vv v11, v8, v9
; FOLDING-NEXT: vwadd.vv v9, v8, v10
; FOLDING-NEXT: vwsub.vv v12, v8, v10
; FOLDING-NEXT: vsetvli zero, zero, e16, mf4, ta, mu
; FOLDING-NEXT: vor.vv v8, v11, v9
; FOLDING-NEXT: vor.vv v8, v8, v12
; FOLDING-NEXT: ret
%a = load <2 x i8>, <2 x i8>* %x
%b = load <2 x i8>, <2 x i8>* %y
%b2 = load <2 x i8>, <2 x i8>* %z
%c = sext <2 x i8> %a to <2 x i16>
%d = sext <2 x i8> %b to <2 x i16>
%d2 = sext <2 x i8> %b2 to <2 x i16>
%e = mul <2 x i16> %c, %d
%f = add <2 x i16> %c, %d2
%g = sub <2 x i16> %c, %d2
%h = or <2 x i16> %e, %f
%i = or <2 x i16> %h, %g
ret <2 x i16> %i
}
12 changes: 5 additions & 7 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll
Expand Up @@ -21,16 +21,14 @@ define <2 x i16> @vwmul_v2i16(<2 x i8>* %x, <2 x i8>* %y) {
define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) {
; CHECK-LABEL: vwmul_v2i16_multiple_users:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, mu
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, mu
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a1)
; CHECK-NEXT: vle8.v v10, (a2)
; CHECK-NEXT: vsext.vf2 v11, v8
; CHECK-NEXT: vsext.vf2 v8, v9
; CHECK-NEXT: vsext.vf2 v9, v10
; CHECK-NEXT: vmul.vv v8, v11, v8
; CHECK-NEXT: vmul.vv v9, v11, v9
; CHECK-NEXT: vor.vv v8, v8, v9
; CHECK-NEXT: vwmul.vv v11, v8, v9
; CHECK-NEXT: vwmul.vv v9, v8, v10
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, mu
; CHECK-NEXT: vor.vv v8, v11, v9
; CHECK-NEXT: ret
%a = load <2 x i8>, <2 x i8>* %x
%b = load <2 x i8>, <2 x i8>* %y
Expand Down

0 comments on commit c5c2de2

Please sign in to comment.