# Run Batch Inference With Segmentation Model

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import sys
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from os import cpu_count
from pathlib import Path
from typing import Any

repo_root = str(Path.cwd().parent.parent)
if repo_root not in sys.path:
    sys.path.append(repo_root)

import cv2
import numpy as np
import torch
from albumentations.pytorch.transforms import ToTensorV2
from PIL import Image
from sam2util import convert_images_to_mp4
from tqdm import tqdm

from model.common import crop_driver_image_contains
from model.inference import (
    filter_small_segments,
    init_dagshub,
    load_model_from_dagshub,
)


In [None]:
model_name = 'pytorch-2025-02-28-173314-unetplusplus-efficientnet-b1'
model_version = 'champion'
input_dir = (
    Path().home()
    / 'source/driver-dataset/2024-10-28-driver-all-frames/2021_08_31_geordi_enyaq/normal/images'
)
batch_size = 16
repo_name = 'driver-seg'

In [None]:
INPUT_DIR = Path(input_dir)

OUTPUT_DIR = INPUT_DIR.parent / 'masks'
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

BATCH_SIZE = batch_size if isinstance(batch_size, int) else 16

assert INPUT_DIR.exists(), f'Input directory `{INPUT_DIR}` does not exist.'

In [None]:
init_dagshub(repo_name=repo_name)
model_ = load_model_from_dagshub(model_name, model_version)

In [None]:
def load_and_process_image(
    image_path: Path, resize: tuple[int, int], transforms: Any
) -> torch.Tensor:
    """Helper function to load and process a single image."""
    image = Image.open(image_path)
    processed_image = crop_driver_image_contains(image, image_path).resize(
        resize, resample=Image.NEAREST
    )
    return transforms(image=np.asarray(processed_image))['image']


def data_loader(
    image_paths: list[Path],
    batch_size: int = 32,
    resize: tuple[int, int] = (256, 256),
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
) -> Generator[tuple[torch.Tensor, list[Path]]]:
    transforms = ToTensorV2()

    def func(path: Path) -> torch.Tensor:
        return load_and_process_image(path, resize, transforms)

    n_workers = cpu_count() or 1
    n_workers = 8 if n_workers >= 8 else n_workers

    with ThreadPoolExecutor(n_workers) as executor:
        for i in range(0, len(image_paths), batch_size):
            image_paths_batch = image_paths[i : i + batch_size]
            images = list(
                executor.map(
                    func,
                    image_paths_batch,
                )
            )
            yield torch.stack(images).to(device), image_paths_batch

In [None]:
def process_and_save_mask(mask: torch.Tensor, image_path: Path) -> None:
    mask_numpy = mask.squeeze().numpy()
    binary_mask = (mask_numpy > 0.5).astype(np.uint8)

    # Filter out small segments
    filtered_binary_mask = filter_small_segments(binary_mask, min_area=64**2)

    output_file_path = OUTPUT_DIR / f'{image_path.stem}.png'
    cv2.imwrite(str(output_file_path), filtered_binary_mask * 255)
    assert output_file_path.exists(), f'Failed to save mask to `{output_file_path}`.'

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move the model to the device (GPU/CPU)
model_ = model_.to(device)

image_paths = sorted(INPUT_DIR.glob('*.jpg'))
loader = data_loader(image_paths, BATCH_SIZE, (256, 256), device)

In [None]:
# ~135 FPS

executor = ThreadPoolExecutor(4)
for images, image_paths in tqdm(loader, total=len(image_paths) // BATCH_SIZE):
    with torch.no_grad():
        prediction = model_(images)
        masks = prediction.sigmoid()

    for mask, image_path in zip(masks, image_paths):
        executor.submit(process_and_save_mask, mask.to('cpu'), image_path)

executor.shutdown()

In [None]:
with open(OUTPUT_DIR / 'readme.txt', 'w') as f:
    f.write(f'Timestamp: {datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")}\n')
    f.write(f'Input path: {INPUT_DIR}\n')
    f.write(f'Model: {model_name} v{model_version}\n')

In [None]:
convert_images_to_mp4(OUTPUT_DIR, OUTPUT_DIR / 'masks.mp4', fps=30, image_format='png')