## Problem Set 3, Programming Exercise

In this exercise, we study a Riemannian geometric structure for the space of point cloud data, based on the theory of statistical manifold and information geometry.   
As an application, we conduct geodesic interpolation of point cloud using a latent representation from a pretrained point cloud autoencoder.    
For details, you may refer to \<A Statistical Manifold Framework for Point Cloud Data\> (Yonghyeon Lee, Seungyeon Kim, Jinwon Choi, and Frack C. Park, ICML 2022).   

#### Statistical Manifold Framework for Point Cloud Data
A point cloud $X$ (with a fixed number of points, $n$) in $\mathbb{R}^D$ can be represented as $X=\{ x_1,\ldots,x_n | x_i\in\mathbb{R}^D, x_i\neq x_j \text{ if } i\neq j \}$.   
Given a point cloud $X$, a probability density function $p(x;X)$ parameterized by the point cloud can be defined as
$$ p(x;X) = \frac{1}{n\sqrt{|\Sigma|}} \sum^n_{i=1}K(\Sigma^{-1/2}(x-x_i)) \quad\text{where}\quad \Sigma=\sigma^2 I, \; K(u)=\frac{1}{\sqrt{(2\pi)^D}}\exp(-\frac{u^Tu}{2}), $$
which is just a kernel density estimate using a Gaussian kernel $K$ and an isotropic bandwidth matrix $\Sigma$.   
Let us denote the space of all point cloud data by ${\cal X}$ and define ${\cal S}:=\{ p(x;X) | X\in{\cal X} \}$.   
Then, it can be known that the correspondence $X \mapsto p(x;X)$ between the space of point clouds and the statistical manifold ${\cal S}$ becomes a one-to-one mapping.   
Since the statistical manifold ${\cal S}$ has its natural Riemannian metric, namely the Fisher information metric, a Riemannian geometric structure of the space of point clouds is naturally defined as well.

<p align="center">
  <img src="figures/pc_to_mfd.png" alt="Point Cloud to Manifold Correspondence" width="750"/>
</p>
<p align="center"><em>Figure 1: The statistical manifold obtained from the one-to-one mappting between point clouds and probabitliy density functions. </em></p>


#### Python Environment
For this exercise, we need to setup the following environment
- python 3.8
- numpy
- scipy
- matplotlib
- scikit-learn
- pandas
- h5py
- pyyaml
- omegaconf
- tqdm
- torch, torchvision
    - https://pytorch.org/get-started/locally/
