In [None]:
import torch
import numpy as np
from lieflow.groups import TSn
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
import matplotlib.colors as mcolors
from tqdm import tqdm

In [None]:
ts2 = TSn(2)

In [None]:
t_1 = np.random.randn(2)
s_1 = np.random.rand(1) * 2. - 1.
g_1 = torch.Tensor(np.hstack((t_1, s_1)))
t_2 = np.random.randn(2)
s_2 = np.random.rand(1) * 2. - 1.
g_2 = torch.Tensor(np.hstack((t_2, s_2)))

A = ts2.log(ts2.L_inv(g_1, g_2))

In [None]:
ts = torch.linspace(0., 1., 101)
g_t = ts2.L(g_1, ts2.exp(ts[..., None] * A[None, ...]))

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].plot(g_t[:, 0], g_t[:, 1])
ax[0].scatter(g_t[::20, 0], g_t[::20, 1])
ax[0].set_xlim(-2, 2)
ax[0].set_ylim(-2, 2)
ax[1].plot(g_t[:, 0], g_t[:, 2])
ax[1].scatter(g_t[::20, 0], g_t[::20, 2])
ax[1].set_xlim(-2, 2)
ax[1].set_ylim(-1, 1)
ax[2].plot(g_t[:, 1], g_t[:, 2])
ax[2].scatter(g_t[::20, 1], g_t[::20, 2])
ax[2].set_xlim(-2, 2)
ax[2].set_ylim(-1, 1);

In [None]:
metadata = {'title': 'Exponential Curve Interpolation TS(2)', 'artist': 'Matplotlib'}
writer = FFMpegWriter(fps=2, metadata=metadata)
Δc = 1 / (6 - 1)
colors = [(j * Δc, 0.1, 1 - j * Δc) for j in range(6)]
cmap = mcolors.ListedColormap(colors)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
with writer.saving(fig, "output/exponential_curves_TS2.mp4", dpi=150):
    for frame in tqdm(range(50)):
        t_1 = np.random.randn(2)
        s_1 = np.random.rand(1) * 2. - 1.
        g_1 = torch.Tensor(np.hstack((t_1, s_1)))
        t_2 = np.random.randn(2)
        s_2 = np.random.rand(1) * 2. - 1.
        g_2 = torch.Tensor(np.hstack((t_2, s_2)))

        A = ts2.log(ts2.L_inv(g_1, g_2))

        ts = torch.linspace(0., 1., 101)
        g_t = ts2.L(g_1, ts2.exp(ts[..., None] * A[None, ...]))
        
        ax.clear()
        ax.plot(g_t[:, 0], g_t[:, 1])
        ax.scatter(g_t[::20, 0], g_t[::20, 1], c=np.arange(6), cmap=cmap)
        ax.set_xlim(-2, 2)
        ax.set_ylim(-2, 2)
        writer.grab_frame()