Skip to content

Commit

Permalink
[SVE][SelectionDAG] Use INDEX to generate matching instances of BUILD…
Browse files Browse the repository at this point in the history
…_VECTOR.

This patch starts small, only detecting sequences of the form
<a, a+n, a+2n, a+3n, ...> where a and n are ConstantSDNodes.

Differential Revision: https://reviews.llvm.org/D125194
  • Loading branch information
paulwalker-arm committed Jul 26, 2022
1 parent a97bb48 commit e5c892d
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 0 deletions.
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Expand Up @@ -53,6 +53,7 @@
#include <iterator>
#include <string>
#include <tuple>
#include <utility>

namespace llvm {

Expand Down Expand Up @@ -2079,6 +2080,11 @@ class BuildVectorSDNode : public SDNode {

bool isConstant() const;

/// If this BuildVector is constant and represents the numerical series
/// <a, a+n, a+2n, a+3n, ...> where a is integer and n is a non-zero integer,
/// the value <a,n> is returned.
Optional<std::pair<APInt, APInt>> isConstantSequence() const;

/// Recast bit data \p SrcBitElements to \p DstEltSizeInBits wide elements.
/// Undef elements are treated as zero, and entirely undefined elements are
/// flagged in \p DstUndefElements.
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -11690,6 +11690,35 @@ bool BuildVectorSDNode::isConstant() const {
return true;
}

Optional<std::pair<APInt, APInt>>
BuildVectorSDNode::isConstantSequence() const {
unsigned NumOps = getNumOperands();
if (NumOps < 2)
return None;

if (!isa<ConstantSDNode>(getOperand(0)) ||
!isa<ConstantSDNode>(getOperand(1)))
return None;

unsigned EltSize = getValueType(0).getScalarSizeInBits();
APInt Start = getConstantOperandAPInt(0).trunc(EltSize);
APInt Stride = getConstantOperandAPInt(1).trunc(EltSize) - Start;

if (Stride.isZero())
return None;

for (unsigned i = 2; i < NumOps; ++i) {
if (!isa<ConstantSDNode>(getOperand(i)))
return None;

APInt Val = getConstantOperandAPInt(i).trunc(EltSize);
if (Val != (Start + (Stride * i)))
return None;
}

return std::make_pair(Start, Stride);
}

bool ShuffleVectorSDNode::isSplatMask(const int *Mask, EVT VT) {
// Find the first non-undef value in the shuffle mask.
unsigned i, e;
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -1634,6 +1634,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::BITREVERSE, VT, Custom);
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
setOperationAction(ISD::BSWAP, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::CTLZ, VT, Custom);
Expand Down Expand Up @@ -11141,6 +11142,20 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op.getValueType();

if (useSVEForFixedLengthVectorVT(VT)) {
if (auto SeqInfo = cast<BuildVectorSDNode>(Op)->isConstantSequence()) {
SDLoc DL(Op);
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
SDValue Start = DAG.getConstant(SeqInfo->first, DL, ContainerVT);
SDValue Steps = DAG.getStepVector(DL, ContainerVT, SeqInfo->second);
SDValue Seq = DAG.getNode(ISD::ADD, DL, ContainerVT, Start, Steps);
return convertFromScalableVector(DAG, Op.getValueType(), Seq);
}

// Revert to common legalisation for all other variants.
return SDValue();
}

// Try to build a simple constant vector.
Op = NormalizeBuildVector(Op, DAG);
if (VT.isInteger()) {
Expand Down
105 changes: 105 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll
@@ -0,0 +1,105 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s -check-prefixes=CHECK,VBITS_GE_256
; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s -check-prefixes=CHECK,VBITS_GE_512

target triple = "aarch64-unknown-linux-gnu"

define void @masked_gather_base_plus_stride_v8f32(ptr %dst, ptr %src) #0 {
; VBITS_GE_256-LABEL: masked_gather_base_plus_stride_v8f32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.d, #0, #7
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: mov z1.d, z0.d
; VBITS_GE_256-NEXT: ld1w { z0.d }, p0/z, [x1, z0.d, lsl #2]
; VBITS_GE_256-NEXT: add z1.d, z1.d, #28 // =0x1c
; VBITS_GE_256-NEXT: ld1w { z1.d }, p0/z, [x1, z1.d, lsl #2]
; VBITS_GE_256-NEXT: ptrue p0.s, vl4
; VBITS_GE_256-NEXT: uzp1 z0.s, z0.s, z0.s
; VBITS_GE_256-NEXT: uzp1 z1.s, z1.s, z1.s
; VBITS_GE_256-NEXT: splice z0.s, p0, z0.s, z1.s
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
;
; VBITS_GE_512-LABEL: masked_gather_base_plus_stride_v8f32:
; VBITS_GE_512: // %bb.0:
; VBITS_GE_512-NEXT: index z0.d, #0, #7
; VBITS_GE_512-NEXT: ptrue p0.d, vl8
; VBITS_GE_512-NEXT: ld1w { z0.d }, p0/z, [x1, z0.d, lsl #2]
; VBITS_GE_512-NEXT: ptrue p0.s, vl8
; VBITS_GE_512-NEXT: uzp1 z0.s, z0.s, z0.s
; VBITS_GE_512-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_512-NEXT: ret
%ptrs = getelementptr float, ptr %src, <8 x i64> <i64 0, i64 7, i64 14, i64 21, i64 28, i64 35, i64 42, i64 49>
%data = tail call <8 x float> @llvm.masked.gather.v8f32.v8p0(<8 x ptr> %ptrs, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x float> undef)
store <8 x float> %data, ptr %dst, align 4
ret void
}

define void @masked_gather_base_plus_stride_v4f64(ptr %dst, ptr %src) #0 {
; CHECK-LABEL: masked_gather_base_plus_stride_v4f64:
; CHECK: // %bb.0:
; CHECK-NEXT: mov x8, #-32
; CHECK-NEXT: ptrue p0.d, vl4
; CHECK-NEXT: index z0.d, #-2, x8
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x1, z0.d, lsl #3]
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
; CHECK-NEXT: ret
%ptrs = getelementptr double, ptr %src, <4 x i64> <i64 -2, i64 -34, i64 -66, i64 -98>
%data = tail call <4 x double> @llvm.masked.gather.v4f64.v4p0(<4 x ptr> %ptrs, i32 8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x double> undef)
store <4 x double> %data, ptr %dst, align 8
ret void
}

define void @masked_scatter_base_plus_stride_v8f32(ptr %dst, ptr %src) #0 {
; VBITS_GE_256-LABEL: masked_scatter_base_plus_stride_v8f32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: mov z1.d, #-28 // =0xffffffffffffffe4
; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x1]
; VBITS_GE_256-NEXT: index z2.d, #0, #-7
; VBITS_GE_256-NEXT: add z1.d, z2.d, z1.d
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: uunpklo z3.d, z0.s
; VBITS_GE_256-NEXT: ext z0.b, z0.b, z0.b, #16
; VBITS_GE_256-NEXT: uunpklo z0.d, z0.s
; VBITS_GE_256-NEXT: st1w { z3.d }, p0, [x0, z2.d, lsl #2]
; VBITS_GE_256-NEXT: st1w { z0.d }, p0, [x0, z1.d, lsl #2]
; VBITS_GE_256-NEXT: ret
;
; VBITS_GE_512-LABEL: masked_scatter_base_plus_stride_v8f32:
; VBITS_GE_512: // %bb.0:
; VBITS_GE_512-NEXT: ptrue p0.s, vl8
; VBITS_GE_512-NEXT: index z1.d, #0, #-7
; VBITS_GE_512-NEXT: ld1w { z0.s }, p0/z, [x1]
; VBITS_GE_512-NEXT: ptrue p0.d, vl8
; VBITS_GE_512-NEXT: uunpklo z0.d, z0.s
; VBITS_GE_512-NEXT: st1w { z0.d }, p0, [x0, z1.d, lsl #2]
; VBITS_GE_512-NEXT: ret
%data = load <8 x float>, ptr %src, align 4
%ptrs = getelementptr float, ptr %dst, <8 x i64> <i64 0, i64 -7, i64 -14, i64 -21, i64 -28, i64 -35, i64 -42, i64 -49>
tail call void @llvm.masked.scatter.v8f32.v8p0(<8 x float> %data, <8 x ptr> %ptrs, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
ret void
}

define void @masked_scatter_base_plus_stride_v4f64(ptr %dst, ptr %src) #0 {
; CHECK-LABEL: masked_scatter_base_plus_stride_v4f64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.d, vl4
; CHECK-NEXT: index z1.d, #-2, #3
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x1]
; CHECK-NEXT: st1d { z0.d }, p0, [x0, z1.d, lsl #3]
; CHECK-NEXT: ret
%data = load <4 x double>, ptr %src, align 8
%ptrs = getelementptr double, ptr %dst, <4 x i64> <i64 -2, i64 1, i64 4, i64 7>
tail call void @llvm.masked.scatter.v4f64.v4p0(<4 x double> %data, <4 x ptr> %ptrs, i32 8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
ret void
}

declare <8 x float> @llvm.masked.gather.v8f32.v8p0(<8 x ptr>, i32 immarg, <8 x i1>, <8 x float>)
declare <4 x double> @llvm.masked.gather.v4f64.v4p0(<4 x ptr>, i32 immarg, <4 x i1>, <4 x double>)

declare void @llvm.masked.scatter.v8f32.v8p0(<8 x float>, <8 x ptr>, i32 immarg, <8 x i1>)
declare void @llvm.masked.scatter.v4f64.v4p0(<4 x double>, <4 x ptr>, i32 immarg, <4 x i1>)

attributes #0 = { "target-features"="+sve" }
73 changes: 73 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-build-vector.ll
@@ -0,0 +1,73 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s -check-prefix=VBITS_GE_256
; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s -check-prefix=VBITS_GE_256

target triple = "aarch64-unknown-linux-gnu"

define void @build_vector_7_inc1_v32i8(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_7_inc1_v32i8:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: ptrue p0.b, vl32
; VBITS_GE_256-NEXT: index z0.b, #7, #1
; VBITS_GE_256-NEXT: st1b { z0.b }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <32 x i8> <i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 17, i8 18, i8 19, i8 20, i8 21, i8 22, i8 23, i8 24, i8 25, i8 26, i8 27, i8 28, i8 29, i8 30, i8 31, i8 32, i8 33, i8 34, i8 35, i8 36, i8 37, i8 38>, ptr %a, align 1
ret void
}

define void @build_vector_0_inc2_v16i16(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_0_inc2_v16i16:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: ptrue p0.h, vl16
; VBITS_GE_256-NEXT: index z0.h, #0, #2
; VBITS_GE_256-NEXT: st1h { z0.h }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <16 x i16> <i16 0, i16 2, i16 4, i16 6, i16 8, i16 10, i16 12, i16 14, i16 16, i16 18, i16 20, i16 22, i16 24, i16 26, i16 28, i16 30>, ptr %a, align 2
ret void
}

; Negative const stride.
define void @build_vector_0_dec3_v8i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_0_dec3_v8i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: index z0.s, #0, #-3
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i32> <i32 0, i32 -3, i32 -6, i32 -9, i32 -12, i32 -15, i32 -18, i32 -21>, ptr %a, align 4
ret void
}

; Constant stride that's too big to be directly encoded into the index.
define void @build_vector_minus2_dec32_v4i64(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_minus2_dec32_v4i64:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: mov x8, #-32
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: index z0.d, #-2, x8
; VBITS_GE_256-NEXT: st1d { z0.d }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <4 x i64> <i64 -2, i64 -34, i64 -66, i64 -98>, ptr %a, align 8
ret void
}

; Constant but not a sequence.
define void @build_vector_no_stride_v4i64(ptr %a) #0 {
; VBITS_GE_256-LABEL: .LCPI4_0:
; VBITS_GE_256: .xword 0
; VBITS_GE_256-NEXT: .xword 4
; VBITS_GE_256-NEXT: .xword 1
; VBITS_GE_256-NEXT: .xword 8
; VBITS_GE_256-LABEL: build_vector_no_stride_v4i64:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: adrp x8, .LCPI4_0
; VBITS_GE_256-NEXT: add x8, x8, :lo12:.LCPI4_0
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x8]
; VBITS_GE_256-NEXT: st1d { z0.d }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <4 x i64> <i64 0, i64 4, i64 1, i64 8>, ptr %a, align 8
ret void
}

attributes #0 = { "target-features"="+sve" }

0 comments on commit e5c892d

Please sign in to comment.