Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nalgebra compatibility for DualVec #59

Merged
merged 10 commits into from
May 8, 2023
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ nalgebra = "0.32"
pyo3 = { version = "0.18", optional = true, features = ["multiple-pymethods"] }
ndarray = { version = "0.15", optional = true }
numpy = { version = "0.18", optional = true }
approx = "0.5"
simba = "0.8"

[profile.release]
lto = true
Expand All @@ -30,7 +32,6 @@ linalg = ["ndarray"]

[dev-dependencies]
criterion = "0.4"
approx = "0.5"

[[bench]]
name = "benchmark"
Expand Down
244 changes: 244 additions & 0 deletions src/derivative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use nalgebra::constraint::{SameNumberOfRows, ShapeConstraint};
use nalgebra::*;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clippy has some suggestions in this file that you might want to consider.

use std::fmt;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ops::{Add, AddAssign, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num_traits::Zero;

#[derive(PartialEq, Eq, Clone, Debug)]
pub struct Derivative<T: DualNum<F>, F, R: Dim, C: Dim>(
Expand Down Expand Up @@ -35,6 +37,84 @@ where
Self::new(None)
}

pub(crate) fn map<T2, F2>(&self, mut f: impl FnMut(T) -> T2) -> Derivative<T2, F2, R, C>
where
T2: DualNum<F2>,
DefaultAllocator: Allocator<T2, R, C>,
{
let opt = self.0.as_ref().map(move |eps| eps.map(|e| f(e)));
Derivative::new(opt)
}

// A version of map that doesn't clone values before mapping. Useful for the SimdValue impl,
// which would be redundantly cloning all the lanes of each epsilon value before extracting
// just one of them.
//
// To implement, we inline a copy of Matrix::map, which implicitly clones values, and remove
// the cloning.
pub(crate) fn map_borrowed<T2, F2>(
&self,
mut f: impl FnMut(&T) -> T2,
) -> Derivative<T2, F2, R, C>
where
T2: DualNum<F2>,
DefaultAllocator: Allocator<T2, R, C>,
{
let opt = self.0.as_ref().map(move |eps| {
let ref this = eps;
let mut f = |e| f(e);
let (nrows, ncols) = this.shape_generic();
let mut res: Matrix<MaybeUninit<T2>, R, C, _> = Matrix::uninit(nrows, ncols);

for j in 0..ncols.value() {
for i in 0..nrows.value() {
// Safety: all indices are in range.
unsafe {
let a = this.data.get_unchecked(i, j);
*res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a));
}
}
}

// Safety: res is now fully initialized.
unsafe { res.assume_init() }
});
Derivative::new(opt)
}

/// Same but bails out if the closure returns None
pub(crate) fn try_map_borrowed<T2, F2>(
&self,
mut f: impl FnMut(&T) -> Option<T2>,
) -> Option<Derivative<T2, F2, R, C>>
where
T2: DualNum<F2>,
DefaultAllocator: Allocator<T2, R, C>,
{
self.0
.as_ref()
.and_then(move |eps| {
let ref this = eps;
let mut f = |e| f(e);
let (nrows, ncols) = this.shape_generic();
let mut res: Matrix<MaybeUninit<T2>, R, C, _> = Matrix::uninit(nrows, ncols);

for j in 0..ncols.value() {
for i in 0..nrows.value() {
// Safety: all indices are in range.
unsafe {
let a = this.data.get_unchecked(i, j);
*res.data.get_unchecked_mut(i, j) = MaybeUninit::new(f(a)?);
}
}
}

// Safety: res is now fully initialized.
Some(unsafe { res.assume_init() })
})
.map(Derivative::some)
}

pub fn derivative_generic(r: R, c: C, i: usize) -> Self {
let mut m = OMatrix::zeros_generic(r, c);
m[i] = T::one();
Expand Down Expand Up @@ -304,3 +384,167 @@ where
}
}
}

impl<T, R: Dim, C: Dim> nalgebra::SimdValue for Derivative<T, T::Element, R, C>
where
DefaultAllocator: Allocator<T, R, C> + Allocator<T::Element, R, C>,
T: DualNum<T::Element> + SimdValue + Scalar,
T::Element: DualNum<T::Element> + Scalar + Zero,
{
type Element = Derivative<T::Element, T::Element, R, C>;

type SimdBool = T::SimdBool;

#[inline]
fn lanes() -> usize {
T::lanes()
}

#[inline]
fn splat(val: Self::Element) -> Self {
val.map(|e| T::splat(e))
}

#[inline]
fn extract(&self, i: usize) -> Self::Element {
self.map_borrowed(|e| T::extract(e, i))
}

#[inline]
unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
let opt = self
.map_borrowed(|e| T::extract_unchecked(e, i))
.0
// Now check it's all zeros.
// Unfortunately there is no way to use the vectorized version of `is_zero`, which is
// only for matrices with statically known dimensions. Specialization would be
// required.
.filter(|x| Iterator::any(&mut x.iter(), |e| !e.is_zero()));
Derivative::new(opt)
}

