# Registration Pipeline

### The point cloud registration pipeline takes the following inputs:
- Point cloud: outputs of depth estimation pipeline, *any format compatible with Open3D,* **more robust results with colors and normals; also uniform scaling if applicable**

### Process:

$\rightarrow$ *Point clouds are downsampled and points are randomly selected using Latin Hypercube Sampling*

$\rightarrow$

### Tunable parameters:

The following should be imported for the registration 

In [None]:
import os
import time
import numpy as np
import open3d as o3d
from scipy.stats import qmc
from scipy.spatial import cKDTree
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment

def procrustes(X, Y):
    """
    Compute the optimal rotation R and translation t that aligns X to Y.
    X, Y: (N, 3) arrays of corresponding points.
    Returns R (3x3), t (3,).
    """
    # Compute centroids
    centroid_X = X.mean(axis=0)
    centroid_Y = Y.mean(axis=0)
    # Center the points
    X_centered = X - centroid_X
    Y_centered = Y - centroid_Y
    # Cross-covariance matrix
    H = X_centered.T @ Y_centered
    # SVD
    U, _, Vt = np.linalg.svd(H)
    # Compute rotation
    R = Vt.T @ U.T
    # Ensure a proper rotation (determinant = +1)
    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt.T @ U.T
    # Compute translation
    t = centroid_Y - centroid_X @ R.T
    return R, t

def sinkhorn(a, b, C, reg, numItermax=1000, tol=1e-9):
    """
    Compute the entropic regularized OT plan using Sinkhorn iterations.
    a, b: weight vectors (summing to 1) for source and target.
    C: cost matrix (n x m)
    reg: regularization strength (epsilon)
    """
    K = np.exp(-C / reg)
    u = np.ones_like(a)
    v = np.ones_like(b)
    for _ in range(numItermax):
        u_prev = u.copy()
        u = a / (K.dot(v))
        v = b / (K.T.dot(u))
        if np.linalg.norm(u - u_prev, 1) < tol:
            break
    return np.diag(u) @ K @ np.diag(v)

def register_point_clouds(X, Y):
    """
    Registers point cloud X to Y using the assignment-based OT (Hungarian algorithm)
    and Procrustes analysis.
    X: (N, 3), Y: (M, 3)
    Returns:
      - X_reg: (N, 3) aligned version of X
      - R: (3, 3) rotation matrix
      - t: (3,) translation vector
    """
    # Compute cost matrix (squared Euclidean distances)
    C = cdist(X, Y, 'sqeuclidean')
    # Solve assignment problem
    row_ind, col_ind = linear_sum_assignment(C)
    # Matched target points
    Y_matched = Y[col_ind]
    # Compute best-fit transform
    R, t = procrustes(X[row_ind], Y_matched)
    # Apply transform to all of X
    return R, t

def register_sinkhorn(X, Y, reg=0.1):
    """
    Register point cloud X to Y using entropic OT + Procrustes.
    Returns aligned X, rotation R, translation t, and transport plan pi.
    """
    n, m = X.shape[0], Y.shape[0]
    a = np.ones(n) / n
    b = np.ones(m) / m
    C = cdist(X, Y, 'sqeuclidean')
    pi = sinkhorn(a, b, C, reg)
    # Barycentric projection
    Y_bar = pi.dot(Y) / pi.sum(axis=1, keepdims=True)
    R, t = procrustes(X, Y_bar)
    return R, t

def compute_similarity_transform(source, target):
	"""
	Computes scale s, rotation R, and translation t such that:
	target ≈ s * R @ source + t
	"""
	# Center the points
	src_centroid = np.mean(source, axis=0)
	tgt_centroid = np.mean(target, axis=0)
	src_centered = source - src_centroid
	tgt_centered = target - tgt_centroid
	# Compute scale factor
	src_var = np.sum(np.square(src_centered))
	tgt_var = np.sum(np.square(tgt_centered))
	scale = np.sqrt(tgt_var / src_var) if src_var > 0 else 1.0
	# Compute optimal rotation using SVD
	H = src_centered.T @ tgt_centered
	U, _, Vt = np.linalg.svd(H)
	R = Vt.T @ U.T
	# Ensure a proper rotation matrix
	if np.linalg.det(R) < 0:
		Vt[-1, :] *= -1
		R = Vt.T @ U.T
	# Compute translation
	t = tgt_centroid - scale * (R @ src_centroid)
	return scale, R, t

