# Inference

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import time
from pathlib import Path

import numpy as np
from PIL import Image
from sam2util import convert_jpg_to_mp4
from tqdm import tqdm

from model.inference import (
    filter_small_segments,
    init_dagshub,
    load_image,
    load_model_from_dagshub,
    measure_model_fps,
    predict,
)
from model.plot import plot_single_prediction


In [None]:
init_dagshub(repo_name='driver-state')

In [None]:
MODEL_NAME = 'pytorch-2024-09-15-215831-unet-efficientnet-b0'
VERSION = 1
model = load_model_from_dagshub(MODEL_NAME, VERSION)

In [None]:
OUTPUT_DIR = Path('outputs') / f'{datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")}-{MODEL_NAME}-v{VERSION}'
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

DATASET_PATH = Path('/home/lanter/source/driver-dataset/images/2022_09_14_stribny_enyaq/normal')

for image_path in tqdm(sorted(DATASET_PATH.glob('*.jpg'))):
    image = load_image(image_path)
    mask = predict(model, image, input_size=256, output_size=(1024, 1024))
    # plot_single_prediction(image, mask)

    mask = mask.squeeze().numpy()
    binary_mask = (mask > 0.5).astype(np.uint8)

    # Filter out small segments
    filtered_binary_mask = filter_small_segments(binary_mask, min_area=(1024 // 4)**2)

    mask_rgba = np.zeros((*mask.shape, 4), dtype=np.uint8)
    mask_rgba[..., 0] = 255  # Red
    mask_rgba[..., 3] = filtered_binary_mask * 102  # Alpha channel with transparency
    mask_img = Image.fromarray(mask_rgba, "RGBA")
    image_rgba = image.resize((mask.shape[0], mask.shape[1])).convert("RGBA")
    overlay_img = Image.alpha_composite(image_rgba, mask_img)
    overlay_img.convert("RGB").save(OUTPUT_DIR / f'{int(image_path.stem) // 30 :05d}.jpg')

    with open(OUTPUT_DIR / 'dataset_path.txt', 'w') as f:
        f.write(str(DATASET_PATH))

In [None]:
measure_model_fps(model, DATASET_PATH)

In [None]:
DATASET_PATH.parent.name, DATASET_PATH.name

In [15]:
FPS = 15
output_video_name = f'{DATASET_PATH.parent.name}_{DATASET_PATH.name}_{FPS}fps.mp4'
convert_jpg_to_mp4(OUTPUT_DIR, OUTPUT_DIR / output_video_name, fps=FPS)