# Rotations performance

## starry

Let's start by defining and evaluating a single rotation as implemented in starry currently

In [68]:
import starry
import numpy as np

l_max = 5
n_max = l_max**2 + 2 * l_max + 1
y = np.linspace(-1, 1, n_max)
starry_op = starry._core.core.OpsYlm(l_max, 0, 0, 1)

Pre-computing some matrices... Done.


Here are som angle

In [26]:
inc = 60.0
obl = 30.0
theta = 45.0

and the time estiamte

In [27]:
starry_op.left_project(y[:, None], inc, obl, theta)

%timeit starry_op.left_project(y[:, None], inc, obl, theta)

Compiling `left_project`... Done.


43.1 µs ± 484 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


For an array of theta

In [28]:
thetas = np.linspace(0, 360, 1000)

%timeit starry_op.left_project(y[:, None], inc, obl, thetas)

Compiling `left_project`... Done.


207 µs ± 2.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## jaxoplanet

We now define the different rotations in jaxoplanet (`left_project1` and `left_project2`, we omit $\theta_z$ since it is not present in starry `left_project`)

In [29]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jaxoplanet.experimental.starry.rotation import (
    right_project_axis_angle,
    dot_rotation_matrix,
    dot_rz,
)


# this is the one where everything is baked-in
def left_project_all(ydeg, inc, obl, theta, theta_z, x):
    axis_x, axis_y, axis_z, angle = right_project_axis_angle(inc, obl, theta, theta_z)
    return dot_rotation_matrix(ydeg, -axis_x, -axis_y, -axis_z, angle)(x)


# this one used dot_rz
def left_project_dotrz(ydeg, inc, obl, theta, x):
    x = dot_rotation_matrix(ydeg, 1.0, 0.0, 0.0, -0.5 * jnp.pi)(x)
    x = dot_rz(ydeg, -theta)(x)
    x = dot_rotation_matrix(ydeg, 1.0, 0.0, 0.0, 0.5 * jnp.pi)(x)
    x = dot_rz(ydeg, -obl)(x)
    x = dot_rotation_matrix(
        ydeg, -jnp.cos(obl), -jnp.sin(obl), 0.0, (0.5 * jnp.pi - inc)
    )(x)
    return x


# this one is the most straightforward
def left_project_classic(ydeg, inc, obl, theta, x):
    x = dot_rotation_matrix(ydeg, 1.0, 0.0, 0.0, -0.5 * jnp.pi)(x)
    x = dot_rotation_matrix(ydeg, None, None, 1.0, -theta)(x)
    x = dot_rotation_matrix(ydeg, 1.0, 0.0, 0.0, 0.5 * jnp.pi)(x)
    x = dot_rotation_matrix(ydeg, None, None, 1.0, -obl)(x)
    x = dot_rotation_matrix(
        ydeg, -jnp.cos(obl), -jnp.sin(obl), 0.0, (0.5 * jnp.pi - inc)
    )(x)
    return x

Let see the performances for a single angle theta

In [30]:
import timeit
from jax import block_until_ready as bur

n = 10000

f = jax.jit(left_project_all, static_argnames=("ydeg"))
bur(f(l_max, inc, obl, theta, 0.0, y))
t_all = (
    timeit.timeit("bur(f(l_max, inc, obl, theta, 0., y))", globals=globals(), number=n)
    / n
    * 1e6
)

f = jax.jit(left_project_dotrz, static_argnames=("ydeg"))
bur(f(l_max, inc, obl, theta, y))
t_1 = (
    timeit.timeit("bur(f(l_max, inc, obl, theta, y))", globals=globals(), number=n)
    / n
    * 1e6
)

f = jax.jit(left_project_classic, static_argnames=("ydeg"))
bur(f(l_max, inc, obl, theta, y))
t_2 = (
    timeit.timeit("bur(f(l_max, inc, obl, theta, y))", globals=globals(), number=n)
    / n
    * 1e6
)

In [13]:
print(f"all\t{t_all:.2f} us\ndot_rz\t{t_1:.2f} us\nclassic\t{t_2:.2f} us")

all	5864.69 us
dot_rz	447.63 us
classic	445.72 us


In [8]:
n = 1000

f = jax.jit(
    jax.vmap(left_project_all, in_axes=(None, None, None, 0, None, None)),
    static_argnames=("ydeg"),
)
bur(f(l_max, inc, obl, thetas, 0.0, y))
t_all = (
    timeit.timeit("bur(f(l_max, inc, obl, thetas, 0., y))", globals=globals(), number=n)
    / n
    * 1e6
)

f = jax.jit(
    jax.vmap(left_project_dotrz, in_axes=(None, None, None, 0, None)),
    static_argnames=("ydeg"),
)
bur(f(l_max, inc, obl, thetas, y))
t_1 = (
    timeit.timeit("bur(f(l_max, inc, obl, thetas, y))", globals=globals(), number=n)
    / n
    * 1e6
)

f = jax.jit(
    jax.vmap(left_project_classic, in_axes=(None, None, None, 0, None)),
    static_argnames=("ydeg"),
)
bur(f(l_max, inc, obl, thetas, y))
t_2 = (
    timeit.timeit("bur(f(l_max, inc, obl, thetas, y))", globals=globals(), number=n)
    / n
    * 1e6
)

In [12]:
print(f"all\t{t_all:.2f} us\ndot_rz\t{t_1:.2f} us\nclassic\t{t_2:.2f} us")

