Skip to content

Commit

Permalink
Fixes #626: Now using cuda::optional for the optional fields in `cu…
Browse files Browse the repository at this point in the history
…da::device::pci_location_t` + comment improvements
  • Loading branch information
eyalroz committed Apr 19, 2024
1 parent 650acfd commit 49c004f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
17 changes: 9 additions & 8 deletions src/cuda/api/detail/pci_id.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <ostream>
#include <sstream>


namespace cuda {
namespace device {

Expand All @@ -43,40 +42,42 @@ inline ::std::istream& operator>>(::std::istream& is, cuda::device::pci_location
is >> first_field;
auto get_colon = [&]() {
auto c = is.get();
// if (c == istream::traits_type::eof() or ) {
if (c != ':') {
throw ::std::invalid_argument("Invalid format of a PCI location for a CUDA device 1");
}
};
get_colon();

int second_field;
int function;
is >> second_field;
switch(is.get()) {
case '.':
// It's the third format
pci_id.domain = pci_location_t::unused; // Is this a reasonable choice?
pci_id.domain = {};
pci_id.bus = first_field;
pci_id.device = second_field;
is >> pci_id.function;
is >> function;
if (not is.good()) {
throw ::std::invalid_argument("Failed parsing PCI location ID for a CUDA device 2");
}
pci_id.function = function;
break;
case ':': {
pci_id.domain = first_field;
pci_id.bus = second_field;
is >> pci_id.device;
if (is.peek() != '.') {
// It's the second format.
pci_id.function = pci_location_t::unused; // Is this a reasonable choice? I woudld have liked that...
pci_id.function = {};
is.flags(format_flags);
return is;
}
else {
// It's the first format.
is.get();
is >> pci_id.function;
is >> function;
pci_id.function = function;
is.flags(format_flags);
return is;
}
Expand All @@ -90,9 +91,9 @@ inline ::std::ostream& operator<<(::std::ostream& os, const cuda::device::pci_lo
{
auto format_flags(os.flags());
os << ::std::hex;
if (pci_id.domain != pci_location_t::unused) { os << pci_id.domain << ':'; }
if (pci_id.domain) { os << pci_id.domain.value() << ':'; }
os << pci_id.bus << ':' << pci_id.device;
if (pci_id.function != pci_location_t::unused) { os << '.' << pci_id.function; }
if (pci_id.function) { os << '.' << pci_id.function.value(); }
os.flags(format_flags);
return os;
}
Expand Down
2 changes: 1 addition & 1 deletion src/cuda/api/device_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ struct properties_t : public cudaDeviceProp {
return { { static_cast<unsigned>(major) }, static_cast<unsigned>(minor) };
}
compute_architecture_t compute_architecture() const noexcept { return { static_cast<unsigned>(major) }; };
pci_location_t pci_id() const noexcept { return { pciDomainID, pciBusID, pciDeviceID, pci_location_t::unused }; }
pci_location_t pci_id() const noexcept { return { pciDomainID, pciBusID, pciDeviceID, {} }; }

unsigned long long max_in_flight_threads_on_device() const
{
Expand Down
35 changes: 27 additions & 8 deletions src/cuda/api/pci_id.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,41 @@ namespace device {
* Location "coordinates" for a CUDA device on a PCIe bus
*
* @note can be compiled from individual values from a device's properties;
* see @ref properties_t
* see {@ref properties_t}.
*/
struct pci_location_t {
// These are the values CUDA's API provides us with directly
int domain;
/**
* The four fields of the PCI configuration space.
*
* @note Only the first three are actually used/recognized by the CUDA driver, and
* when querying a CUDA device for its PCI ID, function will be unused. However - we
* have it be able to parse the different common string notations of PCI IDs; see
* @url https://wiki.xenproject.org/wiki/Bus:Device.Function_(BDF)_Notation .
*/
///@{
optional<int> domain;
int bus;
int device;
int function;
optional<int> function;
///@}

operator ::std::string() const;
// This is not a ctor so as to maintain the PODness

/**
* Parse a string representation of a device's PCI location.
*
* @note This is not a ctor so as to maintain the PODness.
*
* @note There are multiple notations for PCI IDs:
*
* domain::bus::device.function
* domain::bus::device
* bus::device.function
*
* and any of them can be used.
*/
static pci_location_t parse(const ::std::string& id_str);
static pci_location_t parse(const char* id_str);
public:
static constexpr const int unused { -1 };
// In lieu of making this class a variant with 3 type combinations.
};

namespace detail_ {
Expand Down

0 comments on commit 49c004f

Please sign in to comment.