In [1]:
# http://nghiaho.com/?page_id=671
#!/usr/bin/env python3

import numpy as np


def rigid_transform_3D(A, B):
    assert A.shape == B.shape
    num_rows, num_cols = A.shape
    if num_rows != 3:
        raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")

    num_rows, num_cols = B.shape
    if num_rows != 3:
        raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")

    # find mean column wise
    centroid_A = np.mean(A, axis=1)
    centroid_B = np.mean(B, axis=1)

    # ensure centroids are 3x1
    centroid_A = centroid_A.reshape(-1, 1)
    centroid_B = centroid_B.reshape(-1, 1)

    # subtract mean
    Am = A - centroid_A
    Bm = B - centroid_B

    H = Am @ np.transpose(Bm)

    # sanity check
    #if linalg.matrix_rank(H) < 3:
    #    raise ValueError("rank of H = {}, expecting 3".format(linalg.matrix_rank(H)))

    # find rotation
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    # special reflection case
    if np.linalg.det(R) < 0:
        print("det(R) < R, reflection detected!, correcting for it ...")
        Vt[2,:] *= -1
        R = Vt.T @ U.T

    t = -R @ centroid_A + centroid_B

    return R, t

# Test with random data

# Random rotation and translation
R = np.random.rand(3,3)
t = np.random.rand(3,1)

# make R a proper rotation matrix, force orthonormal
U, S, Vt = np.linalg.svd(R)
R = U@Vt

# remove reflection
if np.linalg.det(R) < 0:
   Vt[2,:] *= -1
   R = U@Vt

# number of points
n = 8

A = np.random.rand(3, n)
B = R@A + t

# Recover R and t
ret_R, ret_t = rigid_transform_3D(A, B)

# Compare the recovered R and t with the original
B2 = (ret_R@A) + ret_t

# Find the root mean squared error
err = B2 - B
err = err * err
err = np.sum(err)
rmse = np.sqrt(err/n)

print("Points A")
print(A)
print("")

print("Points B")
print(B)
print("")

print("Ground truth rotation")
print(R)

print("Recovered rotation")
print(ret_R)
print("")

print("Ground truth translation")
print(t)

print("Recovered translation")
print(ret_t)
print("")

print("RMSE:", rmse)

if rmse < 1e-5:
    print("Everything looks good!")
else:
    print("Hmm something doesn't look right ...")

Points A
[[0.54353087 0.57959465 0.81008188 0.34982259 0.88681727 0.23576768
  0.65678765 0.66229526]
 [0.60376219 0.68297583 0.42841209 0.94985135 0.77995557 0.31729744
  0.55127746 0.6405488 ]
 [0.00908108 0.58413107 0.45067231 0.65478524 0.05409284 0.92200252
  0.7609395  0.82168456]]

Points B
[[ 0.79166404  0.64067153  0.51390048  0.80362656  1.00195565  0.10831704
   0.46841882  0.52511821]
 [ 0.09361117  0.4740465   0.13233697  0.78032266 -0.01876615  0.75054514
   0.48167048  0.55576321]
 [ 0.81290499  1.22609756  1.27996776  1.13309041  1.11960434  1.14709122
   1.38368982  1.43832602]]

Ground truth rotation
[[ 0.20618003  0.89364687 -0.3986039 ]
 [-0.63342335  0.43238159  0.64173282]
 [ 0.74583152  0.12017253  0.65520524]]
Recovered rotation
[[ 0.20618003  0.89364687 -0.3986039 ]
 [-0.63342335  0.43238159  0.64173282]
 [ 0.74583152  0.12017253  0.65520524]]

Ground truth translation
[[0.1436684 ]
 [0.17101303]
 [0.32901693]]
Recovered translation
[[0.1436684 ]
 [0.17101303]
