Skip to content

Commit

Permalink
Support JIT types in transform_decompose and transform_compose
Browse files Browse the repository at this point in the history
  • Loading branch information
njroussel committed Aug 18, 2023
1 parent b6f0d78 commit 1244530
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 33 deletions.
10 changes: 5 additions & 5 deletions drjit/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def quat_to_euler(q, /):
Returns:
drjit.ArrayBase: A 3D Dr.Jit array containing the Euler angles.
'''
name = _dr.detail.array_name('Array', q.Type, [3], q.IsScalar)
name = _dr.detail.array_name('Array', q.Type, (3, *q.Shape[1:]), q.IsScalar)
module = _modules.get(q.__module__)
Array3f = getattr(module, name)

Expand Down Expand Up @@ -422,7 +422,7 @@ def euler_to_quat(a, /):
Returns:
drjit.ArrayBase: A Dr.Jit quaternion representing the input Euler angles.
'''
name = _dr.detail.array_name('Quaternion', a.Type, [4], a.IsScalar)
name = _dr.detail.array_name('Quaternion', a.Type, (4, *a.Shape[1:]), a.IsScalar)
module = _modules.get(a.__module__)
Quat4f = getattr(module, name)

Expand Down Expand Up @@ -491,11 +491,11 @@ def transform_decompose(a, it=10):
if not _dr.is_matrix_v(a):
raise Exception('Unsupported type!')

name = _dr.detail.array_name('Array', a.Type, [3], a.IsScalar)
name = _dr.detail.array_name('Array', a.Type, (3, *a.Shape[2:]), a.IsScalar)
module = _modules.get(a.__module__)
Array3f = getattr(module, name)

name = _dr.detail.array_name('Matrix', a.Type, (3, 3), a.IsScalar)
name = _dr.detail.array_name('Matrix', a.Type, (3, 3, *a.Shape[2:]), a.IsScalar)
Matrix3f = getattr(module, name)

Q, P = polar_decomp(Matrix3f(a), it)
Expand Down Expand Up @@ -525,7 +525,7 @@ def transform_compose(s, q, t, /):
if not _dr.is_matrix_v(s) or not _dr.is_quaternion_v(q):
raise Exception('Unsupported type!')

name = _dr.detail.array_name('Matrix', q.Type, (4, 4), q.IsScalar)
name = _dr.detail.array_name('Matrix', q.Type, (4, 4, *q.Shape[1:]), q.IsScalar)
module = _modules.get(q.__module__)
Matrix4f = getattr(module, name)

Expand Down
65 changes: 37 additions & 28 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,33 @@ def test08_polar():
assert dr.allclose(q@dr.transpose(q), dr.identity(M), atol=1e-6)


def test09_transform_decompose():
m = dr.scalar.Matrix4f([[1, 0, 0, 8], [0, 2, 0, 7], [0, 0, 9, 6], [0, 0, 0, 1]])
@pytest.mark.parametrize("package", ["drjit.scalar", "drjit.cuda", "drjit.llvm"])
def test09_transform_decompose(package):
package = prepare(package)
Quaternion4f, Array3f = package.Quaternion4f, package.Array3f
Matrix3f, Matrix4f = package.Matrix3f, package.Matrix4f

m = Matrix4f([[1, 0, 0, 8], [0, 2, 0, 7], [0, 0, 9, 6], [0, 0, 0, 1]])
s, q, t = dr.transform_decompose(m)

assert dr.allclose(s, dr.scalar.Matrix3f(m))
assert dr.allclose(q, dr.scalar.Quaternion4f(1))
assert dr.allclose(s, Matrix3f(m))
assert dr.allclose(q, Quaternion4f(1))
assert dr.allclose(t, [8, 7, 6])
assert dr.allclose(m, dr.transform_compose(s, q, t))

q2 = dr.rotate(dr.scalar.Quaternion4f, dr.scalar.Array3f(0, 0, 1), 15.0)
q2 = dr.rotate(Quaternion4f, Array3f(0, 0, 1), 15.0)
m @= dr.quat_to_matrix(q2)
s, q, t = dr.transform_decompose(m)

assert dr.allclose(q, q2)


def test10_matrix_to_quat():
q = dr.rotate(dr.scalar.Quaternion4f, dr.scalar.Array3f(0, 0, 1), 15.0)
@pytest.mark.parametrize("package", ["drjit.scalar", "drjit.cuda", "drjit.llvm"])
def test10_matrix_to_quat(package):
package = prepare(package)
Quaternion4f, Array3f = package.Quaternion4f, package.Array3f

q = dr.rotate(Quaternion4f, Array3f(0, 0, 1), 15.0)
m = dr.quat_to_matrix(q)
q2 = dr.matrix_to_quat(m)
assert dr.allclose(q, q2)
Expand Down Expand Up @@ -171,22 +180,24 @@ def test12_matrix_scale(package):
m = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], dtype=np.float32)
m2 = np.float32(2*m)