- [torchcubicspline](https://github.com/patrick-kidger/torchcubicspline)
    - `pip install git+https://github.com/patrick-kidger/torchcubicspline.git`
- Open3D 0.13.0
    - `pip install open3d==0.13.0 --no-deps`
    - **Warining:** The latest version of Open3D has slightly different syntax, so we highly recommend using the Open3D version of 0.13.0 as noted above.
    - Use `--no-dpes` option to avoid installation errors.

#### Dataset
- Create `datasets/` directory.
- Download `interpolation_dataset.zip` from the [Google drive link](https://drive.google.com/drive/folders/1NuGq2LtWG627r9BNPzb1EegUuIvPUzDr?usp=sharing), and unzip it under `datasets/` directory.

#### Pretrained Model
- Create `pretrained/` directory.
- Download `interpolation_config.zip` from the [Google drive link](https://drive.google.com/drive/folders/1NuYIfyU6kVQ09qPR6rONWrernKMps_FX?usp=sharing), and unzip it under `pretrained/` directory.


#### Import Python Packages

In [None]:
import numpy as np
import open3d as o3d
import os
import copy
import torch
from torchcubicspline import(natural_cubic_spline_coeffs, 
                             NaturalCubicSpline)
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from datetime import datetime

from loader import get_dataloader
from models import load_pretrained

from functions.util import label_to_color, figure_to_array
from functions.color_assignment import latent_to_color
from functions.util import gallery, render_pointcloud

#### Set Up Your Device

In [None]:
# If you have a GPU, comment out the second line and uncomment the first line
# If you don't have a GPU, comment out the first line and uncomment the second line

####################### Change Device Here #######################
device = 'cuda:0'
# device = 'cpu'
##################################################################

print('Device:', device)

In [3]:
# pretrained autoencoder
config = ('interpolation_config/vanilla', 'interpolation_config.yml', 'model_best.pkl', {})
root = 'pretrained/'

# parameters
num_interpolates_linear = 20
num_interpolates_geodesic = num_interpolates_linear - 1
k = 0.5   
epoch_curve = 5000                                          # 5000 recommended for good results. Set to smaller values (e.g., 1500) if the computation takes too long (e.g., when using cpu).
learning_rate = 1e-3
num_samples = 40
n_control_points = 10
mode = 'smoothed_nn'

# figure parameters
scale = 0.02
scale_ratio = 1.4
kwargs_linear = {'linestyle': 'dotted', 'linewidth': 1.5, 'label': 'linear', 'c': [253/255, 134/255, 18/255]}
kwargs_identity = {'linestyle': 'dashed', 'linewidth': 1.5, 'label': 'identity', 'c': [253/255, 134/255, 18/255]}
kwargs_geodesic = {'linestyle': 'solid', 'linewidth': 1.5, 'label': 'geodesic', 'c': [253/255, 134/255, 18/255]}
label_to_text = ['box', 'cylinder', 'cone', 'ellipsoid', 'interpolates']


In [None]:
# initialize
Z = []
y = []
P = []

# load configuration 
identifier, config_file, ckpt_file, kwargs = config

# load pretrained model
kwargs = {}
model, cfg = load_pretrained(identifier, config_file, ckpt_file, root=root, **kwargs)
model.to(device)

# load test data
print("Load Test Data and Encode!")
cfg_test = cfg['data']['test']
test_dl, mean_MED = get_dataloader(cfg_test)
sample = 0
for data in test_dl:
    P.append(data[0].to(device))
    Z.append(copy.copy(model.encode(data[0].to(device))))
    y.append(data[1]) 
    sample += 1 
P = torch.cat(P, dim=0)
Z = torch.cat(Z, dim=0)
y = torch.cat(y, dim=0)
color_3d = label_to_color(y.squeeze().detach().cpu().numpy())
print(f'Mean MED of the dataset is {mean_MED}.')

# Latent Space Encoding 
f = plt.figure()
plt.scatter(Z[:,0].detach().cpu(), Z[:,1].detach().cpu(), c=color_3d/255.0)
plt.axis('equal')
plt.close()
f_np = np.transpose(figure_to_array(f), (2, 0, 1))

# class-wise latent vectors
Z_cylinder = Z[y.view(-1)==1].detach()
Z_cone = Z[y.view(-1)==2].detach()
z_list = []

# interpolation candidates
data_idx_list = [
    [torch.argsort(Z_cylinder[:,0])[-15], torch.argsort(Z_cylinder[:,0])[1]],
    [torch.argsort(Z_cone[:,0])[-1], torch.argsort(Z_cone[:,0])[0]],
    [torch.argsort(Z_cylinder[:,0])[-185], torch.argsort(Z_cone[:,0])[-110]]
]
z_list.append([Z_cylinder[data_idx_list[0][0]], Z_cylinder[data_idx_list[0][1]]])
z_list.append([Z_cone[data_idx_list[1][0]], Z_cone[data_idx_list[1][1]]])
z_list.append([Z_cylinder[data_idx_list[2][0]], Z_cone[data_idx_list[2][1]]])

# selected example
z_ = z_list[0]
z1 = z_[0]
z2 = z_[1]        


#### (a) Linear Interpolation
Conduct linear interpolation in the latent space.

In [None]:
######################## Your Code Here ###########################
# TO DO: Linearly interpolate between z1 and z2.
# z_linear_interpolates should be a (num_interpolates_linear, 2)-shaped tensor

# z_linear_interpolates = ##### Your Code Here #####

###################################################################
x_linear_interpolates = model.decode(z_linear_interpolates)


In [None]:
# latent space visualization - linear interpolation
f = plt.figure()
plt.scatter(Z[:,0].detach().cpu(), Z[:,1].detach().cpu(), c=color_3d/255.0)
plt.scatter(z1[0].detach().cpu(), z1[1].detach().cpu(), c='r', marker='*', s=200)
plt.scatter(z2[0].detach().cpu(), z2[1].detach().cpu(), c='r', marker='*', s=200)
plt.plot(
    z_linear_interpolates[:, 0].detach().cpu(), 
    z_linear_interpolates[:, 1].detach().cpu(), 
    c='k',
    linewidth=3.0
)
plt.axis('equal')
plt.show()
plt.close()


#### (b) Cubic Spline Optimizer for Geodesic Computation
On the latent space of the pretrained point cloud autoencoder, a geodesic curve can be obtained by optimizing an appropriate objective function.   
For this, we define `cubic_spline_curve` class, whose `train_step` returns the objective function for geodesic and `compute_length` the (Riemannian) length of a spline curve.   
For your information, the definitions of `get_Identity_proj_Riemannian_metric` and `get_Fisher_proj_Riemannian_metric` can be found in `models/base_arch.py`.   
These compute the projection of Euclidean metric and Fisher information metric onto the latent space through decoder Jacobian.   
Complete `train_step` and `compute_length`.   

In [7]:
class cubic_spline_curve(torch.nn.Module):
    def __init__(self, z_i, z_f, mean_MED, k, device, metric_type, channels=2, lengths=2):
        super(cubic_spline_curve, self).__init__()
        self.channels = channels
        self.z_i = z_i.unsqueeze(0)
        self.z_f = z_f.unsqueeze(0)
        self.mean_MED = mean_MED
        self.k = k
        self.device = device
        self.metric_type = metric_type
        self.z = Parameter(
                    torch.cat(
                        [self.z_i + (self.z_f-self.z_i) * t / (lengths + 1) + torch.randn_like(self.z_i)*0.0 for t in range(1, lengths+1)], dim=0)
        )
        self.t_linspace = torch.linspace(0, 1, lengths + 2).to(self.device)

    def append(self):
        return torch.cat([self.z_i, self.z, self.z_f], dim=0)
    
    def spline_gen(self):
        coeffs = natural_cubic_spline_coeffs(self.t_linspace, self.append())
        spline = NaturalCubicSpline(coeffs)
        return spline
    
    def forward(self, t):
        out = self.spline_gen().evaluate(t)
        return out
    
    def velocity(self, t):
        out = self.spline_gen().derivative(t)
        return out
    
    def train_step(self, model, num_samples):
        t_samples = torch.rand(num_samples).to(self.device)
        z_samples = self(t_samples)
        if self.metric_type == 'identity':
            G = model.get_Identity_proj_Riemannian_metric(z_samples, create_graph=True)
        elif self.metric_type == 'information':
            G = model.get_Fisher_proj_Riemannian_metric(
                    z_samples, create_graph=True, sigma=self.mean_MED * self.k)
        else:
            raise ValueError

        z_dot_samples = self.velocity(t_samples)
        ######################## Your Code Here ###########################
        # TO DO: Define the loss function for the geodesic computation, using the Riemannian metrics G obtained above.
        # geodesic_loss should be a tensor consiting of a single real number.
        
        # geodesic_loss = ##### Your Code Here #####
        
        ###################################################################
        return geodesic_loss

    def compute_length(self, model, num_discretizations=100):
        t_samples = torch.linspace(0, 1, num_discretizations).to(self.device)
        z_samples = self(t_samples)
        if self.metric_type == 'identity':
            G = model.get_Identity_proj_Riemannian_metric(z_samples, create_graph=False)
        elif self.metric_type == 'information':
            G = model.get_Fisher_proj_Riemannian_metric(
                    z_samples, create_graph=False, sigma=self.mean_MED * self.k)
        ######################## Your Code Here ###########################
        # TO DO: Compute the (Riemannian) length of the curve, using the Riemannian metrics G obtained above.
        # length should be a tensor consiting of a single real number.
        
        # length = ##### Your Code Here #####
        
        ###################################################################
        return length

#### Geodesic Interpolation with Euclidean Metric

In [None]:
# define curve and optimizer
model_curve_identity = cubic_spline_curve(z1, z2, mean_MED, k, device, 'identity', lengths = n_control_points).to(device)
optimizer = torch.optim.Adam(model_curve_identity.parameters(), lr=learning_rate)
for epoch in range(epoch_curve):
    optimizer.zero_grad()
    loss = model_curve_identity.train_step(model, num_samples)
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        length = model_curve_identity.compute_length(model)
        print(f'(identity_geodesic) Epoch {epoch}: loss = {loss.item()}: length = {length}')
t_samples = torch.linspace(0, 1, steps=20).to(device)
z_identity_interpolates = model_curve_identity(t_samples).detach().cpu()
x_identity_interpolates = model.decode(z_identity_interpolates.to(torch.float32).to(device))

In [None]:
# latent space visualization - geodesic interpolation with Euclidean metric 
f = plt.figure()
plt.scatter(Z[:,0].detach().cpu(), Z[:,1].detach().cpu(), c=color_3d/255.0)
plt.scatter(z1[0].detach().cpu(), z1[1].detach().cpu(), c='r', marker='*', s=200)
plt.scatter(z2[0].detach().cpu(), z2[1].detach().cpu(), c='r', marker='*', s=200)
plt.plot(
    z_identity_interpolates[:, 0].detach().cpu(), 
    z_identity_interpolates[:, 1].detach().cpu(), 
    c='k',
    linewidth=3.0
)
plt.axis('equal')
plt.show()
plt.close()


#### Geodesic Interpolation with Fisher Information Metric

In [None]:
# define curve and optimizer
model_curve = cubic_spline_curve(z1, z2, mean_MED, k, device, 'information', lengths = n_control_points).to(device)
optimizer = torch.optim.Adam(model_curve.parameters(), lr=learning_rate)

# This part might take a while when using CPU (10~20 mins)
for epoch in range(epoch_curve):
    optimizer.zero_grad()
    loss = model_curve.train_step(model, num_samples)
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        length = model_curve.compute_length(model)
        print(f'(information_geodesic) Epoch {epoch}: loss = {loss.item()}: length = {length}')
t_samples = torch.linspace(0, 1, steps=20).to(device)
z_information_interpolates = model_curve(t_samples).detach().cpu()
x_information_interpolates = model.decode(z_information_interpolates.to(torch.float32).to(device))

In [None]:
# latent space visualization - geodesic interpolation with Fisher information metric
f = plt.figure()
plt.scatter(Z[:,0].detach().cpu(), Z[:,1].detach().cpu(), c=color_3d/255.0)
plt.scatter(z1[0].detach().cpu(), z1[1].detach().cpu(), c='r', marker='*', s=200)
plt.scatter(z2[0].detach().cpu(), z2[1].detach().cpu(), c='r', marker='*', s=200)
plt.plot(
    z_information_interpolates[:, 0].detach().cpu(), 
    z_information_interpolates[:, 1].detach().cpu(), 
    c='k',
    linewidth=3.0
)
plt.axis('equal')
plt.show()
plt.close()

#### (c) Can you interpret the latent space geodesics?
(answer here)

#### Visualizing the Results as Point Clouds
Using the decoder of the point cloud autoencoder, we will visualize the obtained interpolation results.

In [None]:
# coloring
color_linear_interp = latent_to_color(model, P, y, z_linear_interpolates.to(torch.float32).to(device), mode=mode)
color_identity_interp = latent_to_color(model, P, y, z_identity_interpolates.to(torch.float32).to(device), mode=mode)
color_information_interp = latent_to_color(model, P, y, z_information_interpolates.to(torch.float32).to(device), mode=mode)

# initialize
data_vis = dict()

# save point clouds               
data_vis['linear_interpolation'] = dict()
data_vis['identity_interpolation'] = dict()
data_vis['information_interpolation'] = dict()
data_vis['linear_interpolation']['pc'] = []
data_vis['linear_interpolation']['color'] = []
data_vis['identity_interpolation']['pc'] = []
data_vis['identity_interpolation']['color'] = []
data_vis['information_interpolation']['pc'] = []
data_vis['information_interpolation']['color'] = []
for pidx in range(num_interpolates_linear):
    data_vis['linear_interpolation']['pc'].append(np.asarray(x_linear_interpolates[pidx,:,:].detach().cpu()))
    data_vis['identity_interpolation']['pc'].append(np.asarray(x_identity_interpolates[pidx,:,:].detach().cpu()))
    data_vis['information_interpolation']['pc'].append(np.asarray(x_information_interpolates[pidx,:,:].detach().cpu()))
    data_vis['linear_interpolation']['color'].append(np.repeat(color_linear_interp[pidx:pidx+1,:].transpose(), x_linear_interpolates.shape[2], axis=1))
    data_vis['identity_interpolation']['color'].append(np.repeat(color_identity_interp[pidx:pidx+1,:].transpose(), x_identity_interpolates.shape[2], axis=1))
    data_vis['information_interpolation']['color'].append(np.repeat(color_information_interp[pidx:pidx+1,:].transpose(), x_information_interpolates.shape[2], axis=1))

In [13]:
# folder for saving images
folder_temp = 'interpolation_results/temp'
if not os.path.exists(folder_temp):
    os.makedirs(folder_temp)   

# ball configuration
radius = 0.03
resolution = 20

# image resolution
img_width = 640
img_height = 900

# make point cloud mesh
meshes_data = []
meshes_for_finding_pose = []
exp_name_list = []

for l, idx_data_ in enumerate(data_vis.keys()):
    exp_name_list.append(idx_data_)
    data_ = data_vis[idx_data_]
    points_list = data_['pc']
    colors_list = data_['color']

    for j in range(len(points_list)):
        points = points_list[j].transpose()
        colors = colors_list[j].transpose()

        for i in range(len(points)):
            mesh = o3d.geometry.TriangleMesh.create_sphere(radius = radius, resolution = resolution).translate((points[i,0], points[i,1], points[i,2]))
            mesh.paint_uniform_color([colors[i,0]/255, colors[i,1]/255, colors[i,2]/255])
            if i == 0:
                mesh_total = mesh
            else:
                mesh_total += mesh
        
        mesh_total.compute_vertex_normals()
        meshes_data.append(mesh_total)

        if l == 0 and (j == 0 or j == len(points_list) - 1):
            for i in range(len(points)):
                mesh = o3d.geometry.TriangleMesh.create_sphere(radius = radius, resolution = resolution).translate((points[i,0], points[i,1], points[i,2]))
                mesh.paint_uniform_color([colors[i,0]/255, colors[i,1]/255, colors[i,2]/255])
                if i == 0:
                    mesh_total = mesh
                else:
                    mesh_total += mesh
        
            mesh_total.compute_vertex_normals()
            meshes_for_finding_pose.append(mesh_total)

# save images
for j in range(len(meshes_data)):
    path_image = os.path.join(folder_temp, f'temp_{j}.png')
    render_pointcloud(meshes_data[j], 
                    visualize=False, 
                    camera_config=None,
                    return_camera_config=False,
                    save_path=path_image,
                    image_size = [img_width, img_height])

# split
splitedSize = 20
meshes_data_splited = [meshes_data[x:x+splitedSize] for x in range(0, len(meshes_data), splitedSize)]

# append images
images = []
counter = 0
for k, split_ in enumerate(meshes_data_splited):

    # figure
    fig = plt.figure()
    images_plt = []
        
    for j in range(len(split_)):
        im = mpimg.imread(os.path.join(folder_temp, f'temp_{counter}.png'))
        images.append(im)
        images_plt.append([plt.imshow(im, animated=True)])
        plt.axis('off')
        counter += 1
    
    plt.close(fig)

# visualize
images_numpy = np.array(images)
image_grid = gallery(images_numpy, ncols=splitedSize)
plt.figure(figsize=(15, 15))
plt.imshow(image_grid)
plt.close()

# 1st row: linear interpolation
# 2nd row: geodesic interpolation using Euclidean metric
# 3rd row: geodesic interpolation using Fisher information metric

#### (d) Can you explain how the shape of point cloud changes in each interpolation? What is the difference between geodesic interpolation with Euclidean metric and Fisher information metric?
(answer here)