diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.h b/llvm/lib/Target/Hexagon/HexagonISelLowering.h index 7aee7df917b46..ccea0da46e0d5 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLowering.h +++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.h @@ -469,8 +469,7 @@ namespace HexagonISD { SDValue LowerHvxExtend(SDValue Op, SelectionDAG &DAG) const; SDValue LowerHvxShift(SDValue Op, SelectionDAG &DAG) const; SDValue LowerHvxIntrinsic(SDValue Op, SelectionDAG &DAG) const; - SDValue LowerHvxStore(SDValue Op, SelectionDAG &DAG) const; - SDValue HvxVecPredBitcastComputation(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerHvxMaskedOp(SDValue Op, SelectionDAG &DAG) const; SDValue SplitHvxPairOp(SDValue Op, SelectionDAG &DAG) const; SDValue SplitHvxMemOp(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp index 7de7d414bd807..6e0733775ec4a 100644 --- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp +++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp @@ -97,6 +97,8 @@ HexagonTargetLowering::initializeHVXLowering() { setOperationAction(ISD::CTTZ, T, Custom); setOperationAction(ISD::LOAD, T, Custom); + setOperationAction(ISD::MLOAD, T, Custom); + setOperationAction(ISD::MSTORE, T, Custom); setOperationAction(ISD::MUL, T, Custom); setOperationAction(ISD::MULHS, T, Custom); setOperationAction(ISD::MULHU, T, Custom); @@ -150,6 +152,8 @@ HexagonTargetLowering::initializeHVXLowering() { setOperationAction(ISD::LOAD, T, Custom); setOperationAction(ISD::STORE, T, Custom); + setOperationAction(ISD::MLOAD, T, Custom); + setOperationAction(ISD::MSTORE, T, Custom); setOperationAction(ISD::CTLZ, T, Custom); setOperationAction(ISD::CTTZ, T, Custom); setOperationAction(ISD::CTPOP, T, Custom); @@ -188,6 +192,9 @@ HexagonTargetLowering::initializeHVXLowering() { setOperationAction(ISD::AND, BoolW, Custom); setOperationAction(ISD::OR, BoolW, Custom); setOperationAction(ISD::XOR, BoolW, Custom); + // Masked load/store takes a mask that may need splitting. + setOperationAction(ISD::MLOAD, BoolW, Custom); + setOperationAction(ISD::MSTORE, BoolW, Custom); } for (MVT T : LegalV) { @@ -1593,7 +1600,7 @@ HexagonTargetLowering::LowerHvxShift(SDValue Op, SelectionDAG &DAG) const { SDValue HexagonTargetLowering::LowerHvxIntrinsic(SDValue Op, SelectionDAG &DAG) const { - const SDLoc &dl(Op); + const SDLoc &dl(Op); MVT ResTy = ty(Op); unsigned IntNo = cast(Op.getOperand(0))->getZExtValue(); @@ -1613,6 +1620,75 @@ HexagonTargetLowering::LowerHvxIntrinsic(SDValue Op, SelectionDAG &DAG) const { return Op; } +SDValue +HexagonTargetLowering::LowerHvxMaskedOp(SDValue Op, SelectionDAG &DAG) const { + const SDLoc &dl(Op); + unsigned HwLen = Subtarget.getVectorLength(); + auto *MaskN = cast(Op.getNode()); + SDValue Mask = MaskN->getMask(); + SDValue Chain = MaskN->getChain(); + SDValue Base = MaskN->getBasePtr(); + auto *MemOp = MaskN->getMemOperand(); + + unsigned Opc = Op->getOpcode(); + assert(Opc == ISD::MLOAD || Opc == ISD::MSTORE); + + if (Opc == ISD::MLOAD) { + MVT ValTy = ty(Op); + SDValue Load = DAG.getLoad(ValTy, dl, Chain, Base, MaskN->getMemOperand()); + SDValue Thru = cast(MaskN)->getPassThru(); + if (isUndef(Thru)) + return Load; + SDValue VSel = DAG.getNode(ISD::VSELECT, dl, ValTy, Mask, Load, Thru); + return DAG.getMergeValues({VSel, Load.getValue(1)}, dl); + } + + // MSTORE + // HVX only has aligned masked stores. + + // TODO: Fold negations of the mask into the store. + unsigned StoreOpc = Hexagon::V6_vS32b_qpred_ai; + SDValue Value = cast(MaskN)->getValue(); + SDValue Offset0 = DAG.getTargetConstant(0, dl, ty(Base)); + + if (MaskN->getAlign().value() % HwLen == 0) { + SDValue Store = getInstr(StoreOpc, dl, MVT::Other, + {Mask, Base, Offset0, Value, Chain}, DAG); + DAG.setNodeMemRefs(cast(Store.getNode()), {MemOp}); + return Store; + } + + // Unaligned case. + auto StoreAlign = [&](SDValue V, SDValue A) { + SDValue Z = getZero(dl, ty(V), DAG); + // TODO: use funnel shifts? + // vlalign(Vu,Vv,Rt) rotates the pair Vu:Vv left by Rt and takes the + // upper half. + SDValue LoV = getInstr(Hexagon::V6_vlalignb, dl, ty(V), {V, Z, A}, DAG); + SDValue HiV = getInstr(Hexagon::V6_vlalignb, dl, ty(V), {Z, V, A}, DAG); + return std::make_pair(LoV, HiV); + }; + + MVT ByteTy = MVT::getVectorVT(MVT::i8, HwLen); + MVT BoolTy = MVT::getVectorVT(MVT::i1, HwLen); + SDValue MaskV = DAG.getNode(HexagonISD::Q2V, dl, ByteTy, Mask); + VectorPair Tmp = StoreAlign(MaskV, Base); + VectorPair MaskU = {DAG.getNode(HexagonISD::V2Q, dl, BoolTy, Tmp.first), + DAG.getNode(HexagonISD::V2Q, dl, BoolTy, Tmp.second)}; + VectorPair ValueU = StoreAlign(Value, Base); + + SDValue Offset1 = DAG.getTargetConstant(HwLen, dl, MVT::i32); + SDValue StoreLo = + getInstr(StoreOpc, dl, MVT::Other, + {MaskU.first, Base, Offset0, ValueU.first, Chain}, DAG); + SDValue StoreHi = + getInstr(StoreOpc, dl, MVT::Other, + {MaskU.second, Base, Offset1, ValueU.second, Chain}, DAG); + DAG.setNodeMemRefs(cast(StoreLo.getNode()), {MemOp}); + DAG.setNodeMemRefs(cast(StoreHi.getNode()), {MemOp}); + return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, {StoreLo, StoreHi}); +} + SDValue HexagonTargetLowering::SplitHvxPairOp(SDValue Op, SelectionDAG &DAG) const { assert(!Op.isMachineOpcode()); @@ -1648,45 +1724,81 @@ HexagonTargetLowering::SplitHvxPairOp(SDValue Op, SelectionDAG &DAG) const { SDValue HexagonTargetLowering::SplitHvxMemOp(SDValue Op, SelectionDAG &DAG) const { - LSBaseSDNode *BN = cast(Op.getNode()); - assert(BN->isUnindexed()); - MVT MemTy = BN->getMemoryVT().getSimpleVT(); + auto *MemN = cast(Op.getNode()); + + MVT MemTy = MemN->getMemoryVT().getSimpleVT(); if (!isHvxPairTy(MemTy)) return Op; const SDLoc &dl(Op); unsigned HwLen = Subtarget.getVectorLength(); MVT SingleTy = typeSplit(MemTy).first; - SDValue Chain = BN->getChain(); - SDValue Base0 = BN->getBasePtr(); + SDValue Chain = MemN->getChain(); + SDValue Base0 = MemN->getBasePtr(); SDValue Base1 = DAG.getMemBasePlusOffset(Base0, TypeSize::Fixed(HwLen), dl); MachineMemOperand *MOp0 = nullptr, *MOp1 = nullptr; - if (MachineMemOperand *MMO = BN->getMemOperand()) { + if (MachineMemOperand *MMO = MemN->getMemOperand()) { MachineFunction &MF = DAG.getMachineFunction(); MOp0 = MF.getMachineMemOperand(MMO, 0, HwLen); MOp1 = MF.getMachineMemOperand(MMO, HwLen, HwLen); } - unsigned MemOpc = BN->getOpcode(); - SDValue NewOp; + unsigned MemOpc = MemN->getOpcode(); if (MemOpc == ISD::LOAD) { + assert(cast(Op)->isUnindexed()); SDValue Load0 = DAG.getLoad(SingleTy, dl, Chain, Base0, MOp0); SDValue Load1 = DAG.getLoad(SingleTy, dl, Chain, Base1, MOp1); - NewOp = DAG.getMergeValues( - { DAG.getNode(ISD::CONCAT_VECTORS, dl, MemTy, Load0, Load1), - DAG.getNode(ISD::TokenFactor, dl, MVT::Other, - Load0.getValue(1), Load1.getValue(1)) }, dl); - } else { - assert(MemOpc == ISD::STORE); + return DAG.getMergeValues( + { DAG.getNode(ISD::CONCAT_VECTORS, dl, MemTy, Load0, Load1), + DAG.getNode(ISD::TokenFactor, dl, MVT::Other, + Load0.getValue(1), Load1.getValue(1)) }, dl); + } + if (MemOpc == ISD::STORE) { + assert(cast(Op)->isUnindexed()); VectorPair Vals = opSplit(cast(Op)->getValue(), dl, DAG); SDValue Store0 = DAG.getStore(Chain, dl, Vals.first, Base0, MOp0); SDValue Store1 = DAG.getStore(Chain, dl, Vals.second, Base1, MOp1); - NewOp = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Store0, Store1); + return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Store0, Store1); + } + + assert(MemOpc == ISD::MLOAD || MemOpc == ISD::MSTORE); + + auto MaskN = cast(Op); + assert(MaskN->isUnindexed()); + VectorPair Masks = opSplit(MaskN->getMask(), dl, DAG); + SDValue Offset = DAG.getUNDEF(MVT::i32); + + if (MemOpc == ISD::MLOAD) { + VectorPair Thru = + opSplit(cast(Op)->getPassThru(), dl, DAG); + SDValue MLoad0 = + DAG.getMaskedLoad(SingleTy, dl, Chain, Base0, Offset, Masks.first, + Thru.first, SingleTy, MOp0, ISD::UNINDEXED, + ISD::NON_EXTLOAD, false); + SDValue MLoad1 = + DAG.getMaskedLoad(SingleTy, dl, Chain, Base1, Offset, Masks.second, + Thru.second, SingleTy, MOp1, ISD::UNINDEXED, + ISD::NON_EXTLOAD, false); + return DAG.getMergeValues( + { DAG.getNode(ISD::CONCAT_VECTORS, dl, MemTy, MLoad0, MLoad1), + DAG.getNode(ISD::TokenFactor, dl, MVT::Other, + MLoad0.getValue(1), MLoad1.getValue(1)) }, dl); + } + if (MemOpc == ISD::MSTORE) { + VectorPair Vals = opSplit(cast(Op)->getValue(), dl, DAG); + SDValue MStore0 = DAG.getMaskedStore(Chain, dl, Vals.first, Base0, Offset, + Masks.first, SingleTy, MOp0, + ISD::UNINDEXED, false, false); + SDValue MStore1 = DAG.getMaskedStore(Chain, dl, Vals.second, Base1, Offset, + Masks.second, SingleTy, MOp1, + ISD::UNINDEXED, false, false); + return DAG.getNode(ISD::TokenFactor, dl, MVT::Other, MStore0, MStore1); } - return NewOp; + std::string Name = "Unexpected operation: " + Op->getOperationName(&DAG); + llvm_unreachable(Name.c_str()); } SDValue @@ -1749,6 +1861,8 @@ HexagonTargetLowering::LowerHvxOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::SETCC: case ISD::INTRINSIC_VOID: return Op; case ISD::INTRINSIC_WO_CHAIN: return LowerHvxIntrinsic(Op, DAG); + case ISD::MLOAD: + case ISD::MSTORE: return LowerHvxMaskedOp(Op, DAG); // Unaligned loads will be handled by the default lowering. case ISD::LOAD: return SDValue(); } @@ -1761,6 +1875,25 @@ HexagonTargetLowering::LowerHvxOperation(SDValue Op, SelectionDAG &DAG) const { void HexagonTargetLowering::LowerHvxOperationWrapper(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { + unsigned Opc = N->getOpcode(); + SDValue Op(N, 0); + + switch (Opc) { + case ISD::MLOAD: + if (isHvxPairTy(ty(Op))) { + SDValue S = SplitHvxMemOp(Op, DAG); + assert(S->getOpcode() == ISD::MERGE_VALUES); + Results.push_back(S.getOperand(0)); + Results.push_back(S.getOperand(1)); + } + break; + case ISD::MSTORE: + if (isHvxPairTy(ty(Op->getOperand(1)))) { // Stored value + SDValue S = SplitHvxMemOp(Op, DAG); + Results.push_back(S); + } + break; + } } void @@ -1783,6 +1916,8 @@ HexagonTargetLowering::ReplaceHvxNodeResults(SDNode *N, SDValue HexagonTargetLowering::PerformHvxDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); const SDLoc &dl(N); SDValue Op(N, 0); diff --git a/llvm/lib/Target/Hexagon/HexagonInstrInfo.cpp b/llvm/lib/Target/Hexagon/HexagonInstrInfo.cpp index d1cd23c3be3e5..93215a4b61870 100644 --- a/llvm/lib/Target/Hexagon/HexagonInstrInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonInstrInfo.cpp @@ -2721,6 +2721,8 @@ bool HexagonInstrInfo::isValidOffset(unsigned Opcode, int Offset, case Hexagon::PS_vloadrw_nt_ai: case Hexagon::V6_vL32b_ai: case Hexagon::V6_vS32b_ai: + case Hexagon::V6_vS32b_qpred_ai: + case Hexagon::V6_vS32b_nqpred_ai: case Hexagon::V6_vL32b_nt_ai: case Hexagon::V6_vS32b_nt_ai: case Hexagon::V6_vL32Ub_ai: diff --git a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td index 078a7135c55be..0e5772bd690f2 100644 --- a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td +++ b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td @@ -364,6 +364,14 @@ let Predicates = [UseHVX] in { (V6_vasrw (V6_vaslw HVI32:$Vs, (A2_tfrsi 16)), (A2_tfrsi 16))>; } + // Take a pair of vectors Vt:Vs and shift them towards LSB by (Rt & HwLen). + def: Pat<(VecI8 (valign HVI8:$Vt, HVI8:$Vs, I32:$Rt)), + (LoVec (V6_valignb HvxVR:$Vt, HvxVR:$Vs, I32:$Rt))>; + def: Pat<(VecI16 (valign HVI16:$Vt, HVI16:$Vs, I32:$Rt)), + (LoVec (V6_valignb HvxVR:$Vt, HvxVR:$Vs, I32:$Rt))>; + def: Pat<(VecI32 (valign HVI32:$Vt, HVI32:$Vs, I32:$Rt)), + (LoVec (V6_valignb HvxVR:$Vt, HvxVR:$Vs, I32:$Rt))>; + def: Pat<(HexagonVASL HVI8:$Vs, I32:$Rt), (V6_vpackeb (V6_vaslh (HiVec (VZxtb HvxVR:$Vs)), I32:$Rt), (V6_vaslh (LoVec (VZxtb HvxVR:$Vs)), I32:$Rt))>; diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp index ce674d638ccb4..cbd60f36d8c6e 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -35,6 +35,9 @@ static cl::opt EmitLookupTables("hexagon-emit-lookup-tables", cl::init(true), cl::Hidden, cl::desc("Control lookup table emission on Hexagon target")); +static cl::opt HexagonMaskedVMem("hexagon-masked-vmem", cl::init(true), + cl::Hidden, cl::desc("Enable loop vectorizer for HVX")); + // Constant "cost factor" to make floating point operations more expensive // in terms of vectorization cost. This isn't the best way, but it should // do. Ultimately, the cost should use cycles. @@ -45,8 +48,7 @@ bool HexagonTTIImpl::useHVX() const { } bool HexagonTTIImpl::isTypeForHVX(Type *VecTy) const { - assert(VecTy->isVectorTy()); - if (isa(VecTy)) + if (!VecTy->isVectorTy() || isa(VecTy)) return false; // Avoid types like <2 x i32*>. if (!cast(VecTy)->getElementType()->isIntegerTy()) @@ -308,6 +310,14 @@ unsigned HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, return 1; } +bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/) { + return HexagonMaskedVMem && isTypeForHVX(DataType); +} + +bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/) { + return HexagonMaskedVMem && isTypeForHVX(DataType); +} + /// --- Vector TTI end --- unsigned HexagonTTIImpl::getPrefetchDistance() const { diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h index 07e59fb5585e8..b99f512df7665 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -155,6 +155,9 @@ class HexagonTTIImpl : public BasicTTIImplBase { return 1; } + bool isLegalMaskedStore(Type *DataType, Align Alignment); + bool isLegalMaskedLoad(Type *DataType, Align Alignment); + /// @} int getUserCost(const User *U, ArrayRef Operands, diff --git a/llvm/test/CodeGen/Hexagon/autohvx/masked-vmem-basic.ll b/llvm/test/CodeGen/Hexagon/autohvx/masked-vmem-basic.ll new file mode 100644 index 0000000000000..9836d2d5cb5ca --- /dev/null +++ b/llvm/test/CodeGen/Hexagon/autohvx/masked-vmem-basic.ll @@ -0,0 +1,35 @@ +; RUN: llc -march=hexagon < %s | FileCheck %s + +; CHECK-LABEL: f0: +; CHECK: vmemu +; CHECK: vmux +define <128 x i8> @f0(<128 x i8>* %a0, i32 %a1, i32 %a2) #0 { + %q0 = call <128 x i1> @llvm.hexagon.V6.pred.scalar2.128B(i32 %a2) + %v0 = call <32 x i32> @llvm.hexagon.V6.lvsplatb.128B(i32 %a1) + %v1 = bitcast <32 x i32> %v0 to <128 x i8> + %v2 = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* %a0, i32 4, <128 x i1> %q0, <128 x i8> %v1) + ret <128 x i8> %v2 +} + +; CHECK-LABEL: f1: +; CHECK: vlalign +; CHECK: if (q{{.}}) vmem{{.*}} = v +define void @f1(<128 x i8>* %a0, i32 %a1, i32 %a2) #0 { + %q0 = call <128 x i1> @llvm.hexagon.V6.pred.scalar2.128B(i32 %a2) + %v0 = call <32 x i32> @llvm.hexagon.V6.lvsplatb.128B(i32 %a1) + %v1 = bitcast <32 x i32> %v0 to <128 x i8> + call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> %v1, <128 x i8>* %a0, i32 4, <128 x i1> %q0) + ret void +} + +declare <128 x i1> @llvm.hexagon.V6.pred.scalar2.128B(i32) #1 +declare <32 x i32> @llvm.hexagon.V6.lvsplatb.128B(i32) #1 +declare <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>*, i32 immarg, <128 x i1>, <128 x i8>) #2 +declare void @llvm.masked.store.v128i8.p0v128i8(<128 x i8>, <128 x i8>*, i32 immarg, <128 x i1>) #2 + +attributes #0 = { "target-cpu"="hexagonv65" "target-features"="+hvx,+hvx-length128b" } +attributes #1 = { nounwind readnone } +attributes #2 = { argmemonly nounwind readonly willreturn } +attributes #3 = { argmemonly nounwind willreturn } + + diff --git a/llvm/test/CodeGen/Hexagon/hvx-bitcast-v64i1.ll b/llvm/test/CodeGen/Hexagon/hvx-bitcast-v64i1.ll index c44e7a863840e..cb135f72448fe 100644 --- a/llvm/test/CodeGen/Hexagon/hvx-bitcast-v64i1.ll +++ b/llvm/test/CodeGen/Hexagon/hvx-bitcast-v64i1.ll @@ -1,4 +1,4 @@ -; RUN: llc -march=hexagon -hexagon-instsimplify=0 < %s | FileCheck %s +; RUN: llc -march=hexagon -hexagon-instsimplify=0 -hexagon-masked-vmem=0 < %s | FileCheck %s ; Test that LLVM does not assert and bitcast v64i1 to i64 is lowered ; without crashing. diff --git a/llvm/test/CodeGen/Hexagon/store-vector-pred.ll b/llvm/test/CodeGen/Hexagon/store-vector-pred.ll index a177f87ddfbd5..d9d841cacc5bb 100644 --- a/llvm/test/CodeGen/Hexagon/store-vector-pred.ll +++ b/llvm/test/CodeGen/Hexagon/store-vector-pred.ll @@ -1,4 +1,4 @@ -; RUN: llc -march=hexagon -hexagon-instsimplify=0 < %s | FileCheck %s +; RUN: llc -march=hexagon -hexagon-instsimplify=0 -hexagon-masked-vmem=0 < %s | FileCheck %s ; This test checks that store a vector predicate of type v128i1 is lowered ; without crashing.