In [1]:
from pathlib import Path

import Metashape

from result import Ok, Err, Result

from benthoscan.backends import metashape as backend
from benthoscan.utils.log import logger

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Define workflow for obtaining rectification results and image pairs from Metashape

In [2]:
import plotly.express as px
import plotly.graph_objects as go
import tqdm

from benthoscan.data.image import Image, ImagePair

from benthoscan.backends.metashape.data_types import (
    SensorPair, 
    CameraPair,
    StereoGroup,
)

from benthoscan.backends.metashape.camera_helpers import (
    compute_stereo_calibration,
    get_stereo_groups,
)

from benthoscan.backends.metashape.image_helpers import (
    ImagePairLoader,
    generate_image_loaders,
)

from benthoscan.geometry.stereo import (
    StereoCalibration,
    RectificationResult,
    compute_stereo_rectification,
    rectify_image_pair,
)


def generate_stereo_rectification(chunk: Metashape.Chunk) -> tuple[RectificationResult, list[ImagePairLoader]]:
    """Returns rectification results and image loaders for the stereo groups in a Metashape chunk."""

    # Get pairs of sensors and cameras
    stereo_groups: list[StereoGroup] = get_stereo_groups(chunk)

    # NOTE: Process the first stereo group for now
    group: StereoGroup = stereo_groups[0]

    # Compute the stereo calibration for the sensor pair
    calibration: StereoCalibration = compute_stereo_calibration(group.sensor_pair)

    # Compute homographies, image transformations and camera matrices for stereo rectification
    rectification: RectificationResult = compute_stereo_rectification(calibration)
    
    # Set up image loaders for all camera pairs
    image_loaders: list[ImagePairLoader] = generate_image_loaders(group.camera_pairs)

    return rectification, image_loaders

### Define HITNET functionality

In [3]:
from pathlib import Path
from typing import NamedTuple

import cv2
import numpy as np
import onnxruntime as onnxrt

from benthoscan.data.image import Image, ImageFormat

MODEL_REPOSITORY: str = "https://github.com/nburrus/stereodemo/releases/download/v0.1-hitnet"

URLS: dict[str, str] = {
    "hitnet_eth3d_120x160.onnx": f"{MODEL_REPOSITORY}/hitnet_eth3d_120x160.onnx",
    "hitnet_eth3d_240x320.onnx": f"{MODEL_REPOSITORY}/hitnet_eth3d_240x320.onnx",
    "hitnet_eth3d_480x640.onnx": f"{MODEL_REPOSITORY}/hitnet_eth3d_480x640.onnx",
    "hitnet_eth3d_720x1280.onnx": f"{MODEL_REPOSITORY}/hitnet_eth3d_720x1280.onnx",
    "hitnet_middlebury_120x160.onnx": f"{MODEL_REPOSITORY}/hitnet_middlebury_120x160.onnx",
    "hitnet_middlebury_240x320.onnx": f"{MODEL_REPOSITORY}/hitnet_middlebury_240x320.onnx",
    "hitnet_middlebury_480x640.onnx": f"{MODEL_REPOSITORY}/hitnet_middlebury_480x640.onnx",
    "hitnet_middlebury_720x1280.onnx": f"{MODEL_REPOSITORY}/hitnet_middlebury_720x1280.onnx",
    "hitnet_sceneflow_120x160.onnx": f"{MODEL_REPOSITORY}/hitnet_sceneflow_120x160.onnx",
    "hitnet_sceneflow_240x320.onnx": f"{MODEL_REPOSITORY}/hitnet_sceneflow_240x320.onnx",
    "hitnet_sceneflow_480x640.onnx": f"{MODEL_REPOSITORY}/hitnet_sceneflow_480x640.onnx",
    "hitnet_sceneflow_720x1280.onnx": f"{MODEL_REPOSITORY}/hitnet_sceneflow_720x1280.onnx",
}


class Argument(NamedTuple):
    """Class representing an argument."""
    
    name: str
    shape: tuple
    type: type
    

class HitnetConfig(NamedTuple):
    """Class representing a Hitnet config."""

    session: onnxrt.InferenceSession

    @property
    def inputs(self) -> list[Argument]:
        """Returns the inputs of the session."""
        arguments = self.session.get_inputs()
        return [Argument(argument.name, tuple(argument.shape), argument.type) for argument in arguments]

    @property
    def outputs(self) -> list[Argument]:
        """Returns the inputs of the session."""
        arguments = self.session.get_outputs()
        return [Argument(argument.name, tuple(argument.shape), argument.type) for argument in arguments]

    @property
    def input_size(self) -> tuple[int, int]:
        """Returns the expected input size for the model as (H, W)."""
        tensor_argument: Argument = self.inputs[0]
        batch, channels, height, width = tensor_argument.shape
        return (height, width)


def load_hitnet(path: Path) -> HitnetConfig:
    """Loads a Hitnet model from an ONNX file."""

    if not path.exists():
        return Err(f"model path does not exist: {path}")
    if not path.suffix == ".onnx":
        return Err(f"model path is not an ONNX file: {path}")
    
    session: onnxrt.InferenceSession = onnxrt.InferenceSession(
        str(path), 
        providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
    )

    return HitnetConfig(session = session)


