### Include registration modules

In [106]:
import copy

from functools import partial
from pathlib import Path
from time import time
from typing import Callable

import numpy as np
import open3d.geometry as geom
import open3d.pipelines.registration as reg
import open3d.visualization as vis
import open3d.utility as util

from benthoscan.runtime import Environment, load_environment

from benthoscan.spatial import (
    MultiTargetIndex,
    PointCloud,
    PointCloudLoader,
    read_point_cloud,
    downsample_point_cloud,
    estimate_point_cloud_normals,
    generate_cascade_indices,
)

from benthoscan.spatial import (
    ExtendedRegistrationResult,
    register_point_cloud_fphp_fast,
    register_point_cloud_fphp_ransac,
    register_point_cloud_icp,
    register_point_cloud_graph,
)

# from benthoscan.tasks.registration import RegistrationTaskConfig

from benthoscan.utils.log import logger

### Load environment and configure data loaders

In [2]:
environment: Environment = load_environment()

DATA_DIR: Path = Path("/home/martin/dev/benthoscan/.cache")

point_cloud_files: dict = {
    0: DATA_DIR / Path("qdc5ghs3_20100430_024508.ply"),
    1: DATA_DIR / Path("qdc5ghs3_20120501_033336.ply"),
    2: DATA_DIR / Path("qdc5ghs3_20130405_103429.ply"),
    3: DATA_DIR / Path("qdc5ghs3_20210315_230947.ply"),
}

loaders: dict = {
    key: partial(read_point_cloud, path=path) for key, path in point_cloud_files.items()
}

count = len(loaders)
if count < 2:
    logger.error(f"invalid number of point clouds for registration: {count}")

### Configure preprocessor and registrator

In [39]:
# NOTE: Parameters - move to config
VOXEL_SIZE: float = 0.20

# FPFH feature parameters
FEATURE_RADIUS: float = 2.00 # NOTE: 2.00
FEATURE_NEIGHBOURS: int = 500

# Validation and estimation
CORRESPONDENCE_DISTANCE: float = 0.30
EDGE_LENGTH: float = 0.95
NORMAL_ANGLE: float = 5.0
ESTIMATE_SCALE: bool = True

# RANSAC parameters
SAMPLE_COUNT: int = 3
MAX_ITERATIONS: int = 200000
CONFIDENCE: float = 0.9999 # NOTE: 0.999


def downsample_and_estimate_normals(cloud: PointCloud, voxel_size: float) -> PointCloud:
    """Preprocesses a point cloud by downsampling and estimating the normals."""
    downsampled: PointCloud = downsample_point_cloud(cloud, spacing=voxel_size)
    preprocessed: PointCloud = estimate_point_cloud_normals(downsampled)
    return preprocessed
    

def registration_worker(
    source_loader: PointCloudLoader, 
    target_loader: PointCloudLoader,
    preprocessor: Callable[[PointCloud], PointCloud],
    registrator: Callable[[PointCloud, PointCloud], ExtendedRegistrationResult],
) -> ExtendedRegistrationResult:
    """Worker function that performs the point cloud registration."""
    
    source_cloud: PointCloud = source_loader().unwrap()
    target_cloud: PointCloud = target_loader().unwrap()
    
    preprocessed_source: PointCloud = preprocessor(source_cloud)
    preprocessed_target: PointCloud = preprocessor(target_cloud)

    result: ExtendedRegistratorResult = registrator(
        source=preprocessed_source, 
        target=preprocessed_target
    )

    return result


preprocessor: Callable[[PointCloud], PointCloud] = partial(
    downsample_and_estimate_normals, 
    voxel_size=VOXEL_SIZE
)

registrator = partial(
    register_point_cloud_fphp_ransac,
    distance_threshold=CORRESPONDENCE_DISTANCE,
    feature_radius=FEATURE_RADIUS,
    feature_neighbours=FEATURE_NEIGHBOURS,
    max_iterations=MAX_ITERATIONS,
    confidence=CONFIDENCE,
    sample_count=SAMPLE_COUNT,
    edge_check=EDGE_LENGTH,
    normal_check=NORMAL_ANGLE,
    scaling = ESTIMATE_SCALE,
)

### Define visualization helpers

