diff --git a/src/occa/internal/modes/dpcpp/device.cpp b/src/occa/internal/modes/dpcpp/device.cpp index 7d39cbd9e..6caacc014 100644 --- a/src/occa/internal/modes/dpcpp/device.cpp +++ b/src/occa/internal/modes/dpcpp/device.cpp @@ -17,39 +17,12 @@ namespace occa { namespace dpcpp { - device::device(const occa::json &properties_) - : occa::launchedModeDevice_t(properties_) + device::device(const occa::json &properties_, + const ::sycl::device& device_) + : occa::launchedModeDevice_t(properties_), dpcppDevice(device_), dpcppContext(device_) { - if (!properties.has("wrapped")) - { - OCCA_ERROR( - "[dpcpp] device not given a [platform_id] integer", - properties.has("platform_id") && properties["platform_id"].isNumber()); - - OCCA_ERROR( - "[dpcpp] device not given a [device_id] integer", - properties.has("device_id") && properties["device_id"].isNumber()); - - platformID = properties.get("platform_id"); - deviceID = properties.get("device_id"); - - auto platforms{::sycl::platform::get_platforms()}; - OCCA_ERROR( - "Invalid platform number (" + toString(platformID) + ")", - (static_cast(platformID) < platforms.size())); - - auto devices{platforms[platformID].get_devices()}; - OCCA_ERROR( - "Invalid device number (" + toString(deviceID) + ")", - (static_cast(deviceID) < devices.size())); - - dpcppDevice = devices[deviceID]; - dpcppContext = ::sycl::context(devices[deviceID]); - } - occa::json &kernelProps = properties["kernel"]; setCompilerLinkerOptions(kernelProps); - arch = dpcppDevice.get_info<::sycl::info::device::name>(); } diff --git a/src/occa/internal/modes/dpcpp/device.hpp b/src/occa/internal/modes/dpcpp/device.hpp index bbd26b54d..d601f61ea 100644 --- a/src/occa/internal/modes/dpcpp/device.hpp +++ b/src/occa/internal/modes/dpcpp/device.hpp @@ -17,12 +17,12 @@ namespace occa mutable hash_t hash_; public: - int platformID{-1}, deviceID{-1}; - ::sycl::device dpcppDevice; ::sycl::context dpcppContext; - device(const occa::json &properties_); + device(const occa::json &properties_, + const ::sycl::device& device_); + virtual ~device() = default; inline bool hasSeparateMemorySpace() const override { return true; } diff --git a/src/occa/internal/modes/dpcpp/registration.cpp b/src/occa/internal/modes/dpcpp/registration.cpp index 0229a5a90..55a21486e 100644 --- a/src/occa/internal/modes/dpcpp/registration.cpp +++ b/src/occa/internal/modes/dpcpp/registration.cpp @@ -82,7 +82,30 @@ namespace occa { } modeDevice_t* dpcppMode::newDevice(const occa::json &props) { - return new occa::dpcpp::device(setModeProp(props)); + // Refactor this into a helper function. + OCCA_ERROR( + "[dpcpp] device not given a [platform_id] integer", + props.has("platform_id") && props["platform_id"].isNumber()); + int platformID = props.get("platform_id"); + + auto platforms{::sycl::platform::get_platforms()}; + OCCA_ERROR( + "Invalid platform number (" + toString(platformID) + ")", + (static_cast(platformID) < platforms.size())); + auto& platform = platforms[platformID]; + + OCCA_ERROR( + "[dpcpp] device not given a [device_id] integer", + props.has("device_id") && props["device_id"].isNumber()); + + int deviceID = props.get("device_id"); + auto devices{platform.get_devices()}; + OCCA_ERROR( + "Invalid device number (" + toString(deviceID) + ")", + (static_cast(deviceID) < devices.size())); + auto& dpcppDevice = devices[deviceID]; + + return new occa::dpcpp::device(setModeProp(props), dpcppDevice); } int dpcppMode::getDeviceCount(const occa::json& props) { diff --git a/src/occa/internal/modes/dpcpp/utils.cpp b/src/occa/internal/modes/dpcpp/utils.cpp index 0ec58ff46..c8a8bc27e 100644 --- a/src/occa/internal/modes/dpcpp/utils.cpp +++ b/src/occa/internal/modes/dpcpp/utils.cpp @@ -111,23 +111,18 @@ namespace occa return *dpcppTag; } - occa::device wrapDevice(::sycl::device device, + occa::device wrapDevice(::sycl::device sycl_device, const occa::properties &props) { occa::properties allProps; allProps["mode"] = "dpcpp"; - allProps["device_id"] = -1; - allProps["platform_id"] = -1; allProps["wrapped"] = true; allProps += props; - auto* wrapper{new dpcpp::device(allProps)}; + auto* wrapper{new dpcpp::device(allProps, sycl_device)}; wrapper->dontUseRefs(); - wrapper->dpcppDevice = device; - wrapper->dpcppContext = ::sycl::context(device); wrapper->currentStream = wrapper->createStream(allProps["stream"]); - return occa::device(wrapper); } diff --git a/src/occa/internal/modes/dpcpp/utils.hpp b/src/occa/internal/modes/dpcpp/utils.hpp index f9b845e3f..939d5a9f5 100644 --- a/src/occa/internal/modes/dpcpp/utils.hpp +++ b/src/occa/internal/modes/dpcpp/utils.hpp @@ -37,7 +37,7 @@ namespace occa { occa::dpcpp::stream& getDpcppStream(const occa::stream& stream_); occa::dpcpp::streamTag &getDpcppStreamTag(const occa::streamTag& tag); - occa::device wrapDevice(::sycl::device device, + occa::device wrapDevice(::sycl::device sycl_device, const occa::properties &props = occa::properties()); void warn(const ::sycl::exception &e,