def _preprocess_images(config: HitnetConfig, left: Image, right: Image) -> tuple[np.ndarray, np.ndarray]:
    """Preprocess input images for HITNET."""

    match left.format:
        case ImageFormat.RGB:
            left_array: np.ndarray = cv2.cvtColor(left.to_array(), cv2.COLOR_RGB2GRAY)
        case ImageFormat.BGR:
            left_array: np.ndarray = cv2.cvtColor(left.to_array(), cv2.COLOR_BGR2GRAY)
        case ImageFormat.GRAY:
            left_array: np.ndarray = left.to_array()
        case _:
            raise NotImplementedError(f"invalid image format: {left.format}")

    match right.format:
        case ImageFormat.RGB:
            right_array: np.ndarray = cv2.cvtColor(right.to_array(), cv2.COLOR_RGB2GRAY)
        case ImageFormat.BGR:
            right_array: np.ndarray = cv2.cvtColor(right.to_array(), cv2.COLOR_BGR2GRAY)
        case ImageFormat.GRAY:
            right_array: np.ndarray = right.to_array()
        case _:
            raise NotImplementedError(f"invalid image format: {right.format}")

    # NOTE: Images should now be grayscale

    assert len(config.inputs) == 1, f"invalid number of inputs: {len(config.inputs)}"
    assert len(config.outputs) == 1, f"invalid number of outputs: {len(config.outputs)}"

    height, width = config.input_size

    left_array: np.ndarray = cv2.resize(left_array, (width, height), cv2.INTER_AREA)
    right_array: np.ndarray = cv2.resize(right_array, (width, height), cv2.INTER_AREA)
    
    # Grayscale needs expansion to reach H,W,C.
    # Need to do that now because resize would change the shape.
    if left_array.ndim == 2:
        left_array: np.ndarray = np.expand_dims(left_array, axis=-1)
    if right_array.ndim == 2:
        right_array: np.ndarray = np.expand_dims(right_array, axis=-1)

    # TODO: Get normalization value based on image dtype

    # -> H,W,C=2 or 6 , normalized to [0,1]
    tensor = np.concatenate((left_array, right_array), axis=-1) / 255.0
    # -> C,H,W
    tensor = tensor.transpose(2, 0, 1)
    # -> B=1,C,H,W
    tensor = np.expand_dims(tensor, 0).astype(np.float32)
    
    return tensor

def _postprocess_disparity(disparity: np.ndarray, image: Image, flip: bool=False) -> np.ndarray:
    """Postprocess the disparity map by resizing it to match the original image, 
    adjusting the disparity with the width ratio, and optionally flipping the disparity 
    horizontally."""

    # Squeeze disparity to a 2D array
    disparity: np.ndarray = np.squeeze(disparity)

    # Scale disparities by the width ratios between the original images and the disparity maps
    scale: float = float(image.width) / float(disparity.shape[1])
    disparity *= scale

    # Resize disparity maps to the original image sizes
    disparity: np.ndarray = cv2.resize(disparity, (image.width, image.height), cv2.INTER_AREA)

    # If enabled, flip disparity map around y-axis (horizontally)
    if flip:
        disparity: np.ndarray = cv2.flip(disparity, 1)
        
    return disparity


def compute_disparity(config: HitnetConfig, left: Image, right: Image) -> tuple[np.ndarray, np.ndarray]:
    """Computes the disparity for a pair of stereo images. The images needs to be 
    rectified prior to disparity estimation. Returns the left and right disparity as
    arrays with float32 values."""

    # Create tensor from flipped images to get left disparity
    flipped_left: Image = Image(
        data=cv2.flip(left.to_array(), 1), 
        format=left.format,
    )
    flipped_right: Image = Image(
        data=cv2.flip(right.to_array(), 1), 
        format=right.format,
    )

    tensor: np.ndarray = _preprocess_images(config, left, right)
    flipped_tensor: np.ndarray = _preprocess_images(config, flipped_right, flipped_left)

    left_outputs: list[np.ndarray] = config.session.run(["reference_output_disparity"], { "input": tensor })
    right_outputs: list[np.ndarray] = config.session.run(["reference_output_disparity"], { "input": flipped_tensor })

    # Since we estimate the right disparity from the flipped images, we need to flip the
    # right disparity map back to the same perspective as the original rigth image
    left_disparity: np.ndarray = _postprocess_disparity(left_outputs[0], left, flip=False)
    right_disparity: np.ndarray = _postprocess_disparity(right_outputs[0], right, flip=True)

    return left_disparity, right_disparity

### Load a ONNX model from file

In [4]:
import os

import cv2
import tqdm

from benthoscan.io import write_image
from benthoscan.runtime import Environment, load_environment
from benthoscan.geometry.range_maps import compute_range_from_disparity, compute_normals_from_range


