diff --git a/build.rs b/build.rs index aef532d..170e3ae 100644 --- a/build.rs +++ b/build.rs @@ -6,7 +6,7 @@ const WRAPPER_NAME: &str = "xsf_wrapper"; const WRAPPER_PREAMBLE: &str = "// generated by build.rs -- do not edit\n\n"; const XSF_HEADERS: &[&str] = &[ - // "airy.h", + "airy.h", "alg.h", "bessel.h", "beta.h", @@ -45,11 +45,17 @@ const XSF_HEADERS: &[&str] = &[ // e.g. `("spam", "if->d")` becomes `double xsf_spam(int x0, float x1)` const XSF_TYPES: &[(&str, &str)] = &[ // airy.h - // TODO: `airyb`, `airyzo`, `airy`, `airye`, `itairy` + ("airy", "d->dddd"), + ("airy", "D->DDDD"), + ("airye", "d->dddd"), + ("airye", "D->DDDD"), + ("itairy", "d->dddd"), + ("airyb*", "d->dddd"), + ("airyzo*", "ii->dddd"), + // alg.h ("cbrt", "d->d"), // bessel.h // TODO: `it1j0y0`, `it2j0y0`, `it1i0k0`, `it2i0k0` - // TODO: `rctj`, `rcty`, ("cyl_bessel_j", "dd->d"), ("cyl_bessel_j", "dD->D"), ("cyl_bessel_je", "dd->d"), @@ -127,7 +133,7 @@ const XSF_TYPES: &[(&str, &str)] = &[ ("expi", "D->D"), ("scaled_exp1", "d->d"), // fresnel.h - // TODO: `fresnel`, `fcszo` + // TODO // gamma.h ("gamma", "d->d"), ("gamma", "D->D"), @@ -144,7 +150,8 @@ const XSF_TYPES: &[(&str, &str)] = &[ ("iv_ratio", "dd->d"), ("iv_ratio_c", "dd->d"), // kelvin.h - // TODO: `kelvin`, `klvnzo` + // TODO: `kelvin`: d->DDDD + // TODO: `klvnzo`: ii->[d] ("ber", "d->d"), ("bei", "d->d"), ("ker", "d->d"), @@ -156,6 +163,8 @@ const XSF_TYPES: &[(&str, &str)] = &[ // lambertw.h ("lambertw", "Dld->D"), // legendre.h + // TODO: `lqn`: d->[d],[d] + // TODO: `lqmn`: d->[[d]],[[d]] ("legendre_p", "id->d"), ("legendre_p", "iD->D"), ("sph_legendre_p", "iid->d"), @@ -210,9 +219,10 @@ const XSF_TYPES: &[(&str, &str)] = &[ ("sph_bessel_k_jac", "ld->d"), ("sph_bessel_k_jac", "lD->D"), // sph_harm.h + // TODO: `sph_harm_y_all`: dd->[[d]] ("sph_harm_y", "iidd->D"), // sphd_wave.h - // TODO: aswfa and radial + // TODO: `*_aswfa*` and `*_radial*` ("prolate_segv", "ddd->d"), ("oblate_segv", "ddd->d"), // stats.h @@ -302,36 +312,56 @@ fn get_ctype(code: char) -> &'static str { } } -fn fmt_params(spec: &str, types: bool) -> String { - // TODO: >1 outputs, e.g. `ii->di` => `void _(int x0, int x1, double &y0, int &y1)` - spec.split("->") - .next() - .unwrap() - .chars() - .enumerate() - .map(if types { - |(i, p)| format!("{} x{}", get_ctype(p), i) +fn fmt_params(spec: &str, types: bool, use_pointers: bool) -> String { + let parts: Vec<&str> = spec.split("->").collect(); + let inputs = parts[0]; + let outputs = parts[1]; + + let mut params = Vec::new(); + + for (i, p) in inputs.chars().enumerate() { + if types { + params.push(format!("{} x{}", get_ctype(p), i)); } else { - |(i, _)| format!("x{}", i) - }) - .collect::>() - .join(", ") + params.push(format!("x{}", i)); + } + } + + if outputs.len() > 1 { + for (i, p) in outputs.chars().enumerate() { + if types { + if use_pointers { + params.push(format!("{} *y{}", get_ctype(p), i)); + } else { + params.push(format!("{} &y{}", get_ctype(p), i)); + } + } else { + params.push(format!("y{}", i)); + } + } + } + + params.join(", ") } fn fmt_return(types: &str) -> String { let chars = types.split("->").last().unwrap(); - assert!(chars.len() == 1, "Multiple return values not supported"); - get_ctype(chars.chars().next().unwrap()).to_string() + if chars.len() > 1 { + "void".to_string() + } else { + get_ctype(chars.chars().next().unwrap()).to_string() + } } fn fmt_func(name: &str, types: &str, suffix: &str) -> String { let ret = fmt_return(types); - let params = fmt_params(types, true); + let use_pointers = name.ends_with('*'); + let params = fmt_params(types, true, use_pointers); let base_name = XSF_RENAME .iter() - .find(|(n, _)| *n == name) + .find(|(n, _)| *n == name.trim_end_matches('*')) .map(|(_, r)| *r) - .unwrap_or(name); + .unwrap_or(name.trim_end_matches('*')); let func_name = if suffix.is_empty() { base_name.to_string() } else { @@ -341,8 +371,24 @@ fn fmt_func(name: &str, types: &str, suffix: &str) -> String { } fn fmt_call(name: &str, types: &str) -> String { - let args = fmt_params(types, false); - format!("xsf::{}({})", name, args) + let parts: Vec<&str> = types.split("->").collect(); + let inputs = parts[0]; + let outputs = parts[1]; + let clean_name = name.trim_end_matches('*'); + + let mut args = Vec::new(); + + for i in 0..inputs.len() { + args.push(format!("x{}", i)); + } + + if outputs.len() > 1 { + for i in 0..outputs.len() { + args.push(format!("y{}", i)); + } + } + + format!("xsf::{}({})", clean_name, args.join(", ")) } fn push_line(source: &mut String, line: &str) { @@ -440,7 +486,12 @@ fn build_wrapper(dir_out: &str, include: &str) { let suffix = if *count == 0 { "" } else { &count.to_string() }; let func_decl = fmt_func(name, types, suffix); let call = fmt_call(name, types); - push_line(&mut source, &format!("{func_decl} {{ return {call}; }}")); + let ret_type = fmt_return(types); + if ret_type == "void" { + push_line(&mut source, &format!("{func_decl} {{ {call}; }}")); + } else { + push_line(&mut source, &format!("{func_decl} {{ return {call}; }}")); + } *count += 1; } @@ -503,11 +554,12 @@ fn generate_bindings(dir_out: &str, header: &str) { let mut name_counts = std::collections::HashMap::new(); for (name, _) in XSF_TYPES { let count = name_counts.entry(*name).or_insert(0); + let clean_name = name.trim_end_matches('*'); let base_name = XSF_RENAME .iter() - .find(|(n, _)| n == name) + .find(|(n, _)| n == &clean_name) .map(|(_, r)| *r) - .unwrap_or(name); + .unwrap_or(clean_name); if *count == 0 { allowlist_functions.push(format!("xsf_wrapper::{}", base_name)); diff --git a/src/airy.rs b/src/airy.rs new file mode 100644 index 0000000..2ed72e3 --- /dev/null +++ b/src/airy.rs @@ -0,0 +1,220 @@ +use crate::bindings; +use num_complex::Complex; +use std::os::raw::c_int; + +mod sealed { + use num_complex::Complex; + + pub trait Sealed {} + impl Sealed for f64 {} + impl Sealed for Complex {} +} + +pub trait AiryArg: sealed::Sealed { + type Output; + + fn airy(self) -> (Self::Output, Self::Output, Self::Output, Self::Output); + fn airye(self) -> (Self::Output, Self::Output, Self::Output, Self::Output); +} + +#[inline(always)] +fn c_c64_nan() -> bindings::root::std::complex { + Complex::new(f64::NAN, f64::NAN).into() +} + +impl AiryArg for f64 { + type Output = f64; + + #[inline(always)] + fn airy(self) -> (Self::Output, Self::Output, Self::Output, Self::Output) { + let mut ai = f64::NAN; + let mut aip = f64::NAN; + let mut bi = f64::NAN; + let mut bip = f64::NAN; + + unsafe { + bindings::airy(self, &mut ai, &mut aip, &mut bi, &mut bip); + } + (ai, aip, bi, bip) + } + + #[inline(always)] + fn airye(self) -> (Self::Output, Self::Output, Self::Output, Self::Output) { + let mut ai = f64::NAN; + let mut aip = f64::NAN; + let mut bi = f64::NAN; + let mut bip = f64::NAN; + + unsafe { + bindings::airye(self, &mut ai, &mut aip, &mut bi, &mut bip); + } + (ai, aip, bi, bip) + } +} + +impl AiryArg for Complex { + type Output = Complex; + + #[inline(always)] + fn airy(self) -> (Self::Output, Self::Output, Self::Output, Self::Output) { + let mut ai = c_c64_nan(); + let mut bi = c_c64_nan(); + let mut ad = c_c64_nan(); + let mut bd = c_c64_nan(); + + unsafe { + bindings::airy_1(self.into(), &mut ai, &mut bi, &mut ad, &mut bd); + } + (ai.into(), bi.into(), ad.into(), bd.into()) + } + + #[inline(always)] + fn airye(self) -> (Self::Output, Self::Output, Self::Output, Self::Output) { + let mut ai = c_c64_nan(); + let mut bi = c_c64_nan(); + let mut ad = c_c64_nan(); + let mut bd = c_c64_nan(); + + unsafe { + bindings::airye_1(self.into(), &mut ai, &mut bi, &mut ad, &mut bd); + } + (ai.into(), bi.into(), ad.into(), bd.into()) + } +} + +/// Airy functions and their derivatives. +/// +/// # Arguments +/// +/// - `z` - real (`f64`) or complex (`num_complex::Complex`) argument +/// +/// # Returns +/// +/// A tuple `(Ai, Aip, Bi, Bip)` where: +/// - `Ai` - Ai(z) +/// - `Aip` - Ai'(z) +/// - `Bi` - Bi(z) +/// - `Bip` - Bi'(z) +pub fn airy(z: T) -> (T::Output, T::Output, T::Output, T::Output) { + z.airy() +} + +/// Exponentially scaled Airy functions and their derivatives. +/// +/// Scaling: +/// +/// ```plain +/// eAi(z) = Ai(z) * exp(2/3 * z * sqrt(z)) +/// eAi'(z) = Ai'(z) * exp(2/3 * z * sqrt(z)) +/// eBi(z) = Bi(z) * exp(-abs(2/3 * (z * sqrt(z)).real)) +/// eBi'(z) = Bi'(z) * exp(-abs(2/3 * (z * sqrt(z)).real)) +/// ``` +/// +/// # Arguments +/// +/// - `z` - real (`f64`) or complex (`num_complex::Complex`) argument +/// +/// # Returns +/// +/// A tuple `(eAi, eAip, eBi, eBip)` where: +/// - `eAi` - eAi(z) +/// - `eAip` - eAi'(z) +/// - `eBi` - eBi(z) +/// - `eBip` - eBi'(z) +pub fn airye(z: T) -> (T::Output, T::Output, T::Output, T::Output) { + z.airye() +} + +/// Integrals of Airy functions +/// +/// Calculates the integrals of Airy functions from 0 to `x`. +/// +/// # Arguments +/// +/// - `x` - Upper limit of the integral (x ≥ 0) +/// +/// # Returns +/// +/// A tuple `(Apt, Bpt, Ant, Bnt)` where: +/// - `Apt` - Integral of Ai(t) from 0 to x +/// - `Bpt` - Integral of Bi(t) from 0 to x +/// - `Ant` - Integral of Ai(-t) from 0 to x +/// - `Bnt` - Integral of Bi(-t) from 0 to x +pub fn itairy(x: f64) -> (f64, f64, f64, f64) { + let mut apt = f64::NAN; + let mut bpt = f64::NAN; + let mut ant = f64::NAN; + let mut bnt = f64::NAN; + + unsafe { + bindings::itairy(x, &mut apt, &mut bpt, &mut ant, &mut bnt); + } + (apt, bpt, ant, bnt) +} + +/// Airy functions and their derivatives. +/// +/// # Arguments +/// +/// - `x` - Real argument +/// +/// # Returns +/// +/// A tuple `(Ai, Bi, Aip, Bip)` where: +/// - `Ai` - Ai(x) +/// - `Bi` - Bi(x) +/// - `Aip` - Ai'(x) +/// - `Bip` - Bi'(x) +pub fn airyb(x: f64) -> (f64, f64, f64, f64) { + let mut ai = f64::NAN; + let mut bi = f64::NAN; + let mut aip = f64::NAN; + let mut bip = f64::NAN; + + unsafe { + bindings::airyb(x, &mut ai, &mut bi, &mut aip, &mut bip); + } + (ai, bi, aip, bip) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AiryKind { + Ai = 1, + Bi = 2, +} + +/// Zeros of Airy functions and their associated values. +/// +/// This function computes the first `nt` zeros of Airy functions Ai(x) and Ai'(x), a and a', +/// and the associated values of Ai(a') and Ai'(a); and the first `nt` zeros of Airy functions +/// Bi(x) and Bi'(x), b and b', and the associated values of Bi(b') and Bi'(b). +/// +/// # Arguments +/// +/// - `nt` - Total number of zeros to compute +/// - `kf` - Function code: +/// - `1` for Ai(x) and Ai'(x) +/// - `2` for Bi(x) and Bi'(x) +/// +/// # Returns +/// +/// A tuple `(xa, xb, xc, xd)` where: +/// - `xa` - The m-th zero *a* of Ai(x) or the m-th zero *b* of Bi(x) +/// - `xb` - The m-th zero *a'* of Ai'(x) or the m-th zero *b'* of Bi'(x) +/// - `xc` - Ai(*a'*) or Bi(*b'*) +/// - `xd` - Ai'(*a*) or Bi'(*b*) +/// +/// where m is the serial number of zeros. +pub fn airyzo(nt: u32, kf: AiryKind) -> (f64, f64, f64, f64) { + assert!(nt > 0); + + let mut xa = f64::NAN; + let mut xb = f64::NAN; + let mut xc = f64::NAN; + let mut xd = f64::NAN; + + unsafe { + bindings::airyzo(nt as c_int, kf as c_int, &mut xa, &mut xb, &mut xc, &mut xd); + } + (xa, xb, xc, xd) +} diff --git a/src/lib.rs b/src/lib.rs index 0c1aadc..e4cbcda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,9 @@ mod bindings; +// airy.h +mod airy; +pub use airy::{airy, airyb, airye, airyzo, itairy}; + // alg.h mod alg; pub use alg::cbrt; diff --git a/tests/test_functions.rs b/tests/test_functions.rs index 06911c1..55ab717 100644 --- a/tests/test_functions.rs +++ b/tests/test_functions.rs @@ -59,6 +59,78 @@ mod xsref { } } + impl TestOutput for (f64, f64, f64, f64) { + fn from_parquet_rows(rows: Vec>) -> Vec { + rows.into_iter() + .map(|row| (row[0], row[1], row[2], row[3])) + .collect() + } + + fn error(actual: Self, expected: Self) -> f64 { + let errors = [ + relative_error(actual.0, expected.0), + relative_error(actual.1, expected.1), + relative_error(actual.2, expected.2), + relative_error(actual.3, expected.3), + ]; + errors.iter().fold(0.0, |acc, &x| acc.max(x)) + } + + fn magnitude(self) -> f64 { + (self.0.abs() + self.1.abs() + self.2.abs() + self.3.abs()) / 4.0 + } + + fn format_value(self) -> String { + format!( + "({:.6e}, {:.6e}, {:.6e}, {:.6e})", + self.0, self.1, self.2, self.3 + ) + } + } + + impl TestOutput for (Complex, Complex, Complex, Complex) { + fn from_parquet_rows(rows: Vec>) -> Vec { + rows.into_iter() + .map(|row| { + ( + Complex::new(row[0], row[1]), + Complex::new(row[2], row[3]), + Complex::new(row[4], row[5]), + Complex::new(row[6], row[7]), + ) + }) + .collect() + } + + fn error(actual: Self, expected: Self) -> f64 { + let errors = [ + complex_relative_error(actual.0, expected.0), + complex_relative_error(actual.1, expected.1), + complex_relative_error(actual.2, expected.2), + complex_relative_error(actual.3, expected.3), + ]; + errors.iter().fold(0.0, |acc, &x| acc.max(x)) + } + + fn magnitude(self) -> f64 { + (self.0.norm() + self.1.norm() + self.2.norm() + self.3.norm()) / 4.0 + } + + fn format_value(self) -> String { + format!( + "({:.6e}+{:.6e}i, {:.6e}+{:.6e}i, {:.6e}+{:.6e}i, {:.6e}+{:.6e}i)", + self.0.re, + self.0.im, + self.1.re, + self.1.im, + self.2.re, + self.2.im, + self.3.re, + self.3.im + ) + } + } + #[derive(Debug)] struct TestCase { pub inputs: Vec, @@ -350,6 +422,22 @@ macro_rules! _test { } } }; + ($test_name:ident, $f:ident, $sig:literal, $test_fn:expr, (f64, f64, f64, f64)) => { + paste::paste! { + #[test] + fn $test_name() { + xsref::test::<(f64, f64, f64, f64), _>(stringify!($f), $sig, $test_fn); + } + } + }; + ($test_name:ident, $f:ident, $sig:literal, $test_fn:expr, (Complex, Complex, Complex, Complex)) => { + paste::paste! { + #[test] + fn $test_name() { + xsref::test::<(num_complex::Complex, num_complex::Complex, num_complex::Complex, num_complex::Complex), _>(stringify!($f), $sig, $test_fn); + } + } + }; } /// Generate a test function for xsf functions @@ -440,6 +528,28 @@ macro_rules! xsref_test { ); } }; + (@single $f:ident, "d->dddd") => { + paste::paste! { + _test!( + [], + $f, + "d-d_d_d_d", + |x: &[f64]| xsf::$f(x[0]), + (f64, f64, f64, f64) + ); + } + }; + (@single $f:ident, "D->DDDD") => { + paste::paste! { + _test!( + [], + $f, + "cd-cd_cd_cd_cd", + |x: &[f64]| xsf::$f(num_complex::Complex::new(x[0], x[1])), + (Complex, Complex, Complex, Complex) + ); + } + }; (@single $f:ident, "D->D") => { paste::paste! { _test!( @@ -498,34 +608,42 @@ macro_rules! xsref_test { ); } }; - (@single $f:ident, "dddD->D") => { + (@single $f:ident, "Dld->D") => { paste::paste! { _test!( [], $f, - "d_d_d_cd-cd", - |x: &[f64]| xsf::$f(x[0], x[1], x[2], num_complex::Complex::new(x[3], x[4])), + "cd_p_d-cd", + |x: &[f64]| xsf::$f( + num_complex::Complex::new(x[0], x[1]), + x[2] as std::os::raw::c_long, + x[3], + ), Complex ); } }; - (@single $f:ident, "Dld->D") => { + (@single $f:ident, "dddD->D") => { paste::paste! { _test!( [], $f, - "cd_p_d-cd", - |x: &[f64]| xsf::$f( - num_complex::Complex::new(x[0], x[1]), - x[2] as std::os::raw::c_long, - x[3], - ), + "d_d_d_cd-cd", + |x: &[f64]| xsf::$f(x[0], x[1], x[2], num_complex::Complex::new(x[3], x[4])), Complex ); } }; } +// airy.h +// xsref_test!(airyb, "d->dddd"); // no xsref table +xsref_test!(airy, "d->dddd"); +xsref_test!(airy, "D->DDDD"); +xsref_test!(airye, "d->dddd"); +xsref_test!(airye, "D->DDDD"); +xsref_test!(itairy, "d->dddd"); + // alg.h xsref_test!(cbrt, "d->d");