In [None]:
pointclouds = os.listdir('IGP_filtered')
pointclouds.sort()
voxel_size = 1e-6
LHS_count = 10000
radius = voxel_size*5 #neighbourhood radius
s_tol = 0.01 #similarity tolerance
min_cutoff = 1e-5 #minimum distance to camera
pcd2 = o3d.io.read_point_cloud(os.path.join('IGP_filtered',pointclouds[0]))
pcd2 = pcd2.voxel_down_sample(voxel_size)
R_c = np.eye(3)
original_normals = []

i = 0
for path in pointclouds:
	print('processing pointcloud: ',path)
	path = os.path.join('IGP_filtered',path)
	sampler = qmc.LatinHypercube(3)
	start = time.time()

	points2 = np.asarray(pcd2.points)
	gnd_level = points2[:,2].min()

	pcd1 = o3d.io.read_point_cloud(path)
	pcd1 = pcd1.voxel_down_sample(voxel_size)
	points1 = np.asarray(pcd1.points)
	points1[:,2] = points1[:,2]-(points1[:,2].min()-gnd_level)
	min_bounds = points1.min(axis=0)
	max_bounds = points1.max(axis=0)
	#print(min_bounds,max_bounds)
	lhs_samples1 = sampler.random(LHS_count)
	samples_scaled1 = min_bounds + lhs_samples1 * (max_bounds - min_bounds)

	# adaptive search box
	min_bounds = points2.min(axis=0) if i == 0 else newmin_bounds
	max_bounds = points2.max(axis=0) if i == 0 else newmax_bounds
	#print(min_bounds,max_bounds)
	#TODO: truncate points close to camera (more likely to be distorted)

	lhs_samples2 = sampler.random(LHS_count)
	samples_scaled2 = min_bounds + lhs_samples2 * (max_bounds - min_bounds)
	
	
	tree = cKDTree(points1)
	_, idx1 = tree.query(samples_scaled1, k=1)
	sampled_points = points1[idx1]
	colors = np.asarray(pcd1.colors)
	colors = np.arctan2(np.sqrt(3) * (colors[:, 1] - colors[:, 2]), 2 * colors[:, 0] - colors[:, 1] - colors[:, 2])
	sampled_colors = colors[idx1]

	tree2 = cKDTree(points2)
	_, idx2 = tree2.query(samples_scaled2, k=1)
	sampled_points2 = points2[idx2]
	colors2 = np.asarray(pcd2.colors)
	colors2 = np.arctan2(np.sqrt(3) * (colors2[:, 1] - colors2[:, 2]), 2 * colors2[:, 0] - colors2[:, 1] - colors2[:, 2])
	sampled_colors2 = colors2[idx2]

	normals = []
	normals2 = []
	densities = []
	densities2 = []
	for p in sampled_points:
		idxs = tree.query_ball_point(p, r=radius)
		neighbors = points1[idxs]
		densities.append(len(neighbors))
		if len(neighbors) >= 3: # Compute normal via PCA
			neighbors_centered = neighbors - neighbors.mean(axis=0)
			cov = neighbors_centered.T @ neighbors_centered
			eigvals, eigvecs = np.linalg.eigh(cov)
			normal = eigvecs[:, 0]  # smallest eigenvector → normal
		else:
			normal = np.array([0, 0, 0])  # or np.nan
		normals.append(normal)

	for p in sampled_points2:
		idxs = tree2.query_ball_point(p, r=radius)
		neighbors = points2[idxs]
		densities2.append(len(neighbors))
		if len(neighbors) >= 3:
			neighbors_centered = neighbors - neighbors.mean(axis=0)
			cov = neighbors_centered.T @ neighbors_centered
			eigvals, eigvecs = np.linalg.eigh(cov)
			normal = eigvecs[:, 0]  # smallest eigenvector → normal
		else:
			normal = np.array([0, 0, 0])  # or np.nan
		normals2.append(normal)
	
	#normals = normals @ R_c.T

	sampled_points = np.asarray(sampled_points).reshape(-1, 3)
	normals = np.asarray(normals).reshape(-1, 3)

	sampled_points2 = np.asarray(sampled_points2).reshape(-1, 3)
	normals2 = np.asarray(normals2).reshape(-1, 3)

	densities = np.asarray(densities).reshape(-1, 1)
	#rank-based density
	#densities = np.argsort(np.argsort(densities),axis=0)
	#densities = densities / (len(densities) - 1)
	sampled_colors = np.asarray(sampled_colors).reshape(-1, 1)*100
	
	densities2 = np.asarray(densities2).reshape(-1, 1)
	#densities2 = np.argsort(np.argsort(densities2),axis=0)
	#densities2 = densities2 / (len(densities2) - 1)
	sampled_colors2 = np.asarray(sampled_colors2).reshape(-1, 1)*100


	# Concatenate everything: [x, y, z, nx, ny, nz, density, r, g, b]
	combined1 = np.hstack([sampled_points, normals, sampled_colors])
	combined2 = np.hstack([sampled_points2, normals2, sampled_colors2])

	arr_view = combined1.view([('', combined1.dtype)] * combined1.shape[1])
	unique_rows, counts = np.unique(arr_view, return_counts=True)
	unique_only = unique_rows[counts == 1]
	combined1 = unique_only.view(combined1.dtype).reshape(-1, combined1.shape[1])

	arr_view = combined2.view([('', combined2.dtype)] * combined2.shape[1])
	unique_rows, counts = np.unique(arr_view, return_counts=True)
	unique_only = unique_rows[counts == 1]
	combined2 = unique_only.view(combined2.dtype).reshape(-1, combined2.shape[1])

	# Map similar points based on normals, hue
	matches = []
	for p in combined1:
		p2 = p
		p = p[2:]
		c_temp = combined2[:,2:]-p
		if len(c_temp > 0):
			c_temp = np.sqrt(np.sum(c_temp**2,axis=1))
			ms = np.min(c_temp[c_temp>0])
			idx = np.where(c_temp == ms)[0][0]
			if ms < s_tol:
				mapped = np.hstack((p2,combined2[idx]))
				matches.append(mapped)
			#TODO: update coordinate system after each transform --> rotate normals to match new point cloud (or make register of original normals)
			combined2 = np.delete(combined2,idx,axis=0)
	matches = np.array(matches)
	print(len(matches))

	##Calculate transform

	s = 1
	#R, t = register_point_clouds(matches[:,0:3], matches[:,7:10]) #optimal transport

	R,t = register_sinkhorn(matches[:,0:3], matches[:,7:10], reg=1.8) #entropic optimal transport
	#TODO: match camera level for all instances

	#s, R, t = compute_similarity_transform(matches[:,0:3], matches[:,7:10])

	#R_c = R@R_c #cumulative rotation
	transformed_points = s * (points1 @ R.T) + t
	pcd1.points = o3d.utility.Vector3dVector(transformed_points)
	#pcd1c = pcd1.paint_uniform_color([1,0,0])
	#pcd2c = pcd2.paint_uniform_color([0,1,0])
	pcd3 = o3d.geometry.PointCloud() #target
	pcd3.points = o3d.utility.Vector3dVector(matches[:,0:3])
	pcd3 = pcd3.paint_uniform_color([0,1,0])
	pcd4_points = s * (matches[:,7:10] @ R.T) + t
	pcd4 = o3d.geometry.PointCloud() #source, moved
	pcd4.points = o3d.utility.Vector3dVector(pcd4_points)
	pcd4 = pcd4.paint_uniform_color([1,1,0])
	pcd5 = o3d.geometry.PointCloud() #source
	pcd5.points = o3d.utility.Vector3dVector(matches[:,7:10])
	pcd5 = pcd5.paint_uniform_color([1,0,0])

	# update search bounds (pcd1 moves)
	newmin_bounds = transformed_points.min(axis=0)
	newmax_bounds = transformed_points.max(axis=0)

	merged = pcd1 + pcd2
	merged.remove_statistical_outlier(nb_neighbors=10,std_ratio=0.1)
	pcd2 = merged
	end = time.time()
	print('processing time:',end-start,'seconds')
	i+=1
	if i > 15:
		break

o3d.visualization.draw_geometries([merged,pcd3,pcd4,pcd5])

## workflow
# get points with LHS
# match based on density, normals (max.neighbors) and color
# minimize KL divergence of similar points based on location
# adaptive sample count based on point cloud dimensions