assert package.Matrix3f(m) * 2 == package.Matrix3f(m2)
assert package.Matrix3f(m) @ 2 == package.Matrix3f(m2)
assert package.Matrix3f(m) * package.Float(2) == package.Matrix3f(m2)
assert package.Matrix3f(m) @ package.Float(2) == package.Matrix3f(m2)
assert 2 * package.Matrix3f(m) == package.Matrix3f(m2)
assert package.Float(2) * package.Matrix3f(m) == package.Matrix3f(m2)
assert Matrix3f(m) * 2 == Matrix3f(m2)
assert Matrix3f(m) @ 2 == Matrix3f(m2)
assert Matrix3f(m) * Float(2) == Matrix3f(m2)
assert Matrix3f(m) @ Float(2) == Matrix3f(m2)
assert 2 * Matrix3f(m) == Matrix3f(m2)
assert Float(2) * Matrix3f(m) == Matrix3f(m2)


@pytest.mark.parametrize("package", ["drjit.scalar", "drjit.cuda", "drjit.llvm"])
def test12_matrix_vector(package):
np = pytest.importorskip("numpy")

m_ = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], dtype=np.float32)
package = prepare(package)
Float, Matrix3f, Array3f = package.Float, package.Matrix3f, package.Array3f

m_ = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], dtype=np.float32)
m = Matrix3f(m_)

v1 = m @ Array3f(1, 0, 0)
v2 = m @ Array3f(1, 1, 0)
assert dr.allclose(v1, [0.1, 0.4, 0.7])
Expand Down Expand Up @@ -281,28 +292,26 @@ def to_dest(a):
assert(m == m3)


@pytest.mark.parametrize("package", ["drjit.scalar"])
@pytest.mark.parametrize("package", ["drjit.scalar", "drjit.cuda", "drjit.llvm"])
def test15_quat_to_euler(package):
np = pytest.importorskip("numpy")

package = prepare(package)
Quaternion4f, Array3f, Float = package.Quaternion4f, package.Array3f, package.Float
Quaternion4f, Array3f = package.Quaternion4f, package.Array3f

# Gimbal lock at +pi/2
q = Quaternion4f(0, 1.0 / np.sqrt(2), 0, 1.0 / np.sqrt(2))
assert(dr.allclose(dr.quat_to_euler(q), Array3f(0, np.pi / 2, 0)))
q = Quaternion4f(0, 1.0 / dr.sqrt(2), 0, 1.0 / dr.sqrt(2))
assert(dr.allclose(dr.quat_to_euler(q), Array3f(0, dr.pi / 2, 0), atol=1e-3))
# Gimbal lock at -pi/2
q = Quaternion4f(0, -1.0 / np.sqrt(2), 0, 1.0 / np.sqrt(2))
assert(dr.allclose(dr.quat_to_euler(q), Array3f(0, -np.pi / 2, 0)))
q = Quaternion4f(0, -1.0 / dr.sqrt(2), 0, 1.0 / dr.sqrt(2))
assert(dr.allclose(dr.quat_to_euler(q), Array3f(0, -dr.pi / 2, 0), atol=1e-3))
# Gimbal lock at +pi/2, such that computed sinp > 1
q = Quaternion4f(0, 1.0 / np.sqrt(2) + 1e-6, 0, 1.0 / np.sqrt(2))
assert(dr.allclose(dr.quat_to_euler(q), Array3f(0, np.pi / 2, 0)))
q = Quaternion4f(0, 1.0 / dr.sqrt(2) + 1e-6, 0, 1.0 / dr.sqrt(2))
assert(dr.allclose(dr.quat_to_euler(q), Array3f(0, dr.pi / 2, 0), atol=1e-3))
# Gimbal lock at -pi/2, such that computed sinp < -1
q = Quaternion4f(0, -1.0 / np.sqrt(2) - 1e-6, 0, 1.0 / np.sqrt(2))
assert(dr.allclose(dr.quat_to_euler(q), Array3f(0, -np.pi / 2, 0)))
q = Quaternion4f(0, -1.0 / dr.sqrt(2) - 1e-6, 0, 1.0 / dr.sqrt(2))
assert(dr.allclose(dr.quat_to_euler(q), Array3f(0, -dr.pi / 2, 0), atol=1e-3))
# Quaternion without gimbal lock
q = Quaternion4f(0.15849363803863525, 0.5915063619613647, 0.15849363803863525, 0.7745190262794495)
e = Array3f(np.pi / 3, np.pi / 3, np.pi / 3)
e = Array3f(dr.pi / 3, dr.pi / 3, dr.pi / 3)
assert(dr.allclose(dr.quat_to_euler(q), e))
# Round trip
assert(dr.allclose(e, dr.quat_to_euler(dr.euler_to_quat(e))))
Expand Down Expand Up @@ -341,7 +350,7 @@ def test17_quat_to_matrix(package):
assert(dr.allclose(q, dr.matrix_to_quat(m4)))

# pi/2 around z-axis
q = Quaternion4f([ 0, 0, 1/np.sqrt(2), 1/np.sqrt(2) ])
q = Quaternion4f([ 0, 0, 1 / dr.sqrt(2), 1 / dr.sqrt(2) ])
m3 = Matrix3f([ [0, -1, 0], [1, 0, 0], [0, 0, 1] ])
m4 = Matrix4f([ [0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ])
assert(dr.allclose(dr.quat_to_matrix(q, size=3), m3, atol=2e-7))
Expand Down

0 comments on commit 1244530

Please sign in to comment.