### Include registration modules

In [1]:
import copy

from functools import partial
from pathlib import Path
from time import time
from typing import Callable

import numpy as np
import open3d.geometry as geom
import open3d.pipelines.registration as regi
import open3d.visualization as vis

from benthoscan.runtime import Environment, load_environment

from benthoscan.spatial import (
    MultiTargetIndex,
    PointCloud,
    PointCloudLoader,
    read_point_cloud,
    downsample_point_cloud,
    estimate_point_cloud_normals,
    generate_cascade_indices,
)

from benthoscan.spatial import (
    ExtendedRegistrationResult,
    register_point_cloud_fphp_fast,
    register_point_cloud_fphp_ransac,
    register_point_cloud_icp,
    register_point_cloud_graph,
)

# from benthoscan.tasks.registration import RegistrationTaskConfig

from benthoscan.utils.log import logger

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Load environment and configure data loaders

In [2]:
environment: Environment = load_environment()

DATA_DIR: Path = Path("/home/martin/dev/benthoscan/.cache")

point_cloud_files: dict = {
    0: DATA_DIR / Path("qdc5ghs3_20100430_024508.ply"),
    1: DATA_DIR / Path("qdc5ghs3_20120501_033336.ply"),
    2: DATA_DIR / Path("qdc5ghs3_20130405_103429.ply"),
    3: DATA_DIR / Path("qdc5ghs3_20210315_230947.ply"),
}

loaders: dict = {
    key: partial(read_point_cloud, path=path) for key, path in point_cloud_files.items()
}

count = len(loaders)
if count < 2:
    logger.error(f"invalid number of point clouds for registration: {count}")

### Configure preprocessor and registrator

In [33]:
# NOTE: Parameters - move to config
VOXEL_SIZE: float = 0.20

# FPFH feature parameters
FEATURE_RADIUS: float = 2.00 # NOTE: 2.00
FEATURE_NEIGHBOURS: int = 500

# Validation and estimation
CORRESPONDENCE_DISTANCE: float = 0.30
EDGE_LENGTH: float = 0.90
NORMAL_ANGLE: float = 5.0
ESTIMATE_SCALE: bool = True

# RANSAC parameters
SAMPLE_COUNT: int = 3
MAX_ITERATIONS: int = 200000
CONFIDENCE: float = 0.999


def downsample_and_estimate_normals(cloud: PointCloud, voxel_size: float) -> PointCloud:
    """Preprocesses a point cloud by downsampling and estimating the normals."""
    downsampled: PointCloud = downsample_point_cloud(cloud, spacing=voxel_size)
    preprocessed: PointCloud = estimate_point_cloud_normals(downsampled)
    return preprocessed
    

def registration_worker(
    source_loader: PointCloudLoader, 
    target_loader: PointCloudLoader,
    preprocessor: Callable[[PointCloud], PointCloud],
    registrator: Callable[[PointCloud, PointCloud], ExtendedRegistrationResult],
) -> ExtendedRegistrationResult:
    """Worker function that performs the point cloud registration."""
    
    source_cloud: PointCloud = source_loader().unwrap()
    target_cloud: PointCloud = target_loader().unwrap()
    
    preprocessed_source: PointCloud = preprocessor(source_cloud)
    preprocessed_target: PointCloud = preprocessor(target_cloud)

    result: ExtendedRegistratorResult = registrator(
        source=preprocessed_source, 
        target=preprocessed_target
    )

    return result


preprocessor: Callable[[PointCloud], PointCloud] = partial(
    downsample_and_estimate_normals, 
    voxel_size=VOXEL_SIZE
)

registrator = partial(
    register_point_cloud_fphp_ransac,
    distance_threshold=CORRESPONDENCE_DISTANCE,
    feature_radius=FEATURE_RADIUS,
    feature_neighbours=FEATURE_NEIGHBOURS,
    max_iterations=MAX_ITERATIONS,
    confidence=CONFIDENCE,
    sample_count=SAMPLE_COUNT,
    edge_check=EDGE_LENGTH,
    normal_check=NORMAL_ANGLE,
    scaling = ESTIMATE_SCALE,
)

### Define visualization helpers

In [22]:
def visualize_registration(source: PointCloud, target: PointCloud, transformation: np.ndarray, title: str="") -> None:
    """TODO"""
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)
    
    source_temp.paint_uniform_color([0.60, 0.20, 0.20])
    target_temp.paint_uniform_color([0.20, 0.20, 0.60])
    
    source_temp.transform(transformation)
    
    vis.draw_geometries(
        geometry_list=[source_temp, target_temp], 
        window_name=title, 
        width=1024, 
        height=768
    )

