diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 35836af3c874b..67bccb46939dc 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1962,10 +1962,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); // We can lower types that have elements to compact. - for (auto VT : - {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32, - MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32}) + for (auto VT : {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, + MVT::nxv2f32, MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, + MVT::nxv4i32, MVT::nxv4f32}) { + setOperationAction(ISD::MSTORE, VT, Custom); + // Use a custom lowering for masked stores that could be a supported + // compressing store. Note: These types still use the normal (Legal) + // lowering for non-compressing masked stores. setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom); + } // If we have SVE, we can use SVE logic for legal (or smaller than legal) // NEON vectors in the lowest bits of the SVE register. @@ -7740,7 +7745,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::STORE: return LowerSTORE(Op, DAG); case ISD::MSTORE: - return LowerFixedLengthVectorMStoreToSVE(Op, DAG); + return LowerMSTORE(Op, DAG); case ISD::MGATHER: return LowerMGATHER(Op, DAG); case ISD::MSCATTER: @@ -30180,6 +30185,36 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE( Store->isTruncatingStore()); } +SDValue AArch64TargetLowering::LowerMSTORE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + auto *Store = cast(Op); + EVT VT = Store->getValue().getValueType(); + if (VT.isFixedLengthVector()) + return LowerFixedLengthVectorMStoreToSVE(Op, DAG); + + if (!Store->isCompressingStore()) + return SDValue(); + + EVT MaskVT = Store->getMask().getValueType(); + + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + SDValue CntActive = + DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i64, Store->getMask()); + SDValue CompressedValue = + DAG.getNode(ISD::VECTOR_COMPRESS, DL, VT, Store->getValue(), + Store->getMask(), DAG.getPOISON(VT)); + SDValue CompressedMask = + DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, MaskVT, Zero, CntActive); + + return DAG.getMaskedStore(Store->getChain(), DL, CompressedValue, + Store->getBasePtr(), Store->getOffset(), + CompressedMask, Store->getMemoryVT(), + Store->getMemOperand(), Store->getAddressingMode(), + Store->isTruncatingStore(), + /*isCompressing=*/false); +} + SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE( SDValue Op, SelectionDAG &DAG) const { auto *Store = cast(Op); @@ -30194,7 +30229,8 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE( return DAG.getMaskedStore( Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(), Mask, Store->getMemoryVT(), Store->getMemOperand(), - Store->getAddressingMode(), Store->isTruncatingStore()); + Store->getAddressingMode(), Store->isTruncatingStore(), + Store->isCompressingStore()); } SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE( diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 70bfae717fb76..8fcef502cdab7 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -755,6 +755,7 @@ class AArch64TargetLowering : public TargetLowering { SDValue LowerWindowsDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; SDValue LowerInlineDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerMSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerAVG(SDValue Op, SelectionDAG &DAG, unsigned NewOp) const; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 50a3a4ab8d8b6..7e3643da4a2fd 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -645,29 +645,34 @@ def nontrunc_masked_store : (masked_st node:$val, node:$ptr, undef, node:$pred), [{ return !cast(N)->isTruncatingStore() && cast(N)->isUnindexed() && - !cast(N)->isNonTemporal(); + !cast(N)->isNonTemporal() && + !cast(N)->isCompressingStore(); }]>; // truncating masked store fragments. def trunc_masked_store : PatFrag<(ops node:$val, node:$ptr, node:$pred), (masked_st node:$val, node:$ptr, undef, node:$pred), [{ return cast(N)->isTruncatingStore() && - cast(N)->isUnindexed(); + cast(N)->isUnindexed() && + !cast(N)->isCompressingStore(); }]>; def trunc_masked_store_i8 : PatFrag<(ops node:$val, node:$ptr, node:$pred), (trunc_masked_store node:$val, node:$ptr, node:$pred), [{ - return cast(N)->getMemoryVT().getScalarType() == MVT::i8; + return cast(N)->getMemoryVT().getScalarType() == MVT::i8 && + !cast(N)->isCompressingStore(); }]>; def trunc_masked_store_i16 : PatFrag<(ops node:$val, node:$ptr, node:$pred), (trunc_masked_store node:$val, node:$ptr, node:$pred), [{ - return cast(N)->getMemoryVT().getScalarType() == MVT::i16; + return cast(N)->getMemoryVT().getScalarType() == MVT::i16 && + !cast(N)->isCompressingStore(); }]>; def trunc_masked_store_i32 : PatFrag<(ops node:$val, node:$ptr, node:$pred), (trunc_masked_store node:$val, node:$ptr, node:$pred), [{ - return cast(N)->getMemoryVT().getScalarType() == MVT::i32; + return cast(N)->getMemoryVT().getScalarType() == MVT::i32 && + !cast(N)->isCompressingStore(); }]>; def non_temporal_store : @@ -675,7 +680,8 @@ def non_temporal_store : (masked_st node:$val, node:$ptr, undef, node:$pred), [{ return !cast(N)->isTruncatingStore() && cast(N)->isUnindexed() && - cast(N)->isNonTemporal(); + cast(N)->isNonTemporal() && + !cast(N)->isCompressingStore(); }]>; multiclass masked_gather_scatter { diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index d189f563f99a1..15ec8629de787 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -333,6 +333,29 @@ class AArch64TTIImpl final : public BasicTTIImplBase { return isLegalMaskedLoadStore(DataType, Alignment); } + bool isElementTypeLegalForCompressStore(Type *Ty) const { + if (Ty->isFloatTy() || Ty->isDoubleTy()) + return true; + + if (Ty->isIntegerTy(8) || Ty->isIntegerTy(16) || Ty->isIntegerTy(32) || + Ty->isIntegerTy(64)) + return true; + + return false; + } + + bool isLegalMaskedCompressStore(Type *DataType, + Align Alignment) const override { + ElementCount EC = cast(DataType)->getElementCount(); + if (EC.getKnownMinValue() != 2 && EC.getKnownMinValue() != 4) + return false; + + if (!isElementTypeLegalForCompressStore(DataType->getScalarType())) + return false; + + return isLegalMaskedLoadStore(DataType, Alignment); + } + bool isLegalMaskedGatherScatter(Type *DataType) const { if (!ST->isSVEAvailable()) return false; diff --git a/llvm/test/CodeGen/AArch64/sve-masked-compressstore.ll b/llvm/test/CodeGen/AArch64/sve-masked-compressstore.ll new file mode 100644 index 0000000000000..5efe03a161cc1 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-masked-compressstore.ll @@ -0,0 +1,141 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s + +;; Full SVE vectors (supported with +sve) + +define void @test_compressstore_nxv4i32(ptr %p, %vec, %mask) { +; CHECK-LABEL: test_compressstore_nxv4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: compact z0.s, p0, z0.s +; CHECK-NEXT: cntp x8, p1, p0.s +; CHECK-NEXT: whilelo p0.s, xzr, x8 +; CHECK-NEXT: st1w { z0.s }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.nxv4i32( %vec, ptr align 4 %p, %mask) + ret void +} + +define void @test_compressstore_nxv2i64(ptr %p, %vec, %mask) { +; CHECK-LABEL: test_compressstore_nxv2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: compact z0.d, p0, z0.d +; CHECK-NEXT: cntp x8, p1, p0.d +; CHECK-NEXT: whilelo p0.d, xzr, x8 +; CHECK-NEXT: st1d { z0.d }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.nxv2i64( %vec, ptr align 8 %p, %mask) + ret void +} + +define void @test_compressstore_nxv4f32(ptr %p, %vec, %mask) { +; CHECK-LABEL: test_compressstore_nxv4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: compact z0.s, p0, z0.s +; CHECK-NEXT: cntp x8, p1, p0.s +; CHECK-NEXT: whilelo p0.s, xzr, x8 +; CHECK-NEXT: st1w { z0.s }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.nxv4f32( %vec, ptr align 4 %p, %mask) + ret void +} + +; TODO: Legal and nonstreaming check +define void @test_compressstore_nxv2f64(ptr %p, %vec, %mask) { +; CHECK-LABEL: test_compressstore_nxv2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: compact z0.d, p0, z0.d +; CHECK-NEXT: cntp x8, p1, p0.d +; CHECK-NEXT: whilelo p0.d, xzr, x8 +; CHECK-NEXT: st1d { z0.d }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.nxv2f64( %vec, ptr align 8 %p, %mask) + ret void +} + +;; SVE vector types promoted to 32/64-bit (non-exhaustive) + +define void @test_compressstore_nxv2i8(ptr %p, %vec, %mask) { +; CHECK-LABEL: test_compressstore_nxv2i8: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: compact z0.d, p0, z0.d +; CHECK-NEXT: cntp x8, p1, p0.d +; CHECK-NEXT: whilelo p0.d, xzr, x8 +; CHECK-NEXT: st1b { z0.d }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.nxv2i8( %vec, ptr align 1 %p, %mask) + ret void +} + +define void @test_compressstore_nxv4i16(ptr %p, %vec, %mask) { +; CHECK-LABEL: test_compressstore_nxv4i16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: compact z0.s, p0, z0.s +; CHECK-NEXT: cntp x8, p1, p0.s +; CHECK-NEXT: whilelo p0.s, xzr, x8 +; CHECK-NEXT: st1h { z0.s }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.nxv4i16( %vec, ptr align 2 %p, %mask) + ret void +} + +;; NEON vector types (promoted to SVE) + +define void @test_compressstore_v2f32(ptr %p, <2 x double> %vec, <2 x i1> %mask) { +; CHECK-LABEL: test_compressstore_v2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ushll v1.2d, v1.2s, #0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: shl v1.2d, v1.2d, #63 +; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0 +; CHECK-NEXT: cntp x8, p1, p0.d +; CHECK-NEXT: compact z0.d, p0, z0.d +; CHECK-NEXT: whilelo p0.d, xzr, x8 +; CHECK-NEXT: st1d { z0.d }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.v2f64(<2 x double> %vec, ptr align 8 %p, <2 x i1> %mask) + ret void +} + +define void @test_compressstore_v4i32(ptr %p, <4 x i32> %vec, <4 x i1> %mask) { +; CHECK-LABEL: test_compressstore_v4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ushll v1.4s, v1.4h, #0 +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: shl v1.4s, v1.4s, #31 +; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, #0 +; CHECK-NEXT: cntp x8, p1, p0.s +; CHECK-NEXT: compact z0.s, p0, z0.s +; CHECK-NEXT: whilelo p0.s, xzr, x8 +; CHECK-NEXT: st1w { z0.s }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.v4i32(<4 x i32> %vec, ptr align 4 %p, <4 x i1> %mask) + ret void +} + +define void @test_compressstore_v2i64(ptr %p, <2 x i64> %vec, <2 x i1> %mask) { +; CHECK-LABEL: test_compressstore_v2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: ushll v1.2d, v1.2s, #0 +; CHECK-NEXT: ptrue p0.d, vl2 +; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0 +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: shl v1.2d, v1.2d, #63 +; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, #0 +; CHECK-NEXT: cntp x8, p1, p0.d +; CHECK-NEXT: compact z0.d, p0, z0.d +; CHECK-NEXT: whilelo p0.d, xzr, x8 +; CHECK-NEXT: st1d { z0.d }, p0, [x0] +; CHECK-NEXT: ret + tail call void @llvm.masked.compressstore.v2i64(<2 x i64> %vec, ptr align 8 %p, <2 x i1> %mask) + ret void +}