Skip to content

Commit

Permalink
[SVE] Extend findMoreOptimalIndexType so BUILD_VECTORs do not force 6…
Browse files Browse the repository at this point in the history
…4bit indices.

Extends findMoreOptimalIndexType to allow ISD::BUILD_VECTOR based
indices to be truncated when such truncation is lossless. This can
enable the use of 32bit gather/scatter indices thus making it less
likely to have to split a gather/scatter in two.

Depends on D125194

Differential Revision: https://reviews.llvm.org/D130533
  • Loading branch information
paulwalker-arm committed Aug 18, 2022
1 parent 2b3be87 commit 96c8d61
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 55 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Expand Up @@ -118,6 +118,10 @@ bool isBuildVectorOfConstantSDNodes(const SDNode *N);
/// ConstantFPSDNode or undef.
bool isBuildVectorOfConstantFPSDNodes(const SDNode *N);

/// Returns true if the specified node is a vector where all elements can
/// be truncated to the specified element size without a loss in meaning.
bool isVectorShrinkable(const SDNode *N, unsigned NewEltSize, bool Signed);

/// Return true if the node has at least one operand and all operands of the
/// specified node are ISD::UNDEF.
bool allOperandsUndef(const SDNode *N);
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Expand Up @@ -291,6 +291,31 @@ bool ISD::isBuildVectorOfConstantFPSDNodes(const SDNode *N) {
return true;
}

bool ISD::isVectorShrinkable(const SDNode *N, unsigned NewEltSize,
bool Signed) {
if (N->getOpcode() != ISD::BUILD_VECTOR)
return false;

unsigned EltSize = N->getValueType(0).getScalarSizeInBits();
if (EltSize <= NewEltSize)
return false;

for (const SDValue &Op : N->op_values()) {
if (Op.isUndef())
continue;
if (!isa<ConstantSDNode>(Op))
return false;

APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().trunc(EltSize);
if (Signed && C.trunc(NewEltSize).sext(EltSize) != C)
return false;
if (!Signed && C.trunc(NewEltSize).zext(EltSize) != C)
return false;
}

return true;
}

bool ISD::allOperandsUndef(const SDNode *N) {
// Return false if the node has no operands.
// This is "logically inconsistent" with the definition of "all" but
Expand Down
14 changes: 10 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -17827,12 +17827,19 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
return Changed;

// Can indices be trivially shrunk?
if (ISD::isVectorShrinkable(Index.getNode(), 32, N->isIndexSigned())) {
EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
Index = DAG.getNode(ISD::TRUNCATE, SDLoc(N), NewIndexVT, Index);
return true;
}

// Match:
// Index = step(const)
int64_t Stride = 0;
if (Index.getOpcode() == ISD::STEP_VECTOR)
if (Index.getOpcode() == ISD::STEP_VECTOR) {
Stride = cast<ConstantSDNode>(Index.getOperand(0))->getSExtValue();

}
// Match:
// Index = step(const) << shift(const)
else if (Index.getOpcode() == ISD::SHL &&
Expand Down Expand Up @@ -17866,8 +17873,7 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
// Stride does not scale explicitly by 'Scale', because it happens in
// the gather/scatter addressing mode.
Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT,
DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32));
Index = DAG.getStepVector(SDLoc(N), NewIndexVT, APInt(32, Stride));
return true;
}

Expand Down
67 changes: 16 additions & 51 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll
@@ -1,35 +1,17 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s -check-prefixes=CHECK,VBITS_GE_256
; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s -check-prefixes=CHECK,VBITS_GE_512
; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s
; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s

target triple = "aarch64-unknown-linux-gnu"

