Skip to content

Commit

Permalink
Merge pull request #16483 from jakevdp:rot-tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 541829344
  • Loading branch information
jax authors committed Jun 20, 2023
2 parents c2935bf + 6bfcc46 commit 7e5e501
Showing 1 changed file with 55 additions and 37 deletions.
92 changes: 55 additions & 37 deletions tests/scipy_spatial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def testRotationApply(self, shape, vector_shape, dtype, inverse):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(vector_shape, dtype),)
jnp_fn = lambda q, v: jsp_Rotation.from_quat(q).apply(v, inverse=inverse)
np_fn = lambda q, v: osp_Rotation.from_quat(q).apply(v, inverse=inverse).astype(dtype) # HACK
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q, v: osp_Rotation.from_quat(q).apply(v, inverse=inverse).astype(dtype) # HACK
tol = 5e-2 if jtu.device_under_test() == 'tpu' else 1e-4
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

@jtu.sample_product(
Expand All @@ -65,8 +66,9 @@ def testRotationAsEuler(self, shape, dtype, seq, degrees):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees)
np_fn = lambda q: osp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees).astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: osp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees).astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -77,8 +79,9 @@ def testRotationAsMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_matrix()
np_fn = lambda q: osp_Rotation.from_quat(q).as_matrix().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: osp_Rotation.from_quat(q).as_matrix().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -89,8 +92,9 @@ def testRotationAsMrp(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_mrp()
np_fn = lambda q: osp_Rotation.from_quat(q).as_mrp().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: osp_Rotation.from_quat(q).as_mrp().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -102,9 +106,10 @@ def testRotationAsRotvec(self, shape, dtype, degrees):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_rotvec(degrees=degrees)
np_fn = lambda q: osp_Rotation.from_quat(q).as_rotvec(degrees=degrees).astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True,
tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: osp_Rotation.from_quat(q).as_rotvec(degrees=degrees).astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True,
# tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -115,8 +120,9 @@ def testRotationAsQuat(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat()
np_fn = lambda q: osp_Rotation.from_quat(q).as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: osp_Rotation.from_quat(q).as_quat().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -130,8 +136,9 @@ def testRotationConcatenate(self, shape, other_shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype),)
jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_quat()
np_fn = lambda q, o: osp_Rotation.concatenate([osp_Rotation.from_quat(q), osp_Rotation.from_quat(o)]).as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q, o: osp_Rotation.concatenate([osp_Rotation.from_quat(q), osp_Rotation.from_quat(o)]).as_quat().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -143,8 +150,9 @@ def testRotationGetItem(self, shape, dtype, indexer):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q)[indexer].as_quat()
np_fn = lambda q: osp_Rotation.from_quat(q)[indexer].as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: osp_Rotation.from_quat(q)[indexer].as_quat().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -158,8 +166,9 @@ def testRotationFromEuler(self, size, dtype, seq, degrees):
shape = (size, len(seq))
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda a: jsp_Rotation.from_euler(seq, a, degrees).as_rotvec()
np_fn = lambda a: osp_Rotation.from_euler(seq, a, degrees).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda a: osp_Rotation.from_euler(seq, a, degrees).as_rotvec().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -170,8 +179,9 @@ def testRotationFromMatrix(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda m: jsp_Rotation.from_matrix(m).as_rotvec()
np_fn = lambda m: osp_Rotation.from_matrix(m).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda m: osp_Rotation.from_matrix(m).as_rotvec().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -182,8 +192,9 @@ def testRotationFromMrp(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda m: jsp_Rotation.from_mrp(m).as_rotvec()
np_fn = lambda m: osp_Rotation.from_mrp(m).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda m: osp_Rotation.from_mrp(m).as_rotvec().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -194,8 +205,9 @@ def testRotationFromRotvec(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda r: jsp_Rotation.from_rotvec(r).as_quat()
np_fn = lambda r: osp_Rotation.from_rotvec(r).as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda r: osp_Rotation.from_rotvec(r).as_quat().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -205,8 +217,9 @@ def testRotationFromRotvec(self, shape, dtype):
def testRotationIdentity(self, num, dtype):
args_maker = lambda: (num,)
jnp_fn = lambda n: jsp_Rotation.identity(n, dtype).as_quat()
np_fn = lambda n: osp_Rotation.identity(n).as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda n: osp_Rotation.identity(n).as_quat().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -230,9 +243,10 @@ def testRotationMean(self, shape, dtype, rng_weights):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None)
jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec()
np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK
tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

@jtu.sample_product(
Expand All @@ -244,8 +258,9 @@ def testRotationMultiply(self, shape, other_shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype))
jnp_fn = lambda q, o: (jsp_Rotation.from_quat(q) * jsp_Rotation.from_quat(o)).as_rotvec()
np_fn = lambda q, o: (osp_Rotation.from_quat(q) * osp_Rotation.from_quat(o)).as_rotvec().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q, o: (osp_Rotation.from_quat(q) * osp_Rotation.from_quat(o)).as_rotvec().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -256,8 +271,9 @@ def testRotationInv(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_quat()
np_fn = lambda q: osp_Rotation.from_quat(q).inv().as_quat().astype(dtype) # HACK
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: osp_Rotation.from_quat(q).inv().as_quat().astype(dtype) # HACK
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -268,8 +284,9 @@ def testRotationLen(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: len(jsp_Rotation.from_quat(q))
np_fn = lambda q: len(osp_Rotation.from_quat(q))
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: len(osp_Rotation.from_quat(q))
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand All @@ -280,8 +297,9 @@ def testRotationSingle(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng(shape, dtype),)
jnp_fn = lambda q: jsp_Rotation.from_quat(q).single
np_fn = lambda q: osp_Rotation.from_quat(q).single
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
# TODO(chrisflesher): re-enable this after accounting for sign degeneracy
# np_fn = lambda q: osp_Rotation.from_quat(q).single
# self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4)
self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

@jtu.sample_product(
Expand Down

0 comments on commit 7e5e501

Please sign in to comment.