Skip to content

Commit

Permalink
Simplified the kernel parameter layout extraction from PTX
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed Jul 25, 2023
1 parent f06c86e commit 8a37039
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 104 deletions.
179 changes: 87 additions & 92 deletions rust-cuda-derive/src/kernel/link/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use ptx_builder::{
use super::{
lints::{LintLevel, PtxLint},
utils::skip_kernel_compilation,
KERNEL_TYPE_USE_END_CANARY, KERNEL_TYPE_USE_START_CANARY,
};

mod config;
Expand Down Expand Up @@ -66,14 +67,14 @@ pub fn check_kernel(tokens: TokenStream) -> TokenStream {
quote!(::core::result::Result::Ok(())).into()
}

#[allow(clippy::module_name_repetitions, clippy::too_many_lines)]
#[allow(clippy::module_name_repetitions)]
pub fn link_kernel(tokens: TokenStream) -> TokenStream {
proc_macro_error::set_dummy(quote! {
const PTX_STR: &'static str = "ERROR in this PTX compilation";
});

let LinkKernelConfig {
kernel,
kernel: _kernel,
kernel_hash,
args,
crate_name,
Expand Down Expand Up @@ -111,116 +112,110 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
.into()
};

let kernel_layout_name = if specialisation.is_empty() {
format!("{kernel}_type_layout_kernel")
} else {
format!(
"{kernel}_type_layout_kernel_{:016x}",
seahash::hash(specialisation.as_bytes())
)
};
let type_layouts = extract_ptx_kernel_layout(&mut kernel_ptx);
remove_kernel_type_use_from_ptx(&mut kernel_ptx);

let mut type_layouts = Vec::new();
check_kernel_ptx_and_report(
&kernel_ptx,
Specialisation::Link(&specialisation),
&kernel_hash,
&ptx_lint_levels,
);

(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
}

let type_layout_start_pattern = format!("\n\t// .globl\t{kernel_layout_name}");
fn extract_ptx_kernel_layout(kernel_ptx: &mut String) -> Vec<proc_macro2::TokenStream> {
const BEFORE_PARAM_PATTERN: &str = "\n.global .align 1 .b8 ";
const PARAM_LEN_PATTERN: &str = "[";
const LEN_BYTES_PATTERN: &str = "] = {";
const AFTER_BYTES_PATTERN: &str = "};\n";
const BYTES_PARAM_PATTERN: &str = "};";

if let Some(type_layout_start) = kernel_ptx.find(&type_layout_start_pattern) {
const BEFORE_PARAM_PATTERN: &str = "\n.global .align 1 .b8 ";
const PARAM_LEN_PATTERN: &str = "[";
const LEN_BYTES_PATTERN: &str = "] = {";
const AFTER_BYTES_PATTERN: &str = "};\n";
const BYTES_PARAM_PATTERN: &str = "};";
let mut type_layouts = Vec::new();

let after_type_layout_start = type_layout_start + type_layout_start_pattern.len();
while let Some(type_layout_start) = kernel_ptx.find(BEFORE_PARAM_PATTERN) {
let param_start = type_layout_start + BEFORE_PARAM_PATTERN.len();

let Some(type_layout_middle) = kernel_ptx[after_type_layout_start..]
.find(&format!(".visible .entry {kernel_layout_name}")).map(|i| after_type_layout_start + i)
else {
let Some(len_start_offset) = kernel_ptx[param_start..].find(PARAM_LEN_PATTERN) else {
abort_call_site!(
"Kernel compilation generated invalid PTX: incomplete type layout information"
"Kernel compilation generated invalid PTX: missing type layout data"
)
};
let len_start = param_start + len_start_offset + PARAM_LEN_PATTERN.len();

let mut next_type_layout = after_type_layout_start;
let Some(bytes_start_offset) = kernel_ptx[len_start..].find(LEN_BYTES_PATTERN) else {
abort_call_site!(
"Kernel compilation generated invalid PTX: missing type layout length"
)
};
let bytes_start = len_start + bytes_start_offset + LEN_BYTES_PATTERN.len();

while let Some(param_start_offset) =
kernel_ptx[next_type_layout..type_layout_middle].find(BEFORE_PARAM_PATTERN)
{
let param_start = next_type_layout + param_start_offset + BEFORE_PARAM_PATTERN.len();
let Some(bytes_end_offset) = kernel_ptx[bytes_start..].find(AFTER_BYTES_PATTERN) else {
abort_call_site!(
"Kernel compilation generated invalid PTX: invalid type layout data"
)
};
let param = &kernel_ptx[param_start..(param_start + len_start_offset)];
let len = &kernel_ptx[len_start..(len_start + bytes_start_offset)];
let bytes = &kernel_ptx[bytes_start..(bytes_start + bytes_end_offset)];

if let Some(len_start_offset) =
kernel_ptx[param_start..type_layout_middle].find(PARAM_LEN_PATTERN)
{
let len_start = param_start + len_start_offset + PARAM_LEN_PATTERN.len();
let param = quote::format_ident!("{}", param);

if let Some(bytes_start_offset) =
kernel_ptx[len_start..type_layout_middle].find(LEN_BYTES_PATTERN)
{
let bytes_start = len_start + bytes_start_offset + LEN_BYTES_PATTERN.len();
let Ok(len) = len.parse::<usize>() else {
abort_call_site!(
"Kernel compilation generated invalid PTX: invalid type layout length"
)
};
let Ok(bytes) = bytes.split(", ").map(std::str::FromStr::from_str).collect::<Result<Vec<u8>, _>>() else {
abort_call_site!(
"Kernel compilation generated invalid PTX: invalid type layout byte"
)
};

if let Some(bytes_end_offset) =
kernel_ptx[bytes_start..type_layout_middle].find(AFTER_BYTES_PATTERN)
{
let param = &kernel_ptx[param_start..(param_start + len_start_offset)];
let len = &kernel_ptx[len_start..(len_start + bytes_start_offset)];
let bytes = &kernel_ptx[bytes_start..(bytes_start + bytes_end_offset)];

let param = quote::format_ident!("{}", param);

let Ok(len) = len.parse::<usize>() else {
abort_call_site!(
"Kernel compilation generated invalid PTX: invalid type layout length"
)
};
let Ok(bytes) = bytes.split(", ").map(std::str::FromStr::from_str).collect::<Result<Vec<u8>, _>>() else {
abort_call_site!(
"Kernel compilation generated invalid PTX: invalid type layout byte"
)
};

if bytes.len() != len {
abort_call_site!(
"Kernel compilation generated invalid PTX: type layout length \
mismatch"
);
}

let byte_str = syn::LitByteStr::new(&bytes, proc_macro2::Span::call_site());

type_layouts.push(quote! {
const #param: &[u8; #len] = #byte_str;
});

next_type_layout =
bytes_start + bytes_end_offset + BYTES_PARAM_PATTERN.len();
} else {
next_type_layout = bytes_start;
}
} else {
next_type_layout = len_start;
}
} else {
next_type_layout = param_start;
}
if bytes.len() != len {
abort_call_site!(
"Kernel compilation generated invalid PTX: type layout length mismatch"
);
}

let Some(type_layout_end) = kernel_ptx[type_layout_middle..].find('}').map(|i| {
type_layout_middle + i + '}'.len_utf8()
}) else {
abort_call_site!("Kernel compilation generated invalid PTX")
};
let byte_str = syn::LitByteStr::new(&bytes, proc_macro2::Span::call_site());

type_layouts.push(quote! {
const #param: &[u8; #len] = #byte_str;
});

let type_layout_end = bytes_start + bytes_end_offset + BYTES_PARAM_PATTERN.len();

kernel_ptx.replace_range(type_layout_start..type_layout_end, "");
}

check_kernel_ptx_and_report(
&kernel_ptx,
Specialisation::Link(&specialisation),
&kernel_hash,
&ptx_lint_levels,
);
type_layouts
}

(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
fn remove_kernel_type_use_from_ptx(kernel_ptx: &mut String) {
while let Some(kernel_type_layout_start) = kernel_ptx.find(KERNEL_TYPE_USE_START_CANARY) {
let kernel_type_layout_start = kernel_ptx[..kernel_type_layout_start]
.rfind('\n')
.unwrap_or(kernel_type_layout_start);

let Some(kernel_type_layout_end_offset) = kernel_ptx[
kernel_type_layout_start..
].find(KERNEL_TYPE_USE_END_CANARY) else {
abort_call_site!(
"Kernel compilation generated invalid PTX: incomplete type layout use section"
);
};

let kernel_type_layout_end_offset = kernel_type_layout_end_offset
+ kernel_ptx[kernel_type_layout_start + kernel_type_layout_end_offset..]
.find('\n')
.unwrap_or(KERNEL_TYPE_USE_END_CANARY.len());

let kernel_type_layout_end = kernel_type_layout_start + kernel_type_layout_end_offset;

kernel_ptx.replace_range(kernel_type_layout_start..kernel_type_layout_end, "");
}
}

#[allow(clippy::too_many_lines)]
Expand Down
3 changes: 3 additions & 0 deletions rust-cuda-derive/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ pub mod wrapper;

mod lints;
mod utils;

const KERNEL_TYPE_USE_START_CANARY: &str = "// <rust-cuda-kernel-param-type-use-start> //";
const KERNEL_TYPE_USE_END_CANARY: &str = "// <rust-cuda-kernel-param-type-use-end> //";
23 changes: 12 additions & 11 deletions rust-cuda-derive/src/kernel/wrapper/generate/cuda_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use proc_macro2::TokenStream;
use quote::quote_spanned;
use syn::spanned::Spanned;

use super::super::{FuncIdent, FunctionInputs, InputCudaType, KernelConfig};
use super::super::{
super::{KERNEL_TYPE_USE_END_CANARY, KERNEL_TYPE_USE_START_CANARY},
FuncIdent, FunctionInputs, InputCudaType, KernelConfig,
};

#[allow(clippy::too_many_lines)]
pub(in super::super) fn quote_cuda_wrapper(
Expand Down Expand Up @@ -96,29 +99,27 @@ pub(in super::super) fn quote_cuda_wrapper(
syn::FnArg::Receiver(_) => unreachable!(),
});

let func_type_layout_ident = quote::format_ident!("{}_type_layout", func_ident);

quote! {
#[cfg(target_os = "cuda")]
#[#crate_path::device::specialise_kernel_entry(#args)]
#[no_mangle]
#(#func_attrs)*
pub unsafe extern "ptx-kernel" fn #func_type_layout_ident(#(#func_params: &mut &[u8]),*) {
pub unsafe extern "ptx-kernel" fn #func_ident_hash(#(#ptx_func_inputs),*) {
unsafe {
::core::arch::asm!(#KERNEL_TYPE_USE_START_CANARY);
}
#(
#[no_mangle]
static #func_layout_params: [
u8; #crate_path::const_type_layout::serialised_type_graph_len::<#ptx_func_types>()
] = #crate_path::const_type_layout::serialise_type_graph::<#ptx_func_types>();

*#func_params = &#func_layout_params;
unsafe { ::core::ptr::read_volatile(&#func_layout_params[0]) };
)*
}
unsafe {
::core::arch::asm!(#KERNEL_TYPE_USE_END_CANARY);
}

#[cfg(target_os = "cuda")]
#[#crate_path::device::specialise_kernel_entry(#args)]
#[no_mangle]
#(#func_attrs)*
pub unsafe extern "ptx-kernel" fn #func_ident_hash(#(#ptx_func_inputs),*) {
#[deny(improper_ctypes)]
mod __rust_cuda_ffi_safe_assert {
use super::#args;
Expand Down
2 changes: 1 addition & 1 deletion rust-cuda-ptx-jit/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ macro_rules! PtxJITConstLoad {
([$index:literal] => $reference:expr) => {
unsafe {
::core::arch::asm!(
concat!("// <rust-cuda-ptx-jit-const-load-{}-", $index, "> //"),
::core::concat!("// <rust-cuda-ptx-jit-const-load-{}-", $index, "> //"),
in(reg32) *($reference as *const _ as *const u32),
)
}
Expand Down

0 comments on commit 8a37039

Please sign in to comment.