Skip to content

Commit

Permalink
Avoid unintended calls to sycl default device/context ctor. (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
kris-rowe committed Aug 21, 2023
1 parent 1590291 commit cf1b0a8
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 42 deletions.
33 changes: 3 additions & 30 deletions src/occa/internal/modes/dpcpp/device.cpp
Expand Up @@ -17,39 +17,12 @@ namespace occa
{
namespace dpcpp
{
device::device(const occa::json &properties_)
: occa::launchedModeDevice_t(properties_)
device::device(const occa::json &properties_,
const ::sycl::device& device_)
: occa::launchedModeDevice_t(properties_), dpcppDevice(device_), dpcppContext(device_)
{
if (!properties.has("wrapped"))
{
OCCA_ERROR(
"[dpcpp] device not given a [platform_id] integer",
properties.has("platform_id") && properties["platform_id"].isNumber());

OCCA_ERROR(
"[dpcpp] device not given a [device_id] integer",
properties.has("device_id") && properties["device_id"].isNumber());

platformID = properties.get<int>("platform_id");
deviceID = properties.get<int>("device_id");

auto platforms{::sycl::platform::get_platforms()};
OCCA_ERROR(
"Invalid platform number (" + toString(platformID) + ")",
(static_cast<size_t>(platformID) < platforms.size()));

auto devices{platforms[platformID].get_devices()};
OCCA_ERROR(
"Invalid device number (" + toString(deviceID) + ")",
(static_cast<size_t>(deviceID) < devices.size()));

dpcppDevice = devices[deviceID];
dpcppContext = ::sycl::context(devices[deviceID]);
}

occa::json &kernelProps = properties["kernel"];
setCompilerLinkerOptions(kernelProps);

arch = dpcppDevice.get_info<::sycl::info::device::name>();
}

Expand Down
6 changes: 3 additions & 3 deletions src/occa/internal/modes/dpcpp/device.hpp
Expand Up @@ -17,12 +17,12 @@ namespace occa
mutable hash_t hash_;

public:
int platformID{-1}, deviceID{-1};

::sycl::device dpcppDevice;
::sycl::context dpcppContext;

device(const occa::json &properties_);
device(const occa::json &properties_,
const ::sycl::device& device_);

virtual ~device() = default;

inline bool hasSeparateMemorySpace() const override { return true; }
Expand Down
25 changes: 24 additions & 1 deletion src/occa/internal/modes/dpcpp/registration.cpp
Expand Up @@ -82,7 +82,30 @@ namespace occa {
}

modeDevice_t* dpcppMode::newDevice(const occa::json &props) {
return new occa::dpcpp::device(setModeProp(props));
// Refactor this into a helper function.
OCCA_ERROR(
"[dpcpp] device not given a [platform_id] integer",
props.has("platform_id") && props["platform_id"].isNumber());
int platformID = props.get<int>("platform_id");

auto platforms{::sycl::platform::get_platforms()};
OCCA_ERROR(
"Invalid platform number (" + toString(platformID) + ")",
(static_cast<size_t>(platformID) < platforms.size()));
auto& platform = platforms[platformID];

OCCA_ERROR(
"[dpcpp] device not given a [device_id] integer",
props.has("device_id") && props["device_id"].isNumber());

int deviceID = props.get<int>("device_id");
auto devices{platform.get_devices()};
OCCA_ERROR(
"Invalid device number (" + toString(deviceID) + ")",
(static_cast<size_t>(deviceID) < devices.size()));
auto& dpcppDevice = devices[deviceID];

return new occa::dpcpp::device(setModeProp(props), dpcppDevice);
}

int dpcppMode::getDeviceCount(const occa::json& props) {
Expand Down
9 changes: 2 additions & 7 deletions src/occa/internal/modes/dpcpp/utils.cpp
Expand Up @@ -111,23 +111,18 @@ namespace occa
return *dpcppTag;
}

occa::device wrapDevice(::sycl::device device,
occa::device wrapDevice(::sycl::device sycl_device,
const occa::properties &props)
{
occa::properties allProps;
allProps["mode"] = "dpcpp";
allProps["device_id"] = -1;
allProps["platform_id"] = -1;
allProps["wrapped"] = true;
allProps += props;

auto* wrapper{new dpcpp::device(allProps)};
auto* wrapper{new dpcpp::device(allProps, sycl_device)};
wrapper->dontUseRefs();

wrapper->dpcppDevice = device;
wrapper->dpcppContext = ::sycl::context(device);
wrapper->currentStream = wrapper->createStream(allProps["stream"]);

return occa::device(wrapper);
}

Expand Down
2 changes: 1 addition & 1 deletion src/occa/internal/modes/dpcpp/utils.hpp
Expand Up @@ -37,7 +37,7 @@ namespace occa {
occa::dpcpp::stream& getDpcppStream(const occa::stream& stream_);
occa::dpcpp::streamTag &getDpcppStreamTag(const occa::streamTag& tag);

occa::device wrapDevice(::sycl::device device,
occa::device wrapDevice(::sycl::device sycl_device,
const occa::properties &props = occa::properties());

void warn(const ::sycl::exception &e,
Expand Down

0 comments on commit cf1b0a8

Please sign in to comment.