### Include registration modules

In [1]:
import copy

from functools import partial
from pathlib import Path
from time import time

from collections.abc import Callable

import numpy as np

import open3d
import open3d.geometry as geom
import open3d.pipelines.registration as reg
import open3d.visualization as vis
import open3d.utility as util

from result import Ok, Err, Result

from benthoscan.io import read_toml
from benthoscan.runtime import Environment, load_environment

from benthoscan.registration import PointCloud, RegistrationResult
from benthoscan.registration import PointCloudLoader, read_point_cloud
from benthoscan.registration import MultiTargetIndex, generate_cascade_indices

from benthoscan.registration import (
    downsample_point_cloud,
    estimate_point_cloud_normals,
)

from benthoscan.registration import (
    PointCloudProcessor,
    GlobalRegistrator,
    IncrementalRegistrator,
)
from benthoscan.registration import (
    build_point_cloud_processor,
    build_feature_registrator,
    build_icp_registrator,
)

from benthoscan.registration import (
    register_icp,
    register_colored_icp,
    build_pose_graph,
    optimize_pose_graph,
)

from benthoscan.spatial import decompose_transformation

from benthoscan.visualization import (
    visualize_registration,
    create_subplots,
    trace_registration_result,
)

from benthoscan.utils.log import logger

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


Can't use license /var/tmp/agisoft/licensing/licenses/metashape-pro.lic: License expired
No license server found


### Load environment and configure data loaders

In [2]:
environment: Environment = load_environment()

DATA_DIR: Path = Path("/home/martin/dev/benthoscan/.cache")

point_cloud_files: dict = {
    0: DATA_DIR / Path("qdc5ghs3_20100430_024508.ply"),
    1: DATA_DIR / Path("qdc5ghs3_20120501_033336.ply"),
    2: DATA_DIR / Path("qdc5ghs3_20130405_103429.ply"),
    3: DATA_DIR / Path("qdc5ghs3_20210315_230947.ply"),
}

loaders: dict = {
    key: partial(read_point_cloud, path=path) for key, path in point_cloud_files.items()
}

count = len(loaders)
if count < 2:
    logger.error(f"invalid number of point clouds for registration: {count}")

### Configure preprocessor and registrator

In [3]:
def registration_worker(
    source_loader: PointCloudLoader,
    target_loader: PointCloudLoader,
    preprocessor: Callable,
    registrator: Callable,
) -> RegistrationResult:
    """Worker function that performs the point cloud registration."""

    source: PointCloud = source_loader().unwrap()
    target: PointCloud = target_loader().unwrap()

    preprocessed_source: PointCloud = source
    preprocessed_target: PointCloud = target

    for preprocessor in preprocessors:
        preprocessed_source: PointCloud = preprocessor(preprocessed_source)

    for preprocessor in preprocessors:
        preprocessed_target: PointCloud = preprocessor(preprocessed_target)

    result: RegistratorResult = registrator(
        source=preprocessed_source,
        target=preprocessed_target,
    )

    return result

### Transformation Helpers

In [4]:
def log_registration(source: int, target: int, result: RegistrationResult) -> None:
    """TODO"""

    scale, rotation, translation = decompose_transformation(result.transformation)

    logger.info("")
    logger.info(f"Source:       {source}")
    logger.info(f"Target:       {target}")
    logger.info(f"Corresp.:     {len(result.correspondence_set)}")
    logger.info(f"Fitness:      {result.fitness}")
    logger.info(f"Inlier RMSE:  {result.inlier_rmse}")
    logger.info(f"Trans. scale:    {scale}")
    logger.info(f"Trans. trans.:   {translation}")
    logger.info(f"Trans. rot.:     {rotation}")
    logger.info("")

In [5]:
CONFIG_FILE: Path = Path("/home/martin/dev/benthoscan/config/registration.toml")

config: dict = read_toml(CONFIG_FILE).unwrap()


registrators: dict[str, list] = {
    "global": list(),
    "refinement": list(),
}

for module_config in config["modules"]:

    module_name: str = module_config["name"]
    module_type: str = module_config["type"]

    match module_type:
        case "feature":
            registrator: GlobalRegistrator = build_feature_registrator(
                module_config["registrator"],
            ).unwrap()
            registrators["global"].append(registrator)

        case "icp":
            registrator: IncrementalRegistrator = build_icp_registrator(
                module_config["registrator"],
            ).unwrap()
            registrators["refinement"].append(registrator)
        case other:
            logger.error(f"unknown module type: {module_type}")


