# Introduction

This notebook presents segmentation-aided stereo matching method with confidence refinement. This method is based on the paper "Segmentation-aided stereo matching with confidence refinement" by Marek S. Tatara, Jan Glinko, Michał Czubenko, Marek Grzegorek, Zdzisław Kowalczuk. The paper is available at ...

## Assumptions

- Input images are already stereo rectified.

## Description of the methodology

The method is based on the following steps:
1. **Segmentation**: The input images are segmented using a pre-trained segmentation model, in this case, SAM2 is used
2. **Segment-level Stereo Matching**: The segmented fragments of the images are populated to the second image, treating the fragment as a single object to be tracked in subsequent video frames. SAM2 is used here as well
3. **Pixel-level Stereo Matching**: For each pair of corresponding fragments, the pixel-level stereo matching is performed using the stereo matching algorithm. The proposed algorithm calculated SAD (Sum of Absolute Differences) and for a given height using the block matching approach, and for each match the confidence is calculated
4. **Confidence Refinement**: The confidence calculated for the found matches is used to extract high-confidence matches between the left and right segments. These points are used to narrow down the search area for the remaining pixels (assuming that the disparity of a pixel must lie in the range of surrounding disparities). The confidence is calculated as the ratio of the minimum SAD to the second minimum SAD. The confidence is then used to filter out the matches with low confidence.
5. **Global Refinement**: After all segments are matched, the global refinement is performed. The disparity of each remaining pixel is calculated using the disparity of the surrounding pixels. The disparity is calculated as the weighted average of the disparities of the surrounding pixels, where the weights are the confidence of the surrounding pixels (to be done)

## References

- [SAM2](https://github.com/facebookresearch/sam2)
- [Middlebury Stereo Dataset 2021](https://vision.middlebury.edu/stereo/data/scenes2021/)

# How to run the code

Install poetry and run the following commands:
```bash
poetry install
```

Run the code in the notebook. Everything should be downloaded and installed automatically.

In [None]:
# Base code taken from SAM2 repository, automatic mask generator connected with video predictor

import os
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

from helpers import setup_device, save_annotations
from sam2.build_sam import build_sam2_video_predictor, build_sam2


In [None]:
# setup parameters and the device

RANDOM_SEED = 42
OUT_PATH = "out_images/"
INPUT_DATA_DIR = "input_data/"
INPUT_FOLDER = "pendulum1"
CHECKPOINT_PATH = "sam2/checkpoints/"

input_dir = Path(INPUT_DATA_DIR) / "data" / INPUT_FOLDER

Path(OUT_PATH).mkdir(parents=True, exist_ok=True)
Path(INPUT_DATA_DIR).mkdir(parents=True, exist_ok=True)
Path(CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True)
device = setup_device()
np.random.seed(RANDOM_SEED)

Download and unzip the dataset

In [None]:
os.listdir(INPUT_DATA_DIR)

In [None]:
# Check if INPUT_DATA folder is not empty
if not os.listdir(INPUT_DATA_DIR):
    print("Downloading Middlebury dataset")
    !wget https://vision.middlebury.edu/stereo/data/scenes2021/zip/all.zip -O all.zip
    !unzip -q all.zip -d {INPUT_DATA_DIR}

Download checkpoint and config. If the model is too big for you GPU, you can use a smaller one, e.g. sam2.1_hiera_base_plus.pt + sam2.1_hiera_base+.yaml

In [None]:
if not os.listdir(CHECKPOINT_PATH):
    !wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt -O {CHECKPOINT_PATH}/sam2.1_hiera_large.pt
    !wget https://raw.githubusercontent.com/facebookresearch/sam2/refs/heads/main/sam2/configs/sam2.1/sam2.1_hiera_l.yaml -O {CHECKPOINT_PATH}/sam2.1_hiera_l.yaml
sam2_checkpoint = Path(CHECKPOINT_PATH) / "sam2.1_hiera_large.pt"
sam2_config = Path(CHECKPOINT_PATH) / ("sam2.1_hiera_l.yaml")
# Don't ask why "/" is needed in front of sam2_config, it's just how it works
sam2_config_full_path = "/"+str(sam2_config.absolute())

Instantiate mask generator. The parameters here can be tweaked, but the default ones should work fine. The most important one here may be the min_mask_region_area, which is (surprisingly) the minimum area of the mask region.

In [None]:

image_predictor = build_sam2(sam2_config_full_path, sam2_checkpoint.absolute(), device=device, apply_postprocessing=False)

mask_generator= SAM2AutomaticMaskGenerator(
    model=image_predictor,
    points_per_side=32,
    points_per_batch=64,
    pred_iou_thresh=0.5,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=0,
    box_nms_thresh=0.5,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=250.0,
    use_m2m=True,
)

Load input images

In [None]:

image0 = Image.open(input_dir / 'im0.png')
image0 = np.array(image0.convert("RGB"))
image1 = Image.open(input_dir / 'im1.png')
image1 = np.array(image1.convert("RGB"))


masks0 = mask_generator.generate(image0)
# masks1 = mask_generator.generate(image1)

save_annotations(masks0, path=OUT_PATH + "0.jpg")

Get the masks and their IDs

In [None]:
print(f"Number of masks found: {len(masks0)}")

# convert segmentation to mask logits

masks = []
mask_ids = []
for i, mask in enumerate(masks0):
    masks.append(mask["segmentation"])
    mask_ids.append(i)


Clean up some memory to avoid OOM from CUDA

In [None]:
import gc
del mask_generator
del image_predictor
torch.cuda.empty_cache()
gc.collect()

### 1. SEGMENTATION

Instantiate video predictor - treat two images as a video

In [None]:
video_predictor = build_sam2_video_predictor(sam2_config_full_path, sam2_checkpoint, device=device)

Helper functions for visualization

In [None]:
from visualization import show_box, show_mask, show_points

Get all the frames and change png to jpg because the video predictor expects jpg

In [None]:
frame_names = [
    p for p in os.listdir(input_dir)
    if os.path.splitext(p)[-1] in [".png", ".PNG"]
]

frames_path = Path(input_dir) / "frames"
frames_path.mkdir(parents=True, exist_ok=True)

for frame_name in frame_names:
    img = Image.open(os.path.join(input_dir, frame_name))
    img.save(os.path.join(frames_path, frame_name.replace('im','').replace('.png', '.jpeg')))

frame_names = [x.replace('im','').replace('png','jpeg') for x in frame_names if 'im' in x]
frame_names.sort(key=lambda p: os.path.splitext(p)[0])

### 2. SEGMENT-LEVEL STEREO MATCHING

Get the masks and their IDs from the first image and propagate them to the second image

In [None]:
inference_state = video_predictor.init_state(video_path=str(frames_path))
MASK_LIMIT = 45

for mask_id, mask in zip(mask_ids, masks):
    if mask_id<MASK_LIMIT:
        video_predictor.add_new_mask(inference_state, 0, mask_id, mask)

video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 1
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(frames_path, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

Load reference disparity maps - this will be our ground truth used to evaluate the results

In [None]:
# load pfm depth data from disp0.pfm and disp1.pfm

disp0 = np.array(Image.open(Path(input_dir) / 'disp0.pfm'))
disp1 = np.array(Image.open(Path(input_dir) / 'disp1.pfm'))

# show the images

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("disp0")
plt.imshow(disp0, cmap='jet')
plt.colorbar()
plt.subplot(1, 2, 2)
plt.title("disp1")
plt.imshow(disp1, cmap='jet')
plt.colorbar()
plt.show()

### 3. PIXEL-LEVEL STEREO MATCHING

Now, for the sake of the demonstration, just a single mask will be selected. In the real application, all masks would be used.

**Note**: The same mask ID is used for both images. The ID consistency between images is caused by the propagation of the masks from the first image to the second image (done above).

In [None]:
SELECTED_MASK_ID = 20

plt.figure(figsize=(6, 4))

plt.title(f"mask {SELECTED_MASK_ID} - frame left")
mask0 = masks[SELECTED_MASK_ID]
plt.imshow(image0)
show_mask(mask0, plt.gca(), obj_id=SELECTED_MASK_ID, random_color=False)
plt.show()

plt.title(f"mask {SELECTED_MASK_ID} - frame right")
plt.imshow(image1)
mask1 = video_segments[1][SELECTED_MASK_ID][0]
show_mask(mask1, plt.gca(), obj_id=SELECTED_MASK_ID, random_color=False)
plt.show()


Optionally remove small blobs from the masks (happens sometimes)

In [None]:
from skimage import measure

def remove_small_blobs(mask, min_size=100):
    labels = measure.label(mask)
    mask = np.zeros_like(mask)
    for region in measure.regionprops(labels):
        if region.area >= min_size:
            mask[labels == region.label] = 1
    return mask

mask0 = remove_small_blobs(mask0, min_size=100)
mask1 = remove_small_blobs(mask1, min_size=100)

Define helper functions to cut out the mask from the image using the bounding box

In [None]:
# find bounding boxes for masks in left and right images

from helpers import find_bbox, crop_image

bbox_left = find_bbox(mask0)
bbox_right = find_bbox(mask1)

# crop the bounding boxes from the images

crop0 = crop_image(image0, bbox_left)
crop1 = crop_image(image1, bbox_right)


plt.figure(figsize=(6, 4))
plt.subplot(1, 2, 1)
plt.title("Crop left")
plt.imshow(crop0)
plt.subplot(1, 2, 2)
plt.title("Crop right")
plt.imshow(crop1)
plt.show()


Merge bboxes to have the same size of the cutout images and calculate the offset for the disparity. The offset is a constant value that is added to the disparity to get the disparity in the original image. Having the same size of the cutout images makes it easier to calculate the disparity.

In [None]:
from helpers import crop_and_align_masks

# It's basically the same crop as above, but with the offset calculated, which is critical for 3D reconstruction
image0_crop_aligned, image1_crop_aligned, mask0_crop_aligned, mask1_crop_aligned, offset = crop_and_align_masks(image0, image1, mask0, mask1)

plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.title("image left cropped")
plt.imshow(image0_crop_aligned)
plt.subplot(1, 2, 2)
plt.title("image right cropped")
plt.imshow(image1_crop_aligned)
plt.show()


This can probably be done with some other algorithm or approach here, but the point is that the sum of absolute differences is calculated for each pair of blocks (of the same size) in the left and right images. It is necessary to know where the difference is the lowest to find the match for each source pixel. The algorithm is quite simple and can be optimized in many ways. It doesn't check other heights because we assume that the images are stereo rectified (correctly).

In [None]:
WINDOW_SIZE = 21

from algo import calculate_sad_values

# Convert images to np.float32
image0_crop_aligned = image0_crop_aligned.astype(np.float32)
image1_crop_aligned = image1_crop_aligned.astype(np.float32)


sads = calculate_sad_values(image0_crop_aligned, image1_crop_aligned, mask0_crop_aligned, mask1_crop_aligned, WINDOW_SIZE)

Note that the shape of SAD is (image0_height, image0_width, image1_width), meaning that SADs are stored for each row, then for each pixel in the left image, and finally for each pixel in the right image. The "pixel" means the central pixel of a block.

Now, visualize the best matches for each pixel in the left image. The best matches are the ones with the lowest SAD. Ideally, the image should be smooth, but for a uniform texture, it is hard to find the correct match. Here's a visualization for a "naive" approach

In [None]:
plt.figure(figsize=(12, 6))
plt.title("Matched pixels (at the same height) for each pixel in the left image")
plt.imshow(sads.argmin(axis=2), cmap='jet')
plt.colorbar()
plt.show()


### A bunch of additional visualizations

1. The disparity map for a given height with the best matches - for a narrower "paths" it's easier for fidn a match. Bigger blobs of uniform color denote regions that are visually similar, and probably kind of interpolation should be used there.
2. Visualizaiton of the compared slices, to see what is actually compared
3. The minimal SAD value for each pixel in the left image - theoretically, the lower the value, the better the match
4. The best matching argument for each pixel on a given height in the left image

In [None]:
VISUALIZATION_INDEX = 82 # make sure it's not exceeding the height of the mask

# show sad and mark minimal values in each row with red

from helpers import find_left_right_bounds_for_all_heights

left_bounds0, right_bounds0 = find_left_right_bounds_for_all_heights(mask0_crop_aligned)
left_bounds1, right_bounds1 = find_left_right_bounds_for_all_heights(mask1_crop_aligned)

exemplary_sad = sads[VISUALIZATION_INDEX,left_bounds0[VISUALIZATION_INDEX]:right_bounds0[VISUALIZATION_INDEX],left_bounds1[VISUALIZATION_INDEX]:right_bounds1[VISUALIZATION_INDEX]]
plt.figure(figsize=(12, 8))
plt.title(f"Normalized Sum of Absolute Differences, height={VISUALIZATION_INDEX}")
plt.imshow(exemplary_sad, cmap='jet')
plt.xlabel("right image x")
plt.ylabel("left image x")

argmin_sad = np.argmin(exemplary_sad, axis=1)
min_sad = np.min(exemplary_sad, axis=1)

# draw a line going through the diagonal of the plot
plt.plot(argmin_sad, np.array(range(0, exemplary_sad.shape[0])) , color='red')
plt.colorbar()
plt.show()

# show slice of left image with height denoted by window_size and right image below

plt.figure(figsize=(12, 4))
plt.subplot(2, 1, 1)
plt.title("left image slice")
plt.imshow(image0_crop_aligned[VISUALIZATION_INDEX-WINDOW_SIZE//2:VISUALIZATION_INDEX+WINDOW_SIZE//2+1,  left_bounds0[VISUALIZATION_INDEX]:right_bounds0[VISUALIZATION_INDEX]].astype(np.uint8))
plt.subplot(2, 1, 2)
plt.title("right image slice")
plt.imshow(image1_crop_aligned[VISUALIZATION_INDEX-WINDOW_SIZE//2:VISUALIZATION_INDEX+WINDOW_SIZE//2+1,  left_bounds1[VISUALIZATION_INDEX]:right_bounds1[VISUALIZATION_INDEX]].astype(np.uint8))
plt.show()


# show the minimal sad values

plt.figure(figsize=(12, 4))
plt.title(f"Minimal SAD values, height={VISUALIZATION_INDEX}")
plt.plot(min_sad.T)
plt.xlabel("left image x")
plt.ylabel("minimal SAD")
plt.show()


# show the minimal sad arguments

plt.figure(figsize=(12, 4))
plt.title(f"Arg for Minimal SAD values, height={VISUALIZATION_INDEX}")
plt.plot(argmin_sad.T)
plt.xlabel("left image x")
plt.ylabel("minimal SAD")
plt.show()


To get even deeper insights, visualize exemplary SAD values for a given height for a given left pixel.

In [None]:
LEFT_INDEX = 20

plt.figure(figsize=(12, 4))
plt.title(f"Minimal SAD values, height={VISUALIZATION_INDEX}")
plt.plot(exemplary_sad[:, LEFT_INDEX], 'x')
plt.xlabel("right image x")
plt.ylabel("minimal SAD")
plt.show()


### 4. CONFIDENCE REFINEMENT

Now, the confidence is calculated for each match. The confidence is calculated to retain only the high-confidence matches. The confidence is calculated as the ratio of the minimum SAD to the second minimum SAD. The confidence is then used to filter out the matches with low confidence.

The current confidence estimation algorithm works as follows:
1. For each pixel in the left image, the SADs are calculated for each pixel in the right image.
2. For a given pixel in the left image:
    - The minimal SAD is found
    - The angles of a line (red line in the figure) between the minimal SAD location and any other location is calculated
    - The confidence is the minimal angle found for a given pixel in the left image

The reasoning of the above is that the directly surrounding pixels will naturally be more similar to each other, especially for uniform texture patches.

The algorithm strongly relies on the confidence calculated. Therefore, there's room for improvement and any other function can be used. One way to improve it is to take relative distance between the left and right pixels into account, or ensure continuity of the disparity.

In [None]:
EXEMPLARY_LEFT_INDEX = 10
# find the line that goes from the minimal sad value and is under the minimal sad values

index_sad = exemplary_sad[EXEMPLARY_LEFT_INDEX, :]

# find the minimal sad value
min_sad_value = index_sad.min()

# find the minimal sad value index
min_sad_index = index_sad.argmin()
print(min_sad_index)

# for each point not being the minimum in index_sad draw a line from minium to that point

plt.figure(figsize=(12, 8))
plt.title("Minimal SAD values")
plt.plot(index_sad, 'x')
plt.xlabel("right image x")
plt.ylabel("minimal SAD")
for i, sadd in enumerate(index_sad):
    if i != min_sad_index:
        plt.plot([min_sad_index, i], [min_sad_value, sadd], color='red')
plt.show()


Now, the angles of each red line are calculated and shown

In [None]:

angles = np.abs(np.arctan2(np.array(range(0, exemplary_sad.shape[1])) - min_sad_index, min_sad_value - index_sad))

# print the second minimal angle
print(f"Minimal angle: {angles[angles != 0].min():.4f} for right-image pixel {angles[angles != 0].argmin()}")

plt.figure(figsize=(12, 4))
plt.title(f"Angles for exemplary left image pixel with height {VISUALIZATION_INDEX} and x coordinate {EXEMPLARY_LEFT_INDEX}")
plt.plot(angles)
plt.xlabel("right image x")
plt.ylabel("angle")
plt.show()


Here's the code for the confidence calculation. The confidence is calculated for each pixel in the left image. The confidence heatmap is shown.

In [None]:
from algo import find_confidence

# TODO: add this part to the find_confidence function
confidences = np.zeros((sads.shape[0], sads.shape[1]))
for j in range(sads.shape[0]):
    if left_bounds0[j] == right_bounds0[j] or left_bounds1[j] == right_bounds1[j]:
        continue
    sad = sads[j, left_bounds0[j]:right_bounds0[j], left_bounds1[j]:right_bounds1[j]]
    for i in range(sad.shape[0]):
        confidences[j,i+left_bounds0[j]] = find_confidence(sad, i)

# plot confidences as heatmap

plt.figure(figsize=(12, 8))
plt.title("Matching confidence map")
plt.imshow(confidences, cmap='jet')
plt.colorbar()
plt.xlabel("Left image width")
plt.ylabel("Left image height")
plt.show()


In [None]:
# find the argument of the minimal value in sads

argmin_sads= np.argmin(np.nan_to_num(sads, nan=float('inf')), axis=2)

print(argmin_sads.shape)
# subtract the location of x from argmin_sads

matches = argmin_sads.astype(np.float64)
# put nan where the mask is not present
matches[mask0_crop_aligned == 0] = np.nan
disparities = (argmin_sads - np.array(range(0, sads.shape[2]))[np.newaxis, np.newaxis, :])[0]
print(disparities.shape)

# find maximal value of confidences ignoring nan values

max_confidence = np.nanmax(confidences)
print(max_confidence)

In [None]:

threshold = 2
thresholded_confidences = confidences.copy()
thresholded_confidences[thresholded_confidences < threshold] = 0
plt.figure(figsize=(12, 8))
plt.title("Matching confidence map")
plt.imshow(thresholded_confidences, cmap='jet')
plt.colorbar()
plt.xlabel("Left image width")
plt.ylabel("Left image height")
plt.show()

# draw argmin_sads where the confidence is non-zero and non-nan

disparities_left = disparities.copy().astype(np.float64) # - offset
disparities_left[thresholded_confidences == 0] = np.nan

# put nan where the confidence is nan



def get_confidences(sad_matrix, left_bounds0, right_bounds0, left_bounds1, right_bounds1):
    confidences = np.zeros((sad_matrix.shape[0], sad_matrix.shape[1]))
    for j in range(sad_matrix.shape[0]):
        if left_bounds0[j] == right_bounds0[j] or left_bounds1[j] == right_bounds1[j]:
            continue
        sad = sad_matrix[j, left_bounds0[j]:right_bounds0[j], left_bounds1[j]:right_bounds1[j]]
        for i in range(sad.shape[0]):
            confidences[j,i+left_bounds0[j]] = find_confidence(sad, i)
    return confidences

confidences = get_confidences(sads, left_bounds0, right_bounds0, left_bounds1, right_bounds1)

disparities_left[np.isnan(confidences)] = 0


plt.figure(figsize=(12, 8))
plt.title("Disparities")
plt.imshow(disparities_left, cmap='jet')
plt.colorbar()
plt.xlabel("Left image width")
plt.ylabel("Left image height")
plt.show()

In [None]:
from dataclasses import dataclass


argmin_sads= np.argmin(np.nan_to_num(sads, nan=float('inf')), axis=2)

print(argmin_sads.shape)
# subtract the location of x from argmin_sads

matches = argmin_sads.astype(np.float64)
matches_filled = np.ones_like(matches)*-1
matches_filled[mask0_crop_aligned == 0] = np.nan
matches_filled_0 = np.ones_like(matches)*-1
matches_filled_0[mask0_crop_aligned == 0] = np.nan
sads_c = sads.copy()
confidences_c = get_confidences(sads_c, left_bounds0, right_bounds0, left_bounds1, right_bounds1)

threshold = 2.1


matches_filled_0[confidences_c > threshold] = matches[confidences_c > threshold]
# copy the matches to matches_filled where the confidence is above the threshold
matches_filled[confidences_c > threshold] = matches[confidences_c > threshold]

@dataclass
class Range:
    start: int
    end: int
    min_val: int
    max_val: int

def get_ranges(non_nan, valid, matches_row):
    # TODO: adde a range in case of [... 135, 135, -1, -1, nan, nan, 135, nan, ...]
    ranges = []
    curr_left = valid[0]
    curr_right = valid[1]
    if non_nan[0] != curr_left:
        ranges.append(Range(non_nan[0], curr_left, 0, matches_row[valid[0]]))
    for right_idx in valid[2:]:
        if curr_right > curr_left + 1:
            ranges.append(Range(curr_left, curr_right, matches_row[curr_left], matches_row[curr_right]))
        curr_left = curr_right
        curr_right = right_idx
    return ranges

# for each row
for j in range(100):
    for i in range(matches_filled.shape[0]):
        # find the first non-nan value
        non_nan_values = np.where(~np.isnan(matches_filled[i]))[0]
        valid_values = np.where(~np.isnan(matches_filled[i]) * matches_filled[i]>-1)[0]
        if len(non_nan_values) == 0:
            continue
        if len(valid_values) == len(non_nan_values):
            # nothing to fill, row is already filled
            continue
        first_non_nan = non_nan_values[0]
        last_non_nan = non_nan_values[-1]
        # fill the values before the first non-nan value with the first non-nan value
        if len(valid_values) > 1:
            first_valid = valid_values[0]
            last_valid = valid_values[-1]
            ranges_to_check = get_ranges(non_nan_values, valid_values, matches_filled[i])
            for range_to_check in ranges_to_check:
                if i==0:
                    print(range_to_check)

                sads_c[i,range_to_check.start:range_to_check.end+1,:np.int64(np.floor(range_to_check.min_val))]=np.inf
                sads_c[i,range_to_check.start:range_to_check.end+1,np.int64(np.ceil(range_to_check.max_val))+1:]=np.inf
    confidences_c = get_confidences(sads_c, left_bounds0, right_bounds0, left_bounds1, right_bounds1)
    # additions to confidences where neighbouring pixels are distant by a distance_threshold
    distance_threshold = 2.5
    conf_increment = 0.3
    for i in range(1, confidences_c.shape[1]-1):
        for j in range(1, confidences_c.shape[0]-1):
            if np.abs(matches_filled[j, i-1] - matches[j, i]) < distance_threshold:
                confidences_c[j, i] += conf_increment
            if np.abs(matches_filled[j, i+1] - matches[j, i]) < distance_threshold:
                confidences_c[j, i] += conf_increment
            if np.abs(matches_filled[j-1, i] - matches[j, i]) < distance_threshold:
                confidences_c[j, i] += conf_increment
            if np.abs(matches_filled[j+1, i] - matches[j, i]) < distance_threshold:
                confidences_c[j, i] += conf_increment

    threshold = threshold - 0.02

    argmin_sads= np.argmin(np.nan_to_num(sads_c, nan=float('inf')), axis=2)

    matches = argmin_sads.astype(np.float64)
    matches_filled[(confidences_c > threshold) * (matches_filled<0)] = matches[(confidences_c > threshold) * (matches_filled<0)]
print(matches_filled_0[30])
print(matches_filled[30])

print(confidences_c[30])
# show matches filled

plt.figure(figsize=(18, 8))
plt.subplot(1, 2, 1)
plt.title("Matches - before filling")
plt.imshow(matches_filled_0, cmap='jet')
plt.colorbar()
plt.xlabel("Left image width")
plt.ylabel("Left image height")
plt.subplot(1, 2, 2)
plt.title("Matches filled")
plt.imshow(matches_filled, cmap='jet')
plt.colorbar()
plt.xlabel("Left image width")
plt.ylabel("Left image height")
plt.show()


In [None]:
# change matches to disparities

disparities_left = (matches_filled - np.array(range(0, sads.shape[2]))[np.newaxis, np.newaxis, :])[0]
disparities_left[mask0_crop_aligned == 0] = np.nan

# add offset
print(offset)
disparities_left -= offset

disparities_left = np.abs(disparities_left)


# show the filled disparities

plt.figure(figsize=(12, 8))
plt.title("Disparities")
plt.imshow(disparities_left, cmap='jet')
plt.colorbar()
plt.xlabel("Left image width")
plt.ylabel("Left image height")
# plt.show()


# 5. GLOBAL REFINEMENT

Now the part 5 should follow, but it's not implemented yet.


# 6. EVALUATION

The evaluation is done using the Middlebury Stereo Dataset 2021. The disparities are calculated 

In [None]:
# load the disp0.pfm and disp1.pfm files

disp0 = np.array(Image.open(Path(input_dir) / 'disp0.pfm'))
disp1 = np.array(Image.open(Path(input_dir) / 'disp1.pfm'))

# show the images

plt.figure(figsize=(20, 9))
plt.subplot(1, 2, 1)
plt.title("disp0")
plt.imshow(disp0, cmap='jet')
plt.colorbar()
plt.subplot(1, 2, 2)
plt.title("disp1")
plt.imshow(disp1, cmap='jet')
plt.colorbar()
plt.show()


Crop the disparities to the analyzed mask

In [None]:
# mask and then crop the image


disp0ca, disp1ca, mask0_crop_aligned, mask1_crop_aligned, offset = crop_and_align_masks(disp0, disp1, mask0, mask1)

# show side-by-side disp0ca and disparities_left

disparities_leftl = np.abs(disparities_left.copy())
disparities_leftl[disparities_leftl > 90] = 0

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("disp0ca")
plt.imshow(disp0ca, cmap='jet')
plt.colorbar()
plt.subplot(1, 2, 2)
plt.title("disparities_left")
plt.imshow(np.abs(disparities_leftl), cmap='jet')
plt.colorbar()
plt.show()



In [None]:
# get the evaluation metrics

# for existing disparities

indices = disparities_leftl > 0 & ~np.isinf(disparities_leftl)
indices = indices & ~np.isinf(disp0ca)

print(disparities_leftl[indices])
print(disp0ca[indices])
# calculate the mean absolute error

mae = np.mean(np.abs(disparities_leftl[indices] - disp0ca[indices]))
print(f"Mean Absolute Error: {mae:.4f}")
