### Include registration modules

In [None]:
import copy

from functools import partial
from pathlib import Path
from time import time

from collections.abc import Callable

import numpy as np

import open3d
import open3d.geometry as geom
import open3d.pipelines.registration as reg
import open3d.visualization as vis
import open3d.utility as util

from mynd.io import read_toml
from mynd.runtime import Environment, load_environment

from mynd.registration import PointCloud, RegistrationResult
from mynd.registration import PointCloudLoader, create_point_cloud_loader
from mynd.registration import MultiTargetIndex, generate_cascade_indices

from mynd.registration import (
    downsample_point_cloud,
    estimate_point_cloud_normals,
)

from mynd.registration import (
    PointCloudProcessor,
    GlobalRegistrator,
    IncrementalRegistrator,
)
from mynd.registration import (
    build_point_cloud_processor,
    build_feature_registrator,
    build_icp_registrator,
)

from mynd.registration import (
    register_icp,
    register_colored_icp,
    build_pose_graph,
    optimize_pose_graph,
)

from mynd.registration.pipeline import (
    Module,
    ModuleList,
    apply_registration_modules,
)

from mynd.spatial import decompose_transformation

from mynd.visualization import (
    visualize_registration,
    create_subplots,
    trace_registration_result,
)

from mynd.utils.log import logger
from mynd.utils.result import Ok, Err, Result

### Load environment and configure data loaders

In [None]:
environment: Environment = load_environment()

DATA_DIR: Path = Path("/home/martin/dev/benthoscan/.cache")

point_cloud_files: dict = {
    0: DATA_DIR / Path("qdc5ghs3_20100430_024508.ply"),
    1: DATA_DIR / Path("qdc5ghs3_20120501_033336.ply"),
    2: DATA_DIR / Path("qdc5ghs3_20130405_103429.ply"),
    3: DATA_DIR / Path("qdc5ghs3_20210315_230947.ply"),
}

loaders: dict = {
    key: create_point_cloud_loader(path) for key, path in point_cloud_files.items()
}

for key, loader in loaders.items():
    test: PointCloud = loader()
    logger.info(f"key {key}, loader {loader}, loaded {test}")


count = len(loaders)
if count < 2:
    logger.error(f"invalid number of point clouds for registration: {count}")

### Logging Helpers

In [None]:
def log_registration(source: int, target: int, result: RegistrationResult) -> None:
    """TODO"""

    scale, rotation, translation = decompose_transformation(result.transformation)

    logger.info("")
    logger.info(f"Source:       {source}")
    logger.info(f"Target:       {target}")
    logger.info(f"Corresp.:     {len(result.correspondence_set)}")
    logger.info(f"Fitness:      {result.fitness}")
    logger.info(f"Inlier RMSE:  {result.inlier_rmse}")
    logger.info(f"Trans. scale:    {scale}")
    logger.info(f"Trans. trans.:   {translation}")
    logger.info(f"Trans. rot.:     {rotation}")
    logger.info("")

### Build registration modules from config

In [None]:
CONFIG_FILE: Path = Path(
    "/home/martin/dev/benthoscan/config/advanced_registration.toml"
)

config: dict = read_toml(CONFIG_FILE).unwrap()

modules: list[Module] = list()

for module_config in config["modules"]:

    module_name: str = module_config["name"]
    module_type: str = module_config["type"]

    preprocessors = [
        build_point_cloud_processor(model["method"], model["parameters"]).unwrap()
        for model in module_config["preprocessors"]
    ]

    match module_type:
        case "feature":
            registrator: GlobalRegistrator = build_feature_registrator(
                module_config["registrator"],
            ).unwrap()
        case "icp":
            registrator: IncrementalRegistrator = build_icp_registrator(
                module_config["registrator"],
            ).unwrap()
        case other:
            logger.error(f"unknown module type: {module_type}")
            registrator = None

    modules.append(Module(preprocessors=preprocessors, registrator=registrator))

# TODO: Verify inference of global/incremental module
modules: ModuleList = ModuleList(modules)

In [None]:
from time import time
from typing import Optional

source: int = 0
target: int = 3

source_cloud: PointCloud = loaders[source]().unwrap()
target_cloud: PointCloud = loaders[target]().unwrap()

module_results: list[RegistrationResult] = list()
current_result: Optional[RegistrationResult] = None


def on_result(
    source: PointCloud, target: PointCloud, result: RegistrationResult
) -> None:
    """Callback that is executed for every pair-wise registration."""
    log_registration(source=0, target=3, result=result)