for registrator in registrators:
    logger.info(f"Registrator: {registrator}")

# TODO: Create preprocessors
"""
preprocessors: list[PointCloudProcessor] = [
    build_point_cloud_processor(item).unwrap()
    for item in module_config["preprocessors"]
]
"""

# TODO: Create registration pipeline

[32m2024-07-29 18:33:53.426[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m33[0m - [1mRegistrator: global[0m
[32m2024-07-29 18:33:53.427[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m33[0m - [1mRegistrator: refinement[0m


'\npreprocessors: list[PointCloudProcessor] = [\n    build_point_cloud_processor(item).unwrap()\n    for item in module_config["preprocessors"]\n]\n'

### Debug builder functionality

In [6]:
preprocessors: list[Callable] = [
    partial(downsample_point_cloud, spacing=0.20),
    partial(estimate_point_cloud_normals, radius=0.40),
]

source_cloud: PointCloud = loaders[0]().unwrap()
target_cloud: PointCloud = loaders[3]().unwrap()

for preprocessor in preprocessors:
    source_cloud: PointCloud = preprocessor(cloud=source_cloud)
    target_cloud: PointCloud = preprocessor(cloud=target_cloud)


result: RegistrationResult = registrators["global"][0](
    source=source_cloud,
    target=target_cloud,
)

log_registration(source=0, target=3, result=result)


result: RegistrationResult = registrators["refinement"][0](
    source=source_cloud,
    target=target_cloud,
    transformation=result.transformation,
)

log_registration(source=0, target=3, result=result)

[32m2024-07-29 16:49:09.180[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m6[0m - [1m[0m
[32m2024-07-29 16:49:09.181[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m7[0m - [1mSource:       0[0m
[32m2024-07-29 16:49:09.182[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m8[0m - [1mTarget:       3[0m
[32m2024-07-29 16:49:09.183[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m9[0m - [1mCorresp.:     12476[0m
[32m2024-07-29 16:49:09.185[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m10[0m - [1mFitness:      0.42510562900368[0m
[32m2024-07-29 16:49:09.186[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m11[0m - [1mInlier RMSE:  0.09300614680802845[0m
[32m2024-07-29 16:49:09.187[0m | [1mINFO    [0m | [36m__main__[0m:[36mlog_registration[0m:[36m12[0m - [1mTrans. scale:    0.9948146923102374[0m
[32m2024-

### Execute several registration runs for a single case

In [None]:
# Select some point clouds to tune the registration - 0-3 is the really hard case, but 1-3 and 2-3 are also challenging
test_source: int = 0
test_target: int = 3
test_runs: int = 10

test_results: list[RegistrationResult] = list()

source_loader: PointCloudLoader = loaders[test_source]
target_loader: PointCloudLoader = loaders[test_target]

logger.info(result)

for i in range(test_runs):

    start: float = time()

    result: RegistrationResult = registration_worker(
        source_loader=source_loader,
        target_loader=target_loader,
        module=module,
    )

    end: float = time()

    logger.info(
        f"RANSAC run {i+1}/{test_runs} - Elapsed time {end - start:.3f} seconds"
    )

    log_registration(0, 3, result)

    test_results.append(result)

### Define registration plotting functions

In [None]:
import plotly.express as px
import plotly.graph_objects as go


figure: go.Figure = create_subplots(rows=2, cols=3)


# NOTE: Alternative color sequences
colors: list[str] = px.colors.sequential.Plasma_r * 2
# colors: list[str] = px.colors.qualitative.G10

for index, (result, color) in enumerate(zip(test_results, colors)):

    traces: dict[str, go.Trace] = trace_registration_result(
        result,
        name=f"Regi. {index}",
        legendgroup=index,
        color=color,
    )

    figure.add_trace(traces["fitness"], row=1, col=1)
    figure.add_trace(traces["rmse"], row=1, col=2)
    figure.add_trace(traces["correspondences"], row=1, col=3)

    figure.add_trace(traces["scale"], row=2, col=1)
    figure.add_trace(traces["rotation"], row=2, col=2)
    figure.add_trace(traces["translation"], row=2, col=3)


figure.update_layout(height=800, width=1000, title_text="Registration Results")
figure.show()


visualize_test: bool = False

if visualize_test:

    source_cloud: PointCloud = loaders[test_source]().unwrap()
    target_cloud: PointCloud = loaders[test_target]().unwrap()

    for index, result in enumerate(test_results):
        visualize_registration(
            source=source_cloud,
            target=target_cloud,
            transformation=result.transformation,
            title=f"Test case: {test_source}, {test_target}, {index}",
        )

### Generate indices and perform registration

In [None]:
indices: list[MultiTargetIndex] = generate_cascade_indices(len(loaders))

result_storage = {
    "feature_matching": dict(),
    "incremental_coarse": dict(),
}

In [None]:
for index in indices:
    source = index.source

    results: dict[int, RegistrationResult] = dict()

    for target in index.targets:

        result: RegistrationResult = registration_worker(
            source_loader=loaders[source],
            target_loader=loaders[target],
            preprocessor=preprocessor,
            registrator=registrator,
        )

        log_registration(source, target, result)

        results[target] = result

    result_storage["feature_matching"][source] = results

### Draw registration and plot results
- TODO: Draw registered point clouds
- TODO: Plot point clouds, correspondences, and error distribution

In [None]:
visualize_global_results: bool = False

if visualize_global_results:

    for source, registrations in result_storage["feature_matching"].items():

        source_cloud: PointCloud = loaders[source]().unwrap()

        for target, result in registrations.items():

            target_cloud: PointCloud = loaders[target]().unwrap()

            visualize_registration(
                source=source_cloud,
                target=target_cloud,
                transformation=result.transformation,
                title=f"Registration - source: {source}, target: {target}",
            )

### Perform coarse incremental registration

In [None]:
icp_coarse_parameters = dict()
icp_coarse_parameters["voxel"] = 0.08
icp_coarse_parameters["distance"] = 0.04

loss = reg.TukeyLoss(k=1.00)
icp_coarse_parameters["estimator"] = reg.TransformationEstimationPointToPlane(loss)

icp_preprocessor: Callable[[PointCloud], PointCloud] = partial(
    downsample_point_cloud,
    spacing=icp_coarse_parameters["voxel"],
)

icp_registrator = partial(
    register_point_cloud_icp,
    distance_threshold=icp_coarse_parameters["distance"],
    distance_measure=icp_coarse_parameters["estimator"],
)

In [None]:
result_storage["incremental_coarse"] = dict()

for index in indices:
    source = index.source

    results: dict[int, RegistrationResult] = dict()

    for target in index.targets:

        preliminary_result = result_storage["feature_matching"][source][target]

        source_cloud: PointCloud = loaders[source]().unwrap()
        target_cloud: PointCloud = loaders[target]().unwrap()

        source_prepped: PointCloud = icp_preprocessor(source_cloud)
        target_prepped: PointCloud = icp_preprocessor(target_cloud)

        refined_result: RegistrationResult = icp_registrator(
            source=source_prepped,
            target=target_prepped,
            transformation=preliminary_result.transformation,
        )

        logger.info(f"ICP - source {source}, target {target}")

        scale, rotation, translation = decompose_transformation(
            preliminary_result.transformation
        )
        logger.info(f" - Initial: {translation}")

        scale, rotation, translation = decompose_transformation(
            refined_result.transformation
        )
        logger.info(f" - Refined: {translation}")

        results[target] = refined_result

    result_storage["incremental_coarse"][source] = results

In [None]:
visualize_incremental_results: bool = False

if visualize_incremental_results:

    for source, registrations in result_storage["incremental_coarse"].items():

        source_cloud: PointCloud = loaders[source]().unwrap()

        for target, result in registrations.items():

            target_cloud: PointCloud = loaders[target]().unwrap()

            visualize_registration(
                source=source_cloud,
                target=target_cloud,
                transformation=result.transformation,
                title=f"Registration - source: {source}, target: {target}",
            )

### TODO: Perform fine / colored registration

In [None]:
# TODO: Add colored ICP

### TODO: Perform multiway registration

In [None]:
initial_graph: reg.PoseGraph = build_pose_graph(result_storage["incremental_coarse"])

with util.VerbosityContextManager(util.VerbosityLevel.Debug) as cm:
    optimized_graph: reg.PoseGraph = optimize_pose_graph(
        initial_graph,
        correspondence_distance=0.075,
        prune_threshold=0.25,
        preference_loop_closure=1.0,
        reference_node=0,
    )

### Draw final registration results

In [None]:
transformed_clouds: list[PointCloud] = list()

for identifier in result_storage["incremental_coarse"]:

    cloud: PointCloud = loaders[identifier]().unwrap()
    cloud.transform(optimized_graph.nodes[identifier].pose)
    transformed_clouds.append(cloud)

vis.draw_geometries(transformed_clouds)