Skip to content

Conversation

AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Aug 25, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 21.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155198.diff

5 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+3-24)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+164-127)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp (+3-3)
  • (modified) llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll (+66)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 3300ed9a5a81c..964b93ed2527c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1097,11 +1097,6 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   if (PlainLoad && PlainLoad->isIndexed())
     return false;
 
-  const EVT LoadedEVT = LD->getMemoryVT();
-  if (!LoadedEVT.isSimple())
-    return false;
-  const MVT LoadedVT = LoadedEVT.getSimpleVT();
-
   // Address Space Setting
   const auto CodeAddrSpace = getAddrSpace(LD);
   if (canLowerToLDG(*LD, *Subtarget, CodeAddrSpace))
@@ -1111,7 +1106,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   SDValue Chain = N->getOperand(0);
   const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, LD);
 
-  const unsigned FromTypeWidth = LoadedVT.getSizeInBits();
+  const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits();
 
   // Vector Setting
   const unsigned FromType =
@@ -1165,9 +1160,6 @@ static unsigned getStoreVectorNumElts(SDNode *N) {
 
 bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   MemSDNode *LD = cast<MemSDNode>(N);
-  const EVT MemEVT = LD->getMemoryVT();
-  if (!MemEVT.isSimple())
-    return false;
 
   // Address Space Setting
   const auto CodeAddrSpace = getAddrSpace(LD);
@@ -1237,10 +1229,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
 }
 
 bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
-  const EVT LoadedEVT = LD->getMemoryVT();
-  if (!LoadedEVT.isSimple())
-    return false;
-
   SDLoc DL(LD);
 
   unsigned ExtensionType;
@@ -1357,10 +1345,6 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
   if (PlainStore && PlainStore->isIndexed())
     return false;
 
-  const EVT StoreVT = ST->getMemoryVT();
-  if (!StoreVT.isSimple())
-    return false;
-
   // Address Space Setting
   const auto CodeAddrSpace = getAddrSpace(ST);
 
@@ -1369,7 +1353,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
   const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
 
   // Vector Setting
-  const unsigned ToTypeWidth = StoreVT.getSimpleVT().getSizeInBits();
+  const unsigned ToTypeWidth = ST->getMemoryVT().getSizeInBits();
 
   // Create the machine instruction DAG
   SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
@@ -1406,8 +1390,7 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
 
 bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   MemSDNode *ST = cast<MemSDNode>(N);
-  const EVT StoreVT = ST->getMemoryVT();
-  assert(StoreVT.isSimple() && "Store value is not simple");
+  const unsigned TotalWidth = ST->getMemoryVT().getSizeInBits();
 
   // Address Space Setting
   const auto CodeAddrSpace = getAddrSpace(ST);
@@ -1420,10 +1403,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   SDValue Chain = ST->getChain();
   const auto [Ordering, Scope] = insertMemoryInstructionFence(DL, Chain, ST);
 
-  // Type Setting: toType + toTypeWidth
-  // - for integer type, always use 'u'
-  const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
-
   const unsigned NumElts = getStoreVectorNumElts(ST);
 
   SmallVector<SDValue, 16> Ops;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index bb4bb1195f78b..4d0dea6d92cd3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -198,6 +198,12 @@ static bool IsPTXVectorType(MVT VT) {
 static std::optional<std::pair<unsigned int, MVT>>
 getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
                        unsigned AddressSpace) {
+  const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);
+
+  if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
+      VectorEVT.getSizeInBits() == 256)
+    return {{4, MVT::i64}};
+
   if (!VectorEVT.isSimple())
     return std::nullopt;
   const MVT VectorVT = VectorEVT.getSimpleVT();
@@ -214,8 +220,6 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
   // The size of the PTX virtual register that holds a packed type.
   unsigned PackRegSize;
 
-  bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);
-
   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
   // legal.  We can (and should) split that into 2 stores of <2 x double> here
   // but I'm leaving that as a TODO for now.
@@ -3088,9 +3092,114 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const {
                       MachinePointerInfo(SV));
 }
 
