In [17]:
# Import dependencies
import cv2
import torch
# import matplotlib.pyplot as plt
import plotly.express as px
import numpy as np
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

model_type = "DPT_Large"
filepath = "./data/1.jpg"

In [18]:
# Download the MiDaS
midas = torch.hub.load("intel-isl/MiDaS", model_type)
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    print("No GPU found, using CPU instead")
midas.to(device)
midas.eval();

Using cache found in C:\Users\crisp/.cache\torch\hub\intel-isl_MiDaS_master


In [19]:
def get_depth_estimate(filepath, model_type):

    # Input transformation pipeline
    midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

    if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
        transform = midas_transforms.dpt_transform
    else:
        transform = midas_transforms.small_transform

    # Transform input for midas
    image = cv2.imread(filepath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    input_batch1 = transform(image).to(device)

    # Make a prediction 1
    with torch.no_grad():
        prediction1 = midas(input_batch1)
        prediction1 = torch.nn.functional.interpolate(
            prediction1.unsqueeze(1),
            size=image.shape[:2],
            mode="bicubic",
            align_corners=False,
        ).squeeze()

        output = prediction1.cpu().numpy()

    return output

output = get_depth_estimate(filepath, model_type)

Using cache found in C:\Users\crisp/.cache\torch\hub\intel-isl_MiDaS_master


In [20]:
# output = (1/output)
# output = output*10
output

array([[ 0.71439385,  0.7582299 ,  0.7667054 , ...,  4.780029  ,
         4.754772  ,  4.6577654 ],
       [ 0.7498202 ,  0.7570119 ,  0.77390444, ...,  4.809814  ,
         4.8093925 ,  4.819333  ],
       [ 0.70368993,  0.7356369 ,  0.7856394 , ...,  4.834246  ,
         4.8531437 ,  4.865252  ],
       ...,
       [20.525383  , 20.552588  , 20.55699   , ..., 20.444962  ,
        20.474337  , 20.497026  ],
       [20.572834  , 20.61487   , 20.650307  , ..., 20.538351  ,
        20.499065  , 20.467705  ],
       [20.66379   , 20.69688   , 20.727446  , ..., 20.563587  ,
        20.498095  , 20.445866  ]], dtype=float32)

In [21]:
fig = px.imshow(output)
fig.update_layout(coloraxis_showscale=False)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.show()