Skip to content

Commit

Permalink
Add nalgebra-API tests
Browse files Browse the repository at this point in the history
eps[0] to eps.unwrap()
  • Loading branch information
Cormac Relf authored and cormac-ainc committed May 2, 2023
1 parent b3b06ac commit 4055007
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions tests/test_dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,53 @@ fn test_dual_bessel_j2_4() {
assert!((res.re - -0.279979741339189).abs() < 1e-12);
assert!((res.eps.unwrap() - -0.132099570594364).abs() < 1e-12);
}

mod nalgebra_api {
use nalgebra::{Point2, Point3, UnitQuaternion, Vector2, Vector3};
use num_dual::*;
use num_traits::Zero;
use std::f32::consts::*;

fn unit_circle(t: Dual32) -> Point2<Dual32> {
let x_dir = |x: Dual32| Vector2::new(x, Dual32::zero());
let y_dir = |y: Dual32| Vector2::new(Dual32::zero(), y);
let theta = t * TAU;
Point2::from(x_dir(theta.cos()) + y_dir(theta.sin()))
}

// This is testing that you can type-check code that whacks DualVec in
// nalgebra structures and tries to use them.
#[test]
fn use_nalgebra_2d() {
// 1 radian around the circle
let t = Dual32::from_re(0.25).derivative();
let point = unit_circle(t);
let real = point.map(|x| x.re);
let grad = point.map(|x| x.eps.unwrap());
println!("point: {}", point.coords);
approx::assert_relative_eq!(real, Point2::new(0., 1.), epsilon = 1e-3);
approx::assert_relative_eq!(grad, Point2::new(-TAU, 0.), epsilon = 1e-3);
}

#[test]
fn use_nalgebra_3d() {
// First one does nothing to the gradient. Still got no y or z direction.
let rot1 = UnitQuaternion::from_axis_angle(&Vector3::x_axis(), FRAC_PI_8);
// Second one goes pi/8 (22.5deg) about the y axis, which should shift some of the steepness
// in the x direction into the z direction, but not much.
let rot2 = UnitQuaternion::from_axis_angle(&Vector3::y_axis(), FRAC_PI_8);
let rotation = (rot2 * rot1).cast::<Dual32>();
let lifted_3d_circle = |t: Dual32| {
let xy = unit_circle(t); // [0, 1] with derivatives [-1, 0]
Point3::new(xy.x, xy.y, Dual32::zero())
};
let function = |t: Dual32| rotation * lifted_3d_circle(t);
let point = function(Dual32::from_re(0.25).derivative());
let real = point.map(|x| x.re);
let grad = point.map(|x| x.eps.unwrap());
println!("rotated point: {}", point.coords);
approx::assert_relative_eq!(real.coords, real.coords.normalize(), epsilon = 1e-3);
approx::assert_relative_eq!(real, Point3::new(0.146, 0.924, 0.354), epsilon = 1e-3);
approx::assert_relative_eq!(grad, Point3::new(-5.805, 0., 2.404), epsilon = 1e-3);
}
}

0 comments on commit 4055007

Please sign in to comment.