# Segmentation Inference

Outputs segmentation overlay video for visualization purposes.

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import sys
from pathlib import Path

repo_root = str(Path.cwd().parent.parent)
if repo_root not in sys.path:
    sys.path.append(repo_root)

import numpy as np
from PIL import Image
from sam2util import convert_images_to_mp4
from tqdm import tqdm

from model.inference import (
    filter_small_segments,
    init_dagshub,
    load_image,
    load_model_from_dagshub,
    measure_model_fps,
    predict,
)
from model.plot import plot_single_prediction  # noqa: F401


In [None]:
def now() -> str:
    return datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')

In [3]:
model_name = 'pytorch-2024-10-28-094601-unet-efficientnet-b0'
model_version = 1
input_dir = (
    Path().home() / 'source/driver-dataset/images/2022_09_14_stribny_enyaq/normal'
)
output_dir = f'outputs/{now()}-{model_name}-v{model_version}'
fps = 30
interpolation_shape: tuple[int, int] | None = (1024, 1024)

In [4]:
INPUT_DIR = Path(input_dir)
OUTPUT_DIR = Path(output_dir) if output_dir else INPUT_DIR.parent / 'masks'
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
FPS = fps if isinstance(fps, int) else 30
INTERPOLATION_SHAPE = interpolation_shape if interpolation_shape else (256, 256)

assert INPUT_DIR.exists(), f'Input directory `{INPUT_DIR}` does not exist.'
assert len(INTERPOLATION_SHAPE) == 2 and all(
    [isinstance(x, int) for x in INTERPOLATION_SHAPE]
), 'Interpolation shape must be a tuple of two integers.'

In [None]:
init_dagshub(repo_name='driver-state')
model = load_model_from_dagshub(model_name, model_version)

In [None]:
for image_path in tqdm(sorted(INPUT_DIR.glob('*.jpg'))):
    image = load_image(image_path)
    mask = predict(model, image, input_size=256, output_size=INTERPOLATION_SHAPE)
    # plot_single_prediction(image, mask)

    mask = mask.squeeze().numpy()
    binary_mask = (mask > 0.5).astype(np.uint8)

    # Filter out small segments
    filtered_binary_mask = filter_small_segments(
        binary_mask, min_area=(INTERPOLATION_SHAPE[0] // 4) ** 2
    )

    mask_rgba = np.zeros((*mask.shape, 4), dtype=np.uint8)
    mask_rgba[..., 0] = 255  # Red
    mask_rgba[..., 3] = filtered_binary_mask * 102  # Alpha channel with transparency
    mask_img = Image.fromarray(mask_rgba, 'RGBA')
    image_rgba = image.resize((mask.shape[0], mask.shape[1])).convert('RGBA')
    overlay_img = Image.alpha_composite(image_rgba, mask_img)
    overlay_img.convert('RGB').save(
        OUTPUT_DIR / f'{int(image_path.stem) // 30:05d}.jpg'
    )

In [7]:
with open(OUTPUT_DIR / 'readme.txt', 'w') as f:
    f.write(f'Timestamp: {datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")}\n')
    f.write(f'Input path: {INPUT_DIR}\n')
    f.write(f'Model: {model_name} v{model_version}\n')

In [None]:
measure_model_fps(model, INPUT_DIR)

In [None]:
INPUT_DIR.parent.name, INPUT_DIR.name

In [10]:
output_video_name = f'{INPUT_DIR.parent.name}_{INPUT_DIR.name}_{FPS}fps.mp4'
convert_images_to_mp4(OUTPUT_DIR, OUTPUT_DIR / output_video_name, fps=FPS)