diff --git a/src/driver/safe/launch.rs b/src/driver/safe/launch.rs index 7d53ca0..654f399 100644 --- a/src/driver/safe/launch.rs +++ b/src/driver/safe/launch.rs @@ -207,6 +207,48 @@ pub unsafe trait LaunchAsync { ) -> Result<(), result::DriverError>; } +unsafe impl LaunchAsync<&mut [*mut std::ffi::c_void]> for CudaFunction { + #[inline(always)] + unsafe fn launch( + self, + cfg: LaunchConfig, + args: &mut [*mut std::ffi::c_void], + ) -> Result<(), result::DriverError> { + self.launch_async_impl(cfg, args) + } + + #[inline(always)] + unsafe fn launch_on_stream( + self, + stream: &CudaStream, + cfg: LaunchConfig, + args: &mut [*mut std::ffi::c_void], + ) -> Result<(), result::DriverError> { + self.par_launch_async_impl(stream, cfg, args) + } +} + +unsafe impl LaunchAsync<&mut Vec<*mut std::ffi::c_void>> for CudaFunction { + #[inline(always)] + unsafe fn launch( + self, + cfg: LaunchConfig, + args: &mut Vec<*mut std::ffi::c_void>, + ) -> Result<(), result::DriverError> { + self.launch_async_impl(cfg, args) + } + + #[inline(always)] + unsafe fn launch_on_stream( + self, + stream: &CudaStream, + cfg: LaunchConfig, + args: &mut Vec<*mut std::ffi::c_void>, + ) -> Result<(), result::DriverError> { + self.par_launch_async_impl(stream, cfg, args) + } +} + macro_rules! impl_launch { ([$($Vars:tt),*], [$($Idx:tt),*]) => { unsafe impl<$($Vars: DeviceRepr),*> LaunchAsync<($($Vars, )*)> for CudaFunction {