In [None]:
import json
import shutil
import traceback
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import cv2
import numpy as np
import torch
import torchvision.transforms.functional as F
from sqlalchemy.orm import Session
from src.api.controllers.gaze_segmentation import (
    get_gaze_points,
    mask_was_viewed,
    match_frames_to_gaze,
    parse_gazedata_file,
)
from src.config import CHECKPOINTS_PATH, GAZE_FOVEA_FOV, TOBII_FOV_X
from src.db import engine
from src.db.models import Recording
from src.utils import cv2_video_fps, cv2_video_frame_count, cv2_video_resolution
from torchvision.ops import masks_to_boxes
from torchvision.transforms import InterpolationMode
from tqdm import tqdm
from ultralytics import FastSAM

from controlled_experiment.settings import FULLY_LABELED_RECORDINGS, GAZE_SEGMENTATION_RESULTS_PATH

In [2]:
with open("experiment_metadata.json") as file:
    experiment_metadata = json.load(file)
    trial_recordings_metadata = experiment_metadata["trial_recordings_metadata"]
    trial_recording_uuids = list(trial_recordings_metadata.keys())
    labeling_same_background_uuid = experiment_metadata["labeling_same_background_uuid"]
    labeling_diff_background_uuid = experiment_metadata["labeling_diff_background_uuid"]

with Session(engine) as session:
    trial_recordings = (
        session.query(Recording).filter(Recording.uuid.in_(trial_recording_uuids)).all()
    )

# Segmenting and Tracking based on Gaze Data, and grounding based on previously built Vector Index

There's a few considerations that might be interesting in an experimental context:
1. Selection of `k` in top-k results from the database?
2. Segmentation quality (IOU?, Confidence?)
3. Adding padding to the bounding boxes?
4. Indexing, search parameters? (which ones exist)
5. Merging of same-frame ROIs or not?
6. Importance of metrics (average, min, max, variance, ?)

