diff --git a/build.rs b/build.rs index 31e05f4..6f921ef 100644 --- a/build.rs +++ b/build.rs @@ -2,9 +2,12 @@ use core::panic; use std::env; use std::path::PathBuf; +const CXX_STANDARD: &str = "c++17"; + const WRAPPER_NAME: &str = "xsf_wrapper"; const WRAPPER_PREAMBLE: &str = "// generated by build.rs -- do not edit\n\n"; +const XSF_DIR: &str = "xsf"; const XSF_HEADERS: &[&str] = &[ "airy.h", "alg.h", @@ -325,6 +328,62 @@ const XSF_RENAME: &[(&str, &str)] = &[ ("log1p", "log1p_"), ]; +// complex helper functions +const DECL_COMPLEX_HELPERS: &str = r#" +std::complex complex__new(double re, double im); +void complex__values(std::complex z, double &re, double &im); +"#; + +const IMPL_COMPLEX_HELPERS: &str = r#" +std::complex complex__new(double re, double im) { + return std::complex(re, im); +} +void complex__values(std::complex z, double &re, double &im) { + re = std::real(z); + im = std::imag(z); +} +"#; + +// cevalpoly +const DECL_CEVALPOLY: &str = r#" +std::complex cevalpoly(const double *coeffs, int degree, std::complex z); +"#; + +const IMPL_CEVALPOLY: &str = r#" +std::complex cevalpoly(const double *coeffs, int degree, std::complex z) { + return xsf::cevalpoly(coeffs, degree, z); +} +"#; + +// assoc_legendre_p +const DECL_ASSOC_LEGENDRE_P: &str = r#" +double assoc_legendre_p_0(int n, int m, double z, int bc); +double assoc_legendre_p_1(int n, int m, double z, int bc); +std::complex assoc_legendre_p_0_1(int n, int m, std::complex z, int bc); +std::complex assoc_legendre_p_1_1(int n, int m, std::complex z, int bc); +"#; + +const IMPL_ASSOC_LEGENDRE_P: &str = r#" +double assoc_legendre_p_0(int n, int m, double z, int bc) { + return xsf::assoc_legendre_p(xsf::assoc_legendre_unnorm, n, m, z, bc); +} +double assoc_legendre_p_1(int n, int m, double z, int bc) { + return xsf::assoc_legendre_p(xsf::assoc_legendre_norm, n, m, z, bc); +} +std::complex assoc_legendre_p_0_1(int n, int m, std::complex z, int bc) { + return xsf::assoc_legendre_p(xsf::assoc_legendre_unnorm, n, m, z, bc); +} +std::complex assoc_legendre_p_1_1(int n, int m, std::complex z, int bc) { + return xsf::assoc_legendre_p(xsf::assoc_legendre_norm, n, m, z, bc); +} +"#; + +const ALLOWLIST_EXTRA: &[&str] = &[ + "cevalpoly", + "assoc_legendre_p_[01](_1)?", + "complex__(new|values)", +]; + fn get_ctype(code: char) -> &'static str { match code { 'i' => "int", @@ -338,51 +397,59 @@ fn get_ctype(code: char) -> &'static str { } } -fn fmt_params(spec: &str, types: bool, use_pointers: bool) -> String { +fn split_typespec(spec: &str) -> (&str, &str) { let parts: Vec<&str> = spec.split("->").collect(); - let inputs = parts[0]; - let outputs = parts[1]; - - let mut params = Vec::new(); + assert!(parts.len() == 2); + (parts[0], parts[1]) +} - for (i, p) in inputs.chars().enumerate() { - if types { - params.push(format!("{} x{}", get_ctype(p), i)); - } else { - params.push(format!("x{}", i)); - } +fn fmt_return(spec: &str) -> &str { + let chars = split_typespec(spec).1; + if chars.len() > 1 { + "void" + } else { + get_ctype(chars.chars().next().unwrap()) } +} - if outputs.len() > 1 { - for (i, p) in outputs.chars().enumerate() { +fn fmt_params(spec: &str, types: bool, do_deref: bool) -> String { + let (inputs, outputs) = split_typespec(spec); + + let mut params = inputs + .chars() + .map(get_ctype) + .enumerate() + .map(|(i, ct)| { if types { - if use_pointers { - params.push(format!("{} *y{}", get_ctype(p), i)); - } else { - params.push(format!("{} &y{}", get_ctype(p), i)); - } + format!("{ct} x{i}") } else { - params.push(format!("y{}", i)); + format!("x{i}") } + }) + .collect::>(); + + if outputs.len() > 1 { + let whence = if do_deref { '*' } else { '&' }; + if types { + params.extend( + outputs + .chars() + .map(get_ctype) + .enumerate() + .map(|(i, ct)| format!("{ct} {whence}y{i}")), + ); + } else { + params.extend((0..outputs.len()).map(|i| format!("y{i}"))); } } params.join(", ") } -fn fmt_return(types: &str) -> String { - let chars = types.split("->").last().unwrap(); - if chars.len() > 1 { - "void".to_string() - } else { - get_ctype(chars.chars().next().unwrap()).to_string() - } -} - -fn fmt_func(name: &str, types: &str, suffix: &str) -> String { - let ret = fmt_return(types); - let use_pointers = name.ends_with('*'); - let params = fmt_params(types, true, use_pointers); +fn fmt_func(name: &str, spec: &str, suffix: &str) -> String { + let ret = fmt_return(spec); + let do_deref = name.ends_with('*'); + let params = fmt_params(spec, true, do_deref); let base_name = XSF_RENAME .iter() .find(|(n, _)| *n == name.trim_end_matches('*')) @@ -391,58 +458,42 @@ fn fmt_func(name: &str, types: &str, suffix: &str) -> String { let func_name = if suffix.is_empty() { base_name.to_string() } else { - format!("{}_{}", base_name, suffix) + format!("{base_name}_{suffix}") }; - format!("{} {}({})", ret, func_name, params) + format!("{ret} {func_name}({params})") } -fn fmt_call(name: &str, types: &str) -> String { - let parts: Vec<&str> = types.split("->").collect(); - let inputs = parts[0]; - let outputs = parts[1]; +fn fmt_call(name: &str, spec: &str) -> String { + let (inputs, outputs) = split_typespec(spec); let clean_name = name.trim_end_matches('*'); - let mut args = Vec::new(); - - for i in 0..inputs.len() { - args.push(format!("x{}", i)); - } + let mut args = (0..inputs.len()) + .map(|i| format!("x{i}")) + .collect::>(); if outputs.len() > 1 { - for i in 0..outputs.len() { - args.push(format!("y{}", i)); - } + args.extend((0..outputs.len()).map(|i| format!("y{i}"))); } format!("xsf::{}({})", clean_name, args.join(", ")) } fn push_line(source: &mut String, line: &str) { - source.push_str(line); + source.push_str(line.trim()); source.push('\n'); } fn generate_header(dir_out: &str) -> String { let mut source = String::from(WRAPPER_PREAMBLE); - push_line(&mut source, "#pragma once"); + // includes push_line(&mut source, "#include "); - // Put wrapper functions in their own namespace - push_line(&mut source, ""); - push_line(&mut source, "namespace xsf_wrapper {"); - push_line(&mut source, ""); - push_line( - &mut source, - "std::complex complex__new(double re, double im);", - ); - push_line( - &mut source, - "void complex__values(std::complex z, double &re, double &im);", - ); + // namespace push_line(&mut source, ""); + push_line(&mut source, &format!("namespace {WRAPPER_NAME} {{")); - // Generate unique function names for overloads + // let mut name_counts = std::collections::HashMap::new(); for (name, types) in XSF_TYPES { let count = name_counts.entry(*name).or_insert(0); @@ -452,32 +503,15 @@ fn generate_header(dir_out: &str) -> String { *count += 1; } - // `cevalpoly` requires special-casing - push_line( - &mut source, - "std::complex cevalpoly(const double *coeffs, int degree, std::complex z);", - ); - - // `assoc_legendre_p` requires special-casing - push_line( - &mut source, - "double assoc_legendre_p_0(int n, int m, double z, int bc);", - ); - push_line( - &mut source, - "std::complex assoc_legendre_p_0_1(int n, int m, std::complex z, int bc);", - ); - push_line( - &mut source, - "double assoc_legendre_p_1(int n, int m, double z, int bc);", - ); - push_line( - &mut source, - "std::complex assoc_legendre_p_1_1(int n, int m, std::complex z, int bc);", - ); + // special-casing + push_line(&mut source, DECL_CEVALPOLY); + push_line(&mut source, DECL_ASSOC_LEGENDRE_P); - push_line(&mut source, ""); - push_line(&mut source, "} // namespace xsf_wrapper"); + // helper functions + push_line(&mut source, DECL_COMPLEX_HELPERS); + + // close namespace + push_line(&mut source, "}"); let file = format!("{dir_out}/{WRAPPER_NAME}.hpp"); std::fs::write(&file, source).unwrap(); @@ -487,82 +521,43 @@ fn generate_header(dir_out: &str) -> String { fn build_wrapper(dir_out: &str, include: &str) { let mut source = String::from(WRAPPER_PREAMBLE); - push_line(&mut source, &format!("#include \"{WRAPPER_NAME}.hpp\"")); + // includes + push_line(&mut source, &format!(r#"#include "{WRAPPER_NAME}.hpp""#)); for xsf_header in XSF_HEADERS { - push_line(&mut source, &format!("#include \"xsf/{xsf_header}\"")); + push_line(&mut source, &format!(r#"#include "xsf/{xsf_header}""#)); } - // Put wrapper implementations in the same namespace - push_line(&mut source, ""); - push_line(&mut source, "namespace xsf_wrapper {"); - push_line(&mut source, ""); - push_line( - &mut source, - "std::complex complex__new(double re, double im) { - return std::complex(re, im); - }", - ); - push_line( - &mut source, - "void complex__values(std::complex z, double &re, double &im) { - re = std::real(z); - im = std::imag(z); - }", - ); + // namespace push_line(&mut source, ""); + push_line(&mut source, &format!("namespace {WRAPPER_NAME} {{")); // Generate unique function implementations for overloads let mut name_counts = std::collections::HashMap::new(); for (name, types) in XSF_TYPES { let count = name_counts.entry(*name).or_insert(0); let suffix = if *count == 0 { "" } else { &count.to_string() }; - let func_decl = fmt_func(name, types, suffix); + *count += 1; + + let decl = fmt_func(name, types, suffix); let call = fmt_call(name, types); - let ret_type = fmt_return(types); - if ret_type == "void" { - push_line(&mut source, &format!("{func_decl} {{ {call}; }}")); + + let stmt = if fmt_return(types) == "void" { + call.to_string() } else { - push_line(&mut source, &format!("{func_decl} {{ return {call}; }}")); - } - *count += 1; + format!("return {call}") + }; + push_line(&mut source, &format!("{decl} {{ {stmt}; }}")); } - // `cevalpoly` requires special-casing - push_line( - &mut source, - "std::complex cevalpoly(const double *coeffs, int degree, std::complex z) { - return xsf::cevalpoly(coeffs, degree, z); - }", - ); - - // `assoc_legendre_p` requires special-casing - push_line( - &mut source, - "double assoc_legendre_p_0(int n, int m, double z, int bc) { - return xsf::assoc_legendre_p(xsf::assoc_legendre_unnorm, n, m, z, bc); - }", - ); - push_line( - &mut source, - "std::complex assoc_legendre_p_0_1(int n, int m, std::complex z, int bc) { - return xsf::assoc_legendre_p(xsf::assoc_legendre_unnorm, n, m, z, bc); - }", - ); - push_line( - &mut source, - "double assoc_legendre_p_1(int n, int m, double z, int bc) { - return xsf::assoc_legendre_p(xsf::assoc_legendre_norm, n, m, z, bc); - }", - ); - push_line( - &mut source, - "std::complex assoc_legendre_p_1_1(int n, int m, std::complex z, int bc) { - return xsf::assoc_legendre_p(xsf::assoc_legendre_norm, n, m, z, bc); - }", - ); + // special-casing + push_line(&mut source, IMPL_CEVALPOLY); + push_line(&mut source, IMPL_ASSOC_LEGENDRE_P); - push_line(&mut source, ""); - push_line(&mut source, "} // namespace xsf_wrapper"); + // helper functions + push_line(&mut source, IMPL_COMPLEX_HELPERS); + + // close namespace + push_line(&mut source, "}"); let file_cpp = format!("{dir_out}/{WRAPPER_NAME}.cpp"); std::fs::write(&file_cpp, source).unwrap(); @@ -578,50 +573,48 @@ fn build_wrapper(dir_out: &str, include: &str) { if build.get_compiler().is_like_msvc() { // windows - build.flag("/std:c++17"); + build.flag(format!("/std:{CXX_STANDARD}")); } else { - build.std("c++17"); + build.std(CXX_STANDARD); } build.compile(WRAPPER_NAME); } -fn generate_bindings(dir_out: &str, header: &str) { - // Generate allowlist pattern including numbered overloads - let mut allowlist_functions = Vec::new(); - allowlist_functions.push("xsf_wrapper::complex__.+".to_string()); - - let mut name_counts = std::collections::HashMap::new(); - for (name, _) in XSF_TYPES { - let count = name_counts.entry(*name).or_insert(0); - let clean_name = name.trim_end_matches('*'); - let base_name = XSF_RENAME - .iter() - .find(|(n, _)| n == &clean_name) - .map(|(_, r)| *r) - .unwrap_or(clean_name); - - if *count == 0 { - allowlist_functions.push(format!("xsf_wrapper::{}", base_name)); - } else { - allowlist_functions.push(format!("xsf_wrapper::{}_{}", base_name, count)); - } - *count += 1; +fn get_allowlist() -> String { + // regex pattern of allowed functions + fn format_entry(name: &str) -> String { + format!(r"{WRAPPER_NAME}::{name}(_\d)?") } - allowlist_functions.push("xsf_wrapper::cevalpoly".to_string()); - allowlist_functions.push("xsf_wrapper::assoc_legendre_p_.+".to_string()); - let allowlist_pattern = allowlist_functions.join("|"); + let mut entries = XSF_TYPES + .iter() + .map(|(name, _)| { + let name_orig = name.trim_end_matches('*'); + let name_safe = XSF_RENAME + .iter() + .find(|(n, _)| n == &name_orig) + .map(|(_, r)| *r) + .unwrap_or(name_orig); + + format_entry(name_safe) + }) + .chain(ALLOWLIST_EXTRA.iter().map(|s| format_entry(s))) + .collect::>(); + + entries.dedup(); + entries.join("|") +} +fn generate_bindings(dir_out: &str, header: &str) { bindgen::Builder::default() .header(header) + .enable_cxx_namespaces() .size_t_is_usize(true) .sort_semantically(true) - .derive_copy(false) // for consistency across platforms - .opaque_type("std::*") - .allowlist_function(&allowlist_pattern) + .derive_copy(false) + .allowlist_function(get_allowlist()) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) - .enable_cxx_namespaces() .generate() .unwrap() .write_to_file(PathBuf::from(dir_out).join("bindings.rs")) @@ -630,11 +623,10 @@ fn generate_bindings(dir_out: &str, header: &str) { fn main() { let out_dir = env::var("OUT_DIR").unwrap(); - let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - - let include = format!("{manifest_dir}/xsf/include"); let header = generate_header(&out_dir); + let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let include = format!("{manifest_dir}/{XSF_DIR}/include"); println!("cargo:rerun-if-changed={include}"); build_wrapper(&out_dir, &include);