In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np 
import matplotlib.pyplot as plt

import open3d as o3d

import pytorch3d as p3d
from pytorch3d.transforms.so3 import *

print('pytorch  :', torch.__version__)
print('open3d   :', o3d.__version__)
print('pytorch3d:', p3d.__version__)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
pytorch  : 1.9.0
open3d   : 0.13.0
pytorch3d: 0.3.0


### Helper function 

In [2]:
def draw_two_pointcloud(source, target):
    source.paint_uniform_color([1, 0.706, 0])
    target.paint_uniform_color([0, 0.651, 0.929])
    o3d.visualization.draw_geometries([source, target])

### Load original data

In [3]:
# load the original shape (ref: http://graphics.stanford.edu/data/3Dscanrep/) 
bunny_original = o3d.io.read_point_cloud('bun_zipper_res2.ply')
bunny_original_xyz = torch.Tensor(np.asarray(bunny_original.points))
o3d.visualization.draw_geometries([bunny_original])

### Random transformation to the data 

In [17]:
# apply the random rotation 
random_rot = p3d.transforms.random_rotation().squeeze().numpy()
random_trans = 0.2*torch.rand(3).numpy()
random_tf = np.eye(4,4)
random_tf[:3, :3] = random_rot
random_tf[:3, -1] = random_trans
print('Random transform:\n', random_tf)

bunny_transformed = o3d.io.read_point_cloud('bun_zipper_res2.ply')
bunny_transformed.transform(random_tf) # in-place transformation 

# add noise 
bunny_transformed_xyz = torch.Tensor(np.asarray(bunny_transformed.points))
bunny_transformed_xyz = bunny_transformed_xyz + 0.01*torch.rand(bunny_transformed_xyz.shape)

# draw 
bunny_transformed.points = o3d.utility.Vector3dVector(bunny_transformed_xyz.detach().numpy())
draw_two_pointcloud(bunny_original, bunny_transformed)


Random transform:
 [[ 0.15794468  0.98420227 -0.07999662  0.1610954 ]
 [ 0.15273811  0.0556879   0.98669648  0.13217635]
 [ 0.97556376 -0.16806208 -0.14152956  0.11073001]
 [ 0.          0.          0.          1.        ]]


### Define the model 
- to learn the related transformation between the original and the transformed clouds 

In [18]:
class RealiveTransform(nn.Module):
    def __init__(self):
        super(RealiveTransform, self).__init__()
        self.rot_part1 = nn.Parameter(torch.rand(3, 3)) 
        self.rot_part2 = nn.Parameter(torch.rand(3, 1024)) 
        self.rot_part3 = nn.Parameter(torch.rand(1024, 3)) 
        
        self.trans_part1 = nn.Parameter(torch.rand(3, 3))
        self.trans_part2 = nn.Parameter(torch.rand(3, 1024))
        self.trans_part3 = nn.Parameter(torch.rand(1024, 1))
         
    def forward(self, x):
        # x: n x 3 
        # y: n x 3 
        rotmat = torch.mm(torch.mm(self.rot_part1, self.rot_part2), self.rot_part3)
        transvec = torch.mm(torch.mm(self.trans_part1, self.trans_part2), self.trans_part3)
        y = (torch.mm(rotmat, x.t()) + transvec).t()
        return y
    
    def get_tf(self):
        rotmat = torch.mm(torch.mm(self.rot_part1, self.rot_part2), self.rot_part3)
        transvec = torch.mm(torch.mm(self.trans_part1, self.trans_part2), self.trans_part3)
        return {'R': rotmat, 't': transvec}

### Initial verification of the estimated transformation before the optimization 

In [19]:
realive_transform = RealiveTransform()
bunny_registered_xyz = realive_transform(bunny_transformed_xyz)

estimated_relative_tf = realive_transform.get_tf()
print('Estimated R:\n', estimated_relative_tf['R'])
print('Estimated t:\n', estimated_relative_tf['t'].t())

bunny_registered = o3d.geometry.PointCloud()
bunny_registered.points = o3d.utility.Vector3dVector(bunny_registered_xyz.detach().numpy())
draw_two_pointcloud(bunny_original, bunny_registered)

Estimated R:
 tensor([[655.1071, 663.0014, 652.4586],
        [356.2054, 359.5293, 356.0751],
        [604.0523, 611.2412, 602.0819]], grad_fn=<MmBackward>)
Estimated t:
 tensor([[281.7417, 476.4532, 647.5089]], grad_fn=<TBackward>)


### Fit the data (i.e., learning of the relative transformation)

In [20]:
# realive_transform = realive_transform.cpu()
# bunny_original_xyz = bunny_original_xyz.cpu()

# optimizer 
optimizer = optim.Adam(realive_transform.parameters(), lr=0.01)

# run optimization 
for i in range(50000):
    optimizer.zero_grad()
    bunny_registered_xyz_est = realive_transform(bunny_transformed_xyz)
    
    # assumption: point correspondonces are known (so simply applying same-row l1(or l2) subtraction for the loss)
    use_l1loss = True
    if use_l1loss:
        loss = torch.abs(bunny_registered_xyz_est - bunny_original_xyz).sum(dim=1).sum() # L1 
    else:
        loss = torch.square(torch.abs(pts_est - pts_orig).sum(dim=1)).sum() # L2

    loss.backward()
    optimizer.step()

    if i % 10000 == 0:
        loss, current = loss.item(), i * len(bunny_transformed_xyz)
        print(f"loss: {loss:>10f}")

# result 
estimated_relative_tf = realive_transform.get_tf()
print('True R     :\n', random_rot)
print('True t     :\n', random_trans)
print('Estimated R:\n', estimated_relative_tf['R'])
print('Estimated t:\n', estimated_relative_tf['t'].t())

loss: 17754754.000000
loss: 184.976654
loss:  63.152931
loss:  63.836708
loss:  60.856045
True R     :
 [[ 0.15794468  0.98420227 -0.07999662]
 [ 0.15273811  0.0556879   0.9866965 ]
 [ 0.97556376 -0.16806208 -0.14152956]]
True t     :
 [0.1610954  0.13217635 0.11073001]
Estimated R:
 tensor([[ 0.1535,  0.1491,  0.9681],
        [ 0.9746,  0.0546, -0.1683],
        [-0.0832,  0.9685, -0.1402]], grad_fn=<MmBackward>)
Estimated t:
 tensor([[-0.1577, -0.1490, -0.1023]], grad_fn=<TBackward>)


### Verification of the estimated results (i.e., estimated registered point cloud)

In [21]:
bunny_registered_xyz_est = realive_transform(bunny_transformed_xyz)
bunny_registered_xyz_est = bunny_registered_xyz_est.detach().cpu().numpy()

bunny_registered_est = o3d.geometry.PointCloud()
bunny_registered_est.points = o3d.utility.Vector3dVector(bunny_registered_xyz_est)
draw_two_pointcloud(bunny_original, bunny_registered_est)