# Geodesic computation example
This piece of code shows you how to compute shortest path with respect to the learned Riemannian metric (*i.e.* geodesics).

In [None]:

... # Code used before (i.e. by running "making_your_own_autoencoder.ipynb" tutorial)

# Retreive your trained RHVAE model with the following command
model_rec = RHVAE.load_from_folder(os.path.join('my_model_with_custom_archi', last_training, 'final_model'))
model_rec

## Define the geodesic models
This model is optimized such that it minimizes the energy associated to associated path $\gamma$ traveling form $z_1 \in \mathcal{M}$ to  $z_2 \in \mathcal{M}$, two points of the manifold.

$$
\inf \limits _{\gamma} L(\gamma) =\int \limits _{0}^1 \lVert \dot{\gamma}(t)\rVert_{\gamma(t)} = \int \limits _{0}^1 \sqrt{\dot{\gamma}(t)^{\top} \mathbf{G}\dot{\gamma}(t)} dt \hspace{5mm} \text{s.t.} \hspace{5mm} \gamma{0}=z_1,  \gamma(1)=z_2
$$
where $\mathbf{G}$ is the riemannian metric

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from torch.autograd import grad


class Geodesic_autodiff(nn.Module):
    def __init__(
        self,
        metric=None,
        starting_pos=None,
        ending_pos=None,
        starting_velo=None,
        latent_dim=2,
        reg=0.0,
        granularity=100,
        early_stopping=100,
        seed=8,
    ):
        """
        Geodesic NN model

        Inputs:
        -------

        metric (function): The metric used to compute the geodesic path
        starting_pos (tensor): The starting point of the
        ending_pos (tensor): The ending point of the path
        starting_velo (tensor) [optional]: The initial velocity (for further use)
        latent_dim (int): Latent space dimension
        reg (float): L-2 regularization factor
        granularity (int): The discretization granularity
        """
        torch.manual_seed(seed)
        nn.Module.__init__(self)

        self.compute_with_ending_point = False
        self.compute_with_velo = False
        self.device = 'cpu'
        self.early_stopping = early_stopping

        if starting_pos is None:
            starting_pos = torch.zeros(1, latent_dim).to(self.device)

        if ending_pos is None:
            ending_pos = torch.ones_like(starting_pos).to(self.device)

        else:
            self.compute_with_ending_point = True

        if starting_velo is None:
            starting_velo = torch.zeros(1, latent_dim).to(self.device)

        else:
            self.compute_with_velo = True

        self.starting_pos = starting_pos
        self.ending_pos = ending_pos
        self.starting_velo = starting_velo

        self.metric = metric
        self.reg = reg
        self.gran = granularity
        self.latent_dim = latent_dim
        self.length = None

        self.fc1 = nn.Linear(1, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, latent_dim)

    def forward(self, t):
        """
        The geodesic model
        """
        h1 = torch.tanh(self.fc1(t))
        h2 = torch.tanh(self.fc2(h1))
        out = self.fc3(h2)

        return out

    def loss_function(self, curve_t, gt_):

        return (
            torch.sqrt(gt_.T @ self.metric(curve_t) @ gt_)
            + self.reg * self.metric(curve_t).norm()
        )

    def fit(self, n_epochs=10, lr=1e-2):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        self.train()

        best_curve_model = deepcopy(self)
        best_loss = 1e20

        for epoch in range(n_epochs):
            optimizer.zero_grad()
            loss = 0
            loss += torch.sqrt(
                self.starting_velo
                @ self.metric(self.starting_pos)
                @ self.starting_velo.T
            )

            curve_t0 = self(torch.tensor([0.0]).to(self.device))
            curve_t1 = self(torch.tensor([1.0]).to(self.device))

            if self.compute_with_velo:
                gt_0 = torch.zeros(1, self.latent_dim)

                for t in range(1):
                    t = (
                        torch.tensor([t])
                        .type(torch.float)
                        .requires_grad_(True)
                        .to(self.device)
                    )
                    curve_t = self(t).reshape(1, self.latent_dim)

                    for i in range(self.latent_dim):
                        gt_0[0][i] = grad(curve_t[0][i], t, retain_graph=True)[0]
                a = self.starting_velo / gt_0
                b = a * curve_t0

            elif self.compute_with_ending_point:
                a = (self.starting_pos - self.ending_pos) / (curve_t0 - curve_t1)
                b = (self.starting_pos * curve_t1 - self.ending_pos * curve_t0) / (
                    curve_t0 - curve_t1
                )

            for t in range(0, self.gran + 1):
                t = (
                    torch.tensor([t / self.gran])
                    .type(torch.float)
                    .requires_grad_(True)
                    .to(self.device)
                )
                curve_t = self(t).reshape(1, self.latent_dim)

                curve_t = a * curve_t - b

                gt_ = torch.zeros(self.latent_dim, 1).to(self.device)

                for i in range(self.latent_dim):
                    gt_[i] = grad(curve_t[0][i], t, retain_graph=True)[0]

                loss += self.loss_function(curve_t, gt_)

            loss /= self.gran

            if loss < best_loss:
                es_epoch = 0
                print("better", loss)
                best_curve_model = deepcopy(self)
                best_loss = loss
                best_a, best_b = a, b
                length = best_loss

            elif self.early_stopping > 0:
                es_epoch += 1

                if es_epoch >= self.early_stopping:
                    print(
                        f"Early Stopping at epoch {epoch} ! Loss did not improve in {self.early_stopping} epochs"
                    )
                    break

            # print(loss)
            if epoch % 50 == 0:
                print("-----")
                print(loss)

            loss.backward()
            optimizer.step()

        return best_curve_model, best_a, best_b, length


In [None]:
start = torch.randn(1, 10) # Define your starting point in the latent space
end = torch.randn(1, 10) # Define your ending point in the latent space

In [None]:
from riemann_tools import Geodesic_autodiff

best_geo = Geodesic_autodiff(
    latent_dim=10,
    starting_pos=start,
    ending_pos=end, 
    metric=model_rec.G, 
    reg=0, 
    granularity=100,
    seed=8)

In [None]:
best_geo, best_a, best_b, _ = best_geo.fit(1000, lr=1e-3) # fit the model
# best_geo is a function that takes t \in [0, 1] and return z in the latent space

In [None]:
# Build the curve
T = torch.linspace(0, 1, 100).reshape(100, 1)
curve = best_a * best_geo(T) - best_b

In [None]:
# You can now decode it 
model_rec.decoder(curve.cuda())