In [1]:
from pathlib import Path

import torch
from src.core.utils import cv2_video_fps, cv2_video_frame_count, cv2_video_resolution
from src.logic.glasses.gaze import get_gaze_points, match_frames_to_gaze, parse_gazedata_file
from src.logic.models.efficientvit_sam import EfficientVitSAMCheckpoint, EfficientVitSAMModel

In [2]:
model = EfficientVitSAMModel(EfficientVitSAMCheckpoint.EFFICIENTVIT_SAM_XL0, Path("../checkpoints"))

In [3]:
VIDEO_PATH = Path("../data/recordings/af47ccce-c344-49d9-9916-5729e2ddc021.mp4")
GAZE_DATA_PATH = Path("../data/recordings/af47ccce-c344-49d9-9916-5729e2ddc021.tsv")

resolution = cv2_video_resolution(VIDEO_PATH)
frame_count = cv2_video_frame_count(VIDEO_PATH)
fps = cv2_video_fps(VIDEO_PATH)
gaze_data = parse_gazedata_file(GAZE_DATA_PATH)
gaze_points = get_gaze_points(gaze_data, resolution)
frame_gaze_mapping = match_frames_to_gaze(num_frames=frame_count, gaze_points=gaze_points, fps=fps)

In [4]:
def get_circle_grid(p: tuple[int, int], r: int, stride: int = 5):
    """
    Returns a grid of points that are within the pixel_error_radius of the gaze_point.
    Ignores points that are outside of the resolution of the video.

    Args:
        p: Point to get the grid around.
        r: The radius around the point to get the grid.
        stride: The stride to use when getting the grid.

    Returns:
        A list of points that are within the pixel_error_radius of the gaze_point.
    """
    grid = []
    for x in range(-r, r + 1, stride):
        Y = int((r**2 - x**2) ** 0.5)  # bound for y given x
        for y in range(-Y, Y + 1):
            grid.append((p[0] + x, p[1] + y))

    return grid

In [5]:
FOV_X = 95
EYETRACKING_ACCURACY = 0.6  # in degrees
pixel_error_radius = int(EYETRACKING_ACCURACY * resolution[0] / FOV_X)

frame_sample_points = []
for gaze_points in frame_gaze_mapping:
    if len(gaze_points) == 0:
        frame_sample_points.append(torch.tensor([], dtype=torch.float32))
    else:
        sample_grid = get_circle_grid(gaze_points[0], pixel_error_radius, stride=7)
        sample_grid = [
            point for point in sample_grid if 0 <= point[0] < resolution[1] and 0 <= point[1] < resolution[1]
        ]
        print(len(sample_grid))
        frame_sample_points.append(torch.tensor(sample_grid, dtype=torch.float32))

12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
1

In [6]:
model.predict(VIDEO_PATH, frame_sample_points)

  preprocessed_frame, local_point_centers = self.preprocess_frame(frame, resolution, torch.tensor(point_prompts, dtype=torch.float32))


Total inference time: 14.97066330909729


In [9]:
5 * 3.73

18.65