### Test registrator parameters on a single case

In [29]:
# Select some point clouds to tune the registration - 0-3 is the really hard case, but 1-3 and 2-3 are also challenging
test_source: int = 0
test_target: int = 3

test_result: ExtendedRegistrationResult = registration_worker(
    source_loader = loaders[test_source],
    target_loader = loaders[test_target],
    preprocessor = preprocessor,
    registrator = registrator,
)

logger.info(test_result)


[32m2024-07-01 21:49:38.360[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mExtendedRegistrationResult(fitness=0.39904583404327826, inlier_rmse=0.13566774874877197, correspondence_set=std::vector<Eigen::Vector2i> with 11710 elements.
Use numpy.asarray() to access data., transformation=array([[ 9.92644762e-01,  1.21034223e-01, -2.66331413e-03,
         8.40572161e+00],
       [-1.21054184e-01,  9.92601214e-01, -9.41879369e-03,
         1.26083306e+01],
       [ 1.50361247e-03,  9.67192154e-03,  9.99952095e-01,
        -1.58223355e-03],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         1.00000000e+00]]), information=array([[ 2.61508524e+07, -1.20186100e+06,  4.95314185e+06,
         0.00000000e+00,  5.35719405e+05,  1.25158923e+05],
       [-1.20186100e+06,  2.60577623e+07,  5.79393799e+06,
        -5.35719405e+05,  0.00000000e+00, -1.08777432e+05],
       [ 4.95314185e+06,  5.79393799e+06,  3.13759162e+06,
        -1.25158923e+05,  1.087

In [30]:
visualize_test: bool = True

if visualize_test:
    visualize_registration(
        source = loaders[test_source]().unwrap(), 
        target = loaders[test_target]().unwrap(), 
        transformation = test_result.transformation,
        title = f"Test case: {test_source}, {test_target}",
    )

### Generate indices and perform registration

In [None]:
indices: list[MultiTargetIndex] = generate_cascade_indices(len(loaders))

result_storage: dict[int, dict] = dict()

for index in indices:
    source = index.source

    results: dict[int, ExtendedRegistrationResult] = dict()

    for target in index.targets:
        start: float = time()
        
        result: ExtendedRegistrationResult = registration_worker(
            source_loader = loaders[source],
            target_loader = loaders[target],
            preprocessor = preprocessor,
            registrator = registrator,
        )

        end: float = time()

        elapsed: float = end - start

        logger.info("")
        logger.info(f"--------------------- FPFH registration --------------------")
        logger.info(f" - Source, target:        {source}, {target}")
        logger.info(f" - Elapsed time:          {elapsed}")
        logger.info(f" - RMSE:                  {result.inlier_rmse}")
        logger.info(f" - Fitness:               {result.fitness}")
        logger.info(f" - Correspondences:       {len(result.correspondence_set)}")
        logger.info(f" - Transformation:        {result.transformation}")
        logger.info(f" - Information:           {result.information}")
        logger.info(f"------------------------------------------------------------")
        logger.info("")


        results[target] = result

    result_storage[source] = results

[32m2024-07-01 22:19:34.304[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1m[0m
[32m2024-07-01 22:19:34.305[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1m--------------------- FPFH registration --------------------[0m
[32m2024-07-01 22:19:34.306[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m - Source, target:        0, 1[0m
[32m2024-07-01 22:19:34.307[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1m - Elapsed time:          30.079916954040527[0m
[32m2024-07-01 22:19:34.308[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1m - RMSE:                  0.11663528456029898[0m
[32m2024-07-01 22:19:34.308[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1m - Fitness:               0.8561935593797921[0m
[32m2024-07-01 22:19:34.309[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[

### Draw registration and plot results
- TODO: Draw registered point clouds
- TODO: Plot point clouds, correspondences, and error distribution

In [32]:
def batch_visualize_registration(storage: dict[int, dict]) -> None:
    """TODO"""
    for source, registrations in storage.items():
        
        source_cloud: PointCloud = loaders[source]().unwrap()
        
        for target, result in registrations.items():
                
            target_cloud: PointCloud = loaders[target]().unwrap()
            
            visualize_registration(
                source = source_cloud,
                target = target_cloud,
                transformation = result.transformation,
                title = f"Source: {source}, target: {target}"
            )


batch_visualize_registration(storage = result_storage)

# TODO: Implement draw functions

### TODO: Perform incremental registration

### TODO: Perform full graph-based registration

### TODO: Draw final registration results