In [3]:
class GazeSegmentationJob:
    def __init__(
        self,
        video_path: Path,
        gaze_data_path: Path,
        results_path: Path,
        fovea_fov: float = GAZE_FOVEA_FOV,
        fov_x: float = TOBII_FOV_X,
        checkpoint_path: str = "checkpoints/FastSAM-x.pt",
        output_video_path: Path | None = None,
    ):
        self.video_path = video_path
        self.gaze_data_path = gaze_data_path
        self.fovea_fov = fovea_fov
        self.fov_x = fov_x

        # Set up the results directory.
        self.results_path = results_path
        if self.results_path.exists():
            shutil.rmtree(self.results_path, ignore_errors=True)
            self.results_path.mkdir(parents=True, exist_ok=True)
        self.results_path.mkdir(parents=True, exist_ok=True)

        # Load the FastSAM model.
        self.model = FastSAM(checkpoint_path)

        # Video properties.
        self.resolution = cv2_video_resolution(self.video_path)
        self.aspect_ratio = self.resolution[1] / self.resolution[0]  # W / H
        self.fps = cv2_video_fps(self.video_path)
        self.viewed_radius = int((self.fovea_fov / self.fov_x) * self.resolution[1])
        self.frame_count = cv2_video_frame_count(self.video_path)

        # Set up the output video.
        if output_video_path is not None:
            self.video_result = cv2.VideoWriter(
                str(output_video_path),
                cv2.VideoWriter_fourcc(*"mp4v"),
                self.fps,
                (self.resolution[1], self.resolution[0]),
            )
        else:
            self.video_result = None

        # Parse gaze data.
        self.gaze_data = parse_gazedata_file(self.gaze_data_path)
        self.gaze_points = get_gaze_points(self.gaze_data, self.resolution)

        # Map frame indexes to gaze points.
        self.frame_gaze_mapping = match_frames_to_gaze(
            self.frame_count, self.gaze_points, self.fps
        )

    def get_gaze_position(self, frame_idx: int) -> tuple[int, int] | None:
        """
        Get the gaze position for a frame index.
        """
        gaze_points = self.frame_gaze_mapping[frame_idx]
        if len(gaze_points) == 0:
            return None
        return gaze_points[0].position

    def mask_too_large(self, mask: torch.Tensor) -> bool:
        """
        Check if the mask area is less than or equal to 30% of the frame area.

        Args:
            mask: A tensor containing a single mask of shape (H, W)

        Returns:
            bool: True if the mask's area is less than or equal to 30% of the frame area, False otherwise.
        """
        height, width = mask.shape
        frame_area = height * width
        max_mask_area = 0.1 * frame_area

        mask_area = mask.sum()
        return mask_area >= max_mask_area

    def run(self):
        with ThreadPoolExecutor() as executor:
            for frame_idx, results in enumerate(
                self.model.track(
                    source=str(self.video_path), imgsz=640, stream=True, verbose=False
                )
            ):
                try:
                    gaze_position = self.get_gaze_position(frame_idx)
                    if gaze_position is None:
                        continue

                    boxes = []
                    rois = []
                    masks = []
                    object_ids = []
                    confidences = []
                    for result in results:
                        confidences.append(float(result.boxes[0].conf))

                        mask = F.resize(
                            result.masks[0].data,
                            self.resolution,
                            interpolation=InterpolationMode.NEAREST,
                        ).squeeze()

                        if not self.mask_too_large(mask) and mask_was_viewed(
                            mask, gaze_position
                        ):
                            box = masks_to_boxes(mask.unsqueeze(0)).int().cpu().numpy()[0]
                            x1, y1, x2, y2 = box
                            roi = results[0].orig_img[y1:y2, x1:x2, :]
                            roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
                            boxes.append(box)
                            masks.append(mask.cpu().numpy().astype(np.uint8))

                            rois.append(roi)
                            object_ids.append(int(result.boxes.id[0]))

                    if len(boxes) > 0:
                        # Offload saving with thread pool (asynchronously)
                        rois_array = np.empty(len(rois), dtype=object)
                        for i, roi in enumerate(rois):
                            rois_array[i] = roi

                        masks_array = np.empty(len(masks), dtype=object)
                        for i, mask in enumerate(masks):
                            masks_array[i] = mask

                        executor.submit(
                            np.savez_compressed,
                            self.results_path / f"{frame_idx}.npz",
                            boxes=boxes,
                            rois=rois_array,
                            masks=masks_array,
                            object_ids=object_ids,
                            frame_idx=frame_idx,
                            gaze_position=gaze_position,
                            confidences=confidences,
                        )

                except Exception as e:
                    print(f"Error processing frame {frame_idx}: {e}")
                    traceback.print_exc()

In [None]:
if GAZE_SEGMENTATION_RESULTS_PATH.exists():
    shutil.rmtree(GAZE_SEGMENTATION_RESULTS_PATH, ignore_errors=True)
GAZE_SEGMENTATION_RESULTS_PATH.mkdir(parents=True, exist_ok=True)


def process_recording(recording: Recording):
    """
    Process a recording for gaze segmentation.
    """
    recording_uuid = recording.uuid
    video_path = Path(recording.video_path)
    gaze_data_path = Path(recording.gaze_data_path)
    results_path = GAZE_SEGMENTATION_RESULTS_PATH / recording_uuid

    if results_path.exists():
        shutil.rmtree(results_path, ignore_errors=True)
    results_path.mkdir(parents=True, exist_ok=True)

    job = GazeSegmentationJob(
        video_path=video_path,
        gaze_data_path=gaze_data_path,
        results_path=results_path,
        fovea_fov=GAZE_FOVEA_FOV,
        fov_x=TOBII_FOV_X,
        checkpoint_path=CHECKPOINTS_PATH / "FastSAM-x.pt",
    )
    job.run()


for recording in tqdm(trial_recordings, desc="Processing recordings"):
    if recording.uuid in FULLY_LABELED_RECORDINGS:
        process_recording(recording)

Processing recordings: 100%|██████████| 14/14 [08:05<00:00, 34.70s/it]
