In [None]:
import cv2
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image, ImageEnhance

In [None]:
def get_depth_estimation_model(model_name:str):
    assert model_name in ["DPT_Large", "DPT_Hybrid", "MiDaS_small"]
    
    midas = torch.hub.load("intel-isl/MiDaS", model_name)
    midas.eval();
    
    
    midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
    if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
        transform = midas_transforms.dpt_transform
    else:
        transform = midas_transforms.small_transform
    return midas, transform

In [None]:
def getDisparityMap(model, transform, img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    input_batch = transform(img)

    with torch.no_grad():
        prediction = midas(input_batch)

        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img.shape[:2],
            mode="bicubic",
            align_corners=False,
        ).squeeze()

    return prediction.cpu().numpy()

In [None]:
### kitti
baseline = 0.54
focal = 721.09
img_scale = 1 

In [None]:
midas, midas_transform = get_depth_estimation_model(model_name="DPT_Large")

img_paths = sorted(Path("data/rgb/").glob("*.jpg"))
for path in img_paths:
    disp = getDisparityMap(midas, midas_transform, str(path))
    disp[disp<0]=0
    disp = disp + 1e-3
    depth = baseline*focal/(disp*img_scale)
    np.save("data/depth/"+path.stem, depth)