In [20]:
import os

os.environ["CHECKPOINTS_PATH"] = "../checkpoints"

import json
import shutil
import traceback
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
from sqlalchemy.orm import Session
from src.api.db import engine
from src.api.models.db import Recording
from src.api.services import gaze_service
from src.config import CHECKPOINTS_PATH, GAZE_FOV, TOBII_FOV_X, TOBII_GLASSES_FPS
from src.utils import cv2_video_fps, cv2_video_frame_count, cv2_video_resolution
from torchvision.ops import masks_to_boxes
from torchvision.transforms import InterpolationMode
from tqdm import tqdm
from ultralytics import FastSAM

from experiment.settings import (
    FULLY_LABELED_RECORDINGS,
    GAZE_SEGMENTATION_RESULTS_PATH,
)
import tempfile
from src.utils import extract_frames_to_dir

In [21]:
with Session(engine) as session:
    trial_recordings = (
        session.query(Recording).filter(Recording.id.in_(FULLY_LABELED_RECORDINGS)).all()
    )

# Segmenting and Tracking based on Gaze Data, and grounding based on previously built Vector Index

There's a few considerations that might be interesting in an experimental context:
1. Selection of `k` in top-k results from the database?
2. Segmentation quality (IOU?, Confidence?)
3. Adding padding to the bounding boxes?
4. Indexing, search parameters? (which ones exist)
5. Merging of same-frame ROIs or not?
6. Importance of metrics (average, min, max, variance, ?)

In [None]:
class GazeSegmentationJob:
    def __init__(
        self,
        video_path: Path,
        gaze_data_path: Path,
        results_path: Path,
        fovea_fov: float = GAZE_FOV,
        fov_x: float = TOBII_FOV_X,
        checkpoint_path: str = "checkpoints/FastSAM-x.pt",
        frames_path: Path | None = None,
    ):
        self.video_path = video_path
        self.gaze_data_path = gaze_data_path
        self.fovea_fov = fovea_fov
        self.fov_x = fov_x

        # Set up the results directory.
        self.results_path = results_path
        if self.results_path.exists():
            shutil.rmtree(self.results_path, ignore_errors=True)
            self.results_path.mkdir(parents=True, exist_ok=True)
        self.results_path.mkdir(parents=True, exist_ok=True)

        # Extract frames to a temporary directory.
        if frames_path is None:
            self.frames_path = Path(tempfile.mkdtemp())
            extract_frames_to_dir(
                video_path=self.video_path, frames_path=self.frames_path
            )
        else:
            self.frames_path = frames_path
            if not self.frames_path.exists():
                raise FileNotFoundError(f"Frames path {self.frames_path} does not exist.")

        # Load the FastSAM model.
        self.model = FastSAM(checkpoint_path)

        # Video properties.
        self.resolution = cv2_video_resolution(self.video_path)
        self.aspect_ratio = self.resolution[1] / self.resolution[0]  # W / H
        self.fps = cv2_video_fps(self.video_path)
        self.viewed_radius = int((self.fovea_fov / self.fov_x) * self.resolution[1])
        self.frame_count = cv2_video_frame_count(self.video_path)

        # Parse gaze data.
        self.gaze_data = gaze_service.parse_gazedata_file(self.gaze_data_path)
        self.gaze_points = gaze_service.get_gaze_points(self.gaze_data, self.resolution)

        # Map frame indexes to gaze points.
        self.frame_gaze_mapping = gaze_service.match_frames_to_gaze(
            self.frame_count, self.gaze_points, self.fps
        )

    def get_gaze_position(self, frame_idx: int) -> tuple[int, int] | None:
        """
        Get the gaze position for a frame index.
        """
        gaze_points = self.frame_gaze_mapping[frame_idx]
        if len(gaze_points) == 0:
            return None
        return gaze_points[0].position

    def mask_too_large(self, mask: torch.Tensor) -> bool:
        """
        Check if the mask area is less than or equal to 30% of the frame area.

        Args:
            mask: A tensor containing a single mask of shape (H, W)

        Returns:
            bool: True if the mask's area is less than or equal to 30% of the frame area, False otherwise.
        """
        MAX_MASK_AREA = 0.1
        height, width = mask.shape
        frame_area = height * width
        max_mask_area = MAX_MASK_AREA * frame_area

        mask_area = mask.sum()
        return mask_area >= max_mask_area

    def run(self):
        frame_paths = list(self.frames_path.iterdir())
        frame_paths.sort(key=lambda x: int(x.stem))

        with ThreadPoolExecutor() as executor:
            for frame_path in frame_paths:
                frame_idx = int(frame_path.stem)
                results = self.model.track(
                    source=str(frame_path), imgsz=640, verbose=False, persist=True
                )[0]

                try:
                    gaze_position = self.get_gaze_position(frame_idx)
                    if gaze_position is None:
                        continue

                    boxes = []
                    rois = []
                    masks = []
                    object_ids = []
                    confidences = []
                    for i in range(len(results.boxes)):
                        confidences.append(float(results.boxes[i].conf))

                        mask = F.resize(
                            results.masks[i].data,
                            self.resolution,
                            interpolation=InterpolationMode.NEAREST,
                        ).squeeze()

                        if not self.mask_too_large(mask) and gaze_service.mask_was_viewed(
                            mask, gaze_position
                        ):
                            box = masks_to_boxes(mask.unsqueeze(0)).int().cpu().numpy()[0]
                            x1, y1, x2, y2 = box
                            roi = results.orig_img[y1:y2, x1:x2, :]
                            roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
                            boxes.append(box)
                            masks.append(mask.cpu().numpy().astype(np.uint8))
                            rois.append(roi)
                            object_ids.append(int(results.boxes.id[i]))

                    if len(boxes) > 0:
                        rois_array = np.empty(len(rois), dtype=object)
                        for i, roi in enumerate(rois):
                            rois_array[i] = roi

                        masks_array = np.empty(len(masks), dtype=object)
                        for i, mask in enumerate(masks):
                            masks_array[i] = mask

                        # Offload saving with thread pool (asynchronously)
                        executor.submit(
                            np.savez_compressed,
                            self.results_path / f"{frame_idx}.npz",
                            boxes=boxes,
                            rois=rois_array,
                            masks=masks_array,
                            object_ids=object_ids,
                            frame_idx=frame_idx,
                            gaze_position=gaze_position,
                            confidences=confidences,
                        )

                except Exception as e:
                    print(f"Error processing frame {frame_idx}: {e}")
                    traceback.print_exc()