In [163]:
def visualize_registration(
    source: PointCloud, 
    target: PointCloud, 
    transformation: np.ndarray, 
    source_color: list=None,
    target_color: list=None,
    title: str="",
    window_width: int=1024,
    window_height: int=768,
) -> None:
    """TODO"""
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)

    if source_color:
        source_temp.paint_uniform_color(source_color)
    if target_color:
        target_temp.paint_uniform_color(target_color)
    
    source_temp.transform(transformation)
    
    vis.draw_geometries(
        geometry_list=[source_temp, target_temp], 
        window_name=title, 
        width=window_width, 
        height=window_height,
    )


def batch_visualize_registration(
    storage: dict[int, dict], 
    source_color: list=[0.60, 0.20, 0.20],
    target_color: list=[0.20, 0.20, 0.60],
) -> None:
    """TODO"""
    for source, registrations in storage.items():
        
        source_cloud: PointCloud = loaders[source]().unwrap()
        
        for target, result in registrations.items():
                
            target_cloud: PointCloud = loaders[target]().unwrap()
            
            visualize_registration(
                source = source_cloud,
                target = target_cloud,
                transformation = result.transformation,
                source_color = source_color,
                target_color = target_color,
                title = f"Source: {source}, target: {target}"
            )

In [58]:
def log_registration(source, target, result, elapsed: float=0.0) -> None:
    logger.info("")
    logger.info(f"--------------------- Registration --------------------")
    logger.info(f" - Source, target:        {source}, {target}")
    logger.info(f" - Elapsed time:          {elapsed}")
    logger.info(f" - RMSE:                  {result.inlier_rmse}")
    logger.info(f" - Fitness:               {result.fitness}")
    logger.info(f" - Correspondences:       {len(result.correspondence_set)}")
    logger.info(f" - Transformation:        {result.transformation}")
    logger.info(f"-------------------------------------------------------")
    logger.info("")

### Test registrator parameters on a single case

In [40]:
# Select some point clouds to tune the registration - 0-3 is the really hard case, but 1-3 and 2-3 are also challenging
test_source: int = 0
test_target: int = 3


for i in range(10):
    test_result: ExtendedRegistrationResult = registration_worker(
        source_loader = loaders[test_source],
        target_loader = loaders[test_target],
        preprocessor = preprocessor,
        registrator = registrator,
    )
    
    logger.info(test_result)


