### Include registration modules

In [5]:
import copy

from functools import partial
from pathlib import Path
from time import time
from typing import Callable

import numpy as np
import open3d.geometry as geom
import open3d.pipelines.registration as reg
import open3d.visualization as vis
import open3d.utility as util

from result import Ok

from benthoscan.runtime import Environment, load_environment

from benthoscan.registration import (
    MultiTargetIndex,
    PointCloud,
    PointCloudLoader,
    read_point_cloud,
    downsample_point_cloud,
    estimate_point_cloud_normals,
    generate_cascade_indices,
)

from benthoscan.registration import (
    ExtendedRegistrationResult,
    register_point_cloud_fphp_ransac,
    register_point_cloud_icp,
    build_pose_graph,
    optimize_pose_graph,
)

from benthoscan.visualization import (
    visualize_registration,
    visualize_registration_batch,
)

# from benthoscan.tasks.registration import RegistrationTaskConfig

from benthoscan.utils.log import logger

### Load environment and configure data loaders

In [6]:
environment: Environment = load_environment()

DATA_DIR: Path = Path("/home/martin/dev/benthoscan/.cache")

point_cloud_files: dict = {
    0: DATA_DIR / Path("qdc5ghs3_20100430_024508.ply"),
    1: DATA_DIR / Path("qdc5ghs3_20120501_033336.ply"),
    2: DATA_DIR / Path("qdc5ghs3_20130405_103429.ply"),
    3: DATA_DIR / Path("qdc5ghs3_20210315_230947.ply"),
}

loaders: dict = {
    key: partial(read_point_cloud, path=path) for key, path in point_cloud_files.items()
}

count = len(loaders)
if count < 2:
    logger.error(f"invalid number of point clouds for registration: {count}")

### Configure preprocessor and registrator

In [15]:
# NOTE: Parameters - move to config
VOXEL_SIZE: float = 0.20

# FPFH feature parameters
FEATURE_RADIUS: float = 2.00  # NOTE: 2.00
FEATURE_NEIGHBOURS: int = 500

# Validation and estimation
CORRESPONDENCE_DISTANCE: float = 0.20
EDGE_LENGTH: float = 0.90 # NOTE: 0.95
NORMAL_ANGLE: float = 5.0

ESTIMATE_SCALE: bool = True

# RANSAC parameters
SAMPLE_COUNT: int = 3
MAX_ITERATIONS: int = 200000
CONFIDENCE: float = 0.99999  # NOTE: 0.999


def downsample_and_estimate_normals(cloud: PointCloud, voxel_size: float) -> PointCloud:
    """Preprocesses a point cloud by downsampling and estimating the normals."""
    downsampled: PointCloud = downsample_point_cloud(cloud, spacing=voxel_size)
    preprocessed: PointCloud = estimate_point_cloud_normals(downsampled)
    return preprocessed


def registration_worker(
    source_loader: PointCloudLoader,
    target_loader: PointCloudLoader,
    preprocessor: Callable[[PointCloud], PointCloud],
    registrator: Callable[[PointCloud, PointCloud], ExtendedRegistrationResult],
) -> ExtendedRegistrationResult:
    """Worker function that performs the point cloud registration."""

    source_cloud: PointCloud = source_loader().unwrap()
    target_cloud: PointCloud = target_loader().unwrap()

    preprocessed_source: PointCloud = preprocessor(source_cloud)
    preprocessed_target: PointCloud = preprocessor(target_cloud)

    result: ExtendedRegistratorResult = registrator(
        source=preprocessed_source, target=preprocessed_target
    )

    return result


preprocessor: Callable[[PointCloud], PointCloud] = partial(
    downsample_and_estimate_normals, voxel_size=VOXEL_SIZE
)