// SIMD code will expect to be able to replace one lane with another Self::Element,
// even with a None Derivative, e.g.
//
// let single = Derivative::none();
// let mut x4 = Derivative::splat(single);
// let one = Derivative::some(...);
// x4.replace(1, one);
//
// So the implementation of `replace` will need to auto-upgrade to Some(zeros) in
// order to satisfy requests like that.
fn replace(&mut self, i: usize, val: Self::Element) {
match (&mut self.0, val.0) {
(Some(ours), Some(theirs)) => {
ours.zip_apply(&theirs, |e, replacement| e.replace(i, replacement));
}
(ours @ None, Some(theirs)) => {
let (r, c) = theirs.shape_generic();
let mut init: OMatrix<T, R, C> = OMatrix::zeros_generic(r, c);
init.zip_apply(&theirs, |e, replacement| e.replace(i, replacement));
*ours = Some(init);
}
(Some(ours), None) => {
ours.apply(|e| e.replace(i, T::Element::zero()));
}
_ => {}
}
}

unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
match (&mut self.0, val.0) {
(Some(ours), Some(theirs)) => {
ours.zip_apply(&theirs, |e, replacement| {
e.replace_unchecked(i, replacement)
});
}
(ours @ None, Some(theirs)) => {
let (r, c) = theirs.shape_generic();
let mut init: OMatrix<T, R, C> = OMatrix::zeros_generic(r, c);
init.zip_apply(&theirs, |e, replacement| {
e.replace_unchecked(i, replacement)
});
*ours = Some(init);
}
(Some(ours), None) => {
ours.apply(|e| e.replace_unchecked(i, T::Element::zero()));
}
_ => {}
}
}

fn select(mut self, cond: Self::SimdBool, other: Self) -> Self {
// If cond is mixed, then we may need to generate big zero matrices to do the
// component-wise select on. So check if cond is all-true or all-first to avoid that.
if cond.all() {
self
} else if cond.none() {
other
} else {
match (&mut self.0, other.0) {
(Some(ours), Some(theirs)) => {
ours.zip_apply(&theirs, |e, other_e| {
// this will probably get optimized out
let e_ = std::mem::replace(e, T::zero());
*e = e_.select(cond, other_e)
});
self
}
(Some(ours), None) => {
ours.apply(|e| {
// this will probably get optimized out
let e_ = std::mem::replace(e, T::zero());
*e = e_.select(cond, T::zero());
});
self
}
(ours @ None, Some(mut theirs)) => {
use std::ops::Not;
let inverted: T::SimdBool = cond.not();
theirs.apply(|e| {
// this will probably get optimized out
let e_ = std::mem::replace(e, T::zero());
*e = e_.select(inverted, T::zero());
});
*ours = Some(theirs);
self
}
_ => self,
}
}
}
}

use simba::scalar::{SubsetOf, SupersetOf};

impl<TSuper, FSuper, T, F, R: Dim, C: Dim> SubsetOf<Derivative<TSuper, FSuper, R, C>>
for Derivative<T, F, R, C>
where
TSuper: DualNum<FSuper> + SupersetOf<T>,
T: DualNum<F>,
DefaultAllocator: Allocator<T, R, C>,
DefaultAllocator: Allocator<TSuper, R, C>,
// DefaultAllocator: Allocator<TSuper, D>
// + Allocator<TSuper, U1, D>
// + Allocator<TSuper, D, U1>
// + Allocator<TSuper, D, D>,
{
#[inline(always)]
fn to_superset(&self) -> Derivative<TSuper, FSuper, R, C> {
self.map_borrowed(|elem| TSuper::from_subset(elem))
}
#[inline(always)]
fn from_superset(element: &Derivative<TSuper, FSuper, R, C>) -> Option<Self> {
element.try_map_borrowed(|elem| TSuper::to_subset(elem))
}
#[inline(always)]
fn from_superset_unchecked(element: &Derivative<TSuper, FSuper, R, C>) -> Self {
element.map_borrowed(|elem| TSuper::to_subset_unchecked(elem))
}
#[inline(always)]
fn is_in_subset(element: &Derivative<TSuper, FSuper, R, C>) -> bool {
element.0.as_ref().map_or(true, |matrix| {
matrix.iter().all(|elem| TSuper::is_in_subset(elem))
})
}
}
Loading