[32m2024-07-02 09:29:55.839[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m14[0m - [1mExtendedRegistrationResult(fitness=0.48880558868631796, inlier_rmse=0.13679054134846724, correspondence_set=std::vector<Eigen::Vector2i> with 14344 elements.
Use numpy.asarray() to access data., transformation=array([[ 0.49414566, -0.84361364, -0.00942538,  1.74248444],
       [ 0.84366088,  0.49407253,  0.00902175,  2.82559043],
       [-0.00302134, -0.01269256,  0.9776413 , -0.66247948],
       [ 0.        ,  0.        ,  0.        ,  1.        ]]), information=array([[ 3.18231229e+07, -2.05402474e+05, -2.90881371e+04,
         0.00000000e+00,  6.63789273e+05,  8.01238890e+04],
       [-2.05402474e+05,  3.13542570e+07,  3.73102438e+06,
        -6.63789273e+05,  0.00000000e+00, -9.93380155e+02],
       [-2.90881371e+04,  3.73102438e+06,  1.69689677e+06,
        -8.01238890e+04,  9.93380155e+02,  0.00000000e+00],
       [ 0.00000000e+00, -6.63789273e+05, -8.01238890e+04,
        

In [30]:
visualize_test: bool = True

if visualize_test:
    visualize_registration(
        source = loaders[test_source]().unwrap(), 
        target = loaders[test_target]().unwrap(), 
        transformation = test_result.transformation,
        title = f"Test case: {test_source}, {test_target}",
    )

### Generate indices and perform registration

In [None]:
indices: list[MultiTargetIndex] = generate_cascade_indices(len(loaders))

result_storage = {
    "feature_matching": dict(),
    "incremental_coarse": dict(),
}

In [36]:
result_storage: dict[int, dict] = dict()

for index in indices:
    source = index.source

    results: dict[int, ExtendedRegistrationResult] = dict()

    for target in index.targets:
        start: float = time()
        
        result: ExtendedRegistrationResult = registration_worker(
            source_loader = loaders[source],
            target_loader = loaders[target],
            preprocessor = preprocessor,
            registrator = registrator,
        )

        end: float = time()

        elapsed: float = end - start

        log_registration(source, target, result, elapsed)

        results[target] = result

    result_storage["feature_matching"][source] = results

[32m2024-07-01 22:24:47.124[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1m[0m
[32m2024-07-01 22:24:47.125[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1m--------------------- FPFH registration --------------------[0m
[32m2024-07-01 22:24:47.126[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m - Source, target:        0, 1[0m
[32m2024-07-01 22:24:47.127[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1m - Elapsed time:          35.294801235198975[0m
[32m2024-07-01 22:24:47.128[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1m - RMSE:                  0.12520877821931495[0m
[32m2024-07-01 22:24:47.128[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1m - Fitness:               0.8745271766910888[0m
[32m2024-07-01 22:24:47.129[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[

### Draw registration and plot results
- TODO: Draw registered point clouds
- TODO: Plot point clouds, correspondences, and error distribution

In [37]:
batch_visualize_registration(storage = result_storage["feature_matching"])

### Perform coarse incremental registration

In [208]:
icp_coarse_parameters = dict()
icp_coarse_parameters["voxel"] = 0.05
icp_coarse_parameters["distance"] = 0.50
icp_coarse_parameters["estimator"] = regi.TransformationEstimationPointToPlane()

icp_preprocessor: Callable[[PointCloud], PointCloud] = partial(
    downsample_point_cloud,
    spacing=icp_coarse_parameters["voxel"],
)

icp_registrator = partial(
    register_point_cloud_icp,
    distance_threshold = icp_coarse_parameters["distance"],
    distance_measure = icp_coarse_parameters["estimator"],
)

In [209]:
result_storage["incremental_coarse"] = dict()


for index in indices:
    source = index.source

    results: dict[int, ExtendedRegistrationResult] = dict()

    for target in index.targets:
        
        transformation: np.ndarray = result_storage["feature_matching"][source][target].transformation

        source_cloud: PointCloud = loaders[source]().unwrap()
        target_cloud: PointCloud = loaders[target]().unwrap()

        source_prepped: PointCloud = icp_preprocessor(source_cloud)
        target_prepped: PointCloud = icp_preprocessor(target_cloud)

        result: ExtendedRegistrationResult = icp_registrator(
            source = source_prepped,
            target = target_prepped,
            transformation = transformation
        )

        logger.info("")
        logger.info(f"Source {source}, target {target}")
        logger.info(f"Initial transformation: {transformation}")
        logger.info(f"Refined transformation: {result.transformation}")
        logger.info("")

        results[target] = result

    result_storage["incremental_coarse"][source] = results

[32m2024-07-04 22:38:20.073[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1m[0m
[32m2024-07-04 22:38:20.074[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mSource 0, target 1[0m
[32m2024-07-04 22:38:20.076[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mInitial transformation: [[ 9.71381407e-01 -1.72230888e-02  5.52931568e-03 -1.17976924e+01]
 [ 1.71974498e-02  9.71387073e-01  4.52185772e-03 -8.28967165e+00]
 [-5.60855042e-03 -4.42319922e-03  9.71523559e-01 -1.21786578e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]][0m
[32m2024-07-04 22:38:20.077[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mRefined transformation: [[ 9.71549401e-01  7.47762972e-04  4.97394756e-04 -1.21686524e+01]
 [-7.48634513e-04  9.71548033e-01  1.70441811e-03 -8.22920207e+00]
 [-4.96082021e-04 -1.70480066e-03  9.71548194e-01 -1.28618127e+00]
 [ 0.00000000e+00 

In [169]:
batch_visualize_registration(
    storage = result_storage["incremental_coarse"],
    source_color = [ 0.60, 0.20, 0.20 ],
    target_color = [ 0.20, 0.20, 0.60 ],
)

### TODO: Perform fine / colored registration

In [200]:
# TODO: Add colored ICP

### TODO: Perform multiway registration

In [197]:
def build_pose_graph(
    results: dict[int, dict[int, ExtendedRegistrationResult]],
) -> reg.PoseGraph:
    """Builds a pose graph."""
    
    odometry = np.identity(4)

    pose_graph = reg.PoseGraph()
    pose_graph.nodes.append(reg.PoseGraphNode(odometry))

    for source_id, registrations in results.items():

        for target_id, result in registrations.items():

            logger.info(f"{source_id}, {target_id}")

            if target_id == source_id + 1:  # odometry case
                
                odometry = np.dot(result.transformation, odometry)
                
                pose_graph.nodes.append(
                    reg.PoseGraphNode(
                        np.linalg.inv(odometry),
                    )
                )
                pose_graph.edges.append(reg.PoseGraphEdge(
                    source_id,
                    target_id,
                    result.transformation,
                    result.information,
                    uncertain=False, # NOTE: False
                ))
            else:  # loop closure case
                pose_graph.edges.append(reg.PoseGraphEdge(
                    source_id,
                    target_id,
                    result.transformation,
                    result.information,
                    uncertain=True,
                ))

    return pose_graph


def optimize_pose_graph(
    graph: reg.PoseGraph,
    correspondence_distance: float,
    prune_threshold: float,
    preference_loop_closure: float,
    reference_node: int=-1,
) -> reg.PoseGraph:
    """TODO"""
    
    option = reg.GlobalOptimizationOption(
        max_correspondence_distance=correspondence_distance,
        edge_prune_threshold=prune_threshold,
        preference_loop_closure=preference_loop_closure,
        reference_node=reference_node,
    )

    reg.global_optimization(
        graph,
        reg.GlobalOptimizationLevenbergMarquardt(),
        reg.GlobalOptimizationConvergenceCriteria(),
        option,
    )

    return graph

In [221]:
initial_graph: reg.PoseGraph = build_pose_graph(result_storage["feature_matching"])

logger.info(initial_graph)


with util.VerbosityContextManager(util.VerbosityLevel.Debug) as cm:
    optimized_graph: reg.PoseGraph = optimize_pose_graph(
        initial_graph,
        correspondence_distance = 0.01,
        prune_threshold = 0.25,
        preference_loop_closure = 1.0,
        reference_node = 0,
    )


for identifier in result_storage["feature_matching"]:
    logger.info(optimized_graph.nodes[identifier].pose)


[32m2024-07-04 22:43:51.688[0m | [1mINFO    [0m | [36m__main__[0m:[36mbuild_pose_graph[0m:[36m15[0m - [1m0, 1[0m
[32m2024-07-04 22:43:51.689[0m | [1mINFO    [0m | [36m__main__[0m:[36mbuild_pose_graph[0m:[36m15[0m - [1m0, 2[0m
[32m2024-07-04 22:43:51.690[0m | [1mINFO    [0m | [36m__main__[0m:[36mbuild_pose_graph[0m:[36m15[0m - [1m0, 3[0m
[32m2024-07-04 22:43:51.690[0m | [1mINFO    [0m | [36m__main__[0m:[36mbuild_pose_graph[0m:[36m15[0m - [1m1, 2[0m
[32m2024-07-04 22:43:51.691[0m | [1mINFO    [0m | [36m__main__[0m:[36mbuild_pose_graph[0m:[36m15[0m - [1m1, 3[0m
[32m2024-07-04 22:43:51.691[0m | [1mINFO    [0m | [36m__main__[0m:[36mbuild_pose_graph[0m:[36m15[0m - [1m2, 3[0m
[32m2024-07-04 22:43:51.692[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mPoseGraph with 4 nodes and 6 edges.[0m
[32m2024-07-04 22:43:51.693[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m17[0m -

[Open3D DEBUG] Validating PoseGraph - finished.
[Open3D DEBUG] [GlobalOptimizationLM] Optimizing PoseGraph having 4 nodes and 6 edges.
[Open3D DEBUG] Line process weight : 2.265217
[Open3D DEBUG] [Initial     ] residual : 5.193143e+05, lambda : 1.239019e+03
[Open3D DEBUG] Delta.norm() < 1.000000e-06 * (x.norm() + 1.000000e-06)
[Open3D DEBUG] [GlobalOptimizationLM] total time : 0.000 sec.
[Open3D DEBUG] [GlobalOptimizationLM] Optimizing PoseGraph having 4 nodes and 3 edges.
[Open3D DEBUG] Line process weight : 2.575567
[Open3D DEBUG] [Initial     ] residual : 1.784097e-24, lambda : 1.239019e+03
[Open3D DEBUG] Maximum coefficient of right term < 1.000000e-06
[Open3D DEBUG] CompensateReferencePoseGraphNode : reference : 0


In [None]:
transformed_clouds: list[PointCloud] = list()

for identifier in result_storage["incremental_coarse"]:
    
    cloud: PointCloud = loaders[identifier]().unwrap()
    cloud.transform(optimized_graph.nodes[identifier].pose)
    transformed_clouds.append(cloud)
    
vis.draw_geometries(transformed_clouds)

### TODO: Draw final registration results