In [None]:
import torch
from lieflow.groups import M3
import matplotlib.pyplot as plt
%matplotlib widget

In [None]:
# m3 = M3(generator="mav")
m3 = M3(generator="pure_rotation")

x_0 = torch.zeros(3)
n_0_x = torch.tensor([1., 0., 0.])
n_0_y = torch.tensor([0., 1., 0.])
n_0_z = torch.tensor([0., 0., 1.])
p_0_x = m3.pack_position_orientation(x_0, n_0_x)
p_0_y = m3.pack_position_orientation(x_0, n_0_y)
p_0_z = m3.pack_position_orientation(x_0, n_0_z)

def g_to_frame(g):
    return m3.act(g, p_0_x), m3.act(g, p_0_y), m3.act(g, p_0_z)

def lift_to_SE3(p):
    A = m3.get_generator(p_0_z, p)
    random_rotation = m3.se3.exp(2 * torch.pi * torch.rand(*p.shape[:-2], 1, 1) * m3.se3.lie_algebra_basis[-1])
    return m3.se3.L(m3.se3.exp(A), random_rotation)
    
def project_to_M3(g):
    return m3.act(g, p_0_z)

def plot_p(ax, p, **kwargs):
    x, n = m3.get_position_orientation(p)
    return ax.quiver(x[..., 0], x[..., 1], x[..., 2], n[..., 0], n[..., 1], n[..., 2], **kwargs)

In [None]:
x = 1. * torch.randn(32, 3)
n = torch.randn(32, 3)
n = n / n.norm(dim=-1, keepdim=True)
p = m3.pack_position_orientation(x, n)
g = lift_to_SE3(p)
p_x, p_y, p_z = g_to_frame(g)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
plot_p(ax, p, length=0.1, color="k")
plot_p(ax, p_x, length=0.2, color="red")
plot_p(ax, p_y, length=0.2, color="green")
plot_p(ax, p_z, length=0.2, color="blue")
ax.set_aspect("equal")

In [None]:
components = 2. * torch.randn(6)
A = (components[:, None, None] * m3.se3.lie_algebra_basis).sum(-3)
g = m3.se3.exp(A)
ts = torch.linspace(0., 1., 20)
g_t = m3.se3.exp(ts[..., None, None] * A)
p_x_t, p_y_t, p_z_t = g_to_frame(g_t)
p_x, p_y, p_z = g_to_frame(g)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
plot_p(ax, p_x_t, length=0.1, alpha=0.5, color="red")
plot_p(ax, p_y_t, length=0.1, alpha=0.5, color="green")
plot_p(ax, p_z_t, length=0.1, alpha=0.5, color="blue")
plot_p(ax, p_x, length=0.1, alpha=0.5, color="red")
plot_p(ax, p_y, length=0.1, alpha=0.5, color="green")
plot_p(ax, p_z, length=0.1, alpha=0.5, color="blue")
ax.set_aspect("equal");

In [None]:
A = m3.get_generator(p_x_t, p_y_t)

g = m3.se3.exp(A)

((m3.act(g, p_x_t) - p_y_t)**2).sum((-2, -1)).sqrt().max()

In [None]:
x_1 = torch.tensor([0., 0., 0.])
n_1 = torch.tensor([0., 0., 1.])
p_1 = m3.pack_position_orientation(x_1, n_1)
x_2 = torch.randn(3)
n_2 = torch.randn(3)
# x_2 = torch.tensor([1., 0., 0.])
# n_2 = torch.tensor([0., 0.1, 1.])
# n_2 = n_1 + 0.05 * torch.randn(3)
n_2 = n_2 / (n_2**2).sum().sqrt()
p_2 = m3.pack_position_orientation(x_2, n_2)

ts = torch.linspace(0., 1., 20)

A = m3.get_generator(p_1, p_2, generator="mav")
g_t = m3.se3.exp(ts[..., None, None] * A)
p_t_mav = m3.act(g_t, p_1)

φ_1 = torch.pi/6.
A = m3.get_generator(p_1, p_2, generator=φ_1)
g_t = m3.se3.exp(ts[..., None, None] * A)
p_t_φ_1 = m3.act(g_t, p_1)

φ_2 = torch.pi/4.
A = m3.get_generator(p_1, p_2, generator=φ_2)
g_t = m3.se3.exp(ts[..., None, None] * A)
p_t_φ_2 = m3.act(g_t, p_1)

φ_3 = torch.pi/3.
A = m3.get_generator(p_1, p_2, generator=φ_3)
g_t = m3.se3.exp(ts[..., None, None] * A)
p_t_φ_3 = m3.act(g_t, p_1)

