In [None]:
import torch
import numpy as np
from lieflow.groups import SO3
import matplotlib.pyplot as plt
from scipy.stats import special_ortho_group
from matplotlib.animation import FFMpegWriter
# %matplotlib widget

In [None]:
torch.arccos(torch.linspace(-1., 1., 10))

In [None]:
so3 = SO3()

In [None]:
def generate_circle(N, d=0., w=np.pi/4, ε=0.05, gap=0.):
    w_ε = w + np.random.randn(N) * ε
    cos_y = np.cos(w_ε)
    sin_y = np.sin(w_ε)
    R_y = np.moveaxis(np.array((
        (cos_y, np.zeros(N), -sin_y),
        (np.zeros(N), np.ones(N), np.zeros(N)),
        (sin_y, np.zeros(N), cos_y),
    )),
    (0, 1),
    (-2, -1)
    )
    θs_x = (np.random.rand(N) * (2 - gap) - (1.5 - gap)) * np.pi
    cos_x = np.cos(θs_x)
    sin_x = np.sin(θs_x)
    R_x = np.moveaxis(np.array((
        (np.ones(N), np.zeros(N), np.zeros(N)),
        (np.zeros(N), cos_x, -sin_x),
        (np.zeros(N), sin_x, cos_x),
    )),
    (0, 1),
    (-2, -1)
    )
    d_ε = d + np.random.randn(N) * ε
    cos_z = np.cos(d_ε)
    sin_z = np.sin(d_ε)
    R_z = np.moveaxis(np.array((
        (cos_z, -sin_z, np.zeros(N)),
        (sin_z, cos_z, np.zeros(N)),
        (np.zeros(N), np.zeros(N), np.ones(N)),
    )),
    (0, 1),
    (-2, -1)
    )
    return torch.Tensor(R_z @ R_x @ R_y)

In [None]:
N = 2**14

In [None]:
generate_R = lambda N: generate_circle(N, gap=0.)

In [None]:
R = generate_R(N)

In [None]:
a = so3.lie_algebra_components(so3.log(R))

In [None]:
norms = (a**2).sum(-1).sqrt()

In [None]:
log_norms = torch.log(norms)
(log_norms > 1.).sum()

In [None]:
np.exp(2)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.hist(log_norms[log_norms > 1.], bins=20);

In [None]:
R = torch.Tensor(special_ortho_group.rvs(3, size=1))
R, so3.log(R), (R - so3.exp(so3.log(R))).abs().max()

In [None]:
R = torch.Tensor(special_ortho_group.rvs(3, size=50))
R = torch.Tensor(R)
(R - so3.exp(so3.log(R))).abs().max()

In [None]:
R_0 = torch.eye(3)
R_1 = torch.Tensor(special_ortho_group.rvs(3, size=1))
A = so3.log(so3.L_inv(R_0, R_1))
ts = torch.linspace(0., 1., 100)
R_t = so3.L(R_0, so3.exp(ts[..., None, None] * A))

In [None]:
skip = 10
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
ax.quiver(ts[::skip], 0., 0., R_t[::skip, 0, 0], R_t[::skip, 1, 0], R_t[::skip, 2, 0], length=0.1, color="red")
ax.quiver(ts[::skip], 0., 0., R_t[::skip, 0, 1], R_t[::skip, 1, 1], R_t[::skip, 2, 1], length=0.1, color="green")
ax.quiver(ts[::skip], 0., 0., R_t[::skip, 0, 2], R_t[::skip, 1, 2], R_t[::skip, 2, 2], length=0.1, color="blue")
ax.quiver(ts[-1], 0., 0., R_t[-1, 0, 0], R_t[-1, 1, 0], R_t[-1, 2, 0], length=0.1, color="red")
ax.quiver(ts[-1], 0., 0., R_t[-1, 0, 1], R_t[-1, 1, 1], R_t[-1, 2, 1], length=0.1, color="green")
ax.quiver(ts[-1], 0., 0., R_t[-1, 0, 2], R_t[-1, 1, 2], R_t[-1, 2, 2], length=0.1, color="blue")
ax.set_xlim(0, 1)
ax.set_xlabel("x = t")
ax.set_ylim(-0.1, 0.1)
ax.set_yticks([])
ax.set_zlim(-0.1, 0.1)
ax.set_zticks([])
ax.set_aspect("equal");

In [None]:
q_0 = torch.Tensor(((1., 0., 0.), (0., 0., 1.))).T
q_t = R_t @ q_0
q_1 = R_1 @ q_0

In [None]:
theta = np.linspace(0, np.pi, 50)
phi = np.linspace(0, 2 * np.pi, 100)
theta, phi = np.meshgrid(theta, phi)

