In [None]:
from pathlib import Path

import Metashape

from result import Ok, Err, Result

from benthoscan.registration import PointCloud, PointCloudLoader
from benthoscan.backends import metashape as backend
from benthoscan.utils.log import logger

In [None]:
PATHS: dict[str, Path] = {
    "DOCUMENT_IN": Path(
        "/data/kingston_snv_01/acfr_revisits_metashape_projects/r23685bc_working_version.psz"
    ),
    "DOCUMENT_OUT": Path(
        "/data/kingston_snv_01/acfr_revisits_metashape_projects_test/r23685bc_working_version_saved.psz"
    ),
    "CACHE": Path("/home/martin/dev/benthoscan/.cache/"),
}

result: Result[str, str] = backend.load_project(PATHS.get("DOCUMENT_IN"))
match result:
    case Ok(message):
        logger.info(message)
    case Err(message):
        logger.error(message)

backend.log_internal_data()

### Define disparity estimator

In [None]:
import cv2
import numpy as np

def sgbm_default_parameters() -> dict:
    """
    Returns default parameters for semi-global block matching.
    Adopted from: https://github.com/eborboihuc/stereo-matching/blob/master/sgbm.py
    """
    
    min_disparity: int = 0
    num_disparities: int = 16
    block_size: int = 5
    window_size: int = 3
    disp12_max_diff: int = 1
    uniqueness_ratio: int = 15
    speckle_window_size: int = 0
    speckle_range: int = 2
    pre_filter_cap: int = 63
    mode: object = cv2.STEREO_SGBM_MODE_SGBM_3WAY
    
    p1: int = 8 * 3 * window_size**2
    p2: int = 32 * 3 * window_size**2

    return {
        "minDisparity": min_disparity,
        "numDisparities": num_disparities,
        "blockSize": block_size,
        # "windowSize": window_size,
        "P1": p1,
        "P2": p2,
        "disp12MaxDiff": disp12_max_diff,
        "uniquenessRatio": uniqueness_ratio,
        "speckleWindowSize": speckle_window_size,
        "speckleRange": speckle_range,
        "preFilterCap": pre_filter_cap,
        "mode": mode
    }

def filter_disparity(
    estimator: object, 
    image_left: np.ndarray, 
    disparity_left: np.ndarray, 
    disparity_right: np.ndarray
) -> np.ndarray:
    """Filter disparity map based on WLS. Returns the disparity map as 16-bit pixels."""
    
    filtered_disparity_left = estimator.filter(
        disparity_left, 
        image_left, 
        None, 
        disparity_right,
    )
    filtered_disparity_left = cv2.normalize(
        src=filtered_disparity_left, 
        dst=filtered_disparity_left, 
        beta=0, 
        alpha=255, 
        norm_type=cv2.NORM_MINMAX
    )
    
    return filtered_disparity_left
    

def estimate_disparity_sgbm(image_left: np.ndarray, image_right: np.ndarray):
    """TODO"""
    # wsize default 3; 5; 7 for SGBM reduced size image; 15 for SGBM full size image (1300px and above); 5 Works nicely

    parameters: dict = sgbm_default_parameters()
    max_disparity: int = 128
    
    left_matcher = cv2.StereoSGBM_create(**parameters)
    right_matcher = cv2.ximgproc.createRightMatcher(left_matcher)
    
    disparity_left: np.ndarray = left_matcher.compute(image_left, image_right)  # .astype(np.float32)/16
    disparity_right: np.ndarray = right_matcher.compute(image_right, image_left)  # .astype(np.float32)/16

    assert disparity_left.dtype == np.int16, "invalid disparity type: expected int16"
    assert disparity_right.dtype == np.int16, "invalid disparity type: expected int16"

    # FILTER Parameters
    lmbda = 70000
    sigma = 1.7
    visual_multiplier = 6

    disparity_filter = cv2.ximgproc.createDisparityWLSFilter(matcher_left=left_matcher)
    disparity_filter.setLambda(lmbda)
    disparity_filter.setSigmaColor(sigma)

    filtered_left: np.ndarray = filter_disparity(disparity_filter, image_left, disparity_left, disparity_right)
    filtered_right: np.ndarray = filter_disparity(disparity_filter, image_right, disparity_right, disparity_left)

    return filtered_left, filtered_right

### Define stereo visualization functions

In [None]:
import numpy as np
import plotly as pl
import plotly.express as px
import plotly.graph_objects as go

from plotly.subplots import make_subplots

def trace_stereo_recitification(
    unrectified_master: np.ndarray, 
    unrectified_slave: np.ndarray,
    rectified_master: np.ndarray,
    rectified_slave: np.ndarray,
) -> dict:
    """Traces stereo rectification results, i.e. unrectified and rectified images."""
    return {
        "unrectified_master": go.Image(z=unrectified_master),
        "unrectified_slave": go.Image(z=unrectified_slave),
        "rectified_master": go.Image(z=rectified_master),
        "rectified_slave": go.Image(z=rectified_slave),
    }
    

def trace_stereo_matching(
    master_image: np.ndarray, 
    slave_image: np.ndarray,
    disparity_map: np.ndarray,
    range_map: np.ndarray,
) -> dict:
    """Traces stereo matching results, i.e. images, disparity map, and range map."""
    return {
        "master_image": go.Image(z=master_image),
        "slave_image": go.Image(z=slave_image),
        "disparity": go.Image(z=disparity_map),
        "range": go.Image(z=range_map),
    }

