In [None]:
from pathlib import Path

import Metashape

from benthoscan.backends import metashape as backend
from benthoscan.utils.log import logger
from benthoscan.utils.result import Ok, Err, Result

In [None]:
PATHS: dict[str, Path] = {
    "DOCUMENT_IN": Path(
        "/data/kingston_snv_01/acfr_revisits_metashape_projects/r23685bc_working_version.psz"
    ),
    "DOCUMENT_OUT": Path(
        "/data/kingston_snv_01/acfr_revisits_metashape_projects_test/r23685bc_working_version_saved.psz"
    ),
    "CACHE": Path("/home/martin/dev/benthoscan/.cache/"),
}

result: Result[str, str] = backend.load_project(PATHS.get("DOCUMENT_IN"))
match result:
    case Ok(message):
        logger.info(message)
    case Err(message):
        logger.error(message)

backend.log_internal_data()

### Define disparity estimator

In [None]:
import cv2
import numpy as np


def sgbm_default_parameters() -> dict:
    """
    Returns default parameters for semi-global block matching.
    Adopted from: https://github.com/eborboihuc/stereo-matching/blob/master/sgbm.py
    """

    min_disparity: int = 0
    num_disparities: int = 6 * 16  # NOTE: 112
    block_size: int = 5
    window_size: int = 3
    disp12_max_diff: int = 1
    uniqueness_ratio: int = 10
    speckle_window_size: int = 50
    speckle_range: int = 1
    pre_filter_cap: int = 63
    mode: object = cv2.STEREO_SGBM_MODE_SGBM_3WAY

    p1: int = 8 * 1 * block_size**2  # **2 # Smoothness
    p2: int = 32 * 1 * block_size**2  # **2 # Smoothness

    return {
        "minDisparity": min_disparity,
        "numDisparities": num_disparities,
        "blockSize": block_size,
        "P1": p1,
        "P2": p2,
        "disp12MaxDiff": disp12_max_diff,
        "uniquenessRatio": uniqueness_ratio,
        "speckleWindowSize": speckle_window_size,
        "speckleRange": speckle_range,
        "preFilterCap": pre_filter_cap,
        "mode": mode,
    }


def filter_disparity(
    estimator: object,
    image_left: np.ndarray,
    disparity_left: np.ndarray,
    disparity_right: np.ndarray,
) -> np.ndarray:
    """Filter disparity map based on WLS. Returns the disparity map as 16-bit pixels."""

    filtered_disparity_left = estimator.filter(
        disparity_left,
        image_left,
        None,
        disparity_right,
    )

    return filtered_disparity_left


