In [1]:
from ultralytics import FastSAM
from src.core.utils import cv2_loadvideo
from pathlib import Path
from src.core.utils import cv2_video_resolution, cv2_video_fps, cv2_video_frame_count
from src.logic.glasses.gaze import parse_gazedata_file, get_gaze_points, match_frames_to_gaze
import cv2
from ultralytics.engine.results import Results
from typing import List
import torch
import numpy as np
import time

In [2]:
model=FastSAM('../checkpoints/FastSAM-s.engine')


In [3]:
VIDEO_PATH = Path("../data/recordings/af47ccce-c344-49d9-9916-5729e2ddc021.mp4")
GAZE_DATA_PATH = Path("../data/recordings/af47ccce-c344-49d9-9916-5729e2ddc021.tsv")

resolution = cv2_video_resolution(VIDEO_PATH)
frame_count = cv2_video_frame_count(VIDEO_PATH)
fps = cv2_video_fps(VIDEO_PATH)
gaze_data = parse_gazedata_file(GAZE_DATA_PATH)
gaze_points = get_gaze_points(gaze_data, resolution)
frame_gaze_mapping = match_frames_to_gaze(frame_count, gaze_points, fps)

In [4]:
def filter_masks(masks: List[torch.Tensor], gaze_depth: float) -> List[torch.Tensor]:
    """
    Filter masks based on mask area and gaze depth
    Args:
        masks: List of binary masks with shape (H, W)
        gaze_depth: Depth of the gaze point in millimeters
    Returns:
        filtered_masks: List of filtered binary masks
    """
    
    filtered_masks = []
    for mask in masks:
        # Calculate the area of the mask
        mask_area = torch.sum(mask)
        total_area = mask.shape[0] * mask.shape[1]

        # maybe filtering by percentage as well is good to account for when gaze depth is not accurate
        mask_area_pct = mask_area / total_area
        
        # Calculate the area of the mask relative to the depth
        mask_area_relative = mask_area / (gaze_depth ** 2)

        # count number of connected components
        num_labels, _ = cv2.connectedComponents(mask.cpu().numpy().astype(np.uint8))
        foreground_components = num_labels - 1

        # Filter masks based on area
        if mask_area_pct < 0.1 and foreground_components < 3:
            filtered_masks.append(mask)
    
    return filtered_masks


def overlay_gaze_points(frame, gaze_points: List[tuple[int, int]]):
    """
    Overlay gaze points on the original frame
    Args:
        frame: Original frame (ndarray)
        gaze_points: List of gaze points (x, y) in pixel coordinates
    Returns:
        frame_with_gazepoints: Frame with overlaid gaze points
    """
    for gaze_point in gaze_points:
        cv2.circle(frame, gaze_point, 15, (255, 0, 0), 2)


def overlay(
    image: np.ndarray, 
    mask: np.ndarray, 
    color: tuple[int, int, int] = (255, 0, 0), 
    alpha: float = 0.5, 
    resize=None
):
    """Combines image and its segmentation mask into a single image.
    https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay

    Params:
        image: Training image. np.ndarray,
        mask: Segmentation mask. np.ndarray,
        color: Color for segmentation mask rendering.  tuple[int, int, int] = (255, 0, 0)
        alpha: Segmentation mask's transparency. float = 0.5,
        resize: If provided, both image and its mask are resized before blending them together.
        tuple[int, int] = (1024, 1024))

    Returns:
        image_combined: The combined image. np.ndarray

    """
    color = color[::-1]
    colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
    colored_mask = np.moveaxis(colored_mask, 0, -1)
    masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
    image_overlay = masked.filled()

    if resize is not None:
        image = cv2.resize(image.transpose(1, 2, 0), resize)
        image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)

    image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)

    return image_combined

In [5]:
CROP_SIZE = 512
HALF_CROP = CROP_SIZE // 2

video_result = cv2.VideoWriter(
    'filename.mp4',  
    cv2.VideoWriter_fourcc(*'mp4v'), 
    fps, 
    (resolution[1], resolution[0])
) 

for frame_idx, frame in cv2_loadvideo(VIDEO_PATH):
    original_frame = frame.copy()
    frame_gaze_points = frame_gaze_mapping[frame_idx]

    if len(frame_gaze_points) == 0:
        video_result.write(original_frame) 
        continue

    # Preprocess
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    # Crop the frame around the gaze point
    gaze_point = frame_gaze_points[0]
    cx, cy = gaze_point.position

    # Compute bounding-box edges with clamping
    left   = max(cx - HALF_CROP, 0)
    right  = min(cx + HALF_CROP, resolution[1])
    top    = max(cy - HALF_CROP, 0)
    bottom = min(cy + HALF_CROP, resolution[0])
    
    crop_x = min(cx - left, HALF_CROP - 1)
    crop_y = min(cy - top, HALF_CROP - 1)
    cropped_frame = frame[top:bottom, left:right]

    # Pad the bottom right corner to make the frame square
    pad_x = CROP_SIZE - cropped_frame.shape[1]
    pad_y = CROP_SIZE - cropped_frame.shape[0]
    padded_frame = cv2.copyMakeBorder(cropped_frame, 0, pad_y, 0, pad_x, cv2.BORDER_CONSTANT, value=(0, 0, 0))

    result: Results = model(
        source=padded_frame,
        points=[(crop_x, crop_y)],
        labels=[1],
        device='cuda',
        verbose=False,
        imgsz=CROP_SIZE,
        conf=0.5, 
        iou=0.9
    )[0]

    if result.masks is None:
        overlay_gaze_points(original_frame, [gp.position for gp in frame_gaze_points])
        video_result.write(original_frame) 
        continue
    
    # Post processing
    masks = filter_masks(result.masks.data, gaze_point.depth)

    # if len(masks) == 0:
    #     overlay_gaze_points(original_frame, [gp.position for gp in frame_gaze_points])
    #     video_result.write(original_frame) 
    #     continue

    # mask = torch.stack(masks).sum(dim=0) > 0
    # result.update(masks=torch.stack(masks))
    plotted_result = result.plot(masks=True)
    # full_mask = torch.zeros((resolution[0], resolution[1]), device='cuda')
    # mask_thresh = mask > torch.tensor([0.5], device='cuda')
    original_frame[top:bottom, left:right] = plotted_result[0:cropped_frame.shape[0], 0:cropped_frame.shape[1]]

    # Rendering
    # original_frame = overlay(original_frame, full_mask.cpu())
    overlay_gaze_points(original_frame, [gp.position for gp in frame_gaze_points])

    video_result.write(original_frame) 

video_result.release()

Loading ../checkpoints/FastSAM-s.engine for TensorRT inference...
[01/29/2025-11:25:10] [TRT] [I] Loaded engine size: 27 MiB
[01/29/2025-11:25:10] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +21, now: CPU 0, GPU 45 (MiB)


KeyboardInterrupt: 