From 44a0220beb8eb321208da120c7c65cd5a49104be Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sun, 14 May 2023 09:30:51 +0000 Subject: [PATCH] Simplified the kernel parameter layout extraction from PTX --- rust-cuda-derive/src/kernel/link/mod.rs | 179 +++++++++--------- rust-cuda-derive/src/kernel/mod.rs | 3 + .../kernel/wrapper/generate/cuda_wrapper.rs | 23 +-- rust-cuda-ptx-jit/src/device.rs | 2 +- 4 files changed, 103 insertions(+), 104 deletions(-) diff --git a/rust-cuda-derive/src/kernel/link/mod.rs b/rust-cuda-derive/src/kernel/link/mod.rs index 220218aa7..b66efd76a 100644 --- a/rust-cuda-derive/src/kernel/link/mod.rs +++ b/rust-cuda-derive/src/kernel/link/mod.rs @@ -21,6 +21,7 @@ use ptx_builder::{ use super::{ lints::{LintLevel, PtxLint}, utils::skip_kernel_compilation, + KERNEL_TYPE_USE_END_CANARY, KERNEL_TYPE_USE_START_CANARY, }; mod config; @@ -66,14 +67,14 @@ pub fn check_kernel(tokens: TokenStream) -> TokenStream { quote!(::core::result::Result::Ok(())).into() } -#[allow(clippy::module_name_repetitions, clippy::too_many_lines)] +#[allow(clippy::module_name_repetitions)] pub fn link_kernel(tokens: TokenStream) -> TokenStream { proc_macro_error::set_dummy(quote! { const PTX_STR: &'static str = "ERROR in this PTX compilation"; }); let LinkKernelConfig { - kernel, + kernel: _kernel, kernel_hash, args, crate_name, @@ -111,116 +112,110 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream { .into() }; - let kernel_layout_name = if specialisation.is_empty() { - format!("{kernel}_type_layout_kernel") - } else { - format!( - "{kernel}_type_layout_kernel_{:016x}", - seahash::hash(specialisation.as_bytes()) - ) - }; + let type_layouts = extract_ptx_kernel_layout(&mut kernel_ptx); + remove_kernel_type_use_from_ptx(&mut kernel_ptx); - let mut type_layouts = Vec::new(); + check_kernel_ptx_and_report( + &kernel_ptx, + Specialisation::Link(&specialisation), + &kernel_hash, + &ptx_lint_levels, + ); + + (quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into() +} - let type_layout_start_pattern = format!("\n\t// .globl\t{kernel_layout_name}"); +fn extract_ptx_kernel_layout(kernel_ptx: &mut String) -> Vec { + const BEFORE_PARAM_PATTERN: &str = "\n.global .align 1 .b8 "; + const PARAM_LEN_PATTERN: &str = "["; + const LEN_BYTES_PATTERN: &str = "] = {"; + const AFTER_BYTES_PATTERN: &str = "};\n"; + const BYTES_PARAM_PATTERN: &str = "};"; - if let Some(type_layout_start) = kernel_ptx.find(&type_layout_start_pattern) { - const BEFORE_PARAM_PATTERN: &str = "\n.global .align 1 .b8 "; - const PARAM_LEN_PATTERN: &str = "["; - const LEN_BYTES_PATTERN: &str = "] = {"; - const AFTER_BYTES_PATTERN: &str = "};\n"; - const BYTES_PARAM_PATTERN: &str = "};"; + let mut type_layouts = Vec::new(); - let after_type_layout_start = type_layout_start + type_layout_start_pattern.len(); + while let Some(type_layout_start) = kernel_ptx.find(BEFORE_PARAM_PATTERN) { + let param_start = type_layout_start + BEFORE_PARAM_PATTERN.len(); - let Some(type_layout_middle) = kernel_ptx[after_type_layout_start..] - .find(&format!(".visible .entry {kernel_layout_name}")).map(|i| after_type_layout_start + i) - else { + let Some(len_start_offset) = kernel_ptx[param_start..].find(PARAM_LEN_PATTERN) else { abort_call_site!( - "Kernel compilation generated invalid PTX: incomplete type layout information" + "Kernel compilation generated invalid PTX: missing type layout data" ) }; + let len_start = param_start + len_start_offset + PARAM_LEN_PATTERN.len(); - let mut next_type_layout = after_type_layout_start; + let Some(bytes_start_offset) = kernel_ptx[len_start..].find(LEN_BYTES_PATTERN) else { + abort_call_site!( + "Kernel compilation generated invalid PTX: missing type layout length" + ) + }; + let bytes_start = len_start + bytes_start_offset + LEN_BYTES_PATTERN.len(); - while let Some(param_start_offset) = - kernel_ptx[next_type_layout..type_layout_middle].find(BEFORE_PARAM_PATTERN) - { - let param_start = next_type_layout + param_start_offset + BEFORE_PARAM_PATTERN.len(); + let Some(bytes_end_offset) = kernel_ptx[bytes_start..].find(AFTER_BYTES_PATTERN) else { + abort_call_site!( + "Kernel compilation generated invalid PTX: invalid type layout data" + ) + }; + let param = &kernel_ptx[param_start..(param_start + len_start_offset)]; + let len = &kernel_ptx[len_start..(len_start + bytes_start_offset)]; + let bytes = &kernel_ptx[bytes_start..(bytes_start + bytes_end_offset)]; - if let Some(len_start_offset) = - kernel_ptx[param_start..type_layout_middle].find(PARAM_LEN_PATTERN) - { - let len_start = param_start + len_start_offset + PARAM_LEN_PATTERN.len(); + let param = quote::format_ident!("{}", param); - if let Some(bytes_start_offset) = - kernel_ptx[len_start..type_layout_middle].find(LEN_BYTES_PATTERN) - { - let bytes_start = len_start + bytes_start_offset + LEN_BYTES_PATTERN.len(); + let Ok(len) = len.parse::() else { + abort_call_site!( + "Kernel compilation generated invalid PTX: invalid type layout length" + ) + }; + let Ok(bytes) = bytes.split(", ").map(std::str::FromStr::from_str).collect::, _>>() else { + abort_call_site!( + "Kernel compilation generated invalid PTX: invalid type layout byte" + ) + }; - if let Some(bytes_end_offset) = - kernel_ptx[bytes_start..type_layout_middle].find(AFTER_BYTES_PATTERN) - { - let param = &kernel_ptx[param_start..(param_start + len_start_offset)]; - let len = &kernel_ptx[len_start..(len_start + bytes_start_offset)]; - let bytes = &kernel_ptx[bytes_start..(bytes_start + bytes_end_offset)]; - - let param = quote::format_ident!("{}", param); - - let Ok(len) = len.parse::() else { - abort_call_site!( - "Kernel compilation generated invalid PTX: invalid type layout length" - ) - }; - let Ok(bytes) = bytes.split(", ").map(std::str::FromStr::from_str).collect::, _>>() else { - abort_call_site!( - "Kernel compilation generated invalid PTX: invalid type layout byte" - ) - }; - - if bytes.len() != len { - abort_call_site!( - "Kernel compilation generated invalid PTX: type layout length \ - mismatch" - ); - } - - let byte_str = syn::LitByteStr::new(&bytes, proc_macro2::Span::call_site()); - - type_layouts.push(quote! { - const #param: &[u8; #len] = #byte_str; - }); - - next_type_layout = - bytes_start + bytes_end_offset + BYTES_PARAM_PATTERN.len(); - } else { - next_type_layout = bytes_start; - } - } else { - next_type_layout = len_start; - } - } else { - next_type_layout = param_start; - } + if bytes.len() != len { + abort_call_site!( + "Kernel compilation generated invalid PTX: type layout length mismatch" + ); } - let Some(type_layout_end) = kernel_ptx[type_layout_middle..].find('}').map(|i| { - type_layout_middle + i + '}'.len_utf8() - }) else { - abort_call_site!("Kernel compilation generated invalid PTX") - }; + let byte_str = syn::LitByteStr::new(&bytes, proc_macro2::Span::call_site()); + + type_layouts.push(quote! { + const #param: &[u8; #len] = #byte_str; + }); + + let type_layout_end = bytes_start + bytes_end_offset + BYTES_PARAM_PATTERN.len(); kernel_ptx.replace_range(type_layout_start..type_layout_end, ""); } - check_kernel_ptx_and_report( - &kernel_ptx, - Specialisation::Link(&specialisation), - &kernel_hash, - &ptx_lint_levels, - ); + type_layouts +} - (quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into() +fn remove_kernel_type_use_from_ptx(kernel_ptx: &mut String) { + while let Some(kernel_type_layout_start) = kernel_ptx.find(KERNEL_TYPE_USE_START_CANARY) { + let kernel_type_layout_start = kernel_ptx[..kernel_type_layout_start] + .rfind('\n') + .unwrap_or(kernel_type_layout_start); + + let Some(kernel_type_layout_end_offset) = kernel_ptx[ + kernel_type_layout_start.. + ].find(KERNEL_TYPE_USE_END_CANARY) else { + abort_call_site!( + "Kernel compilation generated invalid PTX: incomplete type layout use section" + ); + }; + + let kernel_type_layout_end_offset = kernel_type_layout_end_offset + + kernel_ptx[kernel_type_layout_start + kernel_type_layout_end_offset..] + .find('\n') + .unwrap_or(KERNEL_TYPE_USE_END_CANARY.len()); + + let kernel_type_layout_end = kernel_type_layout_start + kernel_type_layout_end_offset; + + kernel_ptx.replace_range(kernel_type_layout_start..kernel_type_layout_end, ""); + } } #[allow(clippy::too_many_lines)] diff --git a/rust-cuda-derive/src/kernel/mod.rs b/rust-cuda-derive/src/kernel/mod.rs index 6dff13380..9e3a80789 100644 --- a/rust-cuda-derive/src/kernel/mod.rs +++ b/rust-cuda-derive/src/kernel/mod.rs @@ -4,3 +4,6 @@ pub mod wrapper; mod lints; mod utils; + +const KERNEL_TYPE_USE_START_CANARY: &str = "// //"; +const KERNEL_TYPE_USE_END_CANARY: &str = "// //"; diff --git a/rust-cuda-derive/src/kernel/wrapper/generate/cuda_wrapper.rs b/rust-cuda-derive/src/kernel/wrapper/generate/cuda_wrapper.rs index 29473858e..04e396d70 100644 --- a/rust-cuda-derive/src/kernel/wrapper/generate/cuda_wrapper.rs +++ b/rust-cuda-derive/src/kernel/wrapper/generate/cuda_wrapper.rs @@ -2,7 +2,10 @@ use proc_macro2::TokenStream; use quote::quote_spanned; use syn::spanned::Spanned; -use super::super::{FuncIdent, FunctionInputs, InputCudaType, KernelConfig}; +use super::super::{ + super::{KERNEL_TYPE_USE_END_CANARY, KERNEL_TYPE_USE_START_CANARY}, + FuncIdent, FunctionInputs, InputCudaType, KernelConfig, +}; #[allow(clippy::too_many_lines)] pub(in super::super) fn quote_cuda_wrapper( @@ -96,29 +99,27 @@ pub(in super::super) fn quote_cuda_wrapper( syn::FnArg::Receiver(_) => unreachable!(), }); - let func_type_layout_ident = quote::format_ident!("{}_type_layout", func_ident); - quote! { #[cfg(target_os = "cuda")] #[#crate_path::device::specialise_kernel_entry(#args)] #[no_mangle] #(#func_attrs)* - pub unsafe extern "ptx-kernel" fn #func_type_layout_ident(#(#func_params: &mut &[u8]),*) { + pub unsafe extern "ptx-kernel" fn #func_ident_hash(#(#ptx_func_inputs),*) { + unsafe { + ::core::arch::asm!(#KERNEL_TYPE_USE_START_CANARY); + } #( #[no_mangle] static #func_layout_params: [ u8; #crate_path::const_type_layout::serialised_type_graph_len::<#ptx_func_types>() ] = #crate_path::const_type_layout::serialise_type_graph::<#ptx_func_types>(); - *#func_params = &#func_layout_params; + unsafe { ::core::ptr::read_volatile(&#func_layout_params[0]) }; )* - } + unsafe { + ::core::arch::asm!(#KERNEL_TYPE_USE_END_CANARY); + } - #[cfg(target_os = "cuda")] - #[#crate_path::device::specialise_kernel_entry(#args)] - #[no_mangle] - #(#func_attrs)* - pub unsafe extern "ptx-kernel" fn #func_ident_hash(#(#ptx_func_inputs),*) { #[deny(improper_ctypes)] mod __rust_cuda_ffi_safe_assert { use super::#args; diff --git a/rust-cuda-ptx-jit/src/device.rs b/rust-cuda-ptx-jit/src/device.rs index 533021b90..c647a65eb 100644 --- a/rust-cuda-ptx-jit/src/device.rs +++ b/rust-cuda-ptx-jit/src/device.rs @@ -5,7 +5,7 @@ macro_rules! PtxJITConstLoad { ([$index:literal] => $reference:expr) => { unsafe { ::core::arch::asm!( - concat!("// //"), + ::core::concat!("// //"), in(reg32) *($reference as *const _ as *const u32), ) }