diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 07e8994df57f6..19e8364876df4 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -549,6 +549,13 @@ graph_impl::add(std::shared_ptr &DynCGImpl, return NodeImpl; } +std::shared_ptr graph_impl::getQueue() const { + std::shared_ptr Return{}; + if (!MRecordingQueues.empty()) + Return = MRecordingQueues.begin()->lock(); + return Return; +} + void graph_impl::addQueue(sycl::detail::queue_impl &RecordingQueue) { MRecordingQueues.insert(RecordingQueue.weak_from_this()); } @@ -870,10 +877,6 @@ exec_graph_impl::exec_graph_impl(sycl::context Context, const std::shared_ptr &GraphImpl, const property_list &PropList) : MSchedule(), MGraphImpl(GraphImpl), MSyncPoints(), - MQueueImpl(sycl::detail::queue_impl::create( - *sycl::detail::getSyclObjImpl(GraphImpl->getDevice()), - *sycl::detail::getSyclObjImpl(Context), sycl::async_handler{}, - sycl::property_list{})), MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(), MSchedulerDependencies(), MIsUpdatable(PropList.has_property()), @@ -893,6 +896,15 @@ exec_graph_impl::exec_graph_impl(sycl::context Context, } // Copy nodes from GraphImpl and merge any subgraph nodes into this graph. duplicateNodes(); + + if (auto PlaceholderQueuePtr = GraphImpl->getQueue()) { + MQueueImpl = PlaceholderQueuePtr; + } else { + MQueueImpl = sycl::detail::queue_impl::create( + *sycl::detail::getSyclObjImpl(GraphImpl->getDevice()), + *sycl::detail::getSyclObjImpl(Context), sycl::async_handler{}, + sycl::property_list{}); + } } exec_graph_impl::~exec_graph_impl() { diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index eedfcf0506bf3..cabca62958e19 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -172,6 +172,8 @@ class graph_impl : public std::enable_shared_from_this { node_impl &add(std::shared_ptr &DynCGImpl, nodes_range Deps); + std::shared_ptr getQueue() const; + /// Add a queue to the set of queues which are currently recording to this /// graph. /// @param RecordingQueue Queue to add to set. diff --git a/sycl/source/interop_handle.cpp b/sycl/source/interop_handle.cpp index 77bf655e9c12d..91cd9526e8498 100644 --- a/sycl/source/interop_handle.cpp +++ b/sycl/source/interop_handle.cpp @@ -60,7 +60,9 @@ ur_native_handle_t interop_handle::getNativeContext() const { ur_native_handle_t interop_handle::getNativeQueue(int32_t &NativeHandleDesc) const { - return MQueue->getNative(NativeHandleDesc); + if (MQueue != nullptr) + return MQueue->getNative(NativeHandleDesc); + return 0; } ur_native_handle_t interop_handle::getNativeGraph() const {