Skip to content

Commit

Permalink
Ensure async launch mutable borrow safety with barriers on use and st…
Browse files Browse the repository at this point in the history
…ream move
  • Loading branch information
juntyr committed Jan 13, 2024
1 parent c74b542 commit 139adce
Show file tree
Hide file tree
Showing 20 changed files with 358 additions and 172 deletions.
16 changes: 9 additions & 7 deletions examples/print/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> {
);

// Create a new CUDA stream to submit kernels to
let stream =
let mut stream =
rust_cuda::host::CudaDropWrapper::from(rust_cuda::deps::rustacuda::stream::Stream::new(
rust_cuda::deps::rustacuda::stream::StreamFlags::NON_BLOCKING,
None,
Expand All @@ -70,12 +70,14 @@ fn main() -> rust_cuda::deps::rustacuda::error::CudaResult<()> {
};

// Launch the CUDA kernel on the stream and synchronise to its completion
println!("Launching print kernel ...");
kernel.launch1(&stream, &config, Action::Print)?;
println!("Launching panic kernel ...");
kernel.launch1(&stream, &config, Action::Panic)?;
println!("Launching alloc error kernel ...");
kernel.launch1(&stream, &config, Action::AllocError)?;
rust_cuda::host::Stream::with(&mut stream, |stream| {
println!("Launching print kernel ...");
kernel.launch1(stream, &config, Action::Print)?;
println!("Launching panic kernel ...");
kernel.launch1(stream, &config, Action::Panic)?;
println!("Launching alloc error kernel ...");
kernel.launch1(stream, &config, Action::AllocError)
})?;

Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions rust-cuda-derive/src/rust_to_cuda/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ pub fn rust_to_cuda_async_trait(
unsafe fn borrow_async<'stream, CudaAllocType: #crate_path::alloc::CudaAlloc>(
&self,
alloc: CudaAllocType,
stream: &'stream #crate_path::deps::rustacuda::stream::Stream,
stream: &'stream #crate_path::host::Stream,
) -> #crate_path::deps::rustacuda::error::CudaResult<(
#crate_path::utils::r#async::Async<
'_, 'stream,
Expand Down Expand Up @@ -219,7 +219,7 @@ pub fn rust_to_cuda_async_trait(
alloc: #crate_path::alloc::CombinedCudaAlloc<
Self::CudaAllocationAsync, CudaAllocType
>,
stream: &'stream #crate_path::deps::rustacuda::stream::Stream,
stream: &'stream #crate_path::host::Stream,
) -> #crate_path::deps::rustacuda::error::CudaResult<(
#crate_path::utils::r#async::Async<
'a, 'stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ pub(in super::super) fn quote_cuda_generic_function(
)
.collect::<Vec<_>>();

let generic_start_token = generic_start_token.unwrap_or_default();
let generic_close_token = generic_close_token.unwrap_or_default();

quote! {
#[cfg(target_os = "cuda")]
#(#func_attrs)*
Expand Down
50 changes: 35 additions & 15 deletions src/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use rustacuda::{
event::Event,
memory::{CopyDestination, DeviceBox, DeviceBuffer, LockedBox, LockedBuffer},
module::Module,
stream::Stream,
};

use crate::{
Expand All @@ -26,6 +25,33 @@ use crate::{
},
};

#[repr(transparent)]
pub struct Stream {
stream: rustacuda::stream::Stream,
}

impl Deref for Stream {
type Target = rustacuda::stream::Stream;

fn deref(&self) -> &Self::Target {
&self.stream
}
}

impl Stream {
pub fn with<O>(
stream: &mut rustacuda::stream::Stream,
inner: impl for<'stream> FnOnce(&'stream Self) -> O,
) -> O {
// Safety:
// - Stream is a newtype wrapper around rustacuda::stream::Stream
// - we forge a unique lifetime for a unique reference
let stream = unsafe { &*std::ptr::from_ref(stream).cast() };

inner(stream)
}
}

pub trait CudaDroppable: Sized {
#[allow(clippy::missing_errors_doc)]
fn drop(val: Self) -> Result<(), (rustacuda::error::CudaError, Self)>;
Expand Down Expand Up @@ -88,7 +114,7 @@ impl<T: rustacuda_core::DeviceCopy> CudaDroppable for LockedBuffer<T> {
}

macro_rules! impl_sealed_drop_value {
($type:ident) => {
($type:ty) => {
impl CudaDroppable for $type {
fn drop(val: Self) -> Result<(), (CudaError, Self)> {
Self::drop(val)
Expand All @@ -98,7 +124,7 @@ macro_rules! impl_sealed_drop_value {
}

impl_sealed_drop_value!(Module);
impl_sealed_drop_value!(Stream);
impl_sealed_drop_value!(rustacuda::stream::Stream);
impl_sealed_drop_value!(Context);
impl_sealed_drop_value!(Event);

Expand Down Expand Up @@ -142,7 +168,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
/// # Safety
///
/// `device_box` must contain EXACTLY the device copy of `host_ref`
pub unsafe fn new_unchecked(
pub(crate) unsafe fn new_unchecked(
device_box: &'a mut DeviceBox<DeviceCopyWithPortableBitSemantics<T>>,
host_ref: &'a mut T,
) -> Self {
Expand Down Expand Up @@ -180,7 +206,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
}

#[must_use]
pub fn as_mut<'b>(&'b mut self) -> HostAndDeviceMutRef<'b, T>
pub fn into_mut<'b>(self) -> HostAndDeviceMutRef<'b, T>
where
'a: 'b,
{
Expand All @@ -191,20 +217,14 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceMutRef<'a, T> {
}

#[must_use]
pub fn as_async<'b, 'stream>(
&'b mut self,
pub fn into_async<'b, 'stream>(
self,
stream: &'stream Stream,
) -> Async<'b, 'stream, HostAndDeviceMutRef<'b, T>, NoCompletion>
where
'a: 'b,
{
Async::ready(
HostAndDeviceMutRef {
device_box: self.device_box,
host_ref: self.host_ref,
},
stream,
)
Async::ready(self.into_mut(), stream)
}
}

Expand Down Expand Up @@ -253,7 +273,7 @@ impl<'a, T: PortableBitSemantics + TypeGraphLayout> HostAndDeviceConstRef<'a, T>
/// # Safety
///
/// `device_box` must contain EXACTLY the device copy of `host_ref`
pub const unsafe fn new_unchecked(
pub(crate) const unsafe fn new_unchecked(
device_box: &'a DeviceBox<DeviceCopyWithPortableBitSemantics<T>>,
host_ref: &'a T,
) -> Self {
Expand Down
46 changes: 36 additions & 10 deletions src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use rustacuda::{
error::{CudaError, CudaResult},
function::Function,
module::Module,
stream::Stream,
};

#[cfg(feature = "kernel")]
Expand All @@ -27,6 +26,8 @@ mod ptx_jit;
#[cfg(feature = "host")]
use ptx_jit::{PtxJITCompiler, PtxJITResult};

#[cfg(feature = "host")]
use crate::host::Stream;
use crate::safety::PortableBitSemantics;

pub mod param;
Expand Down Expand Up @@ -109,7 +110,7 @@ pub trait CudaKernelParameter: sealed::Sealed {
#[allow(clippy::missing_errors_doc)] // FIXME
fn with_new_async<'stream, 'param, O, E: From<rustacuda::error::CudaError>>(
param: Self::SyncHostType,
stream: &'stream rustacuda::stream::Stream,
stream: &'stream crate::host::Stream,
inner: impl WithNewAsync<'stream, Self, O, E>,
) -> Result<O, E>
where
Expand Down Expand Up @@ -206,7 +207,9 @@ macro_rules! impl_launcher_launch {
pub fn $launch_async<$($T: CudaKernelParameter),*>(
&mut self,
$($arg: $T::AsyncHostType<'stream, '_>),*
) -> CudaResult<()>
) -> CudaResult<crate::utils::r#async::Async<
'static, 'stream, (), crate::utils::r#async::NoCompletion,
>>
where
Kernel: FnOnce(&mut Launcher<'stream, '_, Kernel>, $($T),*),
{
Expand Down Expand Up @@ -375,13 +378,10 @@ macro_rules! impl_typed_kernel_launch {
config,
$($arg,)*
|kernel, stream, config, $($arg),*| {
let result = kernel.$launch_async::<$($T),*>(stream, config, $($arg),*);
let r#async = kernel.$launch_async::<$($T),*>(stream, config, $($arg),*)?;

// important: always synchronise here, this function is sync!
match (stream.synchronize(), result) {
(Ok(()), result) => result,
(Err(_), Err(err)) | (Err(err), Ok(())) => Err(err),
}
r#async.synchronize()
},
)
}
Expand Down Expand Up @@ -422,7 +422,29 @@ macro_rules! impl_typed_kernel_launch {
stream: &'stream Stream,
config: &LaunchConfig,
$($arg: $T::AsyncHostType<'stream, '_>),*
) -> CudaResult<()>
) -> CudaResult<crate::utils::r#async::Async<
'static, 'stream, (), crate::utils::r#async::NoCompletion,
>>
// launch_async does not need to capture its parameters until kernel completion:
// - moved parameters are moved and cannot be used again, deallocation will sync
// - immutably borrowed parameters can be shared across multiple kernel launches
// - mutably borrowed parameters are more tricky:
// - Rust's borrowing rules ensure that a single mutable reference cannot be
// passed into multiple parameters of the kernel (no mutable aliasing)
// - CUDA guarantees that kernels launched on the same stream are executed
// sequentially, so even immediate resubmissions for the same mutable data
// will not have temporally overlapping mutation on the same stream
// - however, we have to guarantee that mutable data cannot be used on several
// different streams at the same time
// - Async::move_to_stream always adds a synchronisation barrier between the
// old and the new stream to ensure that all uses on the old stream happen
// strictly before all uses on the new stream
// - async launches take AsyncProj<&mut HostAndDeviceMutRef<..>>, which either
// captures an Async, which must be moved to a different stream explicitly,
// or contains data that cannot async move to a different stream without
// - any use of a mutable borrow in an async kernel launch adds a sync barrier
// on the launch stream s.t. the borrow is only complete once the kernel has
// completed
where
Kernel: FnOnce(&mut Launcher<'stream, 'kernel, Kernel>, $($T),*),
{
Expand Down Expand Up @@ -454,7 +476,11 @@ macro_rules! impl_typed_kernel_launch {
&mut $T::async_to_ffi($arg, sealed::Token)?
).cast::<core::ffi::c_void>()),*
],
) }
) }?;

crate::utils::r#async::Async::pending(
(), stream, crate::utils::r#async::NoCompletion,
)
}
};
(impl $func:ident () + ($($other:expr),*) $inner:block) => {
Expand Down
Loading

0 comments on commit 139adce

Please sign in to comment.