-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[AArch64][SVE] Add basic support for @llvm.masked.compressstore
#168350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This patch adds SVE support for the `masked.compressstore` intrinsic via the existing `VECTOR_COMPRESS` lowering and compressing the store mask via `VECREDUCE_ADD`. Currently, only `nxv4[i32|f32]` and `nxv2[i64|f64]` are directly supported, with other types promoted to these, where possible. This is done in preparation for LV support of this intrinsic, which is currently being worked on in llvm#140723.
|
@llvm/pr-subscribers-backend-aarch64 Author: Benjamin Maxwell (MacDue) ChangesThis patch adds SVE support for the Currently, only This is done in preparation for LV support of this intrinsic, which is currently being worked on in #140723. Full diff: https://github.com/llvm/llvm-project/pull/168350.diff 5 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 35836af3c874b..67bccb46939dc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1962,10 +1962,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
// We can lower types that have <vscale x {2|4}> elements to compact.
- for (auto VT :
- {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
- MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
+ for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64,
+ MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16,
+ MVT::nxv4i32, MVT::nxv4f32}) {
+ setOperationAction(ISD::MSTORE, VT, Custom);
+ // Use a custom lowering for masked stores that could be a supported
+ // compressing store. Note: These types still use the normal (Legal)
+ // lowering for non-compressing masked stores.
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
+ }
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
@@ -7740,7 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::STORE:
return LowerSTORE(Op, DAG);
case ISD::MSTORE:
- return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
+ return LowerMSTORE(Op, DAG);
case ISD::MGATHER:
return LowerMGATHER(Op, DAG);
case ISD::MSCATTER:
@@ -30180,6 +30185,36 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
Store->isTruncatingStore());
}
+SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ auto *Store = cast<MaskedStoreSDNode>(Op);
+ EVT VT = Store->getValue().getValueType();
+ if (VT.isFixedLengthVector())
+ return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
+
+ if (!Store->isCompressingStore())
+ return SDValue();
+
+ EVT MaskVT = Store->getMask().getValueType();
+
+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+ SDValue CntActive =
+ DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i64, Store->getMask());
+ SDValue CompressedValue =
+ DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(),
+ Store->getMask(), DAG.getPOISON(VT));
+ SDValue CompressedMask =
+ DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive);
+
+ return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue,
+ Store->getBasePtr(), Store->getOffset(),
+ CompressedMask, Store->getMemoryVT(),
+ Store->getMemOperand(), Store->getAddressingMode(),
+ Store->isTruncatingStore(),
+ /*isCompressing=*/false);
+}
+
SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
SDValue Op, SelectionDAG &DAG) const {
auto *Store = cast<MaskedStoreSDNode>(Op);
@@ -30194,7 +30229,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
return DAG.getMaskedStore(
Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
Mask, Store->getMemoryVT(), Store->getMemOperand(),
- Store->getAddressingMode(), Store->isTruncatingStore());
+ Store->getAddressingMode(), Store->isTruncatingStore(),
+ Store->isCompressingStore());
}
SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 70bfae717fb76..8fcef502cdab7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -755,6 +755,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerMSTORE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 50a3a4ab8d8b6..7e3643da4a2fd 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -645,29 +645,34 @@ def nontrunc_masked_store :
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
- !cast<MaskedStoreSDNode>(N)->isNonTemporal();
+ !cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
+ !cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
// truncating masked store fragments.
def trunc_masked_store :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
- cast<MaskedStoreSDNode>(N)->isUnindexed();
+ cast<MaskedStoreSDNode>(N)->isUnindexed() &&
+ !cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i8 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
- return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8;
+ return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i8 &&
+ !cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i16 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
- return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16;
+ return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i16 &&
+ !cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def trunc_masked_store_i32 :
PatFrag<(ops node:$val, node:$ptr, node:$pred),
(trunc_masked_store node:$val, node:$ptr, node:$pred), [{
- return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32;
+ return cast<MaskedStoreSDNode>(N)->getMemoryVT().getScalarType() == MVT::i32 &&
+ !cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
def non_temporal_store :
@@ -675,7 +680,8 @@ def non_temporal_store :
(masked_st node:$val, node:$ptr, undef, node:$pred), [{
return !cast<MaskedStoreSDNode>(N)->isTruncatingStore() &&
cast<MaskedStoreSDNode>(N)->isUnindexed() &&
- cast<MaskedStoreSDNode>(N)->isNonTemporal();
+ cast<MaskedStoreSDNode>(N)->isNonTemporal() &&
+ !cast<MaskedStoreSDNode>(N)->isCompressingStore();
}]>;
multiclass masked_gather_scatter<PatFrags GatherScatterOp> {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index d189f563f99a1..15ec8629de787 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -333,6 +333,29 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
return isLegalMaskedLoadStore(DataType, Alignment);
}
+ bool isElementTypeLegalForCompressStore(Type *Ty) const {
+ if (Ty->isFloatTy() || Ty->isDoubleTy())
+ return true;
+
+ if (Ty->isIntegerTy(8) || Ty->isIntegerTy(16) || Ty->isIntegerTy(32) ||
+ Ty->isIntegerTy(64))
+ return true;
+
+ return false;
+ }
+
+ bool isLegalMaskedCompressStore(Type *DataType,
+ Align Alignment) const override {
+ ElementCount EC = cast<VectorType>(DataType)->getElementCount();
+ if (EC.getKnownMinValue() != 2 && EC.getKnownMinValue() != 4)
+ return false;
+
+ if (!isElementTypeLegalForCompressStore(DataType->getScalarType()))
+ return false;
+
+ return isLegalMaskedLoadStore(DataType, Alignment);
+ }
+
bool isLegalMaskedGatherScatter(Type *DataType) const {
if (!ST->isSVEAvailable())
return false;
diff --git a/llvm/test/CodeGen/AArch64/sve-masked-compressstore.ll b/llvm/test/CodeGen/AArch64/sve-masked-compressstore.ll
new file mode 100644
index 0000000000000..1be5b1d1fbb6d
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-masked-compressstore.ll
@@ -0,0 +1,141 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s
+
+;; Full SVE vectors (supported with +sve)
+
+define void @test_compressstore_nxv4i32(ptr %p, <vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_nxv4i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: cntp x8, p1, p0.s
+; CHECK-NEXT: whilelo p0.s, xzr, x8
+; CHECK-NEXT: st1w { z0.s }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.nxv4i32(<vscale x 4 x i32> %vec, ptr align 4 %p, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @test_compressstore_nxv2i64(ptr %p, <vscale x 2 x i64> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_nxv2i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: compact z0.d, p0, z0.d
+; CHECK-NEXT: cntp x8, p1, p0.d
+; CHECK-NEXT: whilelo p0.d, xzr, x8
+; CHECK-NEXT: st1d { z0.d }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.nxv2i64(<vscale x 2 x i64> %vec, ptr align 8 %p, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @test_compressstore_nxv4f32(ptr %p, <vscale x 4 x float> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_nxv4f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: cntp x8, p1, p0.s
+; CHECK-NEXT: whilelo p0.s, xzr, x8
+; CHECK-NEXT: st1w { z0.s }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.nxv4f32(<vscale x 4 x float> %vec, ptr align 4 %p, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+; TODO: Legal and nonstreaming check
+define void @test_compressstore_nxv2f64(ptr %p, <vscale x 2 x double> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_nxv2f64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: compact z0.d, p0, z0.d
+; CHECK-NEXT: cntp x8, p1, p0.d
+; CHECK-NEXT: whilelo p0.d, xzr, x8
+; CHECK-NEXT: st1d { z0.d }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.nxv2f64(<vscale x 2 x double> %vec, ptr align 8 %p, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+;; Promoted SVE vector types promoted to 32/64-bit (non-exhaustive)
+
+define void @test_compressstore_nxv2i8(ptr %p, <vscale x 2 x i8> %vec, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_nxv2i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: compact z0.d, p0, z0.d
+; CHECK-NEXT: cntp x8, p1, p0.d
+; CHECK-NEXT: whilelo p0.d, xzr, x8
+; CHECK-NEXT: st1b { z0.d }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.nxv2i8(<vscale x 2 x i8> %vec, ptr align 1 %p, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @test_compressstore_nxv4i16(ptr %p, <vscale x 4 x i16> %vec, <vscale x 4 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_nxv4i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: cntp x8, p1, p0.s
+; CHECK-NEXT: whilelo p0.s, xzr, x8
+; CHECK-NEXT: st1h { z0.s }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.nxv4i16(<vscale x 4 x i16> %vec, ptr align 2 %p, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+;; NEON vector types (promoted to SVE)
+
+define void @test_compressstore_v2f32(ptr %p, <2 x double> %vec, <2 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_v2f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.2d, v1.2s, #0
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: shl v1.2d, v1.2d, #63
+; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
+; CHECK-NEXT: cntp x8, p1, p0.d
+; CHECK-NEXT: compact z0.d, p0, z0.d
+; CHECK-NEXT: whilelo p0.d, xzr, x8
+; CHECK-NEXT: st1d { z0.d }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.v2f64(<2 x double> %vec, ptr align 8 %p, <2 x i1> %mask)
+ ret void
+}
+
+define void @test_compressstore_v4i32(ptr %p, <4 x i32> %vec, <4 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_v4i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: shl v1.4s, v1.4s, #31
+; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0
+; CHECK-NEXT: cntp x8, p1, p0.s
+; CHECK-NEXT: compact z0.s, p0, z0.s
+; CHECK-NEXT: whilelo p0.s, xzr, x8
+; CHECK-NEXT: st1w { z0.s }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.v4i32(<4 x i32> %vec, ptr align 4 %p, <4 x i1> %mask)
+ ret void
+}
+
+define void @test_compressstore_v2i64(ptr %p, <2 x i64> %vec, <2 x i1> %mask) {
+; CHECK-LABEL: test_compressstore_v2i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.2d, v1.2s, #0
+; CHECK-NEXT: ptrue p0.d, vl2
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: shl v1.2d, v1.2d, #63
+; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0
+; CHECK-NEXT: cntp x8, p1, p0.d
+; CHECK-NEXT: compact z0.d, p0, z0.d
+; CHECK-NEXT: whilelo p0.d, xzr, x8
+; CHECK-NEXT: st1d { z0.d }, p0, [x0]
+; CHECK-NEXT: ret
+ tail call void @llvm.masked.compressstore.v2i64(<2 x i64> %vec, ptr align 8 %p, <2 x i1> %mask)
+ ret void
+}
|
| bool isLegalMaskedCompressStore(Type *DataType, | ||
| Align Alignment) const override { | ||
| ElementCount EC = cast<VectorType>(DataType)->getElementCount(); | ||
| if (EC.getKnownMinValue() != 2 && EC.getKnownMinValue() != 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That disqualifies compress stores for larger vectors with -msve-vector-bits, right? Can you loosen the restriction and also add some tests for SVE vectors > 128bits?
| SDValue CntActive = | ||
| DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i64, Store->getMask()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While ISD::VECREDUCE_ADD supports a larger result type than its operand's element type, the extra bits are undefined. In this case that means only the bottom bit of the result can be relied upon, with the operation likely converted to an ISD::VECREDUCE_XOR. Perhaps this is another use case that ISD::PARTIAL_REDUCE_#MLA can solve? with its implicit operand extension.
Fixing the above might make this request impossible but can this be done as target agnostic expansion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about using ISD::PARTIAL_REDUCE_#MLA, but the following works:
EVT MaskExtVT = getPromotedVTForPredicate(MaskVT);
EVT MaskReduceVT = MaskExtVT.getScalarType();
SDValue MaskExt =
DAG.getNode(ISD::ZERO_EXTEND, DL, MaskExtVT, Store->getMask());
SDValue CntActive =
DAG.getNode(ISD::VECREDUCE_ADD, DL, MaskReduceVT, MaskExt);
if (MaskReduceVT != MVT::i64)
CntActive = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, CntActive);
If we also define that the demanded bits for aarch64_sve_cntp is a most 9 (max value 256, AFAIK), which allows the ZERO_EXTEND to fold away.
(This actually improves code gen as the ptrue is folded into the cntp)
This patch adds SVE support for the
masked.compressstoreintrinsic via the existingVECTOR_COMPRESSlowering and compressing the store mask viaVECREDUCE_ADD.Currently, only
nxv4[i32|f32]andnxv2[i64|f64]are directly supported, with other types promoted to these, where possible.This is done in preparation for LV support of this intrinsic, which is currently being worked on in #140723.