Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1962,10 +1962,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

// We can lower types that have <vscale x {2|4}> elements to compact.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64,
MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16,
MVT::nxv4i32, MVT::nxv4f32}) {
setOperationAction(ISD::MSTORE, VT, Custom);
// Use a custom lowering for masked stores that could be a supported
// compressing store. Note: These types still use the normal (Legal)
// lowering for non-compressing masked stores.
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
}

// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
Expand Down Expand Up @@ -7740,7 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::STORE:
return LowerSTORE(Op, DAG);
case ISD::MSTORE:
return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
return LowerMSTORE(Op, DAG);
case ISD::MGATHER:
return LowerMGATHER(Op, DAG);
case ISD::MSCATTER:
Expand Down Expand Up @@ -30180,6 +30185,36 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
Store->isTruncatingStore());
}

SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
auto *Store = cast<MaskedStoreSDNode>(Op);
EVT VT = Store->getValue().getValueType();
if (VT.isFixedLengthVector())
return LowerFixedLengthVectorMStoreToSVE(Op, DAG);

if (!Store->isCompressingStore())
return SDValue();

EVT MaskVT = Store->getMask().getValueType();

SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
SDValue CntActive =
DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i64, Store->getMask());
Comment on lines +30202 to +30203
Copy link
Collaborator

@paulwalker-arm paulwalker-arm Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While ISD::VECREDUCE_ADD supports a larger result type than its operand's element type, the extra bits are undefined. In this case that means only the bottom bit of the result can be relied upon, with the operation likely converted to an ISD::VECREDUCE_XOR. Perhaps this is another use case that ISD::PARTIAL_REDUCE_#MLA can solve? with its implicit operand extension.

Fixing the above might make this request impossible but can this be done as target agnostic expansion?

Copy link
Member Author

@MacDue MacDue Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about using ISD::PARTIAL_REDUCE_#MLA, but the following works:

  EVT MaskExtVT = getPromotedVTForPredicate(MaskVT);
  EVT MaskReduceVT = MaskExtVT.getScalarType();

  SDValue MaskExt =
      DAG.getNode(ISD::ZERO_EXTEND, DL, MaskExtVT, Store->getMask());
  SDValue CntActive =
      DAG.getNode(ISD::VECREDUCE_ADD, DL, MaskReduceVT, MaskExt);
  if (MaskReduceVT != MVT::i64)
    CntActive = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CntActive);

If we also define that the demanded bits for aarch64_sve_cntp is a most 9 (max value 256, AFAIK), which allows the ZERO_EXTEND to fold away.

(This actually improves code gen as the ptrue is folded into the cntp)

SDValue CompressedValue =
DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(),
Store->getMask(), DAG.getPOISON(VT));
SDValue CompressedMask =
DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive);

return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue,
Store->getBasePtr(), Store->getOffset(),
CompressedMask, Store->getMemoryVT(),
Store->getMemOperand(), Store->getAddressingMode(),
Store->isTruncatingStore(),
/*isCompressing=*/false);
}

SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
SDValue Op, SelectionDAG &DAG) const {
auto *Store = cast<MaskedStoreSDNode>(Op);
Expand All @@ -30194,7 +30229,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
return DAG.getMaskedStore(
Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
Mask, Store->getMemoryVT(), Store->getMemOperand(),
Store->getAddressingMode(), Store->isTruncatingStore());
Store->getAddressingMode(), Store->isTruncatingStore(),
Store->isCompressingStore());
}

SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMSTORE(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;

Expand Down
18 changes: 12 additions & 6 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -645,37 +645,43 @@ def nontrunc_masked_store :
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
!cast<MaskedStoreSDNode>(N)->isNonTemporal();
!cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
// truncating masked store fragments.
def trunc_masked_store :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed();
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i8 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8 &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i16 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16 &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i32 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32 &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;

def non_temporal_store :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
cast<MaskedStoreSDNode>(N)->isNonTemporal();
cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
!cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;

multiclass masked_gather_scatter<PatFrags GatherScatterOp> {
Expand Down
23 changes: 23 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,29 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}

bool isElementTypeLegalForCompressStore(Type *Ty) const {
if (Ty->isFloatTy() || Ty->isDoubleTy())
return true;

if (Ty->isIntegerTy(8) || Ty->isIntegerTy(16) || Ty->isIntegerTy(32) ||
Ty->isIntegerTy(64))
return true;

return false;
}

bool isLegalMaskedCompressStore(Type *DataType,
Align Alignment) const override {
ElementCount EC = cast<VectorType>(DataType)->getElementCount();
if (EC.getKnownMinValue() != 2 && EC.getKnownMinValue() != 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That disqualifies compress stores for larger vectors with -msve-vector-bits, right? Can you loosen the restriction and also add some tests for SVE vectors > 128bits?

return false;

if (!isElementTypeLegalForCompressStore(DataType->getScalarType()))
return false;

return isLegalMaskedLoadStore(DataType, Alignment);
}

bool isLegalMaskedGatherScatter(Type *DataType) const {
if (!ST->isSVEAvailable())
return false;
Expand Down
141 changes: 141 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-masked-compressstore.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s

;; Full SVE vectors (supported with +sve)

define void @test_compressstore_nxv4i32(ptr %p, <vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: test_compressstore_nxv4i32:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: cntp x8, p1, p0.s
; CHECK-NEXT: whilelo p0.s, xzr, x8
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.nxv4i32(<vscale x 4 x i32> %vec, ptr align 4 %p, <vscale x 4 x i1> %mask)
ret void
}

define void @test_compressstore_nxv2i64(ptr %p, <vscale x 2 x i64> %vec, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: test_compressstore_nxv2i64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: compact z0.d, p0, z0.d
; CHECK-NEXT: cntp x8, p1, p0.d
; CHECK-NEXT: whilelo p0.d, xzr, x8
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.nxv2i64(<vscale x 2 x i64> %vec, ptr align 8 %p, <vscale x 2 x i1> %mask)
ret void
}

define void @test_compressstore_nxv4f32(ptr %p, <vscale x 4 x float> %vec, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: test_compressstore_nxv4f32:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: cntp x8, p1, p0.s
; CHECK-NEXT: whilelo p0.s, xzr, x8
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.nxv4f32(<vscale x 4 x float> %vec, ptr align 4 %p, <vscale x 4 x i1> %mask)
ret void
}

; TODO: Legal and nonstreaming check
define void @test_compressstore_nxv2f64(ptr %p, <vscale x 2 x double> %vec, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: test_compressstore_nxv2f64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: compact z0.d, p0, z0.d
; CHECK-NEXT: cntp x8, p1, p0.d
; CHECK-NEXT: whilelo p0.d, xzr, x8
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.nxv2f64(<vscale x 2 x double> %vec, ptr align 8 %p, <vscale x 2 x i1> %mask)
ret void
}

;; SVE vector types promoted to 32/64-bit (non-exhaustive)

define void @test_compressstore_nxv2i8(ptr %p, <vscale x 2 x i8> %vec, <vscale x 2 x i1> %mask) {
; CHECK-LABEL: test_compressstore_nxv2i8:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: compact z0.d, p0, z0.d
; CHECK-NEXT: cntp x8, p1, p0.d
; CHECK-NEXT: whilelo p0.d, xzr, x8
; CHECK-NEXT: st1b { z0.d }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.nxv2i8(<vscale x 2 x i8> %vec, ptr align 1 %p, <vscale x 2 x i1> %mask)
ret void
}

define void @test_compressstore_nxv4i16(ptr %p, <vscale x 4 x i16> %vec, <vscale x 4 x i1> %mask) {
; CHECK-LABEL: test_compressstore_nxv4i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: cntp x8, p1, p0.s
; CHECK-NEXT: whilelo p0.s, xzr, x8
; CHECK-NEXT: st1h { z0.s }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.nxv4i16(<vscale x 4 x i16> %vec, ptr align 2 %p, <vscale x 4 x i1> %mask)
ret void
}

;; NEON vector types (promoted to SVE)

define void @test_compressstore_v2f32(ptr %p, <2 x double> %vec, <2 x i1> %mask) {
; CHECK-LABEL: test_compressstore_v2f32:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v1.2d, v1.2s, #0
; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: shl v1.2d, v1.2d, #63
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
; CHECK-NEXT: cntp x8, p1, p0.d
; CHECK-NEXT: compact z0.d, p0, z0.d
; CHECK-NEXT: whilelo p0.d, xzr, x8
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.v2f64(<2 x double> %vec, ptr align 8 %p, <2 x i1> %mask)
ret void
}

define void @test_compressstore_v4i32(ptr %p, <4 x i32> %vec, <4 x i1> %mask) {
; CHECK-LABEL: test_compressstore_v4i32:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v1.4s, v1.4h, #0
; CHECK-NEXT: ptrue p0.s, vl4
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: shl v1.4s, v1.4s, #31
; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
; CHECK-NEXT: cntp x8, p1, p0.s
; CHECK-NEXT: compact z0.s, p0, z0.s
; CHECK-NEXT: whilelo p0.s, xzr, x8
; CHECK-NEXT: st1w { z0.s }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.v4i32(<4 x i32> %vec, ptr align 4 %p, <4 x i1> %mask)
ret void
}

define void @test_compressstore_v2i64(ptr %p, <2 x i64> %vec, <2 x i1> %mask) {
; CHECK-LABEL: test_compressstore_v2i64:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v1.2d, v1.2s, #0
; CHECK-NEXT: ptrue p0.d, vl2
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: shl v1.2d, v1.2d, #63
; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
; CHECK-NEXT: cntp x8, p1, p0.d
; CHECK-NEXT: compact z0.d, p0, z0.d
; CHECK-NEXT: whilelo p0.d, xzr, x8
; CHECK-NEXT: st1d { z0.d }, p0, [x0]
; CHECK-NEXT: ret
tail call void @llvm.masked.compressstore.v2i64(<2 x i64> %vec, ptr align 8 %p, <2 x i1> %mask)
ret void
}
Loading