def visualize_stereo_results(
    image_master: np.ndarray,
    image_slave: np.ndarray,
    rectified_master: np.ndarray,
    rectified_slave: np.ndarray,
    disparity_map: np.ndarray,
    range_map: np.ndarray
) -> dict[str, go.Figure]:
    """Visualizes stereo results."""
    figures: dict[str, go.Figure] = {
        "rectification": make_subplots(rows=2, cols=2, horizontal_spacing=0.01),
        "matching": make_subplots(rows=2, cols=2, horizontal_spacing=0.01),
    }
    
    figures.get("rectification").add_trace(go.Heatmap(z=image_master, colorscale='gray'), 1, 1)
    figures.get("rectification").add_trace(go.Heatmap(z=image_slave, colorscale='gray'), 1, 2)
    figures.get("rectification").add_trace(go.Heatmap(z=rectified_master, colorscale='gray'), 2, 1)
    figures.get("rectification").add_trace(go.Heatmap(z=rectified_slave, colorscale='gray'), 2, 2)

    figures.get("rectification").update_layout(yaxis = dict(scaleanchor = 'x'))

    figures.get("matching").add_trace(go.Heatmap(z=rectified_master, colorscale='gray'), 1, 1)
    figures.get("matching").add_trace(go.Heatmap(z=rectified_slave, colorscale='gray'), 1, 2)
    figures.get("matching").add_trace(go.Heatmap(z=disparity_map), 2, 1)
    figures.get("matching").add_trace(go.Heatmap(z=range_map), 2, 2)

    figures.get("matching").update_layout(yaxis = dict(scaleanchor = 'x'))

    return figures

### Define overall stereo depth estimation process

In [None]:
from benthoscan.backends.metashape.stereo_funcs import (
    SensorPair, 
    CameraPair,
    StereoGroup,
    compute_stereo_extrinsics,
    image_to_numpy,
)

from benthoscan.geometry.stereo import (
    RectifyingHomography,
    RectifyingPixelMap,
    get_stereo_groups,
    compute_rectifying_homographies,
    compute_rectifying_pixel_maps,
    rectify_image_pair,
)


def process_stereo_images(chunk: Metashape.Chunk, sensors: SensorPair, camera_pairs: list[CameraPair]):
    """Get the stereo images"""
    
    location, rotation = compute_stereo_extrinsics(sensors)

    homographies: RectifyingHomography = compute_rectifying_homographies(sensors)
    pixel_maps: RectifyingPixelMap = compute_rectifying_pixel_maps(sensors, homographies)

    for index, camera_pair in enumerate(camera_pairs):
        
        master_image: Metashape.Image = camera_pair.master.image()
        slave_image: Metashape.Image = camera_pair.slave.image()

        converted_master: np.ndarray = np.squeeze(image_to_numpy(master_image))
        converted_slave: np.ndarray = np.squeeze(image_to_numpy(slave_image))
        
        match master_image.channels:
            case "RGB":
                converted_master: np.ndarray = cv2.cvtColor(converted_master, cv2.COLOR_RGB2GRAY)
        
        match slave_image.channels:
            case "RGB":
                converted_slave: np.ndarray = cv2.cvtColor(converted_slave, cv2.COLOR_RGB2GRAY)
        
        rectified_master, rectified_slave = rectify_image_pair(
            converted_master, 
            converted_slave, 
            pixel_maps,
        )
        
        disparity_maps: tuple[np.ndarray, np.ndarray] = estimate_disparity_sgbm(rectified_master, rectified_slave)

        # Convert disparity to range estimates
        inverse_disparity: np.ndarray = 1.0 / disparity_maps[0]
        range_map: np.ndarray = (inverse_disparity * (1834.0 * location[0])).astype(np.float32)

        # TODO: Plot image side by side 
        # Look into rendering images with Open3D
        rectification_traces: dict = trace_stereo_recitification(
            converted_master, 
            converted_slave, 
            rectified_master,
            rectified_slave,
        )

        figures: dict[str, go.Figure] = visualize_stereo_results(
            converted_master,
            converted_slave,
            rectified_master,
            rectified_slave,
            disparity_maps[0],
            range_map,
        )

        for key, figure in figures.items():
            figure.show()

        raise NotImplementedError("process_stereo_images is not implemented")


def export_stereo_range_maps(chunk: Metashape.Chunk, directory: Path) -> None:
    """Export range maps based on a stereo camera setup."""

    # Get pairs of sensors and cameras (master-slaves)
    stereo_groups: list[StereoGroup] = get_stereo_groups(chunk)

    for group in stereo_groups:
        process_stereo_images(chunk, group.sensor_pair, group.camera_pairs)
    
    raise NotImplementedError

### Test stereo depth estimation on camera pairs from Metashape chunks

In [None]:
document: Metashape.Document = backend.context._backend_data.get("document")

target_labels: list[str] = [ "r23685bc_20100605_021022" ]
target_chunks: list[Metashape.Chunk] = [chunk for chunk in document.chunks if chunk.label in target_labels]


output_root: Path = Path("/data/kingston_snv_01/stereo_range_maps")


# Generate range maps based on stereo pairs
for chunk in target_chunks:
    output_directory: Path = output_root / Path(f"{chunk.label}_range_maps")
    export_stereo_range_maps(chunk, output_directory)