In [None]:
from src.icp import icp

import torch
import open3d as o3d
import numpy as np
import copy
from chamferdist import ChamferDistance
from pytorch3d.transforms import euler_angles_to_matrix
from tqdm import tqdm_notebook as tqdm


In [None]:
def draw_registration_result(source, target, transformation):
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)
    source_temp.paint_uniform_color([1, 0.706, 0])
    target_temp.paint_uniform_color([0, 0.651, 0.929])
    source_temp.transform(transformation)
    o3d.visualization.draw_geometries([source_temp, target_temp],
                                      zoom=0.4459,
                                      front=[0.9288, -0.2951, -0.2242],
                                      lookat=[1.6784, 2.0612, 1.4451],
                                      up=[-0.3402, -0.9189, -0.1996])

In [None]:
# load and subsample

demo_icp_pcds = o3d.data.DemoICPPointClouds().paths

target_cloud = o3d.io.read_point_cloud(demo_icp_pcds[1])

trans_init = np.asarray([[0.862, 0.011, -0.507, 0.5],
                         [-0.139, 0.967, -0.215, 0.7],
                         [0.487, 0.255, 0.835, -1.4], [0.0, 0.0, 0.0, 1.0]])
source_cloud = o3d.io.read_point_cloud(demo_icp_pcds[0]).transform(trans_init)

target_cloud_down = target_cloud.voxel_down_sample(voxel_size=0.02)
source_cloud_down = source_cloud.voxel_down_sample(voxel_size=0.02)


threshold = 0.2

In [None]:
# icp

t_init = np.asarray([[1.,0.,0.,0.], [0.,1.,0.,0.], [0.,0.,1.,0.], [0.,0.,0.,1.]])
t, _ = icp(source_cloud_down, target_cloud_down, threshold, t)
print(t)

In [None]:
draw_registration_result(source_cloud_down, target_cloud_down, t)

In [None]:
draw_registration_result(source_cloud, target_cloud, trans_init)

In [None]:
# convert to gpu tensors
cuda = torch.device('cuda')
source = np.array(source_cloud_down.points)
target = np.array(target_cloud_down.points)
source = torch.tensor([source], device=cuda)
target = torch.tensor([target], device=cuda)
#trans_init = torch.tensor([trans_init], requires_grad=True, device=cuda)


In [None]:
print(source.shape)

In [None]:
#TODO: modify chamfer distance to ignore points beyond a threshold
def chamfer_registration(source, target, iterations, step_size, cuda):
    
    # params = [x_translation, y_translation, z_translation, x_rotation, y_rotation, z_rotation]
    params = torch.zeros((source.shape[0], 6), requires_grad=True, device=cuda)
    optimiser = torch.optim.Adam([params], lr=step_size)
    chamferDist = ChamferDistance()
    
    for i in tqdm(range(iterations)):
        # preparation
        optimiser.zero_grad()
        tensor_1 = torch.ones((source.shape[0], source.shape[1]), device=cuda)
        source_t = torch.cat((source, tensor_1.unsqueeze(2)), dim=-1)
        source_t = torch.transpose(source_t, 1, 2)
        #print(source_t.shape, tensor_1.shape, trans_init.shape)
        
        # setup transformation matrix
        tensor_0 = torch.zeros((source.shape[0], 3), device=cuda)
        transform = euler_angles_to_matrix(params[:, :3], "XYZ")
        transform = torch.cat((transform, tensor_0.unsqueeze(2)), dim =-1)
        tensor_1_vector = torch.ones(source.shape[0], device=cuda)
        translation = torch.cat((params[:, 3:],  tensor_1_vector.unsqueeze(-1)), dim=-1)
        print("trans",translation.unsqueeze(1))
        #print("trans",params[:, 3:].shape, tensor_1_vector.unsqueeze(-1).shape)

        transform = torch.cat((transform, translation.unsqueeze(1)), dim =-2)
        transform = torch.transpose(transform, 1, 2).double()
        #print("trans", transform)
            
        # transformation
        source_t = torch.bmm(transform, source_t)
        source_t = torch.transpose(source_t, 1, 2)[:, :, :-1]
        
        # optimisation
        chamfer_loss = chamferDist(target, source_t, bidirectional=False)
        chamfer_loss.backward()
        optimiser.step()
        
        print(i, "loss", chamfer_loss)
        
    return source_t, transform
    
        

In [None]:
transformed_source, trans = chamfer_registration(source, target, 10,  0.01, cuda)

In [None]:
trans = trans[0].detach().cpu().numpy()

In [None]:
print(trans)

In [None]:
draw_registration_result(source_cloud_down, target_cloud_down, trans)

In [None]:
draw_registration_result(source_cloud_down, target_cloud_down, t_init)