def compute_disparity_sgbm(
    image_left: np.ndarray,
    image_right: np.ndarray,
    parameters: dict = sgbm_default_parameters(),
    smooth: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
    """Computes the disparity for an image pair using OpenCVs SGBM algorithm."""

    left_matcher = cv2.StereoSGBM_create(**parameters)
    right_matcher = cv2.ximgproc.createRightMatcher(left_matcher)

    # Disparity is represented with 16-bit integers. To convert to pixel value cast to float
    # and divide by 16, i.e.:  disparity.astype(np.float32) / 16.0
    disparity_left: np.ndarray = left_matcher.compute(image_left, image_right)
    disparity_right: np.ndarray = right_matcher.compute(image_right, image_left)

    disparity_left: np.ndarray = disparity_left.astype(np.float32) / 16.0
    disparity_right: np.ndarray = disparity_right.astype(np.float32) / 16.0

    if smooth:
        # FILTER Parameters
        lmbda: int = 100  # regularization parameter 500
        sigma: float = 0.2

        disparity_filter = cv2.ximgproc.createDisparityWLSFilter(
            matcher_left=left_matcher
        )
        disparity_filter.setLambda(lmbda)
        disparity_filter.setSigmaColor(sigma)

        disparity_left: np.ndarray = filter_disparity(
            disparity_filter, image_left, disparity_left, disparity_right
        )
        disparity_right: np.ndarray = filter_disparity(
            disparity_filter, image_right, disparity_right, disparity_left
        )

    assert disparity_left.dtype == np.float32, "invalid disparity type: expected int16"
    assert disparity_right.dtype == np.float32, "invalid disparity type: expected int16"

    return disparity_left, disparity_right


def disparity_to_range(
    disparity_map: np.ndarray, focal_length: float, baseline: float
) -> np.ndarray:
    """Converts disparity to range."""
    INVALID_RANGE: float = 0.0

    inverse_disparity: np.ndarray = 1.0 / disparity_map

    range_map: np.ndarray = inverse_disparity * focal_length * baseline
    range_map: np.ndarray = np.where(range_map < 0.0, INVALID_RANGE, range_map)

    return range_map

### Define overall stereo depth estimation process

In [None]:
import cv2
import plotly.express as px
import plotly.graph_objects as go
import tqdm


from benthoscan.backends.metashape.data_types import (
    ImagePair,
    SensorPair,
    CameraPair,
    StereoGroup,
)

from benthoscan.backends.metashape.camera_helpers import (
    compute_camera_calibration,
    compute_stereo_calibration,
    get_stereo_groups,
)

from benthoscan.backends.metashape.image import (
    ImagePairLoader,
    image_to_numpy,
    generate_image_loaders,
)

from benthoscan.geometry.stereo import (
    StereoCalibration,
    RectificationResult,
    compute_stereo_rectification,
    rectify_image_pair,
)


def generate_stereo_rectification(
    chunk: Metashape.Chunk,
) -> tuple[RectificationResult, list[ImagePairLoader]]:
    """Export range maps based on a stereo camera setup."""

    # Get pairs of sensors and cameras (master-slaves)
    stereo_groups: list[StereoGroup] = get_stereo_groups(chunk)

    # NOTE: Process the first stereo group for now
    group: StereoGroup = stereo_groups[0]

    # Compute the stereo calibration for the sensor pair
    calibration: StereoCalibration = compute_stereo_calibration(group.sensor_pair)

    # Compute homographies, image transformations and camera matrices for stereo rectification
    rectification: RectificationResult = compute_stereo_rectification(calibration)

    # Set up image loaders for all camera pairs
    image_loaders: list[ImagePairLoader] = generate_image_loaders(group.camera_pairs)

    return rectification, image_loaders

### Test stereo depth estimation on camera pairs from Metashape chunks

In [None]:
from collections.abc import Callable
from functools import partial
from typing import NamedTuple

import cv2

from benthoscan.utils.log import logger


class WindowHandle(NamedTuple):
    """Class representing a window handle."""

    name: str
    width: int
    height: int


def on_parameter_change(value, name: str = "") -> None:
    """TODO"""
    logger.info(f"{name}: {value}")


def initialize_stereo_tuner(
    window_name: str = "Window", width: int = 800, height: int = 1200
) -> WindowHandle:
    """TODO"""

    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
    cv2.resizeWindow(window_name, width, height)

    track_bars: list[tuple] = [
        ("minDisparity", 0, 200),
        ("numDisparities", 1, 50),
        ("blockSize", 5, 50),
        ("preFilterType", 1, 1),
        ("preFilterSize", 2, 25),
        ("preFilterCap", 5, 62),
        ("textureThreshold", 10, 100),
        ("uniquenessRatio", 15, 100),
        ("speckleRange", 0, 100),
        ("speckleWindowSize", 3, 25),
        ("disp12MaxDiff", 0, 100),
    ]

    for name, lower, upper in track_bars:
        cv2.createTrackbar(
            name, window_name, lower, upper, partial(on_parameter_change, name=name)
        )

    return WindowHandle(name=window_name, width=width, height=height)


def get_tuner_parameters(handle: WindowHandle) -> dict:
    """TODO"""

    parameters: dict = {
        "minDisparity": cv2.getTrackbarPos("minDisparity", handle.name),
        "numDisparities": cv2.getTrackbarPos("numDisparities", handle.name),
        "blockSize": cv2.getTrackbarPos("blockSize", handle.name),
        "preFilterType": cv2.getTrackbarPos("preFilterType", handle.name),
        "preFilterSize": cv2.getTrackbarPos("preFilterSize", handle.name),
        "preFilterCap": cv2.getTrackbarPos("preFilterCap", handle.name),
        "textureThreshold": cv2.getTrackbarPos("textureThreshold", handle.name),
        "uniquenessRatio": cv2.getTrackbarPos("uniquenessRatio", handle.name),
        "speckleRange": cv2.getTrackbarPos("speckleRange", handle.name),
        "speckleWindowSize": cv2.getTrackbarPos("speckleWindowSize", handle.name),
        "disp12MaxDiff": cv2.getTrackbarPos("disp12MaxDiff", handle.name),
        "minDisparity": cv2.getTrackbarPos("minDisparity", handle.name),
    }

    return parameters


def disparity_worker(
    image_pair: ImagePair,
    rectification: RectificationResult,
    parameters: dict,
) -> tuple[np.ndarray, np.ndarray]:
    """TODO"""
    rectified_images: tuple = rectify_image_pair(
        image_pair.master.image,
        image_pair.slave.image,
        rectification,
    )
    disparity_maps: tuple = compute_disparity_sgbm(
        rectified_images[0], rectified_images[1], parameters
    )
    return disparity_maps


def visualize_stereo_results(
    rectification: RectificationResult, image_loaders: list[ImagePairLoader]
) -> None:
    """TODO"""

    handle: WindowHandle = initialize_stereo_tuner(
        window_name="test_window", width=1600, height=900
    )

    focal_length: float = rectification.master.calibration.focal_length
    baseline: float = np.linalg.norm(rectification.slave.location)

    load_index: int = 0
    load_image: bool = True

    # TODO: Initialize disparity estimator
    #

    image_pair = None

    # While the window has not been closed
    while True:

        if load_image:
            loader: ImagePairLoader = image_loaders[load_index]
            image_pair: ImagePair = loader()
            load_action: bool = False

        parameters: dict = get_tuner_parameters(handle)

        if image_pair:
            disparity_maps: tuple = disparity_worker(
                image_pair, rectification, parameters
            )

            range_maps: tuple = (
                disparity_to_range(
                    disparity_maps[0], focal_length=focal_length, baseline=baseline
                ),
                disparity_to_range(
                    -disparity_maps[1], focal_length=focal_length, baseline=baseline
                ),
            )

            logger.info(range_maps[1])

            MAX_RANGE, MIN_RANGE = 3.0, 0.0
            scaled: tuple = (
                np.uint8(255 * (range_maps[0] - MIN_RANGE) / (MAX_RANGE - MIN_RANGE)),
                np.uint8(
                    255 * (-1.0 * range_maps[1] - MIN_RANGE) / (MAX_RANGE - MIN_RANGE)
                ),
            )

            merged: np.ndarray = np.hstack((scaled[0], scaled[1]))

            colored: np.ndarray = cv2.applyColorMap(merged, cv2.COLORMAP_AUTUMN)

            cv2.imshow(handle.name, colored)

        key: int = cv2.waitKey(100)

        match key:
            # Close window using esc key
            case 27:
                cv2.destroyWindow(handle.name)
                break
            case 32:
                load_image: bool = True
            # If no key is pressed, continue
            case -1:
                continue

In [None]:
document: Metashape.Document = backend.context._backend_data.get("document")

target_labels: list[str] = ["r23685bc_20100605_021022"]
target_chunks: list[Metashape.Chunk] = [
    chunk for chunk in document.chunks if chunk.label in target_labels
]

output_root: Path = Path("/data/kingston_snv_01/stereo_range_maps")

logger.info(ord("\n"))

# Generate range maps based on stereo pairs
for chunk in target_chunks:
    output_directory: Path = output_root / Path(f"{chunk.label}_range_maps")

    rectification, image_loaders = generate_stereo_rectification(chunk)

    visualize_stereo_results(rectification, image_loaders)