Skip to content

Commit

Permalink
[AArch64][SelectionDAG] Support passing/returning scalable vectors wi…
Browse files Browse the repository at this point in the history
…th unusual types.

This adds handling for two cases:

1. A scalable vector where the element type is promoted.
2. A scalable vector where the element count is odd (or more generally,
   not divisble by the element count of the part type).

(Some element types still don't work; for example, <vscale x 2 x i128>,
or <vscale x 2 x fp128>.)

Differential Revision: https://reviews.llvm.org/D105591
  • Loading branch information
efriedma-quic committed Aug 2, 2021
1 parent b40a2a5 commit 1f62af6
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 29 deletions.
44 changes: 25 additions & 19 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Expand Up @@ -399,29 +399,31 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
return Val;

if (PartEVT.isVector()) {
// Vector/Vector bitcast.
if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);

// If the element type of the source/dest vectors are the same, but the
// parts vector has more elements than the value vector, then we have a
// vector widening case (e.g. <2 x float> -> <4 x float>). Extract the
// elements we want.
if (PartEVT.getVectorElementType() == ValueVT.getVectorElementType()) {
if (PartEVT.getVectorElementCount() != ValueVT.getVectorElementCount()) {
assert((PartEVT.getVectorElementCount().getKnownMinValue() >
ValueVT.getVectorElementCount().getKnownMinValue()) &&
(PartEVT.getVectorElementCount().isScalable() ==
ValueVT.getVectorElementCount().isScalable()) &&
"Cannot narrow, it would be a lossy transformation");
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ValueVT, Val,
DAG.getVectorIdxConstant(0, DL));
PartEVT =
EVT::getVectorVT(*DAG.getContext(), PartEVT.getVectorElementType(),
ValueVT.getVectorElementCount());
Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartEVT, Val,
DAG.getVectorIdxConstant(0, DL));
if (PartEVT == ValueVT)
return Val;
}

// Vector/Vector bitcast.
if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);

assert(PartEVT.getVectorElementCount() == ValueVT.getVectorElementCount() &&
"Cannot handle this kind of promotion");
// Promoted vector extract
return DAG.getAnyExtOrTrunc(Val, DL, ValueVT);

}

// Trivial bitcast if the types are the same size and the destination
Expand Down Expand Up @@ -726,15 +728,19 @@ static void getCopyToPartsVector(SelectionDAG &DAG, const SDLoc &DL,
} else if (ValueVT.getSizeInBits() == BuiltVectorTy.getSizeInBits()) {
// Bitconvert vector->vector case.
Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val);
} else if (SDValue Widened =
widenVectorToPartType(DAG, Val, DL, BuiltVectorTy)) {
Val = Widened;
} else if (BuiltVectorTy.getVectorElementType().bitsGE(
ValueVT.getVectorElementType()) &&
BuiltVectorTy.getVectorElementCount() ==
ValueVT.getVectorElementCount()) {
// Promoted vector extract
Val = DAG.getAnyExtOrTrunc(Val, DL, BuiltVectorTy);
} else {
if (BuiltVectorTy.getVectorElementType().bitsGT(
ValueVT.getVectorElementType())) {
// Integer promotion.
ValueVT = EVT::getVectorVT(*DAG.getContext(),
BuiltVectorTy.getVectorElementType(),
ValueVT.getVectorElementCount());
Val = DAG.getNode(ISD::ANY_EXTEND, DL, ValueVT, Val);
}

if (SDValue Widened = widenVectorToPartType(DAG, Val, DL, BuiltVectorTy)) {
Val = Widened;
}
}

