In [4]:
import torch
import torchvision.transforms as transforms
from utils.dataset import read_voc_dataset
from PIL import Image
import cv2
import numpy as np
import os
import json
import matplotlib.pyplot as plt

In [5]:
def generate_depth_map(image_path, model, transform, device):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    input_batch = transform(img).to(device)
    
    with torch.no_grad():
        prediction = model(input_batch)
        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img.shape[:2],
            mode="bicubic",
            align_corners=False,
        ).squeeze()

    return prediction.cpu().numpy()

In [6]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load the MiDaS model
model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
model.to(device)
model.eval()

# Load transforms to MiDaS model
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.small_transform

Using cache found in /home/boweiche/.cache/torch/hub/intel-isl_MiDaS_master


Loading weights:  None


Downloading: "https://github.com/rwightman/gen-efficientnet-pytorch/zipball/master" to /home/boweiche/.cache/torch/hub/master.zip
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth" to /home/boweiche/.cache/torch/hub/checkpoints/tf_efficientnet_lite3-b733e338.pth
Downloading: "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" to /home/boweiche/.cache/torch/hub/checkpoints/midas_v21_small_256.pt


  0%|          | 0.00/81.8M [00:00<?, ?B/s]

Using cache found in /home/boweiche/.cache/torch/hub/intel-isl_MiDaS_master


In [7]:
# Paths
data_path = "./data/PascalVOC2012/VOCdevkit/VOC2012"
depth_maps_path = os.path.join(data_path, "depth_maps")
os.makedirs(depth_maps_path, exist_ok=True)
image_path = os.path.join(data_path, 'JPEGImages')

In [8]:
# Generate depth maps for all images in image_path
for filename in os.listdir(image_path):
    if filename.endswith(".jpg"):
        image_file = os.path.join(image_path, filename)
        depth_map = generate_depth_map(image_file, model, transform, device)
        depth_map_file = os.path.join(depth_maps_path, os.path.splitext(filename)[0] + ".png")
        cv2.imwrite(depth_map_file, depth_map)

print("Depth maps generated and saved successfully.")

Depth maps generated and saved successfully.