results: dict[int, RegistrationResult] = apply_registration_modules(
    modules=modules,
    source=source_cloud,
    target=target_cloud,
    callback=on_result,
)

In [None]:
def visualize_plotly(
    source: PointCloud,
    target: PointCloud,
    transformation: np.ndarray,
    source_color: list = None,
    target_color: list = None,
    title: str = "",
    window_width: int = 1024,
    window_height: int = 768,
) -> None:
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)

    if source_color:
        source_temp.paint_uniform_color(source_color)
    if target_color:
        target_temp.paint_uniform_color(target_color)

    source_temp.transform(transformation)

    vis.draw_plotly(
        geometry_list=[source_temp, target_temp],
        window_name=title,
        width=window_width,
        height=window_height,
    )


visualize: bool = True

source_down: PointCloud = downsample_point_cloud(cloud=source_cloud, spacing=0.10)
target_down: PointCloud = downsample_point_cloud(cloud=target_cloud, spacing=0.10)

### Plot registration results

In [None]:
import plotly.express as px
import plotly.graph_objects as go


figure: go.Figure = create_subplots(rows=2, cols=3)


# NOTE: Alternative color sequences
colors: list[str] = px.colors.sequential.Plasma_r * 2
# colors: list[str] = px.colors.qualitative.G10

for index, (result, color) in enumerate(zip(test_results, colors)):

    traces: dict[str, go.Trace] = trace_registration_result(
        result,
        name=f"Regi. {index}",
        legendgroup=index,
        color=color,
    )

    figure.add_trace(traces["fitness"], row=1, col=1)
    figure.add_trace(traces["rmse"], row=1, col=2)
    figure.add_trace(traces["correspondences"], row=1, col=3)

    figure.add_trace(traces["scale"], row=2, col=1)
    figure.add_trace(traces["rotation"], row=2, col=2)
    figure.add_trace(traces["translation"], row=2, col=3)


figure.update_layout(height=800, width=1000, title_text="Registration Results")
figure.show()


visualize_test: bool = False

if visualize_test:

    source_cloud: PointCloud = loaders[test_source]().unwrap()
    target_cloud: PointCloud = loaders[test_target]().unwrap()

    for index, result in enumerate(test_results):
        visualize_registration(
            source=source_cloud,
            target=target_cloud,
            transformation=result.transformation,
            title=f"Test case: {test_source}, {test_target}, {index}",
        )

### Generate indices and perform registration

In [None]:
indices: list[MultiTargetIndex] = generate_cascade_indices(len(loaders))

result_storage = {
    "feature_matching": dict(),
    "incremental_coarse": dict(),
}

In [None]:
for index in indices:
    source = index.source

    results: dict[int, RegistrationResult] = dict()

    for target in index.targets:

        result: RegistrationResult = registration_worker(
            source_loader=loaders[source],
            target_loader=loaders[target],
            preprocessor=preprocessor,
            registrator=registrator,
        )

        log_registration(source, target, result)

        results[target] = result

    result_storage["feature_matching"][source] = results

### Draw registration and plot results
- TODO: Draw registered point clouds
- TODO: Plot point clouds, correspondences, and error distribution

In [None]:
visualize_global_results: bool = False

if visualize_global_results:

    for source, registrations in result_storage["feature_matching"].items():

        source_cloud: PointCloud = loaders[source]().unwrap()

        for target, result in registrations.items():

            target_cloud: PointCloud = loaders[target]().unwrap()

            visualize_registration(
                source=source_cloud,
                target=target_cloud,
                transformation=result.transformation,
                title=f"Registration - source: {source}, target: {target}",
            )

### Perform multiway registration

In [None]:
initial_graph: reg.PoseGraph = build_pose_graph(result_storage["incremental_coarse"])

with util.VerbosityContextManager(util.VerbosityLevel.Debug) as cm:
    optimized_graph: reg.PoseGraph = optimize_pose_graph(
        initial_graph,
        correspondence_distance=0.075,
        prune_threshold=0.25,
        preference_loop_closure=1.0,
        reference_node=0,
    )

### Draw final registration results

In [None]:
transformed_clouds: list[PointCloud] = list()

for identifier in result_storage["incremental_coarse"]:

    cloud: PointCloud = loaders[identifier]().unwrap()
    cloud.transform(optimized_graph.nodes[identifier].pose)
    transformed_clouds.append(cloud)

vis.draw_geometries(transformed_clouds)