assert(Val.getValueType() == BuiltVectorTy && "Unexpected vector value type");
Expand Down
18 changes: 8 additions & 10 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Expand Up @@ -1556,7 +1556,7 @@ unsigned TargetLoweringBase::getVectorTypeBreakdown(LLVMContext &Context,

// Scalable vectors cannot be scalarized, so handle the legalisation of the
// types like done elsewhere in SelectionDAG.
if (VT.isScalableVector() && !isPowerOf2_32(EltCnt.getKnownMinValue())) {
if (EltCnt.isScalable()) {
LegalizeKind LK;
EVT PartVT = VT;
do {
Expand All @@ -1565,16 +1565,14 @@ unsigned TargetLoweringBase::getVectorTypeBreakdown(LLVMContext &Context,
PartVT = LK.second;
} while (LK.first != TypeLegal);

NumIntermediates = VT.getVectorElementCount().getKnownMinValue() /
PartVT.getVectorElementCount().getKnownMinValue();
if (!PartVT.isVector()) {
report_fatal_error(
"Don't know how to legalize this scalable vector type");
}

// FIXME: This code needs to be extended to handle more complex vector
// breakdowns, like nxv7i64 -> nxv8i64 -> 4 x nxv2i64. Currently the only
// supported cases are vectors that are broken down into equal parts
// such as nxv6i64 -> 3 x nxv2i64.
assert((PartVT.getVectorElementCount() * NumIntermediates) ==
VT.getVectorElementCount() &&
"Expected an integer multiple of PartVT");
NumIntermediates =
divideCeil(VT.getVectorElementCount().getKnownMinValue(),
PartVT.getVectorElementCount().getKnownMinValue());
IntermediateVT = PartVT;
RegisterVT = getRegisterType(Context, IntermediateVT);
return NumIntermediates;
Expand Down
61 changes: 61 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll
Expand Up @@ -689,3 +689,64 @@ L1:
L2:
ret <vscale x 8 x double> %illegal
}

define <vscale x 8 x i63> @wide_8i63(i1 %b, <vscale x 16 x i8> %legal, <vscale x 8 x i63> %illegal) nounwind {
; CHECK-LABEL: wide_8i63:
; CHECK: // %bb.0:
; CHECK-NEXT: tbnz w0, #0, .LBB21_2
; CHECK-NEXT: // %bb.1: // %L2
; CHECK-NEXT: mov z0.d, z1.d
; CHECK-NEXT: mov z1.d, z2.d
; CHECK-NEXT: mov z2.d, z3.d
; CHECK-NEXT: mov z3.d, z4.d
; CHECK-NEXT: ret
; CHECK-NEXT: .LBB21_2: // %L1
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: bl bar
br i1 %b, label %L1, label %L2
L1:
call aarch64_sve_vector_pcs void @bar()
unreachable
L2:
ret <vscale x 8 x i63> %illegal
}

define <vscale x 7 x i63> @wide_7i63(i1 %b, <vscale x 16 x i8> %legal, <vscale x 7 x i63> %illegal) nounwind {
; CHECK-LABEL: wide_7i63:
; CHECK: // %bb.0:
; CHECK-NEXT: tbnz w0, #0, .LBB22_2
; CHECK-NEXT: // %bb.1: // %L2
; CHECK-NEXT: mov z0.d, z1.d
; CHECK-NEXT: mov z1.d, z2.d
; CHECK-NEXT: mov z2.d, z3.d
; CHECK-NEXT: mov z3.d, z4.d
; CHECK-NEXT: ret
; CHECK-NEXT: .LBB22_2: // %L1
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: bl bar
br i1 %b, label %L1, label %L2
L1:
call aarch64_sve_vector_pcs void @bar()
unreachable
L2:
ret <vscale x 7 x i63> %illegal
}

define <vscale x 7 x i31> @wide_7i31(i1 %b, <vscale x 16 x i8> %legal, <vscale x 7 x i31> %illegal) nounwind {
; CHECK-LABEL: wide_7i31:
; CHECK: // %bb.0:
; CHECK-NEXT: tbnz w0, #0, .LBB23_2
; CHECK-NEXT: // %bb.1: // %L2
; CHECK-NEXT: mov z0.d, z1.d
; CHECK-NEXT: mov z1.d, z2.d
; CHECK-NEXT: ret
; CHECK-NEXT: .LBB23_2: // %L1
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: bl bar
br i1 %b, label %L1, label %L2
L1:
call aarch64_sve_vector_pcs void @bar()
unreachable
L2:
ret <vscale x 7 x i31> %illegal
}

0 comments on commit 1f62af6

Please sign in to comment.