Skip to content

Commit

Permalink
Some progress on shared slices
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed Jan 7, 2023
1 parent 8648dfa commit 3adaa9c
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 81 deletions.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
any(all(not(feature = "host"), target_os = "cuda"), doc),
feature(asm_const)
)]
#![cfg_attr(target_os = "cuda", feature(ptr_metadata))]
#![cfg_attr(any(feature = "alloc", doc), feature(allocator_api))]
#![feature(doc_cfg)]
#![feature(cfg_version)]
Expand Down
6 changes: 4 additions & 2 deletions src/safety/stack_only.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ mod sealed {
impl<T> !StackOnly for &mut T {}

impl<T: 'static> !StackOnly for crate::utils::shared::r#static::ThreadBlockShared<T> {}
// impl<T: 'static> !StackOnly for
// crate::utils::shared::slice::ThreadBlockSharedSlice<T> {}
impl<T: 'static + ~const const_type_layout::TypeGraphLayout> !StackOnly
for crate::utils::shared::slice::ThreadBlockSharedSlice<T>
{
}

impl<T> StackOnly for core::marker::PhantomData<T> {}
}
2 changes: 1 addition & 1 deletion src/utils/shared/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
// pub mod slice;
pub mod slice;
pub mod r#static;
160 changes: 119 additions & 41 deletions src/utils/shared/slice.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,151 @@
#[cfg(not(target_os = "cuda"))]
use core::marker::PhantomData;

use const_type_layout::TypeGraphLayout;
use rustacuda_core::DeviceCopy;

use crate::common::{CudaAsRust, DeviceAccessible, RustToCuda};

#[cfg(not(target_os = "cuda"))]
#[allow(clippy::module_name_repetitions)]
#[repr(transparent)]
pub struct ThreadBlockSharedSlice<T: 'static + ~const TypeGraphLayout> {
len: usize,
marker: PhantomData<T>,
}

#[cfg(target_os = "cuda")]
#[allow(clippy::module_name_repetitions)]
#[repr(transparent)]
pub struct ThreadBlockSharedSlice<T: 'static + ~const TypeGraphLayout> {
shared: *mut [T],
}

#[doc(hidden)]
#[derive(TypeLayout)]
#[layout(bound = "T: 'static + ~const TypeGraphLayout")]
#[repr(C)]
pub struct ThreadBlockSharedSlice<T: 'static> {
pub struct ThreadBlockSharedSliceCudaRepresentation<T: 'static + ~const TypeGraphLayout> {
len: usize,
byte_offset: usize,
// Note: uses a zero-element array instead of PhantomData here so that
// TypeLayout can still observe T's layout
marker: [T; 0],
}

unsafe impl<T: 'static> DeviceCopy for ThreadBlockSharedSlice<T> {}
unsafe impl<T: 'static + ~const TypeGraphLayout> DeviceCopy
for ThreadBlockSharedSliceCudaRepresentation<T>
{
}

#[cfg(not(any(all(not(feature = "host"), target_os = "cuda"), doc)))]
#[doc(cfg(not(all(not(feature = "host"), target_os = "cuda"))))]
impl<T: 'static> ThreadBlockSharedSlice<T> {
// #[cfg(not(any(all(not(feature = "host"), target_os = "cuda"), doc)))]
// #[doc(cfg(not(all(not(feature = "host"), target_os = "cuda"))))]
impl<T: 'static + ~const TypeGraphLayout> ThreadBlockSharedSlice<T> {
#[cfg(any(not(target_os = "cuda"), doc))]
#[doc(cfg(not(target_os = "cuda")))]
#[must_use]
pub fn with_len(len: usize) -> Self {
pub fn new_uninit_with_len(len: usize) -> Self {
Self {
len,
byte_offset: 0,
marker: [],
marker: PhantomData::<T>,
}
}

#[cfg(not(target_os = "cuda"))]
#[must_use]
pub fn len(&self) -> usize {
self.len
}

#[cfg(target_os = "cuda")]
#[must_use]
pub fn len(&self) -> usize {
core::ptr::metadata(self.shared)
}

#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
self.len() == 0
}

#[cfg(any(target_os = "cuda", doc))]
#[doc(cfg(target_os = "cuda"))]
#[must_use]
pub fn as_mut_slice_ptr(&self) -> *mut [T] {
self.shared
}

#[cfg(any(target_os = "cuda", doc))]
#[doc(cfg(target_os = "cuda"))]
#[must_use]
pub fn as_mut_ptr(&self) -> *mut T {
self.shared.cast()
}
}

#[cfg(all(not(feature = "host"), target_os = "cuda"))]
#[doc(cfg(all(not(feature = "host"), target_os = "cuda")))]
impl<T: 'static> ThreadBlockSharedSlice<T> {
/// # Safety
///
/// The thread-block shared dynamic memory must be initialised once and
/// only once per kernel.
pub unsafe fn init() {
unsafe {
core::arch::asm!(
".shared .align {align} .b8 rust_cuda_dynamic_shared[];",
align = const(core::mem::align_of::<T>()),
);
}
unsafe impl<T: 'static + ~const TypeGraphLayout> RustToCuda for ThreadBlockSharedSlice<T> {
#[cfg(feature = "host")]
#[doc(cfg(feature = "host"))]
type CudaAllocation = crate::host::NullCudaAlloc;
type CudaRepresentation = ThreadBlockSharedSliceCudaRepresentation<T>;

#[cfg(feature = "host")]
#[doc(cfg(feature = "host"))]
unsafe fn borrow<A: crate::host::CudaAlloc>(
&self,
alloc: A,
) -> rustacuda::error::CudaResult<(
DeviceAccessible<Self::CudaRepresentation>,
crate::host::CombinedCudaAlloc<Self::CudaAllocation, A>,
)> {
Ok((
DeviceAccessible::from(ThreadBlockSharedSliceCudaRepresentation {
len: self.len,
marker: [],
}),
crate::host::CombinedCudaAlloc::new(crate::host::NullCudaAlloc, alloc),
))
}

/// # Safety
///
/// Exposing the [`ThreadBlockSharedSlice`] must be preceded by exactly one
/// call to [`ThreadBlockSharedSlice::init`] for the type `T` amongst
/// all `ThreadBlockSharedSlice<T>` that has the largest alignment.
pub unsafe fn with_uninit<F: FnOnce(*mut [T]) -> Q, Q>(self, inner: F) -> Q {
let base: *mut u8;

unsafe {
core::arch::asm!(
"cvta.shared.u64 {reg}, rust_cuda_dynamic_shared;",
reg = out(reg64) base,
);
}
#[cfg(feature = "host")]
#[doc(cfg(feature = "host"))]
unsafe fn restore<A: crate::host::CudaAlloc>(
&mut self,
alloc: crate::host::CombinedCudaAlloc<Self::CudaAllocation, A>,
) -> rustacuda::error::CudaResult<A> {
let (_null, alloc): (crate::host::NullCudaAlloc, A) = alloc.split();

Ok(alloc)
}
}

unsafe impl<T: 'static + ~const TypeGraphLayout> CudaAsRust
for ThreadBlockSharedSliceCudaRepresentation<T>
{
type RustRepresentation = ThreadBlockSharedSlice<T>;

#[cfg(any(not(feature = "host"), doc))]
#[doc(cfg(not(feature = "host")))]
unsafe fn as_rust(_this: &DeviceAccessible<Self>) -> Self::RustRepresentation {
todo!()

// unsafe {
// core::arch::asm!(
// ".shared .align {align} .b8 rust_cuda_dynamic_shared[];",
// align = const(core::mem::align_of::<T>()),
// );
// }

// let base: *mut u8;

let slice =
core::ptr::slice_from_raw_parts_mut(base.add(self.byte_offset).cast(), self.len);
// unsafe {
// core::arch::asm!(
// "cvta.shared.u64 {reg}, rust_cuda_dynamic_shared;",
// reg = out(reg64) base,
// );
// }

inner(slice)
// let slice = core::ptr::slice_from_raw_parts_mut(
// base.add(self.byte_offset).cast(), self.len,
// );
}
}
72 changes: 35 additions & 37 deletions src/utils/shared/static.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,41 @@ pub struct ThreadBlockSharedCudaRepresentation<T: 'static> {

unsafe impl<T: 'static> DeviceCopy for ThreadBlockSharedCudaRepresentation<T> {}

impl<T: 'static> ThreadBlockShared<T> {
#[cfg(not(target_os = "cuda"))]
#[must_use]
pub fn new_uninit() -> Self {
Self {
marker: PhantomData::<T>,
}
}

#[cfg(target_os = "cuda")]
#[must_use]
pub fn new_uninit() -> Self {
let shared: *mut T;

unsafe {
core::arch::asm!(
".shared .align {align} .b8 {reg}_rust_cuda_static_shared[{size}];",
"cvta.shared.u64 {reg}, {reg}_rust_cuda_static_shared;",
reg = out(reg64) shared,
align = const(core::mem::align_of::<T>()),
size = const(core::mem::size_of::<T>()),
);
}

Self { shared }
}

#[cfg(any(target_os = "cuda", doc))]
#[doc(cfg(target_os = "cuda"))]
#[must_use]
pub fn as_mut_ptr(&self) -> *mut T {
self.shared
}
}

unsafe impl<T: 'static + ~const TypeGraphLayout> RustToCuda for ThreadBlockShared<T> {
#[cfg(feature = "host")]
#[doc(cfg(feature = "host"))]
Expand Down Expand Up @@ -73,40 +108,3 @@ unsafe impl<T: 'static + ~const TypeGraphLayout> CudaAsRust
ThreadBlockShared::new_uninit()
}
}

#[cfg(not(any(all(not(feature = "host"), target_os = "cuda"), doc)))]
#[doc(cfg(not(all(not(feature = "host"), target_os = "cuda"))))]
impl<T: 'static> ThreadBlockShared<T> {
#[must_use]
pub fn new_uninit() -> Self {
Self {
marker: PhantomData::<T>,
}
}
}

#[cfg(any(all(not(feature = "host"), target_os = "cuda"), doc))]
#[doc(cfg(all(not(feature = "host"), target_os = "cuda")))]
impl<T: 'static> ThreadBlockShared<T> {
#[must_use]
pub fn new_uninit() -> Self {
let shared: *mut T;

unsafe {
core::arch::asm!(
".shared .align {align} .b8 {reg}_rust_cuda_static_shared[{size}];",
"cvta.shared.u64 {reg}, {reg}_rust_cuda_static_shared;",
reg = out(reg64) shared,
align = const(core::mem::align_of::<T>()),
size = const(core::mem::size_of::<T>()),
);
}

Self { shared }
}

#[must_use]
pub fn as_mut_ptr(&self) -> *mut T {
self.shared
}
}

0 comments on commit 3adaa9c

Please sign in to comment.