In [None]:
from pathlib import Path

import cv2
import numpy as np
from src.api.models.gaze import GazePoint
from src.api.services import gaze_service
from src.logic.models.efficientvit_sam import EfficientVitSAMCheckpoint, EfficientVitSAMModel

In [2]:
model = EfficientVitSAMModel(EfficientVitSAMCheckpoint.EFFICIENTVIT_SAM_XL0, Path("../checkpoints"))

In [None]:
VIDEO_PATH = Path("../data/recordings/af47ccce-c344-49d9-9916-5729e2ddc021.mp4")
GAZE_DATA_PATH = Path("../data/recordings/af47ccce-c344-49d9-9916-5729e2ddc021.tsv")

resolution = cv2_video_resolution(VIDEO_PATH)
frame_count = cv2_video_frame_count(VIDEO_PATH)
fps = cv2_video_fps(VIDEO_PATH)
gaze_data = gaze_service.parse_gazedata_file(GAZE_DATA_PATH)
gaze_points = gaze_service.get_gaze_points(gaze_data, resolution)
frame_gaze_mapping = gaze_service.match_frames_to_gaze(num_frames=frame_count, gaze_points=gaze_points, fps=fps)

In [4]:
def overlay_masks(frame, masks, alpha=0.5):
    """
    Overlay masks on the original frame
    Args:
        frame: Original frame (ndarray)
        masks: Binary masks with shape (3, H, W)
        alpha: Transparency of overlay (0-1)
        colors: List of RGB colors for each mask
    Returns:
        frame: Frame with overlaid masks
    """

    colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]

    # Overlay each mask
    for i, mask in enumerate(masks):
        # Create colored mask
        colored_mask = np.zeros_like(frame)
        colored_mask[mask] = colors[i]

        # Overlay with alpha blending
        frame = cv2.addWeighted(frame, 1, colored_mask, alpha, 0)

    return frame.astype(np.uint8)


def overlay_gaze_points(frame, gaze_points: list[GazePoint]):
    """
    Overlay gaze points on the original frame
    Args:
        frame: Original frame (ndarray)
        gaze_points: List of gaze points (x, y) in pixel coordinates
    Returns:
        frame_with_gazepoints: Frame with overlaid gaze points
    """
    for gaze_point in gaze_points:
        cv2.circle(frame, gaze_point.position, 15, (255, 0, 0), 2)

In [5]:
OUTPUT_PATH = Path("./output")
OUTPUT_PATH.mkdir(exist_ok=True)
for file in OUTPUT_PATH.iterdir():
    file.unlink()

frame_sample_points = [gaze_points[0].position if len(gaze_points) > 0 else [] for gaze_points in frame_gaze_mapping]
frame_gaze_depths = [gaze_points[0].depth if len(gaze_points) > 0 else [] for gaze_points in frame_gaze_mapping]

In [6]:
BATCH_SIZE = 50
batch_start = 0
batch_end = BATCH_SIZE

cap = cv2.VideoCapture(VIDEO_PATH)
if not cap.isOpened():
    raise ValueError(f"Video file not found: {VIDEO_PATH}")

if len(frame_sample_points) != frame_count:
    raise ValueError(
        f"Number of sample point batches ({len(frame_sample_points)}) does not match the number of frames ({frame_count})"
    )

frames = []
while cap.isOpened():
    ret, frame = cap.read()

    if ret:
        frames.append(frame)

    if len(frames) == BATCH_SIZE or not ret:
        frame_masks = model.predict_batch(
            frames,
            frame_sample_points[batch_start:batch_end],
            frame_gaze_depths[batch_start:batch_end],
            resolution=resolution,
        )

        for i, frame in enumerate(frames):
            frame = overlay_masks(frame, frame_masks[i])
            overlay_gaze_points(frame, frame_gaze_mapping[batch_start + i])
            cv2.imwrite(str(OUTPUT_PATH / f"{batch_start + i:04d}.png"), frame)

        del frames
        frames = []
        batch_start += BATCH_SIZE
        batch_end += BATCH_SIZE

    if not ret:
        # End of video
        break

cap.release()

IndexError: index 100 is out of bounds for dimension 0 with size 50