In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("so3os")

In [3]:
import jax
import jax.numpy as jnp
import jaxtyping as jt

import numpy as np

from functools import partial

In [4]:
from jaxtyping import install_import_hook

with install_import_hook("so3os", ("typeguard", "typechecked")):
    from so3os import euler, geometry, quaternion, moebius
    from so3os.func_utils import compose, unpack_args
    from so3os.jax_utils import key_chain

In [5]:
# auto key chain for
chain = key_chain(42)

In [6]:
# check conversions
euler_angles = jax.random.uniform(next(chain), shape=(3,))

should_be_identity = compose(
    euler.to_euler,
    quaternion.quat_to_mat,
    quaternion.mat_to_quat,
    euler.from_euler
)

assert jnp.allclose(euler_angles, should_be_identity(euler_angles))



In [7]:
# check quaternion rotation
euler_angles = jax.random.uniform(next(chain), shape=(3,))

R = euler.from_euler(euler_angles)
q = quaternion.mat_to_quat(R)

v = jnp.array([1.0, 0.0, 0.0])

assert jnp.allclose(v @ R, quaternion.qrot3d(q, v))

In [8]:
# check moebius transform and inverse
p = geometry.unit(jax.random.normal(next(chain), shape=(4,)))
r = 0.5
q = r * geometry.unit(jax.random.normal(next(chain), shape=(4,)))

should_be_identity = compose(
    partial(moebius.moebius_project, q=q),
    partial(moebius.moebius_project, q=q)
)
assert jnp.allclose(p, should_be_identity(p))

In [9]:
# check moebius transform volume change
p = geometry.unit(jax.random.normal(next(chain), shape=(3,)))
r = 0.5
q = r * geometry.unit(jax.random.normal(next(chain), shape=(3,)))


assert jnp.allclose(
    moebius.moebius_volume_change(p, q),
    geometry.volume_change(partial(moebius.moebius_project, q=q), geometry.tangent_space)(p))

In [10]:
# check double moebius transform and inverse
p = geometry.unit(jax.random.normal(next(chain), shape=(4,)))
r = 0.5
q = r * geometry.unit(jax.random.normal(next(chain), shape=(4,)))

should_be_identity = compose(
    partial(moebius.double_moebius_project, q=q),
    partial(moebius.double_moebius_inverse, q=q)
)
assert jnp.allclose(p, should_be_identity(p))

In [11]:
# check double moebius transform volume change
p = geometry.unit(jax.random.normal(next(chain), shape=(4,)))
r = 0.5
q = r * geometry.unit(jax.random.normal(next(chain), shape=(4,)))


assert jnp.allclose(
    moebius.double_moebius_volume_change(p, q),
    geometry.volume_change(partial(moebius.double_moebius_project, q=q), geometry.tangent_space)(p))