Skip to content

Commit

Permalink
[WebAssembly] Re-land 8392bf6
Browse files Browse the repository at this point in the history
Correctly handle single-element vectors to fix an assertion failure. Add tests
that were missing from the original commit.

Differential Revision: D151782
  • Loading branch information
calebzulawski authored and tlively committed Jun 9, 2023
1 parent 6adb1ca commit 18077e9
Show file tree
Hide file tree
Showing 7 changed files with 994 additions and 99 deletions.
141 changes: 141 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(

// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);

// Convert vector to integer bitcasts to bitmask
setTargetDAGCombine(ISD::BITCAST);

// Hoist bitcasts out of shuffles
setTargetDAGCombine(ISD::VECTOR_SHUFFLE);

Expand Down Expand Up @@ -258,6 +264,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// But saturating fp_to_int converstions are
for (auto Op : {ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT})
setOperationAction(Op, MVT::v4i32, Custom);

// Support vector extending
for (auto T : MVT::integer_fixedlen_vector_valuetypes()) {
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
}
}

// As a special case, these operators use the type to mean the type to
Expand Down Expand Up @@ -1374,6 +1386,11 @@ void WebAssemblyTargetLowering::ReplaceNodeResults(
// SIGN_EXTEND_INREG, but for non-vector sign extends the result might be an
// illegal type.
break;
case ISD::SIGN_EXTEND_VECTOR_INREG:
case ISD::ZERO_EXTEND_VECTOR_INREG:
// Do not add any results, signifying that N should not be custom lowered.
// EXTEND_VECTOR_INREG is implemented for some vectors, but not all.
break;
default:
llvm_unreachable(
"ReplaceNodeResults not implemented for this op for WebAssembly!");
Expand Down Expand Up @@ -1424,6 +1441,9 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op,
return LowerIntrinsic(Op, DAG);
case ISD::SIGN_EXTEND_INREG:
return LowerSIGN_EXTEND_INREG(Op, DAG);
case ISD::ZERO_EXTEND_VECTOR_INREG:
case ISD::SIGN_EXTEND_VECTOR_INREG:
return LowerEXTEND_VECTOR_INREG(Op, DAG);
case ISD::BUILD_VECTOR:
return LowerBUILD_VECTOR(Op, DAG);
case ISD::VECTOR_SHUFFLE:
Expand Down Expand Up @@ -1877,6 +1897,48 @@ WebAssemblyTargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
Op.getOperand(1));
}

SDValue
WebAssemblyTargetLowering::LowerEXTEND_VECTOR_INREG(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT VT = Op.getValueType();
SDValue Src = Op.getOperand(0);
EVT SrcVT = Src.getValueType();

if (SrcVT.getVectorElementType() == MVT::i1 ||
SrcVT.getVectorElementType() == MVT::i64)
return SDValue();

assert(VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits() == 0 &&
"Unexpected extension factor.");
unsigned Scale = VT.getScalarSizeInBits() / SrcVT.getScalarSizeInBits();

if (Scale != 2 && Scale != 4 && Scale != 8)
return SDValue();

unsigned Ext;
switch (Op.getOpcode()) {
case ISD::ZERO_EXTEND_VECTOR_INREG:
Ext = WebAssemblyISD::EXTEND_LOW_U;
break;
case ISD::SIGN_EXTEND_VECTOR_INREG:
Ext = WebAssemblyISD::EXTEND_LOW_S;
break;
}

SDValue Ret = Src;
while (Scale != 1) {
Ret = DAG.getNode(Ext, DL,
Ret.getValueType()
.widenIntegerVectorElementType(*DAG.getContext())
.getHalfNumVectorElementsVT(*DAG.getContext()),
Ret);
Scale /= 2;
}
assert(Ret.getValueType() == VT);
return Ret;
}

static SDValue LowerConvertLow(SDValue Op, SelectionDAG &DAG) {
SDLoc DL(Op);
if (Op.getValueType() != MVT::v2f64)
Expand Down Expand Up @@ -2692,12 +2754,91 @@ static SDValue performTruncateCombine(SDNode *N,
return truncateVectorWithNARROW(OutVT, In, DL, DAG);
}

static SDValue performBitcastCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
auto &DAG = DCI.DAG;
SDLoc DL(N);
SDValue Src = N->getOperand(0);
EVT VT = N->getValueType(0);
EVT SrcVT = Src.getValueType();

// bitcast <N x i1> to iN
// ==> bitmask
if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
SrcVT.isFixedLengthVector() && SrcVT.getScalarType() == MVT::i1) {
unsigned NumElts = SrcVT.getVectorNumElements();
if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16)
return SDValue();
EVT Width = MVT::getIntegerVT(128 / NumElts);
return DAG.getZExtOrTrunc(
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
{DAG.getConstant(Intrinsic::wasm_bitmask, DL, MVT::i32),
DAG.getSExtOrTrunc(N->getOperand(0), DL,
SrcVT.changeVectorElementType(Width))}),
DL, VT);
}

return SDValue();
}

static SDValue performSETCCCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
auto &DAG = DCI.DAG;

SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
SDLoc DL(N);
EVT VT = N->getValueType(0);

// setcc (iN (bitcast (vNi1 X))), 0, ne
// ==> any_true (vNi1 X)
// setcc (iN (bitcast (vNi1 X))), 0, eq
// ==> xor (any_true (vNi1 X)), -1
// setcc (iN (bitcast (vNi1 X))), -1, eq
// ==> all_true (vNi1 X)
// setcc (iN (bitcast (vNi1 X))), -1, ne
// ==> xor (all_true (vNi1 X)), -1
if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
(Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
(isNullConstant(RHS) || isAllOnesConstant(RHS)) &&
LHS->getOpcode() == ISD::BITCAST) {
EVT FromVT = LHS->getOperand(0).getValueType();
if (FromVT.isFixedLengthVector() &&
FromVT.getVectorElementType() == MVT::i1) {
int Intrin = isNullConstant(RHS) ? Intrinsic::wasm_anytrue
: Intrinsic::wasm_alltrue;
unsigned NumElts = FromVT.getVectorNumElements();
assert(NumElts == 2 || NumElts == 4 || NumElts == 8 || NumElts == 16);
EVT Width = MVT::getIntegerVT(128 / NumElts);
SDValue Ret = DAG.getZExtOrTrunc(
DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32,
{DAG.getConstant(Intrin, DL, MVT::i32),
DAG.getSExtOrTrunc(LHS->getOperand(0), DL,
FromVT.changeVectorElementType(Width))}),
DL, MVT::i1);
if ((isNullConstant(RHS) && (Cond == ISD::SETEQ)) ||
(isAllOnesConstant(RHS) && (Cond == ISD::SETNE))) {
Ret = DAG.getNOT(DL, Ret, MVT::i1);
}
return DAG.getZExtOrTrunc(Ret, DL, VT);
}
}

return SDValue();
}

SDValue
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
switch (N->getOpcode()) {
default:
return SDValue();
case ISD::BITCAST:
return performBitcastCombine(N, DCI);
case ISD::SETCC:
return performSETCCCombine(N, DCI);
case ISD::VECTOR_SHUFFLE:
return performVECTOR_SHUFFLECombine(N, DCI);
case ISD::SIGN_EXTEND:
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
SDValue LowerCopyToReg(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerIntrinsic(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSIGN_EXTEND_INREG(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerEXTEND_VECTOR_INREG(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSETCC(SDValue Op, SelectionDAG &DAG) const;
Expand Down
Loading

0 comments on commit 18077e9

Please sign in to comment.