x = np.sin(theta) * np.cos(phi)
y = np.sin(theta) * np.sin(phi)
z = np.cos(theta)
r = 1.01

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=15, azim=15)
ax.plot_surface(x, y, z, color='cyan', alpha=0.25, edgecolor=None)
ax.plot(r*q_t[:, 0, 0], r*q_t[:, 1, 0], r*q_t[:, 2, 0])
ax.quiver(
    r*q_t[10:-2:10, 0, 0], r*q_t[10:-2:10, 1, 0], r*q_t[10:-2:10, 2, 0],
    q_t[10:-2:10, 0, 1], q_t[10:-2:10, 1, 1], q_t[10:-2:10, 2, 1],
    length=0.25
)
ax.quiver(
    r*q_0[0, 0], r*q_0[1, 0], r*q_0[2, 0],
    q_0[0, 1], q_0[1, 1], q_0[2, 1],
    length=0.25, color="blue"
)
ax.quiver(
    r*q_1[0, 0], r*q_1[1, 0], r*q_1[2, 0],
    q_1[0, 1], q_1[1, 1], q_1[2, 1],
    length=0.25, color="green"
)

ax.set_xlim([-1, 1])
ax.set_xticks([])
ax.set_ylim([-1, 1])
ax.set_yticks([])
ax.set_zlim([-1, 1])
ax.set_zticks([])
ax.set_aspect("equal");


In [None]:
metadata = {'title': 'Exponential Curve Interpolation SO(3)', 'artist': 'Matplotlib'}
writer = FFMpegWriter(fps=2, metadata=metadata)

R_0 = torch.eye(3)
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
with writer.saving(fig, "output/exponential_curves_SO3.mp4", dpi=150):
    for frame in range(50):
        R_1 = torch.Tensor(special_ortho_group.rvs(3, size=1))

        A = so3.log(so3.L_inv(R_0, R_1))

        ts = torch.linspace(0., 1., 10)
        R_t = so3.L(R_0, so3.exp(ts[..., None, None] * A))
        
        ax.clear()
        ax.quiver(ts, 0., 0., R_t[:, 0, 0], R_t[:, 1, 0], R_t[:, 2, 0], length=0.1, color="red")
        ax.quiver(ts, 0., 0., R_t[:, 0, 1], R_t[:, 1, 1], R_t[:, 2, 1], length=0.1, color="green")
        ax.quiver(ts, 0., 0., R_t[:, 0, 2], R_t[:, 1, 2], R_t[:, 2, 2], length=0.1, color="blue")
        ax.set_xlim(0, 1)
        ax.set_xlabel("t")
        ax.set_ylim(-0.1, 0.1)
        ax.set_yticks([])
        ax.set_zlim(-0.1, 0.1)
        ax.set_zticks([])
        ax.set_aspect("equal")
        writer.grab_frame()

In [None]:
metadata = {'title': 'Exponential Curve Interpolation W_2', 'artist': 'Matplotlib'}
writer = FFMpegWriter(fps=2, metadata=metadata)

R_0 = torch.eye(3)
q_0 = torch.Tensor(((1., 0., 0.), (0., 0., 1.))).T
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=15, azim=15)
ax.plot_surface(x, y, z, color='cyan', alpha=0.25, edgecolor=None)
ax.quiver(r*q_0[0, 0], r*q_0[1, 0], r*q_0[2, 0], q_0[0, 1], q_0[1, 1], q_0[2, 1], length=0.25, color="blue")
quiver_inter = ax.quiver([], [], [], [], [], [])
quiver_end = ax.quiver([], [], [], [], [], [], color="green")
plot, = ax.plot([], [], [])

ax.set_xlim([-1, 1])
ax.set_xticks([])
ax.set_ylim([-1, 1])
ax.set_yticks([])
ax.set_zlim([-1, 1])
ax.set_zticks([])
ax.set_aspect("equal")
with writer.saving(fig, "output/exponential_curves_W2.mp4", dpi=150):
    for frame in range(50):
        R_1 = torch.Tensor(special_ortho_group.rvs(3, size=1))

        A = so3.log(so3.L_inv(R_0, R_1))

        ts = torch.linspace(0., 1., 100)
        R_t = so3.L(R_0, so3.exp(ts[..., None, None] * A))
        q_t = R_t @ q_0
        q_1 = R_1 @ q_0
        
        plot.remove()
        quiver_inter.remove()
        quiver_end.remove()

        plot, = ax.plot(r*q_t[:, 0, 0], r*q_t[:, 1, 0], r*q_t[:, 2, 0])
        quiver_inter = ax.quiver(
            r*q_t[10:-2:10, 0, 0], r*q_t[10:-2:10, 1, 0], r*q_t[10:-2:10, 2, 0],
            q_t[10:-2:10, 0, 1], q_t[10:-2:10, 1, 1], q_t[10:-2:10, 2, 1],
            length=0.25
        )
        quiver_end = ax.quiver(
            r*q_1[0, 0], r*q_1[1, 0], r*q_1[2, 0],
            q_1[0, 1], q_1[1, 1], q_1[2, 1],
            length=0.25, color="green"
        )
        writer.grab_frame()