+/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
+static std::optional<std::pair<SDValue, SDValue>>
+replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) {
+  LoadSDNode *LD = cast<LoadSDNode>(N);
+  const EVT ResVT = LD->getValueType(0);
+  const EVT MemVT = LD->getMemoryVT();
+
+  // If we're doing sign/zero extension as part of the load, avoid lowering to
+  // a LoadV node. TODO: consider relaxing this restriction.
+  if (ResVT != MemVT)
+    return std::nullopt;
+
+  const auto NumEltsAndEltVT =
+      getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
+  if (!NumEltsAndEltVT)
+    return std::nullopt;
+  const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
+
+  Align Alignment = LD->getAlign();
+  const auto &TD = DAG.getDataLayout();
+  Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
+  if (Alignment < PrefAlign) {
+    // This load is not sufficiently aligned, so bail out and let this vector
+    // load be scalarized.  Note that we may still be able to emit smaller
+    // vector loads.  For example, if we are loading a <4 x float> with an
+    // alignment of 8, this check will fail but the legalizer will try again
+    // with 2 x <2 x float>, which will succeed with an alignment of 8.
+    return std::nullopt;
+  }
+
+  // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
+  // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
+  // loaded type to i16 and propagate the "real" type as the memory type.
+  const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
+
+  unsigned Opcode;
+  switch (NumElts) {
+  default:
+    return std::nullopt;
+  case 2:
+    Opcode = NVPTXISD::LoadV2;
+    break;
+  case 4:
+    Opcode = NVPTXISD::LoadV4;
+    break;
+  case 8:
+    Opcode = NVPTXISD::LoadV8;
+    break;
+  }
+  auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
+  ListVTs.push_back(MVT::Other);
+  SDVTList LdResVTs = DAG.getVTList(ListVTs);
+
+  SDLoc DL(LD);
+
+  // Copy regular operands
+  SmallVector<SDValue, 8> OtherOps(LD->ops());
+
+  // The select routine does not have access to the LoadSDNode instance, so
+  // pass along the extension information
+  OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+
+  SDValue NewLD = DAG.getMemIntrinsicNode(
+      Opcode, DL, LdResVTs, OtherOps, MemVT, LD->getMemOperand());
+
+  SmallVector<SDValue> ScalarRes;
+  if (EltVT.isVector()) {
+    assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
+    assert(NumElts * EltVT.getVectorNumElements() ==
+           ResVT.getVectorNumElements());
+    // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
+    // into individual elements.
+    for (const unsigned I : llvm::seq(NumElts)) {
+      SDValue SubVector = NewLD.getValue(I);
+      DAG.ExtractVectorElements(SubVector, ScalarRes);
+    }
+  } else {
+    for (const unsigned I : llvm::seq(NumElts)) {
+      SDValue Res = NewLD.getValue(I);
+      if (LoadEltVT != EltVT)
+        Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
+      ScalarRes.push_back(Res);
+    }
+  }
+
+  SDValue LoadChain = NewLD.getValue(NumElts);
+
+  const MVT BuildVecVT =
+      MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
+  SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
+  SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec);
+
+  return {{LoadValue, LoadChain}};
+}
+
 static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
                               SmallVectorImpl<SDValue> &Results,
-                              const NVPTXSubtarget &STI);
+                              const NVPTXSubtarget &STI) {
+  if (auto Res = replaceLoadVector(N, DAG, STI))
+    Results.append({Res->first, Res->second});
+}
+
+static SDValue lowerLoadVector(SDNode *N, SelectionDAG &DAG,
+                               const NVPTXSubtarget &STI) {
+  if (auto Res = replaceLoadVector(N, DAG, STI))
+    return DAG.getMergeValues({Res->first, Res->second}, SDLoc(N));
+  return SDValue();
+}
 
 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
   if (Op.getValueType() == MVT::i1)
@@ -3137,31 +3246,8 @@ SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getMergeValues(Ops, dl);
 }
 
-SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
-  StoreSDNode *Store = cast<StoreSDNode>(Op);
-  EVT VT = Store->getMemoryVT();
-
-  if (VT == MVT::i1)
-    return LowerSTOREi1(Op, DAG);
-
-  // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
-  // handle unaligned stores and have to handle it here.
-  if (NVPTX::isPackedVectorTy(VT) &&
-      !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
-                                      VT, *Store->getMemOperand()))
-    return expandUnalignedStore(Store, DAG);
-
-  // v2f16/v2bf16/v2i16 don't need special handling.
-  if (NVPTX::isPackedVectorTy(VT) && VT.is32BitVector())
-    return SDValue();
-
-  // Lower store of any other vector type, including v2f32 as we want to break
-  // it apart since this is not a widely-supported type.
-  return LowerSTOREVector(Op, DAG);
-}
-
-SDValue
-NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
+static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG,
+                                const NVPTXSubtarget &STI) {
   MemSDNode *N = cast<MemSDNode>(Op.getNode());
   SDValue Val = N->getOperand(1);
   SDLoc DL(N);
@@ -3253,6 +3339,29 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
   return NewSt;
 }
 
+SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
+  StoreSDNode *Store = cast<StoreSDNode>(Op);
+  EVT VT = Store->getMemoryVT();
+
+  if (VT == MVT::i1)
+    return LowerSTOREi1(Op, DAG);
+
+  // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
+  // handle unaligned stores and have to handle it here.
+  if (NVPTX::isPackedVectorTy(VT) &&
+      !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
+                                      VT, *Store->getMemOperand()))
+    return expandUnalignedStore(Store, DAG);
+
+  // v2f16/v2bf16/v2i16 don't need special handling.
+  if (NVPTX::isPackedVectorTy(VT) && VT.is32BitVector())
+    return SDValue();
+
+  // Lower store of any other vector type, including v2f32 as we want to break
+  // it apart since this is not a widely-supported type.
+  return lowerSTOREVector(Op, DAG, STI);
+}
+
 // st i1 v, addr
 //    =>
 // v1 = zxt v to i16
@@ -5152,11 +5261,34 @@ static SDValue combinePackingMovIntoStore(SDNode *N,
                                      ST->getMemoryVT(), ST->getMemOperand());
 }
 
-static SDValue PerformStoreCombine(SDNode *N,
-                                   TargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineSTORE(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                            const NVPTXSubtarget &STI) {
+
+  if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::STORE) {
+    // Here is our chance to custom lower a store with a non-simple type.
+    // Unfortunately, we can't do this in the legalizer because there is no
+    // way to setOperationAction for an non-simple type.
+    StoreSDNode *ST = cast<StoreSDNode>(N);
+    if (!ST->getValue().getValueType().isSimple())
+      return lowerSTOREVector(SDValue(ST, 0), DCI.DAG, STI);
+  }
+
   return combinePackingMovIntoStore(N, DCI, 1, 2);
 }
 
+static SDValue combineLOAD(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                           const NVPTXSubtarget &STI) {
+  if (DCI.isBeforeLegalize() && N->getOpcode() == ISD::LOAD) {
+    // Here is our chance to custom lower a load with a non-simple type.
+    // Unfortunately, we can't do this in the legalizer because there is no
+    // way to setOperationAction for an non-simple type.
+    if (!N->getValueType(0).isSimple())
+      return lowerLoadVector(N, DCI.DAG, STI);
+  }
+
+  return combineUnpackingMovIntoLoad(N, DCI);
+}
+
 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
 ///
 static SDValue PerformADDCombine(SDNode *N,
@@ -5884,7 +6016,7 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::LOAD:
   case NVPTXISD::LoadV2:
   case NVPTXISD::LoadV4:
-    return combineUnpackingMovIntoLoad(N, DCI);
+    return combineLOAD(N, DCI, STI);
   case ISD::MUL:
     return PerformMULCombine(N, DCI, OptLevel);
   case NVPTXISD::PRMT:
@@ -5901,7 +6033,7 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::STORE:
   case NVPTXISD::StoreV2:
   case NVPTXISD::StoreV4:
-    return PerformStoreCombine(N, DCI);
+    return combineSTORE(N, DCI, STI);
   case ISD::VSELECT:
     return PerformVSELECTCombine(N, DCI);
   }
@@ -5930,102 +6062,7 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
       DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
 }
 