φ_4 = torch.pi * 0.499
A = m3.get_generator(p_1, p_2, generator=φ_4)
ts = torch.linspace(0., 1., 20)
g_t = m3.se3.exp(ts[..., None, None] * A)
p_t_φ_4 = m3.act(g_t, p_1)

φ_5 = torch.pi/2.
A = m3.get_generator(p_1, p_2, generator=φ_5)
ts = torch.linspace(0., 1., 20)
g_t = m3.se3.exp(ts[..., None, None] * A)
p_t_φ_5 = m3.act(g_t, p_1)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
plot_p(ax, p_t_mav, length=0.1, alpha=0.5, color="red", label="mav")
plot_p(ax, p_t_φ_1, length=0.1, alpha=0.5, color="blue", label=f"φ = {φ_1 / torch.pi:.2f}π")
plot_p(ax, p_t_φ_2, length=0.1, alpha=0.5, color="green", label=f"φ = {φ_2 / torch.pi:.2f}π")
plot_p(ax, p_t_φ_3, length=0.1, alpha=0.5, color="orange", label=f"φ = {φ_3 / torch.pi:.2f}π")
plot_p(ax, p_t_φ_4, length=0.1, alpha=0.5, color="yellow", label=f"φ = {φ_4 / torch.pi:.2f}π")
plot_p(ax, p_t_φ_5, length=0.1, alpha=0.5, color="purple", label=f"φ = {φ_5 / torch.pi:.2f}π")
plot_p(ax, p_1, length=0.1, color="red")
plot_p(ax, p_2, length=0.1, color="red")
ax.set_aspect("equal")
ax.legend();

In [None]:
x_1 = torch.tensor([0., 0., 0.])
n_1 = torch.tensor([0., 0., 1.])
p_1 = m3.pack_position_orientation(x_1, n_1)
x_2 = torch.randn(3)
n_2 = torch.randn(3)
# x_2 = torch.tensor([1., 0., 0.])
# n_2 = torch.tensor([0., 0.1, 1.])
# n_2 = n_1 + 0.05 * torch.randn(3)
n_2 = n_2 / (n_2**2).sum().sqrt()
p_2 = m3.pack_position_orientation(x_2, n_2)

ts = torch.linspace(0., 1., 20)
A = m3.get_generator(p_1, p_2, generator="mav")
g_t = m3.se3.exp(ts[..., None, None] * A)
p_t_mav = m3.act(g_t, p_1)

A = m3.get_generator(p_1, p_2, generator="pure_rotation")
g_t = m3.se3.exp(ts[..., None, None] * A)
p_t_rot = m3.act(g_t, p_1)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
plot_p(ax, p_t_mav, length=0.1, alpha=0.5, color="blue", label="mav")
plot_p(ax, p_t_rot, length=0.1, alpha=0.5, color="green", label="pure rotation")
plot_p(ax, p_1, length=0.1, color="red")
plot_p(ax, p_2, length=0.1, color="red")
ax.set_aspect("equal")
ax.legend();

In [None]:
def cross_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    shape = torch.broadcast_shapes(x.shape, y.shape)
    return torch.linalg.cross(x.expand(shape), y.expand(shape))

x_1 = torch.randn(3)
n_1 = torch.randn(3)
n_1 = n_1 / (n_1**2).sum().sqrt()
x_2 = torch.randn(3)
n_2 = torch.randn(3)
n_2 = n_2 / (n_2**2).sum().sqrt()
x_diff = x_2 - x_1
k0 = cross_product(n_1, n_2)
k0 = k0 / (k0**2).sum().sqrt()
kpi2 = n_1 + n_2
kpi2 = kpi2 / (kpi2**2).sum().sqrt()


φs = torch.linspace(-torch.pi, torch.pi, 100)
ks = torch.cos(φs)[:, None] * k0 + torch.sin(φs)[:, None] * kpi2
thingy = (ks * cross_product(n_1, n_2)).sum(dim=-1) / ((n_1 * n_2).sum(dim=-1) - (ks * n_1).sum(dim=-1) * (ks * n_2).sum(dim=-1))
constraint = (
    -thingy * (cross_product(ks, x_diff) * (n_1 + n_2)).sum(dim=-1) +
    (torch.sqrt(thingy**2 + 1) + 1) * (x_diff * (n_1 - n_2)).sum(dim=-1)
)
plt.close()
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(φs, constraint)
ax.hlines(0, -torch.pi, torch.pi);

In [None]:
constraint > 0.