def estimate_stereo_geometry_and_export(
    rectification: RectificationResult, 
    image_loaders: list[ImagePairLoader],
    model: HitnetConfig,
    directories: dict[str, Path],
) -> None:
    """Function for developing Hitnet functionality."""

    focal_length: float = rectification.master.calibration.focal_length
    baseline: float = rectification.slave.location[0]

    for loader in tqdm.tqdm(image_loaders, desc="Estimating disparity..."):
        images: ImagePair = loader()

        rectified_images: tuple = rectify_image_pair(
            master_image=images.first.to_array(),
            slave_image=images.second.to_array(),
            rectification=rectification,
        )
        
        left_rect: Image = Image(data=rectified_images[0], format=images.first.format)
        right_rect: Image = Image(data=rectified_images[1], format=images.second.format)
        
        disparity_maps: tuple[np.ndarray, np.ndarray] = compute_disparity(
            model, 
            left=left_rect, 
            right=right_rect
        )

        left_ranges: np.ndarray = compute_range_from_disparity(
            disparity=disparity_maps[0], 
            baseline=baseline, 
            focal_length=focal_length,
        )

        right_ranges: np.ndarray = compute_range_from_disparity(
            disparity=disparity_maps[1], 
            baseline=baseline,
            focal_length=focal_length,
        )

        left_normals: np.ndarray = compute_normals_from_range(
            range_map = left_ranges, 
            camera_matrix = rectification.master.calibration.camera_matrix, 
            flipped = True,
        )

        right_normals: np.ndarray = compute_normals_from_range(
            range_map = right_ranges, 
            camera_matrix = rectification.slave.calibration.camera_matrix, 
            flipped = True,
        )

        # Convert to 16 bit to save storage space
        left_ranges: np.ndarray = left_ranges.astype(np.float16)
        right_ranges: np.ndarray = right_ranges.astype(np.float16)
        left_normals: np.ndarray = left_normals.astype(np.float16)
        right_normals: np.ndarray = right_normals.astype(np.float16)

        paths: dict[str, Path] = {
            "left_ranges": directories.get("ranges") / f"{images.first.label}.tiff",
            "right_ranges": directories.get("ranges") / f"{images.second.label}.tiff",
            "left_normals": directories.get("normals") / f"{images.first.label}.tiff",
            "right_normals": directories.get("normals") / f"{images.second.label}.tiff",
        }

        write_image(uri=paths.get("left_ranges"), image=left_ranges)
        write_image(uri=paths.get("right_ranges"), image=right_ranges)
        write_image(uri=paths.get("left_normals"), image=left_normals)
        write_image(uri=paths.get("right_normals"), image=right_normals)


In [None]:
PATHS: dict[str, Path] = {
    "DOCUMENT_IN": Path(
        "/data/kingston_snv_01/acfr_revisits_metashape_projects/r23685bc_working_version.psz"
    ),
    "DOCUMENT_OUT": Path(
        "/data/kingston_snv_01/acfr_revisits_metashape_projects_test/r23685bc_working_version_saved.psz"
    ),
    "CACHE": Path("/home/martin/dev/benthoscan/.cache/"),
}


result: Result[Environment, str] = load_environment()

match result:
    case Ok(environment):
        logger.info("loaded environment")
    case Err(message):
        logger.error(message)

result: Result[str, str] = backend.load_project(PATHS.get("DOCUMENT_IN"))
match result:
    case Ok(message):
        logger.info(message)
    case Err(message):
        logger.error(message)

document: Metashape.Document = backend.context._backend_data.get("document")

target_labels: list[str] = [ "r23685bc_20100605_021022" ]
target_chunks: list[Metashape.Chunk] = [chunk for chunk in document.chunks if chunk.label in target_labels]

directories: dict[str, Path] = {
    "ranges": Path("/data/kingston_snv_01/stereo_range_maps"),
    "normals": Path("/data/kingston_snv_01/stereo_normal_maps"),
}

# Load model
model: HitnetConfig = load_hitnet(environment.resource_directory / Path("hitnet_models/hitnet_eth3d_720x1280.onnx"))

# Generate range maps based on stereo pairs
for chunk in target_chunks:

    output_directories = {key: path / f"{chunk.label}" for key, path in directories.items()}
    
    rectification, image_loaders = generate_stereo_rectification(chunk)

    # NOTE: Test Hitnet workflow - Development purposes only!
    estimate_stereo_geometry_and_export(rectification, image_loaders, model, output_directories)

[32m2024-08-22 16:23:59.324[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m16[0m - [1mloaded environment[0m


LoadProject: path = /data/kingston_snv_01/acfr_revisits_metashape_projects/r23685bc_working_version.psz


[32m2024-08-22 16:24:26.543[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mloaded document /data/kingston_snv_01/acfr_revisits_metashape_projects/r23685bc_working_version.psz successfully[0m


loaded project in 27.1366 sec


Estimating disparity...:   4%|███████▌                                                                                                                                                                                           | 106/2728 [04:58<2:03:42,  2.83s/it]