-/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
-static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
-                              SmallVectorImpl<SDValue> &Results,
-                              const NVPTXSubtarget &STI) {
-  LoadSDNode *LD = cast<LoadSDNode>(N);
-  const EVT ResVT = LD->getValueType(0);
-  const EVT MemVT = LD->getMemoryVT();
-
-  // If we're doing sign/zero extension as part of the load, avoid lowering to
-  // a LoadV node. TODO: consider relaxing this restriction.
-  if (ResVT != MemVT)
-    return;
-
-  const auto NumEltsAndEltVT =
-      getVectorLoweringShape(ResVT, STI, LD->getAddressSpace());
-  if (!NumEltsAndEltVT)
-    return;
-  const auto [NumElts, EltVT] = NumEltsAndEltVT.value();
-
-  Align Alignment = LD->getAlign();
-  const auto &TD = DAG.getDataLayout();
-  Align PrefAlign = TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DAG.getContext()));
-  if (Alignment < PrefAlign) {
-    // This load is not sufficiently aligned, so bail out and let this vector
-    // load be scalarized.  Note that we may still be able to emit smaller
-    // vector loads.  For example, if we are loading a <4 x float> with an
-    // alignment of 8, this check will fail but the legalizer will try again
-    // with 2 x <2 x float>, which will succeed with an alignment of 8.
-    return;
-  }
-
-  // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
-  // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
-  // loaded type to i16 and propagate the "real" type as the memory type.
-  const MVT LoadEltVT = (EltVT.getSizeInBits() < 16) ? MVT::i16 : EltVT;
-
-  unsigned Opcode;
-  switch (NumElts) {
-  default:
-    return;
-  case 2:
-    Opcode = NVPTXISD::LoadV2;
-    break;
-  case 4:
-    Opcode = NVPTXISD::LoadV4;
-    break;
-  case 8:
-    Opcode = NVPTXISD::LoadV8;
-    break;
-  }
-  auto ListVTs = SmallVector<EVT, 9>(NumElts, LoadEltVT);
-  ListVTs.push_back(MVT::Other);
-  SDVTList LdResVTs = DAG.getVTList(ListVTs);
-
-  SDLoc DL(LD);
-
-  // Copy regular operands
-  SmallVector<SDValue, 8> OtherOps(LD->ops());
-
-  // The select routine does not have access to the LoadSDNode instance, so
-  // pass along the extension information
-  OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
-
-  SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
-                                          LD->getMemoryVT(),
-                                          LD->getMemOperand());
-
-  SmallVector<SDValue> ScalarRes;
-  if (EltVT.isVector()) {
-    assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
-    assert(NumElts * EltVT.getVectorNumElements() ==
-           ResVT.getVectorNumElements());
-    // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
-    // into individual elements.
-    for (const unsigned I : llvm::seq(NumElts)) {
-      SDValue SubVector = NewLD.getValue(I);
-      DAG.ExtractVectorElements(SubVector, ScalarRes);
-    }
-  } else {
-    for (const unsigned I : llvm::seq(NumElts)) {
-      SDValue Res = NewLD.getValue(I);
-      if (LoadEltVT != EltVT)
-        Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
-      ScalarRes.push_back(Res);
-    }
-  }
 
