Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ const WRAPPER_SPECS: &[(&str, &str)] = &[
// lambertw.h
("lambertw", "Dld->D"),
// legendre.h
// TODO: `lqmn`: d->[[d]],[[d]]
("legendre_p", "id->d"),
("legendre_p", "iD->D"),
("sph_legendre_p", "iid->d"),
Expand Down Expand Up @@ -349,16 +348,54 @@ cdouble assoc_legendre_p_1_1(int n, int m, cdouble z, int bc) {
return xsf::assoc_legendre_p(xsf::assoc_legendre_norm, n, m, z, bc);
}"#;

const _CPP_LEGENDRE_P_ALL: &str = r#"
void legendre_p_all(size_t n, double x, double *pn) {
xsf::legendre_p_all(x, std::mdspan(pn, n + 1));
}
void legendre_p_all_1(size_t n, cdouble z, cdouble *pn) {
xsf::legendre_p_all(z, std::mdspan(pn, n + 1));
}"#;

const _CPP_SPH_LEGENDRE_P_ALL: &str = r#"
void sph_legendre_p_all(size_t n, size_t m, double x, double *pnm) {
xsf::sph_legendre_p_all(x, std::mdspan(pnm, n + 1, 2 * m + 1));
}
void sph_legendre_p_all_1(size_t n, size_t m, cdouble z, cdouble *pnm) {
xsf::sph_legendre_p_all(z, std::mdspan(pnm, n + 1, 2 * m + 1));
}"#;

const _CPP_ASSOC_LEGENDRE_P_ALL: &str = r#"
void assoc_legendre_p_all_0(size_t n, size_t m, double z, int bc, double *pnm) {
auto res = std::mdspan(pnm, n + 1, 2 * m + 1);
return xsf::assoc_legendre_p_all(xsf::assoc_legendre_unnorm, z, bc, res);
}
void assoc_legendre_p_all_0_1(size_t n, size_t m, cdouble z, int bc, cdouble *pnm) {
auto res = std::mdspan(pnm, n + 1, 2 * m + 1);
return xsf::assoc_legendre_p_all(xsf::assoc_legendre_unnorm, z, bc, res);
}
void assoc_legendre_p_all_1(size_t n, size_t m, double z, int bc, double *pnm) {
auto res = std::mdspan(pnm, n + 1, 2 * m + 1);
return xsf::assoc_legendre_p_all(xsf::assoc_legendre_norm, z, bc, res);
}
void assoc_legendre_p_all_1_1(size_t n, size_t m, cdouble z, int bc, cdouble *pnm) {
auto res = std::mdspan(pnm, n + 1, 2 * m + 1);
return xsf::assoc_legendre_p_all(xsf::assoc_legendre_norm, z, bc, res);
}"#;

const _CPP_LQN: &str = r#"
void lqn(int n, double x, double *qn, double *qd) {
auto qn_wrapper = std::mdspan(qn, n + 1);
auto qd_wrapper = std::mdspan(qd, n + 1);
xsf::lqn(x, qn_wrapper, qd_wrapper);
void lqn(size_t n, double x, double *qn, double *qd) {
xsf::lqn(x, std::mdspan(qn, n + 1), std::mdspan(qd, n + 1));
}
void lqn_1(size_t n, cdouble z, cdouble *cqn, cdouble *cqd) {
xsf::lqn(z, std::mdspan(cqn, n + 1), std::mdspan(cqd, n + 1));
}"#;

const _CPP_LQMN: &str = r#"
void lqmn(size_t m, size_t n, double x, double *qm, double *qd) {
xsf::lqmn(x, std::mdspan(qm, m + 1, n + 1), std::mdspan(qd, m + 1, n + 1));
}
void lqn_1(int n, cdouble z, cdouble *cqn, cdouble *cqd) {
auto cqn_wrapper = std::mdspan(cqn, n + 1);
auto cqd_wrapper = std::mdspan(cqd, n + 1);
xsf::lqn(z, cqn_wrapper, cqd_wrapper);
void lqmn_1(size_t m, size_t n, cdouble z, cdouble *qm, cdouble *qd) {
xsf::lqmn(z, std::mdspan(qm, m + 1, n + 1), std::mdspan(qd, m + 1, n + 1));
}"#;

struct WrapperSpecCustom {
Expand All @@ -384,6 +421,10 @@ impl WrapperSpecCustom {
}

const WRAPPER_SPECS_CUSTOM: &[WrapperSpecCustom] = &[
WrapperSpecCustom {
pattern: r"complex__(new|values)",
cpp: _CPP_COMPLEX_HELPERS,
},
WrapperSpecCustom {
pattern: r"cevalpoly",
cpp: _CPP_CEVALPOLY,
Expand All @@ -392,13 +433,25 @@ const WRAPPER_SPECS_CUSTOM: &[WrapperSpecCustom] = &[
pattern: r"assoc_legendre_p_(0|1)",
cpp: _CPP_ASSOC_LEGENDRE_P,
},
WrapperSpecCustom {
pattern: r"legendre_p_all",
cpp: _CPP_LEGENDRE_P_ALL,
},
WrapperSpecCustom {
pattern: r"sph_legendre_p_all",
cpp: _CPP_SPH_LEGENDRE_P_ALL,
},
WrapperSpecCustom {
pattern: r"assoc_legendre_p_all_(0|1)",
cpp: _CPP_ASSOC_LEGENDRE_P_ALL,
},
WrapperSpecCustom {
pattern: r"lqn",
cpp: _CPP_LQN,
},
WrapperSpecCustom {
pattern: r"complex__(new|values)",
cpp: _CPP_COMPLEX_HELPERS,
pattern: r"lqmn",
cpp: _CPP_LQMN,
},
];

Expand Down
28 changes: 25 additions & 3 deletions src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,31 @@
#![allow(non_camel_case_types)]
#![allow(dead_code)]

use alloc::vec::Vec;
use num_complex::Complex;

include!(concat!(env!("OUT_DIR"), "/bindings.rs"));

pub(crate) use root::std::complex;
pub(crate) use root::xsf_wrapper::*;

pub(crate) type cdouble = complex<f64>;

// C++ std::complex type wrapper

impl cdouble {
pub(crate) fn new(re: f64, im: f64) -> Self {
unsafe { complex__new(re, im) }
}
}

impl From<num_complex::Complex<f64>> for cdouble {
fn from(z: num_complex::Complex<f64>) -> Self {
impl From<Complex<f64>> for cdouble {
fn from(z: Complex<f64>) -> Self {
Self::new(z.re, z.im)
}
}

impl From<cdouble> for num_complex::Complex<f64> {
impl From<cdouble> for Complex<f64> {
fn from(z: cdouble) -> Self {
let mut re: f64 = 0.0;
let mut im: f64 = 0.0;
Expand All @@ -32,11 +37,28 @@ impl From<cdouble> for num_complex::Complex<f64> {
}
}

// complex helper functions

#[inline(always)]
pub(crate) fn complex_nan() -> cdouble {
complex::new(f64::NAN, f64::NAN)
}

#[inline(always)]
pub(crate) fn complex_zeros(n: usize) -> Vec<cdouble> {
(0..n).map(|_| complex::new(0.0, 0.0)).collect()
}

#[inline(always)]
pub(crate) fn cvec_into<T>(xs: Vec<complex<T>>) -> Vec<Complex<T>>
where
Complex<T>: From<complex<T>>,
{
xs.into_iter().map(|c| c.into()).collect()
}

// macros

macro_rules! xsf_impl {
($name:ident, ($($param:ident: $type:ty),*), $docs:expr) => {
#[doc = $docs]
Expand Down
Loading
Loading