diff --git a/build.rs b/build.rs index f7c8426..6694119 100644 --- a/build.rs +++ b/build.rs @@ -1,266 +1,300 @@ +use core::panic; use std::env; use std::path::PathBuf; +const WRAPPER_NAME: &str = "xsf_wrapper"; +const WRAPPER_PREAMBLE: &str = "// generated by build.rs -- do not edit\n\n"; + const XSF_HEADERS: &[&str] = &[ - // "xsf/airy.h", - "xsf/alg.h", - // "xsf/bessel.h", - "xsf/beta.h", - "xsf/binom.h", - "xsf/digamma.h", - // "xsf/ellip.h", - "xsf/erf.h", - "xsf/exp.h", - "xsf/expint.h", - // "xsf/fp_error_metrics.h", - // "xsf/fresnel.h", - "xsf/gamma.h", - "xsf/hyp2f1.h", - "xsf/iv_ratio.h", - "xsf/kelvin.h", - // "xsf/lambertw.h", - "xsf/legendre.h", - "xsf/log_exp.h", - "xsf/log.h", - "xsf/loggamma.h", - "xsf/mathieu.h", - // "xsf/par_cyl.h", - // "xsf/recur.h", - // "xsf/sici.h", - "xsf/specfun.h", - // "xsf/sph_bessel.h", - // "xsf/sph_harm.h", - "xsf/sphd_wave.h", - "xsf/stats.h", - "xsf/struve.h", - "xsf/trig.h", - "xsf/wright_bessel.h", - "xsf/zeta.h", + "config.h", + // "airy.h", + "alg.h", + // "bessel.h", + "beta.h", + "binom.h", + // "cdflib.h", + "digamma.h", + // "ellip.h", + "erf.h", + // "evalpoly.h" + "exp.h", + "expint.h", + // "fresnel.h", + "gamma.h", + "hyp2f1.h", + "iv_ratio.h", + "kelvin.h", + // "lambertw.h", + "legendre.h", + "log_exp.h", + "log.h", + "loggamma.h", + "mathieu.h", + // "par_cyl.h", + // "recur.h", + // "sici.h", + "specfun.h", + // "sph_bessel.h", + // "sph_harm.h", + "sphd_wave.h", + "stats.h", + "struve.h", + "trig.h", + "wright_bessel.h", + "zeta.h", ]; - -const WRAPPER_FUNCTIONS: &[(&str, &str)] = &[ +// e.g. `("spam", "if->d")` becomes `double xsf_spam(int x0, float x1)` +const XSF_TYPES: &[(&str, &str)] = &[ // alg.h - ("cbrt", "double x"), + ("cbrt", "d->d"), // beta.h - ("beta", "double a, double b"), - ("betaln", "double a, double b"), + ("beta", "dd->d"), + ("betaln", "dd->d"), // binom.h - ("binom", "double n, double k"), + // cdflib.h (TODO: `gdtrib`) + ("binom", "dd->d"), // digamma.h - ("digamma", "double z"), + ("digamma", "d->d"), // erf.h - ("erf", "double x"), - ("erfc", "double x"), - ("erfcx", "double x"), - ("erfi", "double x"), - ("voigt_profile", "double x, double sigma, double gamma"), - ("dawsn", "double x"), + ("erf", "d->d"), + ("erfc", "d->d"), + ("erfcx", "d->d"), + ("erfi", "d->d"), + ("voigt_profile", "ddd->d"), + ("dawsn", "d->d"), + // evalpoly.h (TODO: `cevalpoly`) // exp.h - ("expm1", "double x"), - ("exp2", "double x"), - ("exp10", "double x"), + ("expm1", "d->d"), + ("exp2", "d->d"), + ("exp10", "d->d"), // expint.h - ("exp1", "double x"), - ("expi", "double x"), - ("scaled_exp1", "double x"), - // fp_error_metrics.h - // ("extended_absolute_error", "double actual, double desired"), - // ("extended_relative_error", "double actual, double desired"), + ("exp1", "d->d"), + ("expi", "d->d"), + ("scaled_exp1", "d->d"), + // fresnel.h (TODO: `fresnel`, `fcszo`) // gamma.h - ("gamma", "double x"), - ("gammaln", "double x"), - ("gammasgn", "double x"), - ("gammainc", "double a, double x"), - ("gammaincinv", "double a, double p"), - ("gammaincc", "double a, double x"), - ("gammainccinv", "double a, double p"), - ("gamma_ratio", "double a, double b"), + ("gamma", "d->d"), // TODO: complex + ("gammaln", "d->d"), + ("gammasgn", "d->d"), + ("gammainc", "dd->d"), + ("gammaincinv", "dd->d"), + ("gammaincc", "dd->d"), + ("gammainccinv", "dd->d"), + ("gamma_ratio", "dd->d"), // hyp2f1.h - ("hyp2f1", "double a, double b, double c, double x"), + ("hyp2f1", "dddd->d"), // TODO: complex // iv_ratio.h - ("iv_ratio", "double v, double x"), - ("iv_ratio_c", "double v, double x"), + ("iv_ratio", "dd->d"), + ("iv_ratio_c", "dd->d"), // kelvin.h (TODO: `kelvin`, `klvnzo`) - ("ber", "double x"), - ("bei", "double x"), - ("ker", "double x"), - ("kei", "double x"), - ("berp", "double x"), - ("beip", "double x"), - ("kerp", "double x"), - ("keip", "double x"), + ("ber", "d->d"), + ("bei", "d->d"), + ("ker", "d->d"), + ("kei", "d->d"), + ("berp", "d->d"), + ("beip", "d->d"), + ("kerp", "d->d"), + ("keip", "d->d"), // legendre.h (TODO: `assoc_legendre_p`, `lqn`, `lqmn`) - ("legendre_p", "int n, double z"), - ("sph_legendre_p", "int n, int m, double theta"), + ("legendre_p", "id->d"), + ("sph_legendre_p", "iid->d"), // log_exp.h - ("expit", "double x"), - ("exprel", "double x"), - ("logit", "double x"), - ("log_expit", "double x"), - ("log1mexp", "double x"), + ("expit", "d->d"), + ("exprel", "d->d"), + ("logit", "d->d"), + ("log_expit", "d->d"), + ("log1mexp", "d->d"), // log.h - ("log1p", "double x"), - ("log1pmx", "double x"), - ("xlogy", "double x, double y"), - ("xlog1py", "double x, double y"), + ("log1p", "d->d"), + ("log1pmx", "d->d"), + ("xlogy", "dd->d"), + ("xlog1py", "dd->d"), // loggamma.h - ("loggamma", "double x"), - ("rgamma", "double z"), + ("loggamma", "d->d"), + ("rgamma", "d->d"), // mathieu.h (TODO: `cen`, `sem`, `mcm1`, `msm1`, `mcm2`, `msm2`) - ("cem_cva", "double m, double q"), - ("sem_cva", "double m, double q"), + ("cem_cva", "dd->d"), + ("sem_cva", "dd->d"), // specfun.h (TODO: `chyp2f1`, `cerf`) - ("hypu", "double a, double b, double x"), - ("hyp1f1", "double a, double b, double x"), - ("pmv", "double m, double v, double x"), + ("hypu", "ddd->d"), + ("hyp1f1", "ddd->d"), + ("pmv", "ddd->d"), // sphd_wave.h - ("prolate_segv", "double m, double n, double c"), - ("oblate_segv", "double m, double n, double c"), + ("prolate_segv", "ddd->d"), + ("oblate_segv", "ddd->d"), // stats.h - ("bdtr", "double k, int n, double p"), - ("bdtrc", "double k, int n, double p"), - ("bdtri", "double k, int n, double y"), - ("chdtr", "double df, double x"), - ("chdtrc", "double df, double x"), - ("chdtri", "double df, double y"), - ("fdtr", "double a, double b, double x"), - ("fdtrc", "double a, double b, double x"), - ("fdtri", "double a, double b, double y"), - ("gdtr", "double a, double b, double x"), - ("gdtrc", "double a, double b, double x"), - ("kolmogorov", "double x"), - ("kolmogc", "double x"), - ("kolmogi", "double x"), - ("kolmogp", "double x"), - ("ndtr", "double x"), - ("ndtri", "double x"), - ("log_ndtr", "double x"), - ("nbdtr", "int k, int n, double p"), - ("nbdtrc", "int k, int n, double p"), - ("nbdtri", "int k, int n, double p"), - ("owens_t", "double h, double a"), - ("pdtr", "double k, double m"), - ("pdtrc", "double k, double m"), - ("pdtri", "int k, double y"), - ("smirnov", "int n, double x"), - ("smirnovc", "int n, double x"), - ("smirnovi", "int n, double x"), - ("smirnovp", "int n, double x"), - ("tukeylambdacdf", "double x, double lmbda"), + ("bdtr", "did->d"), + ("bdtrc", "did->d"), + ("bdtri", "did->d"), + ("chdtr", "dd->d"), + ("chdtrc", "dd->d"), + ("chdtri", "dd->d"), + ("fdtr", "ddd->d"), + ("fdtrc", "ddd->d"), + ("fdtri", "ddd->d"), + ("gdtr", "ddd->d"), + ("gdtrc", "ddd->d"), + ("kolmogorov", "d->d"), + ("kolmogc", "d->d"), + ("kolmogi", "d->d"), + ("kolmogp", "d->d"), + ("ndtr", "d->d"), + ("ndtri", "d->d"), + ("log_ndtr", "d->d"), + ("nbdtr", "iid->d"), + ("nbdtrc", "iid->d"), + ("nbdtri", "iid->d"), + ("owens_t", "dd->d"), + ("pdtr", "dd->d"), + ("pdtrc", "dd->d"), + ("pdtri", "id->d"), + ("smirnov", "id->d"), + ("smirnovc", "id->d"), + ("smirnovi", "id->d"), + ("smirnovp", "id->d"), + ("tukeylambdacdf", "dd->d"), // struve.h - ("itstruve0", "double x"), - ("it2struve0", "double x"), - ("itmodstruve0", "double x"), - ("struve_h", "double v, double z"), - ("struve_l", "double v, double z"), + ("itstruve0", "d->d"), + ("it2struve0", "d->d"), + ("itmodstruve0", "d->d"), + ("struve_h", "dd->d"), + ("struve_l", "dd->d"), // trig.h - ("sinpi", "double x"), - ("cospi", "double x"), - ("sindg", "double x"), - ("cosdg", "double x"), - ("tandg", "double x"), - ("cotdg", "double x"), - ("radian", "double d, double m, double s"), - ("cosm1", "double x"), + ("sinpi", "d->d"), + ("cospi", "d->d"), + ("sindg", "d->d"), + ("cosdg", "d->d"), + ("tandg", "d->d"), + ("cotdg", "d->d"), + ("radian", "ddd->d"), + ("cosm1", "d->d"), // wright_bessel.h - // ("wright_bessel_t", "double a, double b, double x"), - ("wright_bessel", "double a, double b, double x"), - ("log_wright_bessel", "double a, double b, double x"), + // ("wright_bessel_t", "ddd->d"), + ("wright_bessel", "ddd->d"), + ("log_wright_bessel", "ddd->d"), // zeta.h - ("riemann_zeta", "double x"), - ("zeta", "double x, double q"), - ("zetac", "double x"), + ("riemann_zeta", "d->d"), + ("zeta", "dd->d"), + ("zetac", "d->d"), ]; -fn main() { - let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let xsf_path = format!("{manifest_dir}/xsf"); - let out_dir = env::var("OUT_DIR").unwrap(); +fn get_ctype(code: char) -> &'static str { + match code { + 'i' => "int", + 'f' => "float", + 'd' => "double", + 'F' => "std::complex", + 'D' => "std::complex", + 'V' => "void", + _ => panic!("Unknown parameter type"), + } +} - setup_build_dependencies(&xsf_path); +fn fmt_params(spec: &str, types: bool) -> String { + // TODO: >1 outputs, e.g. `ii->di` => `void _(int x0, int x1, double &y0, int &y1)` + spec.split("->") + .next() + .unwrap() + .chars() + .enumerate() + .map(if types { + |(i, p)| format!("{} x{}", get_ctype(p), i) + } else { + |(i, _)| format!("x{}", i) + }) + .collect::>() + .join(", ") +} - let wrapper_header = generate_wrapper_header(&out_dir); - let wrapper_cpp = generate_wrapper_cpp(&wrapper_header, &out_dir); - build_cpp_library(&wrapper_cpp, &xsf_path); - generate_bindings(&wrapper_header); +fn fmt_return(types: &str) -> String { + let chars = types.split("->").last().unwrap(); + assert!(chars.len() == 1, "Multiple return values not supported"); + get_ctype(chars.chars().next().unwrap()).to_string() } -fn setup_build_dependencies(xsf_path: &str) { - println!("cargo:rerun-if-changed={xsf_path}/include"); +fn fmt_func(name: &str, types: &str) -> String { + let ret = fmt_return(types); + let params = fmt_params(types, true); + format!("{} xsf_{}({})", ret, name, params) } -fn extract_param_names(params: &str) -> String { - params - .split(',') - .map(|param| param.split_whitespace().last().unwrap_or("")) - .collect::>() - .join(", ") +fn fmt_call(name: &str, types: &str) -> String { + let args = fmt_params(types, false); + format!("xsf::{}({})", name, args) } -fn generate_wrapper_header(out_dir: &str) -> String { - let wrapper_header = format!("{out_dir}/xsf_wrapper.h"); +fn push_line(source: &mut String, line: &str) { + source.push_str(line); + source.push('\n'); +} - let mut header_content = - String::from("#pragma once\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n"); +fn generate_header(dir_out: &str) -> String { + let mut source = String::from(WRAPPER_PREAMBLE); - for (func_name, params) in WRAPPER_FUNCTIONS { - header_content.push_str(&format!("double xsf_{func_name}({params});\n")); + push_line(&mut source, "#pragma once"); + push_line(&mut source, ""); + push_line(&mut source, "extern \"C\" {"); + for func in XSF_TYPES.iter().map(|(n, s)| fmt_func(n, s)) { + push_line(&mut source, &format!(" {func};")); } + push_line(&mut source, "}"); - header_content.push_str("\n#ifdef __cplusplus\n}\n#endif\n"); - - std::fs::write(&wrapper_header, header_content).expect("Failed to write wrapper header"); - - wrapper_header + let file = format!("{dir_out}/{WRAPPER_NAME}.hpp"); + std::fs::write(&file, source).unwrap(); + file } -fn generate_wrapper_cpp(wrapper_header: &str, out_dir: &str) -> String { - let wrapper_cpp = format!("{out_dir}/xsf_wrapper_impl.cpp"); - - let mut cpp_content = format!("#include \"{wrapper_header}\"\n"); +fn build_wrapper(dir_out: &str, include: &str) { + let mut source = String::from(WRAPPER_PREAMBLE); - // Include all specified XSF headers - for header in XSF_HEADERS { - cpp_content.push_str(&format!("#include \"{header}\"\n")); + push_line(&mut source, &format!("#include \"{WRAPPER_NAME}.hpp\"")); + for xsf_header in XSF_HEADERS { + push_line(&mut source, &format!("#include \"xsf/{xsf_header}\"")); } + push_line(&mut source, ""); - cpp_content.push_str("\nextern \"C\" {\n\n"); - - for (func_name, params) in WRAPPER_FUNCTIONS { - let args = extract_param_names(params); - cpp_content.push_str(&format!( - "double xsf_{func_name}({params}) {{ return xsf::{func_name}({args}); }}\n" - )); + push_line(&mut source, "extern \"C\" {"); + for (func, call) in XSF_TYPES + .iter() + .map(|(n, s)| (fmt_func(n, s), fmt_call(n, s))) + { + push_line(&mut source, &format!(" {func} {{ return {call}; }}")); } + push_line(&mut source, "}"); - cpp_content.push_str("\n}\n"); - - std::fs::write(&wrapper_cpp, cpp_content).expect("Failed to write wrapper implementation"); + let file_cpp = format!("{dir_out}/{WRAPPER_NAME}.cpp"); + std::fs::write(&file_cpp, source).unwrap(); - wrapper_cpp -} - -fn build_cpp_library(wrapper_cpp: &str, xsf_path: &str) { cc::Build::new() .cpp(true) .std("c++17") - .file(wrapper_cpp) - .include(format!("{xsf_path}/include")) .flag_if_supported("-Wno-unused-parameter") .flag_if_supported("-Wno-logical-op-parentheses") - .compile("xsf_wrapper_impl"); + .file(file_cpp) + .include(include) + .compile(WRAPPER_NAME); } -fn generate_bindings(wrapper_header: &str) { - let bindings = bindgen::Builder::default() - .header(wrapper_header) +fn generate_bindings(dir_out: &str, file_hpp: &str) { + bindgen::Builder::default() + .header(file_hpp) .allowlist_function("xsf_.*") .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .generate() - .expect("Unable to generate bindings"); + .unwrap() + .write_to_file(PathBuf::from(dir_out).join("bindings.rs")) + .unwrap(); +} + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + + let include = format!("{manifest_dir}/xsf/include"); + let header = generate_header(&out_dir); + + println!("cargo:rerun-if-changed={include}"); - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); + build_wrapper(&out_dir, &include); + generate_bindings(&out_dir, &header); }