From 5a2936d19dc34f961fe08bfa410a94c5fac98e25 Mon Sep 17 00:00:00 2001 From: Chris Flesher Date: Tue, 20 Jun 2023 05:17:41 -0500 Subject: [PATCH] Fix problematic rotation tests --- tests/scipy_spatial_test.py | 93 +++++++++++++++---------------------- 1 file changed, 38 insertions(+), 55 deletions(-) diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index 9dec53e6cba0..559b2d283167 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -66,9 +66,8 @@ def testRotationAsEuler(self, shape, dtype, seq, degrees): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees) - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: osp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees).astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + np_fn = lambda q: osp_Rotation.from_quat(q).as_euler(seq=seq, degrees=degrees).astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -79,9 +78,8 @@ def testRotationAsMatrix(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_matrix() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: osp_Rotation.from_quat(q).as_matrix().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + np_fn = lambda q: osp_Rotation.from_quat(q).as_matrix().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -92,9 +90,8 @@ def testRotationAsMrp(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_mrp() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: osp_Rotation.from_quat(q).as_mrp().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + np_fn = lambda q: osp_Rotation.from_quat(q).as_mrp().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -106,10 +103,8 @@ def testRotationAsRotvec(self, shape, dtype, degrees): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_rotvec(degrees=degrees) - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: osp_Rotation.from_quat(q).as_rotvec(degrees=degrees).astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, - # tol=1e-4) + np_fn = lambda q: osp_Rotation.from_quat(q).as_rotvec(degrees=degrees).astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -119,10 +114,9 @@ def testRotationAsRotvec(self, shape, dtype, degrees): def testRotationAsQuat(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) - jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: osp_Rotation.from_quat(q).as_quat().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + jnp_fn = lambda q: jsp_Rotation.from_quat(jnp.where(jnp.sum(q, axis=0) > 0, q, -q)).as_quat() + np_fn = lambda q: osp_Rotation.from_quat(onp.where(jnp.sum(q, axis=0) > 0, q, -q)).as_quat().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -135,10 +129,9 @@ def testRotationConcatenate(self, shape, other_shape, dtype): self.skipTest("Scipy 1.8.0 needed for concatenate.") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype),) - jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_quat() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q, o: osp_Rotation.concatenate([osp_Rotation.from_quat(q), osp_Rotation.from_quat(o)]).as_quat().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + jnp_fn = lambda q, o: jsp_Rotation.concatenate([jsp_Rotation.from_quat(q), jsp_Rotation.from_quat(o)]).as_rotvec() + np_fn = lambda q, o: osp_Rotation.concatenate([osp_Rotation.from_quat(q), osp_Rotation.from_quat(o)]).as_rotvec().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -149,10 +142,9 @@ def testRotationConcatenate(self, shape, other_shape, dtype): def testRotationGetItem(self, shape, dtype, indexer): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) - jnp_fn = lambda q: jsp_Rotation.from_quat(q)[indexer].as_quat() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: osp_Rotation.from_quat(q)[indexer].as_quat().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + jnp_fn = lambda q: jsp_Rotation.from_quat(jnp.where(jnp.sum(q, axis=0) > 0, q, -q))[indexer].as_quat() + np_fn = lambda q: osp_Rotation.from_quat(onp.where(onp.sum(q, axis=0) > 0, q, -q))[indexer].as_quat().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -166,9 +158,8 @@ def testRotationFromEuler(self, size, dtype, seq, degrees): shape = (size, len(seq)) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda a: jsp_Rotation.from_euler(seq, a, degrees).as_rotvec() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda a: osp_Rotation.from_euler(seq, a, degrees).as_rotvec().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + np_fn = lambda a: osp_Rotation.from_euler(seq, a, degrees).as_rotvec().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -192,9 +183,8 @@ def testRotationFromMrp(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda m: jsp_Rotation.from_mrp(m).as_rotvec() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda m: osp_Rotation.from_mrp(m).as_rotvec().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + np_fn = lambda m: osp_Rotation.from_mrp(m).as_rotvec().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -204,10 +194,9 @@ def testRotationFromMrp(self, shape, dtype): def testRotationFromRotvec(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) - jnp_fn = lambda r: jsp_Rotation.from_rotvec(r).as_quat() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda r: osp_Rotation.from_rotvec(r).as_quat().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + jnp_fn = lambda r: jsp_Rotation.from_rotvec(r).as_rotvec() + np_fn = lambda r: osp_Rotation.from_rotvec(r).as_rotvec().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -216,10 +205,9 @@ def testRotationFromRotvec(self, shape, dtype): ) def testRotationIdentity(self, num, dtype): args_maker = lambda: (num,) - jnp_fn = lambda n: jsp_Rotation.identity(n, dtype).as_quat() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda n: osp_Rotation.identity(n).as_quat().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + jnp_fn = lambda n: jsp_Rotation.identity(n, dtype).as_rotvec() + np_fn = lambda n: osp_Rotation.identity(n).as_rotvec().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -243,10 +231,9 @@ def testRotationMean(self, shape, dtype, rng_weights): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), jnp.abs(rng(shape[0], dtype)) if rng_weights else None) jnp_fn = lambda q, w: jsp_Rotation.from_quat(q).mean(w).as_rotvec() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK + np_fn = lambda q, w: osp_Rotation.from_quat(q).mean(w).as_rotvec().astype(dtype) # HACK tol = 5e-3 if jtu.device_under_test() == 'tpu' else 1e-4 - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol) @jtu.sample_product( @@ -258,9 +245,8 @@ def testRotationMultiply(self, shape, other_shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), rng(other_shape, dtype)) jnp_fn = lambda q, o: (jsp_Rotation.from_quat(q) * jsp_Rotation.from_quat(o)).as_rotvec() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q, o: (osp_Rotation.from_quat(q) * osp_Rotation.from_quat(o)).as_rotvec().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + np_fn = lambda q, o: (osp_Rotation.from_quat(q) * osp_Rotation.from_quat(o)).as_rotvec().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -270,10 +256,9 @@ def testRotationMultiply(self, shape, other_shape, dtype): def testRotationInv(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) - jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_quat() - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: osp_Rotation.from_quat(q).inv().as_quat().astype(dtype) # HACK - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + jnp_fn = lambda q: jsp_Rotation.from_quat(q).inv().as_rotvec() + np_fn = lambda q: osp_Rotation.from_quat(q).inv().as_rotvec().astype(dtype) # HACK + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -284,9 +269,8 @@ def testRotationLen(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: len(jsp_Rotation.from_quat(q)) - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: len(osp_Rotation.from_quat(q)) - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + np_fn = lambda q: len(osp_Rotation.from_quat(q)) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product( @@ -297,9 +281,8 @@ def testRotationSingle(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda q: jsp_Rotation.from_quat(q).single - # TODO(chrisflesher): re-enable this after accounting for sign degeneracy - # np_fn = lambda q: osp_Rotation.from_quat(q).single - # self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + np_fn = lambda q: osp_Rotation.from_quat(q).single + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) @jtu.sample_product(