In [8]:
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Literal

import torch
from PIL import Image
from tqdm import tqdm
from transformers import pipeline

from model.common import crop_driver_image_contains, pad_to_square

assert torch.cuda.is_available(), 'CUDA is not available'

In [9]:
input_dir = (
    Path().home()
    / 'source/driver-dataset/2024-10-28-driver-all-frames/2021_08_31_geordi_enyaq/normal/images'
)
dataset: Literal['mrl', 'dmd'] = 'mrl'

In [10]:
INPUT_DIR = Path(input_dir)
OUTPUT_DIR = INPUT_DIR.parent / 'depth'
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

if dataset not in ['mrl', 'dmd']:
    raise ValueError('Invalid dataset')
DATASET = dataset

In [11]:
MODEL = 'depth-anything/Depth-Anything-V2-Small-hf'
IMAGE_PATHS = sorted((INPUT_DIR).glob('*.jpg'))
BATCH_SIZE = 64

## Inference on the whole dataset

In [12]:
pipe = pipeline(task='depth-estimation', model=MODEL, device=0)

In [13]:
def load_and_process_image(image_path: Path) -> Image.Image:
    """Helper function to load and process a single image."""
    image = Image.open(image_path)
    # Use default inference resize (https://github.com/DepthAnything/Depth-Anything-V2)
    if DATASET == 'mrl':
        processed_image = crop_driver_image_contains(image, image_path).resize(
            (518, 518)
        )
    elif DATASET == 'dmd':
        processed_image = pad_to_square(image).resize((518, 518))
        # TODO: Alternative - just resize (distort): image.resize((518, 518))
    return processed_image


def dataset_batched(
    batch_size: int = 32,
) -> Generator[tuple[list[Image.Image], list[Path]]]:
    with ThreadPoolExecutor() as executor:
        for i in range(0, len(IMAGE_PATHS), batch_size):
            image_paths_batch = IMAGE_PATHS[i : i + batch_size]
            images = list(executor.map(load_and_process_image, image_paths_batch))
            yield images, image_paths_batch


def save_image(depth_img: Image.Image, export_path: Path) -> None:
    depth_img.save(export_path)


def save_images(results: Any, paths: list[Path]) -> None:
    depth_images = [res['depth'] for res in results]  # type: ignore

    for depth_img, path in zip(depth_images, paths):
        export_path = path.parent.parent / 'depth' / path.with_suffix('.png').name
        save_image(depth_img, export_path)  # type: ignore

In [14]:
# ~41 FPS with 64 batch size
executor = ThreadPoolExecutor()

with tqdm(total=len(IMAGE_PATHS), desc='Generating depth images') as pbar:
    for images, paths in dataset_batched(batch_size=BATCH_SIZE):
        results = pipe(images)
        pbar.update(BATCH_SIZE)
        executor.submit(save_images, results, paths)

executor.shutdown(wait=True)

Generating depth images:   0%|          | 0/27000 [00:00<?, ?it/s]

Generating depth images: 27008it [10:56, 41.14it/s]                           