-  SDValue LoadChain = NewLD.getValue(NumElts);
-
-  const MVT BuildVecVT =
-      MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
-  SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
-  SDValue LoadValue = DAG.getBitcast(ResVT, BuildVec);
-
-  Results.append({LoadValue, LoadChain});
-}
 
 // Lower vector return type of tcgen05.ld intrinsics
 static void ReplaceTcgen05Ld(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 27f099e220976..c559d27a2abd4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -313,7 +313,6 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
-  SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 0603994606d71..833f014a4c870 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -126,12 +126,12 @@ static std::string computeDataLayout(bool is64Bit, bool UseShortPointers) {
   // (addrspace:3).
   if (!is64Bit)
     Ret += "-p:32:32-p6:32:32-p7:32:32";
-  else if (UseShortPointers) {
+  else if (UseShortPointers)
     Ret += "-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32";
-  } else
+  else
     Ret += "-p6:32:32";
 
-  Ret += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
+  Ret += "-i64:64-i128:128-i256:256-v16:16-v32:32-n16:32:64";
 
   return Ret;
 }
diff --git a/llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll
index a846607d816c5..60dd5d9308d2a 100644
--- a/llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll
+++ b/llvm/test/CodeGen/NVPTX/load-store-vectors-256.ll
@@ -1506,3 +1506,69 @@ define void @local_volatile_4xdouble(ptr addrspace(5) %a, ptr addrspace(5) %b) {
   store volatile <4 x double> %a.load, ptr addrspace(5) %b
   ret void
 }
+
+define void @test_i256_global(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; SM90-LABEL: test_i256_global(
+; SM90:       {
+; SM90-NEXT:    .reg .b64 %rd<7>;
+; SM90-EMPTY:
+; SM90-NEXT:  // %bb.0:
+; SM90-NEXT:    ld.param.b64 %rd1, [test_i256_global_param_0];
+; SM90-NEXT:    ld.global.v2.b64 {%rd2, %rd3}, [%rd1];
+; SM90-NEXT:    ld.global.v2.b64 {%rd4, %rd5}, [%rd1+16];
+; SM90-NEXT:    ld.param.b64 %rd6, [test_i256_global_param_1];
+; SM90-NEXT:    st.global.v2.b64 [%rd6+16], {%rd4, %rd5};
+; SM90-NEXT:    st.global.v2.b64 [%rd6], {%rd2, %rd3};
+; SM90-NEXT:    ret;
+;
+; SM100-LABEL: test_i256_global(
+; SM100:       {
+; SM100-NEXT:    .reg .b64 %rd<7>;
+; SM100-EMPTY:
+; SM100-NEXT:  // %bb.0:
+; SM100-NEXT:    ld.param.b64 %rd1, [test_i256_global_param_0];
+; SM100-NEXT:    ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1];
+; SM100-NEXT:    ld.param.b64 %rd6, [test_i256_global_param_1];
+; SM100-NEXT:    st.global.v4.b64 [%rd6], {%rd2, %rd3, %rd4, %rd5};
+; SM100-NEXT:    ret;
+  %a.load = load i256, ptr addrspace(1) %a, align 32
+  store i256 %a.load, ptr addrspace(1) %b, align 32
+  ret void
+}
+
+
+define void @test_i256_global_unaligned(ptr addrspace(1) %a, ptr addrspace(1) %b) {
+; CHECK-LABEL: test_i256...
[truncated]

Copy link

github-actions bot commented Aug 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream-i256 branch 2 times, most recently from 9e7aafe to e8e640b Compare August 25, 2025 03:45
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" labels Aug 25, 2025
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream-i256 branch from e8e640b to 5a4d50f Compare August 25, 2025 05:54
%a.load = load i256, ptr %a, align 32
store i256 %a.load, ptr %b, align 32
ret void
}
Copy link
Contributor

@gonzalobg gonzalobg Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some tests that check that combining atomic and volatile with i256 errors for load and store?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added tests for both atomic and volatile. We don't support any atomic loads/stores of size greater than 64-bits.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream-i256 branch from 5a4d50f to bd18e21 Compare August 25, 2025 18:02
Copy link
Contributor

@dakersnar dakersnar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some small questions and nits

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, with few nits.

Comment on lines +201 to +204
const bool CanLowerTo256Bit = STI.has256BitVectorLoadStore(AddressSpace);

if (CanLowerTo256Bit && VectorEVT.isScalarInteger() &&
VectorEVT.getSizeInBits() == 256)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: no point checking if 256 bit ld/st are supported if we're operating ont he wrong size. I'd just fold the STI call into the if as the last condition.

@@ -1506,3 +1506,98 @@ define void @local_volatile_4xdouble(ptr addrspace(5) %a, ptr addrspace(5) %b) {
store volatile <4 x double> %a.load, ptr addrspace(5) %b
ret void
}

define void @test_i256_global(ptr addrspace(1) %a, ptr addrspace(1) %b) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a test case for non-global AS to make sure 256 bit loads/stores are not used there, and, maybe a test case for some other 256-bit type (e.g. a <2 x 128> or <4 x double>) to have a reference for what's expected to happen (or not).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants