Skip to content

Commit

Permalink
Backup of progress on compile-time PTX checking
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed Mar 1, 2023
1 parent 8f0f0cf commit b90a92f
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 8 deletions.
15 changes: 9 additions & 6 deletions examples/single-source/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

extern crate alloc;

#[cfg(target_os = "cuda")]
use rc::utils::shared::r#static::ThreadBlockShared;

#[cfg(not(target_os = "cuda"))]
Expand Down Expand Up @@ -50,23 +51,25 @@ pub fn kernel<'a, T: rc::common::RustToCuda>(
#[kernel(pass = LendRustToCuda)] _z: &ShallowCopy<Wrapper<T>>,
#[kernel(pass = SafeDeviceCopy, jit)] _v @ _w: &'a core::sync::atomic::AtomicU64,
#[kernel(pass = LendRustToCuda)] _: Wrapper<T>,
#[kernel(pass = SafeDeviceCopy)] Tuple(_s, mut __t): Tuple,
#[kernel(pass = LendRustToCuda)] shared3: ThreadBlockShared<u32>,
#[kernel(pass = SafeDeviceCopy)] Tuple(s, mut __t): Tuple,
// #[kernel(pass = LendRustToCuda)] shared3: ThreadBlockShared<u32>,
) where
<T as rc::common::RustToCuda>::CudaRepresentation: rc::safety::StackOnly,
{
let shared: ThreadBlockShared<[Tuple; 3]> = ThreadBlockShared::new_uninit();
let shared2: ThreadBlockShared<[Tuple; 3]> = ThreadBlockShared::new_uninit();

#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
unsafe {
(*shared.as_mut_ptr().cast::<Tuple>().add(1)).0 = 42;
(*shared.as_mut_ptr().cast::<Tuple>().add(1)).0 = (f64::from(s) * 2.0) as u32;
}
unsafe {
(*shared2.as_mut_ptr().cast::<Tuple>().add(2)).1 = 24;
}
unsafe {
*shared3.as_mut_ptr() = 12;
}
unsafe { core::arch::asm!("hi") }
// unsafe {
// *shared3.as_mut_ptr() = 12;
// }
}

#[cfg(not(target_os = "cuda"))]
Expand Down
2 changes: 2 additions & 0 deletions rust-cuda-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.0"
authors = ["Juniper Tyree <juniper.langenstein@helsinki.fi>"]
license = "MIT OR Apache-2.0"
edition = "2021"
links = "libnvptxcompiler_static"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -24,3 +25,4 @@ colored = "2.0"