registrator = partial(
    register_point_cloud_fphp_ransac,
    distance_threshold=CORRESPONDENCE_DISTANCE,
    feature_radius=FEATURE_RADIUS,
    feature_neighbours=FEATURE_NEIGHBOURS,
    max_iterations=MAX_ITERATIONS,
    confidence=CONFIDENCE,
    sample_count=SAMPLE_COUNT,
    edge_check=EDGE_LENGTH,
    normal_check=NORMAL_ANGLE,
    scaling=ESTIMATE_SCALE,
)

In [20]:
def decompose_transformation(transformation: np.ndarray) -> tuple:
    """Decomposes a 3D rigid body transformation into scale, rotation, and translation."""

    assert transformation.shape == (4, 4), "transformation is not a 3D rigid-body transformation"
    
    scaled_rotation: np.ndarray = transformation[:3, :3]
    translation: np.ndarray = transformation[:3, 3]

    scale: float = np.linalg.norm(scaled_rotation, axis=1)[0]
    rotation: np.ndarray = scaled_rotation / scale

    return scale, rotation, translation

def log_registration(source: int, target: int, result: ExtendedRegistrationResult) -> None:
    """TODO"""

    scale, rotation, translation = decompose_transformation(result.transformation)
    
    logger.info("")
    logger.info(f"Source:       {source}")
    logger.info(f"Target:       {target}")
    logger.info(f"Corresp.:     {len(result.correspondence_set)}")
    logger.info(f"Fitness:      {result.fitness}")
    logger.info(f"Inlier RMSE:  {result.inlier_rmse}")
    logger.info(f"Trans. scale:    {scale}")
    logger.info(f"Trans. trans.:   {translation}")
    logger.info(f"Trans. rot.:     {rotation}")
    logger.info("")

### Test registrator parameters on a single case

In [21]:
# Select some point clouds to tune the registration - 0-3 is the really hard case, but 1-3 and 2-3 are also challenging
test_source: int = 0
test_target: int = 3


with util.VerbosityContextManager(util.VerbosityLevel.Debug) as cm:
    source_cloud: PointCloud = loaders[test_source]().unwrap()
    target_cloud: PointCloud = loaders[test_target]().unwrap()

for i in range(10):
    test_result: ExtendedRegistrationResult = registration_worker(
        source_loader=lambda: Ok(source_cloud),
        target_loader=lambda: Ok(target_cloud),
        preprocessor=preprocessor,
        registrator=registrator,
    )

    log_registration(0, 3, test_result)

[Open3D DEBUG] Format auto File /home/martin/dev/benthoscan/.cache/qdc5ghs3_20100430_024508.ply
[Open3D DEBUG] Read geometry::PointCloud: 10457378 vertices.
[Open3D DEBUG] Format auto File /home/martin/dev/benthoscan/.cache/qdc5ghs3_20210315_230947.ply
[Open3D DEBUG] Read geometry::PointCloud: 15048357 vertices.