define void @masked_gather_base_plus_stride_v8f32(ptr %dst, ptr %src) #0 {
; VBITS_GE_256-LABEL: masked_gather_base_plus_stride_v8f32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.d, #0, #7
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: mov z1.d, z0.d
; VBITS_GE_256-NEXT: ld1w { z0.d }, p0/z, [x1, z0.d, lsl #2]
; VBITS_GE_256-NEXT: add z1.d, z1.d, #28 // =0x1c
; VBITS_GE_256-NEXT: ld1w { z1.d }, p0/z, [x1, z1.d, lsl #2]
; VBITS_GE_256-NEXT: ptrue p0.s, vl4
; VBITS_GE_256-NEXT: uzp1 z0.s, z0.s, z0.s
; VBITS_GE_256-NEXT: uzp1 z1.s, z1.s, z1.s
; VBITS_GE_256-NEXT: splice z0.s, p0, z0.s, z1.s
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
;
; VBITS_GE_512-LABEL: masked_gather_base_plus_stride_v8f32:
; VBITS_GE_512: // %bb.0:
; VBITS_GE_512-NEXT: index z0.d, #0, #7
; VBITS_GE_512-NEXT: ptrue p0.d, vl8
; VBITS_GE_512-NEXT: ld1w { z0.d }, p0/z, [x1, z0.d, lsl #2]
; VBITS_GE_512-NEXT: ptrue p0.s, vl8
; VBITS_GE_512-NEXT: uzp1 z0.s, z0.s, z0.s
; VBITS_GE_512-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_512-NEXT: ret
; CHECK-LABEL: masked_gather_base_plus_stride_v8f32:
; CHECK: // %bb.0:
; CHECK-NEXT: index z0.s, #0, #7
; CHECK-NEXT: ptrue p0.s, vl8
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x1, z0.s, sxtw #2]
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
; CHECK-NEXT: ret
%ptrs = getelementptr float, ptr %src, <8 x i64> <i64 0, i64 7, i64 14, i64 21, i64 28, i64 35, i64 42, i64 49>
%data = tail call <8 x float> @llvm.masked.gather.v8f32.v8p0(<8 x ptr> %ptrs, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x float> undef)
store <8 x float> %data, ptr %dst, align 4
Expand All @@ -52,30 +34,13 @@ define void @masked_gather_base_plus_stride_v4f64(ptr %dst, ptr %src) #0 {
}

define void @masked_scatter_base_plus_stride_v8f32(ptr %dst, ptr %src) #0 {
; VBITS_GE_256-LABEL: masked_scatter_base_plus_stride_v8f32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: mov z1.d, #-28 // =0xffffffffffffffe4
; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x1]
; VBITS_GE_256-NEXT: index z2.d, #0, #-7
; VBITS_GE_256-NEXT: add z1.d, z2.d, z1.d
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: uunpklo z3.d, z0.s
; VBITS_GE_256-NEXT: ext z0.b, z0.b, z0.b, #16
; VBITS_GE_256-NEXT: uunpklo z0.d, z0.s
; VBITS_GE_256-NEXT: st1w { z3.d }, p0, [x0, z2.d, lsl #2]
; VBITS_GE_256-NEXT: st1w { z0.d }, p0, [x0, z1.d, lsl #2]
; VBITS_GE_256-NEXT: ret
;
; VBITS_GE_512-LABEL: masked_scatter_base_plus_stride_v8f32:
; VBITS_GE_512: // %bb.0:
; VBITS_GE_512-NEXT: ptrue p0.s, vl8
; VBITS_GE_512-NEXT: index z1.d, #0, #-7
; VBITS_GE_512-NEXT: ld1w { z0.s }, p0/z, [x1]
; VBITS_GE_512-NEXT: ptrue p0.d, vl8
; VBITS_GE_512-NEXT: uunpklo z0.d, z0.s
; VBITS_GE_512-NEXT: st1w { z0.d }, p0, [x0, z1.d, lsl #2]
; VBITS_GE_512-NEXT: ret
; CHECK-LABEL: masked_scatter_base_plus_stride_v8f32:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.s, vl8
; CHECK-NEXT: index z1.s, #0, #-7
; CHECK-NEXT: ld1w { z0.s }, p0/z, [x1]
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
; CHECK-NEXT: ret
%data = load <8 x float>, ptr %src, align 4
%ptrs = getelementptr float, ptr %dst, <8 x i64> <i64 0, i64 -7, i64 -14, i64 -21, i64 -28, i64 -35, i64 -42, i64 -49>
tail call void @llvm.masked.scatter.v8f32.v8p0(<8 x float> %data, <8 x ptr> %ptrs, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)
Expand Down

0 comments on commit 96c8d61

Please sign in to comment.