In [23]:
if GAZE_SEGMENTATION_RESULTS_PATH.exists():
    shutil.rmtree(GAZE_SEGMENTATION_RESULTS_PATH, ignore_errors=True)
GAZE_SEGMENTATION_RESULTS_PATH.mkdir(parents=True, exist_ok=True)

FRAMES_PATHS = Path("data/recording_frames")


def process_recording(recording: Recording):
    """
    Process a recording for gaze segmentation.
    """
    recording_id = recording.id
    video_path = Path(recording.video_path)
    gaze_data_path = Path(recording.gaze_data_path)
    frames_path = FRAMES_PATHS / recording_id
    results_path = GAZE_SEGMENTATION_RESULTS_PATH / recording_id

    if results_path.exists():
        shutil.rmtree(results_path, ignore_errors=True)
    results_path.mkdir(parents=True, exist_ok=True)

    job = GazeSegmentationJob(
        video_path=video_path,
        gaze_data_path=gaze_data_path,
        results_path=results_path,
        fovea_fov=GAZE_FOV,
        fov_x=TOBII_FOV_X,
        checkpoint_path=CHECKPOINTS_PATH / "FastSAM-x.pt",
        frames_path=frames_path,
    )
    job.run()


for recording in tqdm(trial_recordings, desc="Processing recordings"):
    if recording.id in FULLY_LABELED_RECORDINGS:
        process_recording(recording)

Processing recordings: 100%|██████████| 3/3 [05:19<00:00, 106.52s/it]


# Rendering the results

In [18]:
from src.api.utils import image_utils
import matplotlib.pyplot as plt
import tempfile
from src.utils import extract_frames_to_dir

RECORDING_ID = "32f02db7-adc0-4556-a2da-ed2ba60a58c9"
SEG_RESULTS_PATH = GAZE_SEGMENTATION_RESULTS_PATH / RECORDING_ID
VIDEO_PATH = Path("data/recordings") / f"{RECORDING_ID}.mp4"

temp_video_frames_path = Path(tempfile.gettempdir()) / f"{RECORDING_ID}"
if temp_video_frames_path.exists():
    shutil.rmtree(temp_video_frames_path, ignore_errors=True)
temp_video_frames_path.mkdir(parents=True, exist_ok=True)

extract_frames_to_dir(
    video_path=VIDEO_PATH,
    frames_path=temp_video_frames_path,
)

In [19]:
frames = list(temp_video_frames_path.iterdir())
seg_results = list(SEG_RESULTS_PATH.iterdir())
seg_results.sort(key=lambda x: int(x.stem))

frame_id_to_path = {int(frame.stem): frame for frame in frames}

for i, results in enumerate(tqdm(seg_results)):
    frame_idx = int(results.stem)
    frame_path = frame_id_to_path[frame_idx]

    frame = cv2.imread(str(frame_path))
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    results_file = np.load(results, allow_pickle=True)

    boxes = results_file["boxes"]
    masks = results_file["masks"]
    object_ids = results_file["object_ids"]

    combined_mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
    for mask in masks:
        combined_mask = np.logical_or(combined_mask, mask)

    image_utils.draw_mask(
        img=frame,
        mask=combined_mask,
        box=(0, 0, frame.shape[1], frame.shape[0]),
    )

    for j in range(len(boxes)):
        image_utils.draw_labeled_box(
            img=frame, box=tuple(boxes[j]), label=f"ID: {object_ids[j]}", color="#FF0000"
        )

    # save back to original path
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

    cv2.imwrite(str(frame_path), frame)

import subprocess

cmd = f'ffmpeg -hwaccel cuda -y -pattern_type glob -framerate {TOBII_GLASSES_FPS} -i "{temp_video_frames_path!s}/*.jpg" -c:v libx264 -pix_fmt yuv420p "test.mp4"'
subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

100%|██████████| 975/975 [00:23<00:00, 41.73it/s]


CompletedProcess(args='ffmpeg -hwaccel cuda -y -pattern_type glob -framerate 24.95 -i "/tmp/32f02db7-adc0-4556-a2da-ed2ba60a58c9/*.jpg" -c:v libx264 -pix_fmt yuv420p "test.mp4"', returncode=0)