[32m2024-07-05 22:28:23.389[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m19[0m - [1m[0m
[32m2024-07-05 22:28:23.715[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m20[0m - [1mSource:       0[0m
[32m2024-07-05 22:28:23.742[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m21[0m - [1mTarget:       3[0m
[32m2024-07-05 22:28:23.792[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m22[0m - [1mCorresp.:     13384[0m
[32m2024-07-05 22:28:23.817[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m23[0m - [1mFitness:      0.45604470492026716[0m
[32m2024-07-05 22:28:23.820[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m24[0m - [1mInlier RMSE:  0.10000718211675465[0m
[32m2024-07-05 22:28:23.843[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m25[0m - [1mTrans. scale:    0.983397547097356[0m
[32

In [22]:
visualize_test: bool = True

if visualize_test:
    visualize_registration(
        source=source_cloud,
        target=target_cloud,
        transformation=test_result.transformation,
        title=f"Test case: {test_source}, {test_target}",
    )

### Generate indices and perform registration

In [24]:
indices: list[MultiTargetIndex] = generate_cascade_indices(len(loaders))

result_storage = {
    "feature_matching": dict(),
    "incremental_coarse": dict(),
}

In [25]:
for index in indices:
    source = index.source

    results: dict[int, ExtendedRegistrationResult] = dict()

    for target in index.targets:

        result: ExtendedRegistrationResult = registration_worker(
            source_loader=loaders[source],
            target_loader=loaders[target],
            preprocessor=preprocessor,
            registrator=registrator,
        )

        log_registration(source, target, result)

        results[target] = result

    result_storage["feature_matching"][source] = results

[32m2024-07-05 22:35:02.643[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m19[0m - [1m[0m
[32m2024-07-05 22:35:02.651[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m20[0m - [1mSource:       0[0m
[32m2024-07-05 22:35:02.652[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m21[0m - [1mTarget:       1[0m
[32m2024-07-05 22:35:02.652[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m22[0m - [1mCorresp.:     23990[0m
[32m2024-07-05 22:35:02.653[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m23[0m - [1mFitness:      0.8174321929944118[0m
[32m2024-07-05 22:35:02.653[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m24[0m - [1mInlier RMSE:  0.10092196658495878[0m
[32m2024-07-05 22:35:02.654[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m25[0m - [1mTrans. scale:    0.9653804230033952[0m
[32

### Draw registration and plot results
- TODO: Draw registered point clouds
- TODO: Plot point clouds, correspondences, and error distribution

In [26]:
visualize_registration_batch(
    storage=result_storage["feature_matching"],
    loaders=loaders,
)

### Perform coarse incremental registration

In [27]:
icp_coarse_parameters = dict()
icp_coarse_parameters["voxel"] = 0.05
icp_coarse_parameters["distance"] = 0.075

loss = reg.TukeyLoss(k=0.20)
icp_coarse_parameters["estimator"] = reg.TransformationEstimationPointToPlane(loss)

icp_preprocessor: Callable[[PointCloud], PointCloud] = partial(
    downsample_point_cloud,
    spacing=icp_coarse_parameters["voxel"],
)

icp_registrator = partial(
    register_point_cloud_icp,
    distance_threshold=icp_coarse_parameters["distance"],
    distance_measure=icp_coarse_parameters["estimator"],
)

In [28]:
result_storage["incremental_coarse"] = dict()


for index in indices:
    source = index.source

    results: dict[int, ExtendedRegistrationResult] = dict()

    for target in index.targets:

        transformation: np.ndarray = result_storage["feature_matching"][source][
            target
        ].transformation

        source_cloud: PointCloud = loaders[source]().unwrap()
        target_cloud: PointCloud = loaders[target]().unwrap()

        source_prepped: PointCloud = icp_preprocessor(source_cloud)
        target_prepped: PointCloud = icp_preprocessor(target_cloud)

        result: ExtendedRegistrationResult = icp_registrator(
            source=source_prepped, target=target_prepped, transformation=transformation
        )

        logger.info("")
        logger.info(f"Source {source}, target {target}")
        logger.info(f"Initial transformation: {transformation}")
        logger.info(f"Refined transformation: {result.transformation}")
        logger.info("")

        results[target] = result

    result_storage["incremental_coarse"][source] = results

[32m2024-07-05 22:48:50.685[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1m[0m
[32m2024-07-05 22:48:50.686[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mSource 0, target 1[0m
[32m2024-07-05 22:48:50.694[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mInitial transformation: [[ 9.64894350e-01  3.06258904e-02 -5.56908680e-04 -1.20365865e+01]
 [-3.06253482e-02  9.64894084e-01  9.24799774e-04 -7.77036298e+00]
 [ 5.85966624e-04 -9.06666981e-04  9.65379819e-01 -1.60739702e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]][0m
[32m2024-07-05 22:48:50.696[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mRefined transformation: [[ 9.64963789e-01  2.83528174e-02  6.04743987e-04 -1.19703909e+01]
 [-2.83521298e-02  9.64963395e-01 -1.07871023e-03 -7.88934414e+00]
 [-6.36164014e-04  1.06048404e-03  9.65379631e-01 -1.61307141e+00]
 [ 0.00000000e+00 

In [None]:
visualize_registration_batch(
    storage=result_storage["incremental_coarse"],
    loaders=loaders,
    source_color=[0.60, 0.20, 0.20],
    target_color=[0.20, 0.20, 0.60],
)

### TODO: Perform fine / colored registration

In [None]:
# TODO: Add colored ICP

### TODO: Perform multiway registration

In [35]:
initial_graph: reg.PoseGraph = build_pose_graph(result_storage["incremental_coarse"])

with util.VerbosityContextManager(util.VerbosityLevel.Debug) as cm:
    optimized_graph: reg.PoseGraph = optimize_pose_graph(
        initial_graph,
        correspondence_distance=0.075,
        prune_threshold=0.25,
        preference_loop_closure=100.0,
        reference_node=0,
    )

for identifier in result_storage["incremental_coarse"]:
    logger.info(optimized_graph.nodes[identifier].pose)

[32m2024-07-05 22:54:55.846[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1m[[ 1.00000000e+00  6.84616034e-21 -2.71050543e-20  0.00000000e+00]
 [ 3.41253362e-20  1.00000000e+00 -3.38813179e-21  5.42101086e-20]
 [ 5.42101086e-20 -3.38813179e-21  1.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]][0m
[32m2024-07-05 22:54:55.848[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1m[[ 1.03541402e+00 -3.04220709e-02 -6.82600511e-04  1.21531997e+01]
 [ 3.04228087e-02  1.03541360e+00  1.13796666e-03  8.53474548e+00]
 [ 6.48884843e-04 -1.15752323e-03  1.03586022e+00  1.66955158e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  1.00000000e+00]][0m
[32m2024-07-05 22:54:55.848[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1m[[ 1.02600339e+00 -3.31717991e-02  8.92356667e-04  1.77344770e+01]
 [ 3.31714273e-02  1.02600369e+00  4.38886656e-04  4.23181178e+

[Open3D DEBUG] Validating PoseGraph - finished.
[Open3D DEBUG] [GlobalOptimizationLM] Optimizing PoseGraph having 4 nodes and 6 edges.
[Open3D DEBUG] Line process weight : 166235.718750
[Open3D DEBUG] [Initial     ] residual : 6.773989e+07, lambda : 1.731310e+04
[Open3D DEBUG] [Iteration 00] residual : 4.471957e+05, valid edges : 0, time : 0.000 sec.
[Open3D DEBUG] [Iteration 01] residual : 4.437626e+05, valid edges : 0, time : 0.000 sec.
[Open3D DEBUG] [Iteration 02] residual : 4.425179e+05, valid edges : 0, time : 0.000 sec.
[Open3D DEBUG] [Iteration 03] residual : 4.419195e+05, valid edges : 0, time : 0.000 sec.
[Open3D DEBUG] [Iteration 04] residual : 4.415884e+05, valid edges : 0, time : 0.000 sec.
[Open3D DEBUG] [Iteration 05] residual : 4.413892e+05, valid edges : 0, time : 0.000 sec.
[Open3D DEBUG] [Iteration 06] residual : 4.412625e+05, valid edges : 0, time : 0.000 sec.
[Open3D DEBUG] [Iteration 07] residual : 4.411786e+05, valid edges : 0, time : 0.000 sec.
[Open3D DEBUG] [I

### Draw final registration results

In [36]:
transformed_clouds: list[PointCloud] = list()

for identifier in result_storage["incremental_coarse"]:

    cloud: PointCloud = loaders[identifier]().unwrap()
    cloud.transform(optimized_graph.nodes[identifier].pose)
    transformed_clouds.append(cloud)

vis.draw_geometries(transformed_clouds)