Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 81 additions & 29 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const WRAPPER_NAME: &str = "xsf_wrapper";
const WRAPPER_PREAMBLE: &str = "// generated by build.rs -- do not edit\n\n";

const XSF_HEADERS: &[&str] = &[
// "airy.h",
"airy.h",
"alg.h",
"bessel.h",
"beta.h",
Expand Down Expand Up @@ -45,11 +45,17 @@ const XSF_HEADERS: &[&str] = &[
// e.g. `("spam", "if->d")` becomes `double xsf_spam(int x0, float x1)`
const XSF_TYPES: &[(&str, &str)] = &[
// airy.h
// TODO: `airyb`, `airyzo`, `airy`, `airye`, `itairy`
("airy", "d->dddd"),
("airy", "D->DDDD"),
("airye", "d->dddd"),
("airye", "D->DDDD"),
("itairy", "d->dddd"),
("airyb*", "d->dddd"),
("airyzo*", "ii->dddd"),
// alg.h
("cbrt", "d->d"),
// bessel.h
// TODO: `it1j0y0`, `it2j0y0`, `it1i0k0`, `it2i0k0`
// TODO: `rctj`, `rcty`,
("cyl_bessel_j", "dd->d"),
("cyl_bessel_j", "dD->D"),
("cyl_bessel_je", "dd->d"),
Expand Down Expand Up @@ -127,7 +133,7 @@ const XSF_TYPES: &[(&str, &str)] = &[
("expi", "D->D"),
("scaled_exp1", "d->d"),
// fresnel.h
// TODO: `fresnel`, `fcszo`
// TODO
// gamma.h
("gamma", "d->d"),
("gamma", "D->D"),
Expand All @@ -144,7 +150,8 @@ const XSF_TYPES: &[(&str, &str)] = &[
("iv_ratio", "dd->d"),
("iv_ratio_c", "dd->d"),
// kelvin.h
// TODO: `kelvin`, `klvnzo`
// TODO: `kelvin`: d->DDDD
// TODO: `klvnzo`: ii->[d]
("ber", "d->d"),
("bei", "d->d"),
("ker", "d->d"),
Expand All @@ -156,6 +163,8 @@ const XSF_TYPES: &[(&str, &str)] = &[
// lambertw.h
("lambertw", "Dld->D"),
// legendre.h
// TODO: `lqn`: d->[d],[d]
// TODO: `lqmn`: d->[[d]],[[d]]
("legendre_p", "id->d"),
("legendre_p", "iD->D"),
("sph_legendre_p", "iid->d"),
Expand Down Expand Up @@ -210,9 +219,10 @@ const XSF_TYPES: &[(&str, &str)] = &[
("sph_bessel_k_jac", "ld->d"),
("sph_bessel_k_jac", "lD->D"),
// sph_harm.h
// TODO: `sph_harm_y_all`: dd->[[d]]
("sph_harm_y", "iidd->D"),
// sphd_wave.h
// TODO: aswfa and radial
// TODO: `*_aswfa*` and `*_radial*`
("prolate_segv", "ddd->d"),
("oblate_segv", "ddd->d"),
// stats.h
Expand Down Expand Up @@ -302,36 +312,56 @@ fn get_ctype(code: char) -> &'static str {
}
}

fn fmt_params(spec: &str, types: bool) -> String {
// TODO: >1 outputs, e.g. `ii->di` => `void _(int x0, int x1, double &y0, int &y1)`
spec.split("->")
.next()
.unwrap()
.chars()
.enumerate()
.map(if types {
|(i, p)| format!("{} x{}", get_ctype(p), i)
fn fmt_params(spec: &str, types: bool, use_pointers: bool) -> String {
let parts: Vec<&str> = spec.split("->").collect();
let inputs = parts[0];
let outputs = parts[1];

let mut params = Vec::new();

for (i, p) in inputs.chars().enumerate() {
if types {
params.push(format!("{} x{}", get_ctype(p), i));
} else {
|(i, _)| format!("x{}", i)
})
.collect::<Vec<_>>()
.join(", ")
params.push(format!("x{}", i));
}
}

if outputs.len() > 1 {
for (i, p) in outputs.chars().enumerate() {
if types {
if use_pointers {
params.push(format!("{} *y{}", get_ctype(p), i));
} else {
params.push(format!("{} &y{}", get_ctype(p), i));
}
} else {
params.push(format!("y{}", i));
}
}
}

params.join(", ")
}

fn fmt_return(types: &str) -> String {
let chars = types.split("->").last().unwrap();
assert!(chars.len() == 1, "Multiple return values not supported");
get_ctype(chars.chars().next().unwrap()).to_string()
if chars.len() > 1 {
"void".to_string()
} else {
get_ctype(chars.chars().next().unwrap()).to_string()
}
}

fn fmt_func(name: &str, types: &str, suffix: &str) -> String {
let ret = fmt_return(types);
let params = fmt_params(types, true);
let use_pointers = name.ends_with('*');
let params = fmt_params(types, true, use_pointers);
let base_name = XSF_RENAME
.iter()
.find(|(n, _)| *n == name)
.find(|(n, _)| *n == name.trim_end_matches('*'))
.map(|(_, r)| *r)
.unwrap_or(name);
.unwrap_or(name.trim_end_matches('*'));
let func_name = if suffix.is_empty() {
base_name.to_string()
} else {
Expand All @@ -341,8 +371,24 @@ fn fmt_func(name: &str, types: &str, suffix: &str) -> String {
}

fn fmt_call(name: &str, types: &str) -> String {
let args = fmt_params(types, false);
format!("xsf::{}({})", name, args)
let parts: Vec<&str> = types.split("->").collect();
let inputs = parts[0];
let outputs = parts[1];
let clean_name = name.trim_end_matches('*');

let mut args = Vec::new();

for i in 0..inputs.len() {
args.push(format!("x{}", i));
}

if outputs.len() > 1 {
for i in 0..outputs.len() {
args.push(format!("y{}", i));
}
}

format!("xsf::{}({})", clean_name, args.join(", "))
}

fn push_line(source: &mut String, line: &str) {
Expand Down Expand Up @@ -440,7 +486,12 @@ fn build_wrapper(dir_out: &str, include: &str) {
let suffix = if *count == 0 { "" } else { &count.to_string() };
let func_decl = fmt_func(name, types, suffix);
let call = fmt_call(name, types);
push_line(&mut source, &format!("{func_decl} {{ return {call}; }}"));
let ret_type = fmt_return(types);
if ret_type == "void" {
push_line(&mut source, &format!("{func_decl} {{ {call}; }}"));
} else {
push_line(&mut source, &format!("{func_decl} {{ return {call}; }}"));
}
*count += 1;
}

Expand Down Expand Up @@ -503,11 +554,12 @@ fn generate_bindings(dir_out: &str, header: &str) {
let mut name_counts = std::collections::HashMap::new();
for (name, _) in XSF_TYPES {
let count = name_counts.entry(*name).or_insert(0);
let clean_name = name.trim_end_matches('*');
let base_name = XSF_RENAME
.iter()
.find(|(n, _)| n == name)
.find(|(n, _)| n == &clean_name)
.map(|(_, r)| *r)
.unwrap_or(name);
.unwrap_or(clean_name);

if *count == 0 {
allowlist_functions.push(format!("xsf_wrapper::{}", base_name));
Expand Down
Loading
Loading