Skip to content

Commit

Permalink
[SDAG] Allow scalable vectors in ComputeKnownBits (try 2)
Browse files Browse the repository at this point in the history
This was previously reverted due to a hang on a Hexagon bot.  This turned out to be a bug in the Hexagon backend around how splat_vectors are legalized (which they're using for fixed length vectors!).  I adjusted this patch to remove the implicit truncate support.  This hides the hexagon bug for now, and unblocks the rest of the change.

Original commit message:

This is the SelectionDAG equivalent of D136470, and is thus an alternate patch to D128159.

The basic idea here is that we track a single lane for scalable vectors which corresponds to an unknown number of lanes at runtime. This is enough for us to perform lane wise reasoning on many arithmetic operations.

This patch also includes an implementation for SPLAT_VECTOR as without it, the lane wise reasoning has no base case. The original patch which inspired this (D128159), also included STEP_VECTOR. I plan to do that as a separate patch.

Differential Revision: https://reviews.llvm.org/D137140
  • Loading branch information
preames committed Dec 5, 2022
1 parent 6887cfb commit 7969ab8
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 58 deletions.
64 changes: 48 additions & 16 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -2910,14 +2910,10 @@ const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
KnownBits SelectionDAG::computeKnownBits(SDValue Op, unsigned Depth) const {
EVT VT = Op.getValueType();

// TOOD: Until we have a plan for how to represent demanded elements for
// scalable vectors, we can just bail out for now.
if (Op.getValueType().isScalableVector()) {
unsigned BitWidth = Op.getScalarValueSizeInBits();
return KnownBits(BitWidth);
}

APInt DemandedElts = VT.isVector()
// Since the number of lanes in a scalable vector is unknown at compile time,
// we track one bit which is implicitly broadcast to all lanes. This means
// that all lanes in a scalable vector are considered demanded.
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
return computeKnownBits(Op, DemandedElts, Depth);
Expand All @@ -2932,11 +2928,6 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,

KnownBits Known(BitWidth); // Don't know anything.

// TOOD: Until we have a plan for how to represent demanded elements for
// scalable vectors, we can just bail out for now.
if (Op.getValueType().isScalableVector())
return Known;

if (auto *C = dyn_cast<ConstantSDNode>(Op)) {
// We know all of the bits for a constant!
return KnownBits::makeConstant(C->getAPIntValue());
Expand All @@ -2951,7 +2942,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,

KnownBits Known2;
unsigned NumElts = DemandedElts.getBitWidth();
assert((!Op.getValueType().isVector() ||
assert((!Op.getValueType().isFixedLengthVector() ||
NumElts == Op.getValueType().getVectorNumElements()) &&
"Unexpected vector size");

Expand All @@ -2963,7 +2954,23 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::MERGE_VALUES:
return computeKnownBits(Op.getOperand(Op.getResNo()), DemandedElts,
Depth + 1);
case ISD::SPLAT_VECTOR: {
SDValue SrcOp = Op.getOperand(0);
if (SrcOp.getValueSizeInBits() != BitWidth) {
assert(SrcOp.getValueSizeInBits() > BitWidth &&
"Expected SPLAT_VECTOR implicit truncation");
// FIXME: We should be able to truncate the known bits here to match
// the official semantics of SPLAT_VECTOR, but doing so exposes a
// Hexagon target bug which results in an infinite loop during
// DAGCombine. (See D137140 for repo). Once that's fixed, we can
// strengthen this.
break;
}
Known = computeKnownBits(SrcOp, Depth + 1);
break;
}
case ISD::BUILD_VECTOR:
assert(!Op.getValueType().isScalableVector());
// Collect the known bits that are shared by every demanded vector element.
Known.Zero.setAllBits(); Known.One.setAllBits();
for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) {
Expand All @@ -2989,6 +2996,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
}
break;
case ISD::VECTOR_SHUFFLE: {
assert(!Op.getValueType().isScalableVector());
// Collect the known bits that are shared by every vector element referenced
// by the shuffle.
APInt DemandedLHS, DemandedRHS;
Expand Down Expand Up @@ -3016,6 +3024,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::CONCAT_VECTORS: {
if (Op.getValueType().isScalableVector())
break;
// Split DemandedElts and test each of the demanded subvectors.
Known.Zero.setAllBits(); Known.One.setAllBits();
EVT SubVectorVT = Op.getOperand(0).getValueType();
Expand All @@ -3036,6 +3046,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::INSERT_SUBVECTOR: {
if (Op.getValueType().isScalableVector())
break;
// Demand any elements from the subvector and the remainder from the src its
// inserted into.
SDValue Src = Op.getOperand(0);
Expand Down Expand Up @@ -3063,7 +3075,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
// Offset the demanded elts by the subvector index.
SDValue Src = Op.getOperand(0);
// Bail until we can represent demanded elements for scalable vectors.
if (Src.getValueType().isScalableVector())
if (Op.getValueType().isScalableVector() || Src.getValueType().isScalableVector())
break;
uint64_t Idx = Op.getConstantOperandVal(1);
unsigned NumSrcElts = Src.getValueType().getVectorNumElements();
Expand All @@ -3072,6 +3084,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::SCALAR_TO_VECTOR: {
if (Op.getValueType().isScalableVector())
break;
// We know about scalar_to_vector as much as we know about it source,
// which becomes the first element of otherwise unknown vector.
if (DemandedElts != 1)
Expand All @@ -3085,6 +3099,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::BITCAST: {
if (Op.getValueType().isScalableVector())
break;

SDValue N0 = Op.getOperand(0);
EVT SubVT = N0.getValueType();
unsigned SubBitWidth = SubVT.getScalarSizeInBits();
Expand Down Expand Up @@ -3406,7 +3423,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
if (ISD::isNON_EXTLoad(LD) && Cst) {
// Determine any common known bits from the loaded constant pool value.
Type *CstTy = Cst->getType();
if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits()) {
if ((NumElts * BitWidth) == CstTy->getPrimitiveSizeInBits() &&
!Op.getValueType().isScalableVector()) {
// If its a vector splat, then we can (quickly) reuse the scalar path.
// NOTE: We assume all elements match and none are UNDEF.
if (CstTy->isVectorTy()) {
Expand Down Expand Up @@ -3480,6 +3498,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::ZERO_EXTEND_VECTOR_INREG: {
if (Op.getValueType().isScalableVector())
break;
EVT InVT = Op.getOperand(0).getValueType();
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
Expand All @@ -3492,6 +3512,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::SIGN_EXTEND_VECTOR_INREG: {
if (Op.getValueType().isScalableVector())
break;
EVT InVT = Op.getOperand(0).getValueType();
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
Expand All @@ -3508,6 +3530,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::ANY_EXTEND_VECTOR_INREG: {
if (Op.getValueType().isScalableVector())
break;
EVT InVT = Op.getOperand(0).getValueType();
APInt InDemandedElts = DemandedElts.zext(InVT.getVectorNumElements());
Known = computeKnownBits(Op.getOperand(0), InDemandedElts, Depth + 1);
Expand Down Expand Up @@ -3673,6 +3697,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
}
case ISD::INSERT_VECTOR_ELT: {
if (Op.getValueType().isScalableVector())
break;

// If we know the element index, split the demand between the
// source vector and the inserted element, otherwise assume we need
// the original demanded vector elements and the value.
Expand Down Expand Up @@ -3839,6 +3866,11 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::INTRINSIC_WO_CHAIN:
case ISD::INTRINSIC_W_CHAIN:
case ISD::INTRINSIC_VOID:
// TODO: Probably okay to remove after audit; here to reduce change size
// in initial enablement patch for scalable vectors
if (Op.getValueType().isScalableVector())
break;

// Allow the target to implement this method for its nodes.
TLI->computeKnownBitsForTargetNode(Op, Known, DemandedElts, *this, Depth);
break;
Expand Down
3 changes: 2 additions & 1 deletion llvm/test/CodeGen/AArch64/sve-intrinsics-index.ll
Expand Up @@ -55,7 +55,8 @@ define <vscale x 2 x i64> @index_ii_range() {
define <vscale x 8 x i16> @index_ii_range_combine(i16 %a) {
; CHECK-LABEL: index_ii_range_combine:
; CHECK: // %bb.0:
; CHECK-NEXT: index z0.h, #2, #8
; CHECK-NEXT: index z0.h, #0, #8
; CHECK-NEXT: orr z0.h, z0.h, #0x2
; CHECK-NEXT: ret
%val = insertelement <vscale x 8 x i16> poison, i16 2, i32 0
%val1 = shufflevector <vscale x 8 x i16> %val, <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/sve-intrinsics-perm-select.ll
Expand Up @@ -574,7 +574,7 @@ define <vscale x 2 x i64> @dupq_i64_range(<vscale x 2 x i64> %a) {
; CHECK: // %bb.0:
; CHECK-NEXT: index z1.d, #0, #1
; CHECK-NEXT: and z1.d, z1.d, #0x1
; CHECK-NEXT: add z1.d, z1.d, #8 // =0x8
; CHECK-NEXT: orr z1.d, z1.d, #0x8
; CHECK-NEXT: tbl z0.d, { z0.d }, z1.d
; CHECK-NEXT: ret
%out = call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %a, i64 4)
Expand Down
55 changes: 15 additions & 40 deletions llvm/test/CodeGen/AArch64/sve-umulo-sdnode.ll
Expand Up @@ -9,15 +9,10 @@ define <vscale x 2 x i8> @umulo_nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: and z1.d, z1.d, #0xff
; CHECK-NEXT: and z0.d, z0.d, #0xff
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
; CHECK-NEXT: lsr z1.d, z2.d, #8
; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
; CHECK-NEXT: lsr z1.d, z0.d, #8
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 2 x i8>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i8(<vscale x 2 x i8> %x, <vscale x 2 x i8> %y)
%b = extractvalue { <vscale x 2 x i8>, <vscale x 2 x i1> } %a, 0
Expand All @@ -34,15 +29,10 @@ define <vscale x 4 x i8> @umulo_nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: and z1.s, z1.s, #0xff
; CHECK-NEXT: and z0.s, z0.s, #0xff
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s
; CHECK-NEXT: umulh z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: lsr z1.s, z2.s, #8
; CHECK-NEXT: cmpne p1.s, p0/z, z0.s, #0
; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: lsr z1.s, z0.s, #8
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
; CHECK-NEXT: mov z2.s, p0/m, #0 // =0x0
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 4 x i8>, <vscale x 4 x i1> } @llvm.umul.with.overflow.nxv4i8(<vscale x 4 x i8> %x, <vscale x 4 x i8> %y)
%b = extractvalue { <vscale x 4 x i8>, <vscale x 4 x i1> } %a, 0
Expand Down Expand Up @@ -164,15 +154,10 @@ define <vscale x 2 x i16> @umulo_nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i1
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: and z1.d, z1.d, #0xffff
; CHECK-NEXT: and z0.d, z0.d, #0xffff
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
; CHECK-NEXT: lsr z1.d, z2.d, #16
; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
; CHECK-NEXT: lsr z1.d, z0.d, #16
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 2 x i16>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i16(<vscale x 2 x i16> %x, <vscale x 2 x i16> %y)
%b = extractvalue { <vscale x 2 x i16>, <vscale x 2 x i1> } %a, 0
Expand All @@ -189,15 +174,10 @@ define <vscale x 4 x i16> @umulo_nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i1
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: and z1.s, z1.s, #0xffff
; CHECK-NEXT: and z0.s, z0.s, #0xffff
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s
; CHECK-NEXT: umulh z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: lsr z1.s, z2.s, #16
; CHECK-NEXT: cmpne p1.s, p0/z, z0.s, #0
; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s
; CHECK-NEXT: lsr z1.s, z0.s, #16
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
; CHECK-NEXT: mov z2.s, p0/m, #0 // =0x0
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 4 x i16>, <vscale x 4 x i1> } @llvm.umul.with.overflow.nxv4i16(<vscale x 4 x i16> %x, <vscale x 4 x i16> %y)
%b = extractvalue { <vscale x 4 x i16>, <vscale x 4 x i1> } %a, 0
Expand Down Expand Up @@ -294,15 +274,10 @@ define <vscale x 2 x i32> @umulo_nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i3
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: and z1.d, z1.d, #0xffffffff
; CHECK-NEXT: and z0.d, z0.d, #0xffffffff
; CHECK-NEXT: movprfx z2, z0
; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d
; CHECK-NEXT: umulh z0.d, p0/m, z0.d, z1.d
; CHECK-NEXT: lsr z1.d, z2.d, #32
; CHECK-NEXT: cmpne p1.d, p0/z, z0.d, #0
; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d
; CHECK-NEXT: lsr z1.d, z0.d, #32
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b
; CHECK-NEXT: mov z2.d, p0/m, #0 // =0x0
; CHECK-NEXT: mov z0.d, z2.d
; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0
; CHECK-NEXT: ret
%a = call { <vscale x 2 x i32>, <vscale x 2 x i1> } @llvm.umul.with.overflow.nxv2i32(<vscale x 2 x i32> %x, <vscale x 2 x i32> %y)
%b = extractvalue { <vscale x 2 x i32>, <vscale x 2 x i1> } %a, 0
Expand Down

0 comments on commit 7969ab8

Please sign in to comment.