seahash = "4.1"
ptx-builder = { git = "https://github.com/juntyr/rust-ptx-builder", rev = "1f1f49d" }
ptx_compiler = "0.1"
3 changes: 3 additions & 0 deletions rust-cuda-derive/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn main() {
println!("cargo:rustc-link-lib=nvptxcompiler_static");
}
3 changes: 3 additions & 0 deletions rust-cuda-derive/src/kernel/link/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::path::PathBuf;
#[allow(clippy::module_name_repetitions)]
pub(super) struct LinkKernelConfig {
pub(super) kernel: syn::Ident,
pub(super) kernel_hash: syn::Ident,
pub(super) args: syn::Ident,
pub(super) crate_name: String,
pub(super) crate_path: PathBuf,
Expand All @@ -12,6 +13,7 @@ pub(super) struct LinkKernelConfig {
impl syn::parse::Parse for LinkKernelConfig {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let kernel: syn::Ident = input.parse()?;
let kernel_hash: syn::Ident = input.parse()?;
let args: syn::Ident = input.parse()?;
let name: syn::LitStr = input.parse()?;
let path: syn::LitStr = input.parse()?;
Expand All @@ -37,6 +39,7 @@ impl syn::parse::Parse for LinkKernelConfig {

Ok(Self {
kernel,
kernel_hash,
args,
crate_name: name.value(),
crate_path: PathBuf::from(path.value()),
Expand Down
122 changes: 121 additions & 1 deletion rust-cuda-derive/src/kernel/link/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use std::{
env, fs,
env,
ffi::CString,
fs,
io::{Read, Write},
mem::MaybeUninit,
os::raw::c_int,
path::{Path, PathBuf},
ptr::addr_of_mut,
sync::atomic::{AtomicBool, Ordering},
};

Expand All @@ -11,6 +16,7 @@ use ptx_builder::{
builder::{BuildStatus, Builder, MessageFormat, Profile},
error::{BuildErrorKind, Error, Result},
};
use ptx_compiler::sys::size_t;

use super::utils::skip_kernel_compilation;

Expand Down Expand Up @@ -56,6 +62,7 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {

let LinkKernelConfig {
kernel,
kernel_hash,
args,
crate_name,
crate_path,
Expand Down Expand Up @@ -192,6 +199,119 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
kernel_ptx.replace_range(type_layout_start..type_layout_end, "");
}

let mut compiler = MaybeUninit::uninit();
let r = unsafe {
ptx_compiler::sys::nvPTXCompilerCreate(
compiler.as_mut_ptr(),
kernel_ptx.len() as size_t,
kernel_ptx.as_ptr().cast(),
)
};
emit_call_site_warning!("PTX compiler create result {}", r);
let compiler = unsafe { compiler.assume_init() };

let mut major = 0;
let mut minor = 0;
let r = unsafe {
ptx_compiler::sys::nvPTXCompilerGetVersion(addr_of_mut!(major), addr_of_mut!(minor))
};
emit_call_site_warning!("PTX version result {}", r);
emit_call_site_warning!("PTX compiler version {}.{}", major, minor);

let kernel_name = if specialisation.is_empty() {
format!("{kernel_hash}_kernel")
} else {
format!(
"{kernel_hash}_kernel_{:016x}",
seahash::hash(specialisation.as_bytes())
)
};

let options = vec![
CString::new("--entry").unwrap(),
CString::new(kernel_name).unwrap(),
CString::new("--verbose").unwrap(),
CString::new("--warn-on-double-precision-use").unwrap(),
CString::new("--warn-on-local-memory-usage").unwrap(),
CString::new("--warn-on-spills").unwrap(),
];
let options_ptrs = options.iter().map(|o| o.as_ptr()).collect::<Vec<_>>();

let r = unsafe {
ptx_compiler::sys::nvPTXCompilerCompile(
compiler,
options_ptrs.len() as c_int,
options_ptrs.as_ptr().cast(),
)
};
emit_call_site_warning!("PTX compile result {}", r);

let mut info_log_size = 0;
let r = unsafe {
ptx_compiler::sys::nvPTXCompilerGetInfoLogSize(compiler, addr_of_mut!(info_log_size))
};
emit_call_site_warning!("PTX info log size result {}", r);
#[allow(clippy::cast_possible_truncation)]
let mut info_log: Vec<u8> = Vec::with_capacity(info_log_size as usize);
if info_log_size > 0 {
let r = unsafe {
ptx_compiler::sys::nvPTXCompilerGetInfoLog(compiler, info_log.as_mut_ptr().cast())
};
emit_call_site_warning!("PTX info log content result {}", r);
#[allow(clippy::cast_possible_truncation)]
unsafe {
info_log.set_len(info_log_size as usize);
}
}
let info_log = String::from_utf8_lossy(&info_log);

let mut error_log_size = 0;
let r = unsafe {
ptx_compiler::sys::nvPTXCompilerGetErrorLogSize(compiler, addr_of_mut!(error_log_size))
};
emit_call_site_warning!("PTX error log size result {}", r);
#[allow(clippy::cast_possible_truncation)]
let mut error_log: Vec<u8> = Vec::with_capacity(error_log_size as usize);
if error_log_size > 0 {
let r = unsafe {
ptx_compiler::sys::nvPTXCompilerGetErrorLog(compiler, error_log.as_mut_ptr().cast())
};
emit_call_site_warning!("PTX error log content result {}", r);
#[allow(clippy::cast_possible_truncation)]
unsafe {
error_log.set_len(error_log_size as usize);
}
}
let error_log = String::from_utf8_lossy(&error_log);

// Ensure the compiler is not dropped
let mut compiler = MaybeUninit::new(compiler);
let r = unsafe { ptx_compiler::sys::nvPTXCompilerDestroy(compiler.as_mut_ptr()) };
emit_call_site_warning!("PTX compiler destroy result {}", r);

if !info_log.is_empty() {
emit_call_site_warning!("PTX compiler info log:\n{}", info_log);
}
if !error_log.is_empty() {
let mut max_lines = kernel_ptx.chars().filter(|c| *c == '\n').count() + 1;
let mut indent = 0;
while max_lines > 0 {
max_lines /= 10;
indent += 1;
}

abort_call_site!(
"PTX compiler error log:\n{}\nPTX source:\n{}",
error_log,
kernel_ptx
.lines()
.enumerate()
.map(|(i, l)| format!("{:indent$}| {l}", i + 1))
.collect::<Vec<_>>()
.join("\n")
);
}

(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub(super) fn quote_get_ptx_str(
quote! {
fn get_ptx_str() -> &'static str {
#crate_path::host::link_kernel!{
#func_ident #args #crate_name #crate_manifest_dir #generic_start_token
#func_ident #func_ident_hash #args #crate_name #crate_manifest_dir #generic_start_token
#($#macro_type_ids),*
#generic_close_token
}
Expand Down
7 changes: 7 additions & 0 deletions src/safety/device_copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,11 @@ mod sealed {
for crate::utils::device_copy::SafeDeviceCopyWrapper<T>
{
}

// Only unsafe aliasing is possible since both only expose raw pointers
// impl<T: 'static> SafeDeviceCopy for
// crate::utils::shared::r#static::ThreadBlockShared<T> {}
// impl<T: 'static + ~const const_type_layout::TypeGraphLayout>
// SafeDeviceCopy for crate::utils::shared::slice::ThreadBlockSharedSlice<T>
// {}
}
6 changes: 6 additions & 0 deletions src/safety/no_aliasing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,10 @@ mod private {
{
}
impl<T> NoAliasing for crate::utils::aliasing::SplitSliceOverCudaThreadsDynamicStride<T> {}

// Only unsafe aliasing is possible since both only expose raw pointers
// impl<T: 'static> NoAliasing for
// crate::utils::shared::r#static::ThreadBlockShared<T> {}
// impl<T: 'static + ~const const_type_layout::TypeGraphLayout> NoAliasing
// for crate::utils::shared::slice::ThreadBlockSharedSlice<T> {}
}

0 comments on commit b90a92f

Please sign in to comment.