diff --git a/examples/print/src/main.rs b/examples/print/src/main.rs index 7423f06a..7cd9ab3f 100644 --- a/examples/print/src/main.rs +++ b/examples/print/src/main.rs @@ -55,7 +55,7 @@ fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> { ); // Create a new CUDA stream to submit kernels to - let stream = + let mut stream = rust_cuda::host::CudaDropWrapper::from(rust_cuda::deps::rustacuda::stream::Stream::new( rust_cuda::deps::rustacuda::stream::StreamFlags::NON_BLOCKING, None, @@ -70,12 +70,14 @@ fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> { }; // Launch the CUDA kernel on the stream and synchronise to its completion - println!("Launching print kernel ..."); - kernel.launch1(&stream, &config, Action::Print)?; - println!("Launching panic kernel ..."); - kernel.launch1(&stream, &config, Action::Panic)?; - println!("Launching alloc error kernel ..."); - kernel.launch1(&stream, &config, Action::AllocError)?; + rust_cuda::host::Stream::with(&mut stream, |stream| { + println!("Launching print kernel ..."); + kernel.launch1(stream, &config, Action::Print)?; + println!("Launching panic kernel ..."); + kernel.launch1(stream, &config, Action::Panic)?; + println!("Launching alloc error kernel ..."); + kernel.launch1(stream, &config, Action::AllocError) + })?; Ok(()) } diff --git a/rust-cuda-derive/src/rust_to_cuda/impl.rs b/rust-cuda-derive/src/rust_to_cuda/impl.rs index 40dd3487..e45a0e28 100644 --- a/rust-cuda-derive/src/rust_to_cuda/impl.rs +++ b/rust-cuda-derive/src/rust_to_cuda/impl.rs @@ -191,7 +191,7 @@ pub fn rust_to_cuda_async_trait( unsafe fn borrow_async<'stream, CudaAllocType: #crate_path::alloc::CudaAlloc>( &self, alloc: CudaAllocType, - stream: &'stream #crate_path::deps::rustacuda::stream::Stream, + stream: &'stream #crate_path::host::Stream, ) -> #crate_path::deps::rustacuda::error::CudaResult<( #crate_path::utils::r#async::Async< '_, 'stream, @@ -219,7 +219,7 @@ pub fn rust_to_cuda_async_trait( alloc: #crate_path::alloc::CombinedCudaAlloc< Self::CudaAllocationAsync, CudaAllocType >, - stream: &'stream #crate_path::deps::rustacuda::stream::Stream, + stream: &'stream #crate_path::host::Stream, ) -> #crate_path::deps::rustacuda::error::CudaResult<( #crate_path::utils::r#async::Async< 'a, 'stream, diff --git a/rust-cuda-kernel/src/kernel/wrapper/generate/cuda_generic_function.rs b/rust-cuda-kernel/src/kernel/wrapper/generate/cuda_generic_function.rs index 4084db0e..62cb3456 100644 --- a/rust-cuda-kernel/src/kernel/wrapper/generate/cuda_generic_function.rs +++ b/rust-cuda-kernel/src/kernel/wrapper/generate/cuda_generic_function.rs @@ -82,6 +82,9 @@ pub(in super::super) fn quote_cuda_generic_function( ) .collect::>(); + let generic_start_token = generic_start_token.unwrap_or_default(); + let generic_close_token = generic_close_token.unwrap_or_default(); + quote! { #[cfg(target_os = "cuda")] #(#func_attrs)* diff --git a/src/host/mod.rs b/src/host/mod.rs index ef45511e..25fd73a8 100644 --- a/src/host/mod.rs +++ b/src/host/mod.rs @@ -11,7 +11,6 @@ use rustacuda::{ event::Event, memory::{CopyDestination, DeviceBox, DeviceBuffer, LockedBox, LockedBuffer}, module::Module, - stream::Stream, }; use crate::{ @@ -26,6 +25,33 @@ use crate::{ }, }; +#[repr(transparent)] +pub struct Stream { + stream: rustacuda::stream::Stream, +} + +impl Deref for Stream { + type Target = rustacuda::stream::Stream; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl Stream { + pub fn with( + stream: &mut rustacuda::stream::Stream, + inner: impl for<'stream> FnOnce(&'stream Self) -> O, + ) -> O { + // Safety: + // - Stream is a newtype wrapper around rustacuda::stream::Stream + // - we forge a unique lifetime for a unique reference + let stream = unsafe { &*std::ptr::from_ref(stream).cast() }; + + inner(stream) + } +} + pub trait CudaDroppable: Sized { #[allow(clippy::missing_errors_doc)] fn drop(val: Self) -> Result<(), (rustacuda::error::CudaError, Self)>; @@ -88,7 +114,7 @@ impl CudaDroppable for LockedBuffer { } macro_rules! impl_sealed_drop_value { - ($type:ident) => { + ($type:ty) => { impl CudaDroppable for $type { fn drop(val: Self) -> Result<(), (CudaError, Self)> { Self::drop(val) @@ -98,7 +124,7 @@ macro_rules! impl_sealed_drop_value { } impl_sealed_drop_value!(Module); -impl_sealed_drop_value!(Stream); +impl_sealed_drop_value!(rustacuda::stream::Stream); impl_sealed_drop_value!(Context); impl_sealed_drop_value!(Event); @@ -142,7 +168,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> { /// # Safety /// /// `device_box` must contain EXACTLY the device copy of `host_ref` - pub unsafe fn new_unchecked( + pub(crate) unsafe fn new_unchecked( device_box: &'a mut DeviceBox>, host_ref: &'a mut T, ) -> Self { @@ -180,7 +206,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> { } #[must_use] - pub fn as_mut<'b>(&'b mut self) -> HostAndDeviceMutRef<'b, T> + pub fn into_mut<'b>(self) -> HostAndDeviceMutRef<'b, T> where 'a: 'b, { @@ -191,20 +217,14 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> { } #[must_use] - pub fn as_async<'b, 'stream>( - &'b mut self, + pub fn into_async<'b, 'stream>( + self, stream: &'stream Stream, ) -> Async<'b, 'stream, HostAndDeviceMutRef<'b, T>, NoCompletion> where 'a: 'b, { - Async::ready( - HostAndDeviceMutRef { - device_box: self.device_box, - host_ref: self.host_ref, - }, - stream, - ) + Async::ready(self.into_mut(), stream) } } @@ -253,7 +273,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceConstRef<'a, T> /// # Safety /// /// `device_box` must contain EXACTLY the device copy of `host_ref` - pub const unsafe fn new_unchecked( + pub(crate) const unsafe fn new_unchecked( device_box: &'a DeviceBox>, host_ref: &'a T, ) -> Self { diff --git a/src/kernel/mod.rs b/src/kernel/mod.rs index a27ed5b7..b5fea0af 100644 --- a/src/kernel/mod.rs +++ b/src/kernel/mod.rs @@ -11,7 +11,6 @@ use rustacuda::{ error::{CudaError, CudaResult}, function::Function, module::Module, - stream::Stream, }; #[cfg(feature = "kernel")] @@ -27,6 +26,8 @@ mod ptx_jit; #[cfg(feature = "host")] use ptx_jit::{PtxJITCompiler, PtxJITResult}; +#[cfg(feature = "host")] +use crate::host::Stream; use crate::safety::PortableBitSemantics; pub mod param; @@ -109,7 +110,7 @@ pub trait CudaKernelParameter: sealed::Sealed { #[allow(clippy::missing_errors_doc)] // FIXME fn with_new_async<'stream, 'param, O, E: From>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -206,7 +207,9 @@ macro_rules! impl_launcher_launch { pub fn $launch_async<$($T: CudaKernelParameter),*>( &mut self, $($arg: $T::AsyncHostType<'stream, '_>),* - ) -> CudaResult<()> + ) -> CudaResult> where Kernel: FnOnce(&mut Launcher<'stream, '_, Kernel>, $($T),*), { @@ -375,13 +378,10 @@ macro_rules! impl_typed_kernel_launch { config, $($arg,)* |kernel, stream, config, $($arg),*| { - let result = kernel.$launch_async::<$($T),*>(stream, config, $($arg),*); + let r#async = kernel.$launch_async::<$($T),*>(stream, config, $($arg),*)?; // important: always synchronise here, this function is sync! - match (stream.synchronize(), result) { - (Ok(()), result) => result, - (Err(_), Err(err)) | (Err(err), Ok(())) => Err(err), - } + r#async.synchronize() }, ) } @@ -422,7 +422,29 @@ macro_rules! impl_typed_kernel_launch { stream: &'stream Stream, config: &LaunchConfig, $($arg: $T::AsyncHostType<'stream, '_>),* - ) -> CudaResult<()> + ) -> CudaResult> + // launch_async does not need to capture its parameters until kernel completion: + // - moved parameters are moved and cannot be used again, deallocation will sync + // - immutably borrowed parameters can be shared across multiple kernel launches + // - mutably borrowed parameters are more tricky: + // - Rust's borrowing rules ensure that a single mutable reference cannot be + // passed into multiple parameters of the kernel (no mutable aliasing) + // - CUDA guarantees that kernels launched on the same stream are executed + // sequentially, so even immediate resubmissions for the same mutable data + // will not have temporally overlapping mutation on the same stream + // - however, we have to guarantee that mutable data cannot be used on several + // different streams at the same time + // - Async::move_to_stream always adds a synchronisation barrier between the + // old and the new stream to ensure that all uses on the old stream happen + // strictly before all uses on the new stream + // - async launches take AsyncProj<&mut HostAndDeviceMutRef<..>>, which either + // captures an Async, which must be moved to a different stream explicitly, + // or contains data that cannot async move to a different stream without + // - any use of a mutable borrow in an async kernel launch adds a sync barrier + // on the launch stream s.t. the borrow is only complete once the kernel has + // completed where Kernel: FnOnce(&mut Launcher<'stream, 'kernel, Kernel>, $($T),*), { @@ -454,7 +476,11 @@ macro_rules! impl_typed_kernel_launch { &mut $T::async_to_ffi($arg, sealed::Token)? ).cast::()),* ], - ) } + ) }?; + + crate::utils::r#async::Async::pending( + (), stream, crate::utils::r#async::NoCompletion, + ) } }; (impl $func:ident () + ($($other:expr),*) $inner:block) => { diff --git a/src/kernel/param.rs b/src/kernel/param.rs index 6be634b2..a5a3cf45 100644 --- a/src/kernel/param.rs +++ b/src/kernel/param.rs @@ -81,7 +81,7 @@ impl< #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - _stream: &'stream rustacuda::stream::Stream, + _stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -167,7 +167,7 @@ impl< #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -251,7 +251,7 @@ impl< #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -272,7 +272,7 @@ impl< where Self: 'b, { - let param = unsafe { param.unwrap_unchecked() }; + let param = unsafe { param.unwrap_ref_unchecked() }; inner(Some(¶m_as_raw_bytes(param.for_host()))) } @@ -373,7 +373,7 @@ impl< #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -509,7 +509,7 @@ impl< #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -595,7 +595,7 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a DeepPerThreadBorrow>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -678,16 +678,20 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where Self: 'b, { - crate::lend::LendToCuda::lend_to_cuda_mut(param, |mut param| { - // FIXME: express the same with param.as_async(stream).as_mut() + crate::lend::LendToCuda::lend_to_cuda_mut(param, |param| { + // FIXME: express the same with param.into_async(stream).as_mut() let _ = stream; - inner.with(crate::utils::r#async::AsyncProj::new(&mut param.as_mut())) + inner.with({ + // Safety: this projection cannot be moved to a different stream + // without first exiting lend_to_cuda_mut and synchronizing + unsafe { crate::utils::r#async::AsyncProj::new(&mut param.into_mut(), None) } + }) }) } @@ -716,12 +720,13 @@ impl<'a, T: Sync + RustToCuda + SafeMutableAliasing> CudaKernelParameter #[cfg(feature = "host")] fn async_to_ffi<'stream, 'b, E: From>( - param: Self::AsyncHostType<'stream, 'b>, + mut param: Self::AsyncHostType<'stream, 'b>, _token: sealed::Token, ) -> Result, E> where Self: 'b, { + param.record_mut_use()?; let param = unsafe { param.unwrap_unchecked() }; Ok(param.for_device()) } @@ -763,7 +768,7 @@ impl< #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -846,7 +851,7 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a PtxJit>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -867,7 +872,7 @@ impl<'a, T: Sync + RustToCuda> CudaKernelParameter for &'a PtxJit CudaKernelParameter #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where Self: 'b, { // FIXME: forward impl - crate::lend::LendToCuda::lend_to_cuda_mut(param, |mut param| { + crate::lend::LendToCuda::lend_to_cuda_mut(param, |param| { // FIXME: express the same with param.as_async(stream).as_mut() let _ = stream; - inner.with(crate::utils::r#async::AsyncProj::new(&mut param.as_mut())) + inner.with({ + // Safety: this projection cannot be moved to a different stream + // without first exiting lend_to_cuda_mut and synchronizing + unsafe { crate::utils::r#async::AsyncProj::new(&mut param.into_mut(), None) } + }) }) } @@ -1049,7 +1058,7 @@ impl<'a, T: 'static> CudaKernelParameter for &'a mut crate::utils::shared::Threa #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - _stream: &'stream rustacuda::stream::Stream, + _stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where @@ -1126,7 +1135,7 @@ impl<'a, T: 'static + PortableBitSemantics + TypeGraphLayout> CudaKernelParamete #[cfg(feature = "host")] fn with_new_async<'stream, 'b, O, E: From>( param: Self::SyncHostType, - _stream: &'stream rustacuda::stream::Stream, + _stream: &'stream crate::host::Stream, inner: impl super::WithNewAsync<'stream, Self, O, E>, ) -> Result where diff --git a/src/lend/impls/box.rs b/src/lend/impls/box.rs index 121fe390..fff0bb8d 100644 --- a/src/lend/impls/box.rs +++ b/src/lend/impls/box.rs @@ -90,7 +90,7 @@ unsafe impl RustToCudaAsync for Box( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -131,7 +131,7 @@ unsafe impl RustToCudaAsync for Box( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> CudaResult<( Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>, A, diff --git a/src/lend/impls/boxed_slice.rs b/src/lend/impls/boxed_slice.rs index 09a612c9..c275a6d1 100644 --- a/src/lend/impls/boxed_slice.rs +++ b/src/lend/impls/boxed_slice.rs @@ -96,7 +96,7 @@ unsafe impl RustToCudaAsync for Box<[ unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -140,7 +140,7 @@ unsafe impl RustToCudaAsync for Box<[ unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> CudaResult<( Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>, A, diff --git a/src/lend/impls/final.rs b/src/lend/impls/final.rs index 6235a58f..845424ef 100644 --- a/src/lend/impls/final.rs +++ b/src/lend/impls/final.rs @@ -49,7 +49,7 @@ unsafe impl RustToCudaAsync for Final { unsafe fn borrow_async<'stream, A: crate::alloc::CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async<'_, 'stream, DeviceAccessible>, crate::alloc::CombinedCudaAlloc, @@ -76,7 +76,7 @@ unsafe impl RustToCudaAsync for Final { unsafe fn restore_async<'a, 'stream, A: crate::alloc::CudaAlloc, O>( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: crate::alloc::CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async< 'a, diff --git a/src/lend/impls/option.rs b/src/lend/impls/option.rs index b1c51b9a..76be7e76 100644 --- a/src/lend/impls/option.rs +++ b/src/lend/impls/option.rs @@ -89,7 +89,7 @@ unsafe impl RustToCudaAsync for Option { unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> CudaResult<( Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -135,7 +135,7 @@ unsafe impl RustToCudaAsync for Option { unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>( mut this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> CudaResult<( Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>, A, diff --git a/src/lend/impls/ref.rs b/src/lend/impls/ref.rs index 501393f6..3ce47231 100644 --- a/src/lend/impls/ref.rs +++ b/src/lend/impls/ref.rs @@ -85,7 +85,7 @@ unsafe impl<'a, T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for & unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -127,7 +127,7 @@ unsafe impl<'a, T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for & unsafe fn restore_async<'b, 'stream, A: CudaAlloc, O>( this: owning_ref::BoxRefMut<'b, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> CudaResult<( Async<'b, 'stream, owning_ref::BoxRefMut<'b, O, Self>, CompletionFnMut<'b, Self>>, A, diff --git a/src/lend/impls/slice_ref.rs b/src/lend/impls/slice_ref.rs index 4f8a3ecd..07271a75 100644 --- a/src/lend/impls/slice_ref.rs +++ b/src/lend/impls/slice_ref.rs @@ -88,7 +88,7 @@ unsafe impl<'a, T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for & unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -132,7 +132,7 @@ unsafe impl<'a, T: PortableBitSemantics + TypeGraphLayout> RustToCudaAsync for & unsafe fn restore_async<'b, 'stream, A: CudaAlloc, O>( this: owning_ref::BoxRefMut<'b, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> CudaResult<( Async<'b, 'stream, owning_ref::BoxRefMut<'b, O, Self>, CompletionFnMut<'b, Self>>, A, diff --git a/src/lend/mod.rs b/src/lend/mod.rs index 7a3934aa..6c0467fd 100644 --- a/src/lend/mod.rs +++ b/src/lend/mod.rs @@ -101,7 +101,7 @@ pub unsafe trait RustToCudaAsync: RustToCuda { unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -127,7 +127,7 @@ pub unsafe trait RustToCudaAsync: RustToCuda { unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>, A, @@ -324,7 +324,7 @@ pub trait LendToCudaAsync: RustToCudaAsync { ) -> Result, >( &self, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: F, ) -> Result where @@ -357,7 +357,7 @@ pub trait LendToCudaAsync: RustToCudaAsync { T: 'a, >( this: owning_ref::BoxRefMut<'a, T, Self>, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: F, ) -> Result< ( @@ -393,7 +393,7 @@ pub trait LendToCudaAsync: RustToCudaAsync { ) -> Result, >( self, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: F, ) -> Result where @@ -416,7 +416,7 @@ impl LendToCudaAsync for T { ) -> Result, >( &self, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: F, ) -> Result where @@ -458,7 +458,7 @@ impl LendToCudaAsync for T { S: 'a, >( this: owning_ref::BoxRefMut<'a, S, Self>, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: F, ) -> Result< ( @@ -505,7 +505,7 @@ impl LendToCudaAsync for T { ) -> Result, >( self, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, inner: F, ) -> Result where diff --git a/src/utils/adapter.rs b/src/utils/adapter.rs index f7d04108..84aa2856 100644 --- a/src/utils/adapter.rs +++ b/src/utils/adapter.rs @@ -156,7 +156,7 @@ unsafe impl RustToCudaAsync unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -172,7 +172,7 @@ unsafe impl RustToCudaAsync unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async< 'a, @@ -346,7 +346,7 @@ unsafe impl RustToCudaAsync unsafe fn borrow_async<'stream, A: CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -362,7 +362,7 @@ unsafe impl RustToCudaAsync unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async< 'a, diff --git a/src/utils/aliasing/const.rs b/src/utils/aliasing/const.rs index 3ca7b059..24178131 100644 --- a/src/utils/aliasing/const.rs +++ b/src/utils/aliasing/const.rs @@ -222,7 +222,7 @@ unsafe impl RustToCudaAsync unsafe fn borrow_async<'stream, A: crate::alloc::CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async<'_, 'stream, DeviceAccessible>, crate::alloc::CombinedCudaAlloc, @@ -250,7 +250,7 @@ unsafe impl RustToCudaAsync unsafe fn restore_async<'a, 'stream, A: crate::alloc::CudaAlloc, O>( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: crate::alloc::CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async< 'a, diff --git a/src/utils/aliasing/dynamic.rs b/src/utils/aliasing/dynamic.rs index 2c663e9d..c16d4bf4 100644 --- a/src/utils/aliasing/dynamic.rs +++ b/src/utils/aliasing/dynamic.rs @@ -200,7 +200,7 @@ unsafe impl RustToCudaAsync for SplitSliceOverCudaThreadsDyn unsafe fn borrow_async<'stream, A: crate::alloc::CudaAlloc>( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async<'_, 'stream, DeviceAccessible>, crate::alloc::CombinedCudaAlloc, @@ -230,7 +230,7 @@ unsafe impl RustToCudaAsync for SplitSliceOverCudaThreadsDyn unsafe fn restore_async<'a, 'stream, A: crate::alloc::CudaAlloc, O>( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: crate::alloc::CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( crate::utils::r#async::Async< 'a, diff --git a/src/utils/async.rs b/src/utils/async.rs index b008ac55..7a33da8d 100644 --- a/src/utils/async.rs +++ b/src/utils/async.rs @@ -3,12 +3,12 @@ use std::{borrow::BorrowMut, future::Future, future::IntoFuture, marker::Phantom #[cfg(feature = "host")] use rustacuda::{ - error::CudaError, error::CudaResult, event::Event, event::EventFlags, stream::Stream, + error::CudaError, error::CudaResult, event::Event, event::EventFlags, stream::StreamWaitEventFlags, }; #[cfg(feature = "host")] -use crate::host::CudaDropWrapper; +use crate::host::{CudaDropWrapper, Stream}; #[cfg(feature = "host")] pub struct NoCompletion; @@ -19,6 +19,8 @@ pub type CompletionFnMut<'a, T> = Box CudaResult<()> + 'a> pub trait Completion>: sealed::Sealed { type Completed: ?Sized; + fn no_op() -> Self; + #[doc(hidden)] fn synchronize_on_drop(&self) -> bool; @@ -34,6 +36,11 @@ mod sealed { impl Completion for NoCompletion { type Completed = T; + #[inline] + fn no_op() -> Self { + Self + } + #[inline] fn synchronize_on_drop(&self) -> bool { false @@ -51,6 +58,11 @@ impl sealed::Sealed for NoCompletion {} impl<'a, T: ?Sized + BorrowMut, B: ?Sized> Completion for CompletionFnMut<'a, B> { type Completed = B; + #[inline] + fn no_op() -> Self { + Box::new(|_value| Ok(())) + } + #[inline] fn synchronize_on_drop(&self) -> bool { true @@ -68,6 +80,11 @@ impl<'a, T: ?Sized> sealed::Sealed for CompletionFnMut<'a, T> {} impl, C: Completion> Completion for Option { type Completed = C::Completed; + #[inline] + fn no_op() -> Self { + None + } + #[inline] fn synchronize_on_drop(&self) -> bool { self.as_ref().map_or(false, Completion::synchronize_on_drop) @@ -83,7 +100,7 @@ impl sealed::Sealed for Option {} #[cfg(feature = "host")] pub struct Async<'a, 'stream, T: BorrowMut, C: Completion = NoCompletion> { - _stream: PhantomData<&'stream Stream>, + stream: &'stream Stream, value: T, status: AsyncStatus<'a, T, C>, _capture: PhantomData<&'a ()>, @@ -95,7 +112,7 @@ enum AsyncStatus<'a, T: BorrowMut, C: Completion> { Processing { receiver: oneshot::Receiver>, completion: C, - event: CudaDropWrapper, + event: Option>, _capture: PhantomData<&'a T>, }, Completed { @@ -108,10 +125,8 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Async<'a, 'strea /// Wraps a `value` which is ready on `stream`. #[must_use] pub const fn ready(value: T, stream: &'stream Stream) -> Self { - let _ = stream; - Self { - _stream: PhantomData::<&'stream Stream>, + stream, value, status: AsyncStatus::Completed { result: Ok(()) }, _capture: PhantomData::<&'a ()>, @@ -125,20 +140,16 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Async<'a, 'strea /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside /// CUDA. pub fn pending(value: T, stream: &'stream Stream, completion: C) -> CudaResult { - let event = CudaDropWrapper::from(Event::new(EventFlags::DISABLE_TIMING)?); - let (sender, receiver) = oneshot::channel(); - stream.add_callback(Box::new(|result| std::mem::drop(sender.send(result))))?; - event.record(stream)?; Ok(Self { - _stream: PhantomData::<&'stream Stream>, + stream, value, status: AsyncStatus::Processing { receiver, completion, - event, + event: None, _capture: PhantomData::<&'a T>, }, _capture: PhantomData::<&'a ()>, @@ -157,7 +168,7 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Async<'a, 'strea /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside /// CUDA. pub fn synchronize(self) -> CudaResult { - let (mut value, status) = self.destructure_into_parts(); + let (_stream, mut value, status) = self.destructure_into_parts(); let (receiver, completion) = match status { AsyncStatus::Completed { result } => return result.map(|()| value), @@ -182,6 +193,11 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Async<'a, 'strea /// Moves the asynchronous data move to a different [`Stream`]. /// + /// This method always adds a synchronisation barrier between the old and + /// and the new [`Stream`] to ensure that any usages of this [`Async`] + /// computations on the old [`Stream`] have completed before they can be + /// used on the new one. + /// /// # Errors /// Returns a [`rustacuda::error::CudaError`] iff an error occurs inside /// CUDA. @@ -189,52 +205,45 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Async<'a, 'strea self, stream: &'stream_new Stream, ) -> CudaResult> { - let (mut value, status) = self.destructure_into_parts(); - - let (receiver, completion, event) = match status { - AsyncStatus::Completed { .. } => { - return Ok(Async { - _stream: PhantomData::<&'stream_new Stream>, - value, - status, - _capture: PhantomData::<&'a ()>, - }) + let (old_stream, mut value, status) = self.destructure_into_parts(); + + let completion = match status { + AsyncStatus::Completed { result } => { + result?; + C::no_op() }, AsyncStatus::Processing { receiver, completion, - event, + event: _, _capture, - } => (receiver, completion, event), - }; - - match receiver.try_recv() { - Ok(Ok(())) => (), - Ok(Err(err)) => return Err(err), - Err(oneshot::TryRecvError::Empty) => { - stream.wait_event(&event, StreamWaitEventFlags::DEFAULT)?; - - return Ok(Async { - _stream: PhantomData::<&'stream_new Stream>, - value, - status: AsyncStatus::Processing { - receiver, - completion, - event, - _capture: PhantomData::<&'a T>, - }, - _capture: PhantomData::<&'a ()>, - }); + } => match receiver.try_recv() { + Ok(Ok(())) => { + completion.complete(value.borrow_mut())?; + C::no_op() + }, + Ok(Err(err)) => return Err(err), + Err(oneshot::TryRecvError::Empty) => completion, + Err(oneshot::TryRecvError::Disconnected) => return Err(CudaError::AlreadyAcquired), }, - Err(oneshot::TryRecvError::Disconnected) => return Err(CudaError::AlreadyAcquired), }; - completion.complete(value.borrow_mut())?; + let event = CudaDropWrapper::from(Event::new(EventFlags::DISABLE_TIMING)?); + event.record(old_stream)?; + stream.wait_event(&event, StreamWaitEventFlags::DEFAULT)?; + + let (sender, receiver) = oneshot::channel(); + stream.add_callback(Box::new(|result| std::mem::drop(sender.send(result))))?; Ok(Async { - _stream: PhantomData::<&'stream_new Stream>, + stream, value, - status: AsyncStatus::Completed { result: Ok(()) }, + status: AsyncStatus::Processing { + receiver, + completion, + event: Some(event), + _capture: PhantomData::<&'a T>, + }, _capture: PhantomData::<&'a ()>, }) } @@ -249,7 +258,7 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Async<'a, 'strea /// computation out of smaller ones that have all been submitted to the /// same [`Stream`]. pub unsafe fn unwrap_unchecked(self) -> CudaResult<(T, Option)> { - let (value, status) = self.destructure_into_parts(); + let (_stream, value, status) = self.destructure_into_parts(); match status { AsyncStatus::Completed { result: Ok(()) } => Ok((value, None)), @@ -264,20 +273,63 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Async<'a, 'strea } pub const fn as_ref(&self) -> AsyncProj<'_, 'stream, &T> { - AsyncProj::new(&self.value) + // Safety: this projection captures this async + unsafe { AsyncProj::new(&self.value, None) } } pub fn as_mut(&mut self) -> AsyncProj<'_, 'stream, &mut T> { - AsyncProj::new(&mut self.value) + // Safety: this projection captures this async + unsafe { + AsyncProj::new( + &mut self.value, + Some(Box::new(|| { + let completion = match &mut self.status { + AsyncStatus::Completed { result } => { + (*result)?; + C::no_op() + }, + AsyncStatus::Processing { + receiver: _, + completion, + event: _, + _capture, + } => std::mem::replace(completion, C::no_op()), + }; + + let event = CudaDropWrapper::from(Event::new(EventFlags::DISABLE_TIMING)?); + + let (sender, receiver) = oneshot::channel(); + + self.stream + .add_callback(Box::new(|result| std::mem::drop(sender.send(result))))?; + event.record(self.stream)?; + + self.status = AsyncStatus::Processing { + receiver, + completion, + event: Some(event), + _capture: PhantomData::<&'a T>, + }; + + Ok(()) + })), + ) + } } #[must_use] - fn destructure_into_parts(self) -> (T, AsyncStatus<'a, T, C>) { + fn destructure_into_parts(self) -> (&'stream Stream, T, AsyncStatus<'a, T, C>) { let this = std::mem::ManuallyDrop::new(self); // Safety: we destructure self into its droppable components, // value and status, without dropping self itself - unsafe { (std::ptr::read(&this.value), (std::ptr::read(&this.status))) } + unsafe { + ( + this.stream, + std::ptr::read(&this.value), + (std::ptr::read(&this.status)), + ) + } } } @@ -360,7 +412,7 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> IntoFuture type IntoFuture = impl Future; fn into_future(self) -> Self::IntoFuture { - let (value, status) = self.destructure_into_parts(); + let (_stream, value, status) = self.destructure_into_parts(); let (completion, status): (Option, AsyncStatus<'a, T, NoCompletion>) = match status { AsyncStatus::Completed { result } => { @@ -422,21 +474,30 @@ impl<'a, 'stream, T: BorrowMut, C: Completion> Drop #[cfg(feature = "host")] #[allow(clippy::module_name_repetitions)] -#[derive(Copy, Clone)] pub struct AsyncProj<'a, 'stream, T: 'a> { _capture: PhantomData<&'a ()>, _stream: PhantomData<&'stream Stream>, value: T, + use_callback: Option CudaResult<()> + 'a>>, } #[cfg(feature = "host")] impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, T> { #[must_use] - pub(crate) const fn new(value: T) -> Self { + /// # Safety + /// + /// This projection must either capture an existing [`Async`] or come from + /// a source that ensures that the projected value can never (async) move + /// to a different [`Stream`]. + pub(crate) const unsafe fn new( + value: T, + use_callback: Option CudaResult<()> + 'a>>, + ) -> Self { Self { _capture: PhantomData::<&'a ()>, _stream: PhantomData::<&'stream Stream>, value, + use_callback, } } @@ -452,6 +513,22 @@ impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, T> { pub(crate) unsafe fn unwrap_unchecked(self) -> T { self.value } + + #[allow(clippy::type_complexity)] + /// # Safety + /// + /// The returned reference to the inner value of type `T` may not yet have + /// completed its asynchronous work and may thus be in an inconsistent + /// state. + /// + /// This method must only be used to construct a larger asynchronous + /// computation out of smaller ones that have all been submitted to the + /// same [`Stream`]. + pub(crate) unsafe fn unwrap_unchecked_with_use( + self, + ) -> (T, Option CudaResult<()> + 'a>>) { + (self.value, self.use_callback) + } } #[cfg(feature = "host")] @@ -465,6 +542,7 @@ impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, T> { _capture: PhantomData::<&'b ()>, _stream: PhantomData::<&'stream Stream>, value: &self.value, + use_callback: None, } } @@ -477,8 +555,18 @@ impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, T> { _capture: PhantomData::<&'b ()>, _stream: PhantomData::<&'stream Stream>, value: &mut self.value, + use_callback: self.use_callback.as_mut().map(|use_callback| { + let use_callback: Box CudaResult<()>> = Box::new(use_callback); + use_callback + }), } } + + pub(crate) fn record_mut_use(&mut self) -> CudaResult<()> { + self.use_callback + .as_mut() + .map_or(Ok(()), |use_callback| use_callback()) + } } #[cfg(feature = "host")] @@ -492,8 +580,22 @@ impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, &'a T> { _capture: PhantomData::<&'b ()>, _stream: PhantomData::<&'stream Stream>, value: self.value, + use_callback: None, } } + + /// # Safety + /// + /// The returned reference to the inner value of type `&T` may not yet have + /// completed its asynchronous work and may thus be in an inconsistent + /// state. + /// + /// This method must only be used to construct a larger asynchronous + /// computation out of smaller ones that have all been submitted to the + /// same [`Stream`]. + pub(crate) const unsafe fn unwrap_ref_unchecked(&self) -> &T { + self.value + } } #[cfg(feature = "host")] @@ -507,6 +609,7 @@ impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, &'a mut T> { _capture: PhantomData::<&'b ()>, _stream: PhantomData::<&'stream Stream>, value: self.value, + use_callback: None, } } @@ -519,6 +622,38 @@ impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, &'a mut T> { _capture: PhantomData::<&'b ()>, _stream: PhantomData::<&'stream Stream>, value: self.value, + use_callback: self.use_callback.as_mut().map(|use_callback| { + let use_callback: Box CudaResult<()>> = Box::new(use_callback); + use_callback + }), } } + + #[allow(dead_code)] // FIXME + /// # Safety + /// + /// The returned reference to the inner value of type `&T` may not yet have + /// completed its asynchronous work and may thus be in an inconsistent + /// state. + /// + /// This method must only be used to construct a larger asynchronous + /// computation out of smaller ones that have all been submitted to the + /// same [`Stream`]. + pub(crate) unsafe fn unwrap_ref_unchecked(&self) -> &T { + self.value + } + + #[allow(dead_code)] // FIXME + /// # Safety + /// + /// The returned reference to the inner value of type `&T` may not yet have + /// completed its asynchronous work and may thus be in an inconsistent + /// state. + /// + /// This method must only be used to construct a larger asynchronous + /// computation out of smaller ones that have all been submitted to the + /// same [`Stream`]. + pub(crate) unsafe fn unwrap_mut_unchecked(&mut self) -> &mut T { + self.value + } } diff --git a/src/utils/exchange/buffer/host.rs b/src/utils/exchange/buffer/host.rs index 184de1ac..7db5ba3a 100644 --- a/src/utils/exchange/buffer/host.rs +++ b/src/utils/exchange/buffer/host.rs @@ -180,7 +180,7 @@ impl( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'_, 'stream, DeviceAccessible>>, CombinedCudaAlloc, @@ -217,7 +217,7 @@ impl( mut this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>, A, diff --git a/src/utils/exchange/buffer/mod.rs b/src/utils/exchange/buffer/mod.rs index 28ee028d..80fa09bb 100644 --- a/src/utils/exchange/buffer/mod.rs +++ b/src/utils/exchange/buffer/mod.rs @@ -146,7 +146,7 @@ unsafe impl( &self, alloc: A, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'_, 'stream, DeviceAccessible>, CombinedCudaAlloc, @@ -159,7 +159,7 @@ unsafe impl( this: owning_ref::BoxRefMut<'a, O, Self>, alloc: CombinedCudaAlloc, - stream: &'stream rustacuda::stream::Stream, + stream: &'stream crate::host::Stream, ) -> rustacuda::error::CudaResult<( Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>, A, diff --git a/src/utils/exchange/wrapper.rs b/src/utils/exchange/wrapper.rs index 1f3326c5..faa9c5b4 100644 --- a/src/utils/exchange/wrapper.rs +++ b/src/utils/exchange/wrapper.rs @@ -3,12 +3,11 @@ use std::ops::{Deref, DerefMut}; use rustacuda::{ error::CudaResult, memory::{AsyncCopyDestination, CopyDestination, DeviceBox, LockedBox}, - stream::Stream, }; use crate::{ alloc::{EmptyCudaAlloc, NoCudaAlloc}, - host::{CudaDropWrapper, HostAndDeviceConstRef, HostAndDeviceMutRef}, + host::{CudaDropWrapper, HostAndDeviceConstRef, HostAndDeviceMutRef, Stream}, lend::{RustToCuda, RustToCudaAsync}, safety::SafeMutableAliasing, utils::{ @@ -195,22 +194,6 @@ impl> ExchangeWrapperOnDevice { ) } } - - #[must_use] - pub fn as_mut( - &mut self, - ) -> HostAndDeviceMutRef::CudaRepresentation>> - where - T: SafeMutableAliasing, - { - // Safety: `device_box` contains exactly the device copy of `locked_cuda_repr` - unsafe { - HostAndDeviceMutRef::new_unchecked( - &mut self.device_box, - (**self.locked_cuda_repr).into_mut(), - ) - } - } } impl> @@ -339,12 +322,16 @@ impl< > { let this = unsafe { self.as_ref().unwrap_unchecked() }; - AsyncProj::new(unsafe { - HostAndDeviceConstRef::new_unchecked( - &*(this.device_box), - (**(this.locked_cuda_repr)).into_ref(), + // Safety: this projection captures this async + unsafe { + AsyncProj::new( + HostAndDeviceConstRef::new_unchecked( + &*(this.device_box), + (**(this.locked_cuda_repr)).into_ref(), + ), + None, ) - }) + } } #[must_use] @@ -358,13 +345,17 @@ impl< where T: SafeMutableAliasing, { - let this = unsafe { self.as_mut().unwrap_unchecked() }; + let (this, use_callback) = unsafe { self.as_mut().unwrap_unchecked_with_use() }; - AsyncProj::new(unsafe { - HostAndDeviceMutRef::new_unchecked( - &mut *(this.device_box), - (**(this.locked_cuda_repr)).into_mut(), + // Safety: this projection captures this async + unsafe { + AsyncProj::new( + HostAndDeviceMutRef::new_unchecked( + &mut *(this.device_box), + (**(this.locked_cuda_repr)).into_mut(), + ), + use_callback, ) - }) + } } }