all	5864.69 us
dot_rz	447.63 us
classic	445.72 us


## Few things to notice

- As I say in the paper, rotation about y instead of switching to z is a bad idea:

In [31]:
# notice the first rotation around theta
def left_project_y(ydeg, inc, obl, theta, x):
    x = dot_rotation_matrix(ydeg, None, 1.0, None, -theta)(x)
    x = dot_rotation_matrix(ydeg, None, None, 1.0, -obl)(x)
    x = dot_rotation_matrix(
        ydeg, -jnp.cos(obl), -jnp.sin(obl), 0.0, (0.5 * jnp.pi - inc)
    )(x)
    return x

f = jax.jit(
    jax.vmap(left_project_y, in_axes=(None, None, None, 0, None)),
    static_argnames=("ydeg"),
)
bur(f(l_max, inc, obl, thetas, y))
%timeit bur(f(l_max, inc, obl, thetas, y))

5.72 ms ± 242 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


- `dot_rz` does not necessarily help here, I suspect jax to do a great job

# A new version

In [32]:
import sympy as sp

s_inc, s_obl = sp.symbols("inc, obl")

q = sp.Quaternion.from_axis_angle((1.0, 0.0, 0.0), 0.5 * sp.pi)
q = q * sp.Quaternion.from_axis_angle((0, 0, 1.0), -s_obl)
q = q * sp.Quaternion.from_axis_angle(
    (-sp.cos(s_obl), -sp.sin(s_obl), 0.0), 0.5 * sp.pi - s_inc
)


axis, angle = q.to_axis_angle()

In [33]:
axis[0]

1.0*sin(inc/2)*cos(obl/2)/sqrt(-cos(inc/2)**2*cos(obl/2)**2 + 1)

In [34]:
axis[1]

1.0*sin(inc/2)*sin(obl/2)/sqrt(-cos(inc/2)**2*cos(obl/2)**2 + 1)

In [35]:
axis[2]

-1.0*sin(obl/2)*cos(inc/2)/sqrt(-cos(inc/2)**2*cos(obl/2)**2 + 1)

In [36]:
angle

2*acos(1.0*cos(inc/2)*cos(obl/2))

In [37]:
for i in [*axis, angle]:
    print(sp.latex(i))

\frac{1.0 \sin{\left(\frac{inc}{2} \right)} \cos{\left(\frac{obl}{2} \right)}}{\sqrt{- \cos^{2}{\left(\frac{inc}{2} \right)} \cos^{2}{\left(\frac{obl}{2} \right)} + 1}}
\frac{1.0 \sin{\left(\frac{inc}{2} \right)} \sin{\left(\frac{obl}{2} \right)}}{\sqrt{- \cos^{2}{\left(\frac{inc}{2} \right)} \cos^{2}{\left(\frac{obl}{2} \right)} + 1}}
- \frac{1.0 \sin{\left(\frac{obl}{2} \right)} \cos{\left(\frac{inc}{2} \right)}}{\sqrt{- \cos^{2}{\left(\frac{inc}{2} \right)} \cos^{2}{\left(\frac{obl}{2} \right)} + 1}}
2 \operatorname{acos}{\left(1.0 \cos{\left(\frac{inc}{2} \right)} \cos{\left(\frac{obl}{2} \right)} \right)}


In [82]:
# new one
def left_project_new(ydeg, inc, obl, theta, x):
    x = dot_rotation_matrix(ydeg, 1.0, None, None, -0.5 * jnp.pi)(x)
    x = dot_rotation_matrix(ydeg, None, None, 1.0, -theta)(x)

    co = jnp.cos(obl / 2)
    so = jnp.sin(obl / 2)
    ci = jnp.cos(inc / 2)
    si = jnp.sin(inc / 2)

    denominator = jnp.sqrt(1 - ci**2 * co**2)

    axis_x = si * co
    axis_y = si * so
    axis_z = -so * ci

    angle = 2 * jnp.arccos(ci * co)

    arg = jnp.linalg.norm(jnp.array([axis_x, axis_y, axis_z]))
    axis_x = jnp.where(arg > 0, axis_x / denominator, 1.0)
    axis_y = jnp.where(arg > 0, axis_y / denominator, 0.0)
    axis_z = jnp.where(arg > 0, axis_z / denominator, 0.0)

    x = dot_rotation_matrix(ydeg, axis_x, axis_y, axis_z, angle)(x)
    return x

In [83]:
inc = np.pi / 2 + 0.5
obl = -0.2
theta = 0.2

In [84]:
for l in [2, 5, 10]:
    w = np.linspace(-1, 1, l**2 + 2 * l + 1)
    for i in [0.0, 30.0, 60.0, 90.0]:
        for o in [-30.0, 0.0, 30.0]:
            t = 0.5
            assert np.allclose(
                left_project_classic(l, i, o, t, w),
                left_project_new(l, i, o, t, w),
            )

print("all passed!")

all passed!


In [85]:
f = jax.jit(left_project_new, static_argnames=("ydeg"))
bur(f(l_max, inc, obl, theta, y))
%timeit bur(f(l_max, inc, obl, theta, y))

35.3 µs ± 955 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [86]:
f = jax.jit(
    jax.vmap(left_project_new, in_axes=(None, None, None, 0, None)),
    static_argnames=("ydeg"),
)
bur(f(l_max, inc, obl, thetas, y))
%timeit bur(f(l_max, inc, obl, thetas, y))

320 µs ± 4.35 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
