In [None]:
import cv2
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance

In [None]:
model_type = "DPT_Large"     # MiDaS v3 - Large     (highest accuracy, slowest inference speed)
#model_type = "DPT_Hybrid"   # MiDaS v3 - Hybrid    (medium accuracy, medium inference speed)
# model_type = "MiDaS_small"  # MiDaS v2.1 - Small   (lowest accuracy, highest inference speed)

midas = torch.hub.load("intel-isl/MiDaS", model_type)

In [None]:
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.small_transform

In [None]:
inp =  np.random.randint(low=0, high=255, size=(640,640, 3))
out = transform(inp)
out.shape

In [None]:
midas.eval();
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
    transform = midas_transforms.dpt_transform
else:
    transform = midas_transforms.small_transform

In [None]:
def getDisparityMap(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    input_batch = transform(img)

    with torch.no_grad():
        prediction = midas(input_batch)

        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img.shape[:2],
            mode="bicubic",
            align_corners=False,
        ).squeeze()

    return prediction.cpu().numpy()

In [None]:
img_name = "b1cd1e94-26dd524f"
image_path = f"data/rgb/{img_name}.jpg"

In [None]:
disp_map = getDisparityMap(image_path)

In [None]:
depth_min = disp_map.min()
depth_max = disp_map.max()
normalized_depth = 255 * (disp_map - depth_min) / (depth_max - depth_min)
normalized_depth *= 3

right_side = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3

In [None]:
Image.fromarray((255-right_side).astype(np.uint8)).save(f"{img_name}_disp.png")

In [None]:
np.save("b1ee702d-4a193906", output)

In [None]:
### kitti
baseline = 0.54
focal = 707.09
img_scale = 1 

In [None]:
img_paths = sorted(Path("data/rgb/").glob("*.jpg"))
for path in img_paths:
    disp = getDisparityMap(str(path))
    disp[disp<0]=0
    disp = disp + 1e-3
    depth = 0.54*721/(disp*img_scale)
    np.save("data/depth/"+path.stem, depth)

In [None]:
disp = getDisparityMap("b1d3907b-2278601b-enhance.jpg")
disp[disp<0]=0
disp = disp + 1e-3
depth = 0.54*721/(disp*img_scale)
np.save("b1d3907b-2278601b-enhance", depth)

In [None]:
depth = getDisparityMap("b1d7b3ac-5744370e.jpg")

In [None]:
bits = 2

if not np.isfinite(depth).all():
    depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0)
    print("WARNING: Non-finite depth values present")

depth_min = depth.min()
depth_max = depth.max()

max_val = (2**(8*bits))-1

if depth_max - depth_min > np.finfo("float").eps:
    out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
    out = np.zeros(depth.shape, dtype=depth.dtype)

out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO)
cv2.imwrite("disp.png", out.astype("uint16"))

In [None]:
def create_side_by_side(image, depth, grayscale):
    """
    Take an RGB image and depth map and place them side by side. This includes a proper normalization of the depth map
    for better visibility.
    Args:
        image: the RGB image
        depth: the depth map
        grayscale: use a grayscale colormap?
    Returns:
        the image and depth map place side by side
    """
    depth_min = depth.min()
    depth_max = depth.max()
    normalized_depth = 255 * (depth - depth_min) / (depth_max - depth_min)
    normalized_depth *= 3

    right_side = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3
    if not grayscale:
        right_side = cv2.applyColorMap(np.uint8(right_side), cv2.COLORMAP_INFERNO)

    if image is None:
        return right_side
    else:
        return np.concatenate((image, right_side), axis=1)

In [None]:
img = cv2.cvtColor(cv2.imread("b1d7b3ac-5744370e.jpg"), cv2.COLOR_BGR2RGB).astype(float) / 255.0
original_image_bgr = np.flip(img, 2)
content = create_side_by_side(original_image_bgr*255, depth, True)
cv2.imwrite("test.png", content)