diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td index a4ed62bb5715c..69aa748f0f4f1 100644 --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -511,6 +511,9 @@ def vecreduce_smax : SDNode<"ISD::VECREDUCE_SMAX", SDTVecReduce>; def vecreduce_umax : SDNode<"ISD::VECREDUCE_UMAX", SDTVecReduce>; def vecreduce_smin : SDNode<"ISD::VECREDUCE_SMIN", SDTVecReduce>; def vecreduce_umin : SDNode<"ISD::VECREDUCE_UMIN", SDTVecReduce>; +def vecreduce_and : SDNode<"ISD::VECREDUCE_AND", SDTVecReduce>; +def vecreduce_or : SDNode<"ISD::VECREDUCE_OR", SDTVecReduce>; +def vecreduce_xor: SDNode<"ISD::VECREDUCE_XOR", SDTVecReduce>; def vecreduce_fadd : SDNode<"ISD::VECREDUCE_FADD", SDTFPVecReduce>; def vecreduce_fmin : SDNode<"ISD::VECREDUCE_FMIN", SDTFPVecReduce>; def vecreduce_fmax : SDNode<"ISD::VECREDUCE_FMAX", SDTFPVecReduce>; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 35d5c3ed90c91..db292ab32f8b2 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -3386,6 +3386,75 @@ static SDValue TryMatchTrue(SDNode *N, EVT VecVT, SelectionDAG &DAG) { return DAG.getZExtOrTrunc(Ret, DL, N->getValueType(0)); } +// Combine a setcc of a vecreduce, for example: +// +// setcc (vecreduce_or(v4i32 V128:$vec)), (i32 0), SETNE +// ==> ANYTRUE V128:$vec +// +// setcc (i32 (vecreduce_and(v4i32 V128:$vec))), (i32 -1), SETEQ +// ==> ALLTRUE_I32x4 V128:$vec +static SDValue combineSetCCVecReduce(SDNode *SetCC, + TargetLowering::DAGCombinerInfo &DCI) { + SDValue Reduce = SetCC->getOperand(0); + SDValue Constant = SetCC->getOperand(1); + SDValue Cond = SetCC->getOperand(2); + unsigned ReduceIntrinsic; + + // i8 and i16 truncate the vecreduce result. + if (Reduce->getOpcode() == ISD::AND) { + SDValue L = Reduce->getOperand(0), R = Reduce->getOperand(1); + + ConstantSDNode *C = dyn_cast(R); + if (!C) + return SDValue(); + + EVT VT = Reduce->getValueType(0); + if (VT == MVT::v16i8 && C->getZExtValue() == 255) { + Reduce = L; + } else if (VT == MVT::v8i16 && C->getZExtValue() == 65535) { + Reduce = L; + } else { + return SDValue(); + } + } + + switch (Reduce->getOpcode()) { + case ISD::VECREDUCE_OR: { + ReduceIntrinsic = Intrinsic::wasm_anytrue; + + if (cast(Cond)->get() != ISD::SETNE) + return SDValue(); + + if (cast(Constant)->getSExtValue() != 0) + return SDValue(); + + break; + } + case ISD::VECREDUCE_AND: { + ReduceIntrinsic = Intrinsic::wasm_alltrue; + + if (cast(Cond)->get() != ISD::SETEQ) + return SDValue(); + + if (cast(Constant)->getSExtValue() != -1) + return SDValue(); + + break; + } + default: + return SDValue(); + } + + SDLoc DL(SetCC); + auto &DAG = DCI.DAG; + SDValue Match = Reduce->getOperand(0); + + return DAG.getZExtOrTrunc( + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32, + {DAG.getConstant(ReduceIntrinsic, DL, MVT::i32), Match}), + DL, MVT::i1); +} + /// Try to convert a i128 comparison to a v16i8 comparison before type /// legalization splits it up into chunks static SDValue @@ -3446,6 +3515,9 @@ static SDValue performSETCCCombine(SDNode *N, if (SDValue V = combineVectorSizedSetCCEquality(N, DCI, Subtarget)) return V; + if (SDValue V = combineSetCCVecReduce(N, DCI)) + return V; + SDValue LHS = N->getOperand(0); if (LHS->getOpcode() != ISD::BITCAST) return SDValue(); @@ -3460,9 +3532,9 @@ static SDValue performSETCCCombine(SDNode *N, if (!cast(N->getOperand(1))) return SDValue(); - EVT VecVT = FromVT.changeVectorElementType(MVT::getIntegerVT(128 / NumElts)); auto &DAG = DCI.DAG; + // setcc (iN (bitcast (vNi1 X))), 0, ne // ==> any_true (vNi1 X) if (auto Match = TryMatchTrue<0, ISD::SETNE, false, Intrinsic::wasm_anytrue>( diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp index 08fb7586d215e..efba2f8c8f805 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp @@ -327,3 +327,19 @@ bool WebAssemblyTTIImpl::isProfitableToSinkOperands( return false; } + +bool WebAssemblyTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const { + // Always expand on Subtargets without vector instructions. + if (!ST->hasSIMD128()) + return true; + + // Whether or not to expand is a per-intrinsic decision. + switch (II->getIntrinsicID()) { + default: + return true; + case Intrinsic::vector_reduce_and: + return false; + case Intrinsic::vector_reduce_or: + return false; + } +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h index c915eeb07d4fd..996b5e45daad1 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h @@ -100,6 +100,7 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase { bool isProfitableToSinkOperands(Instruction *I, SmallVectorImpl &Ops) const override; + bool shouldExpandReduction(const IntrinsicInst *II) const override; /// @} }; diff --git a/llvm/test/CodeGen/WebAssembly/any-all-true.ll b/llvm/test/CodeGen/WebAssembly/any-all-true.ll new file mode 100644 index 0000000000000..b6fd0cde83bec --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/any-all-true.ll @@ -0,0 +1,125 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 + +; RUN: llc < %s -verify-machineinstrs -mattr=+simd128 | FileCheck %s + +target triple = "wasm32-unknown-unknown" + +declare i8 @llvm.vector.reduce.and.v16i8(<16 x i8>) +declare i8 @llvm.vector.reduce.or.v16i8(<16 x i8>) +declare i16 @llvm.vector.reduce.and.v8i16(<8 x i16>) +declare i16 @llvm.vector.reduce.or.v8i16(<8 x i16>) +declare i32 @llvm.vector.reduce.and.v4i32(<4 x i32>) +declare i32 @llvm.vector.reduce.or.v4i32(<4 x i32>) +declare i64 @llvm.vector.reduce.and.v2i64(<2 x i64>) +declare i64 @llvm.vector.reduce.or.v2i64(<2 x i64>) + +define zeroext i1 @manual_i8x16_all_true(<4 x i32> %a) { +; CHECK-LABEL: manual_i8x16_all_true: +; CHECK: .functype manual_i8x16_all_true (v128) -> (i32) +; CHECK-NEXT: # %bb.0: # %start +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: i8x16.all_true +; CHECK-NEXT: # fallthrough-return +start: + %_3 = bitcast <4 x i32> %a to <16 x i8> + %0 = tail call i8 @llvm.vector.reduce.and.v16i8(<16 x i8> %_3) + %_0 = icmp eq i8 %0, -1 + ret i1 %_0 +} + +define zeroext i1 @manual_i16x8_all_true(<4 x i32> %a) { +; CHECK-LABEL: manual_i16x8_all_true: +; CHECK: .functype manual_i16x8_all_true (v128) -> (i32) +; CHECK-NEXT: # %bb.0: # %start +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: i16x8.all_true +; CHECK-NEXT: # fallthrough-return +start: + %_3 = bitcast <4 x i32> %a to <8 x i16> + %0 = tail call i16 @llvm.vector.reduce.and.v8i16(<8 x i16> %_3) + %_0 = icmp eq i16 %0, -1 + ret i1 %_0 +} + +define zeroext i1 @manual_i32x4_all_true(<4 x i32> %a) { +; CHECK-LABEL: manual_i32x4_all_true: +; CHECK: .functype manual_i32x4_all_true (v128) -> (i32) +; CHECK-NEXT: # %bb.0: # %start +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: i32x4.all_true +; CHECK-NEXT: # fallthrough-return +start: + %0 = tail call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> %a) + %_0 = icmp eq i32 %0, -1 + ret i1 %_0 +} + +define zeroext i1 @manual_i64x2_all_true(<2 x i64> %a) { +; CHECK-LABEL: manual_i64x2_all_true: +; CHECK: .functype manual_i64x2_all_true (v128) -> (i32) +; CHECK-NEXT: # %bb.0: # %start +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: i64x2.all_true +; CHECK-NEXT: # fallthrough-return +start: + %0 = tail call i64 @llvm.vector.reduce.and.v2i64(<2 x i64> %a) + %_0 = icmp eq i64 %0, -1 + ret i1 %_0 +} + +; --- + +define zeroext i1 @manual_i8x16_any_true(<4 x i32> %a) { +; CHECK-LABEL: manual_i8x16_any_true: +; CHECK: .functype manual_i8x16_any_true (v128) -> (i32) +; CHECK-NEXT: # %bb.0: # %start +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.any_true +; CHECK-NEXT: # fallthrough-return +start: + %_3 = bitcast <4 x i32> %a to <16 x i8> + %0 = tail call i8 @llvm.vector.reduce.or.v16x8(<16 x i8> %_3) + %_0 = icmp ne i8 %0, 0 + ret i1 %_0 +} + +define i1 @i16x8_any_true(<4 x i32> %a) { +; CHECK-LABEL: i16x8_any_true: +; CHECK: .functype i16x8_any_true (v128) -> (i32) +; CHECK-NEXT: # %bb.0: # %start +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.any_true +; CHECK-NEXT: # fallthrough-return +start: + %_3 = bitcast <4 x i32> %a to <8 x i16> + %0 = tail call i16 @llvm.vector.reduce.or.v8i16(<8 x i16> %_3) + %_0 = icmp ne i16 %0, 0 + ret i1 %_0 +} + +define i1 @manual_i32x4_any_true(<4 x i32> %a) { +; CHECK-LABEL: manual_i32x4_any_true: +; CHECK: .functype manual_i32x4_any_true (v128) -> (i32) +; CHECK-NEXT: # %bb.0: # %start +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.any_true +; CHECK-NEXT: # fallthrough-return +start: + %0 = tail call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> %a) + %_0 = icmp ne i32 %0, 0 + ret i1 %_0 +} + + +define i1 @manual_i64x2_any_true(<2 x i64> %a) { +; CHECK-LABEL: manual_i64x2_any_true: +; CHECK: .functype manual_i64x2_any_true (v128) -> (i32) +; CHECK-NEXT: # %bb.0: # %start +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.any_true +; CHECK-NEXT: # fallthrough-return +start: + %0 = tail call i64 @llvm.vector.reduce.or.v2i64(<2 x i64> %a) + %_0 = icmp ne i64 %0, 0 + ret i1 %_0 +}